diff --git a/alembic/versions/3f4e5a1b8c9d_create_nutrition_service_tables.py b/alembic/versions/3f4e5a1b8c9d_create_nutrition_service_tables.py new file mode 100644 index 0000000..7850473 --- /dev/null +++ b/alembic/versions/3f4e5a1b8c9d_create_nutrition_service_tables.py @@ -0,0 +1,151 @@ +"""Create nutrition service tables + +Revision ID: 3f4e5a1b8c9d +Revises: 2ede6d343f7c +Create Date: 2025-10-16 23:01:02.123456 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '3f4e5a1b8c9d' +down_revision = '49846a45b6b0' +branch_labels = None +depends_on = None + + +def upgrade(): + # Таблица продуктов питания + op.create_table( + 'food_items', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('fatsecret_id', sa.String(length=50), nullable=True), + sa.Column('name', sa.String(length=255), nullable=False), + sa.Column('brand', sa.String(length=255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('food_type', sa.String(length=50), nullable=True), + sa.Column('serving_size', sa.String(length=100), nullable=True), + sa.Column('serving_weight_grams', sa.Float(), nullable=True), + sa.Column('calories', sa.Float(), nullable=True), + sa.Column('protein_grams', sa.Float(), nullable=True), + sa.Column('fat_grams', sa.Float(), nullable=True), + sa.Column('carbs_grams', sa.Float(), nullable=True), + sa.Column('fiber_grams', sa.Float(), nullable=True), + sa.Column('sugar_grams', sa.Float(), nullable=True), + sa.Column('sodium_mg', sa.Float(), nullable=True), + sa.Column('cholesterol_mg', sa.Float(), nullable=True), + sa.Column('ingredients', sa.Text(), nullable=True), + sa.Column('is_verified', sa.Boolean(), nullable=True), + sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_food_items_fatsecret_id'), 'food_items', ['fatsecret_id'], unique=True) + op.create_index(op.f('ix_food_items_name'), 'food_items', ['name'], unique=False) + op.create_index(op.f('ix_food_items_uuid'), 'food_items', ['uuid'], unique=True) + + # Таблица записей пользователя о потреблении пищи + op.create_table( + 'user_nutrition_entries', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('entry_date', sa.Date(), nullable=False), + sa.Column('meal_type', sa.String(length=50), nullable=False), + sa.Column('food_item_id', sa.Integer(), nullable=True), + sa.Column('custom_food_name', sa.String(length=255), nullable=True), + sa.Column('quantity', sa.Float(), nullable=False), + sa.Column('unit', sa.String(length=50), nullable=True), + sa.Column('calories', sa.Float(), nullable=True), + sa.Column('protein_grams', sa.Float(), nullable=True), + sa.Column('fat_grams', sa.Float(), nullable=True), + sa.Column('carbs_grams', sa.Float(), nullable=True), + sa.Column('notes', sa.Text(), nullable=True), + sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['food_item_id'], ['food_items.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_nutrition_entries_entry_date'), 'user_nutrition_entries', ['entry_date'], unique=False) + op.create_index(op.f('ix_user_nutrition_entries_user_id'), 'user_nutrition_entries', ['user_id'], unique=False) + op.create_index(op.f('ix_user_nutrition_entries_uuid'), 'user_nutrition_entries', ['uuid'], unique=True) + + # Таблица для отслеживания потребления воды + op.create_table( + 'water_intake', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('entry_date', sa.Date(), nullable=False), + sa.Column('amount_ml', sa.Integer(), nullable=False), + sa.Column('entry_time', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('notes', sa.Text(), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_water_intake_entry_date'), 'water_intake', ['entry_date'], unique=False) + op.create_index(op.f('ix_water_intake_user_id'), 'water_intake', ['user_id'], unique=False) + op.create_index(op.f('ix_water_intake_uuid'), 'water_intake', ['uuid'], unique=True) + + # Таблица для отслеживания физической активности + op.create_table( + 'user_activity_entries', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('uuid', postgresql.UUID(as_uuid=True), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('entry_date', sa.Date(), nullable=False), + sa.Column('activity_type', sa.String(length=100), nullable=False), + sa.Column('duration_minutes', sa.Integer(), nullable=False), + sa.Column('calories_burned', sa.Float(), nullable=True), + sa.Column('distance_km', sa.Float(), nullable=True), + sa.Column('steps', sa.Integer(), nullable=True), + sa.Column('intensity', sa.String(length=20), nullable=True), + sa.Column('notes', sa.Text(), nullable=True), + sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_activity_entries_entry_date'), 'user_activity_entries', ['entry_date'], unique=False) + op.create_index(op.f('ix_user_activity_entries_user_id'), 'user_activity_entries', ['user_id'], unique=False) + op.create_index(op.f('ix_user_activity_entries_uuid'), 'user_activity_entries', ['uuid'], unique=True) + + # Таблица для хранения целей пользователя по питанию и активности + op.create_table( + 'nutrition_goals', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('daily_calorie_goal', sa.Integer(), nullable=True), + sa.Column('protein_goal_grams', sa.Integer(), nullable=True), + sa.Column('fat_goal_grams', sa.Integer(), nullable=True), + sa.Column('carbs_goal_grams', sa.Integer(), nullable=True), + sa.Column('water_goal_ml', sa.Integer(), nullable=True), + sa.Column('activity_goal_minutes', sa.Integer(), nullable=True), + sa.Column('weight_goal_kg', sa.Float(), nullable=True), + sa.Column('goal_type', sa.String(length=50), nullable=True), + sa.Column('created_at', sa.TIMESTAMP(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.TIMESTAMP(timezone=True), nullable=True), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_nutrition_goals_user_id'), 'nutrition_goals', ['user_id'], unique=True) + + +def downgrade(): + op.drop_index(op.f('ix_nutrition_goals_user_id'), table_name='nutrition_goals') + op.drop_table('nutrition_goals') + op.drop_index(op.f('ix_user_activity_entries_uuid'), table_name='user_activity_entries') + op.drop_index(op.f('ix_user_activity_entries_user_id'), table_name='user_activity_entries') + op.drop_index(op.f('ix_user_activity_entries_entry_date'), table_name='user_activity_entries') + op.drop_table('user_activity_entries') + op.drop_index(op.f('ix_water_intake_uuid'), table_name='water_intake') + op.drop_index(op.f('ix_water_intake_user_id'), table_name='water_intake') + op.drop_index(op.f('ix_water_intake_entry_date'), table_name='water_intake') + op.drop_table('water_intake') + op.drop_index(op.f('ix_user_nutrition_entries_uuid'), table_name='user_nutrition_entries') + op.drop_index(op.f('ix_user_nutrition_entries_user_id'), table_name='user_nutrition_entries') + op.drop_index(op.f('ix_user_nutrition_entries_entry_date'), table_name='user_nutrition_entries') + op.drop_table('user_nutrition_entries') + op.drop_index(op.f('ix_food_items_uuid'), table_name='food_items') + op.drop_index(op.f('ix_food_items_name'), table_name='food_items') + op.drop_index(op.f('ix_food_items_fatsecret_id'), table_name='food_items') + op.drop_table('food_items') \ No newline at end of file diff --git a/alembic/versions/a2e71842cf5a_add_nutrition_service_tables.py b/alembic/versions/a2e71842cf5a_add_nutrition_service_tables.py new file mode 100644 index 0000000..957d6d9 --- /dev/null +++ b/alembic/versions/a2e71842cf5a_add_nutrition_service_tables.py @@ -0,0 +1,217 @@ +"""Add nutrition service tables + +Revision ID: a2e71842cf5a +Revises: c78a12db4567 +Create Date: 2025-10-16 10:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "a2e71842cf5a" +down_revision = "c78a12db4567" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + # Создание таблицы food_items + op.create_table( + "food_items", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("fatsecret_id", sa.String(length=50), nullable=True), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("brand", sa.String(length=255), nullable=True), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("food_type", sa.String(length=50), nullable=True), + sa.Column("serving_size", sa.String(length=100), nullable=True), + sa.Column("serving_weight_grams", sa.Float(), nullable=True), + sa.Column("calories", sa.Float(), nullable=True), + sa.Column("protein_grams", sa.Float(), nullable=True), + sa.Column("fat_grams", sa.Float(), nullable=True), + sa.Column("carbs_grams", sa.Float(), nullable=True), + sa.Column("fiber_grams", sa.Float(), nullable=True), + sa.Column("sugar_grams", sa.Float(), nullable=True), + sa.Column("sodium_mg", sa.Float(), nullable=True), + sa.Column("cholesterol_mg", sa.Float(), nullable=True), + sa.Column("ingredients", sa.Text(), nullable=True), + sa.Column("is_verified", sa.Boolean(), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_food_items_fatsecret_id"), "food_items", ["fatsecret_id"], unique=True + ) + op.create_index(op.f("ix_food_items_uuid"), "food_items", ["uuid"], unique=True) + + # Создание таблицы user_nutrition_entries + op.create_table( + "user_nutrition_entries", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("entry_date", sa.Date(), nullable=False), + sa.Column("meal_type", sa.String(length=50), nullable=False), + sa.Column("food_item_id", sa.Integer(), nullable=True), + sa.Column("custom_food_name", sa.String(length=255), nullable=True), + sa.Column("quantity", sa.Float(), nullable=False), + sa.Column("unit", sa.String(length=50), nullable=True), + sa.Column("calories", sa.Float(), nullable=True), + sa.Column("protein_grams", sa.Float(), nullable=True), + sa.Column("fat_grams", sa.Float(), nullable=True), + sa.Column("carbs_grams", sa.Float(), nullable=True), + sa.Column("notes", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.ForeignKeyConstraint(["food_item_id"], ["food_items.id"],), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_nutrition_entries_entry_date"), + "user_nutrition_entries", + ["entry_date"], + unique=False, + ) + op.create_index( + op.f("ix_user_nutrition_entries_user_id"), + "user_nutrition_entries", + ["user_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_nutrition_entries_uuid"), + "user_nutrition_entries", + ["uuid"], + unique=True + ) + + # Создание таблицы water_intake + op.create_table( + "water_intake", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("entry_date", sa.Date(), nullable=False), + sa.Column("amount_ml", sa.Integer(), nullable=False), + sa.Column( + "entry_time", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("notes", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_water_intake_entry_date"), "water_intake", ["entry_date"], unique=False + ) + op.create_index( + op.f("ix_water_intake_user_id"), "water_intake", ["user_id"], unique=False + ) + op.create_index(op.f("ix_water_intake_uuid"), "water_intake", ["uuid"], unique=True) + + # Создание таблицы user_activity_entries + op.create_table( + "user_activity_entries", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("uuid", postgresql.UUID(as_uuid=True), nullable=True), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("entry_date", sa.Date(), nullable=False), + sa.Column("activity_type", sa.String(length=100), nullable=False), + sa.Column("duration_minutes", sa.Integer(), nullable=False), + sa.Column("calories_burned", sa.Float(), nullable=True), + sa.Column("distance_km", sa.Float(), nullable=True), + sa.Column("steps", sa.Integer(), nullable=True), + sa.Column("intensity", sa.String(length=20), nullable=True), + sa.Column("notes", sa.Text(), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_user_activity_entries_entry_date"), + "user_activity_entries", + ["entry_date"], + unique=False, + ) + op.create_index( + op.f("ix_user_activity_entries_user_id"), + "user_activity_entries", + ["user_id"], + unique=False, + ) + op.create_index( + op.f("ix_user_activity_entries_uuid"), + "user_activity_entries", + ["uuid"], + unique=True + ) + + # Создание таблицы nutrition_goals + op.create_table( + "nutrition_goals", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("daily_calorie_goal", sa.Integer(), nullable=True), + sa.Column("protein_goal_grams", sa.Integer(), nullable=True), + sa.Column("fat_goal_grams", sa.Integer(), nullable=True), + sa.Column("carbs_goal_grams", sa.Integer(), nullable=True), + sa.Column("water_goal_ml", sa.Integer(), nullable=True), + sa.Column("activity_goal_minutes", sa.Integer(), nullable=True), + sa.Column("weight_goal_kg", sa.Float(), nullable=True), + sa.Column("goal_type", sa.String(length=50), nullable=True), + sa.Column( + "created_at", + sa.TIMESTAMP(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.Column("updated_at", sa.TIMESTAMP(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_nutrition_goals_user_id"), "nutrition_goals", ["user_id"], unique=True + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_nutrition_goals_user_id"), table_name="nutrition_goals") + op.drop_table("nutrition_goals") + op.drop_index(op.f("ix_user_activity_entries_uuid"), table_name="user_activity_entries") + op.drop_index(op.f("ix_user_activity_entries_user_id"), table_name="user_activity_entries") + op.drop_index(op.f("ix_user_activity_entries_entry_date"), table_name="user_activity_entries") + op.drop_table("user_activity_entries") + op.drop_index(op.f("ix_water_intake_uuid"), table_name="water_intake") + op.drop_index(op.f("ix_water_intake_user_id"), table_name="water_intake") + op.drop_index(op.f("ix_water_intake_entry_date"), table_name="water_intake") + op.drop_table("water_intake") + op.drop_index(op.f("ix_user_nutrition_entries_uuid"), table_name="user_nutrition_entries") + op.drop_index(op.f("ix_user_nutrition_entries_user_id"), table_name="user_nutrition_entries") + op.drop_index(op.f("ix_user_nutrition_entries_entry_date"), table_name="user_nutrition_entries") + op.drop_table("user_nutrition_entries") + op.drop_index(op.f("ix_food_items_uuid"), table_name="food_items") + op.drop_index(op.f("ix_food_items_fatsecret_id"), table_name="food_items") + op.drop_table("food_items") + # ### end Alembic commands ### \ No newline at end of file diff --git a/docker-compose.prod.yml b/docker-compose.prod.yml index e7bfa6d..87a3e77 100644 --- a/docker-compose.prod.yml +++ b/docker-compose.prod.yml @@ -25,6 +25,7 @@ services: - LOCATION_SERVICE_URL=http://location-service-1:8003,http://location-service-2:8003 - CALENDAR_SERVICE_URL=http://calendar-service-1:8004,http://calendar-service-2:8004 - NOTIFICATION_SERVICE_URL=http://notification-service-1:8005,http://notification-service-2:8005 + - NUTRITION_SERVICE_URL=http://nutrition-service-1:8006,http://nutrition-service-2:8006 - REDIS_URL=redis://redis-cluster:6379/0 depends_on: - redis-cluster @@ -47,6 +48,7 @@ services: - LOCATION_SERVICE_URL=http://location-service-1:8003,http://location-service-2:8003 - CALENDAR_SERVICE_URL=http://calendar-service-1:8004,http://calendar-service-2:8004 - NOTIFICATION_SERVICE_URL=http://notification-service-1:8005,http://notification-service-2:8005 + - NUTRITION_SERVICE_URL=http://nutrition-service-1:8006,http://nutrition-service-2:8006 - REDIS_URL=redis://redis-cluster:6379/0 depends_on: - redis-cluster @@ -286,4 +288,48 @@ volumes: kafka_3_data: zookeeper_data: prometheus_data: - grafana_data: \ No newline at end of file + grafana_data: + # Nutrition Service Cluster + nutrition-service-1: + image: women-safety/nutrition-service:${TAG:-latest} + environment: + - NODE_ID=1 + - DATABASE_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-primary:5432/women_safety_prod + - DATABASE_REPLICA_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-replica:5432/women_safety_prod + - REDIS_URL=redis://redis-cluster:6379/5 + - FATSECRET_CLIENT_ID=${FATSECRET_CLIENT_ID} + - FATSECRET_CLIENT_SECRET=${FATSECRET_CLIENT_SECRET} + depends_on: + - postgres-primary + - redis-cluster + restart: always + deploy: + resources: + limits: + cpus: '1.0' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M + + nutrition-service-2: + image: women-safety/nutrition-service:${TAG:-latest} + environment: + - NODE_ID=2 + - DATABASE_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-primary:5432/women_safety_prod + - DATABASE_REPLICA_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-replica:5432/women_safety_prod + - REDIS_URL=redis://redis-cluster:6379/5 + - FATSECRET_CLIENT_ID=${FATSECRET_CLIENT_ID} + - FATSECRET_CLIENT_SECRET=${FATSECRET_CLIENT_SECRET} + depends_on: + - postgres-primary + - redis-cluster + restart: always + deploy: + resources: + limits: + cpus: '1.0' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M \ No newline at end of file diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 81ea97c..023f3e0 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -29,12 +29,14 @@ services: - LOCATION_SERVICE_URL=http://location-service:8003 - CALENDAR_SERVICE_URL=http://calendar-service:8004 - NOTIFICATION_SERVICE_URL=http://notification-service:8005 + - NUTRITION_SERVICE_URL=http://nutrition-service:8006 depends_on: - user-service - emergency-service - location-service - calendar-service - notification-service + - nutrition-service user-service: image: women-safety/user-service:latest @@ -96,5 +98,18 @@ services: - postgres - redis + nutrition-service: + image: women-safety/nutrition-service:latest + ports: + - "8006:8006" + environment: + - DATABASE_URL=postgresql://postgres:postgres@postgres:5432/women_safety_test + - REDIS_URL=redis://redis:6379/5 + - FATSECRET_CLIENT_ID=test-fatsecret-client-id + - FATSECRET_CLIENT_SECRET=test-fatsecret-client-secret + depends_on: + - postgres + - redis + volumes: postgres_test_data: \ No newline at end of file diff --git a/docs/API.md b/docs/API.md index 0b4b3fe..a5aadd2 100644 --- a/docs/API.md +++ b/docs/API.md @@ -6,6 +6,26 @@ The Women's Safety App provides a comprehensive API for managing user profiles, **Base URL:** `http://localhost:8000` (API Gateway) +## Swagger Documentation + +Интерактивная документация API доступна через Swagger UI по следующим URL: + +- API Gateway: `http://localhost:8000/docs` +- User Service: `http://localhost:8001/docs` +- Emergency Service: `http://localhost:8002/docs` +- Location Service: `http://localhost:8003/docs` +- Calendar Service: `http://localhost:8004/docs` +- Notification Service: `http://localhost:8005/docs` +- Nutrition Service: `http://localhost:8006/docs` + +Документация в формате ReDoc доступна по адресам: + +- API Gateway: `http://localhost:8000/redoc` +- User Service: `http://localhost:8001/redoc` +- (и т.д. для остальных сервисов) + +> **Примечание**: Swagger-документация для каждого сервиса доступна только при запущенном соответствующем сервисе. Если сервис не запущен, страница документации будет недоступна. + ## Authentication All endpoints except registration and login require JWT authentication. @@ -15,6 +35,29 @@ All endpoints except registration and login require JWT authentication. Authorization: Bearer ``` +### Testing with Swagger UI + +Для тестирования API через Swagger UI: + +1. Запустите необходимые сервисы: + ```bash + ./start_services.sh + ``` + +2. Откройте Swagger UI в браузере: + ``` + http://localhost:8000/docs + ``` + +3. Получите JWT-токен через эндпоинты `/api/v1/auth/login` или `/api/v1/auth/register` + +4. Авторизуйтесь в Swagger UI: + - Нажмите на кнопку "Authorize" в правом верхнем углу + - Введите полученный JWT-токен в формате: `Bearer ` + - Нажмите "Authorize" + +5. Теперь вы можете тестировать все защищенные эндпоинты + ## API Endpoints ### 🔐 Authentication @@ -247,6 +290,109 @@ Authorization: Bearer } ``` +### 🍎 Nutrition Services + +#### Search Food Items +```http +GET /api/v1/nutrition/foods?query=apple +Authorization: Bearer +``` + +**Response:** +```json +{ + "results": [ + { + "food_id": "123456", + "name": "Apple, raw, with skin", + "brand": "", + "calories": 52, + "serving_size": "100g", + "nutrients": { + "carbohydrates": 13.8, + "protein": 0.3, + "fat": 0.2, + "fiber": 2.4 + } + }, + { + "food_id": "789012", + "name": "Apple juice, unsweetened", + "brand": "Example Brand", + "calories": 46, + "serving_size": "100ml", + "nutrients": { + "carbohydrates": 11.2, + "protein": 0.1, + "fat": 0.1, + "fiber": 0.2 + } + } + ] +} +``` + +#### Add Nutrition Entry +```http +POST /api/v1/nutrition/entries +Authorization: Bearer +``` + +**Body:** +```json +{ + "food_id": "123456", + "date": "2025-10-16", + "meal_type": "lunch", + "quantity": 1.0, + "serving_size": "100g", + "notes": "Red apple" +} +``` + +#### Get Daily Nutrition Summary +```http +GET /api/v1/nutrition/daily-summary?date=2025-10-16 +Authorization: Bearer +``` + +**Response:** +```json +{ + "date": "2025-10-16", + "total_calories": 1578, + "total_carbohydrates": 175.3, + "total_proteins": 78.2, + "total_fats": 52.8, + "total_water": 1200, + "entries": [ + { + "id": 123, + "food_name": "Apple, raw, with skin", + "meal_type": "lunch", + "calories": 52, + "quantity": 1.0, + "serving_size": "100g" + } + ] +} +``` + +#### Track Water Intake +```http +POST /api/v1/nutrition/water +Authorization: Bearer +``` + +**Body:** +```json +{ + "date": "2025-10-16", + "amount_ml": 250, + "time": "12:30:00" +} +``` + ### 📊 System Status #### Check Service Health diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 7536a0c..1ce7753 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -25,16 +25,16 @@ This document describes the microservices architecture of the Women's Safety App │ Request Routing) │ └───────────────────────────┘ │ - ┌─────────────┬──────────────┼──────────────┬─────────────┐ - │ │ │ │ │ -┌─────────┐ ┌─────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ -│ User │ │Emergency│ │ Location │ │ Calendar │ │Notification │ -│Service │ │Service │ │ Service │ │ Service │ │ Service │ -│:8001 │ │:8002 │ │ :8003 │ │ :8004 │ │ :8005 │ -└─────────┘ └─────────┘ └─────────────┘ └─────────────┘ └─────────────┘ - │ │ │ │ │ - └─────────────┼──────────────┼──────────────┼─────────────┘ - │ │ │ + ┌─────────────┬──────────────┼──────────────┬─────────────┬─────────────┐ + │ │ │ │ │ │ +┌─────────┐ ┌─────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ +│ User │ │Emergency│ │ Location │ │ Calendar │ │Notification │ │ Nutrition │ +│Service │ │Service │ │ Service │ │ Service │ │ Service │ │ Service │ +│:8001 │ │:8002 │ │ :8003 │ │ :8004 │ │ :8005 │ │ :8006 │ +└─────────┘ └─────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ + │ │ │ │ │ │ + └─────────────┼──────────────┼──────────────┼─────────────┼─────────────┘ + │ │ │ │ ┌────────────────────────────────────────────────┐ │ Message Bus │ │ (Kafka/RabbitMQ) │ diff --git a/docs/FATSECRET_API.md b/docs/FATSECRET_API.md new file mode 100644 index 0000000..719dcc2 --- /dev/null +++ b/docs/FATSECRET_API.md @@ -0,0 +1,228 @@ +# Работа с FatSecret API в проекте + +Этот документ описывает, как используется API FatSecret для получения данных о продуктах питания и их пищевой ценности в нашем проекте. + +## Настройка API + +### Ключи API +Для работы с FatSecret API необходимы следующие ключи: +- `FATSECRET_CLIENT_ID` - ID клиента +- `FATSECRET_CLIENT_SECRET` - секрет клиента +- `FATSECRET_CUSTOMER_KEY` - ключ пользователя (используется как альтернатива CLIENT_ID) + +Эти ключи хранятся в `.env` файле проекта и загружаются в конфигурацию через модуль `shared/config.py`. + +### Методы аутентификации + +FatSecret API поддерживает два метода аутентификации: +1. **OAuth 2.0** - требует прокси-сервер с белым списком IP для запроса токенов (не работает в нашем тестовом окружении) +2. **OAuth 1.0** - работает напрямую и подписывает каждый запрос (рекомендуется использовать) + +## Примеры использования API + +### Поиск продуктов питания +```python +def search_food(query, max_results=5): + """Поиск продуктов по названию""" + # URL для API + url = "https://platform.fatsecret.com/rest/server.api" + + # Параметры запроса + params = { + 'method': 'foods.search', + 'search_expression': query, + 'max_results': max_results, + 'format': 'json' + } + + # Подписываем запрос с помощью OAuth 1.0 + oauth_params = generate_oauth_params("GET", url, params) + + # Отправляем запрос + response = requests.get(url, params=oauth_params) + + if response.status_code == 200: + return response.json() + return None +``` + +### Получение информации о продукте +```python +def get_food_details(food_id): + """Получение подробной информации о продукте по ID""" + # URL для API + url = "https://platform.fatsecret.com/rest/server.api" + + # Параметры запроса + params = { + 'method': 'food.get', + 'food_id': food_id, + 'format': 'json' + } + + # Подписываем запрос с помощью OAuth 1.0 + oauth_params = generate_oauth_params("GET", url, params) + + # Отправляем запрос + response = requests.get(url, params=oauth_params) + + if response.status_code == 200: + return response.json() + return None +``` + +## Генерация OAuth 1.0 подписи + +```python +def generate_oauth_params(http_method, url, params): + """Создание и подписание OAuth 1.0 параметров""" + # Текущее время в секундах + timestamp = str(int(time.time())) + # Случайная строка для nonce + nonce = ''.join([str(random.randint(0, 9)) for _ in range(8)]) + + # Базовый набор параметров OAuth + oauth_params = { + 'oauth_consumer_key': FATSECRET_KEY, + 'oauth_nonce': nonce, + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': timestamp, + 'oauth_version': '1.0' + } + + # Объединяем с параметрами запроса + all_params = {**params, **oauth_params} + + # Сортируем параметры по ключу + sorted_params = sorted(all_params.items()) + + # Создаем строку параметров для подписи + param_string = "&".join([f"{urllib.parse.quote(str(k), safe='')}={urllib.parse.quote(str(v), safe='')}" + for k, v in sorted_params]) + + # Создаем строку для подписи + signature_base = f"{http_method}&{urllib.parse.quote(url, safe='')}&{urllib.parse.quote(param_string, safe='')}" + + # Создаем ключ для подписи + signing_key = f"{urllib.parse.quote(str(FATSECRET_SECRET), safe='')}&" + + # Создаем HMAC-SHA1 подпись + signature = base64.b64encode( + hmac.new( + signing_key.encode(), + signature_base.encode(), + hashlib.sha1 + ).digest() + ).decode() + + # Добавляем подпись к параметрам OAuth + all_params['oauth_signature'] = signature + + return all_params +``` + +## Формат ответа API + +### Поиск продуктов +Структура ответа от метода `foods.search`: +```json +{ + "foods": { + "max_results": "5", + "total_results": "1000", + "page_number": "0", + "food": [ + { + "food_id": "35718", + "food_name": "Apples", + "food_description": "Per 100g - Calories: 52kcal | Fat: 0.17g | Carbs: 13.81g | Protein: 0.26g", + "food_url": "https://www.fatsecret.com/calories-nutrition/usda/apples?portionid=34128" + }, + // ...другие продукты + ] + } +} +``` + +### Информация о продукте +Структура ответа от метода `food.get`: +```json +{ + "food": { + "food_id": "35718", + "food_name": "Apples", + "food_type": "Generic", + "servings": { + "serving": [ + { + "serving_id": "34128", + "serving_description": "100g", + "calories": "52", + "carbohydrate": "13.81", + "protein": "0.26", + "fat": "0.17", + // другие пищевые вещества + }, + // другие варианты порций + ] + } + } +} +``` + +## Ограничения API + +1. Функциональность поиска на русском языке может быть недоступна в базовой версии API +2. Ограничение на количество запросов в месяц (зависит от уровня доступа) +3. OAuth 2.0 требует прокси-сервера, настроенного на определенные IP-адреса + +## Тестирование API + +Для тестирования API можно использовать готовый тестовый скрипт, который находится в корне проекта: + +```bash +# Активировать виртуальное окружение +source venv/bin/activate + +# Запустить тестовый скрипт +python test_fatsecret_api_oauth1.py +``` + +Вы также можете использовать этот скрипт как шаблон для написания собственных тестов. Примеры использования: + +```python +# Импортировать функции из тестового скрипта +from test_fatsecret_api_oauth1 import search_food, process_search_results + +# Поиск продуктов на английском +result = search_food("chicken breast") +process_search_results(result) + +# Поиск продуктов на русском +result = search_food("яблоко", locale="ru_RU") +process_search_results(result) +``` + +### Примеры команд для тестирования через cURL + +Для тестирования API через cURL можно использовать следующие команды: + +```bash +# Поиск продуктов через nutrition service (требуется авторизация) +curl -X POST http://localhost:8006/api/v1/nutrition/search \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer YOUR_JWT_TOKEN" \ + -d '{"query": "apple", "max_results": 5}' + +# Прямое тестирование FatSecret API (OAuth 1.0) +curl -X GET "https://platform.fatsecret.com/rest/server.api?method=foods.search&search_expression=apple&max_results=5&format=json&oauth_consumer_key=YOUR_CONSUMER_KEY&oauth_signature_method=HMAC-SHA1&oauth_timestamp=TIMESTAMP&oauth_nonce=NONCE&oauth_version=1.0&oauth_signature=YOUR_SIGNATURE" +``` + +> **Примечание:** Для выполнения прямого запроса к FatSecret API через cURL необходимо сгенерировать правильную OAuth 1.0 подпись. Рекомендуется использовать скрипт `test_fatsecret_api_oauth1.py` вместо этого. + +## Рекомендации по использованию + +1. Использовать OAuth 1.0 для аутентификации, так как он работает без дополнительной инфраструктуры +2. Кэшировать результаты запросов, чтобы снизить нагрузку на API +3. Обрабатывать возможные ошибки API и предоставлять пользователям понятные сообщения +4. Использовать английские запросы для поиска, так как база данных в основном на английском языке \ No newline at end of file diff --git a/docs/NUTRITION_API.md b/docs/NUTRITION_API.md new file mode 100644 index 0000000..168ff5d --- /dev/null +++ b/docs/NUTRITION_API.md @@ -0,0 +1,593 @@ +# API Сервиса Питания (Nutrition Service) + +Сервис питания предоставляет API для работы с данными о питании, включая поиск продуктов питания, добавление продуктов в дневник питания, отслеживание потребления воды и физической активности. + +## Основные функции + +- Поиск продуктов питания через FatSecret API +- Отслеживание потребления пищи и питательных веществ +- Учет потребления воды +- Отслеживание физической активности +- Установка и отслеживание целей по питанию и активности + +## Базовый URL + +``` +http://localhost:8006/api/v1/nutrition/ +``` + +## Swagger-документация + +Интерактивная документация API доступна через Swagger UI по следующим URL: + +``` +http://localhost:8006/docs +``` + +или через ReDoc: + +``` +http://localhost:8006/redoc +``` + +> **Примечание**: Swagger-документация доступна только при запущенном сервисе питания. Если сервис не запущен, страница документации будет недоступна. + +### Использование Swagger UI + +1. Откройте URL `http://localhost:8006/docs` в браузере +2. Авторизуйтесь с помощью кнопки "Authorize" в верхней части страницы: + - Введите ваш JWT токен в формате: `Bearer ` + - Нажмите "Authorize" +3. Теперь вы можете тестировать все эндпоинты API непосредственно через Swagger UI: + - Выберите нужный эндпоинт + - Заполните параметры запроса + - Нажмите "Execute" для отправки запроса + +![Swagger UI Example](https://swagger.io/swagger/media/images/swagger-ui-example.png) + +## Эндпоинты + +### Поиск продуктов + +#### Поиск по названию + +```http +POST /api/v1/nutrition/search +``` + +Параметры запроса: +```json +{ + "query": "яблоко", + "page_number": 0, + "max_results": 10 +} +``` + +Ответ: +```json +[ + { + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "fatsecret_id": "35718", + "name": "Apple", + "brand": null, + "description": "A common fruit", + "food_type": "Generic", + "serving_size": "100g", + "serving_weight_grams": 100.0, + "calories": 52.0, + "protein_grams": 0.26, + "fat_grams": 0.17, + "carbs_grams": 13.81, + "fiber_grams": 2.4, + "sugar_grams": 10.39, + "sodium_mg": 1.0, + "cholesterol_mg": 0.0, + "ingredients": null, + "is_verified": true, + "created_at": "2025-10-16T23:10:00" + } +] +``` + +#### Получение информации о продукте + +```http +GET /api/v1/nutrition/food/{food_id} +``` + +Ответ: +```json +{ + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "fatsecret_id": "35718", + "name": "Apple", + "brand": null, + "description": "A common fruit", + "food_type": "Generic", + "serving_size": "100g", + "serving_weight_grams": 100.0, + "calories": 52.0, + "protein_grams": 0.26, + "fat_grams": 0.17, + "carbs_grams": 13.81, + "fiber_grams": 2.4, + "sugar_grams": 10.39, + "sodium_mg": 1.0, + "cholesterol_mg": 0.0, + "ingredients": null, + "is_verified": true, + "created_at": "2025-10-16T23:10:00" +} +``` + +### Дневник питания + +#### Добавление записи в дневник питания + +```http +POST /api/v1/nutrition/diary +``` + +Параметры запроса: +```json +{ + "food_item_id": 1, + "entry_date": "2025-10-16", + "meal_type": "breakfast", + "quantity": 1.5, + "unit": "piece", + "notes": "Morning apple" +} +``` + +Ответ: +```json +{ + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "meal_type": "breakfast", + "food_item_id": 1, + "custom_food_name": null, + "quantity": 1.5, + "unit": "piece", + "calories": 78.0, + "protein_grams": 0.39, + "fat_grams": 0.255, + "carbs_grams": 20.715, + "notes": "Morning apple", + "created_at": "2025-10-16T23:15:00" +} +``` + +#### Получение записей дневника за день + +```http +GET /api/v1/nutrition/diary?date=2025-10-16 +``` + +Ответ: +```json +[ + { + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "meal_type": "breakfast", + "food_item_id": 1, + "custom_food_name": null, + "quantity": 1.5, + "unit": "piece", + "calories": 78.0, + "protein_grams": 0.39, + "fat_grams": 0.255, + "carbs_grams": 20.715, + "notes": "Morning apple", + "created_at": "2025-10-16T23:15:00" + } +] +``` + +#### Получение сводки за день + +```http +GET /api/v1/nutrition/summary?date=2025-10-16 +``` + +Ответ: +```json +{ + "date": "2025-10-16", + "total_calories": 2150.5, + "total_protein": 85.2, + "total_fat": 65.4, + "total_carbs": 275.3, + "water_consumed_ml": 1500, + "activity_minutes": 45, + "calories_burned": 350, + "entries_by_meal": { + "breakfast": [ + { + "id": 1, + "food_name": "Apple", + "quantity": 1.5, + "unit": "piece", + "calories": 78.0 + } + ], + "lunch": [...], + "dinner": [...], + "snack": [...] + } +} +``` + +### Потребление воды + +#### Добавление записи о потреблении воды + +```http +POST /api/v1/nutrition/water +``` + +Параметры запроса: +```json +{ + "amount_ml": 250, + "entry_date": "2025-10-16", + "notes": "Morning glass" +} +``` + +Ответ: +```json +{ + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "amount_ml": 250, + "entry_time": "2025-10-16T08:30:00", + "notes": "Morning glass" +} +``` + +#### Получение записей о потреблении воды за день + +```http +GET /api/v1/nutrition/water?date=2025-10-16 +``` + +Ответ: +```json +[ + { + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "amount_ml": 250, + "entry_time": "2025-10-16T08:30:00", + "notes": "Morning glass" + }, + { + "id": 2, + "uuid": "223e4567-e89b-12d3-a456-426614174001", + "user_id": 42, + "entry_date": "2025-10-16", + "amount_ml": 500, + "entry_time": "2025-10-16T12:15:00", + "notes": "Lunch" + } +] +``` + +### Физическая активность + +#### Добавление записи о физической активности + +```http +POST /api/v1/nutrition/activity +``` + +Параметры запроса: +```json +{ + "entry_date": "2025-10-16", + "activity_type": "running", + "duration_minutes": 30, + "distance_km": 5.2, + "intensity": "medium", + "notes": "Morning run" +} +``` + +Ответ: +```json +{ + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "activity_type": "running", + "duration_minutes": 30, + "calories_burned": 300.5, + "distance_km": 5.2, + "steps": null, + "intensity": "medium", + "notes": "Morning run", + "created_at": "2025-10-16T09:00:00" +} +``` + +#### Получение записей об активности за день + +```http +GET /api/v1/nutrition/activity?date=2025-10-16 +``` + +Ответ: +```json +[ + { + "id": 1, + "uuid": "123e4567-e89b-12d3-a456-426614174000", + "user_id": 42, + "entry_date": "2025-10-16", + "activity_type": "running", + "duration_minutes": 30, + "calories_burned": 300.5, + "distance_km": 5.2, + "steps": null, + "intensity": "medium", + "notes": "Morning run", + "created_at": "2025-10-16T09:00:00" + } +] +``` + +### Цели по питанию и активности + +#### Установка целей + +```http +POST /api/v1/nutrition/goals +``` + +Параметры запроса: +```json +{ + "daily_calorie_goal": 2000, + "protein_goal_grams": 100, + "fat_goal_grams": 65, + "carbs_goal_grams": 250, + "water_goal_ml": 2500, + "activity_goal_minutes": 45, + "weight_goal_kg": 75.5, + "goal_type": "lose_weight" +} +``` + +Ответ: +```json +{ + "id": 1, + "user_id": 42, + "daily_calorie_goal": 2000, + "protein_goal_grams": 100, + "fat_goal_grams": 65, + "carbs_goal_grams": 250, + "water_goal_ml": 2500, + "activity_goal_minutes": 45, + "weight_goal_kg": 75.5, + "goal_type": "lose_weight", + "created_at": "2025-10-16T10:00:00", + "updated_at": "2025-10-16T10:00:00" +} +``` + +#### Получение текущих целей + +```http +GET /api/v1/nutrition/goals +``` + +Ответ: +```json +{ + "id": 1, + "user_id": 42, + "daily_calorie_goal": 2000, + "protein_goal_grams": 100, + "fat_goal_grams": 65, + "carbs_goal_grams": 250, + "water_goal_ml": 2500, + "activity_goal_minutes": 45, + "weight_goal_kg": 75.5, + "goal_type": "lose_weight", + "created_at": "2025-10-16T10:00:00", + "updated_at": "2025-10-16T10:00:00" +} +``` + +## Коды ошибок + +| Код | Описание | +|-----|----------| +| 400 | Некорректный запрос | +| 401 | Не авторизован | +| 403 | Доступ запрещен | +| 404 | Ресурс не найден | +| 500 | Внутренняя ошибка сервера | + +## Аутентификация + +Все запросы к API требуют JWT-токен в заголовке Authorization: + +``` +Authorization: Bearer +``` + +Токен можно получить через сервис авторизации (User Service) по эндпоинту `/api/v1/auth/login`. + +## Интеграции + +Сервис питания интегрирован с API FatSecret для получения данных о продуктах питания и их пищевой ценности. Работа с API FatSecret осуществляется через OAuth 1.0 аутентификацию с использованием ключей, указанных в конфигурации приложения. + +## Тестирование API + +### Тестирование через Swagger UI + +Самый простой способ протестировать API - использовать встроенный интерфейс Swagger UI: + +1. Убедитесь, что сервис питания запущен: +```bash +# Запуск всех сервисов +./start_services.sh +``` + +2. Откройте в браузере URL: `http://localhost:8006/docs` + +3. Авторизуйтесь: + - Нажмите на кнопку "Authorize" в правом верхнем углу + - Введите ваш JWT токен в формате `Bearer ` + - Нажмите "Authorize" + +4. Теперь вы можете интерактивно тестировать все эндпоинты: + - Выберите нужный эндпоинт + - Заполните параметры запроса + - Нажмите "Execute" + - Просмотрите результат запроса и код ответа + +### Настройка и запуск через CLI + +1. Убедитесь, что все необходимые сервисы запущены: +```bash +# Запуск всех сервисов +./start_services.sh +``` + +2. Получите токен аутентификации: +```bash +# Регистрация нового пользователя +curl -X POST http://localhost:8001/api/v1/auth/register -H "Content-Type: application/json" -d '{ + "email": "test_user@example.com", + "username": "test_user", + "password": "Test123!", + "first_name": "Test", + "last_name": "User", + "phone": "+79991234567" +}' | jq + +# Вход и получение токена +curl -X POST http://localhost:8001/api/v1/auth/login -H "Content-Type: application/json" -d '{ + "username": "test_user", + "password": "Test123!" +}' | jq +``` + +3. Сохраните полученный токен в переменную для дальнейшего использования: +```bash +export TOKEN="ваш_полученный_jwt_токен" +``` + +### Примеры запросов + +#### Поиск продуктов + +```bash +# Поиск продуктов по названию +curl -X POST http://localhost:8006/api/v1/nutrition/search \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "query": "apple", + "max_results": 5 + }' | jq +``` + +#### Работа с дневником питания + +```bash +# Добавление записи в дневник питания +curl -X POST http://localhost:8006/api/v1/nutrition/diary \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "food_item_id": 1, + "entry_date": "2025-10-16", + "meal_type": "breakfast", + "quantity": 1.5, + "unit": "piece", + "notes": "Morning apple" + }' | jq + +# Получение дневника за день +curl -X GET http://localhost:8006/api/v1/nutrition/diary?date=2025-10-16 \ + -H "Authorization: Bearer $TOKEN" | jq +``` + +#### Работа с водой + +```bash +# Добавление записи о потреблении воды +curl -X POST http://localhost:8006/api/v1/nutrition/water \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "amount_ml": 250, + "entry_date": "2025-10-16", + "notes": "Morning glass" + }' | jq + +# Получение записей о потреблении воды за день +curl -X GET http://localhost:8006/api/v1/nutrition/water?date=2025-10-16 \ + -H "Authorization: Bearer $TOKEN" | jq +``` + +#### Работа с активностью + +```bash +# Добавление записи о физической активности +curl -X POST http://localhost:8006/api/v1/nutrition/activity \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $TOKEN" \ + -d '{ + "entry_date": "2025-10-16", + "activity_type": "running", + "duration_minutes": 30, + "distance_km": 5.2, + "intensity": "medium", + "notes": "Morning run" + }' | jq + +# Получение записей об активности за день +curl -X GET http://localhost:8006/api/v1/nutrition/activity?date=2025-10-16 \ + -H "Authorization: Bearer $TOKEN" | jq +``` + +### Автоматизированное тестирование + +В папке `tests` есть скрипты для автоматизированного тестирования API: + +```bash +# Запуск всех тестов для nutrition service +cd tests +./test_nutrition_service.sh + +# Запуск тестов через Python +python test_nutrition_api.py +``` + +Для непосредственного тестирования FatSecret API можно использовать скрипт в корне проекта: + +```bash +# Тестирование FatSecret API +python test_fatsecret_api_oauth1.py +``` \ No newline at end of file diff --git a/docs/NUTRITION_SERVICE_API.md b/docs/NUTRITION_SERVICE_API.md new file mode 100644 index 0000000..367078b --- /dev/null +++ b/docs/NUTRITION_SERVICE_API.md @@ -0,0 +1,188 @@ +# Nutrition Service API Documentation + +## Overview + +Nutrition Service предоставляет API для отслеживания питания, подсчета калорий и получения информации о продуктах питания через интеграцию с FatSecret API. Сервис позволяет пользователям контролировать свой рацион и отслеживать потребление воды. + +**Base URL:** `/api/v1/nutrition` + +## Authentication + +Все эндпоинты требуют JWT аутентификацию. + +**Headers:** +``` +Authorization: Bearer +``` + +## API Endpoints + +### 🔍 Поиск продуктов + +#### Найти продукты по названию +```http +GET /api/v1/nutrition/foods?query=яблоко +Authorization: Bearer +``` + +**Параметры:** +- `query` (string, required): Поисковый запрос для поиска продуктов +- `page` (number, optional): Номер страницы результатов, по умолчанию 1 +- `page_size` (number, optional): Количество результатов на странице, по умолчанию 20 + +**Response:** +```json +{ + "results": [ + { + "food_id": "123456", + "name": "Яблоко, сырое, с кожурой", + "brand": "", + "calories": 52, + "serving_size": "100г", + "nutrients": { + "carbohydrates": 13.8, + "protein": 0.3, + "fat": 0.2, + "fiber": 2.4 + } + } + ], + "total": 25, + "page": 1, + "page_size": 20 +} +``` + +### 📝 Записи о питании + +#### Добавить запись о питании +```http +POST /api/v1/nutrition/entries +Authorization: Bearer +``` + +**Body:** +```json +{ + "food_id": "123456", + "date": "2025-10-16", + "meal_type": "lunch", + "quantity": 1.0, + "serving_size": "100г", + "notes": "Красное яблоко" +} +``` + +**Варианты типов приема пищи (meal_type):** +- `breakfast` - завтрак +- `lunch` - обед +- `dinner` - ужин +- `snack` - перекус + +#### Получить записи о питании +```http +GET /api/v1/nutrition/entries?date=2025-10-16 +Authorization: Bearer +``` + +**Параметры:** +- `date` (string, optional): Дата в формате YYYY-MM-DD +- `start_date` (string, optional): Начальная дата для получения записей за период +- `end_date` (string, optional): Конечная дата для получения записей за период +- `meal_type` (string, optional): Фильтр по типу приема пищи + +#### Удалить запись о питании +```http +DELETE /api/v1/nutrition/entries/{entry_id} +Authorization: Bearer +``` + +### 💧 Отслеживание воды + +#### Добавить запись о потреблении воды +```http +POST /api/v1/nutrition/water +Authorization: Bearer +``` + +**Body:** +```json +{ + "date": "2025-10-16", + "amount_ml": 250, + "time": "12:30:00" +} +``` + +#### Получить записи о потреблении воды +```http +GET /api/v1/nutrition/water?date=2025-10-16 +Authorization: Bearer +``` + +### 📊 Сводки и статистика + +#### Получить дневную сводку по питанию +```http +GET /api/v1/nutrition/daily-summary?date=2025-10-16 +Authorization: Bearer +``` + +**Response:** +```json +{ + "date": "2025-10-16", + "total_calories": 1578, + "total_carbohydrates": 175.3, + "total_proteins": 78.2, + "total_fats": 52.8, + "total_water": 1200, + "entries": [ + { + "id": 123, + "food_name": "Яблоко, сырое, с кожурой", + "meal_type": "lunch", + "calories": 52, + "quantity": 1.0, + "serving_size": "100г" + } + ] +} +``` + +#### Получить недельную аналитику +```http +GET /api/v1/nutrition/weekly-summary?start_date=2025-10-10 +Authorization: Bearer +``` + +## Интеграция с FatSecret API + +Сервис использует FatSecret API для получения информации о питательной ценности продуктов. Ключи API хранятся в конфигурации сервера и не требуют дополнительной настройки со стороны клиента. + +## Примеры использования + +### JavaScript +```javascript +// Пример поиска продуктов +async function searchFoods(query) { + const response = await fetch(`http://localhost:8000/api/v1/nutrition/foods?query=${query}`, { + headers: { 'Authorization': `Bearer ${token}` } + }); + return response.json(); +} + +// Пример добавления записи о питании +async function addNutritionEntry(entryData) { + const response = await fetch('http://localhost:8000/api/v1/nutrition/entries', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${token}` + }, + body: JSON.stringify(entryData) + }); + return response.json(); +} +``` \ No newline at end of file diff --git a/docs/PROJECT_STRUCTURE.md b/docs/PROJECT_STRUCTURE.md index 84d5e27..990f4d9 100644 --- a/docs/PROJECT_STRUCTURE.md +++ b/docs/PROJECT_STRUCTURE.md @@ -20,8 +20,13 @@ women-safety-backend/ │ ├── 📁 calendar_service/ │ │ ├── main.py # Calendar Service (8004) │ │ └── models.py # Calendar models -│ └── 📁 notification_service/ -│ └── main.py # Notification Service (8005) +│ ├── 📁 notification_service/ +│ │ └── main.py # Notification Service (8005) +│ └── 📁 nutrition_service/ +│ ├── main.py # Nutrition Service (8006) +│ ├── models.py # Nutrition models +│ ├── schemas.py # Nutrition schemas +│ └── fatsecret_client.py # FatSecret API client │ ├── 📁 shared/ # Общие компоненты │ ├── config.py # Конфигурация приложения diff --git a/integrate_nutrition_service.sh b/integrate_nutrition_service.sh new file mode 100755 index 0000000..692f548 --- /dev/null +++ b/integrate_nutrition_service.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Скрипт для интеграции сервиса питания в docker-compose.prod.yml + +echo "Интеграция сервиса питания в docker-compose.prod.yml..." + +# Находим место для вставки сервиса питания (после последнего определения сервиса) +LAST_SERVICE=$(grep -n "^ [a-zA-Z].*:" docker-compose.prod.yml | tail -1 | cut -d':' -f1) + +# Вставляем определение сервиса питания после последнего сервиса и перед volumes +sed -i "${LAST_SERVICE}r nutrition-service-prod.yml" docker-compose.prod.yml + +echo "Готово! Сервис питания добавлен в docker-compose.prod.yml" \ No newline at end of file diff --git a/nutrition-service-prod.yml b/nutrition-service-prod.yml new file mode 100644 index 0000000..da64214 --- /dev/null +++ b/nutrition-service-prod.yml @@ -0,0 +1,44 @@ + # Nutrition Service Cluster + nutrition-service-1: + image: women-safety/nutrition-service:${TAG:-latest} + environment: + - NODE_ID=1 + - DATABASE_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-primary:5432/women_safety_prod + - DATABASE_REPLICA_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-replica:5432/women_safety_prod + - REDIS_URL=redis://redis-cluster:6379/5 + - FATSECRET_CLIENT_ID=${FATSECRET_CLIENT_ID} + - FATSECRET_CLIENT_SECRET=${FATSECRET_CLIENT_SECRET} + depends_on: + - postgres-primary + - redis-cluster + restart: always + deploy: + resources: + limits: + cpus: '1.0' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M + + nutrition-service-2: + image: women-safety/nutrition-service:${TAG:-latest} + environment: + - NODE_ID=2 + - DATABASE_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-primary:5432/women_safety_prod + - DATABASE_REPLICA_URL=postgresql://postgres:${POSTGRES_PASSWORD}@postgres-replica:5432/women_safety_prod + - REDIS_URL=redis://redis-cluster:6379/5 + - FATSECRET_CLIENT_ID=${FATSECRET_CLIENT_ID} + - FATSECRET_CLIENT_SECRET=${FATSECRET_CLIENT_SECRET} + depends_on: + - postgres-primary + - redis-cluster + restart: always + deploy: + resources: + limits: + cpus: '1.0' + memory: 2G + reservations: + cpus: '0.5' + memory: 512M \ No newline at end of file diff --git a/services/api_gateway/main.py b/services/api_gateway/main.py index 26eccb2..9c066a4 100644 --- a/services/api_gateway/main.py +++ b/services/api_gateway/main.py @@ -59,6 +59,7 @@ SERVICES = { "location": os.getenv("LOCATION_SERVICE_URL", "http://localhost:8003"), "calendar": os.getenv("CALENDAR_SERVICE_URL", "http://localhost:8004"), "notifications": os.getenv("NOTIFICATION_SERVICE_URL", "http://localhost:8005"), + "nutrition": os.getenv("NUTRITION_SERVICE_URL", "http://localhost:8006"), } # Rate limiting (simple in-memory implementation) @@ -732,6 +733,7 @@ async def root(): "location": "/api/v1/locations/update, /api/v1/locations/safe-places", "calendar": "/api/v1/calendar/entries, /api/v1/calendar/cycle-overview", "notifications": "/api/v1/notifications/devices, /api/v1/notifications/history", + "nutrition": "/api/v1/nutrition/foods, /api/v1/nutrition/daily-summary", }, "docs": "/docs", } diff --git a/venv/lib/python3.12/site-packages/celery/contrib/django/__init__.py b/services/nutrition_service/__init__.py similarity index 100% rename from venv/lib/python3.12/site-packages/celery/contrib/django/__init__.py rename to services/nutrition_service/__init__.py diff --git a/services/nutrition_service/fatsecret_client.py b/services/nutrition_service/fatsecret_client.py new file mode 100644 index 0000000..1286148 --- /dev/null +++ b/services/nutrition_service/fatsecret_client.py @@ -0,0 +1,199 @@ +import base64 +import hashlib +import hmac +import json +import logging +import os +import random +import time +import urllib.parse +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Union + +import httpx + +from shared.config import settings + +logger = logging.getLogger(__name__) + + +class FatSecretClient: + """Клиент для работы с API FatSecret""" + + BASE_URL = "https://platform.fatsecret.com/rest/server.api" + + def __init__(self): + """Инициализация клиента для работы с API FatSecret""" + # Используем CUSTOMER_KEY для OAuth 1.0, если он доступен, иначе CLIENT_ID + self.api_key = settings.FATSECRET_CUSTOMER_KEY or settings.FATSECRET_CLIENT_ID + self.api_secret = settings.FATSECRET_CLIENT_SECRET + + # Логируем информацию о ключах (без полного раскрытия) + logger.info(f"FatSecretClient initialized with key: {self.api_key[:8]}...") + + def _generate_oauth_params(self, http_method: str, url: str, params: Dict[str, Any]) -> Dict[str, Any]: + """Создание и подписание OAuth 1.0 параметров""" + # Текущее время в секундах + timestamp = str(int(time.time())) + # Случайная строка для nonce + nonce = ''.join([str(random.randint(0, 9)) for _ in range(8)]) + + # Базовый набор параметров OAuth + oauth_params = { + 'oauth_consumer_key': self.api_key, + 'oauth_nonce': nonce, + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': timestamp, + 'oauth_version': '1.0' + } + + # Объединяем с параметрами запроса + all_params = {**params, **oauth_params} + + # Сортируем параметры по ключу + sorted_params = sorted(all_params.items()) + + # Создаем строку параметров для подписи + param_string = "&".join([ + f"{urllib.parse.quote(str(k), safe='')}={urllib.parse.quote(str(v), safe='')}" + for k, v in sorted_params + ]) + + # Создаем строку для подписи + signature_base = f"{http_method}&{urllib.parse.quote(url, safe='')}&{urllib.parse.quote(param_string, safe='')}" + + # Создаем ключ для подписи + signing_key = f"{urllib.parse.quote(str(self.api_secret), safe='')}&" + + # Создаем HMAC-SHA1 подпись + signature = base64.b64encode( + hmac.new( + signing_key.encode(), + signature_base.encode(), + hashlib.sha1 + ).digest() + ).decode() + + # Добавляем подпись к параметрам OAuth + all_params['oauth_signature'] = signature + + return all_params + + async def search_foods(self, query: str, page_number: int = 0, max_results: int = 10) -> Dict[str, Any]: + """Поиск продуктов по запросу""" + params = { + 'method': 'foods.search', + 'search_expression': query, + 'page_number': str(page_number), + 'max_results': str(max_results), + 'format': 'json' + } + + # Получаем подписанные OAuth параметры + oauth_params = self._generate_oauth_params("GET", self.BASE_URL, params) + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + self.BASE_URL, + params=oauth_params + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error searching foods: {e}") + raise + + async def get_food_details(self, food_id: Union[str, int]) -> Dict[str, Any]: + """Получить детальную информацию о продукте по ID""" + params = { + 'method': 'food.get.v2', + 'food_id': str(food_id), + 'format': 'json' + } + + # Получаем подписанные OAuth параметры + oauth_params = self._generate_oauth_params("GET", self.BASE_URL, params) + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + self.BASE_URL, + params=oauth_params + ) + response.raise_for_status() + return response.json() + except Exception as e: + logger.error(f"Error getting food details: {e}") + raise + + async def parse_food_data(self, food_json: Dict[str, Any]) -> Dict[str, Any]: + """Разбирает данные о продукте из API в более удобный формат""" + try: + food = food_json.get('food', {}) + + # Извлечение основной информации о продукте + food_id = food.get('food_id') + food_name = food.get('food_name', '') + food_type = food.get('food_type', '') + brand_name = food.get('brand_name', '') + + # Обработка информации о питании + servings = food.get('servings', {}).get('serving', []) + + # Если есть только одна порция, преобразуем ее в список + if isinstance(servings, dict): + servings = [servings] + + # Берем первую порцию по умолчанию (обычно это 100г или стандартная порция) + serving_data = {} + for serving in servings: + if serving.get('is_default_serving', 0) == "1" or serving.get('serving_description', '').lower() == '100g': + serving_data = serving + break + + # Если не нашли стандартную порцию, берем первую + if not serving_data and servings: + serving_data = servings[0] + + # Извлечение данных о пищевой ценности + serving_description = serving_data.get('serving_description', '') + serving_amount = serving_data.get('metric_serving_amount', serving_data.get('serving_amount', '')) + serving_unit = serving_data.get('metric_serving_unit', serving_data.get('serving_unit', '')) + + # Формирование читаемого текста размера порции + serving_size = f"{serving_amount} {serving_unit}" if serving_amount and serving_unit else serving_description + + # Извлечение данных о пищевой ценности + calories = float(serving_data.get('calories', 0) or 0) + protein = float(serving_data.get('protein', 0) or 0) + fat = float(serving_data.get('fat', 0) or 0) + carbs = float(serving_data.get('carbohydrate', 0) or 0) + fiber = float(serving_data.get('fiber', 0) or 0) + sugar = float(serving_data.get('sugar', 0) or 0) + sodium = float(serving_data.get('sodium', 0) or 0) + cholesterol = float(serving_data.get('cholesterol', 0) or 0) + + # Формирование результата + result = { + "fatsecret_id": food_id, + "name": food_name, + "brand": brand_name, + "food_type": food_type, + "serving_size": serving_size, + "serving_weight_grams": float(serving_amount) if serving_unit == 'g' else None, + "calories": calories, + "protein_grams": protein, + "fat_grams": fat, + "carbs_grams": carbs, + "fiber_grams": fiber, + "sugar_grams": sugar, + "sodium_mg": sodium, + "cholesterol_mg": cholesterol, + "is_verified": True + } + + return result + except Exception as e: + logger.error(f"Error parsing food data: {e}") + raise \ No newline at end of file diff --git a/services/nutrition_service/main.py b/services/nutrition_service/main.py new file mode 100644 index 0000000..6bfb0da --- /dev/null +++ b/services/nutrition_service/main.py @@ -0,0 +1,462 @@ +from datetime import date, datetime, timedelta +from typing import Dict, List, Optional, Any + +from fastapi import Depends, FastAPI, HTTPException, Query, Path, status +from fastapi.middleware.cors import CORSMiddleware +from sqlalchemy import and_, desc, select, func, text +from sqlalchemy.ext.asyncio import AsyncSession + +from services.nutrition_service.models import ( + FoodItem, UserNutritionEntry, WaterIntake, + UserActivityEntry, NutritionGoal +) +from services.nutrition_service.schemas import ( + FoodItemCreate, FoodItemResponse, UserNutritionEntryCreate, + UserNutritionEntryResponse, WaterIntakeCreate, WaterIntakeResponse, + UserActivityEntryCreate, UserActivityEntryResponse, + NutritionGoalCreate, NutritionGoalResponse, + FoodSearchQuery, FoodDetailsQuery, DailyNutritionSummary +) +from services.nutrition_service.fatsecret_client import FatSecretClient +from shared.auth import get_current_user_from_token +from shared.config import settings +from shared.database import get_db + +app = FastAPI(title="Nutrition Service", version="1.0.0") + +# CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=settings.CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Создаем клиент FatSecret +fatsecret_client = FatSecretClient() + + +@app.get("/health") +async def health_check(): + """Health check endpoint""" + return {"status": "healthy", "service": "nutrition_service"} + + +# Эндпоинты для работы с API FatSecret +@app.post("/api/v1/nutrition/search", response_model=List[FoodItemResponse]) +async def search_foods( + search_query: FoodSearchQuery, + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Поиск продуктов питания по запросу в FatSecret API""" + try: + # Вызов API FatSecret для поиска продуктов + search_results = await fatsecret_client.search_foods( + search_query.query, + search_query.page_number, + search_query.max_results + ) + + # Обработка результатов поиска + foods = [] + if 'foods' in search_results and 'food' in search_results['foods']: + food_list = search_results['foods']['food'] + # Если результат всего один, API возвращает словарь вместо списка + if isinstance(food_list, dict): + food_list = [food_list] + + for food in food_list: + # Получение деталей о продукте + food_details = await fatsecret_client.get_food_details(food['food_id']) + parsed_food = await fatsecret_client.parse_food_data(food_details) + + # Проверяем, существует ли продукт в базе данных + query = select(FoodItem).where(FoodItem.fatsecret_id == parsed_food['fatsecret_id']) + result = await db.execute(query) + db_food = result.scalars().first() + + # Если продукт не существует, сохраняем его + if not db_food: + db_food = FoodItem(**parsed_food) + db.add(db_food) + await db.commit() + await db.refresh(db_food) + + foods.append(FoodItemResponse.model_validate(db_food)) + + return foods + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error searching foods: {str(e)}" + ) + + +@app.get("/api/v1/nutrition/food/{food_id}", response_model=FoodItemResponse) +async def get_food_details( + food_id: int = Path(..., description="ID продукта в базе данных"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение детальной информации о продукте по ID из базы данных""" + query = select(FoodItem).where(FoodItem.id == food_id) + result = await db.execute(query) + food = result.scalars().first() + + if not food: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Food item not found" + ) + + return FoodItemResponse.model_validate(food) + + +@app.get("/api/v1/nutrition/fatsecret/{fatsecret_id}", response_model=FoodItemResponse) +async def get_food_by_fatsecret_id( + fatsecret_id: str = Path(..., description="ID продукта в FatSecret"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение детальной информации о продукте по FatSecret ID""" + # Проверяем, есть ли продукт в нашей базе данных + query = select(FoodItem).where(FoodItem.fatsecret_id == fatsecret_id) + result = await db.execute(query) + food = result.scalars().first() + + # Если продукт не найден в базе, запрашиваем его с FatSecret API + if not food: + try: + food_details = await fatsecret_client.get_food_details(fatsecret_id) + parsed_food = await fatsecret_client.parse_food_data(food_details) + + food = FoodItem(**parsed_food) + db.add(food) + await db.commit() + await db.refresh(food) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error fetching food details: {str(e)}" + ) + + return FoodItemResponse.model_validate(food) + + +# Эндпоинты для работы с записями питания пользователя +@app.post("/api/v1/nutrition/entries", response_model=UserNutritionEntryResponse) +async def create_nutrition_entry( + entry_data: UserNutritionEntryCreate, + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Создание новой записи о питании пользователя""" + # Получаем ID пользователя из токена + user_id = user_data["user_id"] + + # Если указан ID продукта, проверяем его наличие + food_item = None + if entry_data.food_item_id: + query = select(FoodItem).where(FoodItem.id == entry_data.food_item_id) + result = await db.execute(query) + food_item = result.scalars().first() + + if not food_item: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Food item not found" + ) + + # Создаем данные для записи + nutrition_data = entry_data.model_dump(exclude={"food_item_id"}) + nutrition_entry = UserNutritionEntry(**nutrition_data, user_id=user_id) + + if food_item: + nutrition_entry.food_item_id = food_item.id + + # Если питательные данные не указаны, рассчитываем их на основе продукта + if not entry_data.calories and food_item.calories: + nutrition_entry.calories = food_item.calories * entry_data.quantity + if not entry_data.protein_grams and food_item.protein_grams: + nutrition_entry.protein_grams = food_item.protein_grams * entry_data.quantity + if not entry_data.fat_grams and food_item.fat_grams: + nutrition_entry.fat_grams = food_item.fat_grams * entry_data.quantity + if not entry_data.carbs_grams and food_item.carbs_grams: + nutrition_entry.carbs_grams = food_item.carbs_grams * entry_data.quantity + + db.add(nutrition_entry) + await db.commit() + await db.refresh(nutrition_entry) + + return UserNutritionEntryResponse.model_validate(nutrition_entry) + + +@app.get("/api/v1/nutrition/entries", response_model=List[UserNutritionEntryResponse]) +async def get_user_nutrition_entries( + start_date: date = Query(..., description="Начальная дата для выборки"), + end_date: date = Query(..., description="Конечная дата для выборки"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение записей о питании пользователя за указанный период""" + user_id = user_data["user_id"] + + query = ( + select(UserNutritionEntry) + .where( + and_( + UserNutritionEntry.user_id == user_id, + UserNutritionEntry.entry_date >= start_date, + UserNutritionEntry.entry_date <= end_date + ) + ) + .order_by(UserNutritionEntry.entry_date, UserNutritionEntry.meal_type) + ) + + result = await db.execute(query) + entries = result.scalars().all() + + return [UserNutritionEntryResponse.model_validate(entry) for entry in entries] + + +# Эндпоинты для работы с записями о потреблении воды +@app.post("/api/v1/nutrition/water", response_model=WaterIntakeResponse) +async def create_water_intake( + intake_data: WaterIntakeCreate, + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Создание новой записи о потреблении воды""" + user_id = user_data["user_id"] + + water_intake = WaterIntake(**intake_data.model_dump(), user_id=user_id) + db.add(water_intake) + await db.commit() + await db.refresh(water_intake) + + return WaterIntakeResponse.model_validate(water_intake) + + +@app.get("/api/v1/nutrition/water", response_model=List[WaterIntakeResponse]) +async def get_user_water_intake( + start_date: date = Query(..., description="Начальная дата для выборки"), + end_date: date = Query(..., description="Конечная дата для выборки"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение записей о потреблении воды за указанный период""" + user_id = user_data["user_id"] + + query = ( + select(WaterIntake) + .where( + and_( + WaterIntake.user_id == user_id, + WaterIntake.entry_date >= start_date, + WaterIntake.entry_date <= end_date + ) + ) + .order_by(WaterIntake.entry_date, WaterIntake.entry_time) + ) + + result = await db.execute(query) + entries = result.scalars().all() + + return [WaterIntakeResponse.model_validate(entry) for entry in entries] + + +# Эндпоинты для работы с записями о физической активности +@app.post("/api/v1/nutrition/activity", response_model=UserActivityEntryResponse) +async def create_activity_entry( + activity_data: UserActivityEntryCreate, + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Создание новой записи о физической активности""" + user_id = user_data["user_id"] + + # Если не указаны сожженные калории, рассчитываем примерно + if not activity_data.calories_burned: + # Простой расчет на основе типа активности и продолжительности + # Точный расчет требует больше параметров (вес, рост, возраст, пол) + activity_intensity = { + "walking": 5, # ккал/мин + "running": 10, + "cycling": 8, + "swimming": 9, + "yoga": 4, + "weight_training": 6, + "hiit": 12, + "pilates": 5, + } + + activity_type = activity_data.activity_type.lower() + intensity = activity_intensity.get(activity_type, 5) # По умолчанию 5 ккал/мин + + # Увеличиваем интенсивность в зависимости от указанной интенсивности + if activity_data.intensity == "high": + intensity *= 1.5 + elif activity_data.intensity == "low": + intensity *= 0.8 + + calories_burned = intensity * activity_data.duration_minutes + activity_data.calories_burned = round(calories_burned, 1) + + activity_entry = UserActivityEntry(**activity_data.model_dump(), user_id=user_id) + db.add(activity_entry) + await db.commit() + await db.refresh(activity_entry) + + return UserActivityEntryResponse.model_validate(activity_entry) + + +@app.get("/api/v1/nutrition/activity", response_model=List[UserActivityEntryResponse]) +async def get_user_activities( + start_date: date = Query(..., description="Начальная дата для выборки"), + end_date: date = Query(..., description="Конечная дата для выборки"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение записей о физической активности за указанный период""" + user_id = user_data["user_id"] + + query = ( + select(UserActivityEntry) + .where( + and_( + UserActivityEntry.user_id == user_id, + UserActivityEntry.entry_date >= start_date, + UserActivityEntry.entry_date <= end_date + ) + ) + .order_by(UserActivityEntry.entry_date, UserActivityEntry.created_at) + ) + + result = await db.execute(query) + entries = result.scalars().all() + + return [UserActivityEntryResponse.model_validate(entry) for entry in entries] + + +# Эндпоинты для работы с целями питания +@app.post("/api/v1/nutrition/goals", response_model=NutritionGoalResponse) +async def create_or_update_nutrition_goals( + goal_data: NutritionGoalCreate, + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Создание или обновление целей по питанию и активности""" + user_id = user_data["user_id"] + + # Проверяем, существуют ли уже цели для пользователя + query = select(NutritionGoal).where(NutritionGoal.user_id == user_id) + result = await db.execute(query) + existing_goal = result.scalars().first() + + if existing_goal: + # Обновляем существующую цель + for key, value in goal_data.model_dump(exclude_unset=True).items(): + setattr(existing_goal, key, value) + await db.commit() + await db.refresh(existing_goal) + return NutritionGoalResponse.model_validate(existing_goal) + else: + # Создаем новую цель + new_goal = NutritionGoal(**goal_data.model_dump(), user_id=user_id) + db.add(new_goal) + await db.commit() + await db.refresh(new_goal) + return NutritionGoalResponse.model_validate(new_goal) + + +@app.get("/api/v1/nutrition/goals", response_model=NutritionGoalResponse) +async def get_nutrition_goals( + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение целей пользователя по питанию и активности""" + user_id = user_data["user_id"] + + query = select(NutritionGoal).where(NutritionGoal.user_id == user_id) + result = await db.execute(query) + goal = result.scalars().first() + + if not goal: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Nutrition goals not found for this user" + ) + + return NutritionGoalResponse.model_validate(goal) + + +# Сводные отчеты +@app.get("/api/v1/nutrition/summary/daily", response_model=DailyNutritionSummary) +async def get_daily_nutrition_summary( + target_date: date = Query(..., description="Дата для получения сводки"), + user_data: dict = Depends(get_current_user_from_token), + db: AsyncSession = Depends(get_db) +): + """Получение дневной сводки по питанию, потреблению воды и физической активности""" + user_id = user_data["user_id"] + + # Запрос записей о питании + meals_query = select(UserNutritionEntry).where( + and_( + UserNutritionEntry.user_id == user_id, + UserNutritionEntry.entry_date == target_date + ) + ).order_by(UserNutritionEntry.meal_type) + + meals_result = await db.execute(meals_query) + meals = meals_result.scalars().all() + + # Запрос записей о воде + water_query = select(WaterIntake).where( + and_( + WaterIntake.user_id == user_id, + WaterIntake.entry_date == target_date + ) + ).order_by(WaterIntake.entry_time) + + water_result = await db.execute(water_query) + water_entries = water_result.scalars().all() + + # Запрос записей об активности + activity_query = select(UserActivityEntry).where( + and_( + UserActivityEntry.user_id == user_id, + UserActivityEntry.entry_date == target_date + ) + ).order_by(UserActivityEntry.created_at) + + activity_result = await db.execute(activity_query) + activity_entries = activity_result.scalars().all() + + # Расчет суммарных значений + total_calories = sum(meal.calories or 0 for meal in meals) + total_protein = sum(meal.protein_grams or 0 for meal in meals) + total_fat = sum(meal.fat_grams or 0 for meal in meals) + total_carbs = sum(meal.carbs_grams or 0 for meal in meals) + total_water = sum(water.amount_ml for water in water_entries) + total_activity = sum(activity.duration_minutes for activity in activity_entries) + calories_burned = sum(activity.calories_burned or 0 for activity in activity_entries) + + # Формирование ответа + summary = DailyNutritionSummary( + date=target_date, + total_calories=total_calories, + total_protein_grams=total_protein, + total_fat_grams=total_fat, + total_carbs_grams=total_carbs, + total_water_ml=total_water, + total_activity_minutes=total_activity, + estimated_calories_burned=calories_burned, + meals=[UserNutritionEntryResponse.model_validate(meal) for meal in meals], + water_entries=[WaterIntakeResponse.model_validate(water) for water in water_entries], + activity_entries=[UserActivityEntryResponse.model_validate(activity) for activity in activity_entries] + ) + + return summary \ No newline at end of file diff --git a/services/nutrition_service/models.py b/services/nutrition_service/models.py new file mode 100644 index 0000000..792415f --- /dev/null +++ b/services/nutrition_service/models.py @@ -0,0 +1,146 @@ +import uuid + +from sqlalchemy import Boolean, Column, Date, Float, Integer, String, Text, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import func +from sqlalchemy.sql.expression import text +from sqlalchemy.sql.sqltypes import TIMESTAMP + +from shared.database import BaseModel + + +class FoodItem(BaseModel): + """Модель для хранения информации о продуктах питания""" + __tablename__ = "food_items" + + uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True) + + # Основная информация о продукте + fatsecret_id = Column(String(50), unique=True, index=True, nullable=True) # ID продукта в FatSecret + name = Column(String(255), nullable=False) + brand = Column(String(255), nullable=True) + description = Column(Text, nullable=True) + food_type = Column(String(50), nullable=True) # generic, branded, etc. + serving_size = Column(String(100), nullable=True) # e.g. "1 cup" or "100g" + serving_weight_grams = Column(Float, nullable=True) + + # Пищевая ценность на порцию + calories = Column(Float, nullable=True) # kcal + protein_grams = Column(Float, nullable=True) + fat_grams = Column(Float, nullable=True) + carbs_grams = Column(Float, nullable=True) + fiber_grams = Column(Float, nullable=True) + sugar_grams = Column(Float, nullable=True) + sodium_mg = Column(Float, nullable=True) + cholesterol_mg = Column(Float, nullable=True) + + # Дополнительная информация + ingredients = Column(Text, nullable=True) + is_verified = Column(Boolean, default=False) # Проверенные данные или пользовательские + created_at = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now()) + + def __repr__(self): + return f"" + + +class UserNutritionEntry(BaseModel): + """Модель для хранения записей пользователя о потреблении пищи""" + __tablename__ = "user_nutrition_entries" + + uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True) + user_id = Column(Integer, nullable=False, index=True) # Связь с таблицей пользователей + + # Информация о приеме пищи + entry_date = Column(Date, nullable=False, index=True) + meal_type = Column(String(50), nullable=False) # breakfast, lunch, dinner, snack + + food_item_id = Column(Integer, ForeignKey("food_items.id"), nullable=True) + custom_food_name = Column(String(255), nullable=True) # Если продукт не из базы + + # Количество + quantity = Column(Float, nullable=False, default=1.0) + unit = Column(String(50), nullable=True) # g, ml, oz, piece, etc. + + # Рассчитанная пищевая ценность для данного количества + calories = Column(Float, nullable=True) + protein_grams = Column(Float, nullable=True) + fat_grams = Column(Float, nullable=True) + carbs_grams = Column(Float, nullable=True) + + # Метаданные + notes = Column(Text, nullable=True) + created_at = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now()) + + def __repr__(self): + return f"" + + +class WaterIntake(BaseModel): + """Модель для отслеживания потребления воды""" + __tablename__ = "water_intake" + + uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True) + user_id = Column(Integer, nullable=False, index=True) # Связь с таблицей пользователей + + entry_date = Column(Date, nullable=False, index=True) + amount_ml = Column(Integer, nullable=False) # Количество в миллилитрах + entry_time = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + + notes = Column(Text, nullable=True) + + def __repr__(self): + return f"" + + +class UserActivityEntry(BaseModel): + """Модель для отслеживания физической активности""" + __tablename__ = "user_activity_entries" + + uuid = Column(UUID(as_uuid=True), default=uuid.uuid4, unique=True, index=True) + user_id = Column(Integer, nullable=False, index=True) # Связь с таблицей пользователей + + entry_date = Column(Date, nullable=False, index=True) + activity_type = Column(String(100), nullable=False) # walking, running, yoga, etc. + + duration_minutes = Column(Integer, nullable=False) + calories_burned = Column(Float, nullable=True) # Расчетное количество сожженных калорий + + # Дополнительные параметры активности + distance_km = Column(Float, nullable=True) # Для активностей с расстоянием + steps = Column(Integer, nullable=True) # Для ходьбы + intensity = Column(String(20), nullable=True) # low, medium, high + + notes = Column(Text, nullable=True) + created_at = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + + def __repr__(self): + return f"" + + +class NutritionGoal(BaseModel): + """Модель для хранения целей пользователя по питанию и активности""" + __tablename__ = "nutrition_goals" + + user_id = Column(Integer, nullable=False, index=True, unique=True) # Связь с таблицей пользователей + + # Цели по калориям и макронутриентам + daily_calorie_goal = Column(Integer, nullable=True) + protein_goal_grams = Column(Integer, nullable=True) + fat_goal_grams = Column(Integer, nullable=True) + carbs_goal_grams = Column(Integer, nullable=True) + + # Цели по воде и активности + water_goal_ml = Column(Integer, nullable=True, default=2000) # Стандартно 2 литра + activity_goal_minutes = Column(Integer, nullable=True, default=30) # Минимум 30 минут активности + + # Цель по весу и предпочтения + weight_goal_kg = Column(Float, nullable=True) + goal_type = Column(String(50), nullable=True) # lose_weight, maintain, gain_weight, health + + created_at = Column(TIMESTAMP(timezone=True), nullable=False, server_default=func.now()) + updated_at = Column(TIMESTAMP(timezone=True), onupdate=func.now()) + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/services/nutrition_service/schemas.py b/services/nutrition_service/schemas.py new file mode 100644 index 0000000..7eae744 --- /dev/null +++ b/services/nutrition_service/schemas.py @@ -0,0 +1,203 @@ +from datetime import date +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, Field, root_validator + + +class MealType(str, Enum): + BREAKFAST = "breakfast" + LUNCH = "lunch" + DINNER = "dinner" + SNACK = "snack" + + +class ActivityIntensity(str, Enum): + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class GoalType(str, Enum): + LOSE_WEIGHT = "lose_weight" + MAINTAIN = "maintain" + GAIN_WEIGHT = "gain_weight" + HEALTH = "health" + + +# Схемы для FoodItem +class FoodItemBase(BaseModel): + name: str + brand: Optional[str] = None + description: Optional[str] = None + food_type: Optional[str] = None + serving_size: Optional[str] = None + serving_weight_grams: Optional[float] = None + calories: Optional[float] = None + protein_grams: Optional[float] = None + fat_grams: Optional[float] = None + carbs_grams: Optional[float] = None + fiber_grams: Optional[float] = None + sugar_grams: Optional[float] = None + sodium_mg: Optional[float] = None + cholesterol_mg: Optional[float] = None + ingredients: Optional[str] = None + + +class FoodItemCreate(FoodItemBase): + fatsecret_id: Optional[str] = None + is_verified: bool = False + + +class FoodItemResponse(FoodItemBase): + id: int + uuid: str + fatsecret_id: Optional[str] = None + is_verified: bool + created_at: str + updated_at: Optional[str] = None + + class Config: + from_attributes = True + + +# Схемы для UserNutritionEntry +class UserNutritionEntryBase(BaseModel): + entry_date: date + meal_type: MealType + quantity: float = Field(gt=0) + unit: Optional[str] = None + notes: Optional[str] = None + + +class UserNutritionEntryCreate(UserNutritionEntryBase): + food_item_id: Optional[int] = None + custom_food_name: Optional[str] = None + calories: Optional[float] = None + protein_grams: Optional[float] = None + fat_grams: Optional[float] = None + carbs_grams: Optional[float] = None + + @root_validator(skip_on_failure=True) + def check_food_info(cls, values): + food_item_id = values.get("food_item_id") + custom_food_name = values.get("custom_food_name") + + if food_item_id is None and not custom_food_name: + raise ValueError("Either food_item_id or custom_food_name must be provided") + return values + + +class UserNutritionEntryResponse(UserNutritionEntryBase): + id: int + uuid: str + user_id: int + food_item_id: Optional[int] = None + custom_food_name: Optional[str] = None + calories: Optional[float] = None + protein_grams: Optional[float] = None + fat_grams: Optional[float] = None + carbs_grams: Optional[float] = None + created_at: str + + class Config: + from_attributes = True + + +# Схемы для WaterIntake +class WaterIntakeBase(BaseModel): + entry_date: date + amount_ml: int = Field(gt=0) + notes: Optional[str] = None + + +class WaterIntakeCreate(WaterIntakeBase): + pass + + +class WaterIntakeResponse(WaterIntakeBase): + id: int + uuid: str + user_id: int + entry_time: str + + class Config: + from_attributes = True + + +# Схемы для UserActivityEntry +class UserActivityEntryBase(BaseModel): + entry_date: date + activity_type: str + duration_minutes: int = Field(gt=0) + distance_km: Optional[float] = None + steps: Optional[int] = None + intensity: Optional[ActivityIntensity] = None + notes: Optional[str] = None + + +class UserActivityEntryCreate(UserActivityEntryBase): + calories_burned: Optional[float] = None + + +class UserActivityEntryResponse(UserActivityEntryBase): + id: int + uuid: str + user_id: int + calories_burned: Optional[float] = None + created_at: str + + class Config: + from_attributes = True + + +# Схемы для NutritionGoal +class NutritionGoalBase(BaseModel): + daily_calorie_goal: Optional[int] = None + protein_goal_grams: Optional[int] = None + fat_goal_grams: Optional[int] = None + carbs_goal_grams: Optional[int] = None + water_goal_ml: Optional[int] = None + activity_goal_minutes: Optional[int] = None + weight_goal_kg: Optional[float] = None + goal_type: Optional[GoalType] = None + + +class NutritionGoalCreate(NutritionGoalBase): + pass + + +class NutritionGoalResponse(NutritionGoalBase): + id: int + user_id: int + created_at: str + updated_at: Optional[str] = None + + class Config: + from_attributes = True + + +# Схемы для запросов к FatSecret API +class FoodSearchQuery(BaseModel): + query: str + page_number: int = 0 + max_results: int = 10 + + +class FoodDetailsQuery(BaseModel): + food_id: str + + +# Схемы для сводных данных +class DailyNutritionSummary(BaseModel): + date: date + total_calories: float = 0 + total_protein_grams: float = 0 + total_fat_grams: float = 0 + total_carbs_grams: float = 0 + total_water_ml: int = 0 + total_activity_minutes: int = 0 + estimated_calories_burned: float = 0 + meals: List[UserNutritionEntryResponse] = [] + water_entries: List[WaterIntakeResponse] = [] + activity_entries: List[UserActivityEntryResponse] = [] \ No newline at end of file diff --git a/services/user_service/main.py b/services/user_service/main.py index 23de4b9..b4c5105 100644 --- a/services/user_service/main.py +++ b/services/user_service/main.py @@ -85,11 +85,6 @@ async def register_user(user_data: UserCreate, db: AsyncSession = Depends(get_db try: hashed_password = get_password_hash(user_data.password) except ValueError as e: - if "password cannot be longer than 72 bytes" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Password is too long. Please use a shorter password (max 70 characters)." - ) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Password validation error: {str(e)}" diff --git a/services/user_service/schemas.py b/services/user_service/schemas.py index 6ded99b..2b3024b 100644 --- a/services/user_service/schemas.py +++ b/services/user_service/schemas.py @@ -41,15 +41,15 @@ class UserBase(BaseModel): class UserCreate(UserBase): - password: str = Field(..., min_length=8, max_length=70, description="Password (will be truncated to 72 bytes for bcrypt compatibility)") + password: str = Field(..., min_length=8, description="Password for user registration") @field_validator("password") @classmethod def validate_password_bytes(cls, v): - """Ensure password doesn't exceed bcrypt's 72-byte limit.""" - password_bytes = v.encode('utf-8') - if len(password_bytes) > 72: - raise ValueError("Password is too long when encoded as UTF-8 (max 72 bytes for bcrypt)") + """Basic validation for password.""" + # Только проверка минимальной длины + if not v or len(v.strip()) < 8: + raise ValueError("Password must be at least 8 characters") return v @@ -102,17 +102,15 @@ class UserResponse(UserBase): class UserLogin(BaseModel): email: Optional[EmailStr] = None username: Optional[str] = None - password: str = Field(..., max_length=70, description="Password (will be truncated to 72 bytes for bcrypt compatibility)") + password: str = Field(..., min_length=1, description="Password for authentication") @field_validator("password") @classmethod def validate_password_bytes(cls, v): - """Ensure password doesn't exceed bcrypt's 72-byte limit.""" + """Basic password validation.""" if not v or len(v.strip()) == 0: raise ValueError("Password cannot be empty") - password_bytes = v.encode('utf-8') - if len(password_bytes) > 72: - raise ValueError("Password is too long when encoded as UTF-8 (max 72 bytes for bcrypt)") + # Не делаем проверку на максимальную длину - passlib/bcrypt сам справится с ограничениями return v @field_validator("username") diff --git a/shared/auth.py b/shared/auth.py index 9205894..9aedcd4 100644 --- a/shared/auth.py +++ b/shared/auth.py @@ -18,8 +18,13 @@ from shared.config import settings # Suppress bcrypt version warnings logging.getLogger("passlib").setLevel(logging.ERROR) -# Password hashing -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# Password hashing - настройка bcrypt с более надежными параметрами +pwd_context = CryptContext( + schemes=["bcrypt"], + deprecated="auto", + bcrypt__rounds=12, # Стандартное количество раундов + bcrypt__truncate_error=False # Не вызывать ошибку при длинных паролях, а просто обрезать +) # Bearer token scheme security = HTTPBearer() @@ -28,29 +33,32 @@ security = HTTPBearer() def verify_password(plain_password: str, hashed_password: str) -> bool: """Verify a password against its hash. Handle bcrypt compatibility issues.""" try: - # Truncate password to 72 bytes for consistency - password_bytes = plain_password.encode('utf-8') - if len(password_bytes) > 72: - plain_password = password_bytes[:72].decode('utf-8', errors='ignore') - return pwd_context.verify(plain_password, hashed_password) + # Увеличим подробность логов + logging.info(f"Verifying password length: {len(plain_password)} chars") + + # Проверяем пароль с помощью passlib и логируем результат + result = pwd_context.verify(plain_password, hashed_password) + logging.info(f"Password verification result: {result}") + return result except Exception as e: - logging.error(f"Error verifying password: {e}") + logging.error(f"Error verifying password: {e}, hash_type: {hashed_password[:10]}...") return False def get_password_hash(password: str) -> str: - """Get password hash. Truncate password to 72 bytes if necessary for bcrypt compatibility.""" + """Get password hash. Let passlib handle bcrypt compatibility.""" try: - # bcrypt has a 72-byte limit, so truncate if necessary - password_bytes = password.encode('utf-8') - if len(password_bytes) > 72: - logging.warning("Password exceeds bcrypt limit of 72 bytes. Truncating.") - password = password_bytes[:70].decode('utf-8', errors='ignore') - return pwd_context.hash(password) + # Увеличим подробность логов + logging.info(f"Hashing password length: {len(password)} chars") + + # bcrypt автоматически ограничит длину пароля до 72 байт + hashed = pwd_context.hash(password) + logging.info("Password hashed successfully") + return hashed except Exception as e: - # Handle bcrypt compatibility issues + # Логируем ошибку и пробрасываем исключение logging.error(f"Error hashing password: {e}") - raise ValueError("Password hashing failed. Please use a shorter password.") + raise ValueError(f"Password hashing failed: {str(e)}") def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: diff --git a/shared/config.py b/shared/config.py index 47ce2f6..200b31f 100644 --- a/shared/config.py +++ b/shared/config.py @@ -41,6 +41,11 @@ class Settings(BaseSettings): # External Services FCM_SERVER_KEY: Optional[str] = None + + # FatSecret API для данных о питании + FATSECRET_CLIENT_ID: str = "56342dd56fc74b26afb49d65b8f84c16" + FATSECRET_CLIENT_SECRET: str = "fae178f189dc44ddb368cabe9069c0e3" + FATSECRET_CUSTOMER_KEY: Optional[str] = None # Исправляем опечатку в имени параметра # Security CORS_ORIGINS: list = ["*"] # Change in production diff --git a/start_services.sh b/start_services.sh index 15f9459..f4645b9 100755 --- a/start_services.sh +++ b/start_services.sh @@ -89,6 +89,10 @@ echo -e "${YELLOW}Starting Notification Service (port 8005)...${NC}" python -m uvicorn services.notification_service.main:app --port 8005 & NOTIFICATION_PID=$! +echo -e "${YELLOW}Starting Nutrition Service (port 8006)...${NC}" +python -m uvicorn services.nutrition_service.main:app --port 8006 & +NUTRITION_PID=$! + # Wait a bit for services to start sleep 5 @@ -102,6 +106,7 @@ echo $EMERGENCY_PID > emergency_service.pid echo $LOCATION_PID > location_service.pid echo $CALENDAR_PID > calendar_service.pid echo $NOTIFICATION_PID > notification_service.pid +echo $NUTRITION_PID > nutrition_service.pid echo $GATEWAY_PID > api_gateway.pid echo -e "${GREEN}🎉 All services started successfully!${NC}" @@ -112,6 +117,7 @@ echo -e " 🚨 Emergency Service: http://localhost:8002" echo -e " 📍 Location Service: http://localhost:8003" echo -e " 📅 Calendar Service: http://localhost:8004" echo -e " 🔔 Notification Service: http://localhost:8005" +echo -e " 🍎 Nutrition Service: http://localhost:8006" echo -e "${GREEN}📖 API Documentation: http://localhost:8000/docs${NC}" # Keep script running and show logs @@ -127,6 +133,7 @@ cleanup() { if [ -f "location_service.pid" ]; then kill "$(cat location_service.pid)" 2>/dev/null && rm location_service.pid; fi if [ -f "calendar_service.pid" ]; then kill "$(cat calendar_service.pid)" 2>/dev/null && rm calendar_service.pid; fi if [ -f "notification_service.pid" ]; then kill "$(cat notification_service.pid)" 2>/dev/null && rm notification_service.pid; fi + if [ -f "nutrition_service.pid" ]; then kill "$(cat nutrition_service.pid)" 2>/dev/null && rm nutrition_service.pid; fi if [ -f "api_gateway.pid" ]; then kill "$(cat api_gateway.pid)" 2>/dev/null && rm api_gateway.pid; fi echo -e "${GREEN}✅ All services stopped${NC}" diff --git a/start_services_no_docker.sh b/start_services_no_docker.sh index aa27d62..9f14559 100755 --- a/start_services_no_docker.sh +++ b/start_services_no_docker.sh @@ -51,6 +51,7 @@ cleanup() { kill_port 8003 kill_port 8004 kill_port 8005 + kill_port 8006 echo "✅ All services stopped" exit 0 } @@ -66,6 +67,7 @@ kill_port 8002 kill_port 8003 kill_port 8004 kill_port 8005 +kill_port 8006 echo "⏳ Waiting for ports to be freed..." sleep 3 @@ -94,6 +96,10 @@ echo "Starting Calendar Service (port 8004)..." echo "Starting Notification Service (port 8005)..." (cd services/notification_service && PYTHONPATH="${PWD}/../..:${PYTHONPATH}" python -m uvicorn main:app --host 0.0.0.0 --port 8005 --reload) & +# Start Nutrition Service +echo "Starting Nutrition Service (port 8006)..." +(cd services/nutrition_service && PYTHONPATH="${PWD}/../..:${PYTHONPATH}" python -m uvicorn main:app --host 0.0.0.0 --port 8006 --reload) & + # Start API Gateway echo "Starting API Gateway (port 8000)..." (cd services/api_gateway && PYTHONPATH="${PWD}/../..:${PYTHONPATH}" python -m uvicorn main:app --host 0.0.0.0 --port 8000 --reload) & @@ -110,6 +116,7 @@ echo " 🚨 Emergency Service: http://localhost:8002" echo " 📍 Location Service: http://localhost:8003" echo " 📅 Calendar Service: http://localhost:8004" echo " 🔔 Notification Service: http://localhost:8005" +echo " 🍎 Nutrition Service: http://localhost:8006" echo "" echo "📖 API Documentation: http://localhost:8000/docs" echo "📊 Monitoring services... Press Ctrl+C to stop all services" diff --git a/stop_services.sh b/stop_services.sh index a46ac91..40c9a4e 100755 --- a/stop_services.sh +++ b/stop_services.sh @@ -42,8 +42,14 @@ if [ -f "notification_service.pid" ]; then echo -e "${GREEN}✅ Notification Service stopped${NC}" fi +if [ -f "nutrition_service.pid" ]; then + kill "$(cat nutrition_service.pid)" 2>/dev/null + rm nutrition_service.pid + echo -e "${GREEN}✅ Nutrition Service stopped${NC}" +fi + if [ -f "api_gateway.pid" ]; then - kill $(cat api_gateway.pid) 2>/dev/null + kill "$(cat api_gateway.pid)" 2>/dev/null rm api_gateway.pid echo -e "${GREEN}✅ API Gateway stopped${NC}" fi diff --git a/test_fatsecret_api.py b/test_fatsecret_api.py new file mode 100755 index 0000000..104e652 --- /dev/null +++ b/test_fatsecret_api.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +""" +Скрипт для тестирования API FatSecret +Выполняет тестовые запросы к API FatSecret с использованием ключей из конфигурации приложения +""" + +import os +import json +import time +import base64 +import asyncio +import httpx +import urllib.parse +import hmac +import hashlib +from datetime import datetime +from dotenv import load_dotenv + + +# Загружаем .env файл +current_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(current_dir, ".env") +load_dotenv(env_path) +print(f"✅ Loaded .env from: {env_path}") + +# Получаем API ключи из переменных окружения +FATSECRET_CLIENT_ID = os.environ.get("FATSECRET_CLIENT_ID") +FATSECRET_CLIENT_SECRET = os.environ.get("FATSECRET_CLIENT_SECRET") +FATSECRET_CUSTOMER_KEY = os.environ.get("FATSECRET_CUSTOMER_KEY") + +if not FATSECRET_CLIENT_ID or not FATSECRET_CLIENT_SECRET: + raise ValueError("FatSecret API keys not found in .env file") + +print(f"🔑 Using FatSecret API keys: CLIENT_ID={FATSECRET_CLIENT_ID[:8]}...") +if FATSECRET_CUSTOMER_KEY: + print(f"🔑 Using CUSTOMER_KEY={FATSECRET_CUSTOMER_KEY[:8]}...") + + +class FatSecretClient: + """Клиент для работы с API FatSecret""" + + BASE_URL = "https://platform.fatsecret.com/rest/server.api" + + def __init__(self, client_id, client_secret): + self.client_id = client_id + self.client_secret = client_secret + self.access_token = None + self.token_expires = 0 + + async def get_access_token(self): + """Получение OAuth 2.0 токена для доступа к API""" + now = time.time() + + # Если у нас уже есть токен и он не истек, используем его + if self.access_token and self.token_expires > now + 60: + return self.access_token + + print("🔄 Getting new access token...") + + # Подготовка запроса на получение токена + auth_header = base64.b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode() + + print(f"🔑 Using client_id: {self.client_id}") + # Не печатаем секрет полностью, только первые несколько символов для отладки + print(f"🔑 Using client_secret: {self.client_secret[:5]}...") + + async with httpx.AsyncClient() as client: + response = await client.post( + "https://oauth.fatsecret.com/connect/token", + headers={ + "Authorization": f"Basic {auth_header}", + "Content-Type": "application/x-www-form-urlencoded" + }, + data={ + "grant_type": "client_credentials", + "scope": "basic premier" + } + ) + + # Проверяем успешность запроса + if response.status_code != 200: + print(f"❌ Error getting token: {response.status_code}") + print(response.text) + raise Exception(f"Failed to get token: {response.status_code}") + + token_data = response.json() + self.access_token = token_data["access_token"] + self.token_expires = now + token_data["expires_in"] + + print(f"✅ Got token, expires in {token_data['expires_in']} seconds") + return self.access_token + + async def search_food(self, query, page=0, max_results=10): + """Поиск продуктов по названию""" + token = await self.get_access_token() + + async with httpx.AsyncClient() as client: + response = await client.post( + self.BASE_URL, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + }, + json={ + "method": "foods.search", + "search_expression": query, + "page_number": page, + "max_results": max_results, + "format": "json" + } + ) + + if response.status_code != 200: + print(f"❌ Error searching food: {response.status_code}") + print(response.text) + raise Exception(f"Failed to search food: {response.status_code}") + + return response.json() + + async def get_food(self, food_id): + """Получение детальной информации о продукте по ID""" + token = await self.get_access_token() + + async with httpx.AsyncClient() as client: + response = await client.post( + self.BASE_URL, + headers={ + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + }, + json={ + "method": "food.get", + "food_id": food_id, + "format": "json" + } + ) + + if response.status_code != 200: + print(f"❌ Error getting food details: {response.status_code}") + print(response.text) + raise Exception(f"Failed to get food details: {response.status_code}") + + return response.json() + + +async def run_tests(): + """Выполнение тестовых запросов к API FatSecret""" + client = FatSecretClient(FATSECRET_CLIENT_ID, FATSECRET_CLIENT_SECRET) + + # Тест 1: Поиск продуктов + print("\n🔍 Testing food search...") + search_queries = ["apple", "bread", "chicken breast", "молоко"] + + for query in search_queries: + print(f"\n📋 Searching for: {query}") + try: + result = await client.search_food(query) + + # Проверяем структуру ответа + if "foods" not in result: + print(f"❌ Unexpected response format: {result}") + continue + + # Если нет результатов + if "food" not in result["foods"]: + print(f"⚠️ No results found for '{query}'") + continue + + food_list = result["foods"]["food"] + if not isinstance(food_list, list): + food_list = [food_list] # Если только один результат, оборачиваем в список + + print(f"✅ Found {len(food_list)} results") + + # Выводим первые 3 результата + first_food_id = None + for i, food in enumerate(food_list[:3]): + food_name = food.get("food_name", "Unknown") + food_id = food.get("food_id", "Unknown") + food_desc = food.get("food_description", "No description") + + print(f" {i+1}. [{food_id}] {food_name}") + print(f" {food_desc}") + + # Сохраняем ID первого продукта для следующего теста + if i == 0: + first_food_id = food_id + except Exception as e: + print(f"❌ Error during search: {e}") + + # Тест 2: Получение информации о продукте + found_food_id = None + for query in search_queries: + try: + result = await client.search_food(query) + if "foods" in result and "food" in result["foods"]: + food_list = result["foods"]["food"] + if not isinstance(food_list, list): + food_list = [food_list] + if food_list: + found_food_id = food_list[0].get("food_id") + break + except: + continue + + if found_food_id: + print(f"\n🔍 Testing food details for ID: {found_food_id}") + try: + result = await client.get_food(found_food_id) + + if "food" not in result: + print(f"❌ Unexpected response format: {result}") + else: + food = result["food"] + food_name = food.get("food_name", "Unknown") + brand = food.get("brand_name", "Generic") + + print(f"✅ Got details for: {food_name} [{brand}]") + + # Выводим информацию о пищевой ценности + if "servings" in food: + servings = food["servings"] + if "serving" in servings: + serving_data = servings["serving"] + if not isinstance(serving_data, list): + serving_data = [serving_data] + + print("\n📊 Nutrition info per serving:") + for i, serving in enumerate(serving_data[:2]): # Выводим до 2 видов порций + serving_desc = serving.get("serving_description", "Standard") + calories = serving.get("calories", "N/A") + protein = serving.get("protein", "N/A") + carbs = serving.get("carbohydrate", "N/A") + fat = serving.get("fat", "N/A") + + print(f" Serving {i+1}: {serving_desc}") + print(f" Calories: {calories}") + print(f" Protein: {protein}g") + print(f" Carbohydrates: {carbs}g") + print(f" Fat: {fat}g") + except Exception as e: + print(f"❌ Error getting food details: {e}") + + +if __name__ == "__main__": + print("🚀 Starting FatSecret API test...") + asyncio.run(run_tests()) + print("\n✅ Test completed!") \ No newline at end of file diff --git a/test_fatsecret_api_oauth1.py b/test_fatsecret_api_oauth1.py new file mode 100755 index 0000000..35cff1b --- /dev/null +++ b/test_fatsecret_api_oauth1.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +Скрипт для тестирования API FatSecret с использованием OAuth 1.0 +""" + +import os +import time +import hmac +import base64 +import random +import hashlib +import urllib.parse +import requests +from dotenv import load_dotenv + +# Загружаем .env файл +current_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(current_dir, ".env") +load_dotenv(env_path) +print(f"✅ Loaded .env from: {env_path}") + +# Получаем API ключи из переменных окружения +FATSECRET_KEY = os.environ.get("FATSECRET_CUSTOMER_KEY") or os.environ.get("FATSECRET_CLIENT_ID") +FATSECRET_SECRET = os.environ.get("FATSECRET_CLIENT_SECRET") + +if not FATSECRET_KEY or not FATSECRET_SECRET: + raise ValueError("FatSecret API keys not found in .env file") + +print(f"🔑 Using FatSecret API keys: KEY={FATSECRET_KEY[:8]}...") +print(f"🔑 Using FatSecret SECRET (first few chars): {FATSECRET_SECRET[:5]}...") + + +def generate_oauth_params(http_method, url, params): + """Создание и подписание OAuth 1.0 параметров""" + # Текущее время в секундах + timestamp = str(int(time.time())) + # Случайная строка для nonce + nonce = ''.join([str(random.randint(0, 9)) for _ in range(8)]) + + # Базовый набор параметров OAuth + oauth_params = { + 'oauth_consumer_key': FATSECRET_KEY, + 'oauth_nonce': nonce, + 'oauth_signature_method': 'HMAC-SHA1', + 'oauth_timestamp': timestamp, + 'oauth_version': '1.0' + } + + # Объединяем с параметрами запроса + all_params = {**params, **oauth_params} + + # Сортируем параметры по ключу + sorted_params = sorted(all_params.items()) + + # Создаем строку параметров для подписи + param_string = "&".join([f"{urllib.parse.quote(str(k))}={urllib.parse.quote(str(v))}" + for k, v in sorted_params]) + + # Создаем строку для подписи + signature_base = f"{http_method}&{urllib.parse.quote(url, safe='')}&{urllib.parse.quote(param_string, safe='')}" + + # Создаем ключ для подписи + signing_key = f"{urllib.parse.quote(str(FATSECRET_SECRET), safe='')}&" + + # Создаем HMAC-SHA1 подпись + signature = base64.b64encode( + hmac.new( + signing_key.encode(), + signature_base.encode(), + hashlib.sha1 + ).digest() + ).decode() + + # Добавляем подпись к параметрам OAuth + all_params['oauth_signature'] = signature + + return all_params + + +def search_food(query, max_results=5, locale=None): + """Поиск продуктов по названию с использованием OAuth 1.0""" + print(f"\n🔍 Searching for '{query}'{' with locale ' + locale if locale else ''}...") + + # URL для API + url = "https://platform.fatsecret.com/rest/server.api" + + # Параметры запроса + params = { + 'method': 'foods.search', + 'search_expression': query, + 'max_results': max_results, + 'format': 'json' + } + + # Добавляем локаль если указана + if locale: + params['language'] = locale + + # Получаем подписанные OAuth параметры + oauth_params = generate_oauth_params("GET", url, params) + + try: + # Отправляем запрос + response = requests.get(url, params=oauth_params) + + print(f"📥 Response status code: {response.status_code}") + + if response.status_code == 200: + print("✅ Search successful!") + result = response.json() + return result + else: + print(f"❌ Error during search: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Exception during search: {e}") + return None + + +def process_search_results(result): + """Обработка и вывод результатов поиска""" + if not result or "foods" not in result: + print("❌ No valid results found") + return + + foods_data = result["foods"] + + if "food" not in foods_data: + print("⚠️ No food items found") + return + + food_list = foods_data["food"] + if not isinstance(food_list, list): + food_list = [food_list] # Если только один результат, оборачиваем в список + + print(f"📊 Found {len(food_list)} results") + + # Выводим первые 3 результата + for i, food in enumerate(food_list[:3]): + food_name = food.get("food_name", "Unknown") + food_id = food.get("food_id", "Unknown") + food_desc = food.get("food_description", "No description") + + print(f" {i+1}. [{food_id}] {food_name}") + print(f" {food_desc}") + + +def main(): + """Основная функция для тестирования API FatSecret""" + print("\n🚀 Starting FatSecret API test with OAuth 1.0...\n") + + # Тестируем поиск продуктов на английском + search_queries = ["PowerAde", "Americano", "Coca-Cola", "chicken breast"] + + for query in search_queries: + result = search_food(query) + if result: + process_search_results(result) + + # Тестируем поиск продуктов на русском + russian_queries = ["Барни", "хлеб", "яблоко"] + + for query in russian_queries: + result = search_food(query, locale="ru_RU") + if result: + process_search_results(result) + + print("\n✅ Test completed!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_fatsecret_api_v2.py b/test_fatsecret_api_v2.py new file mode 100755 index 0000000..5e95b75 --- /dev/null +++ b/test_fatsecret_api_v2.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +""" +Скрипт для тестирования API FatSecret +Выполняет тестовые запросы к API FatSecret с использованием ключей из конфигурации приложения +""" + +import os +import json +import time +import base64 +import requests +import urllib.parse +from dotenv import load_dotenv + +# Загружаем .env файл +current_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(current_dir, ".env") +load_dotenv(env_path) +print(f"✅ Loaded .env from: {env_path}") + +# Получаем API ключи из переменных окружения +FATSECRET_CLIENT_ID = os.environ.get("FATSECRET_CLIENT_ID") +FATSECRET_CLIENT_SECRET = os.environ.get("FATSECRET_CLIENT_SECRET") + +if not FATSECRET_CLIENT_ID or not FATSECRET_CLIENT_SECRET: + raise ValueError("FatSecret API keys not found in .env file") + +print(f"🔑 Using FatSecret API keys: CLIENT_ID={FATSECRET_CLIENT_ID[:8]}...") +customer_key = os.environ.get("FATSECRET_CUSTOMER_KEY") +if customer_key: + print(f"🔑 Using CUSTOMER_KEY={customer_key[:8]}...") + + +def get_oauth_token(): + """Получение OAuth 2.0 токена для доступа к API""" + print("🔄 Getting OAuth token...") + + # Создаем заголовок авторизации с Base64-кодированными ID и секретом + auth_string = f"{FATSECRET_CLIENT_ID}:{FATSECRET_CLIENT_SECRET}" + auth_header = base64.b64encode(auth_string.encode()).decode() + + # Полный вывод учетных данных для диагностики + print(f"🔑 CLIENT_ID: {FATSECRET_CLIENT_ID}") + if FATSECRET_CLIENT_SECRET: + print(f"🔑 CLIENT_SECRET (first few chars): {FATSECRET_CLIENT_SECRET[:5]}...") + else: + print("⚠️ CLIENT_SECRET is missing!") + print(f"🔑 Authorization header: Basic {auth_header}") + + # Выполняем запрос на получение токена + token_url = "https://oauth.fatsecret.com/connect/token" + headers = { + "Authorization": f"Basic {auth_header}", + "Content-Type": "application/x-www-form-urlencoded" + } + data = { + "grant_type": "client_credentials", + "scope": "basic" + } + + print("📤 Sending request with headers:") + for key, value in headers.items(): + print(f" {key}: {value if key != 'Authorization' else value[:30]}...") + print("📤 Sending request with data:") + for key, value in data.items(): + print(f" {key}: {value}") + + try: + response = requests.post(token_url, headers=headers, data=data) + + # Дополнительная информация о запросе + print(f"📥 Response status code: {response.status_code}") + print(f"📥 Response headers: {dict(response.headers)}") + + # Проверяем успешность запроса + if response.status_code == 200: + token_data = response.json() + access_token = token_data.get("access_token") + expires_in = token_data.get("expires_in") + print(f"✅ Got token, expires in {expires_in} seconds") + return access_token + else: + print(f"❌ Error getting token: {response.status_code}") + print(f"❌ Error response: {response.text}") + return None + except Exception as e: + print(f"❌ Exception getting token: {e}") + return None + + +def search_food(token, query, max_results=5): + """Поиск продуктов по названию""" + if not token: + print("⚠️ No token available, cannot search") + return None + + print(f"🔍 Searching for '{query}'...") + + api_url = "https://platform.fatsecret.com/rest/server.api" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + params = { + "method": "foods.search", + "search_expression": query, + "max_results": max_results, + "format": "json" + } + + try: + response = requests.post(api_url, headers=headers, json=params) + + if response.status_code == 200: + print(f"✅ Search successful") + result = response.json() + return result + else: + print(f"❌ Error searching: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Exception during search: {e}") + return None + + +def get_food_details(token, food_id): + """Получение информации о продукте по ID""" + if not token: + print("⚠️ No token available, cannot get food details") + return None + + print(f"🔍 Getting details for food ID: {food_id}") + + api_url = "https://platform.fatsecret.com/rest/server.api" + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json" + } + + params = { + "method": "food.get", + "food_id": food_id, + "format": "json" + } + + try: + response = requests.post(api_url, headers=headers, json=params) + + if response.status_code == 200: + print(f"✅ Got food details") + result = response.json() + return result + else: + print(f"❌ Error getting food details: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Exception getting food details: {e}") + return None + + +def main(): + """Основная функция для тестирования API FatSecret""" + print("\n🚀 Starting FatSecret API test...\n") + + # Получаем токен доступа + token = get_oauth_token() + + if not token: + print("❌ Failed to get OAuth token, exiting") + return + + # Тестируем поиск продуктов + print("\n--- 📋 Testing Food Search ---") + search_queries = ["apple", "bread", "chicken breast", "молоко"] + first_food_id = None + + for query in search_queries: + result = search_food(token, query) + + if result and "foods" in result: + foods_data = result["foods"] + + if "food" not in foods_data: + print(f"⚠️ No results found for '{query}'") + continue + + food_list = foods_data["food"] + if not isinstance(food_list, list): + food_list = [food_list] # Если только один результат, оборачиваем в список + + print(f"📊 Found {len(food_list)} results") + + # Выводим первые 3 результата + for i, food in enumerate(food_list[:3]): + food_name = food.get("food_name", "Unknown") + food_id = food.get("food_id", "Unknown") + food_desc = food.get("food_description", "No description") + + print(f" {i+1}. [{food_id}] {food_name}") + print(f" {food_desc}") + + # Сохраняем ID первого продукта для следующего теста + if not first_food_id and food_list: + first_food_id = food_list[0].get("food_id") + + # Тестируем получение информации о продукте + if first_food_id: + print("\n--- 🍎 Testing Food Details ---") + food_details = get_food_details(token, first_food_id) + + if food_details and "food" in food_details: + food = food_details["food"] + food_name = food.get("food_name", "Unknown") + brand = food.get("brand_name", "Generic") + + print(f"📝 Details for: {food_name} [{brand}]") + + # Выводим информацию о пищевой ценности + if "servings" in food: + servings = food["servings"] + if "serving" in servings: + serving_data = servings["serving"] + if not isinstance(serving_data, list): + serving_data = [serving_data] + + print("\n📊 Nutrition info per serving:") + for i, serving in enumerate(serving_data[:2]): # Выводим до 2 видов порций + serving_desc = serving.get("serving_description", "Standard") + calories = serving.get("calories", "N/A") + protein = serving.get("protein", "N/A") + carbs = serving.get("carbohydrate", "N/A") + fat = serving.get("fat", "N/A") + + print(f" Serving {i+1}: {serving_desc}") + print(f" Calories: {calories}") + print(f" Protein: {protein}g") + print(f" Carbohydrates: {carbs}g") + print(f" Fat: {fat}g") + + print("\n✅ Test completed!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup_mobile_test.py b/tests/setup_mobile_test.py similarity index 100% rename from setup_mobile_test.py rename to tests/setup_mobile_test.py diff --git a/simple_test.py b/tests/simple_test.py similarity index 100% rename from simple_test.py rename to tests/simple_test.py diff --git a/test_calendar_mobile.py b/tests/test_calendar_mobile.py similarity index 100% rename from test_calendar_mobile.py rename to tests/test_calendar_mobile.py diff --git a/test_debug_endpoint.sh b/tests/test_debug_endpoint.sh similarity index 100% rename from test_debug_endpoint.sh rename to tests/test_debug_endpoint.sh diff --git a/test_mobile_api.py b/tests/test_mobile_api.py similarity index 100% rename from test_mobile_api.py rename to tests/test_mobile_api.py diff --git a/test_mobile_endpoint.py b/tests/test_mobile_endpoint.py similarity index 100% rename from test_mobile_endpoint.py rename to tests/test_mobile_endpoint.py diff --git a/test_mobile_endpoints.py b/tests/test_mobile_endpoints.py similarity index 100% rename from test_mobile_endpoints.py rename to tests/test_mobile_endpoints.py diff --git a/tests/test_nutrition_api.py b/tests/test_nutrition_api.py new file mode 100755 index 0000000..fd2b143 --- /dev/null +++ b/tests/test_nutrition_api.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +""" +Скрипт для тестирования API сервиса питания (Nutrition Service) +""" + +import os +import sys +import json +import requests +from datetime import datetime +from dotenv import load_dotenv + +# Загружаем .env файл +current_dir = os.path.dirname(os.path.abspath(__file__)) +env_path = os.path.join(current_dir, ".env") +load_dotenv(env_path) +print(f"✅ Загружен .env из: {env_path}") + +# Базовый URL API +BASE_URL = os.environ.get("NUTRITION_API_URL", "http://localhost:8006/api/v1/nutrition") +AUTH_URL = os.environ.get("AUTH_API_URL", "http://localhost:8001/api/v1/auth") + +# Настройки для тестовых данных +TEST_USER = { + "username": "test_nutrition_user", + "password": "Test123!", + "email": "test_nutrition@example.com", + "first_name": "Test", + "last_name": "Nutrition", + "phone": "+79991234999" +} + +def get_auth_token(): + """Получение токена авторизации""" + print("\n🔑 Получаем токен авторизации...") + + # Пытаемся сначала войти + try: + login_data = { + "username": TEST_USER["username"], + "password": TEST_USER["password"] + } + + login_response = requests.post( + f"{AUTH_URL}/login", + json=login_data + ) + + if login_response.status_code == 200: + token = login_response.json().get("access_token") + print("✅ Успешный вход в систему!") + return token + except Exception as e: + print(f"⚠️ Ошибка при попытке входа: {e}") + + # Если вход не удался, пробуем регистрацию + try: + register_response = requests.post( + f"{AUTH_URL}/register", + json=TEST_USER + ) + + if register_response.status_code == 201: + print("✅ Пользователь успешно зарегистрирован!") + + # Теперь входим с новыми учетными данными + login_data = { + "username": TEST_USER["username"], + "password": TEST_USER["password"] + } + + login_response = requests.post( + f"{AUTH_URL}/login", + json=login_data + ) + + if login_response.status_code == 200: + token = login_response.json().get("access_token") + print("✅ Успешный вход в систему!") + return token + except Exception as e: + print(f"❌ Ошибка при регистрации: {e}") + + print("❌ Не удалось получить токен авторизации") + return None + +def search_food(token, query="apple", max_results=5): + """Поиск продуктов питания""" + print(f"\n🔍 Поиск продуктов по запросу '{query}'...") + + headers = {"Authorization": f"Bearer {token}"} + data = { + "query": query, + "max_results": max_results + } + + try: + response = requests.post( + f"{BASE_URL}/search", + json=data, + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"✅ Найдено продуктов: {len(results)}") + + # Выводим первые 3 результата + for i, food in enumerate(results[:3]): + print(f" {i+1}. [{food.get('id')}] {food.get('name')}") + print(f" {food.get('description')}") + print(f" Калории: {food.get('calories')} ккал/100г") + + return results + else: + print(f"❌ Ошибка при поиске: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при поиске: {e}") + return None + +def add_diary_entry(token, food_id=1): + """Добавление записи в дневник питания""" + print(f"\n📝 Добавление записи в дневник питания (продукт ID: {food_id})...") + + headers = {"Authorization": f"Bearer {token}"} + today = datetime.now().strftime("%Y-%m-%d") + + data = { + "food_item_id": food_id, + "entry_date": today, + "meal_type": "breakfast", + "quantity": 1.0, + "unit": "piece", + "notes": "Тестовая запись" + } + + try: + response = requests.post( + f"{BASE_URL}/diary", + json=data, + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code in [200, 201]: + result = response.json() + print("✅ Запись успешно добавлена в дневник питания!") + print(f" ID записи: {result.get('id')}") + print(f" Калории: {result.get('calories')} ккал") + return result + else: + print(f"❌ Ошибка при добавлении записи: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при добавлении записи: {e}") + return None + +def get_diary_entries(token): + """Получение записей дневника за текущий день""" + print("\n📋 Получение записей дневника питания...") + + headers = {"Authorization": f"Bearer {token}"} + today = datetime.now().strftime("%Y-%m-%d") + + try: + response = requests.get( + f"{BASE_URL}/diary?date={today}", + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code == 200: + results = response.json() + print(f"✅ Получено записей: {len(results)}") + + # Выводим записи + for i, entry in enumerate(results): + print(f" {i+1}. Прием пищи: {entry.get('meal_type')}") + print(f" Продукт ID: {entry.get('food_item_id')}") + print(f" Калории: {entry.get('calories')} ккал") + + return results + else: + print(f"❌ Ошибка при получении записей: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при получении записей: {e}") + return None + +def add_water_entry(token, amount_ml=250): + """Добавление записи о потреблении воды""" + print(f"\n💧 Добавление записи о потреблении воды ({amount_ml} мл)...") + + headers = {"Authorization": f"Bearer {token}"} + today = datetime.now().strftime("%Y-%m-%d") + + data = { + "amount_ml": amount_ml, + "entry_date": today, + "notes": "Тестовая запись" + } + + try: + response = requests.post( + f"{BASE_URL}/water", + json=data, + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code in [200, 201]: + result = response.json() + print("✅ Запись о потреблении воды успешно добавлена!") + print(f" ID записи: {result.get('id')}") + print(f" Объем: {result.get('amount_ml')} мл") + return result + else: + print(f"❌ Ошибка при добавлении записи о воде: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при добавлении записи о воде: {e}") + return None + +def add_activity_entry(token): + """Добавление записи о физической активности""" + print("\n🏃‍♀️ Добавление записи о физической активности...") + + headers = {"Authorization": f"Bearer {token}"} + today = datetime.now().strftime("%Y-%m-%d") + + data = { + "entry_date": today, + "activity_type": "walking", + "duration_minutes": 30, + "distance_km": 2.5, + "intensity": "medium", + "notes": "Тестовая активность" + } + + try: + response = requests.post( + f"{BASE_URL}/activity", + json=data, + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code in [200, 201]: + result = response.json() + print("✅ Запись о физической активности успешно добавлена!") + print(f" ID записи: {result.get('id')}") + print(f" Тип: {result.get('activity_type')}") + print(f" Продолжительность: {result.get('duration_minutes')} мин") + print(f" Потрачено калорий: {result.get('calories_burned')} ккал") + return result + else: + print(f"❌ Ошибка при добавлении записи об активности: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при добавлении записи об активности: {e}") + return None + +def get_daily_summary(token): + """Получение дневной сводки""" + print("\n📊 Получение сводки за день...") + + headers = {"Authorization": f"Bearer {token}"} + today = datetime.now().strftime("%Y-%m-%d") + + try: + response = requests.get( + f"{BASE_URL}/summary?date={today}", + headers=headers + ) + + print(f"📥 Код ответа: {response.status_code}") + + if response.status_code == 200: + result = response.json() + print("✅ Сводка за день успешно получена!") + print(f" Всего калорий: {result.get('total_calories')} ккал") + print(f" Всего белка: {result.get('total_protein')} г") + print(f" Всего жиров: {result.get('total_fat')} г") + print(f" Всего углеводов: {result.get('total_carbs')} г") + print(f" Потреблено воды: {result.get('water_consumed_ml')} мл") + print(f" Активность: {result.get('activity_minutes')} мин") + print(f" Сожжено калорий: {result.get('calories_burned')} ккал") + return result + else: + print(f"❌ Ошибка при получении сводки: {response.status_code}") + print(response.text) + return None + except Exception as e: + print(f"❌ Исключение при получении сводки: {e}") + return None + +def main(): + """Основная функция для тестирования API сервиса питания""" + print("\n🚀 Запуск тестирования API сервиса питания...\n") + + # Получаем токен авторизации + token = get_auth_token() + if not token: + print("❌ Невозможно продолжить тестирование без авторизации") + sys.exit(1) + + # Выполняем поиск продуктов + search_results = search_food(token, "apple") + + if search_results and len(search_results) > 0: + # Используем первый найденный продукт для дальнейшего тестирования + food_id = search_results[0].get("id") + + # Добавляем запись в дневник питания + add_diary_entry(token, food_id) + + # Получаем записи дневника + get_diary_entries(token) + else: + # Если поиск не дал результатов, продолжаем тестирование с предполагаемым ID продукта + print("⚠️ Используем предполагаемый ID продукта для дальнейших тестов") + add_diary_entry(token, 1) + + # Добавляем записи о воде и активности + add_water_entry(token) + add_activity_entry(token) + + # Получаем дневную сводку + get_daily_summary(token) + + print("\n✅ Тестирование API сервиса питания завершено!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_nutrition_service.sh b/tests/test_nutrition_service.sh new file mode 100755 index 0000000..8dcbf5a --- /dev/null +++ b/tests/test_nutrition_service.sh @@ -0,0 +1,189 @@ +#!/bin/bash + +# Скрипт для тестирования API сервиса питания через cURL + +# Настройки +API_BASE_URL="http://localhost:8006/api/v1/nutrition" +AUTH_URL="http://localhost:8001/api/v1/auth" +TODAY=$(date +"%Y-%m-%d") +TEST_USERNAME="test_nutrition_user" +TEST_PASSWORD="Test123!" + +# Цветной вывод +GREEN='\033[0;32m' +RED='\033[0;31m' +BLUE='\033[0;34m' +YELLOW='\033[0;33m' +NC='\033[0m' # No Color + +echo -e "${BLUE}🚀 Запуск тестов для Nutrition Service API${NC}" +echo "---------------------------------------------" + +# Шаг 1: Авторизация и получение токена +echo -e "${BLUE}📝 Шаг 1: Получение токена авторизации${NC}" + +# Попытка входа +login_response=$(curl -s -X POST "${AUTH_URL}/login" \ + -H "Content-Type: application/json" \ + -d '{ + "username": "'"${TEST_USERNAME}"'", + "password": "'"${TEST_PASSWORD}"'" + }') + +# Проверяем, успешен ли вход +if [[ $login_response == *"access_token"* ]]; then + TOKEN=$(echo $login_response | grep -o '"access_token":"[^"]*' | sed 's/"access_token":"//') + echo -e "${GREEN}✅ Вход успешен!${NC}" +else + echo -e "${YELLOW}⚠️ Вход не удался, пробуем регистрацию...${NC}" + + # Пробуем зарегистрировать пользователя + curl -s -X POST "${AUTH_URL}/register" \ + -H "Content-Type: application/json" \ + -d '{ + "email": "'"${TEST_USERNAME}@example.com"'", + "username": "'"${TEST_USERNAME}"'", + "password": "'"${TEST_PASSWORD}"'", + "first_name": "Test", + "last_name": "Nutrition", + "phone": "+79991234999" + }' > /dev/null + + # После регистрации пробуем войти снова + login_response=$(curl -s -X POST "${AUTH_URL}/login" \ + -H "Content-Type: application/json" \ + -d '{ + "username": "'"${TEST_USERNAME}"'", + "password": "'"${TEST_PASSWORD}"'" + }') + + if [[ $login_response == *"access_token"* ]]; then + TOKEN=$(echo $login_response | grep -o '"access_token":"[^"]*' | sed 's/"access_token":"//') + echo -e "${GREEN}✅ Регистрация и вход успешны!${NC}" + else + echo -e "${RED}❌ Не удалось получить токен авторизации${NC}" + echo "Ответ сервера: $login_response" + exit 1 + fi +fi + +# Шаг 2: Поиск продуктов +echo -e "\n${BLUE}📝 Шаг 2: Поиск продуктов${NC}" +search_response=$(curl -s -X POST "${API_BASE_URL}/search" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "query": "apple", + "max_results": 5 + }') + +echo "Результат поиска:" +echo "$search_response" | grep -o '"name":"[^"]*' | head -3 | sed 's/"name":"/- /' +echo "..." + +# Получаем ID первого продукта из результатов поиска +FOOD_ID=$(echo $search_response | grep -o '"id":[0-9]*' | head -1 | sed 's/"id"://') +if [[ -z "$FOOD_ID" ]]; then + echo -e "${YELLOW}⚠️ Не удалось получить ID продукта, используем значение по умолчанию${NC}" + FOOD_ID=1 +else + echo -e "${GREEN}✅ Получен ID продукта: ${FOOD_ID}${NC}" +fi + +# Шаг 3: Добавление записи в дневник питания +echo -e "\n${BLUE}📝 Шаг 3: Добавление записи в дневник питания${NC}" +diary_response=$(curl -s -X POST "${API_BASE_URL}/diary" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "food_item_id": '"${FOOD_ID}"', + "entry_date": "'"${TODAY}"'", + "meal_type": "breakfast", + "quantity": 1.5, + "unit": "piece", + "notes": "Тестовая запись" + }') + +if [[ $diary_response == *"id"* ]]; then + echo -e "${GREEN}✅ Запись добавлена в дневник питания${NC}" + echo "Детали записи:" + echo "$diary_response" | grep -o '"calories":[0-9.]*' | sed 's/"calories":/Калории: /' +else + echo -e "${RED}❌ Ошибка при добавлении записи в дневник${NC}" + echo "Ответ сервера: $diary_response" +fi + +# Шаг 4: Получение записей дневника +echo -e "\n${BLUE}📝 Шаг 4: Получение записей дневника${NC}" +get_diary_response=$(curl -s -X GET "${API_BASE_URL}/diary?date=${TODAY}" \ + -H "Authorization: Bearer ${TOKEN}") + +if [[ $get_diary_response == *"meal_type"* ]]; then + echo -e "${GREEN}✅ Записи дневника успешно получены${NC}" + echo "Количество записей: $(echo $get_diary_response | grep -o '"meal_type"' | wc -l)" +else + echo -e "${YELLOW}⚠️ Нет записей в дневнике или ошибка получения${NC}" + echo "Ответ сервера: $get_diary_response" +fi + +# Шаг 5: Добавление записи о потреблении воды +echo -e "\n${BLUE}📝 Шаг 5: Добавление записи о потреблении воды${NC}" +water_response=$(curl -s -X POST "${API_BASE_URL}/water" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "amount_ml": 250, + "entry_date": "'"${TODAY}"'", + "notes": "Утренний стакан воды" + }') + +if [[ $water_response == *"id"* ]]; then + echo -e "${GREEN}✅ Запись о потреблении воды добавлена${NC}" + echo "Детали записи:" + echo "$water_response" | grep -o '"amount_ml":[0-9]*' | sed 's/"amount_ml":/Объем (мл): /' +else + echo -e "${RED}❌ Ошибка при добавлении записи о воде${NC}" + echo "Ответ сервера: $water_response" +fi + +# Шаг 6: Добавление записи о физической активности +echo -e "\n${BLUE}📝 Шаг 6: Добавление записи о физической активности${NC}" +activity_response=$(curl -s -X POST "${API_BASE_URL}/activity" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${TOKEN}" \ + -d '{ + "entry_date": "'"${TODAY}"'", + "activity_type": "running", + "duration_minutes": 30, + "distance_km": 5.2, + "intensity": "medium", + "notes": "Утренняя пробежка" + }') + +if [[ $activity_response == *"id"* ]]; then + echo -e "${GREEN}✅ Запись о физической активности добавлена${NC}" + echo "Детали записи:" + echo "$activity_response" | grep -o '"duration_minutes":[0-9]*' | sed 's/"duration_minutes":/Продолжительность (мин): /' + echo "$activity_response" | grep -o '"calories_burned":[0-9.]*' | sed 's/"calories_burned":/Сожжено калорий: /' +else + echo -e "${RED}❌ Ошибка при добавлении записи об активности${NC}" + echo "Ответ сервера: $activity_response" +fi + +# Шаг 7: Получение сводки за день +echo -e "\n${BLUE}📝 Шаг 7: Получение сводки за день${NC}" +summary_response=$(curl -s -X GET "${API_BASE_URL}/summary?date=${TODAY}" \ + -H "Authorization: Bearer ${TOKEN}") + +if [[ $summary_response == *"total_calories"* ]]; then + echo -e "${GREEN}✅ Дневная сводка успешно получена${NC}" + echo "Детали сводки:" + echo "$summary_response" | grep -o '"total_calories":[0-9.]*' | sed 's/"total_calories":/Всего калорий: /' + echo "$summary_response" | grep -o '"water_consumed_ml":[0-9]*' | sed 's/"water_consumed_ml":/Потреблено воды (мл): /' + echo "$summary_response" | grep -o '"activity_minutes":[0-9]*' | sed 's/"activity_minutes":/Минуты активности: /' +else + echo -e "${YELLOW}⚠️ Нет данных для сводки или ошибка получения${NC}" + echo "Ответ сервера: $summary_response" +fi + +echo -e "\n${GREEN}✅ Тестирование API сервиса питания завершено!${NC}" \ No newline at end of file diff --git a/test_standalone.py b/tests/test_standalone.py similarity index 100% rename from test_standalone.py rename to tests/test_standalone.py diff --git a/venv/bin/alembic b/venv/bin/alembic index 826ade7..6882a08 100755 --- a/venv/bin/alembic +++ b/venv/bin/alembic @@ -1,4 +1,4 @@ -#!/home/trevor/dev/chat/venv/bin/python +#!/home/trevor/dev/chat/venv/bin/python3.12 # -*- coding: utf-8 -*- import re import sys diff --git a/venv/bin/celery b/venv/bin/celery index 604c4b1..5280d04 100755 --- a/venv/bin/celery +++ b/venv/bin/celery @@ -1,4 +1,4 @@ -#!/home/trevor/dev/chat/venv/bin/python +#!/home/trevor/dev/chat/venv/bin/python3.12 # -*- coding: utf-8 -*- import re import sys diff --git a/venv/bin/dotenv b/venv/bin/dotenv index 5afe262..a4d2ce6 100755 --- a/venv/bin/dotenv +++ b/venv/bin/dotenv @@ -1,4 +1,4 @@ -#!/home/trevor/dev/chat/venv/bin/python +#!/home/trevor/dev/chat/venv/bin/python3.12 # -*- coding: utf-8 -*- import re import sys diff --git a/venv/bin/fastapi b/venv/bin/fastapi deleted file mode 100755 index 0e19d96..0000000 --- a/venv/bin/fastapi +++ /dev/null @@ -1,8 +0,0 @@ -#!/home/trevor/dev/chat/venv/bin/python -# -*- coding: utf-8 -*- -import re -import sys -from fastapi.cli import main -if __name__ == '__main__': - sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) - sys.exit(main()) diff --git a/venv/bin/uvicorn b/venv/bin/uvicorn index 9c6469f..e1a62c9 100755 --- a/venv/bin/uvicorn +++ b/venv/bin/uvicorn @@ -1,4 +1,4 @@ -#!/home/trevor/dev/chat/venv/bin/python +#!/home/trevor/dev/chat/venv/bin/python3.12 # -*- coding: utf-8 -*- import re import sys diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/RECORD deleted file mode 100644 index 84334ea..0000000 --- a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/RECORD +++ /dev/null @@ -1,33 +0,0 @@ -PyJWT-2.10.1.dist-info/AUTHORS.rst,sha256=klzkNGECnu2_VY7At89_xLBF3vUSDruXk3xwgUBxzwc,322 -PyJWT-2.10.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -PyJWT-2.10.1.dist-info/LICENSE,sha256=eXp6ICMdTEM-nxkR2xcx0GtYKLmPSZgZoDT3wPVvXOU,1085 -PyJWT-2.10.1.dist-info/METADATA,sha256=EkewF6D6KU8SGaaQzVYfxUUU1P_gs_dp1pYTkoYvAx8,3990 -PyJWT-2.10.1.dist-info/RECORD,, -PyJWT-2.10.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -PyJWT-2.10.1.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91 -PyJWT-2.10.1.dist-info/top_level.txt,sha256=RP5DHNyJbMq2ka0FmfTgoSaQzh7e3r5XuCWCO8a00k8,4 -jwt/__init__.py,sha256=VB2vFKuboTjcDGeZ8r-UqK_dz3NsQSQEqySSICby8Xg,1711 -jwt/__pycache__/__init__.cpython-312.pyc,, -jwt/__pycache__/algorithms.cpython-312.pyc,, -jwt/__pycache__/api_jwk.cpython-312.pyc,, -jwt/__pycache__/api_jws.cpython-312.pyc,, -jwt/__pycache__/api_jwt.cpython-312.pyc,, -jwt/__pycache__/exceptions.cpython-312.pyc,, -jwt/__pycache__/help.cpython-312.pyc,, -jwt/__pycache__/jwk_set_cache.cpython-312.pyc,, -jwt/__pycache__/jwks_client.cpython-312.pyc,, -jwt/__pycache__/types.cpython-312.pyc,, -jwt/__pycache__/utils.cpython-312.pyc,, -jwt/__pycache__/warnings.cpython-312.pyc,, -jwt/algorithms.py,sha256=cKr-XEioe0mBtqJMCaHEswqVOA1Z8Purt5Sb3Bi-5BE,30409 -jwt/api_jwk.py,sha256=6F1r7rmm8V5qEnBKA_xMjS9R7VoANe1_BL1oD2FrAjE,4451 -jwt/api_jws.py,sha256=aM8vzqQf6mRrAw7bRy-Moj_pjWsKSVQyYK896AfMjJU,11762 -jwt/api_jwt.py,sha256=OGT4hok1l5A6FH_KdcrU5g6u6EQ8B7em0r9kGM9SYgA,14512 -jwt/exceptions.py,sha256=bUIOJ-v9tjopTLS-FYOTc3kFx5WP5IZt7ksN_HE1G9Q,1211 -jwt/help.py,sha256=vFdNzjQoAch04XCMYpCkyB2blaqHAGAqQrtf9nSPkdk,1808 -jwt/jwk_set_cache.py,sha256=hBKmN-giU7-G37L_XKgc_OZu2ah4wdbj1ZNG_GkoSE8,959 -jwt/jwks_client.py,sha256=p9b-IbQqo2tEge9Zit3oSPBFNePqwho96VLbnUrHUWs,4259 -jwt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -jwt/types.py,sha256=VnhGv_VFu5a7_mrPoSCB7HaNLrJdhM8Sq1sSfEg0gLU,99 -jwt/utils.py,sha256=hxOjvDBheBYhz-RIPiEz7Q88dSUSTMzEdKE_Ww2VdJw,3640 -jwt/warnings.py,sha256=50XWOnyNsIaqzUJTk6XHNiIDykiL763GYA92MjTKmok,59 diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/AUTHORS.rst b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/AUTHORS.rst similarity index 100% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/AUTHORS.rst rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/AUTHORS.rst diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/LICENSE b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/LICENSE rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/METADATA b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/METADATA similarity index 67% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/METADATA rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/METADATA index f31b700..b329a46 100644 --- a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/METADATA @@ -1,45 +1,47 @@ Metadata-Version: 2.1 Name: PyJWT -Version: 2.10.1 +Version: 2.8.0 Summary: JSON Web Token implementation in Python -Author-email: Jose Padilla +Home-page: https://github.com/jpadilla/pyjwt +Author: Jose Padilla +Author-email: hello@jpadilla.com License: MIT -Project-URL: Homepage, https://github.com/jpadilla/pyjwt Keywords: json,jwt,security,signing,token,web Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers -Classifier: License :: OSI Approved :: MIT License Classifier: Natural Language :: English +Classifier: License :: OSI Approved :: MIT License Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Topic :: Utilities -Requires-Python: >=3.9 +Requires-Python: >=3.7 Description-Content-Type: text/x-rst License-File: LICENSE License-File: AUTHORS.rst +Requires-Dist: typing-extensions ; python_version <= "3.7" Provides-Extra: crypto -Requires-Dist: cryptography>=3.4.0; extra == "crypto" +Requires-Dist: cryptography (>=3.4.0) ; extra == 'crypto' Provides-Extra: dev -Requires-Dist: coverage[toml]==5.0.4; extra == "dev" -Requires-Dist: cryptography>=3.4.0; extra == "dev" -Requires-Dist: pre-commit; extra == "dev" -Requires-Dist: pytest<7.0.0,>=6.0.0; extra == "dev" -Requires-Dist: sphinx; extra == "dev" -Requires-Dist: sphinx-rtd-theme; extra == "dev" -Requires-Dist: zope.interface; extra == "dev" +Requires-Dist: sphinx (<5.0.0,>=4.5.0) ; extra == 'dev' +Requires-Dist: sphinx-rtd-theme ; extra == 'dev' +Requires-Dist: zope.interface ; extra == 'dev' +Requires-Dist: cryptography (>=3.4.0) ; extra == 'dev' +Requires-Dist: pytest (<7.0.0,>=6.0.0) ; extra == 'dev' +Requires-Dist: coverage[toml] (==5.0.4) ; extra == 'dev' +Requires-Dist: pre-commit ; extra == 'dev' Provides-Extra: docs -Requires-Dist: sphinx; extra == "docs" -Requires-Dist: sphinx-rtd-theme; extra == "docs" -Requires-Dist: zope.interface; extra == "docs" +Requires-Dist: sphinx (<5.0.0,>=4.5.0) ; extra == 'docs' +Requires-Dist: sphinx-rtd-theme ; extra == 'docs' +Requires-Dist: zope.interface ; extra == 'docs' Provides-Extra: tests -Requires-Dist: coverage[toml]==5.0.4; extra == "tests" -Requires-Dist: pytest<7.0.0,>=6.0.0; extra == "tests" +Requires-Dist: pytest (<7.0.0,>=6.0.0) ; extra == 'tests' +Requires-Dist: coverage[toml] (==5.0.4) ; extra == 'tests' PyJWT ===== @@ -61,12 +63,11 @@ A Python implementation of `RFC 7519 `_. Or Sponsor ------- -.. |auth0-logo| image:: https://github.com/user-attachments/assets/ee98379e-ee76-4bcb-943a-e25c4ea6d174 - :width: 160px ++--------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| |auth0-logo| | If you want to quickly add secure token-based authentication to Python projects, feel free to check Auth0's Python SDK and free plan at `auth0.com/developers `_. | ++--------------+-----------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -+--------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ -| |auth0-logo| | If you want to quickly add secure token-based authentication to Python projects, feel free to check Auth0's Python SDK and free plan at `auth0.com/signup `_. | -+--------------+-----------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +.. |auth0-logo| image:: https://user-images.githubusercontent.com/83319/31722733-de95bbde-b3ea-11e7-96bf-4f4e8f915588.png Installing ---------- diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/RECORD new file mode 100644 index 0000000..d77ecb7 --- /dev/null +++ b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/RECORD @@ -0,0 +1,33 @@ +PyJWT-2.8.0.dist-info/AUTHORS.rst,sha256=klzkNGECnu2_VY7At89_xLBF3vUSDruXk3xwgUBxzwc,322 +PyJWT-2.8.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +PyJWT-2.8.0.dist-info/LICENSE,sha256=eXp6ICMdTEM-nxkR2xcx0GtYKLmPSZgZoDT3wPVvXOU,1085 +PyJWT-2.8.0.dist-info/METADATA,sha256=pV2XZjvithGcVesLHWAv0J4T5t8Qc66fip2sbxwoz1o,4160 +PyJWT-2.8.0.dist-info/RECORD,, +PyJWT-2.8.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +PyJWT-2.8.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92 +PyJWT-2.8.0.dist-info/top_level.txt,sha256=RP5DHNyJbMq2ka0FmfTgoSaQzh7e3r5XuCWCO8a00k8,4 +jwt/__init__.py,sha256=mV9lg6n4-0xiqCKaE1eEPC9a4j6sEkEYQcKghULE7kU,1670 +jwt/__pycache__/__init__.cpython-312.pyc,, +jwt/__pycache__/algorithms.cpython-312.pyc,, +jwt/__pycache__/api_jwk.cpython-312.pyc,, +jwt/__pycache__/api_jws.cpython-312.pyc,, +jwt/__pycache__/api_jwt.cpython-312.pyc,, +jwt/__pycache__/exceptions.cpython-312.pyc,, +jwt/__pycache__/help.cpython-312.pyc,, +jwt/__pycache__/jwk_set_cache.cpython-312.pyc,, +jwt/__pycache__/jwks_client.cpython-312.pyc,, +jwt/__pycache__/types.cpython-312.pyc,, +jwt/__pycache__/utils.cpython-312.pyc,, +jwt/__pycache__/warnings.cpython-312.pyc,, +jwt/algorithms.py,sha256=RDsv5Lm3bzwsiWT3TynT7JR41R6H6s_fWUGOIqd9x_I,29800 +jwt/api_jwk.py,sha256=HPxVqgBZm7RTaEXydciNBCuYNKDYOC_prTdaN9toGbo,4196 +jwt/api_jws.py,sha256=da17RrDe0PDccTbx3rx2lLezEG_c_YGw_vVHa335IOk,11099 +jwt/api_jwt.py,sha256=yF9DwF1kt3PA5n_TiU0OmHd0LtPHfe4JCE1XOfKPjw0,12638 +jwt/exceptions.py,sha256=KDC3M7cTrpR4OQXVURlVMThem0pfANSgBxRz-ttivmo,1046 +jwt/help.py,sha256=Jrp84fG43sCwmSIaDtY08I6ZR2VE7NhrTff89tYSE40,1749 +jwt/jwk_set_cache.py,sha256=hBKmN-giU7-G37L_XKgc_OZu2ah4wdbj1ZNG_GkoSE8,959 +jwt/jwks_client.py,sha256=9W8JVyGByQgoLbBN1u5iY1_jlgfnnukeOBTpqaM_9SE,4222 +jwt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jwt/types.py,sha256=VnhGv_VFu5a7_mrPoSCB7HaNLrJdhM8Sq1sSfEg0gLU,99 +jwt/utils.py,sha256=PAI05_8MHQCxWQTDlwN0hTtTIT2DTTZ28mm1x6-26UY,3903 +jwt/warnings.py,sha256=50XWOnyNsIaqzUJTk6XHNiIDykiL763GYA92MjTKmok,59 diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/WHEEL similarity index 65% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/WHEEL index e7fa31b..1f37c02 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.9.0) +Generator: bdist_wheel (0.40.0) Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/PyJWT-2.8.0.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/LICENSE similarity index 94% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/LICENSE index dfe1a4d..7bf9bbe 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/licenses/LICENSE +++ b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/LICENSE @@ -1,4 +1,4 @@ -Copyright 2005-2025 SQLAlchemy authors and contributors . +Copyright 2005-2023 SQLAlchemy authors and contributors . Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/METADATA b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/METADATA similarity index 73% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/METADATA rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/METADATA index d34d362..9f2808f 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: SQLAlchemy -Version: 2.0.43 +Version: 2.0.23 Summary: Database Abstraction Library Home-page: https://www.sqlalchemy.org Author: Mike Bayer @@ -10,6 +10,7 @@ Project-URL: Documentation, https://docs.sqlalchemy.org Project-URL: Issue Tracker, https://github.com/sqlalchemy/sqlalchemy/ Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 @@ -18,70 +19,67 @@ Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Database :: Front-Ends Requires-Python: >=3.7 Description-Content-Type: text/x-rst License-File: LICENSE -Requires-Dist: importlib-metadata; python_version < "3.8" -Requires-Dist: greenlet>=1; python_version < "3.14" and (platform_machine == "aarch64" or (platform_machine == "ppc64le" or (platform_machine == "x86_64" or (platform_machine == "amd64" or (platform_machine == "AMD64" or (platform_machine == "win32" or platform_machine == "WIN32")))))) -Requires-Dist: typing-extensions>=4.6.0 -Provides-Extra: asyncio -Requires-Dist: greenlet>=1; extra == "asyncio" -Provides-Extra: mypy -Requires-Dist: mypy>=0.910; extra == "mypy" -Provides-Extra: mssql -Requires-Dist: pyodbc; extra == "mssql" -Provides-Extra: mssql-pymssql -Requires-Dist: pymssql; extra == "mssql-pymssql" -Provides-Extra: mssql-pyodbc -Requires-Dist: pyodbc; extra == "mssql-pyodbc" -Provides-Extra: mysql -Requires-Dist: mysqlclient>=1.4.0; extra == "mysql" -Provides-Extra: mysql-connector -Requires-Dist: mysql-connector-python; extra == "mysql-connector" -Provides-Extra: mariadb-connector -Requires-Dist: mariadb!=1.1.10,!=1.1.2,!=1.1.5,>=1.0.1; extra == "mariadb-connector" -Provides-Extra: oracle -Requires-Dist: cx_oracle>=8; extra == "oracle" -Provides-Extra: oracle-oracledb -Requires-Dist: oracledb>=1.0.1; extra == "oracle-oracledb" -Provides-Extra: postgresql -Requires-Dist: psycopg2>=2.7; extra == "postgresql" -Provides-Extra: postgresql-pg8000 -Requires-Dist: pg8000>=1.29.1; extra == "postgresql-pg8000" -Provides-Extra: postgresql-asyncpg -Requires-Dist: greenlet>=1; extra == "postgresql-asyncpg" -Requires-Dist: asyncpg; extra == "postgresql-asyncpg" -Provides-Extra: postgresql-psycopg2binary -Requires-Dist: psycopg2-binary; extra == "postgresql-psycopg2binary" -Provides-Extra: postgresql-psycopg2cffi -Requires-Dist: psycopg2cffi; extra == "postgresql-psycopg2cffi" -Provides-Extra: postgresql-psycopg -Requires-Dist: psycopg>=3.0.7; extra == "postgresql-psycopg" -Provides-Extra: postgresql-psycopgbinary -Requires-Dist: psycopg[binary]>=3.0.7; extra == "postgresql-psycopgbinary" -Provides-Extra: pymysql -Requires-Dist: pymysql; extra == "pymysql" +Requires-Dist: typing-extensions >=4.2.0 +Requires-Dist: greenlet !=0.4.17 ; platform_machine == "aarch64" or (platform_machine == "ppc64le" or (platform_machine == "x86_64" or (platform_machine == "amd64" or (platform_machine == "AMD64" or (platform_machine == "win32" or platform_machine == "WIN32"))))) +Requires-Dist: importlib-metadata ; python_version < "3.8" Provides-Extra: aiomysql -Requires-Dist: greenlet>=1; extra == "aiomysql" -Requires-Dist: aiomysql>=0.2.0; extra == "aiomysql" +Requires-Dist: greenlet !=0.4.17 ; extra == 'aiomysql' +Requires-Dist: aiomysql >=0.2.0 ; extra == 'aiomysql' Provides-Extra: aioodbc -Requires-Dist: greenlet>=1; extra == "aioodbc" -Requires-Dist: aioodbc; extra == "aioodbc" -Provides-Extra: asyncmy -Requires-Dist: greenlet>=1; extra == "asyncmy" -Requires-Dist: asyncmy!=0.2.4,!=0.2.6,>=0.2.3; extra == "asyncmy" +Requires-Dist: greenlet !=0.4.17 ; extra == 'aioodbc' +Requires-Dist: aioodbc ; extra == 'aioodbc' Provides-Extra: aiosqlite -Requires-Dist: greenlet>=1; extra == "aiosqlite" -Requires-Dist: aiosqlite; extra == "aiosqlite" -Requires-Dist: typing_extensions!=3.10.0.1; extra == "aiosqlite" +Requires-Dist: greenlet !=0.4.17 ; extra == 'aiosqlite' +Requires-Dist: aiosqlite ; extra == 'aiosqlite' +Requires-Dist: typing-extensions !=3.10.0.1 ; extra == 'aiosqlite' +Provides-Extra: asyncio +Requires-Dist: greenlet !=0.4.17 ; extra == 'asyncio' +Provides-Extra: asyncmy +Requires-Dist: greenlet !=0.4.17 ; extra == 'asyncmy' +Requires-Dist: asyncmy !=0.2.4,!=0.2.6,>=0.2.3 ; extra == 'asyncmy' +Provides-Extra: mariadb_connector +Requires-Dist: mariadb !=1.1.2,!=1.1.5,>=1.0.1 ; extra == 'mariadb_connector' +Provides-Extra: mssql +Requires-Dist: pyodbc ; extra == 'mssql' +Provides-Extra: mssql_pymssql +Requires-Dist: pymssql ; extra == 'mssql_pymssql' +Provides-Extra: mssql_pyodbc +Requires-Dist: pyodbc ; extra == 'mssql_pyodbc' +Provides-Extra: mypy +Requires-Dist: mypy >=0.910 ; extra == 'mypy' +Provides-Extra: mysql +Requires-Dist: mysqlclient >=1.4.0 ; extra == 'mysql' +Provides-Extra: mysql_connector +Requires-Dist: mysql-connector-python ; extra == 'mysql_connector' +Provides-Extra: oracle +Requires-Dist: cx-oracle >=8 ; extra == 'oracle' +Provides-Extra: oracle_oracledb +Requires-Dist: oracledb >=1.0.1 ; extra == 'oracle_oracledb' +Provides-Extra: postgresql +Requires-Dist: psycopg2 >=2.7 ; extra == 'postgresql' +Provides-Extra: postgresql_asyncpg +Requires-Dist: greenlet !=0.4.17 ; extra == 'postgresql_asyncpg' +Requires-Dist: asyncpg ; extra == 'postgresql_asyncpg' +Provides-Extra: postgresql_pg8000 +Requires-Dist: pg8000 >=1.29.1 ; extra == 'postgresql_pg8000' +Provides-Extra: postgresql_psycopg +Requires-Dist: psycopg >=3.0.7 ; extra == 'postgresql_psycopg' +Provides-Extra: postgresql_psycopg2binary +Requires-Dist: psycopg2-binary ; extra == 'postgresql_psycopg2binary' +Provides-Extra: postgresql_psycopg2cffi +Requires-Dist: psycopg2cffi ; extra == 'postgresql_psycopg2cffi' +Provides-Extra: postgresql_psycopgbinary +Requires-Dist: psycopg[binary] >=3.0.7 ; extra == 'postgresql_psycopgbinary' +Provides-Extra: pymysql +Requires-Dist: pymysql ; extra == 'pymysql' Provides-Extra: sqlcipher -Requires-Dist: sqlcipher3_binary; extra == "sqlcipher" -Dynamic: license-file +Requires-Dist: sqlcipher3-binary ; extra == 'sqlcipher' SQLAlchemy ========== diff --git a/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/RECORD b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/RECORD new file mode 100644 index 0000000..f8e50d8 --- /dev/null +++ b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/RECORD @@ -0,0 +1,530 @@ +SQLAlchemy-2.0.23.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +SQLAlchemy-2.0.23.dist-info/LICENSE,sha256=2lSTeluT1aC-5eJXO8vhkzf93qCSeV_mFXLrv3tNdIU,1100 +SQLAlchemy-2.0.23.dist-info/METADATA,sha256=znDChLueFNPCOPuNix-FfY7FG6aQOCM-lQwwN-cPLQs,9551 +SQLAlchemy-2.0.23.dist-info/RECORD,, +SQLAlchemy-2.0.23.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +SQLAlchemy-2.0.23.dist-info/WHEEL,sha256=JmQLNqDEfvnYMfsIaVeSP3fmUcYDwmF12m3QYW0c7QQ,152 +SQLAlchemy-2.0.23.dist-info/top_level.txt,sha256=rp-ZgB7D8G11ivXON5VGPjupT1voYmWqkciDt5Uaw_Q,11 +sqlalchemy/__init__.py,sha256=DjKCAltzrHGfaVdXVeFJpBmTaX6JmyloHANzewBUWo4,12708 +sqlalchemy/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/__pycache__/events.cpython-312.pyc,, +sqlalchemy/__pycache__/exc.cpython-312.pyc,, +sqlalchemy/__pycache__/inspection.cpython-312.pyc,, +sqlalchemy/__pycache__/log.cpython-312.pyc,, +sqlalchemy/__pycache__/schema.cpython-312.pyc,, +sqlalchemy/__pycache__/types.cpython-312.pyc,, +sqlalchemy/connectors/__init__.py,sha256=uKUYWQoXyleIyjWBuh7gzgnazJokx3DaasKJbFOfQGA,476 +sqlalchemy/connectors/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/connectors/__pycache__/aioodbc.cpython-312.pyc,, +sqlalchemy/connectors/__pycache__/asyncio.cpython-312.pyc,, +sqlalchemy/connectors/__pycache__/pyodbc.cpython-312.pyc,, +sqlalchemy/connectors/aioodbc.py,sha256=QiafuN9bx_wcIs8tByLftTmGAegXPoFPwUaxCDU_ZQA,5737 +sqlalchemy/connectors/asyncio.py,sha256=ZZmJSFT50u-GEjZzytQOdB_tkBFxi3XPWRrNhs_nASc,6139 +sqlalchemy/connectors/pyodbc.py,sha256=NskMydn26ZkHL8aQ1V3L4WIAWin3zwJ5VEnlHvAD1DE,8453 +sqlalchemy/cyextension/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sqlalchemy/cyextension/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so,sha256=qPSMnyXVSLYHMr_ot_ZK7yEYadhTuT8ryb6eTMFFWrM,1947440 +sqlalchemy/cyextension/collections.pyx,sha256=KDI5QTOyYz9gDl-3d7MbGMA0Kc-wxpJqnLmCaUmQy2U,12323 +sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so,sha256=J9m0gK6R8PGR36jxAKx415VxX0-0fqvbQAP9-DDU1qA,811232 +sqlalchemy/cyextension/immutabledict.pxd,sha256=oc8BbnQwDg7pWAdThB-fzu8s9_ViOe1Ds-8T0r0POjI,41 +sqlalchemy/cyextension/immutabledict.pyx,sha256=aQJPZKjcqbO8jHDqpC9F-v-ew2qAjUscc5CntaheZUk,3285 +sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so,sha256=WOLcEWRNXn4UtJGhzF5B1h7JpPPfn-ziQMT0lkhobQE,533968 +sqlalchemy/cyextension/processors.pyx,sha256=0swFIBdR19x1kPRe-dijBaLW898AhH6QJizbv4ho9pk,1545 +sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so,sha256=bte73oURZXuV7YvkjyGo-OjRCnSgYukqDp5KM9-Z8xY,626112 +sqlalchemy/cyextension/resultproxy.pyx,sha256=cDtMjLTdC47g7cME369NSOCck3JwG2jwZ6j25no3_gw,2477 +sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so,sha256=8yMbb069NQN1b6yAsCBCMpbX94sH4iLs61vPNxd0bOg,958760 +sqlalchemy/cyextension/util.pyx,sha256=lv03p63oVn23jLhMI4_RYGewUnJfh-4FkrNMEFL7A3Y,2289 +sqlalchemy/dialects/__init__.py,sha256=hLsgIEomunlp4mNLnvjCQTLOnBVva8N7IT2-RYrN2_4,1770 +sqlalchemy/dialects/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/__pycache__/_typing.cpython-312.pyc,, +sqlalchemy/dialects/_typing.py,sha256=P2ML2o4b_bWAAy3zbdoUjx3vXsMNwpiOblef8ThCxlM,648 +sqlalchemy/dialects/mssql/__init__.py,sha256=CYbbydyMSLjUq8vY1siNStd4lvjVXod8ddeDS6ELHLk,1871 +sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/base.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/json.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-312.pyc,, +sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-312.pyc,, +sqlalchemy/dialects/mssql/aioodbc.py,sha256=ncj3yyfvW91o3g19GB5s1I0oaZKUO_P-R2nwnLF0t9E,2013 +sqlalchemy/dialects/mssql/base.py,sha256=l9vX6fK6DJEYA00N9uDnvSbqfgvxXfYUn2C4AF5T920,133649 +sqlalchemy/dialects/mssql/information_schema.py,sha256=ll0zAupJ4cPvhi9v5hTi7PQLU1lae4o6eQ5Vg7gykXQ,8074 +sqlalchemy/dialects/mssql/json.py,sha256=B0m6H08CKuk-yomDHcCwfQbVuVN2WLufuVueA_qb1NQ,4573 +sqlalchemy/dialects/mssql/provision.py,sha256=x7XRSQDxz4jz2uIpqwhuIXpL9bic0Vw7Mhy39HOkyqY,5013 +sqlalchemy/dialects/mssql/pymssql.py,sha256=BfJp9t-IQabqWXySJBmP9pwNTWnJqbjA2jJM9M4XeWc,4029 +sqlalchemy/dialects/mssql/pyodbc.py,sha256=qwZ8ByOTZ1WObjxeOravoJBSBX-s4RJ_PZ5VJ_Ch5Ws,27048 +sqlalchemy/dialects/mysql/__init__.py,sha256=btLABiNnmbWt9ziW-XgVWEB1qHWQcSFz7zxZNw4m_LY,2144 +sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/base.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/dml.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/expression.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/json.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-312.pyc,, +sqlalchemy/dialects/mysql/__pycache__/types.cpython-312.pyc,, +sqlalchemy/dialects/mysql/aiomysql.py,sha256=Zb-_F9Pzl0t-fT1bZwbNNne6jjCUqBXxeizbhMFPqls,9750 +sqlalchemy/dialects/mysql/asyncmy.py,sha256=zqupDz7AJihjv3E8w_4XAtq95d8stdrETNx60MLNVr0,9819 +sqlalchemy/dialects/mysql/base.py,sha256=q-DzkR_txwDTeWTEByzHAoIArYU3Bb5HT2Bnmuw7WIM,120688 +sqlalchemy/dialects/mysql/cymysql.py,sha256=5CQVJAlqQ3pT4IDGSQJH2hCzj-EWjUitA21MLqJwEEs,2291 +sqlalchemy/dialects/mysql/dml.py,sha256=qw0ZweHbMsbNyVSfC17HqylCnf7XAuIjtgofiWABT8k,7636 +sqlalchemy/dialects/mysql/enumerated.py,sha256=1L2J2wT6nQEmRS4z-jzZpoi44IqIaHgBRZZB9m55czo,8439 +sqlalchemy/dialects/mysql/expression.py,sha256=WW5G2XPwqJfXjuzHBt4BRP0pCLcPJkPD1mvZX1g0JL0,4066 +sqlalchemy/dialects/mysql/json.py,sha256=JlSFBAHhJ9JmV-3azH80xkLgeh7g6A6DVyNVCNZiKPU,2260 +sqlalchemy/dialects/mysql/mariadb.py,sha256=Sugyngvo6j6SfFFuJ23rYeFWEPdZ9Ji9guElsk_1WSQ,844 +sqlalchemy/dialects/mysql/mariadbconnector.py,sha256=F1VPosecC1hDZqjzZI29j4GUduyU4ewPwb-ekBQva5w,8725 +sqlalchemy/dialects/mysql/mysqlconnector.py,sha256=5glmkPhD_KP-Mci8ZXBr4yzqH1MDfzCJ9F_kZNyXcGo,5666 +sqlalchemy/dialects/mysql/mysqldb.py,sha256=R5BDiXiHX5oFuAOzyxZ6TYUTGzly-dulMeQLkeia6kk,9649 +sqlalchemy/dialects/mysql/provision.py,sha256=uPT6-BIoP_12XLmWAza1TDFNhOVVJ3rmQoMH7nvh-Vg,3226 +sqlalchemy/dialects/mysql/pymysql.py,sha256=d2-00IPoyEDkR9REQTE-DGEQrGshUq_0G5liZ5FiSEM,4032 +sqlalchemy/dialects/mysql/pyodbc.py,sha256=mkOvumrxpmAi6noZlkaTVKz2F7G5vLh2vx0cZSn9VTA,4288 +sqlalchemy/dialects/mysql/reflection.py,sha256=ak6E-eCP9346ixnILYNJcrRYblWbIT0sjXf4EqmfBsY,22556 +sqlalchemy/dialects/mysql/reserved_words.py,sha256=DsPHsW3vwOrvU7bv3Nbfact2Z_jyZ9xUTT-mdeQvqxo,9145 +sqlalchemy/dialects/mysql/types.py,sha256=i8DpRkOL1QhPErZ25AmCQOuFLciWhdjNL3I0CeHEhdY,24258 +sqlalchemy/dialects/oracle/__init__.py,sha256=pjk1aWi9XFCAHWNSJzSzmoIcL32-AkU_1J9IV4PtwpA,1318 +sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/base.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/dialects/oracle/__pycache__/types.cpython-312.pyc,, +sqlalchemy/dialects/oracle/base.py,sha256=u55_R9NrCRijud7ioHMxT-r0MSW0gMFjOwbrDdPgFsc,118036 +sqlalchemy/dialects/oracle/cx_oracle.py,sha256=L0GvcB6xb0-zyv5dx3bpQCeptp0KSqH6g9FUQ4y-d-g,55108 +sqlalchemy/dialects/oracle/dictionary.py,sha256=iUoyFEFM8z0sfVWR2n_nnre14kaQkV_syKO0R5Dos4M,19487 +sqlalchemy/dialects/oracle/oracledb.py,sha256=_-fUQ94xai80B7v9WLVGoGDIv8u54nVspBdyGEyI76g,3457 +sqlalchemy/dialects/oracle/provision.py,sha256=5cvIc3yTWxz4AIRYxcesbRJ1Ft-zT9GauQ911yPnN2o,8055 +sqlalchemy/dialects/oracle/types.py,sha256=TeOhUW5W9qZC8SaJ-9b3u6OvOPOarNq4MmCQ7l3wWX0,8204 +sqlalchemy/dialects/postgresql/__init__.py,sha256=bZEPsLbRtB7s6TMQAHCIzKBgkxUa3eDXvCkeARua37E,3734 +sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/array.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/base.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/json.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/__pycache__/types.cpython-312.pyc,, +sqlalchemy/dialects/postgresql/_psycopg_common.py,sha256=U3aWzbKD3VOj6Z6r-4IsIQmtjGGIB4RDZH6NXfd8Xz0,5655 +sqlalchemy/dialects/postgresql/array.py,sha256=tLyU9GDAeIypNhjTuFQUYbaTeijVM1VVJS6UdzzXXn4,13682 +sqlalchemy/dialects/postgresql/asyncpg.py,sha256=XNaoOZ5Da4-jUTaES1zEOTEW3WG8UKyVCoIS3LsFhzE,39967 +sqlalchemy/dialects/postgresql/base.py,sha256=DGhaquFJWDQL7wIvQ2EE57LxD7zGR06BKQxvNZHFLgY,175634 +sqlalchemy/dialects/postgresql/dml.py,sha256=_He69efdpDA5gGmBsE7Lo4ViSi3QnR38BiFmrR1tw6k,11203 +sqlalchemy/dialects/postgresql/ext.py,sha256=oPP22Pq-n2lMmQ8ahifYmsmzRhSiSv1RV-xrTT0gycw,16253 +sqlalchemy/dialects/postgresql/hstore.py,sha256=q5x0npbAMI8cdRFGTMwLoWFj9P1G9DUkw5OEUCfTXpI,11532 +sqlalchemy/dialects/postgresql/json.py,sha256=panGtnEbcirQDy4yR2huWydFqa_Kmv8xhpLyf-SSRWE,11203 +sqlalchemy/dialects/postgresql/named_types.py,sha256=zNoHsP3nVq5xxA7SOQ6LLDwYZEHFciZ-nDjw_I9f_G0,17092 +sqlalchemy/dialects/postgresql/operators.py,sha256=MB40xq1124OnhUzkvtbnTmxEiey0VxMOYyznF96wwhI,2799 +sqlalchemy/dialects/postgresql/pg8000.py,sha256=w6pJ3LaIKWmnwvB0Pr1aTJX5OKNtG5RNClVfkE019vU,18620 +sqlalchemy/dialects/postgresql/pg_catalog.py,sha256=0lLnIgxfCrqkx_LNijMxo0trNLsodcd8KwretZIj4uM,8875 +sqlalchemy/dialects/postgresql/provision.py,sha256=oxyAzs8_PhuK0ChivXC3l2Nldih3_HKffvGsZqD8XWI,5509 +sqlalchemy/dialects/postgresql/psycopg.py,sha256=YMubzQHMYN1By8QJScIPb_PwNiACv6srddQ6nX6WltQ,22238 +sqlalchemy/dialects/postgresql/psycopg2.py,sha256=3Xci4bTA2BvhrZAQa727uFWdaXEZmvfD-Z-upE3NyQE,31592 +sqlalchemy/dialects/postgresql/psycopg2cffi.py,sha256=2EOuDwBetfvelcPoTzSwOHe6X8lTwaYH7znNzXJt9eM,1739 +sqlalchemy/dialects/postgresql/ranges.py,sha256=yHB1BRlUreQPZB3VEn0KMMLf02zjf5jjYdmg4N4S2Sw,30220 +sqlalchemy/dialects/postgresql/types.py,sha256=l24rs8_nK4vqLyQC0aUkf4S7ecw6T_7Pgq50Icc5CBs,7292 +sqlalchemy/dialects/sqlite/__init__.py,sha256=wnZ9vtfm0QXmth1jiGiubFgRiKxIoQoNthb1bp4FhCs,1173 +sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/base.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/json.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-312.pyc,, +sqlalchemy/dialects/sqlite/aiosqlite.py,sha256=GZJioZLot0D5CQ6ovPQoqv2iV8FAFm3G75lEFCzopoE,12296 +sqlalchemy/dialects/sqlite/base.py,sha256=YYEB5BeuemLC3FAR7EB8vA0zoUOwHTKoF_srvnAStps,96785 +sqlalchemy/dialects/sqlite/dml.py,sha256=PYESBj8Ip7bGs_Fi7QjbWLXLnU9a-SbP96JZiUoZNHg,8434 +sqlalchemy/dialects/sqlite/json.py,sha256=XFPwSdNx0DxDfxDZn7rmGGqsAgL4vpJbjjGaA73WruQ,2533 +sqlalchemy/dialects/sqlite/provision.py,sha256=O4JDoybdb2RBblXErEVPE2P_5xHab927BQItJa203zU,5383 +sqlalchemy/dialects/sqlite/pysqlcipher.py,sha256=_JuOCoic--ehAGkCgnwUUKKTs6xYoBGag4Y_WkQUDwU,5347 +sqlalchemy/dialects/sqlite/pysqlite.py,sha256=xBg6DKqvml5cCGxVSAQxR1dcMvso8q4uyXs2m4WLzz0,27891 +sqlalchemy/dialects/type_migration_guidelines.txt,sha256=-uHNdmYFGB7bzUNT6i8M5nb4j6j9YUKAtW4lcBZqsMg,8239 +sqlalchemy/engine/__init__.py,sha256=fJCAl5P7JH9iwjuWo72_3LOIzWWhTnvXqzpAmm_T0fY,2818 +sqlalchemy/engine/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/_py_processors.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/_py_row.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/_py_util.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/base.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/characteristics.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/create.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/cursor.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/default.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/events.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/interfaces.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/mock.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/processors.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/reflection.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/result.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/row.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/strategies.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/url.cpython-312.pyc,, +sqlalchemy/engine/__pycache__/util.cpython-312.pyc,, +sqlalchemy/engine/_py_processors.py,sha256=RSVKm9YppSBDSCEi8xvbZdRCP9EsCYfbyEg9iDCMCiI,3744 +sqlalchemy/engine/_py_row.py,sha256=Zdta0JGa7V2aV04L7nzXUEp-H1gpresKyBlneQu60pk,3549 +sqlalchemy/engine/_py_util.py,sha256=5m3MZbEqnUwP5kK_ghisFpzcXgBwSxTSkBEFB6afiD8,2245 +sqlalchemy/engine/base.py,sha256=RbIfWZ1Otyb4VzMYjDpK5BiDIE8QZwa4vQgRX0yCa28,122246 +sqlalchemy/engine/characteristics.py,sha256=YvMgrUVAt3wsSiQ0K8l44yBjFlMK3MGajxhg50t5yFM,2344 +sqlalchemy/engine/create.py,sha256=8372TLpy4FOAIZ9WmuNkx1v9DPgwpoCAH9P7LNXZCwY,32629 +sqlalchemy/engine/cursor.py,sha256=6e1Tp63r0Kt-P4pEaYR7wUew2aClTdKAEI-FoAAxJxE,74405 +sqlalchemy/engine/default.py,sha256=bi--ytxYJ0EtsCudl38owGtytnwTHX-PjlsYTFe8LpA,84065 +sqlalchemy/engine/events.py,sha256=PQyc_sbmqks6pqyN7xitO658KdKzzJWfW1TKYwEd5vo,37392 +sqlalchemy/engine/interfaces.py,sha256=pAFYR15f1Z_-qdzTYI4mAm8IYbD6maLBKbG3pBaJ8Us,112824 +sqlalchemy/engine/mock.py,sha256=ki4ud7YrUrzP2katdkxlJGFUKB2kS7cZZAHK5xWsNF8,4179 +sqlalchemy/engine/processors.py,sha256=ENN6XwndxJPW-aXPu_3NzAZsy5SvNznHoa1Qn29ERAw,2383 +sqlalchemy/engine/reflection.py,sha256=2aakNheQJNMUXZbhY8s1NtqGoGWTxM2THkJlMMfiX_s,75125 +sqlalchemy/engine/result.py,sha256=shRAsboHPTvKR38ryGgC4KLcUeVTbABSlWzAfOUKVZs,77841 +sqlalchemy/engine/row.py,sha256=doiXKaUI6s6OkfqPIwNyTPLllxJfR8HYgEI8ve9VYe0,11955 +sqlalchemy/engine/strategies.py,sha256=HjCj_FHQOgkkhhtnVmcOEuHI_cftNo3P0hN5zkhZvDc,442 +sqlalchemy/engine/url.py,sha256=_WNE7ia0JIPRc1PLY_jSA3F7bB5kp1gzuzkc5eoKviA,30694 +sqlalchemy/engine/util.py,sha256=3-ENI9S-3KLWr0GW27uWQfsvCJwMBGTKbykkKPUgiAE,5667 +sqlalchemy/event/__init__.py,sha256=CSBMp0yu5joTC6tWvx40B4p87N7oGKxC-ZLx2ULKUnQ,997 +sqlalchemy/event/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/event/__pycache__/api.cpython-312.pyc,, +sqlalchemy/event/__pycache__/attr.cpython-312.pyc,, +sqlalchemy/event/__pycache__/base.cpython-312.pyc,, +sqlalchemy/event/__pycache__/legacy.cpython-312.pyc,, +sqlalchemy/event/__pycache__/registry.cpython-312.pyc,, +sqlalchemy/event/api.py,sha256=nQAvPK1jrLpmu8aKCUtc-vYWcIuG-1FgAtp3GRkfIiI,8227 +sqlalchemy/event/attr.py,sha256=NMe_sPQTju2PE-f68C8TcKJGW-Gxyi1CLXumAmE368Y,20438 +sqlalchemy/event/base.py,sha256=Cr_PNJlCYJSU3rtT8DkplyjBRb-E2Wa3OAeK9woFJkk,14980 +sqlalchemy/event/legacy.py,sha256=OpPqE64xk1OYjLW1scvc6iijhoa5GZJt5f7-beWhgOc,8211 +sqlalchemy/event/registry.py,sha256=Zig9q2Galo8kO2aqr7a2rNAhmIkdJ-ntHSEcM5MfSgw,10833 +sqlalchemy/events.py,sha256=pRcPKKsPQHGPH_pvTtKRmzuEIy-QHCtkUiZl4MUbxKs,536 +sqlalchemy/exc.py,sha256=4SMKOJtz7_SWt5vskCSeXSi4ZlFyL4jh53Q8sk4-ODQ,24011 +sqlalchemy/ext/__init__.py,sha256=w4h7EpXjKPr0LD4yHa0pDCfrvleU3rrX7mgyb8RuDYQ,322 +sqlalchemy/ext/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/associationproxy.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/automap.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/baked.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/compiler.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/horizontal_shard.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/hybrid.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/indexable.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/instrumentation.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/mutable.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/orderinglist.cpython-312.pyc,, +sqlalchemy/ext/__pycache__/serializer.cpython-312.pyc,, +sqlalchemy/ext/associationproxy.py,sha256=5voNXWIJYGt6c8mwuSA6alm3SmEHOZ-CVK8ikgfzk8s,65960 +sqlalchemy/ext/asyncio/__init__.py,sha256=iG_0TmBO1pCB316WS-p17AImwqRtUoaKo7UphYZ7bYw,1317 +sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/base.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/engine.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/exc.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/result.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-312.pyc,, +sqlalchemy/ext/asyncio/__pycache__/session.cpython-312.pyc,, +sqlalchemy/ext/asyncio/base.py,sha256=PXF4YqfRi2-mADAtaL2_-Uv7CzoBVojPbzyA5phJ9To,8959 +sqlalchemy/ext/asyncio/engine.py,sha256=h4pe3ixuX6YfI97B5QWo2V4_CCCnOvM_EHPZhX19Mgc,47796 +sqlalchemy/ext/asyncio/exc.py,sha256=1hCdOKzvSryc_YE4jgj0l9JASOmZXutdzShEYPiLbGI,639 +sqlalchemy/ext/asyncio/result.py,sha256=zETerVB53gql1DL6tkO_JiqeU-m1OM-8kX0ULxmoL_I,30554 +sqlalchemy/ext/asyncio/scoping.py,sha256=cBNluB7n_lwdAAo6pySbvNRqPN7UBzwQHZ6XhRDyWgA,52685 +sqlalchemy/ext/asyncio/session.py,sha256=yWwhI5i_yVWjykxmxkcP3-xmw3UpoGYNhHZL8sYXQMA,62998 +sqlalchemy/ext/automap.py,sha256=7p13-VpN0MOM525r7pmEnftedya9l5G-Ei_cFXZfpTc,61431 +sqlalchemy/ext/baked.py,sha256=R8ZAxiVN6eH50AJu0O3TtFXNE1tnRkMlSj3AvkcWFhY,17818 +sqlalchemy/ext/compiler.py,sha256=h7eR0NcPJ4F_k8YGRP3R9YX75Y9pgiVxoCjRyvceF7g,20391 +sqlalchemy/ext/declarative/__init__.py,sha256=VJu8S1efxil20W48fJlpDn6gHorOudn5p3-lF72WcJ8,1818 +sqlalchemy/ext/declarative/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/ext/declarative/__pycache__/extensions.cpython-312.pyc,, +sqlalchemy/ext/declarative/extensions.py,sha256=vwZjudPFA_mao1U04-RZCaU_tvPMBgQa5OTmSI7K7SU,19547 +sqlalchemy/ext/horizontal_shard.py,sha256=eh14W8QWHYH22PL1l5qF_ad9Fyh1WAFjKi_vNfsme94,16766 +sqlalchemy/ext/hybrid.py,sha256=98D72WBmlileYBtEKMSNF9l-bwRavThSV8-LyB2gjo0,52499 +sqlalchemy/ext/indexable.py,sha256=RkG9BKwil-TqDjVBM14ML9c-geUrHxtRKpYkSJEwGHA,11028 +sqlalchemy/ext/instrumentation.py,sha256=rjjSbTGilYeGLdyEWV932TfTaGxiVP44_RajinANk54,15723 +sqlalchemy/ext/mutable.py,sha256=d3Pp8PcAVN4pHN9rhc1ReXBWe0Q70Q5S1klFoYGyDPA,37393 +sqlalchemy/ext/mypy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sqlalchemy/ext/mypy/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/apply.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/infer.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/names.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/plugin.cpython-312.pyc,, +sqlalchemy/ext/mypy/__pycache__/util.cpython-312.pyc,, +sqlalchemy/ext/mypy/apply.py,sha256=uUES4grydYtKykLKlxzJeBXeGe8kfWou9_rzEyEkfp0,10503 +sqlalchemy/ext/mypy/decl_class.py,sha256=Ls2Efh4kEhle6Z4VMz0GRBgGQTYs2fHr5b4DfuDj44c,17377 +sqlalchemy/ext/mypy/infer.py,sha256=si720RW6iGxMRZNP5tcaIxA1_ehFp215TzxVXaLjglU,19364 +sqlalchemy/ext/mypy/names.py,sha256=tch4f5fDmdv4AWWFzXgGZdCpxmae59XRPT02KyMvrEI,10625 +sqlalchemy/ext/mypy/plugin.py,sha256=fLXDukvZqbJ0JJCOoyZAuOniYZ_F1YT-l9gKppu8SEs,9750 +sqlalchemy/ext/mypy/util.py,sha256=TlEQq4bcs8ARLL3PoFS8Qw6oYFeMqcGnWTeJ7NsPPFk,9408 +sqlalchemy/ext/orderinglist.py,sha256=8Vcg7UUkLg-QbYAbLVDSqu-5REkR6L-FLLhCYsHYxCQ,14384 +sqlalchemy/ext/serializer.py,sha256=ox6dbMOBmFR0H2RQFt17mcYBOGKgn1cNVFfqY8-jpgQ,6178 +sqlalchemy/future/__init__.py,sha256=79DZx3v7TQZpkS_qThlmuCOm1a9UK2ObNZhyMmjfNB0,516 +sqlalchemy/future/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/future/__pycache__/engine.cpython-312.pyc,, +sqlalchemy/future/engine.py,sha256=6uOpOedIqiT1-3qJSJIlv9_raMJU8NTkhQwN_Ngg8kI,499 +sqlalchemy/inspection.py,sha256=i3aR-IV101YU8D9TA8Pxb2wi08QZuJ34sMy6L5M__rY,5145 +sqlalchemy/log.py,sha256=aSlZ8DFHkOuI-AMmaOUUYtS9zGPadi_7tAo98QpUOiY,8634 +sqlalchemy/orm/__init__.py,sha256=cBn0aPWyDFY4ya-cHRshQBcuThk1smTUCTrlp6LHdlE,8463 +sqlalchemy/orm/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/_orm_constructors.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/_typing.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/attributes.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/base.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/bulk_persistence.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/clsregistry.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/collections.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/context.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/decl_api.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/decl_base.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/dependency.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/descriptor_props.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/dynamic.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/evaluator.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/events.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/exc.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/identity.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/instrumentation.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/interfaces.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/loading.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/mapped_collection.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/mapper.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/path_registry.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/persistence.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/properties.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/query.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/relationships.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/scoping.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/session.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/state.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/state_changes.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/strategies.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/strategy_options.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/sync.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/unitofwork.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/util.cpython-312.pyc,, +sqlalchemy/orm/__pycache__/writeonly.cpython-312.pyc,, +sqlalchemy/orm/_orm_constructors.py,sha256=_7_GY6qw2sA-GG_WXLz1GOO-0qC-SCBeA43GhVuS2Qw,99803 +sqlalchemy/orm/_typing.py,sha256=oRUJVAGpU3_DhSkIb1anXgneweVIARjB51HlPhMNfcM,5015 +sqlalchemy/orm/attributes.py,sha256=NFhYheqqu2VcXmKTdcvQKiRR_6qo0rHLK7nda7rpviA,92578 +sqlalchemy/orm/base.py,sha256=iZXsygk4fn8wd7wx1iXn_PfnGDY7d41YRfS0mC_q5vE,27700 +sqlalchemy/orm/bulk_persistence.py,sha256=S9VK5a6GSqnw3z7O5UG5OOnc9WxzmS_ooDkA5JmCIsY,69878 +sqlalchemy/orm/clsregistry.py,sha256=4J-kKshmLOEyx3VBqREm2k_XY0cer4zwUoHJT3n5Xmw,17949 +sqlalchemy/orm/collections.py,sha256=0AZFr9us9MiHo_Xcyi7DUsN02jSBERUOd-jIK8qQ1DA,52159 +sqlalchemy/orm/context.py,sha256=VyJl1ZJ5OnJUACKlM-bPLyyoqu4tyaKKdxeC-QF4EuU,111698 +sqlalchemy/orm/decl_api.py,sha256=a2Cyvjh6j5BlXJQ2i0jpQx7xkeI_6xo5MMxr0d2ndQY,63589 +sqlalchemy/orm/decl_base.py,sha256=g9xW9G-n9iStMI0i3i-9Rt4LDRW8--3iCCRPlWF6Cko,81660 +sqlalchemy/orm/dependency.py,sha256=g3R_1H_OGzagXFeen3Irm3c1lO3yeXGdGa0muUZgZAk,47583 +sqlalchemy/orm/descriptor_props.py,sha256=SdrfVu05zhWLGe_DnBlgbU6e5sWkkfBTirH9Nrr1MLk,37176 +sqlalchemy/orm/dynamic.py,sha256=pYlMIrpp80Ex4KByqdyhx0x0kIrl_cIADwkeVxvYu4s,9798 +sqlalchemy/orm/evaluator.py,sha256=jPjVrP7XbVOG6aXTCBREq0rF3oNHLqB4XAT-gt_cpaA,11925 +sqlalchemy/orm/events.py,sha256=fGnUHwDTV9FTiifB2mmIJispwPbIT4mZongRJD7uiw4,127258 +sqlalchemy/orm/exc.py,sha256=A3wvZVs5sC5XCef4LoTUBG-UfhmliFpU9rYMdS2t_To,7356 +sqlalchemy/orm/identity.py,sha256=gRiuQSrurHGEAJXH9QGYioXL49Im5EGcYQ-IKUEpHmQ,9249 +sqlalchemy/orm/instrumentation.py,sha256=o1mTv5gCgl9d-SRvEXXjl8rzl8uBasRL3bpDgWg9P58,24337 +sqlalchemy/orm/interfaces.py,sha256=RW7bBXGWtZHY2wXFOSqtvYm6UDl7yHZUyRX_6Yd3GfQ,48395 +sqlalchemy/orm/loading.py,sha256=F1ZEHTPBglmznST2nGj_0ARccoFgTyaOOwjcqpYeuvM,57366 +sqlalchemy/orm/mapped_collection.py,sha256=ZgYHaF37yo6-gZ7Da1Gg25rMgG2GynAy-RJoDhljV5g,19698 +sqlalchemy/orm/mapper.py,sha256=kyq4pBkTvvEqlW4H4XK_ktP1sOiALNAycgvF5f-xtqw,170969 +sqlalchemy/orm/path_registry.py,sha256=olyutgn0uNB7Wi32YNQx9ZHV6sUgV3TbyGplfSxfZ6g,25938 +sqlalchemy/orm/persistence.py,sha256=qr1jUgo-NZ0tLa5eIis2271QDt4KNJwYlYU_9CaKNhQ,60545 +sqlalchemy/orm/properties.py,sha256=dt1Gy06pbRY6zgm4QGR9nU6z2WCyoTZWBJYKpUhLq_c,29095 +sqlalchemy/orm/query.py,sha256=VBSD0k15xU_XykggvLGAwGdwNglBAoBKbOk8qAoMKdI,117714 +sqlalchemy/orm/relationships.py,sha256=wrHyICb8A5qPoyxf-nITQVJ13kCNr2MedDqEY8QMSt8,127816 +sqlalchemy/orm/scoping.py,sha256=75iPEWDFhPcIXgl8EUd_sPTCL6punfegEaTRE5mP3e8,78835 +sqlalchemy/orm/session.py,sha256=TeBcZNdY4HWQFdXNCIqbsQTtkvfJkBweMzvA9p3BiPA,193279 +sqlalchemy/orm/state.py,sha256=EaWkVNWHaDeJ_FZGXHakSamUk51BXmtMWLGdFhlJmh8,37536 +sqlalchemy/orm/state_changes.py,sha256=pqkjSDOR6H5BufMKdzFUIatDp3DY90SovOJiJ1k6Ayw,6815 +sqlalchemy/orm/strategies.py,sha256=V0o-1kB1IVTxhOGqGtRyjddZqAbPdsl_h-k0N3MKCGo,114052 +sqlalchemy/orm/strategy_options.py,sha256=EmgH28uMQhwwBCDVcXmywLk_Q8AbpnK02seMsMV4nmc,84102 +sqlalchemy/orm/sync.py,sha256=5Nt_OqP4IfhAtHwFRar4dw-YjLENRLvp4d3jDC4wpnw,5749 +sqlalchemy/orm/unitofwork.py,sha256=Wk5YZocBbxe4m1wU2aFQ7gY1Cp5CROi13kDEM1iOSz4,27033 +sqlalchemy/orm/util.py,sha256=7hCRYbQjqhWJTkrPf_NXY9zF_18VWTpyguu-nfYfc6c,80340 +sqlalchemy/orm/writeonly.py,sha256=WCPXCAwHqVCfhVWXQEFCP3OocIiHgqNJ5KnuJwSgGq4,22329 +sqlalchemy/pool/__init__.py,sha256=CIv4b6ctueY7w3sML_LxyLKAdl59esYOhz3O7W5w7WE,1815 +sqlalchemy/pool/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/pool/__pycache__/base.cpython-312.pyc,, +sqlalchemy/pool/__pycache__/events.cpython-312.pyc,, +sqlalchemy/pool/__pycache__/impl.cpython-312.pyc,, +sqlalchemy/pool/base.py,sha256=wuwKIak5d_4-TqKI2RFN8OYMEyOvV0djnoSVR8gbxAQ,52249 +sqlalchemy/pool/events.py,sha256=IcWfORKbHM69Z9FdPJlXI7-NIhQrR9O_lg59tiUdTRU,13148 +sqlalchemy/pool/impl.py,sha256=vU0n82a7uxdE34p3hU7cvUDA5QDy9MkIv1COT4kYFP8,17724 +sqlalchemy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sqlalchemy/schema.py,sha256=mt74CGCBtfv_qI1_6zzNFMexYGyWDj2Jkh-XdH4kEWI,3194 +sqlalchemy/sql/__init__.py,sha256=jAQx9rwhyPhoSjntM1BZSElJiMRmLowGThJVDGvExSU,5820 +sqlalchemy/sql/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_dml_constructors.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_elements_constructors.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_orm_types.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_py_util.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/_typing.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/annotation.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/base.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/cache_key.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/coercions.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/compiler.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/crud.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/ddl.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/default_comparator.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/dml.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/elements.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/events.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/expression.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/functions.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/lambdas.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/naming.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/operators.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/roles.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/schema.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/selectable.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/sqltypes.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/traversals.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/type_api.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/util.cpython-312.pyc,, +sqlalchemy/sql/__pycache__/visitors.cpython-312.pyc,, +sqlalchemy/sql/_dml_constructors.py,sha256=hoNyINY3FNi1ZQajR6lbcRN7oYsNghM1wuzzVWxIv3c,3867 +sqlalchemy/sql/_elements_constructors.py,sha256=-qksx59Gqhmzxo1xByPtZZboNvL8uYcCN14pjHYHxL8,62914 +sqlalchemy/sql/_orm_types.py,sha256=_vR3_HQYgZR_of6_ZpTQByie2gaVScxQjVAVWAP3Ztg,620 +sqlalchemy/sql/_py_util.py,sha256=iiwgX3dQhOjdB5-10jtgHPIdibUqGk49bC1qdZMBpYI,2173 +sqlalchemy/sql/_selectable_constructors.py,sha256=RDqgejqiUuU12Be1jBpMIx_YdJho8fhKfnMoJLPFTFE,18812 +sqlalchemy/sql/_typing.py,sha256=C8kNZQ3TIpM-Q12Of3tTaESB1UxIfRME_lXouqgwMT8,12252 +sqlalchemy/sql/annotation.py,sha256=pTNidcQatCar6H1I9YAoPP1e6sOewaJ15B7_-7ykZOE,18271 +sqlalchemy/sql/base.py,sha256=dVvZoPoa3pb6iuwTU4QoCvVWQPyHZthaekl5J2zV_SU,73928 +sqlalchemy/sql/cache_key.py,sha256=Dl163qHjTkMCa5LTipZud8X3w0d8DvdIvGvv4AqriHE,32823 +sqlalchemy/sql/coercions.py,sha256=ju8xEi7b9G_GzxaQ6Nwu0cFIWFZ--ottIVfdiuhHY7Y,40553 +sqlalchemy/sql/compiler.py,sha256=9Wx423H72Yq7NHR8cmMAH6GpMCJmghs1L85YJqs_Lng,268763 +sqlalchemy/sql/crud.py,sha256=nyAPlmvuyWxMqSBdWPffC5P3CGXTQKK0bJoDbNgB3iQ,56457 +sqlalchemy/sql/ddl.py,sha256=XuUhulJLvvPjU4nYD6N42QLg8rEgquD6Jwn_yIHZejk,45542 +sqlalchemy/sql/default_comparator.py,sha256=SE0OaK1BlY0RinQ21ZXJOUGkO00oGv6GMMmAH-4iNTQ,16663 +sqlalchemy/sql/dml.py,sha256=eftbzdFJgMk7NV0BHKfK4dQ2R7XsyyJn6fCgYFJ0KNQ,65728 +sqlalchemy/sql/elements.py,sha256=dsNa2K57RygsGoaWuTMPp2QQ6SU3uZXSMW6CLGBbcIY,171208 +sqlalchemy/sql/events.py,sha256=xe3vJ6pQJau3dJWBAY0zU7Lz52UKuMrpLycriLm3AWA,18301 +sqlalchemy/sql/expression.py,sha256=baMnCH04jeE8E3tA2TovXlsREocA2j3fdHKnzOB8H4U,7586 +sqlalchemy/sql/functions.py,sha256=AcI_KstJxeLw6rEXx6QnIgR2rq4Ru6RXMbq4EIIUURA,55319 +sqlalchemy/sql/lambdas.py,sha256=EfDdUBi5cSmkjz8pQCSRo858UWQCFNZxXkM-1qS0CgU,49281 +sqlalchemy/sql/naming.py,sha256=l8udFP2wvXLgehIB0uF2KXwpkXSVSREDk6fLCH9F-XY,6865 +sqlalchemy/sql/operators.py,sha256=BYATjkBQLJAmwHAlGUSV-dv9RLtGw_ziAvFbKDrN4YU,76107 +sqlalchemy/sql/roles.py,sha256=71zm_xpRkUdnu-WzG6lxQVnFHwvUjf6X6e3kRIkbzAs,7686 +sqlalchemy/sql/schema.py,sha256=TOBTbcRY6ehosJEcpYn2NX0_UGZP9lfFs-o8lJVc5tI,228104 +sqlalchemy/sql/selectable.py,sha256=9dO2yhN83zjna7nPjOE1hcvGyJGjc_lj5SAz7SP5CBQ,233041 +sqlalchemy/sql/sqltypes.py,sha256=_0FpFLH0AFueb3TIB5Vcx9nXWDNj31XFQTP0u8OXnSo,126540 +sqlalchemy/sql/traversals.py,sha256=7b98JSeLxqecmGHhhLXT_2M4QMke6W-xCci5RXndhxI,33521 +sqlalchemy/sql/type_api.py,sha256=D9Kq-ppwZvlNmxaHqvVmM8IVg4n6_erzJpVioye9WKE,83823 +sqlalchemy/sql/util.py,sha256=lBEAf_-eRepTErOBCp1PbEMZDYdJqAiK1GemQtgojYo,48175 +sqlalchemy/sql/visitors.py,sha256=KD1qOYm6RdftCufVGB8q6jFTIZIQKS3zPCg78cVV0mQ,36427 +sqlalchemy/testing/__init__.py,sha256=9M2SMxBBLJ8xLUWXNCWDzkcvOqFznWcJzrSd712vATU,3126 +sqlalchemy/testing/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/assertions.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/assertsql.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/asyncio.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/config.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/engines.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/entities.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/exclusions.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/pickleable.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/profiling.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/provision.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/requirements.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/schema.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/util.cpython-312.pyc,, +sqlalchemy/testing/__pycache__/warnings.cpython-312.pyc,, +sqlalchemy/testing/assertions.py,sha256=lNNZ-gfF4TDRXmB7hZDdch7JYZRb_qWGeqWDFKtopx0,31439 +sqlalchemy/testing/assertsql.py,sha256=EIVk3i5qjiSI63c1ikTPoGhulZl88SSeOS2VNo1LJvM,16817 +sqlalchemy/testing/asyncio.py,sha256=cAw68tzu3h5wjdIKfOqhFATcbMb38XeK0ThjIalUHuQ,3728 +sqlalchemy/testing/config.py,sha256=MZOWz7wqzc1pbwHWSAR0RJkt2C-SD6ox-nYY7VHdi_U,12030 +sqlalchemy/testing/engines.py,sha256=w5-0FbanItRsOt6x4n7wM_OnToCzJnrvZZ2hk5Yzng8,13355 +sqlalchemy/testing/entities.py,sha256=rysywsnjXHlIIC-uv0L7-fLmTAuNpHJvcSd1HeAdY5M,3354 +sqlalchemy/testing/exclusions.py,sha256=uoYLEwyNOK1eR8rpfOZ2Q3dxgY0akM-RtsIFML-FPrY,12444 +sqlalchemy/testing/fixtures/__init__.py,sha256=9snVns5A7g28LqC6gqQuO4xRBoJzdnf068GQ6Cae75I,1198 +sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/testing/fixtures/__pycache__/base.cpython-312.pyc,, +sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-312.pyc,, +sqlalchemy/testing/fixtures/__pycache__/orm.cpython-312.pyc,, +sqlalchemy/testing/fixtures/__pycache__/sql.cpython-312.pyc,, +sqlalchemy/testing/fixtures/base.py,sha256=OayRr25soCqj1_yc665D5XbWWzFCm7Xl9Txtps953p4,12256 +sqlalchemy/testing/fixtures/mypy.py,sha256=7fWVZzYzNjqmLIoFa-MmXSGDPS3eZYFXlH-WxaxBDDY,11845 +sqlalchemy/testing/fixtures/orm.py,sha256=x27qjpK54JETATcYuiphtW-HXRy8ej8h3aCDkeQXPfY,6095 +sqlalchemy/testing/fixtures/sql.py,sha256=Q7Qq0n4qTT681nWt5DqjThopgjv5BB2KmSmrmAxUqHM,15704 +sqlalchemy/testing/pickleable.py,sha256=B9dXGF7E2PywB67SngHPjSMIBDTFhyAV4rkDUcyMulk,2833 +sqlalchemy/testing/plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +sqlalchemy/testing/plugin/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-312.pyc,, +sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-312.pyc,, +sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-312.pyc,, +sqlalchemy/testing/plugin/bootstrap.py,sha256=GrBB27KbswjE3Tt-zJlj6uSqGh9N-_CXkonnJSSBz84,1437 +sqlalchemy/testing/plugin/plugin_base.py,sha256=4SizjghFdDddt5o5gQ16Nw0bJHrtuBa4smxJcea-ti8,21573 +sqlalchemy/testing/plugin/pytestplugin.py,sha256=yh4PP406O0TwPMDzpJHpcNdU2WHXCLYI10F3oOLePjE,27295 +sqlalchemy/testing/profiling.py,sha256=HPjYvRLT1nD90FCZ7AA8j9ygkMtf1SGA47Xze2QPueo,10148 +sqlalchemy/testing/provision.py,sha256=w4F_ceGHPpWHUeh6cVcE5ktCC-ISrGc2yOSnXauOd5U,14200 +sqlalchemy/testing/requirements.py,sha256=gkviA8f5p4qdoDwAK791I4oGvnEqlm0ZZwJZpJzobFY,51393 +sqlalchemy/testing/schema.py,sha256=OSfMoIJ7ORbevGkeJdrKcTrQ0s7wXebuCU08mC1Y9jA,6513 +sqlalchemy/testing/suite/__init__.py,sha256=_firVc2uS3TMZ3vH2baQzNb17ubM78RHtb9kniSybmk,476 +sqlalchemy/testing/suite/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_cte.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_insert.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_results.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_select.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_types.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-312.pyc,, +sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-312.pyc,, +sqlalchemy/testing/suite/test_cte.py,sha256=O5idVeBnHm9zdiG3tuCBUn4hYU_TA63-6LNnRygr8g0,6205 +sqlalchemy/testing/suite/test_ddl.py,sha256=xWimTjggpTe3S1Xfmt_IPofTXkUUcKuVSVCIfIyGMbA,11785 +sqlalchemy/testing/suite/test_deprecations.py,sha256=XI8ZU1NxC-6uvPDImaaq9O7Ov6MF5gmy-yk3TfesLAo,5082 +sqlalchemy/testing/suite/test_dialect.py,sha256=HUpHZb7pnHbsoRpDLONpsCO_oWhBgjglU9pBO-EOUw4,22673 +sqlalchemy/testing/suite/test_insert.py,sha256=Wm_pW0qqUNV1Fs7mXoxtmaTHMQGmaVDgDsYgZs1jlxM,18308 +sqlalchemy/testing/suite/test_reflection.py,sha256=Nd4Ao_J3Sr-VeAeWbUe3gs6STPvik9DC37WkyJc-PVg,106205 +sqlalchemy/testing/suite/test_results.py,sha256=Hd6R4jhBNNQSp0xGa8wwTgpw-XUrCEZ3dWXpoZ4_DKs,15687 +sqlalchemy/testing/suite/test_rowcount.py,sha256=zhKVv0ibFSQmnE5luLwgHAn840zOJ6HxtkR3oL995cs,7652 +sqlalchemy/testing/suite/test_select.py,sha256=QHsBX16EZpxlEZZLM0pMNcwayPU0dig39McKwiiith0,58325 +sqlalchemy/testing/suite/test_sequence.py,sha256=c80CBWrU930GPnPfr9TCRbTTuITR7BpIactncLIj2XU,9672 +sqlalchemy/testing/suite/test_types.py,sha256=QjV48MqR7dB8UVzt56UL2z7Nt28-IhywX3DKuQeLYsY,65429 +sqlalchemy/testing/suite/test_unicode_ddl.py,sha256=7obItCpFt4qlWaDqe25HWgQT6FoUhgz1W7_Xycfz9Xk,5887 +sqlalchemy/testing/suite/test_update_delete.py,sha256=1hT0BTxB4SNipd6hnVlMnq25dLtQQoXov7z7UR0Sgi8,3658 +sqlalchemy/testing/util.py,sha256=Wsu4GZgCW6wX9mmxfiffhDz1cZm3778OB3LtiWNgb3Y,14080 +sqlalchemy/testing/warnings.py,sha256=pmfT33PF1q1PI7DdHOsup3LxHq1AC4-aYl1oL8HmrYo,1546 +sqlalchemy/types.py,sha256=DgBpPaT-vtsn6_glx5wocrIhR2A1vy56SQNRY3NiPUw,3168 +sqlalchemy/util/__init__.py,sha256=Bh0SkfkeCsz6-rbDmC41lAWOuCvKCiXVZthN2cWJEXk,8245 +sqlalchemy/util/__pycache__/__init__.cpython-312.pyc,, +sqlalchemy/util/__pycache__/_collections.cpython-312.pyc,, +sqlalchemy/util/__pycache__/_concurrency_py3k.cpython-312.pyc,, +sqlalchemy/util/__pycache__/_has_cy.cpython-312.pyc,, +sqlalchemy/util/__pycache__/_py_collections.cpython-312.pyc,, +sqlalchemy/util/__pycache__/compat.cpython-312.pyc,, +sqlalchemy/util/__pycache__/concurrency.cpython-312.pyc,, +sqlalchemy/util/__pycache__/deprecations.cpython-312.pyc,, +sqlalchemy/util/__pycache__/langhelpers.cpython-312.pyc,, +sqlalchemy/util/__pycache__/preloaded.cpython-312.pyc,, +sqlalchemy/util/__pycache__/queue.cpython-312.pyc,, +sqlalchemy/util/__pycache__/tool_support.cpython-312.pyc,, +sqlalchemy/util/__pycache__/topological.cpython-312.pyc,, +sqlalchemy/util/__pycache__/typing.cpython-312.pyc,, +sqlalchemy/util/_collections.py,sha256=FYqVQg3CaqiEd21OFN1pNCfFbQ8gvlchW_TMtihSFNE,20169 +sqlalchemy/util/_concurrency_py3k.py,sha256=31vs1oXaLzeTRgmOXRrWToRQskWmJk-CBs3-JxSTcck,8223 +sqlalchemy/util/_has_cy.py,sha256=XMkeqCDGmhkd0uuzpCdyELz7gOjHxyFQ1AIlc5NneoY,1229 +sqlalchemy/util/_py_collections.py,sha256=cYjsYLCLBy5jdGBJATLJCmtfzr_AaJ-HKTUN8OdAzxY,16630 +sqlalchemy/util/compat.py,sha256=FkeHnW9asJYJvNmxVltee8jQNwQSdVRdKJlVRRInJI4,9388 +sqlalchemy/util/concurrency.py,sha256=ZxcQYOKy-GBsQkPmCrBO5MzMpqW3JZme2Hiyqpbt9uc,2284 +sqlalchemy/util/deprecations.py,sha256=pr9DSAf1ECqDk7X7F6TNc1jrhOeFihL33uEb5Wt2_T0,11971 +sqlalchemy/util/langhelpers.py,sha256=CQQP2Q9c68nL5mcWL-Q38-INrtoDHDnBmq7QhnWyEDM,64980 +sqlalchemy/util/preloaded.py,sha256=KKNLJEqChDW1TNUsM_TzKu7JYEA3kkuh2N-quM_2_Y4,5905 +sqlalchemy/util/queue.py,sha256=ITejs6KS4Hz_ojrss2oFeUO9MoIeR3qWmZQ8J7yyrNU,10205 +sqlalchemy/util/tool_support.py,sha256=epm8MzDZpVmhE6LIjrjJrP8BUf12Wab2m28A9lGq95s,5969 +sqlalchemy/util/topological.py,sha256=hjJWL3C_B7Rpv9s7jj7wcTckcZUSkxc6xRDhiN1xyec,3458 +sqlalchemy/util/typing.py,sha256=ESYm4oQtt-SarN04YTXCgovXT8tFupMiPmuGCDCMEIc,15831 diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/WHEEL b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/WHEEL similarity index 78% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/WHEEL index 3e81182..c5825c5 100644 --- a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (75.1.0) +Generator: bdist_wheel (0.41.3) Root-Is-Purelib: false Tag: cp312-cp312-manylinux_2_17_x86_64 Tag: cp312-cp312-manylinux2014_x86_64 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/SQLAlchemy-2.0.23.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/METADATA similarity index 87% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/METADATA index 942d74c..61b6e3e 100644 --- a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.3 +Metadata-Version: 2.1 Name: aiofiles -Version: 24.1.0 +Version: 23.2.1 Summary: File support for asyncio. Project-URL: Changelog, https://github.com/Tinche/aiofiles#history Project-URL: Bug Tracker, https://github.com/Tinche/aiofiles/issues @@ -13,15 +13,15 @@ Classifier: Development Status :: 5 - Production/Stable Classifier: Framework :: AsyncIO Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy -Requires-Python: >=3.8 +Requires-Python: >=3.7 Description-Content-Type: text/markdown # aiofiles: file support for asyncio @@ -135,8 +135,6 @@ several useful `os` functions that deal with files: - `listdir` - `scandir` - `access` -- `getcwd` -- `path.abspath` - `path.exists` - `path.isfile` - `path.isdir` @@ -178,50 +176,25 @@ as desired. The return type also needs to be registered with the ```python aiofiles.threadpool.wrap.register(mock.MagicMock)( - lambda *args, **kwargs: aiofiles.threadpool.AsyncBufferedIOBase(*args, **kwargs) -) + lambda *args, **kwargs: threadpool.AsyncBufferedIOBase(*args, **kwargs)) async def test_stuff(): - write_data = 'data' - read_file_chunks = [ - b'file chunks 1', - b'file chunks 2', - b'file chunks 3', - b'', - ] - file_chunks_iter = iter(read_file_chunks) + data = 'data' + mock_file = mock.MagicMock() - mock_file_stream = mock.MagicMock( - read=lambda *args, **kwargs: next(file_chunks_iter) - ) - - with mock.patch('aiofiles.threadpool.sync_open', return_value=mock_file_stream) as mock_open: + with mock.patch('aiofiles.threadpool.sync_open', return_value=mock_file) as mock_open: async with aiofiles.open('filename', 'w') as f: - await f.write(write_data) - assert f.read() == b'file chunks 1' + await f.write(data) - mock_file_stream.write.assert_called_once_with(write_data) + mock_file.write.assert_called_once_with(data) ``` ### History -#### 24.1.0 (2024-06-24) - -- Import `os.link` conditionally to fix importing on android. - [#175](https://github.com/Tinche/aiofiles/issues/175) -- Remove spurious items from `aiofiles.os.__all__` when running on Windows. -- Switch to more modern async idioms: Remove types.coroutine and make AiofilesContextManager an awaitable instead a coroutine. -- Add `aiofiles.os.path.abspath` and `aiofiles.os.getcwd`. - [#174](https://github.com/Tinche/aiofiles/issues/181) -- _aiofiles_ is now tested on Python 3.13 too. - [#184](https://github.com/Tinche/aiofiles/pull/184) -- Dropped Python 3.7 support. If you require it, use version 23.2.1. - #### 23.2.1 (2023-08-09) - Import `os.statvfs` conditionally to fix importing on non-UNIX systems. [#171](https://github.com/Tinche/aiofiles/issues/171) [#172](https://github.com/Tinche/aiofiles/pull/172) -- aiofiles is now also tested on Windows. #### 23.2.0 (2023-08-09) diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/RECORD similarity index 59% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/RECORD rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/RECORD index 7a9df6a..5a1d8d0 100644 --- a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/RECORD @@ -1,23 +1,23 @@ -aiofiles-24.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -aiofiles-24.1.0.dist-info/METADATA,sha256=CvUJx21XclgI1Lp5Bt_4AyJesRYg0xCSx4exJZVmaSA,10708 -aiofiles-24.1.0.dist-info/RECORD,, -aiofiles-24.1.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -aiofiles-24.1.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87 -aiofiles-24.1.0.dist-info/licenses/LICENSE,sha256=y16Ofl9KOYjhBjwULGDcLfdWBfTEZRXnduOspt-XbhQ,11325 -aiofiles-24.1.0.dist-info/licenses/NOTICE,sha256=EExY0dRQvWR0wJ2LZLwBgnM6YKw9jCU-M0zegpRSD_E,55 +aiofiles-23.2.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +aiofiles-23.2.1.dist-info/METADATA,sha256=cot28p_PNjdl_MK--l9Qu2e6QOv9OxdHrKbjLmYf9Uw,9673 +aiofiles-23.2.1.dist-info/RECORD,, +aiofiles-23.2.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +aiofiles-23.2.1.dist-info/WHEEL,sha256=KGYbc1zXlYddvwxnNty23BeaKzh7YuoSIvIMO4jEhvw,87 +aiofiles-23.2.1.dist-info/licenses/LICENSE,sha256=y16Ofl9KOYjhBjwULGDcLfdWBfTEZRXnduOspt-XbhQ,11325 +aiofiles-23.2.1.dist-info/licenses/NOTICE,sha256=EExY0dRQvWR0wJ2LZLwBgnM6YKw9jCU-M0zegpRSD_E,55 aiofiles/__init__.py,sha256=1iAMJQyJtX3LGIS0AoFTJeO1aJ_RK2jpBSBhg0VoIrE,344 aiofiles/__pycache__/__init__.cpython-312.pyc,, aiofiles/__pycache__/base.cpython-312.pyc,, aiofiles/__pycache__/os.cpython-312.pyc,, aiofiles/__pycache__/ospath.cpython-312.pyc,, -aiofiles/base.py,sha256=zo0FgkCqZ5aosjvxqIvDf2t-RFg1Lc6X8P6rZ56p6fQ,1784 -aiofiles/os.py,sha256=0DrsG-eH4h7xRzglv9pIWsQuzqe7ZhVYw5FQS18fIys,1153 -aiofiles/ospath.py,sha256=WaYelz_k6ykAFRLStr4bqYIfCVQ-5GGzIqIizykbY2Q,794 +aiofiles/base.py,sha256=rZwA151Ji8XlBkzvDmcF1CgDTY2iKNuJMfvNlM0s0E0,2684 +aiofiles/os.py,sha256=zuFGaIyGCGUuFb7trFFEm6SLdCRqTFsSV0mY6SO8z3M,970 +aiofiles/ospath.py,sha256=zqG2VFzRb6yYiIOWipqsdgvZmoMTFvZmBdkxkAl1FT4,764 aiofiles/tempfile/__init__.py,sha256=hFSNTOjOUv371Ozdfy6FIxeln46Nm3xOVh4ZR3Q94V0,10244 aiofiles/tempfile/__pycache__/__init__.cpython-312.pyc,, aiofiles/tempfile/__pycache__/temptypes.cpython-312.pyc,, aiofiles/tempfile/temptypes.py,sha256=ddEvNjMLVlr7WUILCe6ypTqw77yREeIonTk16Uw_NVs,2093 -aiofiles/threadpool/__init__.py,sha256=kt0hwwx3bLiYtnA1SORhW8mJ6z4W9Xr7MbY80UIJJrI,3133 +aiofiles/threadpool/__init__.py,sha256=c_aexl1t193iKdPZaolPEEbHDrQ0RrsH_HTAToMPQBo,3171 aiofiles/threadpool/__pycache__/__init__.cpython-312.pyc,, aiofiles/threadpool/__pycache__/binary.cpython-312.pyc,, aiofiles/threadpool/__pycache__/text.cpython-312.pyc,, diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/WHEEL similarity index 67% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/WHEEL index cdd68a4..9a7c9d3 100644 --- a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/WHEEL @@ -1,4 +1,4 @@ Wheel-Version: 1.0 -Generator: hatchling 1.25.0 +Generator: hatchling 1.17.1 Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/licenses/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/licenses/LICENSE diff --git a/venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/licenses/NOTICE b/venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/licenses/NOTICE similarity index 100% rename from venv/lib/python3.12/site-packages/aiofiles-24.1.0.dist-info/licenses/NOTICE rename to venv/lib/python3.12/site-packages/aiofiles-23.2.1.dist-info/licenses/NOTICE diff --git a/venv/lib/python3.12/site-packages/aiofiles/base.py b/venv/lib/python3.12/site-packages/aiofiles/base.py index 64f7d6b..07f2c2e 100644 --- a/venv/lib/python3.12/site-packages/aiofiles/base.py +++ b/venv/lib/python3.12/site-packages/aiofiles/base.py @@ -1,6 +1,6 @@ """Various base classes.""" -from collections.abc import Awaitable -from contextlib import AbstractAsyncContextManager +from types import coroutine +from collections.abc import Coroutine from asyncio import get_running_loop @@ -45,22 +45,66 @@ class AsyncIndirectBase(AsyncBase): pass # discard writes -class AiofilesContextManager(Awaitable, AbstractAsyncContextManager): - """An adjusted async context manager for aiofiles.""" - +class _ContextManager(Coroutine): __slots__ = ("_coro", "_obj") def __init__(self, coro): self._coro = coro self._obj = None + def send(self, value): + return self._coro.send(value) + + def throw(self, typ, val=None, tb=None): + if val is None: + return self._coro.throw(typ) + elif tb is None: + return self._coro.throw(typ, val) + else: + return self._coro.throw(typ, val, tb) + + def close(self): + return self._coro.close() + + @property + def gi_frame(self): + return self._coro.gi_frame + + @property + def gi_running(self): + return self._coro.gi_running + + @property + def gi_code(self): + return self._coro.gi_code + + def __next__(self): + return self.send(None) + + @coroutine + def __iter__(self): + resp = yield from self._coro + return resp + def __await__(self): - if self._obj is None: - self._obj = yield from self._coro.__await__() - return self._obj + resp = yield from self._coro + return resp + + async def __anext__(self): + resp = await self._coro + return resp async def __aenter__(self): - return await self + self._obj = await self._coro + return self._obj + + async def __aexit__(self, exc_type, exc, tb): + self._obj.close() + self._obj = None + + +class AiofilesContextManager(_ContextManager): + """An adjusted async context manager for aiofiles.""" async def __aexit__(self, exc_type, exc_val, exc_tb): await get_running_loop().run_in_executor( diff --git a/venv/lib/python3.12/site-packages/aiofiles/os.py b/venv/lib/python3.12/site-packages/aiofiles/os.py index 92243fa..29bc748 100644 --- a/venv/lib/python3.12/site-packages/aiofiles/os.py +++ b/venv/lib/python3.12/site-packages/aiofiles/os.py @@ -1,5 +1,4 @@ """Async executor versions of file functions from the os module.""" - import os from . import ospath as path @@ -8,6 +7,7 @@ from .ospath import wrap __all__ = [ "path", "stat", + "statvfs", "rename", "renames", "replace", @@ -17,20 +17,15 @@ __all__ = [ "makedirs", "rmdir", "removedirs", + "link", "symlink", "readlink", "listdir", "scandir", "access", + "sendfile", "wrap", - "getcwd", ] -if hasattr(os, "link"): - __all__ += ["link"] -if hasattr(os, "sendfile"): - __all__ += ["sendfile"] -if hasattr(os, "statvfs"): - __all__ += ["statvfs"] stat = wrap(os.stat) @@ -43,15 +38,13 @@ mkdir = wrap(os.mkdir) makedirs = wrap(os.makedirs) rmdir = wrap(os.rmdir) removedirs = wrap(os.removedirs) +link = wrap(os.link) symlink = wrap(os.symlink) readlink = wrap(os.readlink) listdir = wrap(os.listdir) scandir = wrap(os.scandir) access = wrap(os.access) -getcwd = wrap(os.getcwd) -if hasattr(os, "link"): - link = wrap(os.link) if hasattr(os, "sendfile"): sendfile = wrap(os.sendfile) if hasattr(os, "statvfs"): diff --git a/venv/lib/python3.12/site-packages/aiofiles/ospath.py b/venv/lib/python3.12/site-packages/aiofiles/ospath.py index 387d68d..5f32a43 100644 --- a/venv/lib/python3.12/site-packages/aiofiles/ospath.py +++ b/venv/lib/python3.12/site-packages/aiofiles/ospath.py @@ -1,5 +1,4 @@ """Async executor versions of file functions from the os.path module.""" - import asyncio from functools import partial, wraps from os import path @@ -27,4 +26,3 @@ getatime = wrap(path.getatime) getctime = wrap(path.getctime) samefile = wrap(path.samefile) sameopenfile = wrap(path.sameopenfile) -abspath = wrap(path.abspath) diff --git a/venv/lib/python3.12/site-packages/aiofiles/threadpool/__init__.py b/venv/lib/python3.12/site-packages/aiofiles/threadpool/__init__.py index e543283..a1cc673 100644 --- a/venv/lib/python3.12/site-packages/aiofiles/threadpool/__init__.py +++ b/venv/lib/python3.12/site-packages/aiofiles/threadpool/__init__.py @@ -10,6 +10,7 @@ from io import ( FileIO, TextIOBase, ) +from types import coroutine from ..base import AiofilesContextManager from .binary import ( @@ -62,7 +63,8 @@ def open( ) -async def _open( +@coroutine +def _open( file, mode="r", buffering=-1, @@ -89,7 +91,7 @@ async def _open( closefd=closefd, opener=opener, ) - f = await loop.run_in_executor(executor, cb) + f = yield from loop.run_in_executor(executor, cb) return wrap(f, loop=loop, executor=executor) diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/LICENSE similarity index 95% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/LICENSE index ab4bb16..74b9ce3 100644 --- a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/licenses/LICENSE +++ b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/LICENSE @@ -1,4 +1,4 @@ -Copyright 2009-2025 Michael Bayer. +Copyright 2009-2023 Michael Bayer. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in @@ -16,4 +16,4 @@ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +SOFTWARE. \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/METADATA b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/METADATA similarity index 92% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/METADATA rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/METADATA index c2aa6c3..0d01974 100644 --- a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/METADATA @@ -1,10 +1,11 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: alembic -Version: 1.16.5 +Version: 1.12.1 Summary: A database migration tool for SQLAlchemy. -Author-email: Mike Bayer -License-Expression: MIT -Project-URL: Homepage, https://alembic.sqlalchemy.org +Home-page: https://alembic.sqlalchemy.org +Author: Mike Bayer +Author-email: mike_mp@zzzcomputing.com +License: MIT Project-URL: Documentation, https://alembic.sqlalchemy.org/en/latest/ Project-URL: Changelog, https://alembic.sqlalchemy.org/en/latest/changelog.html Project-URL: Source, https://github.com/sqlalchemy/alembic/ @@ -12,27 +13,27 @@ Project-URL: Issue Tracker, https://github.com/sqlalchemy/alembic/issues/ Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Environment :: Console +Classifier: License :: OSI Approved :: MIT License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Database :: Front-Ends -Requires-Python: >=3.9 +Requires-Python: >=3.7 Description-Content-Type: text/x-rst License-File: LICENSE -Requires-Dist: SQLAlchemy>=1.4.0 +Requires-Dist: SQLAlchemy >=1.3.0 Requires-Dist: Mako -Requires-Dist: typing-extensions>=4.12 -Requires-Dist: tomli; python_version < "3.11" +Requires-Dist: typing-extensions >=4 +Requires-Dist: importlib-metadata ; python_version < "3.9" +Requires-Dist: importlib-resources ; python_version < "3.9" Provides-Extra: tz -Requires-Dist: tzdata; extra == "tz" -Dynamic: license-file +Requires-Dist: python-dateutil ; extra == 'tz' Alembic is a database migrations tool written by the author of `SQLAlchemy `_. A migrations tool diff --git a/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/RECORD new file mode 100644 index 0000000..a26054d --- /dev/null +++ b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/RECORD @@ -0,0 +1,149 @@ +../../../bin/alembic,sha256=kheZTewTBSd6rruOpyoj8QhFdGKiaj38MUFgBD5whig,238 +alembic-1.12.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +alembic-1.12.1.dist-info/LICENSE,sha256=soUmiob0QW6vTQWyrjiAwVb3xZqPk1pAK8BW6vszrwg,1058 +alembic-1.12.1.dist-info/METADATA,sha256=D9-LeKL0unLPg2JKmlFMB5NAxt9N9y-8oVEGOUHbQnU,7306 +alembic-1.12.1.dist-info/RECORD,, +alembic-1.12.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +alembic-1.12.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92 +alembic-1.12.1.dist-info/entry_points.txt,sha256=aykM30soxwGN0pB7etLc1q0cHJbL9dy46RnK9VX4LLw,48 +alembic-1.12.1.dist-info/top_level.txt,sha256=FwKWd5VsPFC8iQjpu1u9Cn-JnK3-V1RhUCmWqz1cl-s,8 +alembic/__init__.py,sha256=gczqgDgBRw3aV70aNeH6WGu0WdASQf_YiChV12qCRRI,75 +alembic/__main__.py,sha256=373m7-TBh72JqrSMYviGrxCHZo-cnweM8AGF8A22PmY,78 +alembic/__pycache__/__init__.cpython-312.pyc,, +alembic/__pycache__/__main__.cpython-312.pyc,, +alembic/__pycache__/command.cpython-312.pyc,, +alembic/__pycache__/config.cpython-312.pyc,, +alembic/__pycache__/context.cpython-312.pyc,, +alembic/__pycache__/environment.cpython-312.pyc,, +alembic/__pycache__/migration.cpython-312.pyc,, +alembic/__pycache__/op.cpython-312.pyc,, +alembic/autogenerate/__init__.py,sha256=4IHgWH89pForRq-yCDZhGjjVtsfGX5ECWNPuUs8nGUk,351 +alembic/autogenerate/__pycache__/__init__.cpython-312.pyc,, +alembic/autogenerate/__pycache__/api.cpython-312.pyc,, +alembic/autogenerate/__pycache__/compare.cpython-312.pyc,, +alembic/autogenerate/__pycache__/render.cpython-312.pyc,, +alembic/autogenerate/__pycache__/rewriter.cpython-312.pyc,, +alembic/autogenerate/api.py,sha256=MNn0Xtmj44aMFjfiR0LMkbxOynHyiyaRBnrj5EkImm4,21967 +alembic/autogenerate/compare.py,sha256=gSCjxrkQAl0rJD6o9Ln8wNxGVNU6FrWzKZYVkH5Tmac,47042 +alembic/autogenerate/render.py,sha256=Fik2aPZEIxOlTCrBd0UiPxnX5SFG__CvfXqMWoJr6lw,34475 +alembic/autogenerate/rewriter.py,sha256=Osba8GFVeqiX1ypGJW7Axt0ui2EROlaFtVZdMFbhzZ0,7384 +alembic/command.py,sha256=ze4pYvKpB-FtF8rduY6F6n3XHqeA-15iXaaEDeNHVzI,21588 +alembic/config.py,sha256=68e1nmYU5Nfh0bNRqRWUygSilDl1p0G_U1zZ8ifgmD8,21931 +alembic/context.py,sha256=hK1AJOQXJ29Bhn276GYcosxeG7pC5aZRT5E8c4bMJ4Q,195 +alembic/context.pyi,sha256=FLsT0be_vO_ozlC05EJkWR5olDPoTVq-7tgtoM5wSAw,31463 +alembic/ddl/__init__.py,sha256=xXr1W6PePe0gCLwR42ude0E6iru9miUFc1fCeQN4YP8,137 +alembic/ddl/__pycache__/__init__.cpython-312.pyc,, +alembic/ddl/__pycache__/base.cpython-312.pyc,, +alembic/ddl/__pycache__/impl.cpython-312.pyc,, +alembic/ddl/__pycache__/mssql.cpython-312.pyc,, +alembic/ddl/__pycache__/mysql.cpython-312.pyc,, +alembic/ddl/__pycache__/oracle.cpython-312.pyc,, +alembic/ddl/__pycache__/postgresql.cpython-312.pyc,, +alembic/ddl/__pycache__/sqlite.cpython-312.pyc,, +alembic/ddl/base.py,sha256=cCY3NldMRggrKd9bZ0mFRBE9GNDaAy0UJcM3ey4Utgw,9638 +alembic/ddl/impl.py,sha256=Z3GpNM2KwBpfl1UCam1YsYbSd0mQzRigOKQhUCLIPgE,25564 +alembic/ddl/mssql.py,sha256=0k26xnUSZNj3qCHEMzRFbaWgUzKcV07I3_-Ns47VhO0,14105 +alembic/ddl/mysql.py,sha256=ff8OE0zQ8YYjAgltBbtjQkDR-g9z65DNeFjEMm4sX6c,16675 +alembic/ddl/oracle.py,sha256=E0VaZaUM_5mwqNiJVA3zOAK-cuHVVIv_-NmUbH1JuGQ,6097 +alembic/ddl/postgresql.py,sha256=aO8pcVN5ycw1wG2m1RRt8dQUD1KgRa6T4rSzg9FPCkU,26457 +alembic/ddl/sqlite.py,sha256=9q7NAxyeFwn9kWwQSc9RLeMFSos8waM7x9lnXdByh44,7613 +alembic/environment.py,sha256=MM5lPayGT04H3aeng1H7GQ8HEAs3VGX5yy6mDLCPLT4,43 +alembic/migration.py,sha256=MV6Fju6rZtn2fTREKzXrCZM6aIBGII4OMZFix0X-GLs,41 +alembic/op.py,sha256=flHtcsVqOD-ZgZKK2pv-CJ5Cwh-KJ7puMUNXzishxLw,167 +alembic/op.pyi,sha256=ldQBwAfzm_-ZsC3nizMuGoD34hjMKb4V_-Q1rR8q8LI,48591 +alembic/operations/__init__.py,sha256=e0KQSZAgLpTWvyvreB7DWg7RJV_MWSOPVDgCqsd2FzY,318 +alembic/operations/__pycache__/__init__.cpython-312.pyc,, +alembic/operations/__pycache__/base.cpython-312.pyc,, +alembic/operations/__pycache__/batch.cpython-312.pyc,, +alembic/operations/__pycache__/ops.cpython-312.pyc,, +alembic/operations/__pycache__/schemaobj.cpython-312.pyc,, +alembic/operations/__pycache__/toimpl.cpython-312.pyc,, +alembic/operations/base.py,sha256=2so4KisDNuOLw0CRiZqorIHrhuenpVoFbn3B0sNvDic,72471 +alembic/operations/batch.py,sha256=uMvGJDlcTs0GSHasg4Gsdv1YcXeLOK_1lkRl3jk1ezY,26954 +alembic/operations/ops.py,sha256=aP9Uz36k98O_Y-njKIAifyvyhi0g2zU6_igKMos91_s,93539 +alembic/operations/schemaobj.py,sha256=-tWad8pgWUNWucbpTnPuFK_EEl913C0RADJhlBnrjhc,9393 +alembic/operations/toimpl.py,sha256=K8nUmojtL94tyLSWdDD-e94IbghZ19k55iBIMvzMm5E,6993 +alembic/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +alembic/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +alembic/runtime/__pycache__/__init__.cpython-312.pyc,, +alembic/runtime/__pycache__/environment.cpython-312.pyc,, +alembic/runtime/__pycache__/migration.cpython-312.pyc,, +alembic/runtime/environment.py,sha256=qaerrw5jB7zYliNnCvIziaju4-tvQ451MuGW8PHnfvw,41019 +alembic/runtime/migration.py,sha256=5UtTI_T0JtYzt6ZpeUhannMZOvXWiEymKFOpeCefaPY,49407 +alembic/script/__init__.py,sha256=lSj06O391Iy5avWAiq8SPs6N8RBgxkSPjP8wpXcNDGg,100 +alembic/script/__pycache__/__init__.cpython-312.pyc,, +alembic/script/__pycache__/base.cpython-312.pyc,, +alembic/script/__pycache__/revision.cpython-312.pyc,, +alembic/script/__pycache__/write_hooks.cpython-312.pyc,, +alembic/script/base.py,sha256=90SpT8wyTMTUuS0Svsy5YIoqJSrR-6CtYSzStmRvFT0,37174 +alembic/script/revision.py,sha256=DE0nwvDOzdFo843brvnhs1DfP0jRC5EVQHrNihC7PUQ,61471 +alembic/script/write_hooks.py,sha256=Nqj4zz3sm97kAPOpK1m-i2znJchiybO_TWT50oljlJw,4917 +alembic/templates/async/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58 +alembic/templates/async/__pycache__/env.cpython-312.pyc,, +alembic/templates/async/alembic.ini.mako,sha256=k3IyGDG15Rp1JDweC0TiDauaKYNvj3clrGfhw6oV6MI,3505 +alembic/templates/async/env.py,sha256=zbOCf3Y7w2lg92hxSwmG1MM_7y56i_oRH4AKp0pQBYo,2389 +alembic/templates/async/script.py.mako,sha256=MEqL-2qATlST9TAOeYgscMn1uy6HUS9NFvDgl93dMj8,635 +alembic/templates/generic/README,sha256=MVlc9TYmr57RbhXET6QxgyCcwWP7w-vLkEsirENqiIQ,38 +alembic/templates/generic/__pycache__/env.cpython-312.pyc,, +alembic/templates/generic/alembic.ini.mako,sha256=gZWFmH2A9sP0i7cxEDhJFkjGtTKUXaVna8QAbIaRqxk,3614 +alembic/templates/generic/env.py,sha256=TLRWOVW3Xpt_Tpf8JFzlnoPn_qoUu8UV77Y4o9XD6yI,2103 +alembic/templates/generic/script.py.mako,sha256=MEqL-2qATlST9TAOeYgscMn1uy6HUS9NFvDgl93dMj8,635 +alembic/templates/multidb/README,sha256=dWLDhnBgphA4Nzb7sNlMfCS3_06YqVbHhz-9O5JNqyI,606 +alembic/templates/multidb/__pycache__/env.cpython-312.pyc,, +alembic/templates/multidb/alembic.ini.mako,sha256=j_Y0yuZVoHy7sTPgSPd8DmbT2ItvAdWs7trYZSOmFnw,3708 +alembic/templates/multidb/env.py,sha256=6zNjnW8mXGUk7erTsAvrfhvqoczJ-gagjVq1Ypg2YIQ,4230 +alembic/templates/multidb/script.py.mako,sha256=N06nMtNSwHkgl0EBXDyMt8njp9tlOesR583gfq21nbY,1090 +alembic/testing/__init__.py,sha256=kOxOh5nwmui9d-_CCq9WA4Udwy7ITjm453w74CTLZDo,1159 +alembic/testing/__pycache__/__init__.cpython-312.pyc,, +alembic/testing/__pycache__/assertions.cpython-312.pyc,, +alembic/testing/__pycache__/env.cpython-312.pyc,, +alembic/testing/__pycache__/fixtures.cpython-312.pyc,, +alembic/testing/__pycache__/requirements.cpython-312.pyc,, +alembic/testing/__pycache__/schemacompare.cpython-312.pyc,, +alembic/testing/__pycache__/util.cpython-312.pyc,, +alembic/testing/__pycache__/warnings.cpython-312.pyc,, +alembic/testing/assertions.py,sha256=1CbJk8c8-WO9eJ0XJ0jJvMsNRLUrXV41NOeIJUAlOBk,5015 +alembic/testing/env.py,sha256=zJacVb_z6uLs2U1TtkmnFH9P3_F-3IfYbVv4UEPOvfo,10754 +alembic/testing/fixtures.py,sha256=NyP4wE_dFN9ZzSGiBagRu1cdzkka03nwJYJYHYrrkSY,9112 +alembic/testing/plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +alembic/testing/plugin/__pycache__/__init__.cpython-312.pyc,, +alembic/testing/plugin/__pycache__/bootstrap.cpython-312.pyc,, +alembic/testing/plugin/bootstrap.py,sha256=9C6wtjGrIVztZ928w27hsQE0KcjDLIUtUN3dvZKsMVk,50 +alembic/testing/requirements.py,sha256=WByOiJxn2crazIXPq6-0cfqV95cfd9vP_ZQ1Cf2l8hY,4841 +alembic/testing/schemacompare.py,sha256=7_4_0Y4UvuMiZ66pz1RC_P8Z1kYOP-R4Y5qUcNmcMKA,4535 +alembic/testing/suite/__init__.py,sha256=MvE7-hwbaVN1q3NM-ztGxORU9dnIelUCINKqNxewn7Y,288 +alembic/testing/suite/__pycache__/__init__.cpython-312.pyc,, +alembic/testing/suite/__pycache__/_autogen_fixtures.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_autogen_comments.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_autogen_computed.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_autogen_diffs.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_autogen_fks.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_autogen_identity.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_environment.cpython-312.pyc,, +alembic/testing/suite/__pycache__/test_op.cpython-312.pyc,, +alembic/testing/suite/_autogen_fixtures.py,sha256=cDq1pmzHe15S6dZPGNC6sqFaCQ3hLT_oPV2IDigUGQ0,9880 +alembic/testing/suite/test_autogen_comments.py,sha256=aEGqKUDw4kHjnDk298aoGcQvXJWmZXcIX_2FxH4cJK8,6283 +alembic/testing/suite/test_autogen_computed.py,sha256=qJeBpc8urnwTFvbwWrSTIbHVkRUuCXP-dKaNbUK2U2U,6077 +alembic/testing/suite/test_autogen_diffs.py,sha256=T4SR1n_kmcOKYhR4W1-dA0e5sddJ69DSVL2HW96kAkE,8394 +alembic/testing/suite/test_autogen_fks.py,sha256=AqFmb26Buex167HYa9dZWOk8x-JlB1OK3bwcvvjDFaU,32927 +alembic/testing/suite/test_autogen_identity.py,sha256=kcuqngG7qXAKPJDX4U8sRzPKHEJECHuZ0DtuaS6tVkk,5824 +alembic/testing/suite/test_environment.py,sha256=w9F0xnLEbALeR8k6_-Tz6JHvy91IqiTSypNasVzXfZQ,11877 +alembic/testing/suite/test_op.py,sha256=2XQCdm_NmnPxHGuGj7hmxMzIhKxXNotUsKdACXzE1mM,1343 +alembic/testing/util.py,sha256=CQrcQDA8fs_7ME85z5ydb-Bt70soIIID-qNY1vbR2dg,3350 +alembic/testing/warnings.py,sha256=RxA7x_8GseANgw07Us8JN_1iGbANxaw6_VitX2ZGQH4,1078 +alembic/util/__init__.py,sha256=cPF_jjFx7YRBByHHDqW3wxCIHsqnGfncEr_i238aduY,1202 +alembic/util/__pycache__/__init__.cpython-312.pyc,, +alembic/util/__pycache__/compat.cpython-312.pyc,, +alembic/util/__pycache__/editor.cpython-312.pyc,, +alembic/util/__pycache__/exc.cpython-312.pyc,, +alembic/util/__pycache__/langhelpers.cpython-312.pyc,, +alembic/util/__pycache__/messaging.cpython-312.pyc,, +alembic/util/__pycache__/pyfiles.cpython-312.pyc,, +alembic/util/__pycache__/sqla_compat.cpython-312.pyc,, +alembic/util/compat.py,sha256=WN8jPPFB9ri_uuEM1HEaN1ak3RJc_H3x8NqvtFkoXuM,2279 +alembic/util/editor.py,sha256=JIz6_BdgV8_oKtnheR6DZoB7qnrHrlRgWjx09AsTsUw,2546 +alembic/util/exc.py,sha256=KQTru4zcgAmN4IxLMwLFS56XToUewaXB7oOLcPNjPwg,98 +alembic/util/langhelpers.py,sha256=ZFGyGygHRbztOeajpajppyhd-Gp4PB5slMuvCFVrnmg,8591 +alembic/util/messaging.py,sha256=B6T-loMhIOY3OTbG47Ywp1Df9LZn18PgjwpwLrD1VNg,3042 +alembic/util/pyfiles.py,sha256=95J01FChN0j2uP3p72mjaOQvh5wC6XbdGtTDK8oEzsQ,3373 +alembic/util/sqla_compat.py,sha256=94MHlkj43y-QQySz5dCUiJUNOPr3BF9TQ_BrP6ey-8w,18906 diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/WHEEL b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/WHEEL similarity index 65% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/WHEEL index e7fa31b..7e68873 100644 --- a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.9.0) +Generator: bdist_wheel (0.41.2) Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/entry_points.txt b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/entry_points.txt similarity index 100% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/entry_points.txt rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/entry_points.txt diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/alembic-1.12.1.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/RECORD b/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/RECORD deleted file mode 100644 index 5108e5d..0000000 --- a/venv/lib/python3.12/site-packages/alembic-1.16.5.dist-info/RECORD +++ /dev/null @@ -1,163 +0,0 @@ -../../../bin/alembic,sha256=_J6yD4KtWGrilKk3GrsJKTd-33Dqp4ejOp_LNh0fQNs,234 -alembic-1.16.5.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -alembic-1.16.5.dist-info/METADATA,sha256=_hKTp0jnKI77a2esxmoCXgv5t2U8hDZS7yZDRkDBl0k,7265 -alembic-1.16.5.dist-info/RECORD,, -alembic-1.16.5.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -alembic-1.16.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 -alembic-1.16.5.dist-info/entry_points.txt,sha256=aykM30soxwGN0pB7etLc1q0cHJbL9dy46RnK9VX4LLw,48 -alembic-1.16.5.dist-info/licenses/LICENSE,sha256=NeqcNBmyYfrxvkSMT0fZJVKBv2s2tf_qVQUiJ9S6VN4,1059 -alembic-1.16.5.dist-info/top_level.txt,sha256=FwKWd5VsPFC8iQjpu1u9Cn-JnK3-V1RhUCmWqz1cl-s,8 -alembic/__init__.py,sha256=H_hItDeyDOrQAHc1AFoYXIRN3O3FSxw4zSNiVzz2JlM,63 -alembic/__main__.py,sha256=373m7-TBh72JqrSMYviGrxCHZo-cnweM8AGF8A22PmY,78 -alembic/__pycache__/__init__.cpython-312.pyc,, -alembic/__pycache__/__main__.cpython-312.pyc,, -alembic/__pycache__/command.cpython-312.pyc,, -alembic/__pycache__/config.cpython-312.pyc,, -alembic/__pycache__/context.cpython-312.pyc,, -alembic/__pycache__/environment.cpython-312.pyc,, -alembic/__pycache__/migration.cpython-312.pyc,, -alembic/__pycache__/op.cpython-312.pyc,, -alembic/autogenerate/__init__.py,sha256=ntmUTXhjLm4_zmqIwyVaECdpPDn6_u1yM9vYk6-553E,543 -alembic/autogenerate/__pycache__/__init__.cpython-312.pyc,, -alembic/autogenerate/__pycache__/api.cpython-312.pyc,, -alembic/autogenerate/__pycache__/compare.cpython-312.pyc,, -alembic/autogenerate/__pycache__/render.cpython-312.pyc,, -alembic/autogenerate/__pycache__/rewriter.cpython-312.pyc,, -alembic/autogenerate/api.py,sha256=L4qkapSJO1Ypymx8HsjLl0vFFt202agwMYsQbIe6ZtI,22219 -alembic/autogenerate/compare.py,sha256=LRTxNijEBvcTauuUXuJjC6Sg_gUn33FCYBTF0neZFwE,45979 -alembic/autogenerate/render.py,sha256=ceQL8nk8m2kBtQq5gtxtDLR9iR0Sck8xG_61Oez-Sqs,37270 -alembic/autogenerate/rewriter.py,sha256=NIASSS-KaNKPmbm1k4pE45aawwjSh1Acf6eZrOwnUGM,7814 -alembic/command.py,sha256=pZPQUGSxCjFu7qy0HMe02HJmByM0LOqoiK2AXKfRO3A,24855 -alembic/config.py,sha256=nfwN_OOFPpee-OY4o10DANh7VG_E4O7bdW00Wx8NNKY,34237 -alembic/context.py,sha256=hK1AJOQXJ29Bhn276GYcosxeG7pC5aZRT5E8c4bMJ4Q,195 -alembic/context.pyi,sha256=fdeFNTRc0bUgi7n2eZWVFh6NG-TzIv_0gAcapbfHnKY,31773 -alembic/ddl/__init__.py,sha256=Df8fy4Vn_abP8B7q3x8gyFwEwnLw6hs2Ljt_bV3EZWE,152 -alembic/ddl/__pycache__/__init__.cpython-312.pyc,, -alembic/ddl/__pycache__/_autogen.cpython-312.pyc,, -alembic/ddl/__pycache__/base.cpython-312.pyc,, -alembic/ddl/__pycache__/impl.cpython-312.pyc,, -alembic/ddl/__pycache__/mssql.cpython-312.pyc,, -alembic/ddl/__pycache__/mysql.cpython-312.pyc,, -alembic/ddl/__pycache__/oracle.cpython-312.pyc,, -alembic/ddl/__pycache__/postgresql.cpython-312.pyc,, -alembic/ddl/__pycache__/sqlite.cpython-312.pyc,, -alembic/ddl/_autogen.py,sha256=Blv2RrHNyF4cE6znCQXNXG5T9aO-YmiwD4Fz-qfoaWA,9275 -alembic/ddl/base.py,sha256=A1f89-rCZvqw-hgWmBbIszRqx94lL6gKLFXE9kHettA,10478 -alembic/ddl/impl.py,sha256=UL8-iza7CJk_T73lr5fjDLdhxEL56uD-AEjtmESAbLk,30439 -alembic/ddl/mssql.py,sha256=NzORSIDHUll_g6iH4IyMTXZU1qjKzXrpespKrjWnfLY,14216 -alembic/ddl/mysql.py,sha256=LSfwiABdT54sKY_uQ-h6RvjbGiG-1vCSDkO3ECeq3qM,18383 -alembic/ddl/oracle.py,sha256=669YlkcZihlXFbnXhH2krdrvDry8q5pcUGfoqkg_R6Y,6243 -alembic/ddl/postgresql.py,sha256=S7uye2NDSHLwV3w8SJ2Q9DLbcvQIxQfJ3EEK6JqyNag,29950 -alembic/ddl/sqlite.py,sha256=u5tJgRUiY6bzVltl_NWlI6cy23v8XNagk_9gPI6Lnns,8006 -alembic/environment.py,sha256=MM5lPayGT04H3aeng1H7GQ8HEAs3VGX5yy6mDLCPLT4,43 -alembic/migration.py,sha256=MV6Fju6rZtn2fTREKzXrCZM6aIBGII4OMZFix0X-GLs,41 -alembic/op.py,sha256=flHtcsVqOD-ZgZKK2pv-CJ5Cwh-KJ7puMUNXzishxLw,167 -alembic/op.pyi,sha256=PQ4mKNp7EXrjVdIWQRoGiBSVke4PPxTc9I6qF8ZGGZE,50711 -alembic/operations/__init__.py,sha256=e0KQSZAgLpTWvyvreB7DWg7RJV_MWSOPVDgCqsd2FzY,318 -alembic/operations/__pycache__/__init__.cpython-312.pyc,, -alembic/operations/__pycache__/base.cpython-312.pyc,, -alembic/operations/__pycache__/batch.cpython-312.pyc,, -alembic/operations/__pycache__/ops.cpython-312.pyc,, -alembic/operations/__pycache__/schemaobj.cpython-312.pyc,, -alembic/operations/__pycache__/toimpl.cpython-312.pyc,, -alembic/operations/base.py,sha256=npw1iFboTlEsaQS0b7mb2SEHsRDV4GLQqnjhcfma6Nk,75157 -alembic/operations/batch.py,sha256=1UmCFcsFWObinQWFRWoGZkjynl54HKpldbPs67aR4wg,26923 -alembic/operations/ops.py,sha256=ftsFgcZIctxRDiuGgkQsaFHsMlRP7cLq7Dj_seKVBnQ,96276 -alembic/operations/schemaobj.py,sha256=Wp-bBe4a8lXPTvIHJttBY0ejtpVR5Jvtb2kI-U2PztQ,9468 -alembic/operations/toimpl.py,sha256=rgufuSUNwpgrOYzzY3Q3ELW1rQv2fQbQVokXgnIYIrs,7503 -alembic/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -alembic/runtime/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -alembic/runtime/__pycache__/__init__.cpython-312.pyc,, -alembic/runtime/__pycache__/environment.cpython-312.pyc,, -alembic/runtime/__pycache__/migration.cpython-312.pyc,, -alembic/runtime/environment.py,sha256=L6bDW1dvw8L4zwxlTG8KnT0xcCgLXxUfdRpzqlJoFjo,41479 -alembic/runtime/migration.py,sha256=lu9_z_qyWmNzSM52_FgdXP_G52PTmTTeOeMBQAkQTFg,49997 -alembic/script/__init__.py,sha256=lSj06O391Iy5avWAiq8SPs6N8RBgxkSPjP8wpXcNDGg,100 -alembic/script/__pycache__/__init__.cpython-312.pyc,, -alembic/script/__pycache__/base.cpython-312.pyc,, -alembic/script/__pycache__/revision.cpython-312.pyc,, -alembic/script/__pycache__/write_hooks.cpython-312.pyc,, -alembic/script/base.py,sha256=4jINClsNNwQIvnf4Kwp9JPAMrANLXdLItylXmcMqAkI,36896 -alembic/script/revision.py,sha256=BQcJoMCIXtSJRLCvdasgLOtCx9O7A8wsSym1FsqLW4s,62307 -alembic/script/write_hooks.py,sha256=uQWAtguSCrxU_k9d87NX19y6EzyjJRRQ5HS9cyPnK9o,5092 -alembic/templates/async/README,sha256=ISVtAOvqvKk_5ThM5ioJE-lMkvf9IbknFUFVU_vPma4,58 -alembic/templates/async/__pycache__/env.cpython-312.pyc,, -alembic/templates/async/alembic.ini.mako,sha256=Bgi4WkaHYsT7xvsX-4WOGkcXKFroNoQLaUvZA23ZwGs,4864 -alembic/templates/async/env.py,sha256=zbOCf3Y7w2lg92hxSwmG1MM_7y56i_oRH4AKp0pQBYo,2389 -alembic/templates/async/script.py.mako,sha256=04kgeBtNMa4cCnG8CfQcKt6P6rnloIfj8wy0u_DBydM,704 -alembic/templates/generic/README,sha256=MVlc9TYmr57RbhXET6QxgyCcwWP7w-vLkEsirENqiIQ,38 -alembic/templates/generic/__pycache__/env.cpython-312.pyc,, -alembic/templates/generic/alembic.ini.mako,sha256=LCpLL02bi9Qr3KRTEj9NbQqAu0ckUmYBwPtrMtQkv-Y,4864 -alembic/templates/generic/env.py,sha256=TLRWOVW3Xpt_Tpf8JFzlnoPn_qoUu8UV77Y4o9XD6yI,2103 -alembic/templates/generic/script.py.mako,sha256=04kgeBtNMa4cCnG8CfQcKt6P6rnloIfj8wy0u_DBydM,704 -alembic/templates/multidb/README,sha256=dWLDhnBgphA4Nzb7sNlMfCS3_06YqVbHhz-9O5JNqyI,606 -alembic/templates/multidb/__pycache__/env.cpython-312.pyc,, -alembic/templates/multidb/alembic.ini.mako,sha256=rIp1LTdE1xcoFT2G7X72KshzYjUTRrHTvnkvFL___-8,5190 -alembic/templates/multidb/env.py,sha256=6zNjnW8mXGUk7erTsAvrfhvqoczJ-gagjVq1Ypg2YIQ,4230 -alembic/templates/multidb/script.py.mako,sha256=ZbCXMkI5Wj2dwNKcxuVGkKZ7Iav93BNx_bM4zbGi3c8,1235 -alembic/templates/pyproject/README,sha256=dMhIiFoeM7EdeaOXBs3mVQ6zXACMyGXDb_UBB6sGRA0,60 -alembic/templates/pyproject/__pycache__/env.cpython-312.pyc,, -alembic/templates/pyproject/alembic.ini.mako,sha256=bQnEoydnLOUgg9vNbTOys4r5MaW8lmwYFXSrlfdEEkw,782 -alembic/templates/pyproject/env.py,sha256=TLRWOVW3Xpt_Tpf8JFzlnoPn_qoUu8UV77Y4o9XD6yI,2103 -alembic/templates/pyproject/pyproject.toml.mako,sha256=Gf16ZR9OMG9zDlFO5PVQlfiL1DTKwSA--sTNzK7Lba0,2852 -alembic/templates/pyproject/script.py.mako,sha256=04kgeBtNMa4cCnG8CfQcKt6P6rnloIfj8wy0u_DBydM,704 -alembic/templates/pyproject_async/README,sha256=2Q5XcEouiqQ-TJssO9805LROkVUd0F6d74rTnuLrifA,45 -alembic/templates/pyproject_async/__pycache__/env.cpython-312.pyc,, -alembic/templates/pyproject_async/alembic.ini.mako,sha256=bQnEoydnLOUgg9vNbTOys4r5MaW8lmwYFXSrlfdEEkw,782 -alembic/templates/pyproject_async/env.py,sha256=zbOCf3Y7w2lg92hxSwmG1MM_7y56i_oRH4AKp0pQBYo,2389 -alembic/templates/pyproject_async/pyproject.toml.mako,sha256=Gf16ZR9OMG9zDlFO5PVQlfiL1DTKwSA--sTNzK7Lba0,2852 -alembic/templates/pyproject_async/script.py.mako,sha256=04kgeBtNMa4cCnG8CfQcKt6P6rnloIfj8wy0u_DBydM,704 -alembic/testing/__init__.py,sha256=PTMhi_2PZ1T_3atQS2CIr0V4YRZzx_doKI-DxKdQS44,1297 -alembic/testing/__pycache__/__init__.cpython-312.pyc,, -alembic/testing/__pycache__/assertions.cpython-312.pyc,, -alembic/testing/__pycache__/env.cpython-312.pyc,, -alembic/testing/__pycache__/fixtures.cpython-312.pyc,, -alembic/testing/__pycache__/requirements.cpython-312.pyc,, -alembic/testing/__pycache__/schemacompare.cpython-312.pyc,, -alembic/testing/__pycache__/util.cpython-312.pyc,, -alembic/testing/__pycache__/warnings.cpython-312.pyc,, -alembic/testing/assertions.py,sha256=qcqf3tRAUe-A12NzuK_yxlksuX9OZKRC5E8pKIdBnPg,5302 -alembic/testing/env.py,sha256=pka7fjwOC8hYL6X0XE4oPkJpy_1WX01bL7iP7gpO_4I,11551 -alembic/testing/fixtures.py,sha256=fOzsRF8SW6CWpAH0sZpUHcgsJjun9EHnp4k2S3Lq5eU,9920 -alembic/testing/plugin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -alembic/testing/plugin/__pycache__/__init__.cpython-312.pyc,, -alembic/testing/plugin/__pycache__/bootstrap.cpython-312.pyc,, -alembic/testing/plugin/bootstrap.py,sha256=9C6wtjGrIVztZ928w27hsQE0KcjDLIUtUN3dvZKsMVk,50 -alembic/testing/requirements.py,sha256=gNnnvgPCuiqKeHmiNymdQuYIjQ0BrxiPxu_in4eHEsc,4180 -alembic/testing/schemacompare.py,sha256=N5UqSNCOJetIKC4vKhpYzQEpj08XkdgIoqBmEPQ3tlc,4838 -alembic/testing/suite/__init__.py,sha256=MvE7-hwbaVN1q3NM-ztGxORU9dnIelUCINKqNxewn7Y,288 -alembic/testing/suite/__pycache__/__init__.cpython-312.pyc,, -alembic/testing/suite/__pycache__/_autogen_fixtures.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_autogen_comments.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_autogen_computed.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_autogen_diffs.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_autogen_fks.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_autogen_identity.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_environment.cpython-312.pyc,, -alembic/testing/suite/__pycache__/test_op.cpython-312.pyc,, -alembic/testing/suite/_autogen_fixtures.py,sha256=Drrz_FKb9KDjq8hkwxtPkJVY1sCY7Biw-Muzb8kANp8,13480 -alembic/testing/suite/test_autogen_comments.py,sha256=aEGqKUDw4kHjnDk298aoGcQvXJWmZXcIX_2FxH4cJK8,6283 -alembic/testing/suite/test_autogen_computed.py,sha256=-5wran56qXo3afAbSk8cuSDDpbQweyJ61RF-GaVuZbA,4126 -alembic/testing/suite/test_autogen_diffs.py,sha256=T4SR1n_kmcOKYhR4W1-dA0e5sddJ69DSVL2HW96kAkE,8394 -alembic/testing/suite/test_autogen_fks.py,sha256=AqFmb26Buex167HYa9dZWOk8x-JlB1OK3bwcvvjDFaU,32927 -alembic/testing/suite/test_autogen_identity.py,sha256=kcuqngG7qXAKPJDX4U8sRzPKHEJECHuZ0DtuaS6tVkk,5824 -alembic/testing/suite/test_environment.py,sha256=OwD-kpESdLoc4byBrGrXbZHvqtPbzhFCG4W9hJOJXPQ,11877 -alembic/testing/suite/test_op.py,sha256=2XQCdm_NmnPxHGuGj7hmxMzIhKxXNotUsKdACXzE1mM,1343 -alembic/testing/util.py,sha256=CQrcQDA8fs_7ME85z5ydb-Bt70soIIID-qNY1vbR2dg,3350 -alembic/testing/warnings.py,sha256=cDDWzvxNZE6x9dME2ACTXSv01G81JcIbE1GIE_s1kvg,831 -alembic/util/__init__.py,sha256=_Zj_xp6ssKLyoLHUFzmKhnc8mhwXW8D8h7qyX-wO56M,1519 -alembic/util/__pycache__/__init__.cpython-312.pyc,, -alembic/util/__pycache__/compat.cpython-312.pyc,, -alembic/util/__pycache__/editor.cpython-312.pyc,, -alembic/util/__pycache__/exc.cpython-312.pyc,, -alembic/util/__pycache__/langhelpers.cpython-312.pyc,, -alembic/util/__pycache__/messaging.cpython-312.pyc,, -alembic/util/__pycache__/pyfiles.cpython-312.pyc,, -alembic/util/__pycache__/sqla_compat.cpython-312.pyc,, -alembic/util/compat.py,sha256=Vt5xCn5Y675jI4seKNBV4IVnCl9V4wyH3OBI2w7U0EY,4248 -alembic/util/editor.py,sha256=JIz6_BdgV8_oKtnheR6DZoB7qnrHrlRgWjx09AsTsUw,2546 -alembic/util/exc.py,sha256=ZBlTQ8g-Jkb1iYFhFHs9djilRz0SSQ0Foc5SSoENs5o,564 -alembic/util/langhelpers.py,sha256=LpOcovnhMnP45kTt8zNJ4BHpyQrlF40OL6yDXjqKtsE,10026 -alembic/util/messaging.py,sha256=3bEBoDy4EAXETXAvArlYjeMITXDTgPTu6ZoE3ytnzSw,3294 -alembic/util/pyfiles.py,sha256=kOBjZEytRkBKsQl0LAj2sbKJMQazjwQ_5UeMKSIvVFo,4730 -alembic/util/sqla_compat.py,sha256=9OYPTf-GCultAIuv1PoiaqYXAApZQxUOqjrOaeJDAik,14790 diff --git a/venv/lib/python3.12/site-packages/alembic/__init__.py b/venv/lib/python3.12/site-packages/alembic/__init__.py index 302a806..c5870fb 100644 --- a/venv/lib/python3.12/site-packages/alembic/__init__.py +++ b/venv/lib/python3.12/site-packages/alembic/__init__.py @@ -1,4 +1,6 @@ +import sys + from . import context from . import op -__version__ = "1.16.5" +__version__ = "1.12.1" diff --git a/venv/lib/python3.12/site-packages/alembic/autogenerate/__init__.py b/venv/lib/python3.12/site-packages/alembic/autogenerate/__init__.py index 445ddb2..cd2ed1c 100644 --- a/venv/lib/python3.12/site-packages/alembic/autogenerate/__init__.py +++ b/venv/lib/python3.12/site-packages/alembic/autogenerate/__init__.py @@ -1,10 +1,10 @@ -from .api import _render_migration_diffs as _render_migration_diffs -from .api import compare_metadata as compare_metadata -from .api import produce_migrations as produce_migrations -from .api import render_python_code as render_python_code -from .api import RevisionContext as RevisionContext -from .compare import _produce_net_changes as _produce_net_changes -from .compare import comparators as comparators -from .render import render_op_text as render_op_text -from .render import renderers as renderers -from .rewriter import Rewriter as Rewriter +from .api import _render_migration_diffs +from .api import compare_metadata +from .api import produce_migrations +from .api import render_python_code +from .api import RevisionContext +from .compare import _produce_net_changes +from .compare import comparators +from .render import render_op_text +from .render import renderers +from .rewriter import Rewriter diff --git a/venv/lib/python3.12/site-packages/alembic/autogenerate/api.py b/venv/lib/python3.12/site-packages/alembic/autogenerate/api.py index 811462e..7282487 100644 --- a/venv/lib/python3.12/site-packages/alembic/autogenerate/api.py +++ b/venv/lib/python3.12/site-packages/alembic/autogenerate/api.py @@ -17,7 +17,6 @@ from . import compare from . import render from .. import util from ..operations import ops -from ..util import sqla_compat """Provide the 'autogenerate' feature which can produce migration operations automatically.""" @@ -28,7 +27,6 @@ if TYPE_CHECKING: from sqlalchemy.engine import Inspector from sqlalchemy.sql.schema import MetaData from sqlalchemy.sql.schema import SchemaItem - from sqlalchemy.sql.schema import Table from ..config import Config from ..operations.ops import DowngradeOps @@ -166,7 +164,6 @@ def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any: """ migration_script = produce_migrations(context, metadata) - assert migration_script.upgrade_ops is not None return migration_script.upgrade_ops.as_diffs() @@ -277,7 +274,7 @@ class AutogenContext: """Maintains configuration and state that's specific to an autogenerate operation.""" - metadata: Union[MetaData, Sequence[MetaData], None] = None + metadata: Optional[MetaData] = None """The :class:`~sqlalchemy.schema.MetaData` object representing the destination. @@ -332,8 +329,8 @@ class AutogenContext: def __init__( self, migration_context: MigrationContext, - metadata: Union[MetaData, Sequence[MetaData], None] = None, - opts: Optional[Dict[str, Any]] = None, + metadata: Optional[MetaData] = None, + opts: Optional[dict] = None, autogenerate: bool = True, ) -> None: if ( @@ -443,7 +440,7 @@ class AutogenContext: def run_object_filters( self, object_: SchemaItem, - name: sqla_compat._ConstraintName, + name: Optional[str], type_: NameFilterType, reflected: bool, compare_to: Optional[SchemaItem], @@ -467,7 +464,7 @@ class AutogenContext: run_filters = run_object_filters @util.memoized_property - def sorted_tables(self) -> List[Table]: + def sorted_tables(self): """Return an aggregate of the :attr:`.MetaData.sorted_tables` collection(s). @@ -483,7 +480,7 @@ class AutogenContext: return result @util.memoized_property - def table_key_to_table(self) -> Dict[str, Table]: + def table_key_to_table(self): """Return an aggregate of the :attr:`.MetaData.tables` dictionaries. The :attr:`.MetaData.tables` collection is a dictionary of table key @@ -494,7 +491,7 @@ class AutogenContext: objects contain the same table key, an exception is raised. """ - result: Dict[str, Table] = {} + result = {} for m in util.to_list(self.metadata): intersect = set(result).intersection(set(m.tables)) if intersect: @@ -596,9 +593,9 @@ class RevisionContext: migration_script = self.generated_revisions[-1] if not getattr(migration_script, "_needs_render", False): migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token - migration_script.downgrade_ops_list[-1].downgrade_token = ( - downgrade_token - ) + migration_script.downgrade_ops_list[ + -1 + ].downgrade_token = downgrade_token migration_script._needs_render = True else: migration_script._upgrade_ops.append( diff --git a/venv/lib/python3.12/site-packages/alembic/autogenerate/compare.py b/venv/lib/python3.12/site-packages/alembic/autogenerate/compare.py index a9adda1..a24a75d 100644 --- a/venv/lib/python3.12/site-packages/alembic/autogenerate/compare.py +++ b/venv/lib/python3.12/site-packages/alembic/autogenerate/compare.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import contextlib @@ -10,12 +7,12 @@ from typing import Any from typing import cast from typing import Dict from typing import Iterator +from typing import List from typing import Mapping from typing import Optional from typing import Set from typing import Tuple from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union from sqlalchemy import event @@ -24,15 +21,10 @@ from sqlalchemy import schema as sa_schema from sqlalchemy import text from sqlalchemy import types as sqltypes from sqlalchemy.sql import expression -from sqlalchemy.sql.elements import conv -from sqlalchemy.sql.schema import ForeignKeyConstraint -from sqlalchemy.sql.schema import Index -from sqlalchemy.sql.schema import UniqueConstraint from sqlalchemy.util import OrderedSet +from alembic.ddl.base import _fk_spec from .. import util -from ..ddl._autogen import is_index_sig -from ..ddl._autogen import is_uq_sig from ..operations import ops from ..util import sqla_compat @@ -43,7 +35,10 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import Column + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import Index from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint from alembic.autogenerate.api import AutogenContext from alembic.ddl.impl import DefaultImpl @@ -51,8 +46,6 @@ if TYPE_CHECKING: from alembic.operations.ops import MigrationScript from alembic.operations.ops import ModifyTableOps from alembic.operations.ops import UpgradeOps - from ..ddl._autogen import _constraint_sig - log = logging.getLogger(__name__) @@ -217,7 +210,7 @@ def _compare_tables( (inspector), # fmt: on ) - _InspectorConv(inspector).reflect_table(t, include_columns=None) + sqla_compat._reflect_table(inspector, t) if autogen_context.run_object_filters(t, tname, "table", True, None): modify_table_ops = ops.ModifyTableOps(tname, [], schema=s) @@ -247,8 +240,7 @@ def _compare_tables( _compat_autogen_column_reflect(inspector), # fmt: on ) - _InspectorConv(inspector).reflect_table(t, include_columns=None) - + sqla_compat._reflect_table(inspector, t) conn_column_info[(s, tname)] = t for s, tname in sorted(existing_tables, key=lambda x: (x[0] or "", x[1])): @@ -437,56 +429,102 @@ def _compare_columns( log.info("Detected removed column '%s.%s'", name, cname) -_C = TypeVar("_C", bound=Union[UniqueConstraint, ForeignKeyConstraint, Index]) +class _constraint_sig: + const: Union[UniqueConstraint, ForeignKeyConstraint, Index] - -class _InspectorConv: - __slots__ = ("inspector",) - - def __init__(self, inspector): - self.inspector = inspector - - def _apply_reflectinfo_conv(self, consts): - if not consts: - return consts - for const in consts: - if const["name"] is not None and not isinstance( - const["name"], conv - ): - const["name"] = conv(const["name"]) - return consts - - def _apply_constraint_conv(self, consts): - if not consts: - return consts - for const in consts: - if const.name is not None and not isinstance(const.name, conv): - const.name = conv(const.name) - return consts - - def get_indexes(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_indexes(*args, **kw) + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: + return sqla_compat._get_constraint_final_name( + self.const, context.dialect ) - def get_unique_constraints(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_unique_constraints(*args, **kw) + def __eq__(self, other): + return self.const == other.const + + def __ne__(self, other): + return self.const != other.const + + def __hash__(self) -> int: + return hash(self.const) + + +class _uq_constraint_sig(_constraint_sig): + is_index = False + is_unique = True + + def __init__(self, const: UniqueConstraint, impl: DefaultImpl) -> None: + self.const = const + self.name = const.name + self.sig = ("UNIQUE_CONSTRAINT",) + impl.create_unique_constraint_sig( + const ) - def get_foreign_keys(self, *args, **kw): - return self._apply_reflectinfo_conv( - self.inspector.get_foreign_keys(*args, **kw) + @property + def column_names(self) -> List[str]: + return [col.name for col in self.const.columns] + + +class _ix_constraint_sig(_constraint_sig): + is_index = True + + def __init__(self, const: Index, impl: DefaultImpl) -> None: + self.const = const + self.name = const.name + self.sig = ("INDEX",) + impl.create_index_sig(const) + self.is_unique = bool(const.unique) + + def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: + return sqla_compat._get_constraint_final_name( + self.const, context.dialect ) - def reflect_table(self, table, *, include_columns): - self.inspector.reflect_table(table, include_columns=include_columns) + @property + def column_names(self) -> Union[List[quoted_name], List[None]]: + return sqla_compat._get_index_column_names(self.const) - # I had a cool version of this using _ReflectInfo, however that doesn't - # work in 1.4 and it's not public API in 2.x. Then this is just a two - # liner. So there's no competition... - self._apply_constraint_conv(table.constraints) - self._apply_constraint_conv(table.indexes) + +class _fk_constraint_sig(_constraint_sig): + def __init__( + self, const: ForeignKeyConstraint, include_options: bool = False + ) -> None: + self.const = const + self.name = const.name + + ( + self.source_schema, + self.source_table, + self.source_columns, + self.target_schema, + self.target_table, + self.target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) = _fk_spec(const) + + self.sig: Tuple[Any, ...] = ( + self.source_schema, + self.source_table, + tuple(self.source_columns), + self.target_schema, + self.target_table, + tuple(self.target_columns), + ) + if include_options: + self.sig += ( + (None if onupdate.lower() == "no action" else onupdate.lower()) + if onupdate + else None, + (None if ondelete.lower() == "no action" else ondelete.lower()) + if ondelete + else None, + # convert initially + deferrable into one three-state value + "initially_deferrable" + if initially and initially.lower() == "deferred" + else "deferrable" + if deferrable + else "not deferrable", + ) @comparators.dispatch_for("table") @@ -523,34 +561,34 @@ def _compare_indexes_and_uniques( if conn_table is not None: # 1b. ... and from connection, if the table exists - try: - conn_uniques = _InspectorConv(inspector).get_unique_constraints( - tname, schema=schema - ) - - supports_unique_constraints = True - except NotImplementedError: - pass - except TypeError: - # number of arguments is off for the base - # method in SQLAlchemy due to the cache decorator - # not being present - pass - else: - conn_uniques = [ # type:ignore[assignment] - uq - for uq in conn_uniques - if autogen_context.run_name_filters( - uq["name"], - "unique_constraint", - {"table_name": tname, "schema_name": schema}, + if hasattr(inspector, "get_unique_constraints"): + try: + conn_uniques = inspector.get_unique_constraints( # type:ignore[assignment] # noqa + tname, schema=schema ) - ] - for uq in conn_uniques: - if uq.get("duplicates_index"): - unique_constraints_duplicate_unique_indexes = True + supports_unique_constraints = True + except NotImplementedError: + pass + except TypeError: + # number of arguments is off for the base + # method in SQLAlchemy due to the cache decorator + # not being present + pass + else: + conn_uniques = [ # type:ignore[assignment] + uq + for uq in conn_uniques + if autogen_context.run_name_filters( + uq["name"], + "unique_constraint", + {"table_name": tname, "schema_name": schema}, + ) + ] + for uq in conn_uniques: + if uq.get("duplicates_index"): + unique_constraints_duplicate_unique_indexes = True try: - conn_indexes = _InspectorConv(inspector).get_indexes( + conn_indexes = inspector.get_indexes( # type:ignore[assignment] tname, schema=schema ) except NotImplementedError: @@ -601,7 +639,7 @@ def _compare_indexes_and_uniques( # 3. give the dialect a chance to omit indexes and constraints that # we know are either added implicitly by the DB or that the DB # can't accurately report on - impl.correct_for_autogen_constraints( + autogen_context.migration_context.impl.correct_for_autogen_constraints( conn_uniques, # type: ignore[arg-type] conn_indexes, # type: ignore[arg-type] metadata_unique_constraints, @@ -613,31 +651,31 @@ def _compare_indexes_and_uniques( # Index and UniqueConstraint so we can easily work with them # interchangeably metadata_unique_constraints_sig = { - impl._create_metadata_constraint_sig(uq) - for uq in metadata_unique_constraints + _uq_constraint_sig(uq, impl) for uq in metadata_unique_constraints } metadata_indexes_sig = { - impl._create_metadata_constraint_sig(ix) for ix in metadata_indexes + _ix_constraint_sig(ix, impl) for ix in metadata_indexes } conn_unique_constraints = { - impl._create_reflected_constraint_sig(uq) for uq in conn_uniques + _uq_constraint_sig(uq, impl) for uq in conn_uniques } - conn_indexes_sig = { - impl._create_reflected_constraint_sig(ix) for ix in conn_indexes - } + conn_indexes_sig = {_ix_constraint_sig(ix, impl) for ix in conn_indexes} # 5. index things by name, for those objects that have names metadata_names = { cast(str, c.md_name_to_sql_name(autogen_context)): c - for c in metadata_unique_constraints_sig.union(metadata_indexes_sig) - if c.is_named + for c in metadata_unique_constraints_sig.union( + metadata_indexes_sig # type:ignore[arg-type] + ) + if isinstance(c, _ix_constraint_sig) + or sqla_compat._constraint_is_named(c.const, autogen_context.dialect) } - conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig] - conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _constraint_sig] + conn_uniques_by_name: Dict[sqla_compat._ConstraintName, _uq_constraint_sig] + conn_indexes_by_name: Dict[sqla_compat._ConstraintName, _ix_constraint_sig] conn_uniques_by_name = {c.name: c for c in conn_unique_constraints} conn_indexes_by_name = {c.name: c for c in conn_indexes_sig} @@ -656,12 +694,13 @@ def _compare_indexes_and_uniques( # 6. index things by "column signature", to help with unnamed unique # constraints. - conn_uniques_by_sig = {uq.unnamed: uq for uq in conn_unique_constraints} + conn_uniques_by_sig = {uq.sig: uq for uq in conn_unique_constraints} metadata_uniques_by_sig = { - uq.unnamed: uq for uq in metadata_unique_constraints_sig + uq.sig: uq for uq in metadata_unique_constraints_sig } + metadata_indexes_by_sig = {ix.sig: ix for ix in metadata_indexes_sig} unnamed_metadata_uniques = { - uq.unnamed: uq + uq.sig: uq for uq in metadata_unique_constraints_sig if not sqla_compat._constraint_is_named( uq.const, autogen_context.dialect @@ -676,18 +715,18 @@ def _compare_indexes_and_uniques( # 4. The backend may double up indexes as unique constraints and # vice versa (e.g. MySQL, Postgresql) - def obj_added(obj: _constraint_sig): - if is_index_sig(obj): + def obj_added(obj): + if obj.is_index: if autogen_context.run_object_filters( obj.const, obj.name, "index", False, None ): modify_ops.ops.append(ops.CreateIndexOp.from_index(obj.const)) log.info( - "Detected added index %r on '%s'", + "Detected added index '%s' on %s", obj.name, - obj.column_names, + ", ".join(["'%s'" % obj.column_names]), ) - elif is_uq_sig(obj): + else: if not supports_unique_constraints: # can't report unique indexes as added if we don't # detect them @@ -702,15 +741,13 @@ def _compare_indexes_and_uniques( ops.AddConstraintOp.from_constraint(obj.const) ) log.info( - "Detected added unique constraint %r on '%s'", + "Detected added unique constraint '%s' on %s", obj.name, - obj.column_names, + ", ".join(["'%s'" % obj.column_names]), ) - else: - assert False - def obj_removed(obj: _constraint_sig): - if is_index_sig(obj): + def obj_removed(obj): + if obj.is_index: if obj.is_unique and not supports_unique_constraints: # many databases double up unique constraints # as unique indexes. without that list we can't @@ -721,8 +758,10 @@ def _compare_indexes_and_uniques( obj.const, obj.name, "index", True, None ): modify_ops.ops.append(ops.DropIndexOp.from_index(obj.const)) - log.info("Detected removed index %r on %r", obj.name, tname) - elif is_uq_sig(obj): + log.info( + "Detected removed index '%s' on '%s'", obj.name, tname + ) + else: if is_create_table or is_drop_table: # if the whole table is being dropped, we don't need to # consider unique constraint separately @@ -734,40 +773,33 @@ def _compare_indexes_and_uniques( ops.DropConstraintOp.from_constraint(obj.const) ) log.info( - "Detected removed unique constraint %r on %r", + "Detected removed unique constraint '%s' on '%s'", obj.name, tname, ) - else: - assert False - - def obj_changed( - old: _constraint_sig, - new: _constraint_sig, - msg: str, - ): - if is_index_sig(old): - assert is_index_sig(new) + def obj_changed(old, new, msg): + if old.is_index: if autogen_context.run_object_filters( new.const, new.name, "index", False, old.const ): log.info( - "Detected changed index %r on %r: %s", old.name, tname, msg + "Detected changed index '%s' on '%s':%s", + old.name, + tname, + ", ".join(msg), ) modify_ops.ops.append(ops.DropIndexOp.from_index(old.const)) modify_ops.ops.append(ops.CreateIndexOp.from_index(new.const)) - elif is_uq_sig(old): - assert is_uq_sig(new) - + else: if autogen_context.run_object_filters( new.const, new.name, "unique_constraint", False, old.const ): log.info( - "Detected changed unique constraint %r on %r: %s", + "Detected changed unique constraint '%s' on '%s':%s", old.name, tname, - msg, + ", ".join(msg), ) modify_ops.ops.append( ops.DropConstraintOp.from_constraint(old.const) @@ -775,24 +807,18 @@ def _compare_indexes_and_uniques( modify_ops.ops.append( ops.AddConstraintOp.from_constraint(new.const) ) - else: - assert False for removed_name in sorted(set(conn_names).difference(metadata_names)): - conn_obj = conn_names[removed_name] - if ( - is_uq_sig(conn_obj) - and conn_obj.unnamed in unnamed_metadata_uniques - ): + conn_obj: Union[_ix_constraint_sig, _uq_constraint_sig] = conn_names[ + removed_name + ] + if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques: continue elif removed_name in doubled_constraints: conn_uq, conn_idx = doubled_constraints[removed_name] if ( - all( - conn_idx.unnamed != meta_idx.unnamed - for meta_idx in metadata_indexes_sig - ) - and conn_uq.unnamed not in metadata_uniques_by_sig + conn_idx.sig not in metadata_indexes_by_sig + and conn_uq.sig not in metadata_uniques_by_sig ): obj_removed(conn_uq) obj_removed(conn_idx) @@ -804,36 +830,30 @@ def _compare_indexes_and_uniques( if existing_name in doubled_constraints: conn_uq, conn_idx = doubled_constraints[existing_name] - if is_index_sig(metadata_obj): + if metadata_obj.is_index: conn_obj = conn_idx else: conn_obj = conn_uq else: conn_obj = conn_names[existing_name] - if type(conn_obj) != type(metadata_obj): + if conn_obj.is_index != metadata_obj.is_index: obj_removed(conn_obj) obj_added(metadata_obj) else: - comparison = metadata_obj.compare_to_reflected(conn_obj) + msg = [] + if conn_obj.is_unique != metadata_obj.is_unique: + msg.append( + " unique=%r to unique=%r" + % (conn_obj.is_unique, metadata_obj.is_unique) + ) + if conn_obj.sig != metadata_obj.sig: + msg.append( + " expression %r to %r" % (conn_obj.sig, metadata_obj.sig) + ) - if comparison.is_different: - # constraint are different - obj_changed(conn_obj, metadata_obj, comparison.message) - elif comparison.is_skip: - # constraint cannot be compared, skip them - thing = ( - "index" if is_index_sig(conn_obj) else "unique constraint" - ) - log.info( - "Cannot compare %s %r, assuming equal and skipping. %s", - thing, - conn_obj.name, - comparison.message, - ) - else: - # constraint are equal - assert comparison.is_equal + if msg: + obj_changed(conn_obj, metadata_obj, msg) for added_name in sorted(set(metadata_names).difference(conn_names)): obj = metadata_names[added_name] @@ -873,7 +893,7 @@ def _correct_for_uq_duplicates_uix( } unnamed_metadata_uqs = { - impl._create_metadata_constraint_sig(cons).unnamed + _uq_constraint_sig(cons, impl).sig for name, cons in metadata_cons_names if name is None } @@ -897,9 +917,7 @@ def _correct_for_uq_duplicates_uix( for overlap in uqs_dupe_indexes: if overlap not in metadata_uq_names: if ( - impl._create_reflected_constraint_sig( - uqs_dupe_indexes[overlap] - ).unnamed + _uq_constraint_sig(uqs_dupe_indexes[overlap], impl).sig not in unnamed_metadata_uqs ): conn_unique_constraints.discard(uqs_dupe_indexes[overlap]) @@ -1035,7 +1053,7 @@ def _normalize_computed_default(sqltext: str) -> str: """ - return re.sub(r"[ \(\)'\"`\[\]\t\r\n]", "", sqltext).lower() + return re.sub(r"[ \(\)'\"`\[\]]", "", sqltext).lower() def _compare_computed_default( @@ -1119,15 +1137,27 @@ def _compare_server_default( return False if sqla_compat._server_default_is_computed(metadata_default): - return _compare_computed_default( # type:ignore[func-returns-value] - autogen_context, - alter_column_op, - schema, - tname, - cname, - conn_col, - metadata_col, - ) + # return False in case of a computed column as the server + # default. Note that DDL for adding or removing "GENERATED AS" from + # an existing column is not currently known for any backend. + # Once SQLAlchemy can reflect "GENERATED" as the "computed" element, + # we would also want to ignore and/or warn for changes vs. the + # metadata (or support backend specific DDL if applicable). + if not sqla_compat.has_computed_reflection: + return False + + else: + return ( + _compare_computed_default( # type:ignore[func-returns-value] + autogen_context, + alter_column_op, + schema, + tname, + cname, + conn_col, + metadata_col, + ) + ) if sqla_compat._server_default_is_computed(conn_col_default): _warn_computed_not_supported(tname, cname) return False @@ -1213,8 +1243,8 @@ def _compare_foreign_keys( modify_table_ops: ModifyTableOps, schema: Optional[str], tname: Union[quoted_name, str], - conn_table: Table, - metadata_table: Table, + conn_table: Optional[Table], + metadata_table: Optional[Table], ) -> None: # if we're doing CREATE TABLE, all FKs are created # inline within the table def @@ -1230,9 +1260,7 @@ def _compare_foreign_keys( conn_fks_list = [ fk - for fk in _InspectorConv(inspector).get_foreign_keys( - tname, schema=schema - ) + for fk in inspector.get_foreign_keys(tname, schema=schema) if autogen_context.run_name_filters( fk["name"], "foreign_key_constraint", @@ -1240,11 +1268,14 @@ def _compare_foreign_keys( ) ] - conn_fks = { - _make_foreign_key(const, conn_table) for const in conn_fks_list - } + backend_reflects_fk_options = bool( + conn_fks_list and "options" in conn_fks_list[0] + ) - impl = autogen_context.migration_context.impl + conn_fks = { + _make_foreign_key(const, conn_table) # type: ignore[arg-type] + for const in conn_fks_list + } # give the dialect a chance to correct the FKs to match more # closely @@ -1253,24 +1284,17 @@ def _compare_foreign_keys( ) metadata_fks_sig = { - impl._create_metadata_constraint_sig(fk) for fk in metadata_fks + _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) + for fk in metadata_fks } conn_fks_sig = { - impl._create_reflected_constraint_sig(fk) for fk in conn_fks + _fk_constraint_sig(fk, include_options=backend_reflects_fk_options) + for fk in conn_fks } - # check if reflected FKs include options, indicating the backend - # can reflect FK options - if conn_fks_list and "options" in conn_fks_list[0]: - conn_fks_by_sig = {c.unnamed: c for c in conn_fks_sig} - metadata_fks_by_sig = {c.unnamed: c for c in metadata_fks_sig} - else: - # otherwise compare by sig without options added - conn_fks_by_sig = {c.unnamed_no_options: c for c in conn_fks_sig} - metadata_fks_by_sig = { - c.unnamed_no_options: c for c in metadata_fks_sig - } + conn_fks_by_sig = {c.sig: c for c in conn_fks_sig} + metadata_fks_by_sig = {c.sig: c for c in metadata_fks_sig} metadata_fks_by_name = { c.name: c for c in metadata_fks_sig if c.name is not None diff --git a/venv/lib/python3.12/site-packages/alembic/autogenerate/render.py b/venv/lib/python3.12/site-packages/alembic/autogenerate/render.py index 7f32838..9c84cd6 100644 --- a/venv/lib/python3.12/site-packages/alembic/autogenerate/render.py +++ b/venv/lib/python3.12/site-packages/alembic/autogenerate/render.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations from io import StringIO @@ -18,9 +15,7 @@ from mako.pygen import PythonPrinter from sqlalchemy import schema as sa_schema from sqlalchemy import sql from sqlalchemy import types as sqltypes -from sqlalchemy.sql.base import _DialectArgView from sqlalchemy.sql.elements import conv -from sqlalchemy.sql.elements import Label from sqlalchemy.sql.elements import quoted_name from .. import util @@ -30,8 +25,7 @@ from ..util import sqla_compat if TYPE_CHECKING: from typing import Literal - from sqlalchemy import Computed - from sqlalchemy import Identity + from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.schema import CheckConstraint @@ -51,6 +45,8 @@ if TYPE_CHECKING: from alembic.config import Config from alembic.operations.ops import MigrationScript from alembic.operations.ops import ModifyTableOps + from alembic.util.sqla_compat import Computed + from alembic.util.sqla_compat import Identity MAX_PYTHON_ARGS = 255 @@ -168,31 +164,21 @@ def _render_modify_table( def _render_create_table_comment( autogen_context: AutogenContext, op: ops.CreateTableCommentOp ) -> str: - if autogen_context._has_batch: - templ = ( - "{prefix}create_table_comment(\n" - "{indent}{comment},\n" - "{indent}existing_comment={existing}\n" - ")" - ) - else: - templ = ( - "{prefix}create_table_comment(\n" - "{indent}'{tname}',\n" - "{indent}{comment},\n" - "{indent}existing_comment={existing},\n" - "{indent}schema={schema}\n" - ")" - ) + templ = ( + "{prefix}create_table_comment(\n" + "{indent}'{tname}',\n" + "{indent}{comment},\n" + "{indent}existing_comment={existing},\n" + "{indent}schema={schema}\n" + ")" + ) return templ.format( prefix=_alembic_autogenerate_prefix(autogen_context), tname=op.table_name, comment="%r" % op.comment if op.comment is not None else None, - existing=( - "%r" % op.existing_comment - if op.existing_comment is not None - else None - ), + existing="%r" % op.existing_comment + if op.existing_comment is not None + else None, schema="'%s'" % op.schema if op.schema is not None else None, indent=" ", ) @@ -202,28 +188,19 @@ def _render_create_table_comment( def _render_drop_table_comment( autogen_context: AutogenContext, op: ops.DropTableCommentOp ) -> str: - if autogen_context._has_batch: - templ = ( - "{prefix}drop_table_comment(\n" - "{indent}existing_comment={existing}\n" - ")" - ) - else: - templ = ( - "{prefix}drop_table_comment(\n" - "{indent}'{tname}',\n" - "{indent}existing_comment={existing},\n" - "{indent}schema={schema}\n" - ")" - ) + templ = ( + "{prefix}drop_table_comment(\n" + "{indent}'{tname}',\n" + "{indent}existing_comment={existing},\n" + "{indent}schema={schema}\n" + ")" + ) return templ.format( prefix=_alembic_autogenerate_prefix(autogen_context), tname=op.table_name, - existing=( - "%r" % op.existing_comment - if op.existing_comment is not None - else None - ), + existing="%r" % op.existing_comment + if op.existing_comment is not None + else None, schema="'%s'" % op.schema if op.schema is not None else None, indent=" ", ) @@ -280,9 +257,6 @@ def _add_table(autogen_context: AutogenContext, op: ops.CreateTableOp) -> str: prefixes = ", ".join("'%s'" % p for p in table._prefixes) text += ",\nprefixes=[%s]" % prefixes - if op.if_not_exists is not None: - text += ",\nif_not_exists=%r" % bool(op.if_not_exists) - text += "\n)" return text @@ -295,20 +269,16 @@ def _drop_table(autogen_context: AutogenContext, op: ops.DropTableOp) -> str: } if op.schema: text += ", schema=%r" % _ident(op.schema) - - if op.if_exists is not None: - text += ", if_exists=%r" % bool(op.if_exists) - text += ")" return text def _render_dialect_kwargs_items( - autogen_context: AutogenContext, dialect_kwargs: _DialectArgView + autogen_context: AutogenContext, item: DialectKWArgs ) -> list[str]: return [ f"{key}={_render_potential_expr(val, autogen_context)}" - for key, val in dialect_kwargs.items() + for key, val in item.dialect_kwargs.items() ] @@ -331,9 +301,7 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: assert index.table is not None - opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs) - if op.if_not_exists is not None: - opts.append("if_not_exists=%r" % bool(op.if_not_exists)) + opts = _render_dialect_kwargs_items(autogen_context, index) text = tmpl % { "prefix": _alembic_autogenerate_prefix(autogen_context), "name": _render_gen_name(autogen_context, index.name), @@ -342,11 +310,9 @@ def _add_index(autogen_context: AutogenContext, op: ops.CreateIndexOp) -> str: _get_index_rendered_expressions(index, autogen_context) ), "unique": index.unique or False, - "schema": ( - (", schema=%r" % _ident(index.table.schema)) - if index.table.schema - else "" - ), + "schema": (", schema=%r" % _ident(index.table.schema)) + if index.table.schema + else "", "kwargs": ", " + ", ".join(opts) if opts else "", } return text @@ -365,9 +331,7 @@ def _drop_index(autogen_context: AutogenContext, op: ops.DropIndexOp) -> str: "%(prefix)sdrop_index(%(name)r, " "table_name=%(table_name)r%(schema)s%(kwargs)s)" ) - opts = _render_dialect_kwargs_items(autogen_context, index.dialect_kwargs) - if op.if_exists is not None: - opts.append("if_exists=%r" % bool(op.if_exists)) + opts = _render_dialect_kwargs_items(autogen_context, index) text = tmpl % { "prefix": _alembic_autogenerate_prefix(autogen_context), "name": _render_gen_name(autogen_context, op.index_name), @@ -389,7 +353,6 @@ def _add_unique_constraint( def _add_fk_constraint( autogen_context: AutogenContext, op: ops.CreateForeignKeyOp ) -> str: - constraint = op.to_constraint() args = [repr(_render_gen_name(autogen_context, op.constraint_name))] if not autogen_context._has_batch: args.append(repr(_ident(op.source_table))) @@ -419,16 +382,9 @@ def _add_fk_constraint( if value is not None: args.append("%s=%r" % (k, value)) - dialect_kwargs = _render_dialect_kwargs_items( - autogen_context, constraint.dialect_kwargs - ) - - return "%(prefix)screate_foreign_key(%(args)s%(dialect_kwargs)s)" % { + return "%(prefix)screate_foreign_key(%(args)s)" % { "prefix": _alembic_autogenerate_prefix(autogen_context), "args": ", ".join(args), - "dialect_kwargs": ( - ", " + ", ".join(dialect_kwargs) if dialect_kwargs else "" - ), } @@ -450,7 +406,7 @@ def _drop_constraint( name = _render_gen_name(autogen_context, op.constraint_name) schema = _ident(op.schema) if op.schema else None type_ = _ident(op.constraint_type) if op.constraint_type else None - if_exists = op.if_exists + params_strs = [] params_strs.append(repr(name)) if not autogen_context._has_batch: @@ -459,47 +415,32 @@ def _drop_constraint( params_strs.append(f"schema={schema!r}") if type_ is not None: params_strs.append(f"type_={type_!r}") - if if_exists is not None: - params_strs.append(f"if_exists={if_exists}") return f"{prefix}drop_constraint({', '.join(params_strs)})" @renderers.dispatch_for(ops.AddColumnOp) def _add_column(autogen_context: AutogenContext, op: ops.AddColumnOp) -> str: - schema, tname, column, if_not_exists = ( - op.schema, - op.table_name, - op.column, - op.if_not_exists, - ) + schema, tname, column = op.schema, op.table_name, op.column if autogen_context._has_batch: template = "%(prefix)sadd_column(%(column)s)" else: template = "%(prefix)sadd_column(%(tname)r, %(column)s" if schema: template += ", schema=%(schema)r" - if if_not_exists is not None: - template += ", if_not_exists=%(if_not_exists)r" template += ")" text = template % { "prefix": _alembic_autogenerate_prefix(autogen_context), "tname": tname, "column": _render_column(column, autogen_context), "schema": schema, - "if_not_exists": if_not_exists, } return text @renderers.dispatch_for(ops.DropColumnOp) def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str: - schema, tname, column_name, if_exists = ( - op.schema, - op.table_name, - op.column_name, - op.if_exists, - ) + schema, tname, column_name = op.schema, op.table_name, op.column_name if autogen_context._has_batch: template = "%(prefix)sdrop_column(%(cname)r)" @@ -507,8 +448,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str: template = "%(prefix)sdrop_column(%(tname)r, %(cname)r" if schema: template += ", schema=%(schema)r" - if if_exists is not None: - template += ", if_exists=%(if_exists)r" template += ")" text = template % { @@ -516,7 +455,6 @@ def _drop_column(autogen_context: AutogenContext, op: ops.DropColumnOp) -> str: "tname": _ident(tname), "cname": _ident(column_name), "schema": _ident(schema), - "if_exists": if_exists, } return text @@ -531,7 +469,6 @@ def _alter_column( type_ = op.modify_type nullable = op.modify_nullable comment = op.modify_comment - newname = op.modify_name autoincrement = op.kw.get("autoincrement", None) existing_type = op.existing_type existing_nullable = op.existing_nullable @@ -560,8 +497,6 @@ def _alter_column( rendered = _render_server_default(server_default, autogen_context) text += ",\n%sserver_default=%s" % (indent, rendered) - if newname is not None: - text += ",\n%snew_column_name=%r" % (indent, newname) if type_ is not None: text += ",\n%stype_=%s" % (indent, _repr_type(type_, autogen_context)) if nullable is not None: @@ -614,28 +549,23 @@ def _render_potential_expr( value: Any, autogen_context: AutogenContext, *, - wrap_in_element: bool = True, + wrap_in_text: bool = True, is_server_default: bool = False, is_index: bool = False, ) -> str: if isinstance(value, sql.ClauseElement): - sql_text = autogen_context.migration_context.impl.render_ddl_sql_expr( - value, is_server_default=is_server_default, is_index=is_index - ) - if wrap_in_element: - prefix = _sqlalchemy_autogenerate_prefix(autogen_context) - element = "literal_column" if is_index else "text" - value_str = f"{prefix}{element}({sql_text!r})" - if ( - is_index - and isinstance(value, Label) - and type(value.name) is str - ): - return value_str + f".label({value.name!r})" - else: - return value_str + if wrap_in_text: + template = "%(prefix)stext(%(sql)r)" else: - return repr(sql_text) + template = "%(sql)r" + + return template % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "sql": autogen_context.migration_context.impl.render_ddl_sql_expr( + value, is_server_default=is_server_default, is_index=is_index + ), + } + else: return repr(value) @@ -644,11 +574,9 @@ def _get_index_rendered_expressions( idx: Index, autogen_context: AutogenContext ) -> List[str]: return [ - ( - repr(_ident(getattr(exp, "name", None))) - if isinstance(exp, sa_schema.Column) - else _render_potential_expr(exp, autogen_context, is_index=True) - ) + repr(_ident(getattr(exp, "name", None))) + if isinstance(exp, sa_schema.Column) + else _render_potential_expr(exp, autogen_context, is_index=True) for exp in idx.expressions ] @@ -663,18 +591,16 @@ def _uq_constraint( has_batch = autogen_context._has_batch if constraint.deferrable: - opts.append(("deferrable", constraint.deferrable)) + opts.append(("deferrable", str(constraint.deferrable))) if constraint.initially: - opts.append(("initially", constraint.initially)) + opts.append(("initially", str(constraint.initially))) if not has_batch and alter and constraint.table.schema: opts.append(("schema", _ident(constraint.table.schema))) if not alter and constraint.name: opts.append( ("name", _render_gen_name(autogen_context, constraint.name)) ) - dialect_options = _render_dialect_kwargs_items( - autogen_context, constraint.dialect_kwargs - ) + dialect_options = _render_dialect_kwargs_items(autogen_context, constraint) if alter: args = [repr(_render_gen_name(autogen_context, constraint.name))] @@ -778,7 +704,7 @@ def _render_column( + [ "%s=%s" % (key, _render_potential_expr(val, autogen_context)) - for key, val in column.kwargs.items() + for key, val in sqla_compat._column_kwargs(column).items() ] ) ), @@ -813,8 +739,6 @@ def _render_server_default( return _render_potential_expr( default.arg, autogen_context, is_server_default=True ) - elif isinstance(default, sa_schema.FetchedValue): - return _render_fetched_value(autogen_context) if isinstance(default, str) and repr_: default = repr(re.sub(r"^'|'$", "", default)) @@ -826,7 +750,7 @@ def _render_computed( computed: Computed, autogen_context: AutogenContext ) -> str: text = _render_potential_expr( - computed.sqltext, autogen_context, wrap_in_element=False + computed.sqltext, autogen_context, wrap_in_text=False ) kwargs = {} @@ -852,12 +776,6 @@ def _render_identity( } -def _render_fetched_value(autogen_context: AutogenContext) -> str: - return "%(prefix)sFetchedValue()" % { - "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - } - - def _repr_type( type_: TypeEngine, autogen_context: AutogenContext, @@ -876,10 +794,7 @@ def _repr_type( mod = type(type_).__module__ imports = autogen_context.imports - - if not _skip_variants and sqla_compat._type_has_variants(type_): - return _render_Variant_type(type_, autogen_context) - elif mod.startswith("sqlalchemy.dialects"): + if mod.startswith("sqlalchemy.dialects"): match = re.match(r"sqlalchemy\.dialects\.(\w+)", mod) assert match is not None dname = match.group(1) @@ -891,6 +806,8 @@ def _repr_type( return "%s.%r" % (dname, type_) elif impl_rt: return impl_rt + elif not _skip_variants and sqla_compat._type_has_variants(type_): + return _render_Variant_type(type_, autogen_context) elif mod.startswith("sqlalchemy."): if "_render_%s_type" % type_.__visit_name__ in globals(): fn = globals()["_render_%s_type" % type_.__visit_name__] @@ -917,7 +834,7 @@ def _render_Variant_type( ) -> str: base_type, variant_mapping = sqla_compat._get_variant_mapping(type_) base = _repr_type(base_type, autogen_context, _skip_variants=True) - assert base is not None and base is not False # type: ignore[comparison-overlap] # noqa:E501 + assert base is not None and base is not False for dialect in sorted(variant_mapping): typ = variant_mapping[dialect] base += ".with_variant(%s, %r)" % ( @@ -1008,13 +925,13 @@ def _render_primary_key( def _fk_colspec( fk: ForeignKey, metadata_schema: Optional[str], - namespace_metadata: Optional[MetaData], + namespace_metadata: MetaData, ) -> str: """Implement a 'safe' version of ForeignKey._get_colspec() that won't fail if the remote table can't be resolved. """ - colspec = fk._get_colspec() + colspec = fk._get_colspec() # type:ignore[attr-defined] tokens = colspec.split(".") tname, colname = tokens[-2:] @@ -1032,10 +949,7 @@ def _fk_colspec( # the FK constraint needs to be rendered in terms of the column # name. - if ( - namespace_metadata is not None - and table_fullname in namespace_metadata.tables - ): + if table_fullname in namespace_metadata.tables: col = namespace_metadata.tables[table_fullname].c.get(colname) if col is not None: colname = _ident(col.name) # type: ignore[assignment] @@ -1066,7 +980,7 @@ def _populate_render_fk_opts( def _render_foreign_key( constraint: ForeignKeyConstraint, autogen_context: AutogenContext, - namespace_metadata: Optional[MetaData], + namespace_metadata: MetaData, ) -> Optional[str]: rendered = _user_defined_render("foreign_key", constraint, autogen_context) if rendered is not False: @@ -1080,16 +994,15 @@ def _render_foreign_key( _populate_render_fk_opts(constraint, opts) - apply_metadata_schema = ( - namespace_metadata.schema if namespace_metadata is not None else None - ) + apply_metadata_schema = namespace_metadata.schema return ( "%(prefix)sForeignKeyConstraint([%(cols)s], " "[%(refcols)s], %(args)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "cols": ", ".join( - repr(_ident(f.parent.name)) for f in constraint.elements + "%r" % _ident(cast("Column", f.parent).name) + for f in constraint.elements ), "refcols": ", ".join( repr(_fk_colspec(f, apply_metadata_schema, namespace_metadata)) @@ -1130,10 +1043,12 @@ def _render_check_constraint( # ideally SQLAlchemy would give us more of a first class # way to detect this. if ( - constraint._create_rule - and hasattr(constraint._create_rule, "target") + constraint._create_rule # type:ignore[attr-defined] + and hasattr( + constraint._create_rule, "target" # type:ignore[attr-defined] + ) and isinstance( - constraint._create_rule.target, + constraint._create_rule.target, # type:ignore[attr-defined] sqltypes.TypeEngine, ) ): @@ -1145,13 +1060,11 @@ def _render_check_constraint( ) return "%(prefix)sCheckConstraint(%(sqltext)s%(opts)s)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "opts": ( - ", " + (", ".join("%s=%s" % (k, v) for k, v in opts)) - if opts - else "" - ), + "opts": ", " + (", ".join("%s=%s" % (k, v) for k, v in opts)) + if opts + else "", "sqltext": _render_potential_expr( - constraint.sqltext, autogen_context, wrap_in_element=False + constraint.sqltext, autogen_context, wrap_in_text=False ), } @@ -1163,10 +1076,7 @@ def _execute_sql(autogen_context: AutogenContext, op: ops.ExecuteSQLOp) -> str: "Autogenerate rendering of SQL Expression language constructs " "not supported here; please use a plain SQL string" ) - return "{prefix}execute({sqltext!r})".format( - prefix=_alembic_autogenerate_prefix(autogen_context), - sqltext=op.sqltext, - ) + return "op.execute(%r)" % op.sqltext renderers = default_renderers.branch() diff --git a/venv/lib/python3.12/site-packages/alembic/autogenerate/rewriter.py b/venv/lib/python3.12/site-packages/alembic/autogenerate/rewriter.py index 1d44b5c..68a93dd 100644 --- a/venv/lib/python3.12/site-packages/alembic/autogenerate/rewriter.py +++ b/venv/lib/python3.12/site-packages/alembic/autogenerate/rewriter.py @@ -4,7 +4,7 @@ from typing import Any from typing import Callable from typing import Iterator from typing import List -from typing import Tuple +from typing import Optional from typing import Type from typing import TYPE_CHECKING from typing import Union @@ -16,18 +16,12 @@ if TYPE_CHECKING: from ..operations.ops import AddColumnOp from ..operations.ops import AlterColumnOp from ..operations.ops import CreateTableOp - from ..operations.ops import DowngradeOps from ..operations.ops import MigrateOperation from ..operations.ops import MigrationScript from ..operations.ops import ModifyTableOps from ..operations.ops import OpContainer - from ..operations.ops import UpgradeOps + from ..runtime.environment import _GetRevArg from ..runtime.migration import MigrationContext - from ..script.revision import _GetRevArg - -ProcessRevisionDirectiveFn = Callable[ - ["MigrationContext", "_GetRevArg", List["MigrationScript"]], None -] class Rewriter: @@ -58,21 +52,15 @@ class Rewriter: _traverse = util.Dispatcher() - _chained: Tuple[Union[ProcessRevisionDirectiveFn, Rewriter], ...] = () + _chained: Optional[Rewriter] = None def __init__(self) -> None: self.dispatch = util.Dispatcher() - def chain( - self, - other: Union[ - ProcessRevisionDirectiveFn, - Rewriter, - ], - ) -> Rewriter: + def chain(self, other: Rewriter) -> Rewriter: """Produce a "chain" of this :class:`.Rewriter` to another. - This allows two or more rewriters to operate serially on a stream, + This allows two rewriters to operate serially on a stream, e.g.:: writer1 = autogenerate.Rewriter() @@ -101,7 +89,7 @@ class Rewriter: """ wr = self.__class__.__new__(self.__class__) wr.__dict__.update(self.__dict__) - wr._chained += (other,) + wr._chained = other return wr def rewrites( @@ -113,7 +101,7 @@ class Rewriter: Type[CreateTableOp], Type[ModifyTableOps], ], - ) -> Callable[..., Any]: + ) -> Callable: """Register a function as rewriter for a given type. The function should receive three arguments, which are @@ -158,8 +146,8 @@ class Rewriter: directives: List[MigrationScript], ) -> None: self.process_revision_directives(context, revision, directives) - for process_revision_directives in self._chained: - process_revision_directives(context, revision, directives) + if self._chained: + self._chained(context, revision, directives) @_traverse.dispatch_for(ops.MigrationScript) def _traverse_script( @@ -168,7 +156,7 @@ class Rewriter: revision: _GetRevArg, directive: MigrationScript, ) -> None: - upgrade_ops_list: List[UpgradeOps] = [] + upgrade_ops_list = [] for upgrade_ops in directive.upgrade_ops_list: ret = self._traverse_for(context, revision, upgrade_ops) if len(ret) != 1: @@ -176,10 +164,9 @@ class Rewriter: "Can only return single object for UpgradeOps traverse" ) upgrade_ops_list.append(ret[0]) - directive.upgrade_ops = upgrade_ops_list - downgrade_ops_list: List[DowngradeOps] = [] + downgrade_ops_list = [] for downgrade_ops in directive.downgrade_ops_list: ret = self._traverse_for(context, revision, downgrade_ops) if len(ret) != 1: diff --git a/venv/lib/python3.12/site-packages/alembic/command.py b/venv/lib/python3.12/site-packages/alembic/command.py index 8e48547..dbaa9cf 100644 --- a/venv/lib/python3.12/site-packages/alembic/command.py +++ b/venv/lib/python3.12/site-packages/alembic/command.py @@ -1,9 +1,6 @@ -# mypy: allow-untyped-defs, allow-untyped-calls - from __future__ import annotations import os -import pathlib from typing import List from typing import Optional from typing import TYPE_CHECKING @@ -13,7 +10,6 @@ from . import autogenerate as autogen from . import util from .runtime.environment import EnvironmentContext from .script import ScriptDirectory -from .util import compat if TYPE_CHECKING: from alembic.config import Config @@ -22,7 +18,7 @@ if TYPE_CHECKING: from .runtime.environment import ProcessRevisionDirectiveFn -def list_templates(config: Config) -> None: +def list_templates(config: Config): """List available templates. :param config: a :class:`.Config` object. @@ -30,10 +26,12 @@ def list_templates(config: Config) -> None: """ config.print_stdout("Available templates:\n") - for tempname in config._get_template_path().iterdir(): - with (tempname / "README").open() as readme: + for tempname in os.listdir(config.get_template_directory()): + with open( + os.path.join(config.get_template_directory(), tempname, "README") + ) as readme: synopsis = next(readme).rstrip() - config.print_stdout("%s - %s", tempname.name, synopsis) + config.print_stdout("%s - %s", tempname, synopsis) config.print_stdout("\nTemplates are used via the 'init' command, e.g.:") config.print_stdout("\n alembic init --template generic ./scripts") @@ -49,7 +47,7 @@ def init( :param config: a :class:`.Config` object. - :param directory: string path of the target directory. + :param directory: string path of the target directory :param template: string name of the migration environment template to use. @@ -59,136 +57,65 @@ def init( """ - directory_path = pathlib.Path(directory) - if directory_path.exists() and list(directory_path.iterdir()): + if os.access(directory, os.F_OK) and os.listdir(directory): raise util.CommandError( - "Directory %s already exists and is not empty" % directory_path + "Directory %s already exists and is not empty" % directory ) - template_path = config._get_template_path() / template + template_dir = os.path.join(config.get_template_directory(), template) + if not os.access(template_dir, os.F_OK): + raise util.CommandError("No such template %r" % template) - if not template_path.exists(): - raise util.CommandError(f"No such template {template_path}") - - # left as os.access() to suit unit test mocking - if not os.access(directory_path, os.F_OK): + if not os.access(directory, os.F_OK): with util.status( - f"Creating directory {directory_path.absolute()}", + f"Creating directory {os.path.abspath(directory)!r}", **config.messaging_opts, ): - os.makedirs(directory_path) + os.makedirs(directory) - versions = directory_path / "versions" + versions = os.path.join(directory, "versions") with util.status( - f"Creating directory {versions.absolute()}", + f"Creating directory {os.path.abspath(versions)!r}", **config.messaging_opts, ): os.makedirs(versions) - if not directory_path.is_absolute(): - # for non-absolute path, state config file in .ini / pyproject - # as relative to the %(here)s token, which is where the config - # file itself would be + script = ScriptDirectory(directory) - if config._config_file_path is not None: - rel_dir = compat.path_relative_to( - directory_path.absolute(), - config._config_file_path.absolute().parent, - walk_up=True, - ) - ini_script_location_directory = ("%(here)s" / rel_dir).as_posix() - if config._toml_file_path is not None: - rel_dir = compat.path_relative_to( - directory_path.absolute(), - config._toml_file_path.absolute().parent, - walk_up=True, - ) - toml_script_location_directory = ("%(here)s" / rel_dir).as_posix() - - else: - ini_script_location_directory = directory_path.as_posix() - toml_script_location_directory = directory_path.as_posix() - - script = ScriptDirectory(directory_path) - - has_toml = False - - config_file: pathlib.Path | None = None - - for file_path in template_path.iterdir(): - file_ = file_path.name + config_file: str | None = None + for file_ in os.listdir(template_dir): + file_path = os.path.join(template_dir, file_) if file_ == "alembic.ini.mako": assert config.config_file_name is not None - config_file = pathlib.Path(config.config_file_name).absolute() - if config_file.exists(): + config_file = os.path.abspath(config.config_file_name) + if os.access(config_file, os.F_OK): util.msg( - f"File {config_file} already exists, skipping", + f"File {config_file!r} already exists, skipping", **config.messaging_opts, ) else: script._generate_template( - file_path, - config_file, - script_location=ini_script_location_directory, + file_path, config_file, script_location=directory ) - elif file_ == "pyproject.toml.mako": - has_toml = True - assert config._toml_file_path is not None - toml_path = config._toml_file_path.absolute() - - if toml_path.exists(): - # left as open() to suit unit test mocking - with open(toml_path, "rb") as f: - toml_data = compat.tomllib.load(f) - if "tool" in toml_data and "alembic" in toml_data["tool"]: - - util.msg( - f"File {toml_path} already exists " - "and already has a [tool.alembic] section, " - "skipping", - ) - continue - script._append_template( - file_path, - toml_path, - script_location=toml_script_location_directory, - ) - else: - script._generate_template( - file_path, - toml_path, - script_location=toml_script_location_directory, - ) - - elif file_path.is_file(): - output_file = directory_path / file_ + elif os.path.isfile(file_path): + output_file = os.path.join(directory, file_) script._copy_file(file_path, output_file) if package: for path in [ - directory_path.absolute() / "__init__.py", - versions.absolute() / "__init__.py", + os.path.join(os.path.abspath(directory), "__init__.py"), + os.path.join(os.path.abspath(versions), "__init__.py"), ]: - with util.status(f"Adding {path!s}", **config.messaging_opts): - # left as open() to suit unit test mocking + with util.status(f"Adding {path!r}", **config.messaging_opts): with open(path, "w"): pass assert config_file is not None - - if has_toml: - util.msg( - f"Please edit configuration settings in {toml_path} and " - "configuration/connection/logging " - f"settings in {config_file} before proceeding.", - **config.messaging_opts, - ) - else: - util.msg( - "Please edit configuration/connection/logging " - f"settings in {config_file} before proceeding.", - **config.messaging_opts, - ) + util.msg( + "Please edit configuration/connection/logging " + f"settings in {config_file!r} before proceeding.", + **config.messaging_opts, + ) def revision( @@ -199,7 +126,7 @@ def revision( head: str = "head", splice: bool = False, branch_label: Optional[_RevIdType] = None, - version_path: Union[str, os.PathLike[str], None] = None, + version_path: Optional[str] = None, rev_id: Optional[str] = None, depends_on: Optional[str] = None, process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None, @@ -245,7 +172,7 @@ def revision( will be applied to the structure generated by the revision process where it can be altered programmatically. Note that unlike all the other parameters, this option is only available via programmatic - use of :func:`.command.revision`. + use of :func:`.command.revision` """ @@ -269,9 +196,7 @@ def revision( process_revision_directives=process_revision_directives, ) - environment = util.asbool( - config.get_alembic_option("revision_environment") - ) + environment = util.asbool(config.get_main_option("revision_environment")) if autogenerate: environment = True @@ -365,15 +290,10 @@ def check(config: "Config") -> None: # the revision_context now has MigrationScript structure(s) present. migration_script = revision_context.generated_revisions[-1] - diffs = [] - for upgrade_ops in migration_script.upgrade_ops_list: - diffs.extend(upgrade_ops.as_diffs()) - + diffs = migration_script.upgrade_ops.as_diffs() if diffs: raise util.AutogenerateDiffsDetected( - f"New upgrade operations detected: {diffs}", - revision_context=revision_context, - diffs=diffs, + f"New upgrade operations detected: {diffs}" ) else: config.print_stdout("No new upgrade operations detected.") @@ -390,11 +310,9 @@ def merge( :param config: a :class:`.Config` instance - :param revisions: The revisions to merge. + :param message: string message to apply to the revision - :param message: string message to apply to the revision. - - :param branch_label: string label name to apply to the new revision. + :param branch_label: string label name to apply to the new revision :param rev_id: hardcoded revision identifier instead of generating a new one. @@ -411,9 +329,7 @@ def merge( # e.g. multiple databases } - environment = util.asbool( - config.get_alembic_option("revision_environment") - ) + environment = util.asbool(config.get_main_option("revision_environment")) if environment: @@ -449,10 +365,9 @@ def upgrade( :param config: a :class:`.Config` instance. - :param revision: string revision target or range for --sql mode. May be - ``"heads"`` to target the most recent revision(s). + :param revision: string revision target or range for --sql mode - :param sql: if True, use ``--sql`` mode. + :param sql: if True, use ``--sql`` mode :param tag: an arbitrary "tag" that can be intercepted by custom ``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument` @@ -493,10 +408,9 @@ def downgrade( :param config: a :class:`.Config` instance. - :param revision: string revision target or range for --sql mode. May - be ``"base"`` to target the first revision. + :param revision: string revision target or range for --sql mode - :param sql: if True, use ``--sql`` mode. + :param sql: if True, use ``--sql`` mode :param tag: an arbitrary "tag" that can be intercepted by custom ``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument` @@ -530,13 +444,12 @@ def downgrade( script.run_env() -def show(config: Config, rev: str) -> None: +def show(config, rev): """Show the revision(s) denoted by the given symbol. :param config: a :class:`.Config` instance. - :param rev: string revision target. May be ``"current"`` to show the - revision(s) currently applied in the database. + :param revision: string revision target """ @@ -566,7 +479,7 @@ def history( :param config: a :class:`.Config` instance. - :param rev_range: string revision range. + :param rev_range: string revision range :param verbose: output in verbose mode. @@ -586,7 +499,7 @@ def history( base = head = None environment = ( - util.asbool(config.get_alembic_option("revision_environment")) + util.asbool(config.get_main_option("revision_environment")) or indicate_current ) @@ -625,9 +538,7 @@ def history( _display_history(config, script, base, head) -def heads( - config: Config, verbose: bool = False, resolve_dependencies: bool = False -) -> None: +def heads(config, verbose=False, resolve_dependencies=False): """Show current available heads in the script directory. :param config: a :class:`.Config` instance. @@ -652,7 +563,7 @@ def heads( ) -def branches(config: Config, verbose: bool = False) -> None: +def branches(config, verbose=False): """Show current branch points. :param config: a :class:`.Config` instance. @@ -722,9 +633,7 @@ def stamp( :param config: a :class:`.Config` instance. :param revision: target revision or list of revisions. May be a list - to indicate stamping of multiple branch heads; may be ``"base"`` - to remove all revisions from the table or ``"heads"`` to stamp the - most recent revision(s). + to indicate stamping of multiple branch heads. .. note:: this parameter is called "revisions" in the command line interface. @@ -814,7 +723,7 @@ def ensure_version(config: Config, sql: bool = False) -> None: :param config: a :class:`.Config` instance. - :param sql: use ``--sql`` mode. + :param sql: use ``--sql`` mode .. versionadded:: 1.7.6 diff --git a/venv/lib/python3.12/site-packages/alembic/config.py b/venv/lib/python3.12/site-packages/alembic/config.py index b8c60a4..55b5811 100644 --- a/venv/lib/python3.12/site-packages/alembic/config.py +++ b/venv/lib/python3.12/site-packages/alembic/config.py @@ -5,8 +5,6 @@ from argparse import Namespace from configparser import ConfigParser import inspect import os -from pathlib import Path -import re import sys from typing import Any from typing import cast @@ -14,8 +12,6 @@ from typing import Dict from typing import Mapping from typing import Optional from typing import overload -from typing import Protocol -from typing import Sequence from typing import TextIO from typing import Union @@ -25,7 +21,6 @@ from . import __version__ from . import command from . import util from .util import compat -from .util.pyfiles import _preserving_path_as_str class Config: @@ -75,20 +70,7 @@ class Config: alembic_cfg.attributes['connection'] = connection command.upgrade(alembic_cfg, "head") - :param file\_: name of the .ini file to open if an ``alembic.ini`` is - to be used. This should refer to the ``alembic.ini`` file, either as - a filename or a full path to the file. This filename if passed must refer - to an **ini file in ConfigParser format** only. - - :param toml\_file: name of the pyproject.toml file to open if a - ``pyproject.toml`` file is to be used. This should refer to the - ``pyproject.toml`` file, either as a filename or a full path to the file. - This file must be in toml format. Both :paramref:`.Config.file\_` and - :paramref:`.Config.toml\_file` may be passed simultaneously, or - exclusively. - - .. versionadded:: 1.16.0 - + :param file\_: name of the .ini file to open. :param ini_section: name of the main Alembic section within the .ini file :param output_buffer: optional file-like input buffer which @@ -98,13 +80,12 @@ class Config: Defaults to ``sys.stdout``. :param config_args: A dictionary of keys and values that will be used - for substitution in the alembic config file, as well as the pyproject.toml - file, depending on which / both are used. The dictionary as given is - **copied** to two new, independent dictionaries, stored locally under the - attributes ``.config_args`` and ``.toml_args``. Both of these - dictionaries will also be populated with the replacement variable - ``%(here)s``, which refers to the location of the .ini and/or .toml file - as appropriate. + for substitution in the alembic config file. The dictionary as given + is **copied** to a new one, stored locally as the attribute + ``.config_args``. When the :attr:`.Config.file_config` attribute is + first invoked, the replacement variable ``here`` will be added to this + dictionary before the dictionary is passed to ``ConfigParser()`` + to parse the .ini file. :param attributes: optional dictionary of arbitrary Python keys/values, which will be populated into the :attr:`.Config.attributes` dictionary. @@ -118,27 +99,20 @@ class Config: def __init__( self, file_: Union[str, os.PathLike[str], None] = None, - toml_file: Union[str, os.PathLike[str], None] = None, ini_section: str = "alembic", output_buffer: Optional[TextIO] = None, stdout: TextIO = sys.stdout, cmd_opts: Optional[Namespace] = None, config_args: Mapping[str, Any] = util.immutabledict(), - attributes: Optional[Dict[str, Any]] = None, + attributes: Optional[dict] = None, ) -> None: """Construct a new :class:`.Config`""" - self.config_file_name = ( - _preserving_path_as_str(file_) if file_ else None - ) - self.toml_file_name = ( - _preserving_path_as_str(toml_file) if toml_file else None - ) + self.config_file_name = file_ self.config_ini_section = ini_section self.output_buffer = output_buffer self.stdout = stdout self.cmd_opts = cmd_opts self.config_args = dict(config_args) - self.toml_args = dict(config_args) if attributes: self.attributes.update(attributes) @@ -154,28 +128,9 @@ class Config: """ - config_file_name: Optional[str] = None + config_file_name: Union[str, os.PathLike[str], None] = None """Filesystem path to the .ini file in use.""" - toml_file_name: Optional[str] = None - """Filesystem path to the pyproject.toml file in use. - - .. versionadded:: 1.16.0 - - """ - - @property - def _config_file_path(self) -> Optional[Path]: - if self.config_file_name is None: - return None - return Path(self.config_file_name) - - @property - def _toml_file_path(self) -> Optional[Path]: - if self.toml_file_name is None: - return None - return Path(self.toml_file_name) - config_ini_section: str = None # type:ignore[assignment] """Name of the config file section to read basic configuration from. Defaults to ``alembic``, that is the ``[alembic]`` section @@ -185,7 +140,7 @@ class Config: """ @util.memoized_property - def attributes(self) -> Dict[str, Any]: + def attributes(self): """A Python dictionary for storage of additional state. @@ -204,7 +159,7 @@ class Config: """ return {} - def print_stdout(self, text: str, *arg: Any) -> None: + def print_stdout(self, text: str, *arg) -> None: """Render a message to standard out. When :meth:`.Config.print_stdout` is called with additional args @@ -228,48 +183,28 @@ class Config: util.write_outstream(self.stdout, output, "\n", **self.messaging_opts) @util.memoized_property - def file_config(self) -> ConfigParser: + def file_config(self): """Return the underlying ``ConfigParser`` object. - Dir*-ect access to the .ini file is available here, + Direct access to the .ini file is available here, though the :meth:`.Config.get_section` and :meth:`.Config.get_main_option` methods provide a possibly simpler interface. """ - if self._config_file_path: - here = self._config_file_path.absolute().parent + if self.config_file_name: + here = os.path.abspath(os.path.dirname(self.config_file_name)) else: - here = Path() - self.config_args["here"] = here.as_posix() + here = "" + self.config_args["here"] = here file_config = ConfigParser(self.config_args) - if self._config_file_path: - compat.read_config_parser(file_config, [self._config_file_path]) + if self.config_file_name: + compat.read_config_parser(file_config, [self.config_file_name]) else: file_config.add_section(self.config_ini_section) return file_config - @util.memoized_property - def toml_alembic_config(self) -> Mapping[str, Any]: - """Return a dictionary of the [tool.alembic] section from - pyproject.toml""" - - if self._toml_file_path and self._toml_file_path.exists(): - - here = self._toml_file_path.absolute().parent - self.toml_args["here"] = here.as_posix() - - with open(self._toml_file_path, "rb") as f: - toml_data = compat.tomllib.load(f) - data = toml_data.get("tool", {}).get("alembic", {}) - if not isinstance(data, dict): - raise util.CommandError("Incorrect TOML format") - return data - - else: - return {} - def get_template_directory(self) -> str: """Return the directory where Alembic setup templates are found. @@ -279,24 +214,14 @@ class Config: """ import alembic - package_dir = Path(alembic.__file__).absolute().parent - return str(package_dir / "templates") - - def _get_template_path(self) -> Path: - """Return the directory where Alembic setup templates are found. - - This method is used by the alembic ``init`` and ``list_templates`` - commands. - - .. versionadded:: 1.16.0 - - """ - return Path(self.get_template_directory()) + package_dir = os.path.abspath(os.path.dirname(alembic.__file__)) + return os.path.join(package_dir, "templates") @overload def get_section( self, name: str, default: None = ... - ) -> Optional[Dict[str, str]]: ... + ) -> Optional[Dict[str, str]]: + ... # "default" here could also be a TypeVar # _MT = TypeVar("_MT", bound=Mapping[str, str]), @@ -304,12 +229,14 @@ class Config: @overload def get_section( self, name: str, default: Dict[str, str] - ) -> Dict[str, str]: ... + ) -> Dict[str, str]: + ... @overload def get_section( self, name: str, default: Mapping[str, str] - ) -> Union[Dict[str, str], Mapping[str, str]]: ... + ) -> Union[Dict[str, str], Mapping[str, str]]: + ... def get_section( self, name: str, default: Optional[Mapping[str, str]] = None @@ -353,12 +280,6 @@ class Config: The value here will override whatever was in the .ini file. - Does **NOT** consume from the pyproject.toml file. - - .. seealso:: - - :meth:`.Config.get_alembic_option` - includes pyproject support - :param section: name of the section :param name: name of the value @@ -391,122 +312,25 @@ class Config: return default @overload - def get_main_option(self, name: str, default: str) -> str: ... + def get_main_option(self, name: str, default: str) -> str: + ... @overload - def get_main_option( - self, name: str, default: Optional[str] = None - ) -> Optional[str]: ... - def get_main_option( self, name: str, default: Optional[str] = None ) -> Optional[str]: + ... + + def get_main_option(self, name, default=None): """Return an option from the 'main' section of the .ini file. This defaults to being a key from the ``[alembic]`` section, unless the ``-n/--name`` flag were used to indicate a different section. - Does **NOT** consume from the pyproject.toml file. - - .. seealso:: - - :meth:`.Config.get_alembic_option` - includes pyproject support - """ return self.get_section_option(self.config_ini_section, name, default) - @overload - def get_alembic_option(self, name: str, default: str) -> str: ... - - @overload - def get_alembic_option( - self, name: str, default: Optional[str] = None - ) -> Optional[str]: ... - - def get_alembic_option( - self, name: str, default: Optional[str] = None - ) -> Union[ - None, str, list[str], dict[str, str], list[dict[str, str]], int - ]: - """Return an option from the "[alembic]" or "[tool.alembic]" section - of the configparser-parsed .ini file (e.g. ``alembic.ini``) or - toml-parsed ``pyproject.toml`` file. - - The value returned is expected to be None, string, list of strings, - or dictionary of strings. Within each type of string value, the - ``%(here)s`` token is substituted out with the absolute path of the - ``pyproject.toml`` file, as are other tokens which are extracted from - the :paramref:`.Config.config_args` dictionary. - - Searches always prioritize the configparser namespace first, before - searching in the toml namespace. - - If Alembic was run using the ``-n/--name`` flag to indicate an - alternate main section name, this is taken into account **only** for - the configparser-parsed .ini file. The section name in toml is always - ``[tool.alembic]``. - - - .. versionadded:: 1.16.0 - - """ - - if self.file_config.has_option(self.config_ini_section, name): - return self.file_config.get(self.config_ini_section, name) - else: - return self._get_toml_config_value(name, default=default) - - def get_alembic_boolean_option(self, name: str) -> bool: - if self.file_config.has_option(self.config_ini_section, name): - return ( - self.file_config.get(self.config_ini_section, name) == "true" - ) - else: - value = self.toml_alembic_config.get(name, False) - if not isinstance(value, bool): - raise util.CommandError( - f"boolean value expected for TOML parameter {name!r}" - ) - return value - - def _get_toml_config_value( - self, name: str, default: Optional[Any] = None - ) -> Union[ - None, str, list[str], dict[str, str], list[dict[str, str]], int - ]: - USE_DEFAULT = object() - value: Union[None, str, list[str], dict[str, str], int] = ( - self.toml_alembic_config.get(name, USE_DEFAULT) - ) - if value is USE_DEFAULT: - return default - if value is not None: - if isinstance(value, str): - value = value % (self.toml_args) - elif isinstance(value, list): - if value and isinstance(value[0], dict): - value = [ - {k: v % (self.toml_args) for k, v in dv.items()} - for dv in value - ] - else: - value = cast( - "list[str]", [v % (self.toml_args) for v in value] - ) - elif isinstance(value, dict): - value = cast( - "dict[str, str]", - {k: v % (self.toml_args) for k, v in value.items()}, - ) - elif isinstance(value, int): - return value - else: - raise util.CommandError( - f"unsupported TOML value type for key: {name!r}" - ) - return value - @util.memoized_property def messaging_opts(self) -> MessagingOptions: """The messaging options.""" @@ -517,313 +341,179 @@ class Config: ), ) - def _get_file_separator_char(self, *names: str) -> Optional[str]: - for name in names: - separator = self.get_main_option(name) - if separator is not None: - break - else: - return None - - split_on_path = { - "space": " ", - "newline": "\n", - "os": os.pathsep, - ":": ":", - ";": ";", - } - - try: - sep = split_on_path[separator] - except KeyError as ke: - raise ValueError( - "'%s' is not a valid value for %s; " - "expected 'space', 'newline', 'os', ':', ';'" - % (separator, name) - ) from ke - else: - if name == "version_path_separator": - util.warn_deprecated( - "The version_path_separator configuration parameter " - "is deprecated; please use path_separator" - ) - return sep - - def get_version_locations_list(self) -> Optional[list[str]]: - - version_locations_str = self.file_config.get( - self.config_ini_section, "version_locations", fallback=None - ) - - if version_locations_str: - split_char = self._get_file_separator_char( - "path_separator", "version_path_separator" - ) - - if split_char is None: - - # legacy behaviour for backwards compatibility - util.warn_deprecated( - "No path_separator found in configuration; " - "falling back to legacy splitting on spaces/commas " - "for version_locations. Consider adding " - "path_separator=os to Alembic config." - ) - - _split_on_space_comma = re.compile(r", *|(?: +)") - return _split_on_space_comma.split(version_locations_str) - else: - return [ - x.strip() - for x in version_locations_str.split(split_char) - if x - ] - else: - return cast( - "list[str]", - self._get_toml_config_value("version_locations", None), - ) - - def get_prepend_sys_paths_list(self) -> Optional[list[str]]: - prepend_sys_path_str = self.file_config.get( - self.config_ini_section, "prepend_sys_path", fallback=None - ) - - if prepend_sys_path_str: - split_char = self._get_file_separator_char("path_separator") - - if split_char is None: - - # legacy behaviour for backwards compatibility - util.warn_deprecated( - "No path_separator found in configuration; " - "falling back to legacy splitting on spaces, commas, " - "and colons for prepend_sys_path. Consider adding " - "path_separator=os to Alembic config." - ) - - _split_on_space_comma_colon = re.compile(r", *|(?: +)|\:") - return _split_on_space_comma_colon.split(prepend_sys_path_str) - else: - return [ - x.strip() - for x in prepend_sys_path_str.split(split_char) - if x - ] - else: - return cast( - "list[str]", - self._get_toml_config_value("prepend_sys_path", None), - ) - - def get_hooks_list(self) -> list[PostWriteHookConfig]: - - hooks: list[PostWriteHookConfig] = [] - - if not self.file_config.has_section("post_write_hooks"): - toml_hook_config = cast( - "list[dict[str, str]]", - self._get_toml_config_value("post_write_hooks", []), - ) - for cfg in toml_hook_config: - opts = dict(cfg) - opts["_hook_name"] = opts.pop("name") - hooks.append(opts) - - else: - _split_on_space_comma = re.compile(r", *|(?: +)") - ini_hook_config = self.get_section("post_write_hooks", {}) - names = _split_on_space_comma.split( - ini_hook_config.get("hooks", "") - ) - - for name in names: - if not name: - continue - opts = { - key[len(name) + 1 :]: ini_hook_config[key] - for key in ini_hook_config - if key.startswith(name + ".") - } - - opts["_hook_name"] = name - hooks.append(opts) - - return hooks - - -PostWriteHookConfig = Mapping[str, str] - class MessagingOptions(TypedDict, total=False): quiet: bool -class CommandFunction(Protocol): - """A function that may be registered in the CLI as an alembic command. - It must be a named function and it must accept a :class:`.Config` object - as the first argument. - - .. versionadded:: 1.15.3 - - """ - - __name__: str - - def __call__(self, config: Config, *args: Any, **kwargs: Any) -> Any: ... - - class CommandLine: - """Provides the command line interface to Alembic.""" - def __init__(self, prog: Optional[str] = None) -> None: self._generate_args(prog) - _KWARGS_OPTS = { - "template": ( - "-t", - "--template", - dict( - default="generic", - type=str, - help="Setup template for use with 'init'", - ), - ), - "message": ( - "-m", - "--message", - dict(type=str, help="Message string to use with 'revision'"), - ), - "sql": ( - "--sql", - dict( - action="store_true", - help="Don't emit SQL to database - dump to " - "standard output/file instead. See docs on " - "offline mode.", - ), - ), - "tag": ( - "--tag", - dict( - type=str, - help="Arbitrary 'tag' name - can be used by " - "custom env.py scripts.", - ), - ), - "head": ( - "--head", - dict( - type=str, - help="Specify head revision or @head " - "to base new revision on.", - ), - ), - "splice": ( - "--splice", - dict( - action="store_true", - help="Allow a non-head revision as the 'head' to splice onto", - ), - ), - "depends_on": ( - "--depends-on", - dict( - action="append", - help="Specify one or more revision identifiers " - "which this revision should depend on.", - ), - ), - "rev_id": ( - "--rev-id", - dict( - type=str, - help="Specify a hardcoded revision id instead of " - "generating one", - ), - ), - "version_path": ( - "--version-path", - dict( - type=str, - help="Specify specific path from config for version file", - ), - ), - "branch_label": ( - "--branch-label", - dict( - type=str, - help="Specify a branch label to apply to the new revision", - ), - ), - "verbose": ( - "-v", - "--verbose", - dict(action="store_true", help="Use more verbose output"), - ), - "resolve_dependencies": ( - "--resolve-dependencies", - dict( - action="store_true", - help="Treat dependency versions as down revisions", - ), - ), - "autogenerate": ( - "--autogenerate", - dict( - action="store_true", - help="Populate revision script with candidate " - "migration operations, based on comparison " - "of database to model.", - ), - ), - "rev_range": ( - "-r", - "--rev-range", - dict( - action="store", - help="Specify a revision range; format is [start]:[end]", - ), - ), - "indicate_current": ( - "-i", - "--indicate-current", - dict( - action="store_true", - help="Indicate the current revision", - ), - ), - "purge": ( - "--purge", - dict( - action="store_true", - help="Unconditionally erase the version table before stamping", - ), - ), - "package": ( - "--package", - dict( - action="store_true", - help="Write empty __init__.py files to the " - "environment and version locations", - ), - ), - } - _POSITIONAL_OPTS = { - "directory": dict(help="location of scripts directory"), - "revision": dict( - help="revision identifier", - ), - "revisions": dict( - nargs="+", - help="one or more revisions, or 'heads' for all heads", - ), - } - _POSITIONAL_TRANSLATIONS: dict[Any, dict[str, str]] = { - command.stamp: {"revision": "revisions"} - } - def _generate_args(self, prog: Optional[str]) -> None: + def add_options(fn, parser, positional, kwargs): + kwargs_opts = { + "template": ( + "-t", + "--template", + dict( + default="generic", + type=str, + help="Setup template for use with 'init'", + ), + ), + "message": ( + "-m", + "--message", + dict( + type=str, help="Message string to use with 'revision'" + ), + ), + "sql": ( + "--sql", + dict( + action="store_true", + help="Don't emit SQL to database - dump to " + "standard output/file instead. See docs on " + "offline mode.", + ), + ), + "tag": ( + "--tag", + dict( + type=str, + help="Arbitrary 'tag' name - can be used by " + "custom env.py scripts.", + ), + ), + "head": ( + "--head", + dict( + type=str, + help="Specify head revision or @head " + "to base new revision on.", + ), + ), + "splice": ( + "--splice", + dict( + action="store_true", + help="Allow a non-head revision as the " + "'head' to splice onto", + ), + ), + "depends_on": ( + "--depends-on", + dict( + action="append", + help="Specify one or more revision identifiers " + "which this revision should depend on.", + ), + ), + "rev_id": ( + "--rev-id", + dict( + type=str, + help="Specify a hardcoded revision id instead of " + "generating one", + ), + ), + "version_path": ( + "--version-path", + dict( + type=str, + help="Specify specific path from config for " + "version file", + ), + ), + "branch_label": ( + "--branch-label", + dict( + type=str, + help="Specify a branch label to apply to the " + "new revision", + ), + ), + "verbose": ( + "-v", + "--verbose", + dict(action="store_true", help="Use more verbose output"), + ), + "resolve_dependencies": ( + "--resolve-dependencies", + dict( + action="store_true", + help="Treat dependency versions as down revisions", + ), + ), + "autogenerate": ( + "--autogenerate", + dict( + action="store_true", + help="Populate revision script with candidate " + "migration operations, based on comparison " + "of database to model.", + ), + ), + "rev_range": ( + "-r", + "--rev-range", + dict( + action="store", + help="Specify a revision range; " + "format is [start]:[end]", + ), + ), + "indicate_current": ( + "-i", + "--indicate-current", + dict( + action="store_true", + help="Indicate the current revision", + ), + ), + "purge": ( + "--purge", + dict( + action="store_true", + help="Unconditionally erase the version table " + "before stamping", + ), + ), + "package": ( + "--package", + dict( + action="store_true", + help="Write empty __init__.py files to the " + "environment and version locations", + ), + ), + } + positional_help = { + "directory": "location of scripts directory", + "revision": "revision identifier", + "revisions": "one or more revisions, or 'heads' for all heads", + } + for arg in kwargs: + if arg in kwargs_opts: + args = kwargs_opts[arg] + args, kw = args[0:-1], args[-1] + parser.add_argument(*args, **kw) + + for arg in positional: + if ( + arg == "revisions" + or fn in positional_translations + and positional_translations[fn][arg] == "revisions" + ): + subparser.add_argument( + "revisions", + nargs="+", + help=positional_help.get("revisions"), + ) + else: + subparser.add_argument(arg, help=positional_help.get(arg)) + parser = ArgumentParser(prog=prog) parser.add_argument( @@ -832,19 +522,17 @@ class CommandLine: parser.add_argument( "-c", "--config", - action="append", + type=str, + default=os.environ.get("ALEMBIC_CONFIG", "alembic.ini"), help="Alternate config file; defaults to value of " - 'ALEMBIC_CONFIG environment variable, or "alembic.ini". ' - "May also refer to pyproject.toml file. May be specified twice " - "to reference both files separately", + 'ALEMBIC_CONFIG environment variable, or "alembic.ini"', ) parser.add_argument( "-n", "--name", type=str, default="alembic", - help="Name of section in .ini file to use for Alembic config " - "(only applies to configparser config, not toml)", + help="Name of section in .ini file to " "use for Alembic config", ) parser.add_argument( "-x", @@ -864,80 +552,47 @@ class CommandLine: action="store_true", help="Do not log to std output.", ) + subparsers = parser.add_subparsers() - self.subparsers = parser.add_subparsers() - alembic_commands = ( - cast(CommandFunction, fn) - for fn in (getattr(command, name) for name in dir(command)) + positional_translations = {command.stamp: {"revision": "revisions"}} + + for fn in [getattr(command, n) for n in dir(command)]: if ( inspect.isfunction(fn) and fn.__name__[0] != "_" and fn.__module__ == "alembic.command" - ) - ) - - for fn in alembic_commands: - self.register_command(fn) - - self.parser = parser - - def register_command(self, fn: CommandFunction) -> None: - """Registers a function as a CLI subcommand. The subcommand name - matches the function name, the arguments are extracted from the - signature and the help text is read from the docstring. - - .. versionadded:: 1.15.3 - - .. seealso:: - - :ref:`custom_commandline` - """ - - positional, kwarg, help_text = self._inspect_function(fn) - - subparser = self.subparsers.add_parser(fn.__name__, help=help_text) - subparser.set_defaults(cmd=(fn, positional, kwarg)) - - for arg in kwarg: - if arg in self._KWARGS_OPTS: - kwarg_opt = self._KWARGS_OPTS[arg] - args, opts = kwarg_opt[0:-1], kwarg_opt[-1] - subparser.add_argument(*args, **opts) # type:ignore - - for arg in positional: - opts = self._POSITIONAL_OPTS.get(arg, {}) - subparser.add_argument(arg, **opts) # type:ignore - - def _inspect_function(self, fn: CommandFunction) -> tuple[Any, Any, str]: - spec = compat.inspect_getfullargspec(fn) - if spec[3] is not None: - positional = spec[0][1 : -len(spec[3])] - kwarg = spec[0][-len(spec[3]) :] - else: - positional = spec[0][1:] - kwarg = [] - - if fn in self._POSITIONAL_TRANSLATIONS: - positional = [ - self._POSITIONAL_TRANSLATIONS[fn].get(name, name) - for name in positional - ] - - # parse first line(s) of helptext without a line break - help_ = fn.__doc__ - if help_: - help_lines = [] - for line in help_.split("\n"): - if not line.strip(): - break + ): + spec = compat.inspect_getfullargspec(fn) + if spec[3] is not None: + positional = spec[0][1 : -len(spec[3])] + kwarg = spec[0][-len(spec[3]) :] else: - help_lines.append(line.strip()) - else: - help_lines = [] + positional = spec[0][1:] + kwarg = [] - help_text = " ".join(help_lines) + if fn in positional_translations: + positional = [ + positional_translations[fn].get(name, name) + for name in positional + ] - return positional, kwarg, help_text + # parse first line(s) of helptext without a line break + help_ = fn.__doc__ + if help_: + help_text = [] + for line in help_.split("\n"): + if not line.strip(): + break + else: + help_text.append(line.strip()) + else: + help_text = [] + subparser = subparsers.add_parser( + fn.__name__, help=" ".join(help_text) + ) + add_options(fn, subparser, positional, kwarg) + subparser.set_defaults(cmd=(fn, positional, kwarg)) + self.parser = parser def run_cmd(self, config: Config, options: Namespace) -> None: fn, positional, kwarg = options.cmd @@ -954,69 +609,22 @@ class CommandLine: else: util.err(str(e), **config.messaging_opts) - def _inis_from_config(self, options: Namespace) -> tuple[str, str]: - names = options.config - - alembic_config_env = os.environ.get("ALEMBIC_CONFIG") - if ( - alembic_config_env - and os.path.basename(alembic_config_env) == "pyproject.toml" - ): - default_pyproject_toml = alembic_config_env - default_alembic_config = "alembic.ini" - elif alembic_config_env: - default_pyproject_toml = "pyproject.toml" - default_alembic_config = alembic_config_env - else: - default_alembic_config = "alembic.ini" - default_pyproject_toml = "pyproject.toml" - - if not names: - return default_pyproject_toml, default_alembic_config - - toml = ini = None - - for name in names: - if os.path.basename(name) == "pyproject.toml": - if toml is not None: - raise util.CommandError( - "pyproject.toml indicated more than once" - ) - toml = name - else: - if ini is not None: - raise util.CommandError( - "only one ini file may be indicated" - ) - ini = name - - return toml if toml else default_pyproject_toml, ( - ini if ini else default_alembic_config - ) - - def main(self, argv: Optional[Sequence[str]] = None) -> None: - """Executes the command line with the provided arguments.""" + def main(self, argv=None): options = self.parser.parse_args(argv) if not hasattr(options, "cmd"): # see http://bugs.python.org/issue9253, argparse # behavior changed incompatibly in py3.3 self.parser.error("too few arguments") else: - toml, ini = self._inis_from_config(options) cfg = Config( - file_=ini, - toml_file=toml, + file_=options.config, ini_section=options.name, cmd_opts=options, ) self.run_cmd(cfg, options) -def main( - argv: Optional[Sequence[str]] = None, - prog: Optional[str] = None, - **kwargs: Any, -) -> None: +def main(argv=None, prog=None, **kwargs): """The console runner function for Alembic.""" CommandLine(prog=prog).main(argv=argv) diff --git a/venv/lib/python3.12/site-packages/alembic/context.pyi b/venv/lib/python3.12/site-packages/alembic/context.pyi index 9117c31..f37f246 100644 --- a/venv/lib/python3.12/site-packages/alembic/context.pyi +++ b/venv/lib/python3.12/site-packages/alembic/context.pyi @@ -5,6 +5,7 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import Collection +from typing import ContextManager from typing import Dict from typing import Iterable from typing import List @@ -13,14 +14,11 @@ from typing import Mapping from typing import MutableMapping from typing import Optional from typing import overload -from typing import Sequence from typing import TextIO from typing import Tuple from typing import TYPE_CHECKING from typing import Union -from typing_extensions import ContextManager - if TYPE_CHECKING: from sqlalchemy.engine.base import Connection from sqlalchemy.engine.url import URL @@ -41,9 +39,7 @@ if TYPE_CHECKING: ### end imports ### -def begin_transaction() -> ( - Union[_ProxyTransaction, ContextManager[None, Optional[bool]]] -): +def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]: """Return a context manager that will enclose an operation within a "transaction", as defined by the environment's offline @@ -101,7 +97,7 @@ def configure( tag: Optional[str] = None, template_args: Optional[Dict[str, Any]] = None, render_as_batch: bool = False, - target_metadata: Union[MetaData, Sequence[MetaData], None] = None, + target_metadata: Optional[MetaData] = None, include_name: Optional[ Callable[ [ @@ -163,8 +159,8 @@ def configure( MigrationContext, Column[Any], Column[Any], - TypeEngine[Any], - TypeEngine[Any], + TypeEngine, + TypeEngine, ], Optional[bool], ], @@ -639,8 +635,7 @@ def configure( """ def execute( - sql: Union[Executable, str], - execution_options: Optional[Dict[str, Any]] = None, + sql: Union[Executable, str], execution_options: Optional[dict] = None ) -> None: """Execute the given SQL using the current change context. @@ -763,11 +758,7 @@ def get_x_argument( The return value is a list, returned directly from the ``argparse`` structure. If ``as_dictionary=True`` is passed, the ``x`` arguments are parsed using ``key=value`` format into a dictionary that is - then returned. If there is no ``=`` in the argument, value is an empty - string. - - .. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when - arguments are passed without the ``=`` symbol. + then returned. For example, to support passing a database URL on the command line, the standard ``env.py`` script can be modified like this:: @@ -809,7 +800,7 @@ def is_offline_mode() -> bool: """ -def is_transactional_ddl() -> bool: +def is_transactional_ddl(): """Return True if the context is configured to expect a transactional DDL capable backend. diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/__init__.py b/venv/lib/python3.12/site-packages/alembic/ddl/__init__.py index f2f72b3..cfcc47e 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/__init__.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/__init__.py @@ -3,4 +3,4 @@ from . import mysql from . import oracle from . import postgresql from . import sqlite -from .impl import DefaultImpl as DefaultImpl +from .impl import DefaultImpl diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py b/venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py deleted file mode 100644 index 74715b1..0000000 --- a/venv/lib/python3.12/site-packages/alembic/ddl/_autogen.py +++ /dev/null @@ -1,329 +0,0 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - -from __future__ import annotations - -from typing import Any -from typing import ClassVar -from typing import Dict -from typing import Generic -from typing import NamedTuple -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING -from typing import TypeVar -from typing import Union - -from sqlalchemy.sql.schema import Constraint -from sqlalchemy.sql.schema import ForeignKeyConstraint -from sqlalchemy.sql.schema import Index -from sqlalchemy.sql.schema import UniqueConstraint -from typing_extensions import TypeGuard - -from .. import util -from ..util import sqla_compat - -if TYPE_CHECKING: - from typing import Literal - - from alembic.autogenerate.api import AutogenContext - from alembic.ddl.impl import DefaultImpl - -CompareConstraintType = Union[Constraint, Index] - -_C = TypeVar("_C", bound=CompareConstraintType) - -_clsreg: Dict[str, Type[_constraint_sig]] = {} - - -class ComparisonResult(NamedTuple): - status: Literal["equal", "different", "skip"] - message: str - - @property - def is_equal(self) -> bool: - return self.status == "equal" - - @property - def is_different(self) -> bool: - return self.status == "different" - - @property - def is_skip(self) -> bool: - return self.status == "skip" - - @classmethod - def Equal(cls) -> ComparisonResult: - """the constraints are equal.""" - return cls("equal", "The two constraints are equal") - - @classmethod - def Different(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: - """the constraints are different for the provided reason(s).""" - return cls("different", ", ".join(util.to_list(reason))) - - @classmethod - def Skip(cls, reason: Union[str, Sequence[str]]) -> ComparisonResult: - """the constraint cannot be compared for the provided reason(s). - - The message is logged, but the constraints will be otherwise - considered equal, meaning that no migration command will be - generated. - """ - return cls("skip", ", ".join(util.to_list(reason))) - - -class _constraint_sig(Generic[_C]): - const: _C - - _sig: Tuple[Any, ...] - name: Optional[sqla_compat._ConstraintNameDefined] - - impl: DefaultImpl - - _is_index: ClassVar[bool] = False - _is_fk: ClassVar[bool] = False - _is_uq: ClassVar[bool] = False - - _is_metadata: bool - - def __init_subclass__(cls) -> None: - cls._register() - - @classmethod - def _register(cls): - raise NotImplementedError() - - def __init__( - self, is_metadata: bool, impl: DefaultImpl, const: _C - ) -> None: - raise NotImplementedError() - - def compare_to_reflected( - self, other: _constraint_sig[Any] - ) -> ComparisonResult: - assert self.impl is other.impl - assert self._is_metadata - assert not other._is_metadata - - return self._compare_to_reflected(other) - - def _compare_to_reflected( - self, other: _constraint_sig[_C] - ) -> ComparisonResult: - raise NotImplementedError() - - @classmethod - def from_constraint( - cls, is_metadata: bool, impl: DefaultImpl, constraint: _C - ) -> _constraint_sig[_C]: - # these could be cached by constraint/impl, however, if the - # constraint is modified in place, then the sig is wrong. the mysql - # impl currently does this, and if we fixed that we can't be sure - # someone else might do it too, so play it safe. - sig = _clsreg[constraint.__visit_name__](is_metadata, impl, constraint) - return sig - - def md_name_to_sql_name(self, context: AutogenContext) -> Optional[str]: - return sqla_compat._get_constraint_final_name( - self.const, context.dialect - ) - - @util.memoized_property - def is_named(self): - return sqla_compat._constraint_is_named(self.const, self.impl.dialect) - - @util.memoized_property - def unnamed(self) -> Tuple[Any, ...]: - return self._sig - - @util.memoized_property - def unnamed_no_options(self) -> Tuple[Any, ...]: - raise NotImplementedError() - - @util.memoized_property - def _full_sig(self) -> Tuple[Any, ...]: - return (self.name,) + self.unnamed - - def __eq__(self, other) -> bool: - return self._full_sig == other._full_sig - - def __ne__(self, other) -> bool: - return self._full_sig != other._full_sig - - def __hash__(self) -> int: - return hash(self._full_sig) - - -class _uq_constraint_sig(_constraint_sig[UniqueConstraint]): - _is_uq = True - - @classmethod - def _register(cls) -> None: - _clsreg["unique_constraint"] = cls - - is_unique = True - - def __init__( - self, - is_metadata: bool, - impl: DefaultImpl, - const: UniqueConstraint, - ) -> None: - self.impl = impl - self.const = const - self.name = sqla_compat.constraint_name_or_none(const.name) - self._sig = tuple(sorted([col.name for col in const.columns])) - self._is_metadata = is_metadata - - @property - def column_names(self) -> Tuple[str, ...]: - return tuple([col.name for col in self.const.columns]) - - def _compare_to_reflected( - self, other: _constraint_sig[_C] - ) -> ComparisonResult: - assert self._is_metadata - metadata_obj = self - conn_obj = other - - assert is_uq_sig(conn_obj) - return self.impl.compare_unique_constraint( - metadata_obj.const, conn_obj.const - ) - - -class _ix_constraint_sig(_constraint_sig[Index]): - _is_index = True - - name: sqla_compat._ConstraintName - - @classmethod - def _register(cls) -> None: - _clsreg["index"] = cls - - def __init__( - self, is_metadata: bool, impl: DefaultImpl, const: Index - ) -> None: - self.impl = impl - self.const = const - self.name = const.name - self.is_unique = bool(const.unique) - self._is_metadata = is_metadata - - def _compare_to_reflected( - self, other: _constraint_sig[_C] - ) -> ComparisonResult: - assert self._is_metadata - metadata_obj = self - conn_obj = other - - assert is_index_sig(conn_obj) - return self.impl.compare_indexes(metadata_obj.const, conn_obj.const) - - @util.memoized_property - def has_expressions(self): - return sqla_compat.is_expression_index(self.const) - - @util.memoized_property - def column_names(self) -> Tuple[str, ...]: - return tuple([col.name for col in self.const.columns]) - - @util.memoized_property - def column_names_optional(self) -> Tuple[Optional[str], ...]: - return tuple( - [getattr(col, "name", None) for col in self.const.expressions] - ) - - @util.memoized_property - def is_named(self): - return True - - @util.memoized_property - def unnamed(self): - return (self.is_unique,) + self.column_names_optional - - -class _fk_constraint_sig(_constraint_sig[ForeignKeyConstraint]): - _is_fk = True - - @classmethod - def _register(cls) -> None: - _clsreg["foreign_key_constraint"] = cls - - def __init__( - self, - is_metadata: bool, - impl: DefaultImpl, - const: ForeignKeyConstraint, - ) -> None: - self._is_metadata = is_metadata - - self.impl = impl - self.const = const - - self.name = sqla_compat.constraint_name_or_none(const.name) - - ( - self.source_schema, - self.source_table, - self.source_columns, - self.target_schema, - self.target_table, - self.target_columns, - onupdate, - ondelete, - deferrable, - initially, - ) = sqla_compat._fk_spec(const) - - self._sig: Tuple[Any, ...] = ( - self.source_schema, - self.source_table, - tuple(self.source_columns), - self.target_schema, - self.target_table, - tuple(self.target_columns), - ) + ( - ( - (None if onupdate.lower() == "no action" else onupdate.lower()) - if onupdate - else None - ), - ( - (None if ondelete.lower() == "no action" else ondelete.lower()) - if ondelete - else None - ), - # convert initially + deferrable into one three-state value - ( - "initially_deferrable" - if initially and initially.lower() == "deferred" - else "deferrable" if deferrable else "not deferrable" - ), - ) - - @util.memoized_property - def unnamed_no_options(self): - return ( - self.source_schema, - self.source_table, - tuple(self.source_columns), - self.target_schema, - self.target_table, - tuple(self.target_columns), - ) - - -def is_index_sig(sig: _constraint_sig) -> TypeGuard[_ix_constraint_sig]: - return sig._is_index - - -def is_uq_sig(sig: _constraint_sig) -> TypeGuard[_uq_constraint_sig]: - return sig._is_uq - - -def is_fk_sig(sig: _constraint_sig) -> TypeGuard[_fk_constraint_sig]: - return sig._is_fk diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/base.py b/venv/lib/python3.12/site-packages/alembic/ddl/base.py index ad2847e..339db0c 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/base.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/base.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import functools @@ -25,8 +22,6 @@ from ..util.sqla_compat import _table_for_constraint # noqa if TYPE_CHECKING: from typing import Any - from sqlalchemy import Computed - from sqlalchemy import Identity from sqlalchemy.sql.compiler import Compiled from sqlalchemy.sql.compiler import DDLCompiler from sqlalchemy.sql.elements import TextClause @@ -35,11 +30,14 @@ if TYPE_CHECKING: from sqlalchemy.sql.type_api import TypeEngine from .impl import DefaultImpl + from ..util.sqla_compat import Computed + from ..util.sqla_compat import Identity _ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str] class AlterTable(DDLElement): + """Represent an ALTER TABLE statement. Only the string name and optional schema name of the table @@ -154,24 +152,17 @@ class AddColumn(AlterTable): name: str, column: Column[Any], schema: Optional[Union[quoted_name, str]] = None, - if_not_exists: Optional[bool] = None, ) -> None: super().__init__(name, schema=schema) self.column = column - self.if_not_exists = if_not_exists class DropColumn(AlterTable): def __init__( - self, - name: str, - column: Column[Any], - schema: Optional[str] = None, - if_exists: Optional[bool] = None, + self, name: str, column: Column[Any], schema: Optional[str] = None ) -> None: super().__init__(name, schema=schema) self.column = column - self.if_exists = if_exists class ColumnComment(AlterColumn): @@ -196,9 +187,7 @@ def visit_rename_table( def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), - add_column( - compiler, element.column, if_not_exists=element.if_not_exists, **kw - ), + add_column(compiler, element.column, **kw), ) @@ -206,9 +195,7 @@ def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str: def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str: return "%s %s" % ( alter_table(compiler, element.table_name, element.schema), - drop_column( - compiler, element.column.name, if_exists=element.if_exists, **kw - ), + drop_column(compiler, element.column.name, **kw), ) @@ -248,11 +235,9 @@ def visit_column_default( return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - ( - "SET DEFAULT %s" % format_server_default(compiler, element.default) - if element.default is not None - else "DROP DEFAULT" - ), + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT", ) @@ -310,13 +295,9 @@ def format_server_default( compiler: DDLCompiler, default: Optional[_ServerDefault], ) -> str: - # this can be updated to use compiler.render_default_string - # for SQLAlchemy 2.0 and above; not in 1.4 - default_str = compiler.get_column_default_string( + return compiler.get_column_default_string( Column("x", Integer, server_default=default) ) - assert default_str is not None - return default_str def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str: @@ -331,29 +312,16 @@ def alter_table( return "ALTER TABLE %s" % format_table_name(compiler, name, schema) -def drop_column( - compiler: DDLCompiler, name: str, if_exists: Optional[bool] = None, **kw -) -> str: - return "DROP COLUMN %s%s" % ( - "IF EXISTS " if if_exists else "", - format_column_name(compiler, name), - ) +def drop_column(compiler: DDLCompiler, name: str, **kw) -> str: + return "DROP COLUMN %s" % format_column_name(compiler, name) def alter_column(compiler: DDLCompiler, name: str) -> str: return "ALTER COLUMN %s" % format_column_name(compiler, name) -def add_column( - compiler: DDLCompiler, - column: Column[Any], - if_not_exists: Optional[bool] = None, - **kw, -) -> str: - text = "ADD COLUMN %s%s" % ( - "IF NOT EXISTS " if if_not_exists else "", - compiler.get_column_specification(column, **kw), - ) +def add_column(compiler: DDLCompiler, column: Column[Any], **kw) -> str: + text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) const = " ".join( compiler.process(constraint) for constraint in column.constraints diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/impl.py b/venv/lib/python3.12/site-packages/alembic/ddl/impl.py index d352f12..8a7c75d 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/impl.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/impl.py @@ -1,9 +1,6 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations -import logging +from collections import namedtuple import re from typing import Any from typing import Callable @@ -11,7 +8,6 @@ from typing import Dict from typing import Iterable from typing import List from typing import Mapping -from typing import NamedTuple from typing import Optional from typing import Sequence from typing import Set @@ -21,18 +17,10 @@ from typing import TYPE_CHECKING from typing import Union from sqlalchemy import cast -from sqlalchemy import Column -from sqlalchemy import MetaData -from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema -from sqlalchemy import String -from sqlalchemy import Table from sqlalchemy import text -from . import _autogen from . import base -from ._autogen import _constraint_sig as _constraint_sig -from ._autogen import ComparisonResult as ComparisonResult from .. import util from ..util import sqla_compat @@ -46,10 +34,13 @@ if TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql import ClauseElement from sqlalchemy.sql import Executable + from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name + from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index + from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import UniqueConstraint from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine @@ -59,8 +50,6 @@ if TYPE_CHECKING: from ..operations.batch import ApplyBatchImpl from ..operations.batch import BatchOperationsImpl -log = logging.getLogger(__name__) - class ImplMeta(type): def __init__( @@ -77,8 +66,11 @@ class ImplMeta(type): _impls: Dict[str, Type[DefaultImpl]] = {} +Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"]) + class DefaultImpl(metaclass=ImplMeta): + """Provide the entrypoint for major migration operations, including database-specific behavioral variances. @@ -138,40 +130,6 @@ class DefaultImpl(metaclass=ImplMeta): self.output_buffer.write(text + "\n\n") self.output_buffer.flush() - def version_table_impl( - self, - *, - version_table: str, - version_table_schema: Optional[str], - version_table_pk: bool, - **kw: Any, - ) -> Table: - """Generate a :class:`.Table` object which will be used as the - structure for the Alembic version table. - - Third party dialects may override this hook to provide an alternate - structure for this :class:`.Table`; requirements are only that it - be named based on the ``version_table`` parameter and contains - at least a single string-holding column named ``version_num``. - - .. versionadded:: 1.14 - - """ - vt = Table( - version_table, - MetaData(), - Column("version_num", String(32), nullable=False), - schema=version_table_schema, - ) - if version_table_pk: - vt.append_constraint( - PrimaryKeyConstraint( - "version_num", name=f"{version_table}_pkc" - ) - ) - - return vt - def requires_recreate_in_batch( self, batch_op: BatchOperationsImpl ) -> bool: @@ -203,15 +161,16 @@ class DefaultImpl(metaclass=ImplMeta): def _exec( self, construct: Union[Executable, str], - execution_options: Optional[Mapping[str, Any]] = None, - multiparams: Optional[Sequence[Mapping[str, Any]]] = None, - params: Mapping[str, Any] = util.immutabledict(), + execution_options: Optional[dict[str, Any]] = None, + multiparams: Sequence[dict] = (), + params: Dict[str, Any] = util.immutabledict(), ) -> Optional[CursorResult]: if isinstance(construct, str): construct = text(construct) if self.as_sql: - if multiparams is not None or params: - raise TypeError("SQL parameters not allowed with as_sql") + if multiparams or params: + # TODO: coverage + raise Exception("Execution arguments not allowed with as_sql") compile_kw: dict[str, Any] if self.literal_binds and not isinstance( @@ -234,16 +193,11 @@ class DefaultImpl(metaclass=ImplMeta): assert conn is not None if execution_options: conn = conn.execution_options(**execution_options) + if params: + assert isinstance(multiparams, tuple) + multiparams += (params,) - if params and multiparams is not None: - raise TypeError( - "Can't send params and multiparams at the same time" - ) - - if multiparams: - return conn.execute(construct, multiparams) - else: - return conn.execute(construct, params) + return conn.execute(construct, multiparams) def execute( self, @@ -256,11 +210,8 @@ class DefaultImpl(metaclass=ImplMeta): self, table_name: str, column_name: str, - *, nullable: Optional[bool] = None, - server_default: Optional[ - Union[_ServerDefault, Literal[False]] - ] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, @@ -371,40 +322,25 @@ class DefaultImpl(metaclass=ImplMeta): self, table_name: str, column: Column[Any], - *, schema: Optional[Union[str, quoted_name]] = None, - if_not_exists: Optional[bool] = None, ) -> None: - self._exec( - base.AddColumn( - table_name, - column, - schema=schema, - if_not_exists=if_not_exists, - ) - ) + self._exec(base.AddColumn(table_name, column, schema=schema)) def drop_column( self, table_name: str, column: Column[Any], - *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, **kw, ) -> None: - self._exec( - base.DropColumn( - table_name, column, schema=schema, if_exists=if_exists - ) - ) + self._exec(base.DropColumn(table_name, column, schema=schema)) def add_constraint(self, const: Any) -> None: if const._create_rule is None or const._create_rule(self): self._exec(schema.AddConstraint(const)) - def drop_constraint(self, const: Constraint, **kw: Any) -> None: - self._exec(schema.DropConstraint(const, **kw)) + def drop_constraint(self, const: Constraint) -> None: + self._exec(schema.DropConstraint(const)) def rename_table( self, @@ -416,11 +352,11 @@ class DefaultImpl(metaclass=ImplMeta): base.RenameTable(old_table_name, new_table_name, schema=schema) ) - def create_table(self, table: Table, **kw: Any) -> None: + def create_table(self, table: Table) -> None: table.dispatch.before_create( table, self.connection, checkfirst=False, _ddl_runner=self ) - self._exec(schema.CreateTable(table, **kw)) + self._exec(schema.CreateTable(table)) table.dispatch.after_create( table, self.connection, checkfirst=False, _ddl_runner=self ) @@ -439,11 +375,11 @@ class DefaultImpl(metaclass=ImplMeta): if comment and with_comment: self.create_column_comment(column) - def drop_table(self, table: Table, **kw: Any) -> None: + def drop_table(self, table: Table) -> None: table.dispatch.before_drop( table, self.connection, checkfirst=False, _ddl_runner=self ) - self._exec(schema.DropTable(table, **kw)) + self._exec(schema.DropTable(table)) table.dispatch.after_drop( table, self.connection, checkfirst=False, _ddl_runner=self ) @@ -457,7 +393,7 @@ class DefaultImpl(metaclass=ImplMeta): def drop_table_comment(self, table: Table) -> None: self._exec(schema.DropTableComment(table)) - def create_column_comment(self, column: Column[Any]) -> None: + def create_column_comment(self, column: ColumnElement[Any]) -> None: self._exec(schema.SetColumnComment(column)) def drop_index(self, index: Index, **kw: Any) -> None: @@ -476,19 +412,15 @@ class DefaultImpl(metaclass=ImplMeta): if self.as_sql: for row in rows: self._exec( - table.insert() - .inline() - .values( + sqla_compat._insert_inline(table).values( **{ - k: ( - sqla_compat._literal_bindparam( - k, v, type_=table.c[k].type - ) - if not isinstance( - v, sqla_compat._literal_bindparam - ) - else v + k: sqla_compat._literal_bindparam( + k, v, type_=table.c[k].type ) + if not isinstance( + v, sqla_compat._literal_bindparam + ) + else v for k, v in row.items() } ) @@ -496,13 +428,16 @@ class DefaultImpl(metaclass=ImplMeta): else: if rows: if multiinsert: - self._exec(table.insert().inline(), multiparams=rows) + self._exec( + sqla_compat._insert_inline(table), multiparams=rows + ) else: for row in rows: - self._exec(table.insert().inline().values(**row)) + self._exec( + sqla_compat._insert_inline(table).values(**row) + ) def _tokenize_column_type(self, column: Column) -> Params: - definition: str definition = self.dialect.type_compiler.process(column.type).lower() # tokenize the SQLAlchemy-generated version of a type, so that @@ -517,9 +452,9 @@ class DefaultImpl(metaclass=ImplMeta): # varchar character set utf8 # - tokens: List[str] = re.findall(r"[\w\-_]+|\(.+?\)", definition) + tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition) - term_tokens: List[str] = [] + term_tokens = [] paren_term = None for token in tokens: @@ -531,7 +466,6 @@ class DefaultImpl(metaclass=ImplMeta): params = Params(term_tokens[0], term_tokens[1:], [], {}) if paren_term: - term: str for term in re.findall("[^(),]+", paren_term): if "=" in term: key, val = term.split("=") @@ -708,7 +642,7 @@ class DefaultImpl(metaclass=ImplMeta): diff, ignored = _compare_identity_options( metadata_identity, inspector_identity, - schema.Identity(), + sqla_compat.Identity(), skip={"always"}, ) @@ -730,96 +664,15 @@ class DefaultImpl(metaclass=ImplMeta): bool(diff) or bool(metadata_identity) != bool(inspector_identity), ) - def _compare_index_unique( - self, metadata_index: Index, reflected_index: Index - ) -> Optional[str]: - conn_unique = bool(reflected_index.unique) - meta_unique = bool(metadata_index.unique) - if conn_unique != meta_unique: - return f"unique={conn_unique} to unique={meta_unique}" - else: - return None + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + # order of col matters in an index + return tuple(col.name for col in index.columns) - def _create_metadata_constraint_sig( - self, constraint: _autogen._C, **opts: Any - ) -> _constraint_sig[_autogen._C]: - return _constraint_sig.from_constraint(True, self, constraint, **opts) - - def _create_reflected_constraint_sig( - self, constraint: _autogen._C, **opts: Any - ) -> _constraint_sig[_autogen._C]: - return _constraint_sig.from_constraint(False, self, constraint, **opts) - - def compare_indexes( - self, - metadata_index: Index, - reflected_index: Index, - ) -> ComparisonResult: - """Compare two indexes by comparing the signature generated by - ``create_index_sig``. - - This method returns a ``ComparisonResult``. - """ - msg: List[str] = [] - unique_msg = self._compare_index_unique( - metadata_index, reflected_index - ) - if unique_msg: - msg.append(unique_msg) - m_sig = self._create_metadata_constraint_sig(metadata_index) - r_sig = self._create_reflected_constraint_sig(reflected_index) - - assert _autogen.is_index_sig(m_sig) - assert _autogen.is_index_sig(r_sig) - - # The assumption is that the index have no expression - for sig in m_sig, r_sig: - if sig.has_expressions: - log.warning( - "Generating approximate signature for index %s. " - "The dialect " - "implementation should either skip expression indexes " - "or provide a custom implementation.", - sig.const, - ) - - if m_sig.column_names != r_sig.column_names: - msg.append( - f"expression {r_sig.column_names} to {m_sig.column_names}" - ) - - if msg: - return ComparisonResult.Different(msg) - else: - return ComparisonResult.Equal() - - def compare_unique_constraint( - self, - metadata_constraint: UniqueConstraint, - reflected_constraint: UniqueConstraint, - ) -> ComparisonResult: - """Compare two unique constraints by comparing the two signatures. - - The arguments are two tuples that contain the unique constraint and - the signatures generated by ``create_unique_constraint_sig``. - - This method returns a ``ComparisonResult``. - """ - metadata_tup = self._create_metadata_constraint_sig( - metadata_constraint - ) - reflected_tup = self._create_reflected_constraint_sig( - reflected_constraint - ) - - meta_sig = metadata_tup.unnamed - conn_sig = reflected_tup.unnamed - if conn_sig != meta_sig: - return ComparisonResult.Different( - f"expression {conn_sig} to {meta_sig}" - ) - else: - return ComparisonResult.Equal() + def create_unique_constraint_sig( + self, const: UniqueConstraint + ) -> Tuple[Any, ...]: + # order of col does not matters in an unique constraint + return tuple(sorted([col.name for col in const.columns])) def _skip_functional_indexes(self, metadata_indexes, conn_indexes): conn_indexes_by_name = {c.name: c for c in conn_indexes} @@ -844,13 +697,6 @@ class DefaultImpl(metaclass=ImplMeta): return reflected_object.get("dialect_options", {}) -class Params(NamedTuple): - token0: str - tokens: List[str] - args: List[str] - kwargs: Dict[str, str] - - def _compare_identity_options( metadata_io: Union[schema.Identity, schema.Sequence, None], inspector_io: Union[schema.Identity, schema.Sequence, None], @@ -889,13 +735,12 @@ def _compare_identity_options( set(meta_d).union(insp_d), ) if sqla_compat.identity_has_dialect_kwargs: - assert hasattr(default_io, "dialect_kwargs") # use only the dialect kwargs in inspector_io since metadata_io # can have options for many backends check_dicts( getattr(metadata_io, "dialect_kwargs", {}), getattr(inspector_io, "dialect_kwargs", {}), - default_io.dialect_kwargs, + default_io.dialect_kwargs, # type: ignore[union-attr] getattr(inspector_io, "dialect_kwargs", {}), ) diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/mssql.py b/venv/lib/python3.12/site-packages/alembic/ddl/mssql.py index 5376da5..9b0fff8 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/mssql.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/mssql.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import re @@ -12,6 +9,7 @@ from typing import TYPE_CHECKING from typing import Union from sqlalchemy import types as sqltypes +from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import Column from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.base import Executable @@ -32,7 +30,6 @@ from .base import RenameTable from .impl import DefaultImpl from .. import util from ..util import sqla_compat -from ..util.sqla_compat import compiles if TYPE_CHECKING: from typing import Literal @@ -83,11 +80,10 @@ class MSSQLImpl(DefaultImpl): if self.as_sql and self.batch_separator: self.static_output(self.batch_separator) - def alter_column( + def alter_column( # type:ignore[override] self, table_name: str, column_name: str, - *, nullable: Optional[bool] = None, server_default: Optional[ Union[_ServerDefault, Literal[False]] @@ -203,7 +199,6 @@ class MSSQLImpl(DefaultImpl): self, table_name: str, column: Column[Any], - *, schema: Optional[str] = None, **kw, ) -> None: diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/mysql.py b/venv/lib/python3.12/site-packages/alembic/ddl/mysql.py index 3d7cf21..32ced49 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/mysql.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/mysql.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import re @@ -11,9 +8,7 @@ from typing import Union from sqlalchemy import schema from sqlalchemy import types as sqltypes -from sqlalchemy.sql import elements -from sqlalchemy.sql import functions -from sqlalchemy.sql import operators +from sqlalchemy.ext.compiler import compiles from .base import alter_table from .base import AlterColumn @@ -25,16 +20,16 @@ from .base import format_column_name from .base import format_server_default from .impl import DefaultImpl from .. import util +from ..autogenerate import compare from ..util import sqla_compat +from ..util.sqla_compat import _is_mariadb from ..util.sqla_compat import _is_type_bound -from ..util.sqla_compat import compiles if TYPE_CHECKING: from typing import Literal from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler from sqlalchemy.sql.ddl import DropConstraint - from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.type_api import TypeEngine @@ -51,40 +46,12 @@ class MySQLImpl(DefaultImpl): ) type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"] - def render_ddl_sql_expr( - self, - expr: ClauseElement, - is_server_default: bool = False, - is_index: bool = False, - **kw: Any, - ) -> str: - # apply Grouping to index expressions; - # see https://github.com/sqlalchemy/sqlalchemy/blob/ - # 36da2eaf3e23269f2cf28420ae73674beafd0661/ - # lib/sqlalchemy/dialects/mysql/base.py#L2191 - if is_index and ( - isinstance(expr, elements.BinaryExpression) - or ( - isinstance(expr, elements.UnaryExpression) - and expr.modifier not in (operators.desc_op, operators.asc_op) - ) - or isinstance(expr, functions.FunctionElement) - ): - expr = elements.Grouping(expr) - - return super().render_ddl_sql_expr( - expr, is_server_default=is_server_default, is_index=is_index, **kw - ) - - def alter_column( + def alter_column( # type:ignore[override] self, table_name: str, column_name: str, - *, nullable: Optional[bool] = None, - server_default: Optional[ - Union[_ServerDefault, Literal[False]] - ] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, @@ -125,29 +92,21 @@ class MySQLImpl(DefaultImpl): column_name, schema=schema, newname=name if name is not None else column_name, - nullable=( - nullable - if nullable is not None - else ( - existing_nullable - if existing_nullable is not None - else True - ) - ), + nullable=nullable + if nullable is not None + else existing_nullable + if existing_nullable is not None + else True, type_=type_ if type_ is not None else existing_type, - default=( - server_default - if server_default is not False - else existing_server_default - ), - autoincrement=( - autoincrement - if autoincrement is not None - else existing_autoincrement - ), - comment=( - comment if comment is not False else existing_comment - ), + default=server_default + if server_default is not False + else existing_server_default, + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, + comment=comment + if comment is not False + else existing_comment, ) ) elif ( @@ -162,29 +121,21 @@ class MySQLImpl(DefaultImpl): column_name, schema=schema, newname=name if name is not None else column_name, - nullable=( - nullable - if nullable is not None - else ( - existing_nullable - if existing_nullable is not None - else True - ) - ), + nullable=nullable + if nullable is not None + else existing_nullable + if existing_nullable is not None + else True, type_=type_ if type_ is not None else existing_type, - default=( - server_default - if server_default is not False - else existing_server_default - ), - autoincrement=( - autoincrement - if autoincrement is not None - else existing_autoincrement - ), - comment=( - comment if comment is not False else existing_comment - ), + default=server_default + if server_default is not False + else existing_server_default, + autoincrement=autoincrement + if autoincrement is not None + else existing_autoincrement, + comment=comment + if comment is not False + else existing_comment, ) ) elif server_default is not False: @@ -197,7 +148,6 @@ class MySQLImpl(DefaultImpl): def drop_constraint( self, const: Constraint, - **kw: Any, ) -> None: if isinstance(const, schema.CheckConstraint) and _is_type_bound(const): return @@ -207,11 +157,12 @@ class MySQLImpl(DefaultImpl): def _is_mysql_allowed_functional_default( self, type_: Optional[TypeEngine], - server_default: Optional[Union[_ServerDefault, Literal[False]]], + server_default: Union[_ServerDefault, Literal[False]], ) -> bool: return ( type_ is not None - and type_._type_affinity is sqltypes.DateTime + and type_._type_affinity # type:ignore[attr-defined] + is sqltypes.DateTime and server_default is not None ) @@ -321,12 +272,10 @@ class MySQLImpl(DefaultImpl): def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks): conn_fk_by_sig = { - self._create_reflected_constraint_sig(fk).unnamed_no_options: fk - for fk in conn_fks + compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks } metadata_fk_by_sig = { - self._create_metadata_constraint_sig(fk).unnamed_no_options: fk - for fk in metadata_fks + compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks } for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig): @@ -358,7 +307,7 @@ class MySQLAlterDefault(AlterColumn): self, name: str, column_name: str, - default: Optional[_ServerDefault], + default: _ServerDefault, schema: Optional[str] = None, ) -> None: super(AlterColumn, self).__init__(name, schema=schema) @@ -416,11 +365,9 @@ def _mysql_alter_default( return "%s ALTER COLUMN %s %s" % ( alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), - ( - "SET DEFAULT %s" % format_server_default(compiler, element.default) - if element.default is not None - else "DROP DEFAULT" - ), + "SET DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DROP DEFAULT", ) @@ -507,7 +454,7 @@ def _mysql_drop_constraint( # note that SQLAlchemy as of 1.2 does not yet support # DROP CONSTRAINT for MySQL/MariaDB, so we implement fully # here. - if compiler.dialect.is_mariadb: + if _is_mariadb(compiler.dialect): return "ALTER TABLE %s DROP CONSTRAINT %s" % ( compiler.preparer.format_table(constraint.table), compiler.preparer.format_constraint(constraint), diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/oracle.py b/venv/lib/python3.12/site-packages/alembic/ddl/oracle.py index eac9912..e56bb21 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/oracle.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/oracle.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import re @@ -8,6 +5,7 @@ from typing import Any from typing import Optional from typing import TYPE_CHECKING +from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql import sqltypes from .base import AddColumn @@ -24,7 +22,6 @@ from .base import format_type from .base import IdentityColumnDefault from .base import RenameTable from .impl import DefaultImpl -from ..util.sqla_compat import compiles if TYPE_CHECKING: from sqlalchemy.dialects.oracle.base import OracleDDLCompiler @@ -141,11 +138,9 @@ def visit_column_default( return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), - ( - "DEFAULT %s" % format_server_default(compiler, element.default) - if element.default is not None - else "DEFAULT NULL" - ), + "DEFAULT %s" % format_server_default(compiler, element.default) + if element.default is not None + else "DEFAULT NULL", ) diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py b/venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py index 90ecf70..949e256 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/postgresql.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import logging @@ -16,19 +13,18 @@ from typing import TYPE_CHECKING from typing import Union from sqlalchemy import Column -from sqlalchemy import Float -from sqlalchemy import Identity from sqlalchemy import literal_column from sqlalchemy import Numeric -from sqlalchemy import select from sqlalchemy import text from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql import BIGINT from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.schema import CreateIndex +from sqlalchemy.sql import operators from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.types import NULLTYPE @@ -36,12 +32,12 @@ from .base import alter_column from .base import alter_table from .base import AlterColumn from .base import ColumnComment +from .base import compiles from .base import format_column_name from .base import format_table_name from .base import format_type from .base import IdentityColumnDefault from .base import RenameTable -from .impl import ComparisonResult from .impl import DefaultImpl from .. import util from ..autogenerate import render @@ -50,8 +46,6 @@ from ..operations import schemaobj from ..operations.base import BatchOperations from ..operations.base import Operations from ..util import sqla_compat -from ..util.sqla_compat import compiles - if TYPE_CHECKING: from typing import Literal @@ -136,28 +130,25 @@ class PostgresqlImpl(DefaultImpl): metadata_default = metadata_column.server_default.arg if isinstance(metadata_default, str): - if not isinstance(inspector_column.type, (Numeric, Float)): + if not isinstance(inspector_column.type, Numeric): metadata_default = re.sub(r"^'|'$", "", metadata_default) metadata_default = f"'{metadata_default}'" metadata_default = literal_column(metadata_default) # run a real compare against the server - conn = self.connection - assert conn is not None - return not conn.scalar( - select(literal_column(conn_col_default) == metadata_default) + return not self.connection.scalar( + sqla_compat._select( + literal_column(conn_col_default) == metadata_default + ) ) - def alter_column( + def alter_column( # type:ignore[override] self, table_name: str, column_name: str, - *, nullable: Optional[bool] = None, - server_default: Optional[ - Union[_ServerDefault, Literal[False]] - ] = False, + server_default: Union[_ServerDefault, Literal[False]] = False, name: Optional[str] = None, type_: Optional[TypeEngine] = None, schema: Optional[str] = None, @@ -223,8 +214,7 @@ class PostgresqlImpl(DefaultImpl): "join pg_class t on t.oid=d.refobjid " "join pg_attribute a on a.attrelid=t.oid and " "a.attnum=d.refobjsubid " - "where c.relkind='S' and " - "c.oid=cast(:seqname as regclass)" + "where c.relkind='S' and c.relname=:seqname" ), seqname=seq_match.group(1), ).first() @@ -262,60 +252,62 @@ class PostgresqlImpl(DefaultImpl): if not sqla_compat.sqla_2: self._skip_functional_indexes(metadata_indexes, conn_indexes) - # pg behavior regarding modifiers - # | # | compiled sql | returned sql | regexp. group is removed | - # | - | ---------------- | -----------------| ------------------------ | - # | 1 | nulls first | nulls first | - | - # | 2 | nulls last | | (? str: + def _cleanup_index_expr( + self, index: Index, expr: str, remove_suffix: str + ) -> str: + # start = expr expr = expr.lower().replace('"', "").replace("'", "") if index.table is not None: # should not be needed, since include_table=False is in compile expr = expr.replace(f"{index.table.name.lower()}.", "") + while expr and expr[0] == "(" and expr[-1] == ")": + expr = expr[1:-1] if "::" in expr: # strip :: cast. types can have spaces in them expr = re.sub(r"(::[\w ]+\w)", "", expr) - while expr and expr[0] == "(" and expr[-1] == ")": - expr = expr[1:-1] + if remove_suffix and expr.endswith(remove_suffix): + expr = expr[: -len(remove_suffix)] - # NOTE: when parsing the connection expression this cleanup could - # be skipped - for rs in self._default_modifiers_re: - if match := rs.search(expr): - start, end = match.span(1) - expr = expr[:start] + expr[end:] - break - - while expr and expr[0] == "(" and expr[-1] == ")": - expr = expr[1:-1] - - # strip casts - cast_re = re.compile(r"cast\s*\(") - if cast_re.match(expr): - expr = cast_re.sub("", expr) - # remove the as type - expr = re.sub(r"as\s+[^)]+\)", "", expr) - # remove spaces - expr = expr.replace(" ", "") + # print(f"START: {start} END: {expr}") return expr - def _dialect_options( + def _default_modifiers(self, exp: ClauseElement) -> str: + to_remove = "" + while isinstance(exp, UnaryExpression): + if exp.modifier is None: + exp = exp.element + else: + op = exp.modifier + if isinstance(exp.element, UnaryExpression): + inner_op = exp.element.modifier + else: + inner_op = None + if inner_op is None: + if op == operators.asc_op: + # default is asc + to_remove = " asc" + elif op == operators.nullslast_op: + # default is nulls last + to_remove = " nulls last" + else: + if ( + inner_op == operators.asc_op + and op == operators.nullslast_op + ): + # default is asc nulls last + to_remove = " asc nulls last" + elif ( + inner_op == operators.desc_op + and op == operators.nullsfirst_op + ): + # default for desc is nulls first + to_remove = " nulls first" + break + return to_remove + + def _dialect_sig( self, item: Union[Index, UniqueConstraint] ) -> Tuple[Any, ...]: # only the positive case is returned by sqlalchemy reflection so @@ -324,93 +316,25 @@ class PostgresqlImpl(DefaultImpl): return ("nulls_not_distinct",) return () - def compare_indexes( - self, - metadata_index: Index, - reflected_index: Index, - ) -> ComparisonResult: - msg = [] - unique_msg = self._compare_index_unique( - metadata_index, reflected_index - ) - if unique_msg: - msg.append(unique_msg) - m_exprs = metadata_index.expressions - r_exprs = reflected_index.expressions - if len(m_exprs) != len(r_exprs): - msg.append(f"expression number {len(r_exprs)} to {len(m_exprs)}") - if msg: - # no point going further, return early - return ComparisonResult.Different(msg) - skip = [] - for pos, (m_e, r_e) in enumerate(zip(m_exprs, r_exprs), 1): - m_compile = self._compile_element(m_e) - m_text = self._cleanup_index_expr(metadata_index, m_compile) - # print(f"META ORIG: {m_compile!r} CLEANUP: {m_text!r}") - r_compile = self._compile_element(r_e) - r_text = self._cleanup_index_expr(metadata_index, r_compile) - # print(f"CONN ORIG: {r_compile!r} CLEANUP: {r_text!r}") - if m_text == r_text: - continue # expressions these are equal - elif m_compile.strip().endswith("_ops") and ( - " " in m_compile or ")" in m_compile # is an expression - ): - skip.append( - f"expression #{pos} {m_compile!r} detected " - "as including operator clause." - ) - util.warn( - f"Expression #{pos} {m_compile!r} in index " - f"{reflected_index.name!r} detected to include " - "an operator clause. Expression compare cannot proceed. " - "Please move the operator clause to the " - "``postgresql_ops`` dict to enable proper compare " - "of the index expressions: " - "https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#operator-classes", # noqa: E501 - ) - else: - msg.append(f"expression #{pos} {r_compile!r} to {m_compile!r}") - - m_options = self._dialect_options(metadata_index) - r_options = self._dialect_options(reflected_index) - if m_options != r_options: - msg.extend(f"options {r_options} to {m_options}") - - if msg: - return ComparisonResult.Different(msg) - elif skip: - # if there are other changes detected don't skip the index - return ComparisonResult.Skip(skip) - else: - return ComparisonResult.Equal() - - def compare_unique_constraint( - self, - metadata_constraint: UniqueConstraint, - reflected_constraint: UniqueConstraint, - ) -> ComparisonResult: - metadata_tup = self._create_metadata_constraint_sig( - metadata_constraint - ) - reflected_tup = self._create_reflected_constraint_sig( - reflected_constraint - ) - - meta_sig = metadata_tup.unnamed - conn_sig = reflected_tup.unnamed - if conn_sig != meta_sig: - return ComparisonResult.Different( - f"expression {conn_sig} to {meta_sig}" + def create_index_sig(self, index: Index) -> Tuple[Any, ...]: + return tuple( + self._cleanup_index_expr( + index, + *( + (e, "") + if isinstance(e, str) + else (self._compile_element(e), self._default_modifiers(e)) + ), ) + for e in index.expressions + ) + self._dialect_sig(index) - metadata_do = self._dialect_options(metadata_tup.const) - conn_do = self._dialect_options(reflected_tup.const) - if metadata_do != conn_do: - return ComparisonResult.Different( - f"expression {conn_do} to {metadata_do}" - ) - - return ComparisonResult.Equal() + def create_unique_constraint_sig( + self, const: UniqueConstraint + ) -> Tuple[Any, ...]: + return tuple( + sorted([col.name for col in const.columns]) + ) + self._dialect_sig(const) def adjust_reflected_dialect_options( self, reflected_options: Dict[str, Any], kind: str @@ -421,9 +345,7 @@ class PostgresqlImpl(DefaultImpl): options.pop("postgresql_include", None) return options - def _compile_element(self, element: Union[ClauseElement, str]) -> str: - if isinstance(element, str): - return element + def _compile_element(self, element: ClauseElement) -> str: return element.compile( dialect=self.dialect, compile_kwargs={"literal_binds": True, "include_table": False}, @@ -590,7 +512,7 @@ def visit_identity_column( ) else: text += "SET %s " % compiler.get_identity_options( - Identity(**{attr: getattr(identity, attr)}) + sqla_compat.Identity(**{attr: getattr(identity, attr)}) ) return text @@ -634,8 +556,9 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): return cls( constraint.name, constraint_table.name, - [ # type: ignore - (expr, op) for expr, name, op in constraint._render_exprs + [ + (expr, op) + for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa ], where=cast("ColumnElement[bool] | None", constraint.where), schema=constraint_table.schema, @@ -662,7 +585,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): expr, name, oper, - ) in excl._render_exprs: + ) in excl._render_exprs: # type:ignore[attr-defined] t.append_column(Column(name, NULLTYPE)) t.append_constraint(excl) return excl @@ -720,7 +643,7 @@ class CreateExcludeConstraintOp(ops.AddConstraintOp): constraint_name: str, *elements: Any, **kw: Any, - ) -> Optional[Table]: + ): """Issue a "create exclude constraint" instruction using the current batch migration context. @@ -792,13 +715,10 @@ def _exclude_constraint( args = [ "(%s, %r)" % ( - _render_potential_column( - sqltext, # type:ignore[arg-type] - autogen_context, - ), + _render_potential_column(sqltext, autogen_context), opstring, ) - for sqltext, name, opstring in constraint._render_exprs + for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa ] if constraint.where is not None: args.append( @@ -850,5 +770,5 @@ def _render_potential_column( return render._render_potential_expr( value, autogen_context, - wrap_in_element=isinstance(value, (TextClause, FunctionElement)), + wrap_in_text=isinstance(value, (TextClause, FunctionElement)), ) diff --git a/venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py b/venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py index 5f14133..c6186c6 100644 --- a/venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py +++ b/venv/lib/python3.12/site-packages/alembic/ddl/sqlite.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import re @@ -11,19 +8,16 @@ from typing import TYPE_CHECKING from typing import Union from sqlalchemy import cast -from sqlalchemy import Computed from sqlalchemy import JSON from sqlalchemy import schema from sqlalchemy import sql +from sqlalchemy.ext.compiler import compiles from .base import alter_table -from .base import ColumnName -from .base import format_column_name from .base import format_table_name from .base import RenameTable from .impl import DefaultImpl from .. import util -from ..util.sqla_compat import compiles if TYPE_CHECKING: from sqlalchemy.engine.reflection import Inspector @@ -65,7 +59,7 @@ class SQLiteImpl(DefaultImpl): ) and isinstance(col.server_default.arg, sql.ClauseElement): return True elif ( - isinstance(col.server_default, Computed) + isinstance(col.server_default, util.sqla_compat.Computed) and col.server_default.persisted ): return True @@ -77,13 +71,13 @@ class SQLiteImpl(DefaultImpl): def add_constraint(self, const: Constraint): # attempt to distinguish between an # auto-gen constraint and an explicit one - if const._create_rule is None: + if const._create_rule is None: # type:ignore[attr-defined] raise NotImplementedError( "No support for ALTER of constraints in SQLite dialect. " "Please refer to the batch mode feature which allows for " "SQLite migrations using a copy-and-move strategy." ) - elif const._create_rule(self): + elif const._create_rule(self): # type:ignore[attr-defined] util.warn( "Skipping unsupported ALTER for " "creation of implicit constraint. " @@ -91,8 +85,8 @@ class SQLiteImpl(DefaultImpl): "SQLite migrations using a copy-and-move strategy." ) - def drop_constraint(self, const: Constraint, **kw: Any): - if const._create_rule is None: + def drop_constraint(self, const: Constraint): + if const._create_rule is None: # type:ignore[attr-defined] raise NotImplementedError( "No support for ALTER of constraints in SQLite dialect. " "Please refer to the batch mode feature which allows for " @@ -183,7 +177,8 @@ class SQLiteImpl(DefaultImpl): new_type: TypeEngine, ) -> None: if ( - existing.type._type_affinity is not new_type._type_affinity + existing.type._type_affinity # type:ignore[attr-defined] + is not new_type._type_affinity # type:ignore[attr-defined] and not isinstance(new_type, JSON) ): existing_transfer["expr"] = cast( @@ -210,15 +205,6 @@ def visit_rename_table( ) -@compiles(ColumnName, "sqlite") -def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str: - return "%s RENAME COLUMN %s TO %s" % ( - alter_table(compiler, element.table_name, element.schema), - format_column_name(compiler, element.column_name), - format_column_name(compiler, element.newname), - ) - - # @compiles(AddColumn, 'sqlite') # def visit_add_column(element, compiler, **kw): # return "%s %s" % ( diff --git a/venv/lib/python3.12/site-packages/alembic/op.pyi b/venv/lib/python3.12/site-packages/alembic/op.pyi index 8cdf759..944b5ae 100644 --- a/venv/lib/python3.12/site-packages/alembic/op.pyi +++ b/venv/lib/python3.12/site-packages/alembic/op.pyi @@ -12,7 +12,6 @@ from typing import List from typing import Literal from typing import Mapping from typing import Optional -from typing import overload from typing import Sequence from typing import Tuple from typing import Type @@ -27,6 +26,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.expression import TableClause + from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed from sqlalchemy.sql.schema import Identity @@ -35,36 +35,16 @@ if TYPE_CHECKING: from sqlalchemy.sql.type_api import TypeEngine from sqlalchemy.util import immutabledict - from .operations.base import BatchOperations - from .operations.ops import AddColumnOp - from .operations.ops import AddConstraintOp - from .operations.ops import AlterColumnOp - from .operations.ops import AlterTableOp - from .operations.ops import BulkInsertOp - from .operations.ops import CreateIndexOp - from .operations.ops import CreateTableCommentOp - from .operations.ops import CreateTableOp - from .operations.ops import DropColumnOp - from .operations.ops import DropConstraintOp - from .operations.ops import DropIndexOp - from .operations.ops import DropTableCommentOp - from .operations.ops import DropTableOp - from .operations.ops import ExecuteSQLOp + from .operations.ops import BatchOperations from .operations.ops import MigrateOperation from .runtime.migration import MigrationContext from .util.sqla_compat import _literal_bindparam _T = TypeVar("_T") -_C = TypeVar("_C", bound=Callable[..., Any]) - ### end imports ### def add_column( - table_name: str, - column: Column[Any], - *, - schema: Optional[str] = None, - if_not_exists: Optional[bool] = None, + table_name: str, column: Column[Any], *, schema: Optional[str] = None ) -> None: """Issue an "add column" instruction using the current migration context. @@ -141,10 +121,6 @@ def add_column( quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator - when creating the new column for compatible dialects - - .. versionadded:: 1.16.0 """ @@ -154,14 +130,12 @@ def alter_column( *, nullable: Optional[bool] = None, comment: Union[str, Literal[False], None] = False, - server_default: Union[ - str, bool, Identity, Computed, TextClause, None - ] = False, + server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, - existing_type: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, + type_: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + str, bool, Identity, Computed, None ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, @@ -256,7 +230,7 @@ def batch_alter_table( table_name: str, schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", - partial_reordering: Optional[Tuple[Any, ...]] = None, + partial_reordering: Optional[tuple] = None, copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = immutabledict({}), @@ -403,7 +377,7 @@ def batch_alter_table( def bulk_insert( table: Union[Table, TableClause], - rows: List[Dict[str, Any]], + rows: List[dict], *, multiinsert: bool = True, ) -> None: @@ -659,7 +633,7 @@ def create_foreign_key( def create_index( index_name: Optional[str], table_name: str, - columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + columns: Sequence[Union[str, TextClause, Function[Any]]], *, schema: Optional[str] = None, unique: bool = False, @@ -756,12 +730,7 @@ def create_primary_key( """ -def create_table( - table_name: str, - *columns: SchemaItem, - if_not_exists: Optional[bool] = None, - **kw: Any, -) -> Table: +def create_table(table_name: str, *columns: SchemaItem, **kw: Any) -> Table: r"""Issue a "create table" instruction using the current migration context. @@ -832,10 +801,6 @@ def create_table( quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator when - creating the new table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. @@ -935,11 +900,6 @@ def drop_column( quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the new column for compatible dialects - - .. versionadded:: 1.16.0 - :param mssql_drop_check: Optional boolean. When ``True``, on Microsoft SQL Server only, first drop the CHECK constraint on the column using a @@ -961,6 +921,7 @@ def drop_column( then exec's a separate DROP CONSTRAINT for that default. Only works if the column has exactly one FK constraint which refers to it, at the moment. + """ def drop_constraint( @@ -969,7 +930,6 @@ def drop_constraint( type_: Optional[str] = None, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, ) -> None: r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. @@ -981,10 +941,6 @@ def drop_constraint( quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the constraint - - .. versionadded:: 1.16.0 """ @@ -1025,11 +981,7 @@ def drop_index( """ def drop_table( - table_name: str, - *, - schema: Optional[str] = None, - if_exists: Optional[bool] = None, - **kw: Any, + table_name: str, *, schema: Optional[str] = None, **kw: Any ) -> None: r"""Issue a "drop table" instruction using the current migration context. @@ -1044,10 +996,6 @@ def drop_table( quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. @@ -1184,7 +1132,7 @@ def f(name: str) -> conv: names will be converted along conventions. If the ``target_metadata`` contains the naming convention ``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the - output of the following:: + output of the following: op.add_column("t", "x", Boolean(name="x")) @@ -1214,7 +1162,7 @@ def get_context() -> MigrationContext: """ -def implementation_for(op_cls: Any) -> Callable[[_C], _C]: +def implementation_for(op_cls: Any) -> Callable[..., Any]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -1226,7 +1174,7 @@ def implementation_for(op_cls: Any) -> Callable[[_C], _C]: """ def inline_literal( - value: Union[str, int], type_: Optional[TypeEngine[Any]] = None + value: Union[str, int], type_: Optional[TypeEngine] = None ) -> _literal_bindparam: r"""Produce an 'inline literal' expression, suitable for using in an INSERT, UPDATE, or DELETE statement. @@ -1270,27 +1218,6 @@ def inline_literal( """ -@overload -def invoke(operation: CreateTableOp) -> Table: ... -@overload -def invoke( - operation: Union[ - AddConstraintOp, - DropConstraintOp, - CreateIndexOp, - DropIndexOp, - AddColumnOp, - AlterColumnOp, - AlterTableOp, - CreateTableCommentOp, - DropTableCommentOp, - DropColumnOp, - BulkInsertOp, - DropTableOp, - ExecuteSQLOp, - ], -) -> None: ... -@overload def invoke(operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -1299,7 +1226,7 @@ def invoke(operation: MigrateOperation) -> Any: def register_operation( name: str, sourcename: Optional[str] = None -) -> Callable[[Type[_T]], Type[_T]]: +) -> Callable[[_T], _T]: """Register a new operation for this class. This method is normally used to add new operations diff --git a/venv/lib/python3.12/site-packages/alembic/operations/base.py b/venv/lib/python3.12/site-packages/alembic/operations/base.py index 26c3272..e3207be 100644 --- a/venv/lib/python3.12/site-packages/alembic/operations/base.py +++ b/venv/lib/python3.12/site-packages/alembic/operations/base.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-calls - from __future__ import annotations from contextlib import contextmanager @@ -12,9 +10,7 @@ from typing import Dict from typing import Iterator from typing import List # noqa from typing import Mapping -from typing import NoReturn from typing import Optional -from typing import overload from typing import Sequence # noqa from typing import Tuple from typing import Type # noqa @@ -43,6 +39,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.expression import ColumnElement from sqlalchemy.sql.expression import TableClause from sqlalchemy.sql.expression import TextClause + from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed from sqlalchemy.sql.schema import Identity @@ -50,28 +47,12 @@ if TYPE_CHECKING: from sqlalchemy.types import TypeEngine from .batch import BatchOperationsImpl - from .ops import AddColumnOp - from .ops import AddConstraintOp - from .ops import AlterColumnOp - from .ops import AlterTableOp - from .ops import BulkInsertOp - from .ops import CreateIndexOp - from .ops import CreateTableCommentOp - from .ops import CreateTableOp - from .ops import DropColumnOp - from .ops import DropConstraintOp - from .ops import DropIndexOp - from .ops import DropTableCommentOp - from .ops import DropTableOp - from .ops import ExecuteSQLOp from .ops import MigrateOperation from ..ddl import DefaultImpl from ..runtime.migration import MigrationContext __all__ = ("Operations", "BatchOperations") _T = TypeVar("_T") -_C = TypeVar("_C", bound=Callable[..., Any]) - class AbstractOperations(util.ModuleClsProxy): """Base class for Operations and BatchOperations. @@ -105,7 +86,7 @@ class AbstractOperations(util.ModuleClsProxy): @classmethod def register_operation( cls, name: str, sourcename: Optional[str] = None - ) -> Callable[[Type[_T]], Type[_T]]: + ) -> Callable[[_T], _T]: """Register a new operation for this class. This method is normally used to add new operations @@ -122,7 +103,7 @@ class AbstractOperations(util.ModuleClsProxy): """ - def register(op_cls: Type[_T]) -> Type[_T]: + def register(op_cls): if sourcename is None: fn = getattr(op_cls, name) source_name = fn.__name__ @@ -141,11 +122,8 @@ class AbstractOperations(util.ModuleClsProxy): *spec, formatannotation=formatannotation_fwdref ) num_defaults = len(spec[3]) if spec[3] else 0 - - defaulted_vals: Tuple[Any, ...] - if num_defaults: - defaulted_vals = tuple(name_args[0 - num_defaults :]) + defaulted_vals = name_args[0 - num_defaults :] else: defaulted_vals = () @@ -186,7 +164,7 @@ class AbstractOperations(util.ModuleClsProxy): globals_ = dict(globals()) globals_.update({"op_cls": op_cls}) - lcl: Dict[str, Any] = {} + lcl = {} exec(func_text, globals_, lcl) setattr(cls, name, lcl[name]) @@ -202,7 +180,7 @@ class AbstractOperations(util.ModuleClsProxy): return register @classmethod - def implementation_for(cls, op_cls: Any) -> Callable[[_C], _C]: + def implementation_for(cls, op_cls: Any) -> Callable[..., Any]: """Register an implementation for a given :class:`.MigrateOperation`. This is part of the operation extensibility API. @@ -213,7 +191,7 @@ class AbstractOperations(util.ModuleClsProxy): """ - def decorate(fn: _C) -> _C: + def decorate(fn): cls._to_impl.dispatch_for(op_cls)(fn) return fn @@ -235,7 +213,7 @@ class AbstractOperations(util.ModuleClsProxy): table_name: str, schema: Optional[str] = None, recreate: Literal["auto", "always", "never"] = "auto", - partial_reordering: Optional[Tuple[Any, ...]] = None, + partial_reordering: Optional[tuple] = None, copy_from: Optional[Table] = None, table_args: Tuple[Any, ...] = (), table_kwargs: Mapping[str, Any] = util.immutabledict(), @@ -404,32 +382,6 @@ class AbstractOperations(util.ModuleClsProxy): return self.migration_context - @overload - def invoke(self, operation: CreateTableOp) -> Table: ... - - @overload - def invoke( - self, - operation: Union[ - AddConstraintOp, - DropConstraintOp, - CreateIndexOp, - DropIndexOp, - AddColumnOp, - AlterColumnOp, - AlterTableOp, - CreateTableCommentOp, - DropTableCommentOp, - DropColumnOp, - BulkInsertOp, - DropTableOp, - ExecuteSQLOp, - ], - ) -> None: ... - - @overload - def invoke(self, operation: MigrateOperation) -> Any: ... - def invoke(self, operation: MigrateOperation) -> Any: """Given a :class:`.MigrateOperation`, invoke it in terms of this :class:`.Operations` instance. @@ -464,7 +416,7 @@ class AbstractOperations(util.ModuleClsProxy): names will be converted along conventions. If the ``target_metadata`` contains the naming convention ``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the - output of the following:: + output of the following: op.add_column("t", "x", Boolean(name="x")) @@ -618,7 +570,6 @@ class Operations(AbstractOperations): column: Column[Any], *, schema: Optional[str] = None, - if_not_exists: Optional[bool] = None, ) -> None: """Issue an "add column" instruction using the current migration context. @@ -695,10 +646,6 @@ class Operations(AbstractOperations): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator - when creating the new column for compatible dialects - - .. versionadded:: 1.16.0 """ # noqa: E501 ... @@ -710,16 +657,12 @@ class Operations(AbstractOperations): *, nullable: Optional[bool] = None, comment: Union[str, Literal[False], None] = False, - server_default: Union[ - str, bool, Identity, Computed, TextClause, None - ] = False, + server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, - existing_type: Union[ - TypeEngine[Any], Type[TypeEngine[Any]], None - ] = None, + type_: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + str, bool, Identity, Computed, None ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, @@ -813,7 +756,7 @@ class Operations(AbstractOperations): def bulk_insert( self, table: Union[Table, TableClause], - rows: List[Dict[str, Any]], + rows: List[dict], *, multiinsert: bool = True, ) -> None: @@ -1080,7 +1023,7 @@ class Operations(AbstractOperations): self, index_name: Optional[str], table_name: str, - columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + columns: Sequence[Union[str, TextClause, Function[Any]]], *, schema: Optional[str] = None, unique: bool = False, @@ -1181,11 +1124,7 @@ class Operations(AbstractOperations): ... def create_table( - self, - table_name: str, - *columns: SchemaItem, - if_not_exists: Optional[bool] = None, - **kw: Any, + self, table_name: str, *columns: SchemaItem, **kw: Any ) -> Table: r"""Issue a "create table" instruction using the current migration context. @@ -1257,10 +1196,6 @@ class Operations(AbstractOperations): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator when - creating the new table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. @@ -1366,11 +1301,6 @@ class Operations(AbstractOperations): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the new column for compatible dialects - - .. versionadded:: 1.16.0 - :param mssql_drop_check: Optional boolean. When ``True``, on Microsoft SQL Server only, first drop the CHECK constraint on the column using a @@ -1392,6 +1322,7 @@ class Operations(AbstractOperations): then exec's a separate DROP CONSTRAINT for that default. Only works if the column has exactly one FK constraint which refers to it, at the moment. + """ # noqa: E501 ... @@ -1402,7 +1333,6 @@ class Operations(AbstractOperations): type_: Optional[str] = None, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, ) -> None: r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. @@ -1414,10 +1344,6 @@ class Operations(AbstractOperations): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the constraint - - .. versionadded:: 1.16.0 """ # noqa: E501 ... @@ -1461,12 +1387,7 @@ class Operations(AbstractOperations): ... def drop_table( - self, - table_name: str, - *, - schema: Optional[str] = None, - if_exists: Optional[bool] = None, - **kw: Any, + self, table_name: str, *, schema: Optional[str] = None, **kw: Any ) -> None: r"""Issue a "drop table" instruction using the current migration context. @@ -1481,10 +1402,6 @@ class Operations(AbstractOperations): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. @@ -1643,7 +1560,7 @@ class BatchOperations(AbstractOperations): impl: BatchOperationsImpl - def _noop(self, operation: Any) -> NoReturn: + def _noop(self, operation): raise NotImplementedError( "The %s method does not apply to a batch table alter operation." % operation @@ -1660,7 +1577,6 @@ class BatchOperations(AbstractOperations): *, insert_before: Optional[str] = None, insert_after: Optional[str] = None, - if_not_exists: Optional[bool] = None, ) -> None: """Issue an "add column" instruction using the current batch migration context. @@ -1680,10 +1596,8 @@ class BatchOperations(AbstractOperations): comment: Union[str, Literal[False], None] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Union[TypeEngine[Any], Type[TypeEngine[Any]], None] = None, - existing_type: Union[ - TypeEngine[Any], Type[TypeEngine[Any]], None - ] = None, + type_: Union[TypeEngine, Type[TypeEngine], None] = None, + existing_type: Union[TypeEngine, Type[TypeEngine], None] = None, existing_server_default: Union[ str, bool, Identity, Computed, None ] = False, @@ -1738,7 +1652,7 @@ class BatchOperations(AbstractOperations): def create_exclude_constraint( self, constraint_name: str, *elements: Any, **kw: Any - ) -> Optional[Table]: + ): """Issue a "create exclude constraint" instruction using the current batch migration context. @@ -1754,7 +1668,7 @@ class BatchOperations(AbstractOperations): def create_foreign_key( self, - constraint_name: Optional[str], + constraint_name: str, referent_table: str, local_cols: List[str], remote_cols: List[str], @@ -1804,7 +1718,7 @@ class BatchOperations(AbstractOperations): ... def create_primary_key( - self, constraint_name: Optional[str], columns: List[str] + self, constraint_name: str, columns: List[str] ) -> None: """Issue a "create primary key" instruction using the current batch migration context. diff --git a/venv/lib/python3.12/site-packages/alembic/operations/batch.py b/venv/lib/python3.12/site-packages/alembic/operations/batch.py index fe183e9..8c88e88 100644 --- a/venv/lib/python3.12/site-packages/alembic/operations/batch.py +++ b/venv/lib/python3.12/site-packages/alembic/operations/batch.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations from typing import Any @@ -18,10 +15,9 @@ from sqlalchemy import Index from sqlalchemy import MetaData from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema as sql_schema -from sqlalchemy import select from sqlalchemy import Table from sqlalchemy import types as sqltypes -from sqlalchemy.sql.schema import SchemaEventTarget +from sqlalchemy.events import SchemaEventTarget from sqlalchemy.util import OrderedDict from sqlalchemy.util import topological @@ -32,9 +28,11 @@ from ..util.sqla_compat import _copy_expression from ..util.sqla_compat import _ensure_scope_for_ddl from ..util.sqla_compat import _fk_is_self_referential from ..util.sqla_compat import _idx_table_bound_expressions +from ..util.sqla_compat import _insert_inline from ..util.sqla_compat import _is_type_bound from ..util.sqla_compat import _remove_column_from_collection from ..util.sqla_compat import _resolve_for_variant +from ..util.sqla_compat import _select from ..util.sqla_compat import constraint_name_defined from ..util.sqla_compat import constraint_name_string @@ -376,7 +374,7 @@ class ApplyBatchImpl: for idx_existing in self.indexes.values(): # this is a lift-and-move from Table.to_metadata - if idx_existing._column_flag: + if idx_existing._column_flag: # type: ignore continue idx_copy = Index( @@ -405,7 +403,9 @@ class ApplyBatchImpl: def _setup_referent( self, metadata: MetaData, constraint: ForeignKeyConstraint ) -> None: - spec = constraint.elements[0]._get_colspec() + spec = constraint.elements[ + 0 + ]._get_colspec() # type:ignore[attr-defined] parts = spec.split(".") tname = parts[-2] if len(parts) == 3: @@ -448,15 +448,13 @@ class ApplyBatchImpl: try: op_impl._exec( - self.new_table.insert() - .inline() - .from_select( + _insert_inline(self.new_table).from_select( list( k for k, transfer in self.column_transfers.items() if "expr" in transfer ), - select( + _select( *[ transfer["expr"] for transfer in self.column_transfers.values() @@ -548,7 +546,9 @@ class ApplyBatchImpl: else: sql_schema.DefaultClause( server_default # type: ignore[arg-type] - )._set_parent(existing) + )._set_parent( # type:ignore[attr-defined] + existing + ) if autoincrement is not None: existing.autoincrement = bool(autoincrement) diff --git a/venv/lib/python3.12/site-packages/alembic/operations/ops.py b/venv/lib/python3.12/site-packages/alembic/operations/ops.py index c9b1526..711d7ab 100644 --- a/venv/lib/python3.12/site-packages/alembic/operations/ops.py +++ b/venv/lib/python3.12/site-packages/alembic/operations/ops.py @@ -1,13 +1,10 @@ from __future__ import annotations from abc import abstractmethod -import os -import pathlib import re from typing import Any from typing import Callable from typing import cast -from typing import Dict from typing import FrozenSet from typing import Iterator from typing import List @@ -18,7 +15,6 @@ from typing import Set from typing import Tuple from typing import Type from typing import TYPE_CHECKING -from typing import TypeVar from typing import Union from sqlalchemy.types import NULLTYPE @@ -37,6 +33,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import conv from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.functions import Function from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Computed @@ -56,9 +53,6 @@ if TYPE_CHECKING: from ..runtime.migration import MigrationContext from ..script.revision import _RevIdType -_T = TypeVar("_T", bound=Any) -_AC = TypeVar("_AC", bound="AddConstraintOp") - class MigrateOperation: """base class for migration command and organization objects. @@ -76,7 +70,7 @@ class MigrateOperation: """ @util.memoized_property - def info(self) -> Dict[Any, Any]: + def info(self): """A dictionary that may be used to store arbitrary information along with this :class:`.MigrateOperation` object. @@ -98,14 +92,12 @@ class AddConstraintOp(MigrateOperation): add_constraint_ops = util.Dispatcher() @property - def constraint_type(self) -> str: + def constraint_type(self): raise NotImplementedError() @classmethod - def register_add_constraint( - cls, type_: str - ) -> Callable[[Type[_AC]], Type[_AC]]: - def go(klass: Type[_AC]) -> Type[_AC]: + def register_add_constraint(cls, type_: str) -> Callable: + def go(klass): cls.add_constraint_ops.dispatch_for(type_)(klass.from_constraint) return klass @@ -113,7 +105,7 @@ class AddConstraintOp(MigrateOperation): @classmethod def from_constraint(cls, constraint: Constraint) -> AddConstraintOp: - return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( # type: ignore[no-any-return] # noqa: E501 + return cls.add_constraint_ops.dispatch(constraint.__visit_name__)( constraint ) @@ -142,14 +134,12 @@ class DropConstraintOp(MigrateOperation): type_: Optional[str] = None, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, _reverse: Optional[AddConstraintOp] = None, ) -> None: self.constraint_name = constraint_name self.table_name = table_name self.constraint_type = type_ self.schema = schema - self.if_exists = if_exists self._reverse = _reverse def reverse(self) -> AddConstraintOp: @@ -207,7 +197,6 @@ class DropConstraintOp(MigrateOperation): type_: Optional[str] = None, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, ) -> None: r"""Drop a constraint of the given name, typically via DROP CONSTRAINT. @@ -219,20 +208,10 @@ class DropConstraintOp(MigrateOperation): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the constraint - - .. versionadded:: 1.16.0 """ - op = cls( - constraint_name, - table_name, - type_=type_, - schema=schema, - if_exists=if_exists, - ) + op = cls(constraint_name, table_name, type_=type_, schema=schema) return operations.invoke(op) @classmethod @@ -363,7 +342,7 @@ class CreatePrimaryKeyOp(AddConstraintOp): def batch_create_primary_key( cls, operations: BatchOperations, - constraint_name: Optional[str], + constraint_name: str, columns: List[str], ) -> None: """Issue a "create primary key" instruction using the @@ -419,7 +398,7 @@ class CreateUniqueConstraintOp(AddConstraintOp): uq_constraint = cast("UniqueConstraint", constraint) - kw: Dict[str, Any] = {} + kw: dict = {} if uq_constraint.deferrable: kw["deferrable"] = uq_constraint.deferrable if uq_constraint.initially: @@ -553,7 +532,7 @@ class CreateForeignKeyOp(AddConstraintOp): @classmethod def from_constraint(cls, constraint: Constraint) -> CreateForeignKeyOp: fk_constraint = cast("ForeignKeyConstraint", constraint) - kw: Dict[str, Any] = {} + kw: dict = {} if fk_constraint.onupdate: kw["onupdate"] = fk_constraint.onupdate if fk_constraint.ondelete: @@ -695,7 +674,7 @@ class CreateForeignKeyOp(AddConstraintOp): def batch_create_foreign_key( cls, operations: BatchOperations, - constraint_name: Optional[str], + constraint_name: str, referent_table: str, local_cols: List[str], remote_cols: List[str], @@ -918,9 +897,9 @@ class CreateIndexOp(MigrateOperation): def from_index(cls, index: Index) -> CreateIndexOp: assert index.table is not None return cls( - index.name, + index.name, # type: ignore[arg-type] index.table.name, - index.expressions, + sqla_compat._get_index_expressions(index), schema=index.table.schema, unique=index.unique, **index.kwargs, @@ -947,7 +926,7 @@ class CreateIndexOp(MigrateOperation): operations: Operations, index_name: Optional[str], table_name: str, - columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + columns: Sequence[Union[str, TextClause, Function[Any]]], *, schema: Optional[str] = None, unique: bool = False, @@ -1075,7 +1054,6 @@ class DropIndexOp(MigrateOperation): table_name=index.table.name, schema=index.table.schema, _reverse=CreateIndexOp.from_index(index), - unique=index.unique, **index.kwargs, ) @@ -1173,7 +1151,6 @@ class CreateTableOp(MigrateOperation): columns: Sequence[SchemaItem], *, schema: Optional[str] = None, - if_not_exists: Optional[bool] = None, _namespace_metadata: Optional[MetaData] = None, _constraints_included: bool = False, **kw: Any, @@ -1181,7 +1158,6 @@ class CreateTableOp(MigrateOperation): self.table_name = table_name self.columns = columns self.schema = schema - self.if_not_exists = if_not_exists self.info = kw.pop("info", {}) self.comment = kw.pop("comment", None) self.prefixes = kw.pop("prefixes", None) @@ -1206,7 +1182,7 @@ class CreateTableOp(MigrateOperation): return cls( table.name, - list(table.c) + list(table.constraints), + list(table.c) + list(table.constraints), # type:ignore[arg-type] schema=table.schema, _namespace_metadata=_namespace_metadata, # given a Table() object, this Table will contain full Index() @@ -1244,7 +1220,6 @@ class CreateTableOp(MigrateOperation): operations: Operations, table_name: str, *columns: SchemaItem, - if_not_exists: Optional[bool] = None, **kw: Any, ) -> Table: r"""Issue a "create table" instruction using the current migration @@ -1317,10 +1292,6 @@ class CreateTableOp(MigrateOperation): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator when - creating the new table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. @@ -1328,7 +1299,7 @@ class CreateTableOp(MigrateOperation): to the parameters given. """ - op = cls(table_name, columns, if_not_exists=if_not_exists, **kw) + op = cls(table_name, columns, **kw) return operations.invoke(op) @@ -1341,13 +1312,11 @@ class DropTableOp(MigrateOperation): table_name: str, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, table_kw: Optional[MutableMapping[Any, Any]] = None, _reverse: Optional[CreateTableOp] = None, ) -> None: self.table_name = table_name self.schema = schema - self.if_exists = if_exists self.table_kw = table_kw or {} self.comment = self.table_kw.pop("comment", None) self.info = self.table_kw.pop("info", None) @@ -1394,9 +1363,9 @@ class DropTableOp(MigrateOperation): info=self.info.copy() if self.info else {}, prefixes=list(self.prefixes) if self.prefixes else [], schema=self.schema, - _constraints_included=( - self._reverse._constraints_included if self._reverse else False - ), + _constraints_included=self._reverse._constraints_included + if self._reverse + else False, **self.table_kw, ) return t @@ -1408,7 +1377,6 @@ class DropTableOp(MigrateOperation): table_name: str, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, **kw: Any, ) -> None: r"""Issue a "drop table" instruction using the current @@ -1424,15 +1392,11 @@ class DropTableOp(MigrateOperation): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the table. - - .. versionadded:: 1.13.3 :param \**kw: Other keyword arguments are passed to the underlying :class:`sqlalchemy.schema.Table` object created for the command. """ - op = cls(table_name, schema=schema, if_exists=if_exists, table_kw=kw) + op = cls(table_name, schema=schema, table_kw=kw) operations.invoke(op) @@ -1570,7 +1534,7 @@ class CreateTableCommentOp(AlterTableOp): ) return operations.invoke(op) - def reverse(self) -> Union[CreateTableCommentOp, DropTableCommentOp]: + def reverse(self): """Reverses the COMMENT ON operation against a table.""" if self.existing_comment is None: return DropTableCommentOp( @@ -1586,16 +1550,14 @@ class CreateTableCommentOp(AlterTableOp): schema=self.schema, ) - def to_table( - self, migration_context: Optional[MigrationContext] = None - ) -> Table: + def to_table(self, migration_context=None): schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table( self.table_name, schema=self.schema, comment=self.comment ) - def to_diff_tuple(self) -> Tuple[Any, ...]: + def to_diff_tuple(self): return ("add_table_comment", self.to_table(), self.existing_comment) @@ -1667,20 +1629,18 @@ class DropTableCommentOp(AlterTableOp): ) return operations.invoke(op) - def reverse(self) -> CreateTableCommentOp: + def reverse(self): """Reverses the COMMENT ON operation against a table.""" return CreateTableCommentOp( self.table_name, self.existing_comment, schema=self.schema ) - def to_table( - self, migration_context: Optional[MigrationContext] = None - ) -> Table: + def to_table(self, migration_context=None): schema_obj = schemaobj.SchemaObjects(migration_context) return schema_obj.table(self.table_name, schema=self.schema) - def to_diff_tuple(self) -> Tuple[Any, ...]: + def to_diff_tuple(self): return ("remove_table_comment", self.to_table()) @@ -1855,16 +1815,12 @@ class AlterColumnOp(AlterTableOp): *, nullable: Optional[bool] = None, comment: Optional[Union[str, Literal[False]]] = False, - server_default: Union[ - str, bool, Identity, Computed, TextClause, None - ] = False, + server_default: Any = False, new_column_name: Optional[str] = None, - type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, - existing_type: Optional[ - Union[TypeEngine[Any], Type[TypeEngine[Any]]] - ] = None, - existing_server_default: Union[ - str, bool, Identity, Computed, TextClause, None + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_server_default: Optional[ + Union[str, bool, Identity, Computed] ] = False, existing_nullable: Optional[bool] = None, existing_comment: Optional[str] = None, @@ -1982,10 +1938,8 @@ class AlterColumnOp(AlterTableOp): comment: Optional[Union[str, Literal[False]]] = False, server_default: Any = False, new_column_name: Optional[str] = None, - type_: Optional[Union[TypeEngine[Any], Type[TypeEngine[Any]]]] = None, - existing_type: Optional[ - Union[TypeEngine[Any], Type[TypeEngine[Any]]] - ] = None, + type_: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, + existing_type: Optional[Union[TypeEngine, Type[TypeEngine]]] = None, existing_server_default: Optional[ Union[str, bool, Identity, Computed] ] = False, @@ -2049,31 +2003,27 @@ class AddColumnOp(AlterTableOp): column: Column[Any], *, schema: Optional[str] = None, - if_not_exists: Optional[bool] = None, **kw: Any, ) -> None: super().__init__(table_name, schema=schema) self.column = column - self.if_not_exists = if_not_exists self.kw = kw def reverse(self) -> DropColumnOp: - op = DropColumnOp.from_column_and_tablename( + return DropColumnOp.from_column_and_tablename( self.schema, self.table_name, self.column ) - op.if_exists = self.if_not_exists - return op def to_diff_tuple( self, ) -> Tuple[str, Optional[str], str, Column[Any]]: return ("add_column", self.schema, self.table_name, self.column) - def to_column(self) -> Column[Any]: + def to_column(self) -> Column: return self.column @classmethod - def from_column(cls, col: Column[Any]) -> AddColumnOp: + def from_column(cls, col: Column) -> AddColumnOp: return cls(col.table.name, col, schema=col.table.schema) @classmethod @@ -2093,7 +2043,6 @@ class AddColumnOp(AlterTableOp): column: Column[Any], *, schema: Optional[str] = None, - if_not_exists: Optional[bool] = None, ) -> None: """Issue an "add column" instruction using the current migration context. @@ -2170,19 +2119,10 @@ class AddColumnOp(AlterTableOp): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_not_exists: If True, adds IF NOT EXISTS operator - when creating the new column for compatible dialects - - .. versionadded:: 1.16.0 """ - op = cls( - table_name, - column, - schema=schema, - if_not_exists=if_not_exists, - ) + op = cls(table_name, column, schema=schema) return operations.invoke(op) @classmethod @@ -2193,7 +2133,6 @@ class AddColumnOp(AlterTableOp): *, insert_before: Optional[str] = None, insert_after: Optional[str] = None, - if_not_exists: Optional[bool] = None, ) -> None: """Issue an "add column" instruction using the current batch migration context. @@ -2214,7 +2153,6 @@ class AddColumnOp(AlterTableOp): operations.impl.table_name, column, schema=operations.impl.schema, - if_not_exists=if_not_exists, **kw, ) return operations.invoke(op) @@ -2231,14 +2169,12 @@ class DropColumnOp(AlterTableOp): column_name: str, *, schema: Optional[str] = None, - if_exists: Optional[bool] = None, _reverse: Optional[AddColumnOp] = None, **kw: Any, ) -> None: super().__init__(table_name, schema=schema) self.column_name = column_name self.kw = kw - self.if_exists = if_exists self._reverse = _reverse def to_diff_tuple( @@ -2258,11 +2194,9 @@ class DropColumnOp(AlterTableOp): "original column is not present" ) - op = AddColumnOp.from_column_and_tablename( + return AddColumnOp.from_column_and_tablename( self.schema, self.table_name, self._reverse.column ) - op.if_not_exists = self.if_exists - return op @classmethod def from_column_and_tablename( @@ -2280,7 +2214,7 @@ class DropColumnOp(AlterTableOp): def to_column( self, migration_context: Optional[MigrationContext] = None - ) -> Column[Any]: + ) -> Column: if self._reverse is not None: return self._reverse.column schema_obj = schemaobj.SchemaObjects(migration_context) @@ -2309,11 +2243,6 @@ class DropColumnOp(AlterTableOp): quoting of the schema outside of the default behavior, use the SQLAlchemy construct :class:`~sqlalchemy.sql.elements.quoted_name`. - :param if_exists: If True, adds IF EXISTS operator when - dropping the new column for compatible dialects - - .. versionadded:: 1.16.0 - :param mssql_drop_check: Optional boolean. When ``True``, on Microsoft SQL Server only, first drop the CHECK constraint on the column using a @@ -2335,6 +2264,7 @@ class DropColumnOp(AlterTableOp): then exec's a separate DROP CONSTRAINT for that default. Only works if the column has exactly one FK constraint which refers to it, at the moment. + """ op = cls(table_name, column_name, schema=schema, **kw) @@ -2368,7 +2298,7 @@ class BulkInsertOp(MigrateOperation): def __init__( self, table: Union[Table, TableClause], - rows: List[Dict[str, Any]], + rows: List[dict], *, multiinsert: bool = True, ) -> None: @@ -2381,7 +2311,7 @@ class BulkInsertOp(MigrateOperation): cls, operations: Operations, table: Union[Table, TableClause], - rows: List[Dict[str, Any]], + rows: List[dict], *, multiinsert: bool = True, ) -> None: @@ -2677,7 +2607,7 @@ class UpgradeOps(OpContainer): self.upgrade_token = upgrade_token def reverse_into(self, downgrade_ops: DowngradeOps) -> DowngradeOps: - downgrade_ops.ops[:] = list( + downgrade_ops.ops[:] = list( # type:ignore[index] reversed([op.reverse() for op in self.ops]) ) return downgrade_ops @@ -2704,7 +2634,7 @@ class DowngradeOps(OpContainer): super().__init__(ops=ops) self.downgrade_token = downgrade_token - def reverse(self) -> UpgradeOps: + def reverse(self): return UpgradeOps( ops=list(reversed([op.reverse() for op in self.ops])) ) @@ -2735,8 +2665,6 @@ class MigrationScript(MigrateOperation): """ _needs_render: Optional[bool] - _upgrade_ops: List[UpgradeOps] - _downgrade_ops: List[DowngradeOps] def __init__( self, @@ -2749,7 +2677,7 @@ class MigrationScript(MigrateOperation): head: Optional[str] = None, splice: Optional[bool] = None, branch_label: Optional[_RevIdType] = None, - version_path: Union[str, os.PathLike[str], None] = None, + version_path: Optional[str] = None, depends_on: Optional[_RevIdType] = None, ) -> None: self.rev_id = rev_id @@ -2758,15 +2686,13 @@ class MigrationScript(MigrateOperation): self.head = head self.splice = splice self.branch_label = branch_label - self.version_path = ( - pathlib.Path(version_path).as_posix() if version_path else None - ) + self.version_path = version_path self.depends_on = depends_on self.upgrade_ops = upgrade_ops self.downgrade_ops = downgrade_ops @property - def upgrade_ops(self) -> Optional[UpgradeOps]: + def upgrade_ops(self): """An instance of :class:`.UpgradeOps`. .. seealso:: @@ -2785,15 +2711,13 @@ class MigrationScript(MigrateOperation): return self._upgrade_ops[0] @upgrade_ops.setter - def upgrade_ops( - self, upgrade_ops: Union[UpgradeOps, List[UpgradeOps]] - ) -> None: + def upgrade_ops(self, upgrade_ops): self._upgrade_ops = util.to_list(upgrade_ops) for elem in self._upgrade_ops: assert isinstance(elem, UpgradeOps) @property - def downgrade_ops(self) -> Optional[DowngradeOps]: + def downgrade_ops(self): """An instance of :class:`.DowngradeOps`. .. seealso:: @@ -2812,9 +2736,7 @@ class MigrationScript(MigrateOperation): return self._downgrade_ops[0] @downgrade_ops.setter - def downgrade_ops( - self, downgrade_ops: Union[DowngradeOps, List[DowngradeOps]] - ) -> None: + def downgrade_ops(self, downgrade_ops): self._downgrade_ops = util.to_list(downgrade_ops) for elem in self._downgrade_ops: assert isinstance(elem, DowngradeOps) diff --git a/venv/lib/python3.12/site-packages/alembic/operations/schemaobj.py b/venv/lib/python3.12/site-packages/alembic/operations/schemaobj.py index 59c1002..799f113 100644 --- a/venv/lib/python3.12/site-packages/alembic/operations/schemaobj.py +++ b/venv/lib/python3.12/site-packages/alembic/operations/schemaobj.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations from typing import Any @@ -223,12 +220,10 @@ class SchemaObjects: t = sa_schema.Table(name, m, *cols, **kw) constraints = [ - ( - sqla_compat._copy(elem, target_table=t) - if getattr(elem, "parent", None) is not t - and getattr(elem, "parent", None) is not None - else elem - ) + sqla_compat._copy(elem, target_table=t) + if getattr(elem, "parent", None) is not t + and getattr(elem, "parent", None) is not None + else elem for elem in columns if isinstance(elem, (Constraint, Index)) ] @@ -279,8 +274,10 @@ class SchemaObjects: ForeignKey. """ - if isinstance(fk._colspec, str): - table_key, cname = fk._colspec.rsplit(".", 1) + if isinstance(fk._colspec, str): # type:ignore[attr-defined] + table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined] + ".", 1 + ) sname, tname = self._parse_table_key(table_key) if table_key not in metadata.tables: rel_t = sa_schema.Table(tname, metadata, schema=sname) diff --git a/venv/lib/python3.12/site-packages/alembic/operations/toimpl.py b/venv/lib/python3.12/site-packages/alembic/operations/toimpl.py index c18ec79..ba974b6 100644 --- a/venv/lib/python3.12/site-packages/alembic/operations/toimpl.py +++ b/venv/lib/python3.12/site-packages/alembic/operations/toimpl.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from typing import TYPE_CHECKING from sqlalchemy import schema as sa_schema @@ -79,11 +76,8 @@ def alter_column( @Operations.implementation_for(ops.DropTableOp) def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None: - kw = {} - if operation.if_exists is not None: - kw["if_exists"] = operation.if_exists operations.impl.drop_table( - operation.to_table(operations.migration_context), **kw + operation.to_table(operations.migration_context) ) @@ -93,11 +87,7 @@ def drop_column( ) -> None: column = operation.to_column(operations.migration_context) operations.impl.drop_column( - operation.table_name, - column, - schema=operation.schema, - if_exists=operation.if_exists, - **operation.kw, + operation.table_name, column, schema=operation.schema, **operation.kw ) @@ -108,6 +98,9 @@ def create_index( idx = operation.to_index(operations.migration_context) kw = {} if operation.if_not_exists is not None: + if not sqla_2: + raise NotImplementedError("SQLAlchemy 2.0+ required") + kw["if_not_exists"] = operation.if_not_exists operations.impl.create_index(idx, **kw) @@ -116,6 +109,9 @@ def create_index( def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None: kw = {} if operation.if_exists is not None: + if not sqla_2: + raise NotImplementedError("SQLAlchemy 2.0+ required") + kw["if_exists"] = operation.if_exists operations.impl.drop_index( @@ -128,11 +124,8 @@ def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None: def create_table( operations: "Operations", operation: "ops.CreateTableOp" ) -> "Table": - kw = {} - if operation.if_not_exists is not None: - kw["if_not_exists"] = operation.if_not_exists table = operation.to_table(operations.migration_context) - operations.impl.create_table(table, **kw) + operations.impl.create_table(table) return table @@ -172,13 +165,7 @@ def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None: column = _copy(column) t = operations.schema_obj.table(table_name, column, schema=schema) - operations.impl.add_column( - table_name, - column, - schema=schema, - if_not_exists=operation.if_not_exists, - **kw, - ) + operations.impl.add_column(table_name, column, schema=schema, **kw) for constraint in t.constraints: if not isinstance(constraint, sa_schema.PrimaryKeyConstraint): @@ -208,19 +195,13 @@ def create_constraint( def drop_constraint( operations: "Operations", operation: "ops.DropConstraintOp" ) -> None: - kw = {} - if operation.if_exists is not None: - if not sqla_2: - raise NotImplementedError("SQLAlchemy 2.0 required") - kw["if_exists"] = operation.if_exists operations.impl.drop_constraint( operations.schema_obj.generic_constraint( operation.constraint_name, operation.table_name, operation.constraint_type, schema=operation.schema, - ), - **kw, + ) ) diff --git a/venv/lib/python3.12/site-packages/alembic/runtime/environment.py b/venv/lib/python3.12/site-packages/alembic/runtime/environment.py index 80ca2b6..7640f56 100644 --- a/venv/lib/python3.12/site-packages/alembic/runtime/environment.py +++ b/venv/lib/python3.12/site-packages/alembic/runtime/environment.py @@ -3,13 +3,13 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import Collection +from typing import ContextManager from typing import Dict from typing import List from typing import Mapping from typing import MutableMapping from typing import Optional from typing import overload -from typing import Sequence from typing import TextIO from typing import Tuple from typing import TYPE_CHECKING @@ -17,7 +17,6 @@ from typing import Union from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import FetchedValue -from typing_extensions import ContextManager from typing_extensions import Literal from .migration import _ProxyTransaction @@ -108,6 +107,7 @@ CompareType = Callable[ class EnvironmentContext(util.ModuleClsProxy): + """A configurational facade made available in an ``env.py`` script. The :class:`.EnvironmentContext` acts as a *facade* to the more @@ -227,9 +227,9 @@ class EnvironmentContext(util.ModuleClsProxy): has been configured. """ - return self.context_opts.get("as_sql", False) # type: ignore[no-any-return] # noqa: E501 + return self.context_opts.get("as_sql", False) - def is_transactional_ddl(self) -> bool: + def is_transactional_ddl(self): """Return True if the context is configured to expect a transactional DDL capable backend. @@ -341,17 +341,18 @@ class EnvironmentContext(util.ModuleClsProxy): return self.context_opts.get("tag", None) @overload - def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: ... + def get_x_argument(self, as_dictionary: Literal[False]) -> List[str]: + ... @overload - def get_x_argument( - self, as_dictionary: Literal[True] - ) -> Dict[str, str]: ... + def get_x_argument(self, as_dictionary: Literal[True]) -> Dict[str, str]: + ... @overload def get_x_argument( self, as_dictionary: bool = ... - ) -> Union[List[str], Dict[str, str]]: ... + ) -> Union[List[str], Dict[str, str]]: + ... def get_x_argument( self, as_dictionary: bool = False @@ -365,11 +366,7 @@ class EnvironmentContext(util.ModuleClsProxy): The return value is a list, returned directly from the ``argparse`` structure. If ``as_dictionary=True`` is passed, the ``x`` arguments are parsed using ``key=value`` format into a dictionary that is - then returned. If there is no ``=`` in the argument, value is an empty - string. - - .. versionchanged:: 1.13.1 Support ``as_dictionary=True`` when - arguments are passed without the ``=`` symbol. + then returned. For example, to support passing a database URL on the command line, the standard ``env.py`` script can be modified like this:: @@ -403,12 +400,7 @@ class EnvironmentContext(util.ModuleClsProxy): else: value = [] if as_dictionary: - dict_value = {} - for arg in value: - x_key, _, x_value = arg.partition("=") - dict_value[x_key] = x_value - value = dict_value - + value = dict(arg.split("=", 1) for arg in value) return value def configure( @@ -424,7 +416,7 @@ class EnvironmentContext(util.ModuleClsProxy): tag: Optional[str] = None, template_args: Optional[Dict[str, Any]] = None, render_as_batch: bool = False, - target_metadata: Union[MetaData, Sequence[MetaData], None] = None, + target_metadata: Optional[MetaData] = None, include_name: Optional[IncludeNameFn] = None, include_object: Optional[IncludeObjectFn] = None, include_schemas: bool = False, @@ -948,7 +940,7 @@ class EnvironmentContext(util.ModuleClsProxy): def execute( self, sql: Union[Executable, str], - execution_options: Optional[Dict[str, Any]] = None, + execution_options: Optional[dict] = None, ) -> None: """Execute the given SQL using the current change context. @@ -976,7 +968,7 @@ class EnvironmentContext(util.ModuleClsProxy): def begin_transaction( self, - ) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]: + ) -> Union[_ProxyTransaction, ContextManager[None]]: """Return a context manager that will enclose an operation within a "transaction", as defined by the environment's offline diff --git a/venv/lib/python3.12/site-packages/alembic/runtime/migration.py b/venv/lib/python3.12/site-packages/alembic/runtime/migration.py index c1c7b0f..24e3d64 100644 --- a/venv/lib/python3.12/site-packages/alembic/runtime/migration.py +++ b/venv/lib/python3.12/site-packages/alembic/runtime/migration.py @@ -1,6 +1,3 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations from contextlib import contextmanager @@ -11,6 +8,7 @@ from typing import Any from typing import Callable from typing import cast from typing import Collection +from typing import ContextManager from typing import Dict from typing import Iterable from typing import Iterator @@ -23,11 +21,13 @@ from typing import Union from sqlalchemy import Column from sqlalchemy import literal_column -from sqlalchemy import select +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy.engine import Engine from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine.strategies import MockEngineStrategy -from typing_extensions import ContextManager from .. import ddl from .. import util @@ -83,6 +83,7 @@ class _ProxyTransaction: class MigrationContext: + """Represent the database state made available to a migration script. @@ -175,11 +176,7 @@ class MigrationContext: opts["output_encoding"], ) else: - self.output_buffer = opts.get( - "output_buffer", sys.stdout - ) # type:ignore[assignment] # noqa: E501 - - self.transactional_ddl = transactional_ddl + self.output_buffer = opts.get("output_buffer", sys.stdout) self._user_compare_type = opts.get("compare_type", True) self._user_compare_server_default = opts.get( @@ -191,6 +188,18 @@ class MigrationContext: self.version_table_schema = version_table_schema = opts.get( "version_table_schema", None ) + self._version = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if opts.get("version_table_pk", True): + self._version.append_constraint( + PrimaryKeyConstraint( + "version_num", name="%s_pkc" % version_table + ) + ) self._start_from_rev: Optional[str] = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( @@ -201,23 +210,14 @@ class MigrationContext: self.output_buffer, opts, ) - - self._version = self.impl.version_table_impl( - version_table=version_table, - version_table_schema=version_table_schema, - version_table_pk=opts.get("version_table_pk", True), - ) - log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") log.info( "Will assume %s DDL.", - ( - "transactional" - if self.impl.transactional_ddl - else "non-transactional" - ), + "transactional" + if self.impl.transactional_ddl + else "non-transactional", ) @classmethod @@ -342,9 +342,9 @@ class MigrationContext: # except that it will not know it's in "autocommit" and will # emit deprecation warnings when an autocommit action takes # place. - self.connection = self.impl.connection = ( - base_connection.execution_options(isolation_level="AUTOCOMMIT") - ) + self.connection = ( + self.impl.connection + ) = base_connection.execution_options(isolation_level="AUTOCOMMIT") # sqlalchemy future mode will "autobegin" in any case, so take # control of that "transaction" here @@ -372,7 +372,7 @@ class MigrationContext: def begin_transaction( self, _per_migration: bool = False - ) -> Union[_ProxyTransaction, ContextManager[None, Optional[bool]]]: + ) -> Union[_ProxyTransaction, ContextManager[None]]: """Begin a logical transaction for migration operations. This method is used within an ``env.py`` script to demarcate where @@ -521,7 +521,7 @@ class MigrationContext: start_from_rev = None elif start_from_rev is not None and self.script: start_from_rev = [ - self.script.get_revision(sfr).revision + cast("Script", self.script.get_revision(sfr)).revision for sfr in util.to_list(start_from_rev) if sfr not in (None, "base") ] @@ -536,10 +536,7 @@ class MigrationContext: return () assert self.connection is not None return tuple( - row[0] - for row in self.connection.execute( - select(self._version.c.version_num) - ) + row[0] for row in self.connection.execute(self._version.select()) ) def _ensure_version_table(self, purge: bool = False) -> None: @@ -655,7 +652,7 @@ class MigrationContext: def execute( self, sql: Union[Executable, str], - execution_options: Optional[Dict[str, Any]] = None, + execution_options: Optional[dict] = None, ) -> None: """Execute a SQL construct or string statement. @@ -1003,11 +1000,6 @@ class MigrationStep: is_upgrade: bool migration_fn: Any - if TYPE_CHECKING: - - @property - def doc(self) -> Optional[str]: ... - @property def name(self) -> str: return self.migration_fn.__name__ @@ -1056,9 +1048,13 @@ class RevisionStep(MigrationStep): self.revision = revision self.is_upgrade = is_upgrade if is_upgrade: - self.migration_fn = revision.module.upgrade + self.migration_fn = ( + revision.module.upgrade # type:ignore[attr-defined] + ) else: - self.migration_fn = revision.module.downgrade + self.migration_fn = ( + revision.module.downgrade # type:ignore[attr-defined] + ) def __repr__(self): return "RevisionStep(%r, is_upgrade=%r)" % ( @@ -1074,7 +1070,7 @@ class RevisionStep(MigrationStep): ) @property - def doc(self) -> Optional[str]: + def doc(self) -> str: return self.revision.doc @property @@ -1172,18 +1168,7 @@ class RevisionStep(MigrationStep): } return tuple(set(self.to_revisions).difference(ancestors)) else: - # for each revision we plan to return, compute its ancestors - # (excluding self), and remove those from the final output since - # they are already accounted for. - ancestors = { - r.revision - for to_revision in self.to_revisions - for r in self.revision_map._get_ancestor_nodes( - self.revision_map.get_revisions(to_revision), check=False - ) - if r.revision != to_revision - } - return tuple(set(self.to_revisions).difference(ancestors)) + return self.to_revisions def unmerge_branch_idents( self, heads: Set[str] @@ -1298,7 +1283,7 @@ class StampStep(MigrationStep): def __eq__(self, other): return ( isinstance(other, StampStep) - and other.from_revisions == self.from_revisions + and other.from_revisions == self.revisions and other.to_revisions == self.to_revisions and other.branch_move == self.branch_move and self.is_upgrade == other.is_upgrade diff --git a/venv/lib/python3.12/site-packages/alembic/script/base.py b/venv/lib/python3.12/site-packages/alembic/script/base.py index 9429231..d0f9abb 100644 --- a/venv/lib/python3.12/site-packages/alembic/script/base.py +++ b/venv/lib/python3.12/site-packages/alembic/script/base.py @@ -3,7 +3,6 @@ from __future__ import annotations from contextlib import contextmanager import datetime import os -from pathlib import Path import re import shutil import sys @@ -12,6 +11,7 @@ from typing import Any from typing import cast from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import Sequence from typing import Set @@ -23,9 +23,7 @@ from . import revision from . import write_hooks from .. import util from ..runtime import migration -from ..util import compat from ..util import not_none -from ..util.pyfiles import _preserving_path_as_str if TYPE_CHECKING: from .revision import _GetRevArg @@ -33,28 +31,26 @@ if TYPE_CHECKING: from .revision import Revision from ..config import Config from ..config import MessagingOptions - from ..config import PostWriteHookConfig from ..runtime.migration import RevisionStep from ..runtime.migration import StampStep try: - if compat.py39: - from zoneinfo import ZoneInfo - from zoneinfo import ZoneInfoNotFoundError - else: - from backports.zoneinfo import ZoneInfo # type: ignore[import-not-found,no-redef] # noqa: E501 - from backports.zoneinfo import ZoneInfoNotFoundError # type: ignore[no-redef] # noqa: E501 + from dateutil import tz except ImportError: - ZoneInfo = None # type: ignore[assignment, misc] + tz = None # type: ignore[assignment] _sourceless_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)(c|o)?$") _only_source_rev_file = re.compile(r"(?!\.\#|__init__)(.*\.py)$") _legacy_rev = re.compile(r"([a-f0-9]+)\.py$") _slug_re = re.compile(r"\w+") _default_file_template = "%(rev)s_%(slug)s" +_split_on_space_comma = re.compile(r", *|(?: +)") + +_split_on_space_comma_colon = re.compile(r", *|(?: +)|\:") class ScriptDirectory: + """Provides operations upon an Alembic script directory. This object is useful to get information as to current revisions, @@ -76,55 +72,40 @@ class ScriptDirectory: def __init__( self, - dir: Union[str, os.PathLike[str]], # noqa: A002 + dir: str, # noqa file_template: str = _default_file_template, truncate_slug_length: Optional[int] = 40, - version_locations: Optional[ - Sequence[Union[str, os.PathLike[str]]] - ] = None, + version_locations: Optional[List[str]] = None, sourceless: bool = False, output_encoding: str = "utf-8", timezone: Optional[str] = None, - hooks: list[PostWriteHookConfig] = [], + hook_config: Optional[Mapping[str, str]] = None, recursive_version_locations: bool = False, messaging_opts: MessagingOptions = cast( "MessagingOptions", util.EMPTY_DICT ), ) -> None: - self.dir = _preserving_path_as_str(dir) - self.version_locations = [ - _preserving_path_as_str(p) for p in version_locations or () - ] + self.dir = dir self.file_template = file_template + self.version_locations = version_locations self.truncate_slug_length = truncate_slug_length or 40 self.sourceless = sourceless self.output_encoding = output_encoding self.revision_map = revision.RevisionMap(self._load_revisions) self.timezone = timezone - self.hooks = hooks + self.hook_config = hook_config self.recursive_version_locations = recursive_version_locations self.messaging_opts = messaging_opts if not os.access(dir, os.F_OK): raise util.CommandError( - f"Path doesn't exist: {dir}. Please use " + "Path doesn't exist: %r. Please use " "the 'init' command to create a new " - "scripts folder." + "scripts folder." % os.path.abspath(dir) ) @property def versions(self) -> str: - """return a single version location based on the sole path passed - within version_locations. - - If multiple version locations are configured, an error is raised. - - - """ - return str(self._singular_version_location) - - @util.memoized_property - def _singular_version_location(self) -> Path: loc = self._version_locations if len(loc) > 1: raise util.CommandError("Multiple version_locations present") @@ -132,31 +113,40 @@ class ScriptDirectory: return loc[0] @util.memoized_property - def _version_locations(self) -> Sequence[Path]: + def _version_locations(self): if self.version_locations: return [ - util.coerce_resource_to_filename(location).absolute() + os.path.abspath(util.coerce_resource_to_filename(location)) for location in self.version_locations ] else: - return [Path(self.dir, "versions").absolute()] + return (os.path.abspath(os.path.join(self.dir, "versions")),) def _load_revisions(self) -> Iterator[Script]: - paths = [vers for vers in self._version_locations if vers.exists()] + if self.version_locations: + paths = [ + vers + for vers in self._version_locations + if os.path.exists(vers) + ] + else: + paths = [self.versions] dupes = set() for vers in paths: for file_path in Script._list_py_dir(self, vers): - real_path = file_path.resolve() + real_path = os.path.realpath(file_path) if real_path in dupes: util.warn( - f"File {real_path} loaded twice! ignoring. " - "Please ensure version_locations is unique." + "File %s loaded twice! ignoring. Please ensure " + "version_locations is unique." % real_path ) continue dupes.add(real_path) - script = Script._from_path(self, real_path) + filename = os.path.basename(real_path) + dir_name = os.path.dirname(real_path) + script = Script._from_filename(self, dir_name, filename) if script is None: continue yield script @@ -170,36 +160,74 @@ class ScriptDirectory: present. """ - script_location = config.get_alembic_option("script_location") + script_location = config.get_main_option("script_location") if script_location is None: raise util.CommandError( - "No 'script_location' key found in configuration." + "No 'script_location' key " "found in configuration." ) truncate_slug_length: Optional[int] - tsl = config.get_alembic_option("truncate_slug_length") + tsl = config.get_main_option("truncate_slug_length") if tsl is not None: truncate_slug_length = int(tsl) else: truncate_slug_length = None - prepend_sys_path = config.get_prepend_sys_paths_list() - if prepend_sys_path: - sys.path[:0] = prepend_sys_path + version_locations_str = config.get_main_option("version_locations") + version_locations: Optional[List[str]] + if version_locations_str: + version_path_separator = config.get_main_option( + "version_path_separator" + ) - rvl = config.get_alembic_boolean_option("recursive_version_locations") + split_on_path = { + None: None, + "space": " ", + "os": os.pathsep, + ":": ":", + ";": ";", + } + + try: + split_char: Optional[str] = split_on_path[ + version_path_separator + ] + except KeyError as ke: + raise ValueError( + "'%s' is not a valid value for " + "version_path_separator; " + "expected 'space', 'os', ':', ';'" % version_path_separator + ) from ke + else: + if split_char is None: + # legacy behaviour for backwards compatibility + version_locations = _split_on_space_comma.split( + version_locations_str + ) + else: + version_locations = [ + x for x in version_locations_str.split(split_char) if x + ] + else: + version_locations = None + + prepend_sys_path = config.get_main_option("prepend_sys_path") + if prepend_sys_path: + sys.path[:0] = list( + _split_on_space_comma_colon.split(prepend_sys_path) + ) + + rvl = config.get_main_option("recursive_version_locations") == "true" return ScriptDirectory( util.coerce_resource_to_filename(script_location), - file_template=config.get_alembic_option( + file_template=config.get_main_option( "file_template", _default_file_template ), truncate_slug_length=truncate_slug_length, - sourceless=config.get_alembic_boolean_option("sourceless"), - output_encoding=config.get_alembic_option( - "output_encoding", "utf-8" - ), - version_locations=config.get_version_locations_list(), - timezone=config.get_alembic_option("timezone"), - hooks=config.get_hooks_list(), + sourceless=config.get_main_option("sourceless") == "true", + output_encoding=config.get_main_option("output_encoding", "utf-8"), + version_locations=version_locations, + timezone=config.get_main_option("timezone"), + hook_config=config.get_section("post_write_hooks", {}), recursive_version_locations=rvl, messaging_opts=config.messaging_opts, ) @@ -269,22 +297,24 @@ class ScriptDirectory: ): yield cast(Script, rev) - def get_revisions(self, id_: _GetRevArg) -> Tuple[Script, ...]: + def get_revisions(self, id_: _GetRevArg) -> Tuple[Optional[Script], ...]: """Return the :class:`.Script` instance with the given rev identifier, symbolic name, or sequence of identifiers. """ with self._catch_revision_errors(): return cast( - Tuple[Script, ...], + Tuple[Optional[Script], ...], self.revision_map.get_revisions(id_), ) - def get_all_current(self, id_: Tuple[str, ...]) -> Set[Script]: + def get_all_current(self, id_: Tuple[str, ...]) -> Set[Optional[Script]]: with self._catch_revision_errors(): - return cast(Set[Script], self.revision_map._get_all_current(id_)) + return cast( + Set[Optional[Script]], self.revision_map._get_all_current(id_) + ) - def get_revision(self, id_: str) -> Script: + def get_revision(self, id_: str) -> Optional[Script]: """Return the :class:`.Script` instance with the given rev id. .. seealso:: @@ -294,7 +324,7 @@ class ScriptDirectory: """ with self._catch_revision_errors(): - return cast(Script, self.revision_map.get_revision(id_)) + return cast(Optional[Script], self.revision_map.get_revision(id_)) def as_revision_number( self, id_: Optional[str] @@ -549,37 +579,24 @@ class ScriptDirectory: util.load_python_file(self.dir, "env.py") @property - def env_py_location(self) -> str: - return str(Path(self.dir, "env.py")) + def env_py_location(self): + return os.path.abspath(os.path.join(self.dir, "env.py")) - def _append_template(self, src: Path, dest: Path, **kw: Any) -> None: + def _generate_template(self, src: str, dest: str, **kw: Any) -> None: with util.status( - f"Appending to existing {dest.absolute()}", - **self.messaging_opts, - ): - util.template_to_file( - src, - dest, - self.output_encoding, - append_with_newlines=True, - **kw, - ) - - def _generate_template(self, src: Path, dest: Path, **kw: Any) -> None: - with util.status( - f"Generating {dest.absolute()}", **self.messaging_opts + f"Generating {os.path.abspath(dest)}", **self.messaging_opts ): util.template_to_file(src, dest, self.output_encoding, **kw) - def _copy_file(self, src: Path, dest: Path) -> None: + def _copy_file(self, src: str, dest: str) -> None: with util.status( - f"Generating {dest.absolute()}", **self.messaging_opts + f"Generating {os.path.abspath(dest)}", **self.messaging_opts ): shutil.copy(src, dest) - def _ensure_directory(self, path: Path) -> None: - path = path.absolute() - if not path.exists(): + def _ensure_directory(self, path: str) -> None: + path = os.path.abspath(path) + if not os.path.exists(path): with util.status( f"Creating directory {path}", **self.messaging_opts ): @@ -587,27 +604,25 @@ class ScriptDirectory: def _generate_create_date(self) -> datetime.datetime: if self.timezone is not None: - if ZoneInfo is None: + if tz is None: raise util.CommandError( - "Python >= 3.9 is required for timezone support or " - "the 'backports.zoneinfo' package must be installed." + "The library 'python-dateutil' is required " + "for timezone support" ) # First, assume correct capitalization - try: - tzinfo = ZoneInfo(self.timezone) - except ZoneInfoNotFoundError: - tzinfo = None + tzinfo = tz.gettz(self.timezone) if tzinfo is None: - try: - tzinfo = ZoneInfo(self.timezone.upper()) - except ZoneInfoNotFoundError: - raise util.CommandError( - "Can't locate timezone: %s" % self.timezone - ) from None - - create_date = datetime.datetime.now( - tz=datetime.timezone.utc - ).astimezone(tzinfo) + # Fall back to uppercase + tzinfo = tz.gettz(self.timezone.upper()) + if tzinfo is None: + raise util.CommandError( + "Can't locate timezone: %s" % self.timezone + ) + create_date = ( + datetime.datetime.utcnow() + .replace(tzinfo=tz.tzutc()) + .astimezone(tzinfo) + ) else: create_date = datetime.datetime.now() return create_date @@ -619,8 +634,7 @@ class ScriptDirectory: head: Optional[_RevIdType] = None, splice: Optional[bool] = False, branch_labels: Optional[_RevIdType] = None, - version_path: Union[str, os.PathLike[str], None] = None, - file_template: Optional[str] = None, + version_path: Optional[str] = None, depends_on: Optional[_RevIdType] = None, **kw: Any, ) -> Optional[Script]: @@ -661,7 +675,7 @@ class ScriptDirectory: self.revision_map.get_revisions(head), ) for h in heads: - assert h != "base" # type: ignore[comparison-overlap] + assert h != "base" if len(set(heads)) != len(heads): raise util.CommandError("Duplicate head revisions specified") @@ -673,7 +687,7 @@ class ScriptDirectory: for head_ in heads: if head_ is not None: assert isinstance(head_, Script) - version_path = head_._script_path.parent + version_path = os.path.dirname(head_.path) break else: raise util.CommandError( @@ -681,19 +695,16 @@ class ScriptDirectory: "please specify --version-path" ) else: - version_path = self._singular_version_location - else: - version_path = Path(version_path) + version_path = self.versions - assert isinstance(version_path, Path) - norm_path = version_path.absolute() + norm_path = os.path.normpath(os.path.abspath(version_path)) for vers_path in self._version_locations: - if vers_path.absolute() == norm_path: + if os.path.normpath(vers_path) == norm_path: break else: raise util.CommandError( - f"Path {version_path} is not represented in current " - "version locations" + "Path %s is not represented in current " + "version locations" % version_path ) if self.version_locations: @@ -714,11 +725,9 @@ class ScriptDirectory: if depends_on: with self._catch_revision_errors(): resolved_depends_on = [ - ( - dep - if dep in rev.branch_labels # maintain branch labels - else rev.revision - ) # resolve partial revision identifiers + dep + if dep in rev.branch_labels # maintain branch labels + else rev.revision # resolve partial revision identifiers for rev, dep in [ (not_none(self.revision_map.get_revision(dep)), dep) for dep in util.to_list(depends_on) @@ -728,7 +737,7 @@ class ScriptDirectory: resolved_depends_on = None self._generate_template( - Path(self.dir, "script.py.mako"), + os.path.join(self.dir, "script.py.mako"), path, up_revision=str(revid), down_revision=revision.tuple_rev_as_scalar( @@ -742,7 +751,7 @@ class ScriptDirectory: **kw, ) - post_write_hooks = self.hooks + post_write_hooks = self.hook_config if post_write_hooks: write_hooks._run_hooks(path, post_write_hooks) @@ -765,11 +774,11 @@ class ScriptDirectory: def _rev_path( self, - path: Union[str, os.PathLike[str]], + path: str, rev_id: str, message: Optional[str], create_date: datetime.datetime, - ) -> Path: + ) -> str: epoch = int(create_date.timestamp()) slug = "_".join(_slug_re.findall(message or "")).lower() if len(slug) > self.truncate_slug_length: @@ -788,10 +797,11 @@ class ScriptDirectory: "second": create_date.second, } ) - return Path(path) / filename + return os.path.join(path, filename) class Script(revision.Revision): + """Represent a single revision file in a ``versions/`` directory. The :class:`.Script` instance is returned by methods @@ -799,17 +809,12 @@ class Script(revision.Revision): """ - def __init__( - self, - module: ModuleType, - rev_id: str, - path: Union[str, os.PathLike[str]], - ): + def __init__(self, module: ModuleType, rev_id: str, path: str): self.module = module - self.path = _preserving_path_as_str(path) + self.path = path super().__init__( rev_id, - module.down_revision, + module.down_revision, # type: ignore[attr-defined] branch_labels=util.to_tuple( getattr(module, "branch_labels", None), default=() ), @@ -824,10 +829,6 @@ class Script(revision.Revision): path: str """Filesystem path of the script.""" - @property - def _script_path(self) -> Path: - return Path(self.path) - _db_current_indicator: Optional[bool] = None """Utility variable which when set will cause string output to indicate this is a "current" version in some database""" @@ -846,9 +847,9 @@ class Script(revision.Revision): if doc: if hasattr(self.module, "_alembic_source_encoding"): doc = doc.decode( # type: ignore[attr-defined] - self.module._alembic_source_encoding + self.module._alembic_source_encoding # type: ignore[attr-defined] # noqa ) - return doc.strip() + return doc.strip() # type: ignore[union-attr] else: return "" @@ -888,7 +889,7 @@ class Script(revision.Revision): ) return entry - def __str__(self) -> str: + def __str__(self): return "%s -> %s%s%s%s, %s" % ( self._format_down_revision(), self.revision, @@ -922,11 +923,9 @@ class Script(revision.Revision): if head_indicators or tree_indicators: text += "%s%s%s" % ( " (head)" if self._is_real_head else "", - ( - " (effective head)" - if self.is_head and not self._is_real_head - else "" - ), + " (effective head)" + if self.is_head and not self._is_real_head + else "", " (current)" if self._db_current_indicator else "", ) if tree_indicators: @@ -960,33 +959,36 @@ class Script(revision.Revision): return util.format_as_comma(self._versioned_down_revisions) @classmethod - def _list_py_dir( - cls, scriptdir: ScriptDirectory, path: Path - ) -> List[Path]: + def _from_path( + cls, scriptdir: ScriptDirectory, path: str + ) -> Optional[Script]: + dir_, filename = os.path.split(path) + return cls._from_filename(scriptdir, dir_, filename) + + @classmethod + def _list_py_dir(cls, scriptdir: ScriptDirectory, path: str) -> List[str]: paths = [] - for root, dirs, files in compat.path_walk(path, top_down=True): - if root.name.endswith("__pycache__"): + for root, dirs, files in os.walk(path, topdown=True): + if root.endswith("__pycache__"): # a special case - we may include these files # if a `sourceless` option is specified continue for filename in sorted(files): - paths.append(root / filename) + paths.append(os.path.join(root, filename)) if scriptdir.sourceless: # look for __pycache__ - py_cache_path = root / "__pycache__" - if py_cache_path.exists(): + py_cache_path = os.path.join(root, "__pycache__") + if os.path.exists(py_cache_path): # add all files from __pycache__ whose filename is not # already in the names we got from the version directory. # add as relative paths including __pycache__ token - names = { - Path(filename).name.split(".")[0] for filename in files - } + names = {filename.split(".")[0] for filename in files} paths.extend( - py_cache_path / pyc - for pyc in py_cache_path.iterdir() - if pyc.name.split(".")[0] not in names + os.path.join(py_cache_path, pyc) + for pyc in os.listdir(py_cache_path) + if pyc.split(".")[0] not in names ) if not scriptdir.recursive_version_locations: @@ -1001,13 +1003,9 @@ class Script(revision.Revision): return paths @classmethod - def _from_path( - cls, scriptdir: ScriptDirectory, path: Union[str, os.PathLike[str]] + def _from_filename( + cls, scriptdir: ScriptDirectory, dir_: str, filename: str ) -> Optional[Script]: - - path = Path(path) - dir_, filename = path.parent, path.name - if scriptdir.sourceless: py_match = _sourceless_rev_file.match(filename) else: @@ -1025,8 +1023,8 @@ class Script(revision.Revision): is_c = is_o = False if is_o or is_c: - py_exists = (dir_ / py_filename).exists() - pyc_exists = (dir_ / (py_filename + "c")).exists() + py_exists = os.path.exists(os.path.join(dir_, py_filename)) + pyc_exists = os.path.exists(os.path.join(dir_, py_filename + "c")) # prefer .py over .pyc because we'd like to get the # source encoding; prefer .pyc over .pyo because we'd like to @@ -1042,14 +1040,14 @@ class Script(revision.Revision): m = _legacy_rev.match(filename) if not m: raise util.CommandError( - "Could not determine revision id from " - f"filename {filename}. " + "Could not determine revision id from filename %s. " "Be sure the 'revision' variable is " "declared inside the script (please see 'Upgrading " "from Alembic 0.1 to 0.2' in the documentation)." + % filename ) else: revision = m.group(1) else: revision = module.revision - return Script(module, revision, dir_ / filename) + return Script(module, revision, os.path.join(dir_, filename)) diff --git a/venv/lib/python3.12/site-packages/alembic/script/revision.py b/venv/lib/python3.12/site-packages/alembic/script/revision.py index 587e904..0350264 100644 --- a/venv/lib/python3.12/site-packages/alembic/script/revision.py +++ b/venv/lib/python3.12/site-packages/alembic/script/revision.py @@ -14,7 +14,6 @@ from typing import Iterator from typing import List from typing import Optional from typing import overload -from typing import Protocol from typing import Sequence from typing import Set from typing import Tuple @@ -48,17 +47,6 @@ _relative_destination = re.compile(r"(?:(.+?)@)?(\w+)?((?:\+|-)\d+)") _revision_illegal_chars = ["@", "-", "+"] -class _CollectRevisionsProtocol(Protocol): - def __call__( - self, - upper: _RevisionIdentifierType, - lower: _RevisionIdentifierType, - inclusive: bool, - implicit_base: bool, - assert_relative_length: bool, - ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: ... - - class RevisionError(Exception): pass @@ -408,7 +396,7 @@ class RevisionMap: for rev in self._get_ancestor_nodes( [revision], include_dependencies=False, - map_=map_, + map_=cast(_RevisionMapType, map_), ): if rev is revision: continue @@ -719,11 +707,9 @@ class RevisionMap: resolved_target = target resolved_test_against_revs = [ - ( - self._revision_for_ident(test_against_rev) - if not isinstance(test_against_rev, Revision) - else test_against_rev - ) + self._revision_for_ident(test_against_rev) + if not isinstance(test_against_rev, Revision) + else test_against_rev for test_against_rev in util.to_tuple( test_against_revs, default=() ) @@ -805,7 +791,7 @@ class RevisionMap: The iterator yields :class:`.Revision` objects. """ - fn: _CollectRevisionsProtocol + fn: Callable if select_for_downgrade: fn = self._collect_downgrade_revisions else: @@ -832,7 +818,7 @@ class RevisionMap: ) -> Iterator[Any]: if omit_immediate_dependencies: - def fn(rev: Revision) -> Iterable[str]: + def fn(rev): if rev not in targets: return rev._all_nextrev else: @@ -840,12 +826,12 @@ class RevisionMap: elif include_dependencies: - def fn(rev: Revision) -> Iterable[str]: + def fn(rev): return rev._all_nextrev else: - def fn(rev: Revision) -> Iterable[str]: + def fn(rev): return rev.nextrev return self._iterate_related_revisions( @@ -861,12 +847,12 @@ class RevisionMap: ) -> Iterator[Revision]: if include_dependencies: - def fn(rev: Revision) -> Iterable[str]: + def fn(rev): return rev._normalized_down_revisions else: - def fn(rev: Revision) -> Iterable[str]: + def fn(rev): return rev._versioned_down_revisions return self._iterate_related_revisions( @@ -875,7 +861,7 @@ class RevisionMap: def _iterate_related_revisions( self, - fn: Callable[[Revision], Iterable[str]], + fn: Callable, targets: Collection[Optional[_RevisionOrBase]], map_: Optional[_RevisionMapType], check: bool = False, @@ -937,7 +923,7 @@ class RevisionMap: id_to_rev = self._revision_map - def get_ancestors(rev_id: str) -> Set[str]: + def get_ancestors(rev_id): return { r.revision for r in self._get_ancestor_nodes([id_to_rev[rev_id]]) @@ -1017,9 +1003,9 @@ class RevisionMap: # each time but it was getting complicated current_heads[current_candidate_idx] = heads_to_add[0] current_heads.extend(heads_to_add[1:]) - ancestors_by_idx[current_candidate_idx] = ( - get_ancestors(heads_to_add[0]) - ) + ancestors_by_idx[ + current_candidate_idx + ] = get_ancestors(heads_to_add[0]) ancestors_by_idx.extend( get_ancestors(head) for head in heads_to_add[1:] ) @@ -1055,7 +1041,7 @@ class RevisionMap: children: Sequence[Optional[_RevisionOrBase]] for _ in range(abs(steps)): if steps > 0: - assert initial != "base" # type: ignore[comparison-overlap] + assert initial != "base" # Walk up walk_up = [ is_revision(rev) @@ -1069,7 +1055,7 @@ class RevisionMap: children = walk_up else: # Walk down - if initial == "base": # type: ignore[comparison-overlap] + if initial == "base": children = () else: children = self.get_revisions( @@ -1184,13 +1170,9 @@ class RevisionMap: branch_label = symbol # Walk down the tree to find downgrade target. rev = self._walk( - start=( - self.get_revision(symbol) - if branch_label is None - else self.get_revision( - "%s@%s" % (branch_label, symbol) - ) - ), + start=self.get_revision(symbol) + if branch_label is None + else self.get_revision("%s@%s" % (branch_label, symbol)), steps=rel_int, no_overwalk=assert_relative_length, ) @@ -1207,7 +1189,7 @@ class RevisionMap: # No relative destination given, revision specified is absolute. branch_label, _, symbol = target.rpartition("@") if not branch_label: - branch_label = None + branch_label = None # type:ignore[assignment] return branch_label, self.get_revision(symbol) def _parse_upgrade_target( @@ -1308,13 +1290,9 @@ class RevisionMap: ) return ( self._walk( - start=( - self.get_revision(symbol) - if branch_label is None - else self.get_revision( - "%s@%s" % (branch_label, symbol) - ) - ), + start=self.get_revision(symbol) + if branch_label is None + else self.get_revision("%s@%s" % (branch_label, symbol)), steps=relative, no_overwalk=assert_relative_length, ), @@ -1323,11 +1301,11 @@ class RevisionMap: def _collect_downgrade_revisions( self, upper: _RevisionIdentifierType, - lower: _RevisionIdentifierType, + target: _RevisionIdentifierType, inclusive: bool, implicit_base: bool, assert_relative_length: bool, - ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase], ...]]: + ) -> Any: """ Compute the set of current revisions specified by :upper, and the downgrade target specified by :target. Return all dependents of target @@ -1338,7 +1316,7 @@ class RevisionMap: branch_label, target_revision = self._parse_downgrade_target( current_revisions=upper, - target=lower, + target=target, assert_relative_length=assert_relative_length, ) if target_revision == "base": @@ -1430,7 +1408,7 @@ class RevisionMap: inclusive: bool, implicit_base: bool, assert_relative_length: bool, - ) -> Tuple[Set[Revision], Tuple[Revision, ...]]: + ) -> Tuple[Set[Revision], Tuple[Optional[_RevisionOrBase]]]: """ Compute the set of required revisions specified by :upper, and the current set of active revisions specified by :lower. Find the @@ -1522,7 +1500,7 @@ class RevisionMap: ) needs.intersection_update(lower_descendents) - return needs, tuple(targets) + return needs, tuple(targets) # type:ignore[return-value] def _get_all_current( self, id_: Tuple[str, ...] @@ -1703,13 +1681,15 @@ class Revision: @overload -def tuple_rev_as_scalar(rev: None) -> None: ... +def tuple_rev_as_scalar(rev: None) -> None: + ... @overload def tuple_rev_as_scalar( - rev: Union[Tuple[_T, ...], List[_T]], -) -> Union[_T, Tuple[_T, ...], List[_T]]: ... + rev: Union[Tuple[_T, ...], List[_T]] +) -> Union[_T, Tuple[_T, ...], List[_T]]: + ... def tuple_rev_as_scalar( diff --git a/venv/lib/python3.12/site-packages/alembic/script/write_hooks.py b/venv/lib/python3.12/site-packages/alembic/script/write_hooks.py index f40bb35..b44ce64 100644 --- a/venv/lib/python3.12/site-packages/alembic/script/write_hooks.py +++ b/venv/lib/python3.12/site-packages/alembic/script/write_hooks.py @@ -1,10 +1,5 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations -import importlib.util -import os import shlex import subprocess import sys @@ -12,16 +7,13 @@ from typing import Any from typing import Callable from typing import Dict from typing import List +from typing import Mapping from typing import Optional -from typing import TYPE_CHECKING from typing import Union from .. import util from ..util import compat -from ..util.pyfiles import _preserving_path_as_str -if TYPE_CHECKING: - from ..config import PostWriteHookConfig REVISION_SCRIPT_TOKEN = "REVISION_SCRIPT_FILENAME" @@ -48,19 +40,16 @@ def register(name: str) -> Callable: def _invoke( - name: str, - revision_path: Union[str, os.PathLike[str]], - options: PostWriteHookConfig, + name: str, revision: str, options: Mapping[str, Union[str, int]] ) -> Any: """Invokes the formatter registered for the given name. :param name: The name of a formatter in the registry - :param revision: string path to the revision file + :param revision: A :class:`.MigrationRevision` instance :param options: A dict containing kwargs passed to the specified formatter. :raises: :class:`alembic.util.CommandError` """ - revision_path = _preserving_path_as_str(revision_path) try: hook = _registry[name] except KeyError as ke: @@ -68,28 +57,36 @@ def _invoke( f"No formatter with name '{name}' registered" ) from ke else: - return hook(revision_path, options) + return hook(revision, options) -def _run_hooks( - path: Union[str, os.PathLike[str]], hooks: list[PostWriteHookConfig] -) -> None: +def _run_hooks(path: str, hook_config: Mapping[str, str]) -> None: """Invoke hooks for a generated revision.""" - for hook in hooks: - name = hook["_hook_name"] + from .base import _split_on_space_comma + + names = _split_on_space_comma.split(hook_config.get("hooks", "")) + + for name in names: + if not name: + continue + opts = { + key[len(name) + 1 :]: hook_config[key] + for key in hook_config + if key.startswith(name + ".") + } + opts["_hook_name"] = name try: - type_ = hook["type"] + type_ = opts["type"] except KeyError as ke: raise util.CommandError( - f"Key '{name}.type' (or 'type' in toml) is required " - f"for post write hook {name!r}" + f"Key {name}.type is required for post write hook {name!r}" ) from ke else: with util.status( f"Running post write hook {name!r}", newline=True ): - _invoke(type_, path, hook) + _invoke(type_, path, opts) def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]: @@ -113,35 +110,17 @@ def _parse_cmdline_options(cmdline_options_str: str, path: str) -> List[str]: return cmdline_options_list -def _get_required_option(options: dict, name: str) -> str: - try: - return options[name] - except KeyError as ke: - raise util.CommandError( - f"Key {options['_hook_name']}.{name} is required for post " - f"write hook {options['_hook_name']!r}" - ) from ke - - -def _run_hook( - path: str, options: dict, ignore_output: bool, command: List[str] -) -> None: - cwd: Optional[str] = options.get("cwd", None) - cmdline_options_str = options.get("options", "") - cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path) - - kw: Dict[str, Any] = {} - if ignore_output: - kw["stdout"] = kw["stderr"] = subprocess.DEVNULL - - subprocess.run([*command, *cmdline_options_list], cwd=cwd, **kw) - - @register("console_scripts") def console_scripts( path: str, options: dict, ignore_output: bool = False ) -> None: - entrypoint_name = _get_required_option(options, "entrypoint") + try: + entrypoint_name = options["entrypoint"] + except KeyError as ke: + raise util.CommandError( + f"Key {options['_hook_name']}.entrypoint is required for post " + f"write hook {options['_hook_name']!r}" + ) from ke for entry in compat.importlib_metadata_get("console_scripts"): if entry.name == entrypoint_name: impl: Any = entry @@ -150,27 +129,48 @@ def console_scripts( raise util.CommandError( f"Could not find entrypoint console_scripts.{entrypoint_name}" ) + cwd: Optional[str] = options.get("cwd", None) + cmdline_options_str = options.get("options", "") + cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path) - command = [ - sys.executable, - "-c", - f"import {impl.module}; {impl.module}.{impl.attr}()", - ] - _run_hook(path, options, ignore_output, command) + kw: Dict[str, Any] = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL + + subprocess.run( + [ + sys.executable, + "-c", + f"import {impl.module}; {impl.module}.{impl.attr}()", + ] + + cmdline_options_list, + cwd=cwd, + **kw, + ) @register("exec") def exec_(path: str, options: dict, ignore_output: bool = False) -> None: - executable = _get_required_option(options, "executable") - _run_hook(path, options, ignore_output, command=[executable]) + try: + executable = options["executable"] + except KeyError as ke: + raise util.CommandError( + f"Key {options['_hook_name']}.executable is required for post " + f"write hook {options['_hook_name']!r}" + ) from ke + cwd: Optional[str] = options.get("cwd", None) + cmdline_options_str = options.get("options", "") + cmdline_options_list = _parse_cmdline_options(cmdline_options_str, path) + kw: Dict[str, Any] = {} + if ignore_output: + kw["stdout"] = kw["stderr"] = subprocess.DEVNULL -@register("module") -def module(path: str, options: dict, ignore_output: bool = False) -> None: - module_name = _get_required_option(options, "module") - - if importlib.util.find_spec(module_name) is None: - raise util.CommandError(f"Could not find module {module_name}") - - command = [sys.executable, "-m", module_name] - _run_hook(path, options, ignore_output, command) + subprocess.run( + [ + executable, + *cmdline_options_list, + ], + cwd=cwd, + **kw, + ) diff --git a/venv/lib/python3.12/site-packages/alembic/templates/async/alembic.ini.mako b/venv/lib/python3.12/site-packages/alembic/templates/async/alembic.ini.mako index 67acc6d..bc9f2d5 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/async/alembic.ini.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/async/alembic.ini.mako @@ -1,32 +1,27 @@ # A generic, single database configuration. [alembic] -# path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file +# path to migration scripts script_location = ${script_location} # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. for multiple paths, the path separator -# is defined by "path_separator" below. +# defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() # leave blank for localtime # timezone = -# max length of characters to apply to the "slug" field +# max length of characters to apply to the +# "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during @@ -39,38 +34,20 @@ prepend_sys_path = . # sourceless = false # version location specification; This defaults -# to /versions. When using multiple version +# to ${script_location}/versions. When using multiple version # directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "path_separator" -# below. -# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions -# path_separator; This indicates what character is used to split lists of file -# paths, including version_locations and prepend_sys_path within configparser -# files such as alembic.ini. -# The default rendered in new alembic.ini files is "os", which uses os.pathsep -# to provide os-dependent path splitting. +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: # -# Note that in order to support legacy alembic.ini files, this default does NOT -# take place if path_separator is not present in alembic.ini. If this -# option is omitted entirely, fallback logic is as follows: -# -# 1. Parsing of the version_locations option falls back to using the legacy -# "version_path_separator" key, which if absent then falls back to the legacy -# behavior of splitting on spaces and/or commas. -# 2. Parsing of the prepend_sys_path option falls back to the legacy -# behavior of splitting on spaces, commas, or colons. -# -# Valid values for path_separator are: -# -# path_separator = : -# path_separator = ; -# path_separator = space -# path_separator = newline -# -# Use os.pathsep. Default configuration used for new projects. -path_separator = os - +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory @@ -81,9 +58,6 @@ path_separator = os # are written from script.py.mako # output_encoding = utf-8 -# database URL. This is consumed by the user-maintained env.py script only. -# other means of configuring database URLs may be customized within the env.py -# file. sqlalchemy.url = driver://user:pass@localhost/dbname @@ -98,20 +72,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname # black.entrypoint = black # black.options = -l 79 REVISION_SCRIPT_FILENAME -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# hooks = ruff -# ruff.type = module -# ruff.module = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME - -# Alternatively, use the exec runner to execute a binary found on your PATH +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff # ruff.type = exec -# ruff.executable = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME -# Logging configuration. This is also consumed by the user-maintained -# env.py script only. +# Logging configuration [loggers] keys = root,sqlalchemy,alembic @@ -122,12 +89,12 @@ keys = console keys = generic [logger_root] -level = WARNING +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = WARNING +level = WARN handlers = qualname = sqlalchemy.engine diff --git a/venv/lib/python3.12/site-packages/alembic/templates/async/script.py.mako b/venv/lib/python3.12/site-packages/alembic/templates/async/script.py.mako index 1101630..fbc4b07 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/async/script.py.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/async/script.py.mako @@ -13,16 +13,14 @@ ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} -down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} def upgrade() -> None: - """Upgrade schema.""" ${upgrades if upgrades else "pass"} def downgrade() -> None: - """Downgrade schema.""" ${downgrades if downgrades else "pass"} diff --git a/venv/lib/python3.12/site-packages/alembic/templates/generic/alembic.ini.mako b/venv/lib/python3.12/site-packages/alembic/templates/generic/alembic.ini.mako index bb93d0e..c18ddb4 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/generic/alembic.ini.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/generic/alembic.ini.mako @@ -1,10 +1,7 @@ # A generic, single database configuration. [alembic] -# path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file +# path to migration scripts script_location = ${script_location} # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s @@ -14,20 +11,19 @@ script_location = ${script_location} # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. for multiple paths, the path separator -# is defined by "path_separator" below. +# defaults to the current working directory. prepend_sys_path = . - # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() # leave blank for localtime # timezone = -# max length of characters to apply to the "slug" field +# max length of characters to apply to the +# "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during @@ -40,37 +36,20 @@ prepend_sys_path = . # sourceless = false # version location specification; This defaults -# to /versions. When using multiple version +# to ${script_location}/versions. When using multiple version # directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "path_separator" -# below. -# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions -# path_separator; This indicates what character is used to split lists of file -# paths, including version_locations and prepend_sys_path within configparser -# files such as alembic.ini. -# The default rendered in new alembic.ini files is "os", which uses os.pathsep -# to provide os-dependent path splitting. +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: # -# Note that in order to support legacy alembic.ini files, this default does NOT -# take place if path_separator is not present in alembic.ini. If this -# option is omitted entirely, fallback logic is as follows: -# -# 1. Parsing of the version_locations option falls back to using the legacy -# "version_path_separator" key, which if absent then falls back to the legacy -# behavior of splitting on spaces and/or commas. -# 2. Parsing of the prepend_sys_path option falls back to the legacy -# behavior of splitting on spaces, commas, or colons. -# -# Valid values for path_separator are: -# -# path_separator = : -# path_separator = ; -# path_separator = space -# path_separator = newline -# -# Use os.pathsep. Default configuration used for new projects. -path_separator = os +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory @@ -81,9 +60,6 @@ path_separator = os # are written from script.py.mako # output_encoding = utf-8 -# database URL. This is consumed by the user-maintained env.py script only. -# other means of configuring database URLs may be customized within the env.py -# file. sqlalchemy.url = driver://user:pass@localhost/dbname @@ -98,20 +74,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname # black.entrypoint = black # black.options = -l 79 REVISION_SCRIPT_FILENAME -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# hooks = ruff -# ruff.type = module -# ruff.module = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME - -# Alternatively, use the exec runner to execute a binary found on your PATH +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff # ruff.type = exec -# ruff.executable = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME -# Logging configuration. This is also consumed by the user-maintained -# env.py script only. +# Logging configuration [loggers] keys = root,sqlalchemy,alembic @@ -122,12 +91,12 @@ keys = console keys = generic [logger_root] -level = WARNING +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = WARNING +level = WARN handlers = qualname = sqlalchemy.engine diff --git a/venv/lib/python3.12/site-packages/alembic/templates/generic/script.py.mako b/venv/lib/python3.12/site-packages/alembic/templates/generic/script.py.mako index 1101630..fbc4b07 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/generic/script.py.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/generic/script.py.mako @@ -13,16 +13,14 @@ ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} -down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} def upgrade() -> None: - """Upgrade schema.""" ${upgrades if upgrades else "pass"} def downgrade() -> None: - """Downgrade schema.""" ${downgrades if downgrades else "pass"} diff --git a/venv/lib/python3.12/site-packages/alembic/templates/multidb/alembic.ini.mako b/venv/lib/python3.12/site-packages/alembic/templates/multidb/alembic.ini.mako index a662983..a9ea075 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/multidb/alembic.ini.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/multidb/alembic.ini.mako @@ -1,10 +1,7 @@ # a multi-database configuration. [alembic] -# path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file +# path to migration scripts script_location = ${script_location} # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s @@ -14,19 +11,19 @@ script_location = ${script_location} # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. for multiple paths, the path separator -# is defined by "path_separator" below. +# defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() # leave blank for localtime # timezone = -# max length of characters to apply to the "slug" field +# max length of characters to apply to the +# "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during @@ -39,37 +36,20 @@ prepend_sys_path = . # sourceless = false # version location specification; This defaults -# to /versions. When using multiple version +# to ${script_location}/versions. When using multiple version # directories, initial revisions must be specified with --version-path. -# The path separator used here should be the separator specified by "path_separator" -# below. -# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:${script_location}/versions -# path_separator; This indicates what character is used to split lists of file -# paths, including version_locations and prepend_sys_path within configparser -# files such as alembic.ini. -# The default rendered in new alembic.ini files is "os", which uses os.pathsep -# to provide os-dependent path splitting. +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: # -# Note that in order to support legacy alembic.ini files, this default does NOT -# take place if path_separator is not present in alembic.ini. If this -# option is omitted entirely, fallback logic is as follows: -# -# 1. Parsing of the version_locations option falls back to using the legacy -# "version_path_separator" key, which if absent then falls back to the legacy -# behavior of splitting on spaces and/or commas. -# 2. Parsing of the prepend_sys_path option falls back to the legacy -# behavior of splitting on spaces, commas, or colons. -# -# Valid values for path_separator are: -# -# path_separator = : -# path_separator = ; -# path_separator = space -# path_separator = newline -# -# Use os.pathsep. Default configuration used for new projects. -path_separator = os +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory @@ -80,13 +60,6 @@ path_separator = os # are written from script.py.mako # output_encoding = utf-8 -# for multiple database configuration, new named sections are added -# which each include a distinct ``sqlalchemy.url`` entry. A custom value -# ``databases`` is added which indicates a listing of the per-database sections. -# The ``databases`` entry as well as the URLs present in the ``[engine1]`` -# and ``[engine2]`` sections continue to be consumed by the user-maintained env.py -# script only. - databases = engine1, engine2 [engine1] @@ -106,20 +79,13 @@ sqlalchemy.url = driver://user:pass@localhost/dbname2 # black.entrypoint = black # black.options = -l 79 REVISION_SCRIPT_FILENAME -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# hooks = ruff -# ruff.type = module -# ruff.module = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME - -# Alternatively, use the exec runner to execute a binary found on your PATH +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary # hooks = ruff # ruff.type = exec -# ruff.executable = ruff -# ruff.options = check --fix REVISION_SCRIPT_FILENAME +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME -# Logging configuration. This is also consumed by the user-maintained -# env.py script only. +# Logging configuration [loggers] keys = root,sqlalchemy,alembic @@ -130,12 +96,12 @@ keys = console keys = generic [logger_root] -level = WARNING +level = WARN handlers = console qualname = [logger_sqlalchemy] -level = WARNING +level = WARN handlers = qualname = sqlalchemy.engine diff --git a/venv/lib/python3.12/site-packages/alembic/templates/multidb/script.py.mako b/venv/lib/python3.12/site-packages/alembic/templates/multidb/script.py.mako index 8e667d8..6108b8a 100644 --- a/venv/lib/python3.12/site-packages/alembic/templates/multidb/script.py.mako +++ b/venv/lib/python3.12/site-packages/alembic/templates/multidb/script.py.mako @@ -16,18 +16,16 @@ ${imports if imports else ""} # revision identifiers, used by Alembic. revision: str = ${repr(up_revision)} -down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} def upgrade(engine_name: str) -> None: - """Upgrade schema.""" globals()["upgrade_%s" % engine_name]() def downgrade(engine_name: str) -> None: - """Downgrade schema.""" globals()["downgrade_%s" % engine_name]() <% @@ -40,12 +38,10 @@ def downgrade(engine_name: str) -> None: % for db_name in re.split(r',\s*', db_names): def upgrade_${db_name}() -> None: - """Upgrade ${db_name} schema.""" ${context.get("%s_upgrades" % db_name, "pass")} def downgrade_${db_name}() -> None: - """Downgrade ${db_name} schema.""" ${context.get("%s_downgrades" % db_name, "pass")} % endfor diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/README b/venv/lib/python3.12/site-packages/alembic/templates/pyproject/README deleted file mode 100644 index fdacc05..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/README +++ /dev/null @@ -1 +0,0 @@ -pyproject configuration, based on the generic configuration. \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/alembic.ini.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject/alembic.ini.mako deleted file mode 100644 index 3d10f0e..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/alembic.ini.mako +++ /dev/null @@ -1,44 +0,0 @@ -# A generic, single database configuration. - -[alembic] - -# database URL. This is consumed by the user-maintained env.py script only. -# other means of configuring database URLs may be customized within the env.py -# file. -sqlalchemy.url = driver://user:pass@localhost/dbname - - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARNING -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARNING -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/env.py b/venv/lib/python3.12/site-packages/alembic/templates/pyproject/env.py deleted file mode 100644 index 36112a3..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/env.py +++ /dev/null @@ -1,78 +0,0 @@ -from logging.config import fileConfig - -from sqlalchemy import engine_from_config -from sqlalchemy import pool - -from alembic import context - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/pyproject.toml.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject/pyproject.toml.mako deleted file mode 100644 index e68cef3..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/pyproject.toml.mako +++ /dev/null @@ -1,82 +0,0 @@ -[tool.alembic] - -# path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file -script_location = "${script_location}" - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" - -# additional paths to be prepended to sys.path. defaults to the current working directory. -prepend_sys_path = [ - "." -] - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to /versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# version_locations = [ -# "%(here)s/alembic/versions", -# "%(here)s/foo/bar" -# ] - - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = "utf-8" - -# This section defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples -# [[tool.alembic.post_write_hooks]] -# format using "black" - use the console_scripts runner, -# against the "black" entrypoint -# name = "black" -# type = "console_scripts" -# entrypoint = "black" -# options = "-l 79 REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# name = "ruff" -# type = "module" -# module = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# Alternatively, use the exec runner to execute a binary found on your PATH -# name = "ruff" -# type = "exec" -# executable = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" - diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/script.py.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject/script.py.mako deleted file mode 100644 index 1101630..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject/script.py.mako +++ /dev/null @@ -1,28 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - """Upgrade schema.""" - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - """Downgrade schema.""" - ${downgrades if downgrades else "pass"} diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/README b/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/README deleted file mode 100644 index dfd718d..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/README +++ /dev/null @@ -1 +0,0 @@ -pyproject configuration, with an async dbapi. \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/alembic.ini.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/alembic.ini.mako deleted file mode 100644 index 3d10f0e..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/alembic.ini.mako +++ /dev/null @@ -1,44 +0,0 @@ -# A generic, single database configuration. - -[alembic] - -# database URL. This is consumed by the user-maintained env.py script only. -# other means of configuring database URLs may be customized within the env.py -# file. -sqlalchemy.url = driver://user:pass@localhost/dbname - - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARNING -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARNING -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/env.py b/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/env.py deleted file mode 100644 index 9f2d519..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/env.py +++ /dev/null @@ -1,89 +0,0 @@ -import asyncio -from logging.config import fileConfig - -from sqlalchemy import pool -from sqlalchemy.engine import Connection -from sqlalchemy.ext.asyncio import async_engine_from_config - -from alembic import context - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def do_run_migrations(connection: Connection) -> None: - context.configure(connection=connection, target_metadata=target_metadata) - - with context.begin_transaction(): - context.run_migrations() - - -async def run_async_migrations() -> None: - """In this scenario we need to create an Engine - and associate a connection with the context. - - """ - - connectable = async_engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - async with connectable.connect() as connection: - await connection.run_sync(do_run_migrations) - - await connectable.dispose() - - -def run_migrations_online() -> None: - """Run migrations in 'online' mode.""" - - asyncio.run(run_async_migrations()) - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/pyproject.toml.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/pyproject.toml.mako deleted file mode 100644 index e68cef3..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/pyproject.toml.mako +++ /dev/null @@ -1,82 +0,0 @@ -[tool.alembic] - -# path to migration scripts. -# this is typically a path given in POSIX (e.g. forward slashes) -# format, relative to the token %(here)s which refers to the location of this -# ini file -script_location = "${script_location}" - -# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s -# Uncomment the line below if you want the files to be prepended with date and time -# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file -# for all available tokens -# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" - -# additional paths to be prepended to sys.path. defaults to the current working directory. -prepend_sys_path = [ - "." -] - -# timezone to use when rendering the date within the migration file -# as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. -# Any required deps can installed by adding `alembic[tz]` to the pip requirements -# string value is passed to ZoneInfo() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; This defaults -# to /versions. When using multiple version -# directories, initial revisions must be specified with --version-path. -# version_locations = [ -# "%(here)s/alembic/versions", -# "%(here)s/foo/bar" -# ] - - -# set to 'true' to search source files recursively -# in each "version_locations" directory -# new in Alembic version 1.10 -# recursive_version_locations = false - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = "utf-8" - -# This section defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples -# [[tool.alembic.post_write_hooks]] -# format using "black" - use the console_scripts runner, -# against the "black" entrypoint -# name = "black" -# type = "console_scripts" -# entrypoint = "black" -# options = "-l 79 REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module -# name = "ruff" -# type = "module" -# module = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" -# -# [[tool.alembic.post_write_hooks]] -# Alternatively, use the exec runner to execute a binary found on your PATH -# name = "ruff" -# type = "exec" -# executable = "ruff" -# options = "check --fix REVISION_SCRIPT_FILENAME" - diff --git a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/script.py.mako b/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/script.py.mako deleted file mode 100644 index 1101630..0000000 --- a/venv/lib/python3.12/site-packages/alembic/templates/pyproject_async/script.py.mako +++ /dev/null @@ -1,28 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision: str = ${repr(up_revision)} -down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} -branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} -depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} - - -def upgrade() -> None: - """Upgrade schema.""" - ${upgrades if upgrades else "pass"} - - -def downgrade() -> None: - """Downgrade schema.""" - ${downgrades if downgrades else "pass"} diff --git a/venv/lib/python3.12/site-packages/alembic/testing/__init__.py b/venv/lib/python3.12/site-packages/alembic/testing/__init__.py index 3291508..0407adf 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/__init__.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/__init__.py @@ -9,15 +9,12 @@ from sqlalchemy.testing import uses_deprecated from sqlalchemy.testing.config import combinations from sqlalchemy.testing.config import fixture from sqlalchemy.testing.config import requirements as requires -from sqlalchemy.testing.config import Variation -from sqlalchemy.testing.config import variation from .assertions import assert_raises from .assertions import assert_raises_message from .assertions import emits_python_deprecation_warning from .assertions import eq_ from .assertions import eq_ignore_whitespace -from .assertions import expect_deprecated from .assertions import expect_raises from .assertions import expect_raises_message from .assertions import expect_sqlalchemy_deprecated diff --git a/venv/lib/python3.12/site-packages/alembic/testing/assertions.py b/venv/lib/python3.12/site-packages/alembic/testing/assertions.py index 898fbd1..ec9593b 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/assertions.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/assertions.py @@ -8,7 +8,6 @@ from typing import Dict from sqlalchemy import exc as sa_exc from sqlalchemy.engine import default -from sqlalchemy.engine import URL from sqlalchemy.testing.assertions import _expect_warnings from sqlalchemy.testing.assertions import eq_ # noqa from sqlalchemy.testing.assertions import is_ # noqa @@ -18,6 +17,8 @@ from sqlalchemy.testing.assertions import is_true # noqa from sqlalchemy.testing.assertions import ne_ # noqa from sqlalchemy.util import decorator +from ..util import sqla_compat + def _assert_proper_exception_context(exception): """assert that any exception we're catching does not have a __context__ @@ -73,9 +74,7 @@ class _ErrorContainer: @contextlib.contextmanager -def _expect_raises( - except_cls, msg=None, check_context=False, text_exact=False -): +def _expect_raises(except_cls, msg=None, check_context=False): ec = _ErrorContainer() if check_context: are_we_already_in_a_traceback = sys.exc_info()[0] @@ -86,10 +85,7 @@ def _expect_raises( ec.error = err success = True if msg is not None: - if text_exact: - assert str(err) == msg, f"{msg} != {err}" - else: - assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}" + assert re.search(msg, str(err), re.UNICODE), f"{msg} !~ {err}" if check_context and not are_we_already_in_a_traceback: _assert_proper_exception_context(err) print(str(err).encode("utf-8")) @@ -102,12 +98,8 @@ def expect_raises(except_cls, check_context=True): return _expect_raises(except_cls, check_context=check_context) -def expect_raises_message( - except_cls, msg, check_context=True, text_exact=False -): - return _expect_raises( - except_cls, msg=msg, check_context=check_context, text_exact=text_exact - ) +def expect_raises_message(except_cls, msg, check_context=True): + return _expect_raises(except_cls, msg=msg, check_context=check_context) def eq_ignore_whitespace(a, b, msg=None): @@ -126,7 +118,7 @@ def _get_dialect(name): if name is None or name == "default": return default.DefaultDialect() else: - d = URL.create(name).get_dialect()() + d = sqla_compat._create_url(name).get_dialect()() if name == "postgresql": d.implicit_returning = True @@ -167,10 +159,6 @@ def emits_python_deprecation_warning(*messages): return decorate -def expect_deprecated(*messages, **kw): - return _expect_warnings(DeprecationWarning, messages, **kw) - - def expect_sqlalchemy_deprecated(*messages, **kw): return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw) diff --git a/venv/lib/python3.12/site-packages/alembic/testing/env.py b/venv/lib/python3.12/site-packages/alembic/testing/env.py index 72a5e42..5df7ef8 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/env.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/env.py @@ -1,6 +1,5 @@ import importlib.machinery import os -from pathlib import Path import shutil import textwrap @@ -17,7 +16,7 @@ from ..script import ScriptDirectory def _get_staging_directory(): if provision.FOLLOWER_IDENT: - return f"scratch_{provision.FOLLOWER_IDENT}" + return "scratch_%s" % provision.FOLLOWER_IDENT else: return "scratch" @@ -25,7 +24,7 @@ def _get_staging_directory(): def staging_env(create=True, template="generic", sourceless=False): cfg = _testing_config() if create: - path = _join_path(_get_staging_directory(), "scripts") + path = os.path.join(_get_staging_directory(), "scripts") assert not os.path.exists(path), ( "staging directory %s already exists; poor cleanup?" % path ) @@ -48,7 +47,7 @@ def staging_env(create=True, template="generic", sourceless=False): "pep3147_everything", ), sourceless make_sourceless( - _join_path(path, "env.py"), + os.path.join(path, "env.py"), "pep3147" if "pep3147" in sourceless else "simple", ) @@ -64,14 +63,14 @@ def clear_staging_env(): def script_file_fixture(txt): - dir_ = _join_path(_get_staging_directory(), "scripts") - path = _join_path(dir_, "script.py.mako") + dir_ = os.path.join(_get_staging_directory(), "scripts") + path = os.path.join(dir_, "script.py.mako") with open(path, "w") as f: f.write(txt) def env_file_fixture(txt): - dir_ = _join_path(_get_staging_directory(), "scripts") + dir_ = os.path.join(_get_staging_directory(), "scripts") txt = ( """ from alembic import context @@ -81,7 +80,7 @@ config = context.config + txt ) - path = _join_path(dir_, "env.py") + path = os.path.join(dir_, "env.py") pyc_path = util.pyc_file_from_path(path) if pyc_path: os.unlink(pyc_path) @@ -91,26 +90,26 @@ config = context.config def _sqlite_file_db(tempname="foo.db", future=False, scope=None, **options): - dir_ = _join_path(_get_staging_directory(), "scripts") + dir_ = os.path.join(_get_staging_directory(), "scripts") url = "sqlite:///%s/%s" % (dir_, tempname) - if scope: + if scope and util.sqla_14: options["scope"] = scope return testing_util.testing_engine(url=url, future=future, options=options) def _sqlite_testing_config(sourceless=False, future=False): - dir_ = _join_path(_get_staging_directory(), "scripts") - url = f"sqlite:///{dir_}/foo.db" + dir_ = os.path.join(_get_staging_directory(), "scripts") + url = "sqlite:///%s/foo.db" % dir_ sqlalchemy_future = future or ("future" in config.db.__class__.__module__) return _write_config_file( - f""" + """ [alembic] -script_location = {dir_} -sqlalchemy.url = {url} -sourceless = {"true" if sourceless else "false"} -{"sqlalchemy.future = true" if sqlalchemy_future else ""} +script_location = %s +sqlalchemy.url = %s +sourceless = %s +%s [loggers] keys = root,sqlalchemy @@ -119,7 +118,7 @@ keys = root,sqlalchemy keys = console [logger_root] -level = WARNING +level = WARN handlers = console qualname = @@ -141,25 +140,29 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ + % ( + dir_, + url, + "true" if sourceless else "false", + "sqlalchemy.future = true" if sqlalchemy_future else "", + ) ) def _multi_dir_testing_config(sourceless=False, extra_version_location=""): - dir_ = _join_path(_get_staging_directory(), "scripts") + dir_ = os.path.join(_get_staging_directory(), "scripts") sqlalchemy_future = "future" in config.db.__class__.__module__ url = "sqlite:///%s/foo.db" % dir_ return _write_config_file( - f""" + """ [alembic] -script_location = {dir_} -sqlalchemy.url = {url} -sqlalchemy.future = {"true" if sqlalchemy_future else "false"} -sourceless = {"true" if sourceless else "false"} -path_separator = space -version_locations = %(here)s/model1/ %(here)s/model2/ %(here)s/model3/ \ -{extra_version_location} +script_location = %s +sqlalchemy.url = %s +sqlalchemy.future = %s +sourceless = %s +version_locations = %%(here)s/model1/ %%(here)s/model2/ %%(here)s/model3/ %s [loggers] keys = root @@ -168,7 +171,7 @@ keys = root keys = console [logger_root] -level = WARNING +level = WARN handlers = console qualname = @@ -185,63 +188,26 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ - ) - - -def _no_sql_pyproject_config(dialect="postgresql", directives=""): - """use a postgresql url with no host so that - connections guaranteed to fail""" - dir_ = _join_path(_get_staging_directory(), "scripts") - - return _write_toml_config( - f""" -[tool.alembic] -script_location ="{dir_}" -{textwrap.dedent(directives)} - - """, - f""" -[alembic] -sqlalchemy.url = {dialect}:// - -[loggers] -keys = root - -[handlers] -keys = console - -[logger_root] -level = WARNING -handlers = console -qualname = - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatters] -keys = generic - -[formatter_generic] -format = %%(levelname)-5.5s [%%(name)s] %%(message)s -datefmt = %%H:%%M:%%S - -""", + % ( + dir_, + url, + "true" if sqlalchemy_future else "false", + "true" if sourceless else "false", + extra_version_location, + ) ) def _no_sql_testing_config(dialect="postgresql", directives=""): """use a postgresql url with no host so that connections guaranteed to fail""" - dir_ = _join_path(_get_staging_directory(), "scripts") + dir_ = os.path.join(_get_staging_directory(), "scripts") return _write_config_file( - f""" + """ [alembic] -script_location ={dir_} -sqlalchemy.url = {dialect}:// -{directives} +script_location = %s +sqlalchemy.url = %s:// +%s [loggers] keys = root @@ -250,7 +216,7 @@ keys = root keys = console [logger_root] -level = WARNING +level = WARN handlers = console qualname = @@ -268,16 +234,10 @@ format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ + % (dir_, dialect, directives) ) -def _write_toml_config(tomltext, initext): - cfg = _write_config_file(initext) - with open(cfg.toml_file_name, "w") as f: - f.write(tomltext) - return cfg - - def _write_config_file(text): cfg = _testing_config() with open(cfg.config_file_name, "w") as f: @@ -290,10 +250,7 @@ def _testing_config(): if not os.access(_get_staging_directory(), os.F_OK): os.mkdir(_get_staging_directory()) - return Config( - _join_path(_get_staging_directory(), "test_alembic.ini"), - _join_path(_get_staging_directory(), "pyproject.toml"), - ) + return Config(os.path.join(_get_staging_directory(), "test_alembic.ini")) def write_script( @@ -313,7 +270,9 @@ def write_script( script = Script._from_path(scriptdir, path) old = scriptdir.revision_map.get_revision(script.revision) if old.down_revision != script.down_revision: - raise Exception("Can't change down_revision on a refresh operation.") + raise Exception( + "Can't change down_revision " "on a refresh operation." + ) scriptdir.revision_map.add_revision(script, _replace=True) if sourceless: @@ -353,9 +312,9 @@ def three_rev_fixture(cfg): write_script( script, a, - f"""\ + """\ "Rev A" -revision = '{a}' +revision = '%s' down_revision = None from alembic import op @@ -368,7 +327,8 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 1") -""", +""" + % a, ) script.generate_revision(b, "revision b", refresh=True, head=a) @@ -398,10 +358,10 @@ def downgrade(): write_script( script, c, - f"""\ + """\ "Rev C" -revision = '{c}' -down_revision = '{b}' +revision = '%s' +down_revision = '%s' from alembic import op @@ -413,7 +373,8 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 3") -""", +""" + % (c, b), ) return a, b, c @@ -435,10 +396,10 @@ def multi_heads_fixture(cfg, a, b, c): write_script( script, d, - f"""\ + """\ "Rev D" -revision = '{d}' -down_revision = '{b}' +revision = '%s' +down_revision = '%s' from alembic import op @@ -450,7 +411,8 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 4") -""", +""" + % (d, b), ) script.generate_revision( @@ -459,10 +421,10 @@ def downgrade(): write_script( script, e, - f"""\ + """\ "Rev E" -revision = '{e}' -down_revision = '{d}' +revision = '%s' +down_revision = '%s' from alembic import op @@ -474,7 +436,8 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 5") -""", +""" + % (e, d), ) script.generate_revision( @@ -483,10 +446,10 @@ def downgrade(): write_script( script, f, - f"""\ + """\ "Rev F" -revision = '{f}' -down_revision = '{b}' +revision = '%s' +down_revision = '%s' from alembic import op @@ -498,7 +461,8 @@ def upgrade(): def downgrade(): op.execute("DROP STEP 6") -""", +""" + % (f, b), ) return d, e, f @@ -507,25 +471,25 @@ def downgrade(): def _multidb_testing_config(engines): """alembic.ini fixture to work exactly with the 'multidb' template""" - dir_ = _join_path(_get_staging_directory(), "scripts") + dir_ = os.path.join(_get_staging_directory(), "scripts") sqlalchemy_future = "future" in config.db.__class__.__module__ databases = ", ".join(engines.keys()) engines = "\n\n".join( - f"[{key}]\nsqlalchemy.url = {value.url}" + "[%s]\n" "sqlalchemy.url = %s" % (key, value.url) for key, value in engines.items() ) return _write_config_file( - f""" + """ [alembic] -script_location = {dir_} +script_location = %s sourceless = false -sqlalchemy.future = {"true" if sqlalchemy_future else "false"} -databases = {databases} +sqlalchemy.future = %s +databases = %s -{engines} +%s [loggers] keys = root @@ -533,7 +497,7 @@ keys = root keys = console [logger_root] -level = WARNING +level = WARN handlers = console qualname = @@ -550,8 +514,5 @@ keys = generic format = %%(levelname)-5.5s [%%(name)s] %%(message)s datefmt = %%H:%%M:%%S """ + % (dir_, "true" if sqlalchemy_future else "false", databases, engines) ) - - -def _join_path(base: str, *more: str): - return str(Path(base).joinpath(*more).as_posix()) diff --git a/venv/lib/python3.12/site-packages/alembic/testing/fixtures.py b/venv/lib/python3.12/site-packages/alembic/testing/fixtures.py index 61bcd7e..4b83a74 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/fixtures.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/fixtures.py @@ -3,14 +3,11 @@ from __future__ import annotations import configparser from contextlib import contextmanager import io -import os import re -import shutil from typing import Any from typing import Dict from sqlalchemy import Column -from sqlalchemy import create_mock_engine from sqlalchemy import inspect from sqlalchemy import MetaData from sqlalchemy import String @@ -20,19 +17,20 @@ from sqlalchemy import text from sqlalchemy.testing import config from sqlalchemy.testing import mock from sqlalchemy.testing.assertions import eq_ -from sqlalchemy.testing.fixtures import FutureEngineMixin from sqlalchemy.testing.fixtures import TablesTest as SQLAlchemyTablesTest from sqlalchemy.testing.fixtures import TestBase as SQLAlchemyTestBase import alembic from .assertions import _get_dialect -from .env import _get_staging_directory from ..environment import EnvironmentContext from ..migration import MigrationContext from ..operations import Operations from ..util import sqla_compat +from ..util.sqla_compat import create_mock_engine +from ..util.sqla_compat import sqla_14 from ..util.sqla_compat import sqla_2 + testing_config = configparser.ConfigParser() testing_config.read(["test.cfg"]) @@ -40,31 +38,6 @@ testing_config.read(["test.cfg"]) class TestBase(SQLAlchemyTestBase): is_sqlalchemy_future = sqla_2 - @testing.fixture() - def clear_staging_dir(self): - yield - location = _get_staging_directory() - for filename in os.listdir(location): - file_path = os.path.join(location, filename) - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - - @contextmanager - def pushd(self, dirname): - current_dir = os.getcwd() - try: - os.chdir(dirname) - yield - finally: - os.chdir(current_dir) - - @testing.fixture() - def pop_alembic_config_env(self): - yield - os.environ.pop("ALEMBIC_CONFIG", None) - @testing.fixture() def ops_context(self, migration_context): with migration_context.begin_transaction(_per_migration=True): @@ -76,12 +49,6 @@ class TestBase(SQLAlchemyTestBase): connection, opts=dict(transaction_per_migration=True) ) - @testing.fixture - def as_sql_migration_context(self, connection): - return MigrationContext.configure( - connection, opts=dict(transaction_per_migration=True, as_sql=True) - ) - @testing.fixture def connection(self): with config.db.connect() as conn: @@ -92,6 +59,14 @@ class TablesTest(TestBase, SQLAlchemyTablesTest): pass +if sqla_14: + from sqlalchemy.testing.fixtures import FutureEngineMixin +else: + + class FutureEngineMixin: # type:ignore[no-redef] + __requires__ = ("sqlalchemy_14",) + + FutureEngineMixin.is_sqlalchemy_future = True @@ -209,8 +184,12 @@ def op_fixture( opts["as_sql"] = as_sql if literal_binds: opts["literal_binds"] = literal_binds + if not sqla_14 and dialect == "mariadb": + ctx_dialect = _get_dialect("mysql") + ctx_dialect.server_version_info = (10, 4, 0, "MariaDB") - ctx_dialect = _get_dialect(dialect) + else: + ctx_dialect = _get_dialect(dialect) if native_boolean is not None: ctx_dialect.supports_native_boolean = native_boolean # this is new as of SQLAlchemy 1.2.7 and is used by SQL Server, @@ -289,11 +268,9 @@ class AlterColRoundTripFixture: "x", column.name, existing_type=column.type, - existing_server_default=( - column.server_default - if column.server_default is not None - else False - ), + existing_server_default=column.server_default + if column.server_default is not None + else False, existing_nullable=True if column.nullable else False, # existing_comment=column.comment, nullable=to_.get("nullable", None), @@ -321,13 +298,9 @@ class AlterColRoundTripFixture: new_col["type"], new_col.get("default", None), compare.get("type", old_col["type"]), - ( - compare["server_default"].text - if "server_default" in compare - else ( - column.server_default.arg.text - if column.server_default is not None - else None - ) - ), + compare["server_default"].text + if "server_default" in compare + else column.server_default.arg.text + if column.server_default is not None + else None, ) diff --git a/venv/lib/python3.12/site-packages/alembic/testing/requirements.py b/venv/lib/python3.12/site-packages/alembic/testing/requirements.py index 8b63c16..2107da4 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/requirements.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/requirements.py @@ -1,6 +1,7 @@ from sqlalchemy.testing.requirements import Requirements from alembic import util +from alembic.util import sqla_compat from ..testing import exclusions @@ -73,6 +74,13 @@ class SuiteRequirements(Requirements): def reflects_fk_options(self): return exclusions.closed() + @property + def sqlalchemy_14(self): + return exclusions.skip_if( + lambda config: not util.sqla_14, + "SQLAlchemy 1.4 or greater required", + ) + @property def sqlalchemy_1x(self): return exclusions.skip_if( @@ -87,18 +95,6 @@ class SuiteRequirements(Requirements): "SQLAlchemy 2.x test", ) - @property - def asyncio(self): - def go(config): - try: - import greenlet # noqa: F401 - except ImportError: - return False - else: - return True - - return exclusions.only_if(go) - @property def comments(self): return exclusions.only_if( @@ -113,6 +109,26 @@ class SuiteRequirements(Requirements): def computed_columns(self): return exclusions.closed() + @property + def computed_columns_api(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_computed) + ) + + @property + def computed_reflects_normally(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_computed_reflection) + ) + + @property + def computed_reflects_as_server_default(self): + return exclusions.closed() + + @property + def computed_doesnt_reflect_as_server_default(self): + return exclusions.closed() + @property def autoincrement_on_composite_pk(self): return exclusions.closed() @@ -174,3 +190,9 @@ class SuiteRequirements(Requirements): @property def identity_columns_alter(self): return exclusions.closed() + + @property + def identity_columns_api(self): + return exclusions.only_if( + exclusions.BooleanPredicate(sqla_compat.has_identity) + ) diff --git a/venv/lib/python3.12/site-packages/alembic/testing/schemacompare.py b/venv/lib/python3.12/site-packages/alembic/testing/schemacompare.py index 204cc4d..c063499 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/schemacompare.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/schemacompare.py @@ -1,7 +1,6 @@ from itertools import zip_longest from sqlalchemy import schema -from sqlalchemy.sql.elements import ClauseList class CompareTable: @@ -61,14 +60,6 @@ class CompareIndex: def __ne__(self, other): return not self.__eq__(other) - def __repr__(self): - expr = ClauseList(*self.index.expressions) - try: - expr_str = expr.compile().string - except Exception: - expr_str = str(expr) - return f"" - class CompareCheckConstraint: def __init__(self, constraint): diff --git a/venv/lib/python3.12/site-packages/alembic/testing/suite/_autogen_fixtures.py b/venv/lib/python3.12/site-packages/alembic/testing/suite/_autogen_fixtures.py index ed4acb2..d838ebe 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/suite/_autogen_fixtures.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/suite/_autogen_fixtures.py @@ -14,7 +14,6 @@ from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import Numeric -from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import Text @@ -150,118 +149,6 @@ class ModelOne: return m -class NamingConvModel: - __requires__ = ("unique_constraint_reflection",) - configure_opts = {"conv_all_constraint_names": True} - naming_convention = { - "ix": "ix_%(column_0_label)s", - "uq": "uq_%(table_name)s_%(constraint_name)s", - "ck": "ck_%(table_name)s_%(constraint_name)s", - "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(table_name)s", - } - - @classmethod - def _get_db_schema(cls): - # database side - assume all constraints have a name that - # we would assume here is a "db generated" name. need to make - # sure these all render with op.f(). - m = MetaData() - Table( - "x1", - m, - Column("q", Integer), - Index("db_x1_index_q", "q"), - PrimaryKeyConstraint("q", name="db_x1_primary_q"), - ) - Table( - "x2", - m, - Column("q", Integer), - Column("p", ForeignKey("x1.q", name="db_x2_foreign_q")), - CheckConstraint("q > 5", name="db_x2_check_q"), - ) - Table( - "x3", - m, - Column("q", Integer), - Column("r", Integer), - Column("s", Integer), - UniqueConstraint("q", name="db_x3_unique_q"), - ) - Table( - "x4", - m, - Column("q", Integer), - PrimaryKeyConstraint("q", name="db_x4_primary_q"), - ) - Table( - "x5", - m, - Column("q", Integer), - Column("p", ForeignKey("x4.q", name="db_x5_foreign_q")), - Column("r", Integer), - Column("s", Integer), - PrimaryKeyConstraint("q", name="db_x5_primary_q"), - UniqueConstraint("r", name="db_x5_unique_r"), - CheckConstraint("s > 5", name="db_x5_check_s"), - ) - # SQLite and it's "no names needed" thing. bleh. - # we can't have a name for these so you'll see "None" for the name. - Table( - "unnamed_sqlite", - m, - Column("q", Integer), - Column("r", Integer), - PrimaryKeyConstraint("q"), - UniqueConstraint("r"), - ) - return m - - @classmethod - def _get_model_schema(cls): - from sqlalchemy.sql.naming import conv - - m = MetaData(naming_convention=cls.naming_convention) - Table( - "x1", m, Column("q", Integer, primary_key=True), Index(None, "q") - ) - Table( - "x2", - m, - Column("q", Integer), - Column("p", ForeignKey("x1.q")), - CheckConstraint("q > 5", name="token_x2check1"), - ) - Table( - "x3", - m, - Column("q", Integer), - Column("r", Integer), - Column("s", Integer), - UniqueConstraint("r", name="token_x3r"), - UniqueConstraint("s", name=conv("userdef_x3_unique_s")), - ) - Table( - "x4", - m, - Column("q", Integer, primary_key=True), - Index("userdef_x4_idx_q", "q"), - ) - Table( - "x6", - m, - Column("q", Integer, primary_key=True), - Column("p", ForeignKey("x4.q")), - Column("r", Integer), - Column("s", Integer), - UniqueConstraint("r", name="token_x6r"), - CheckConstraint("s > 5", "token_x6check1"), - CheckConstraint("s < 20", conv("userdef_x6_check_s")), - ) - return m - - class _ComparesFKs: def _assert_fk_diff( self, diff --git a/venv/lib/python3.12/site-packages/alembic/testing/suite/test_autogen_computed.py b/venv/lib/python3.12/site-packages/alembic/testing/suite/test_autogen_computed.py index fe7eb7a..01a89a1 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/suite/test_autogen_computed.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/suite/test_autogen_computed.py @@ -6,7 +6,9 @@ from sqlalchemy import Table from ._autogen_fixtures import AutogenFixtureTest from ... import testing +from ...testing import config from ...testing import eq_ +from ...testing import exclusions from ...testing import is_ from ...testing import is_true from ...testing import mock @@ -61,8 +63,18 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase): c = diffs[0][3] eq_(c.name, "foo") - is_true(isinstance(c.computed, sa.Computed)) - is_true(isinstance(c.server_default, sa.Computed)) + if config.requirements.computed_reflects_normally.enabled: + is_true(isinstance(c.computed, sa.Computed)) + else: + is_(c.computed, None) + + if config.requirements.computed_reflects_as_server_default.enabled: + is_true(isinstance(c.server_default, sa.DefaultClause)) + eq_(str(c.server_default.arg.text), "5") + elif config.requirements.computed_reflects_normally.enabled: + is_true(isinstance(c.computed, sa.Computed)) + else: + is_(c.computed, None) @testing.combinations( lambda: (None, sa.Computed("bar*5")), @@ -73,6 +85,7 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase): ), lambda: (sa.Computed("bar*5"), sa.Computed("bar * 42")), ) + @config.requirements.computed_reflects_normally def test_cant_change_computed_warning(self, test_case): arg_before, arg_after = testing.resolve_lambda(test_case, **locals()) m1 = MetaData() @@ -111,7 +124,10 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase): lambda: (None, None), lambda: (sa.Computed("5"), sa.Computed("5")), lambda: (sa.Computed("bar*5"), sa.Computed("bar*5")), - lambda: (sa.Computed("bar*5"), sa.Computed("bar * \r\n\t5")), + ( + lambda: (sa.Computed("bar*5"), None), + config.requirements.computed_doesnt_reflect_as_server_default, + ), ) def test_computed_unchanged(self, test_case): arg_before, arg_after = testing.resolve_lambda(test_case, **locals()) @@ -142,3 +158,46 @@ class AutogenerateComputedTest(AutogenFixtureTest, TestBase): eq_(mock_warn.mock_calls, []) eq_(list(diffs), []) + + @config.requirements.computed_reflects_as_server_default + def test_remove_computed_default_on_computed(self): + """Asserts the current behavior which is that on PG and Oracle, + the GENERATED ALWAYS AS is reflected as a server default which we can't + tell is actually "computed", so these come out as a modification to + the server default. + + """ + m1 = MetaData() + m2 = MetaData() + + Table( + "user", + m1, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer, sa.Computed("bar + 42")), + ) + + Table( + "user", + m2, + Column("id", Integer, primary_key=True), + Column("bar", Integer), + Column("foo", Integer), + ) + + diffs = self._fixture(m1, m2) + + eq_(diffs[0][0][0], "modify_default") + eq_(diffs[0][0][2], "user") + eq_(diffs[0][0][3], "foo") + old = diffs[0][0][-2] + new = diffs[0][0][-1] + + is_(new, None) + is_true(isinstance(old, sa.DefaultClause)) + + if exclusions.against(config, "postgresql"): + eq_(str(old.arg.text), "(bar + 42)") + elif exclusions.against(config, "oracle"): + eq_(str(old.arg.text), '"BAR"+42') diff --git a/venv/lib/python3.12/site-packages/alembic/testing/suite/test_environment.py b/venv/lib/python3.12/site-packages/alembic/testing/suite/test_environment.py index df2d9af..8c86859 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/suite/test_environment.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/suite/test_environment.py @@ -24,9 +24,9 @@ class MigrationTransactionTest(TestBase): self.context = MigrationContext.configure( dialect=conn.dialect, opts=opts ) - self.context.output_buffer = self.context.impl.output_buffer = ( - io.StringIO() - ) + self.context.output_buffer = ( + self.context.impl.output_buffer + ) = io.StringIO() else: self.context = MigrationContext.configure( connection=conn, opts=opts diff --git a/venv/lib/python3.12/site-packages/alembic/testing/warnings.py b/venv/lib/python3.12/site-packages/alembic/testing/warnings.py index 86d45a0..e87136b 100644 --- a/venv/lib/python3.12/site-packages/alembic/testing/warnings.py +++ b/venv/lib/python3.12/site-packages/alembic/testing/warnings.py @@ -10,6 +10,8 @@ import warnings from sqlalchemy import exc as sa_exc +from ..util import sqla_14 + def setup_filters(): """Set global warning behavior for the test suite.""" @@ -21,6 +23,13 @@ def setup_filters(): # some selected deprecations... warnings.filterwarnings("error", category=DeprecationWarning) + if not sqla_14: + # 1.3 uses pkg_resources in PluginLoader + warnings.filterwarnings( + "ignore", + "pkg_resources is deprecated as an API", + DeprecationWarning, + ) try: import pytest except ImportError: diff --git a/venv/lib/python3.12/site-packages/alembic/util/__init__.py b/venv/lib/python3.12/site-packages/alembic/util/__init__.py index 1d3a217..3c1e27c 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/__init__.py +++ b/venv/lib/python3.12/site-packages/alembic/util/__init__.py @@ -1,29 +1,35 @@ -from .editor import open_in_editor as open_in_editor -from .exc import AutogenerateDiffsDetected as AutogenerateDiffsDetected -from .exc import CommandError as CommandError -from .langhelpers import _with_legacy_names as _with_legacy_names -from .langhelpers import asbool as asbool -from .langhelpers import dedupe_tuple as dedupe_tuple -from .langhelpers import Dispatcher as Dispatcher -from .langhelpers import EMPTY_DICT as EMPTY_DICT -from .langhelpers import immutabledict as immutabledict -from .langhelpers import memoized_property as memoized_property -from .langhelpers import ModuleClsProxy as ModuleClsProxy -from .langhelpers import not_none as not_none -from .langhelpers import rev_id as rev_id -from .langhelpers import to_list as to_list -from .langhelpers import to_tuple as to_tuple -from .langhelpers import unique_list as unique_list -from .messaging import err as err -from .messaging import format_as_comma as format_as_comma -from .messaging import msg as msg -from .messaging import obfuscate_url_pw as obfuscate_url_pw -from .messaging import status as status -from .messaging import warn as warn -from .messaging import warn_deprecated as warn_deprecated -from .messaging import write_outstream as write_outstream -from .pyfiles import coerce_resource_to_filename as coerce_resource_to_filename -from .pyfiles import load_python_file as load_python_file -from .pyfiles import pyc_file_from_path as pyc_file_from_path -from .pyfiles import template_to_file as template_to_file -from .sqla_compat import sqla_2 as sqla_2 +from .editor import open_in_editor +from .exc import AutogenerateDiffsDetected +from .exc import CommandError +from .langhelpers import _with_legacy_names +from .langhelpers import asbool +from .langhelpers import dedupe_tuple +from .langhelpers import Dispatcher +from .langhelpers import EMPTY_DICT +from .langhelpers import immutabledict +from .langhelpers import memoized_property +from .langhelpers import ModuleClsProxy +from .langhelpers import not_none +from .langhelpers import rev_id +from .langhelpers import to_list +from .langhelpers import to_tuple +from .langhelpers import unique_list +from .messaging import err +from .messaging import format_as_comma +from .messaging import msg +from .messaging import obfuscate_url_pw +from .messaging import status +from .messaging import warn +from .messaging import write_outstream +from .pyfiles import coerce_resource_to_filename +from .pyfiles import load_python_file +from .pyfiles import pyc_file_from_path +from .pyfiles import template_to_file +from .sqla_compat import has_computed +from .sqla_compat import sqla_13 +from .sqla_compat import sqla_14 +from .sqla_compat import sqla_2 + + +if not sqla_13: + raise CommandError("SQLAlchemy 1.3.0 or greater is required.") diff --git a/venv/lib/python3.12/site-packages/alembic/util/compat.py b/venv/lib/python3.12/site-packages/alembic/util/compat.py index 131f16a..31e0208 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/compat.py +++ b/venv/lib/python3.12/site-packages/alembic/util/compat.py @@ -1,37 +1,22 @@ -# mypy: no-warn-unused-ignores - from __future__ import annotations from configparser import ConfigParser import io import os -from pathlib import Path import sys import typing -from typing import Any -from typing import Iterator -from typing import List -from typing import Optional from typing import Sequence from typing import Union -if True: - # zimports hack for too-long names - from sqlalchemy.util import ( # noqa: F401 - inspect_getfullargspec as inspect_getfullargspec, - ) - from sqlalchemy.util.compat import ( # noqa: F401 - inspect_formatargspec as inspect_formatargspec, - ) +from sqlalchemy.util import inspect_getfullargspec # noqa +from sqlalchemy.util.compat import inspect_formatargspec # noqa is_posix = os.name == "posix" -py314 = sys.version_info >= (3, 14) -py313 = sys.version_info >= (3, 13) -py312 = sys.version_info >= (3, 12) py311 = sys.version_info >= (3, 11) py310 = sys.version_info >= (3, 10) py39 = sys.version_info >= (3, 9) +py38 = sys.version_info >= (3, 8) # produce a wrapper that allows encoded text to stream @@ -43,82 +28,24 @@ class EncodedIO(io.TextIOWrapper): if py39: - from importlib import resources as _resources - - importlib_resources = _resources - from importlib import metadata as _metadata - - importlib_metadata = _metadata - from importlib.metadata import EntryPoint as EntryPoint + from importlib import resources as importlib_resources + from importlib import metadata as importlib_metadata + from importlib.metadata import EntryPoint else: import importlib_resources # type:ignore # noqa import importlib_metadata # type:ignore # noqa from importlib_metadata import EntryPoint # type:ignore # noqa -if py311: - import tomllib as tomllib -else: - import tomli as tomllib # type: ignore # noqa - - -if py312: - - def path_walk( - path: Path, *, top_down: bool = True - ) -> Iterator[tuple[Path, list[str], list[str]]]: - return Path.walk(path) - - def path_relative_to( - path: Path, other: Path, *, walk_up: bool = False - ) -> Path: - return path.relative_to(other, walk_up=walk_up) - -else: - - def path_walk( - path: Path, *, top_down: bool = True - ) -> Iterator[tuple[Path, list[str], list[str]]]: - for root, dirs, files in os.walk(path, topdown=top_down): - yield Path(root), dirs, files - - def path_relative_to( - path: Path, other: Path, *, walk_up: bool = False - ) -> Path: - """ - Calculate the relative path of 'path' with respect to 'other', - optionally allowing 'path' to be outside the subtree of 'other'. - - OK I used AI for this, sorry - - """ - try: - return path.relative_to(other) - except ValueError: - if walk_up: - other_ancestors = list(other.parents) + [other] - for ancestor in other_ancestors: - try: - return path.relative_to(ancestor) - except ValueError: - continue - raise ValueError( - f"{path} is not in the same subtree as {other}" - ) - else: - raise - def importlib_metadata_get(group: str) -> Sequence[EntryPoint]: ep = importlib_metadata.entry_points() if hasattr(ep, "select"): - return ep.select(group=group) + return ep.select(group=group) # type: ignore else: return ep.get(group, ()) # type: ignore -def formatannotation_fwdref( - annotation: Any, base_module: Optional[Any] = None -) -> str: +def formatannotation_fwdref(annotation, base_module=None): """vendored from python 3.7""" # copied over _formatannotation from sqlalchemy 2.0 @@ -139,7 +66,7 @@ def formatannotation_fwdref( def read_config_parser( file_config: ConfigParser, file_argument: Sequence[Union[str, os.PathLike[str]]], -) -> List[str]: +) -> list[str]: if py310: return file_config.read(file_argument, encoding="locale") else: diff --git a/venv/lib/python3.12/site-packages/alembic/util/exc.py b/venv/lib/python3.12/site-packages/alembic/util/exc.py index c790e18..0d0496b 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/exc.py +++ b/venv/lib/python3.12/site-packages/alembic/util/exc.py @@ -1,25 +1,6 @@ -from __future__ import annotations - -from typing import Any -from typing import List -from typing import Tuple -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from alembic.autogenerate import RevisionContext - - class CommandError(Exception): pass class AutogenerateDiffsDetected(CommandError): - def __init__( - self, - message: str, - revision_context: RevisionContext, - diffs: List[Tuple[Any, ...]], - ) -> None: - super().__init__(message) - self.revision_context = revision_context - self.diffs = diffs + pass diff --git a/venv/lib/python3.12/site-packages/alembic/util/langhelpers.py b/venv/lib/python3.12/site-packages/alembic/util/langhelpers.py index 80d88cb..34d48bc 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/langhelpers.py +++ b/venv/lib/python3.12/site-packages/alembic/util/langhelpers.py @@ -5,46 +5,33 @@ from collections.abc import Iterable import textwrap from typing import Any from typing import Callable -from typing import cast from typing import Dict from typing import List from typing import Mapping -from typing import MutableMapping -from typing import NoReturn from typing import Optional from typing import overload from typing import Sequence -from typing import Set from typing import Tuple -from typing import Type -from typing import TYPE_CHECKING from typing import TypeVar from typing import Union import uuid import warnings -from sqlalchemy.util import asbool as asbool # noqa: F401 -from sqlalchemy.util import immutabledict as immutabledict # noqa: F401 -from sqlalchemy.util import to_list as to_list # noqa: F401 -from sqlalchemy.util import unique_list as unique_list +from sqlalchemy.util import asbool # noqa +from sqlalchemy.util import immutabledict # noqa +from sqlalchemy.util import memoized_property # noqa +from sqlalchemy.util import to_list # noqa +from sqlalchemy.util import unique_list # noqa from .compat import inspect_getfullargspec -if True: - # zimports workaround :( - from sqlalchemy.util import ( # noqa: F401 - memoized_property as memoized_property, - ) - EMPTY_DICT: Mapping[Any, Any] = immutabledict() -_T = TypeVar("_T", bound=Any) - -_C = TypeVar("_C", bound=Callable[..., Any]) +_T = TypeVar("_T") class _ModuleClsMeta(type): - def __setattr__(cls, key: str, value: Callable[..., Any]) -> None: + def __setattr__(cls, key: str, value: Callable) -> None: super().__setattr__(key, value) cls._update_module_proxies(key) # type: ignore @@ -58,13 +45,9 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): """ - _setups: Dict[ - Type[Any], - Tuple[ - Set[str], - List[Tuple[MutableMapping[str, Any], MutableMapping[str, Any]]], - ], - ] = collections.defaultdict(lambda: (set(), [])) + _setups: Dict[type, Tuple[set, list]] = collections.defaultdict( + lambda: (set(), []) + ) @classmethod def _update_module_proxies(cls, name: str) -> None: @@ -87,33 +70,18 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): del globals_[attr_name] @classmethod - def create_module_class_proxy( - cls, - globals_: MutableMapping[str, Any], - locals_: MutableMapping[str, Any], - ) -> None: + def create_module_class_proxy(cls, globals_, locals_): attr_names, modules = cls._setups[cls] modules.append((globals_, locals_)) cls._setup_proxy(globals_, locals_, attr_names) @classmethod - def _setup_proxy( - cls, - globals_: MutableMapping[str, Any], - locals_: MutableMapping[str, Any], - attr_names: Set[str], - ) -> None: + def _setup_proxy(cls, globals_, locals_, attr_names): for methname in dir(cls): cls._add_proxied_attribute(methname, globals_, locals_, attr_names) @classmethod - def _add_proxied_attribute( - cls, - methname: str, - globals_: MutableMapping[str, Any], - locals_: MutableMapping[str, Any], - attr_names: Set[str], - ) -> None: + def _add_proxied_attribute(cls, methname, globals_, locals_, attr_names): if not methname.startswith("_"): meth = getattr(cls, methname) if callable(meth): @@ -124,15 +92,10 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): attr_names.add(methname) @classmethod - def _create_method_proxy( - cls, - name: str, - globals_: MutableMapping[str, Any], - locals_: MutableMapping[str, Any], - ) -> Callable[..., Any]: + def _create_method_proxy(cls, name, globals_, locals_): fn = getattr(cls, name) - def _name_error(name: str, from_: Exception) -> NoReturn: + def _name_error(name, from_): raise NameError( "Can't invoke function '%s', as the proxy object has " "not yet been " @@ -156,9 +119,7 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): translations, ) - def translate( - fn_name: str, spec: Any, translations: Any, args: Any, kw: Any - ) -> Any: + def translate(fn_name, spec, translations, args, kw): return_kw = {} return_args = [] @@ -215,15 +176,15 @@ class ModuleClsProxy(metaclass=_ModuleClsMeta): "doc": fn.__doc__, } ) - lcl: MutableMapping[str, Any] = {} + lcl = {} - exec(func_text, cast("Dict[str, Any]", globals_), lcl) - return cast("Callable[..., Any]", lcl[name]) + exec(func_text, globals_, lcl) + return lcl[name] -def _with_legacy_names(translations: Any) -> Any: - def decorate(fn: _C) -> _C: - fn._legacy_translations = translations # type: ignore[attr-defined] +def _with_legacy_names(translations): + def decorate(fn): + fn._legacy_translations = translations return fn return decorate @@ -234,22 +195,21 @@ def rev_id() -> str: @overload -def to_tuple(x: Any, default: Tuple[Any, ...]) -> Tuple[Any, ...]: ... +def to_tuple(x: Any, default: tuple) -> tuple: + ... @overload -def to_tuple(x: None, default: Optional[_T] = ...) -> _T: ... +def to_tuple(x: None, default: Optional[_T] = None) -> _T: + ... @overload -def to_tuple( - x: Any, default: Optional[Tuple[Any, ...]] = None -) -> Tuple[Any, ...]: ... +def to_tuple(x: Any, default: Optional[tuple] = None) -> tuple: + ... -def to_tuple( - x: Any, default: Optional[Tuple[Any, ...]] = None -) -> Optional[Tuple[Any, ...]]: +def to_tuple(x, default=None): if x is None: return default elif isinstance(x, str): @@ -266,13 +226,13 @@ def dedupe_tuple(tup: Tuple[str, ...]) -> Tuple[str, ...]: class Dispatcher: def __init__(self, uselist: bool = False) -> None: - self._registry: Dict[Tuple[Any, ...], Any] = {} + self._registry: Dict[tuple, Any] = {} self.uselist = uselist def dispatch_for( self, target: Any, qualifier: str = "default" - ) -> Callable[[_C], _C]: - def decorate(fn: _C) -> _C: + ) -> Callable: + def decorate(fn): if self.uselist: self._registry.setdefault((target, qualifier), []).append(fn) else: @@ -284,7 +244,7 @@ class Dispatcher: def dispatch(self, obj: Any, qualifier: str = "default") -> Any: if isinstance(obj, str): - targets: Sequence[Any] = [obj] + targets: Sequence = [obj] elif isinstance(obj, type): targets = obj.__mro__ else: @@ -299,13 +259,11 @@ class Dispatcher: raise ValueError("no dispatch function for object: %s" % obj) def _fn_or_list( - self, fn_or_list: Union[List[Callable[..., Any]], Callable[..., Any]] - ) -> Callable[..., Any]: + self, fn_or_list: Union[List[Callable], Callable] + ) -> Callable: if self.uselist: - def go(*arg: Any, **kw: Any) -> None: - if TYPE_CHECKING: - assert isinstance(fn_or_list, Sequence) + def go(*arg, **kw): for fn in fn_or_list: fn(*arg, **kw) diff --git a/venv/lib/python3.12/site-packages/alembic/util/messaging.py b/venv/lib/python3.12/site-packages/alembic/util/messaging.py index 4c08f16..35592c0 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/messaging.py +++ b/venv/lib/python3.12/site-packages/alembic/util/messaging.py @@ -5,7 +5,6 @@ from contextlib import contextmanager import logging import sys import textwrap -from typing import Iterator from typing import Optional from typing import TextIO from typing import Union @@ -13,6 +12,8 @@ import warnings from sqlalchemy.engine import url +from . import sqla_compat + log = logging.getLogger(__name__) # disable "no handler found" errors @@ -52,9 +53,7 @@ def write_outstream( @contextmanager -def status( - status_msg: str, newline: bool = False, quiet: bool = False -) -> Iterator[None]: +def status(status_msg: str, newline: bool = False, quiet: bool = False): msg(status_msg + " ...", newline, flush=True, quiet=quiet) try: yield @@ -67,24 +66,21 @@ def status( write_outstream(sys.stdout, " done\n") -def err(message: str, quiet: bool = False) -> None: +def err(message: str, quiet: bool = False): log.error(message) msg(f"FAILED: {message}", quiet=quiet) sys.exit(-1) def obfuscate_url_pw(input_url: str) -> str: - return url.make_url(input_url).render_as_string(hide_password=True) + u = url.make_url(input_url) + return sqla_compat.url_render_as_string(u, hide_password=True) def warn(msg: str, stacklevel: int = 2) -> None: warnings.warn(msg, UserWarning, stacklevel=stacklevel) -def warn_deprecated(msg: str, stacklevel: int = 2) -> None: - warnings.warn(msg, DeprecationWarning, stacklevel=stacklevel) - - def msg( msg: str, newline: bool = True, flush: bool = False, quiet: bool = False ) -> None: @@ -96,17 +92,11 @@ def msg( write_outstream(sys.stdout, "\n") else: # left indent output lines - indent = " " - lines = textwrap.wrap( - msg, - TERMWIDTH, - initial_indent=indent, - subsequent_indent=indent, - ) + lines = textwrap.wrap(msg, TERMWIDTH) if len(lines) > 1: for line in lines[0:-1]: - write_outstream(sys.stdout, line, "\n") - write_outstream(sys.stdout, lines[-1], ("\n" if newline else "")) + write_outstream(sys.stdout, " ", line, "\n") + write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else "")) if flush: sys.stdout.flush() diff --git a/venv/lib/python3.12/site-packages/alembic/util/pyfiles.py b/venv/lib/python3.12/site-packages/alembic/util/pyfiles.py index 6b75d57..e757673 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/pyfiles.py +++ b/venv/lib/python3.12/site-packages/alembic/util/pyfiles.py @@ -6,13 +6,9 @@ import importlib import importlib.machinery import importlib.util import os -import pathlib import re import tempfile -from types import ModuleType -from typing import Any from typing import Optional -from typing import Union from mako import exceptions from mako.template import Template @@ -22,14 +18,9 @@ from .exc import CommandError def template_to_file( - template_file: Union[str, os.PathLike[str]], - dest: Union[str, os.PathLike[str]], - output_encoding: str, - *, - append_with_newlines: bool = False, - **kw: Any, + template_file: str, dest: str, output_encoding: str, **kw ) -> None: - template = Template(filename=_preserving_path_as_str(template_file)) + template = Template(filename=template_file) try: output = template.render_unicode(**kw).encode(output_encoding) except: @@ -45,13 +36,11 @@ def template_to_file( "template-oriented traceback." % fname ) else: - with open(dest, "ab" if append_with_newlines else "wb") as f: - if append_with_newlines: - f.write("\n\n".encode(output_encoding)) + with open(dest, "wb") as f: f.write(output) -def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path: +def coerce_resource_to_filename(fname: str) -> str: """Interpret a filename as either a filesystem location or as a package resource. @@ -59,9 +48,8 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path: are interpreted as resources and coerced to a file location. """ - # TODO: there seem to be zero tests for the package resource codepath - if not os.path.isabs(fname_or_resource) and ":" in fname_or_resource: - tokens = fname_or_resource.split(":") + if not os.path.isabs(fname) and ":" in fname: + tokens = fname.split(":") # from https://importlib-resources.readthedocs.io/en/latest/migration.html#pkg-resources-resource-filename # noqa E501 @@ -71,48 +59,37 @@ def coerce_resource_to_filename(fname_or_resource: str) -> pathlib.Path: ref = compat.importlib_resources.files(tokens[0]) for tok in tokens[1:]: ref = ref / tok - fname_or_resource = file_manager.enter_context( # type: ignore[assignment] # noqa: E501 + fname = file_manager.enter_context( # type: ignore[assignment] compat.importlib_resources.as_file(ref) ) - return pathlib.Path(fname_or_resource) + return fname -def pyc_file_from_path( - path: Union[str, os.PathLike[str]], -) -> Optional[pathlib.Path]: +def pyc_file_from_path(path: str) -> Optional[str]: """Given a python source path, locate the .pyc.""" - pathpath = pathlib.Path(path) - candidate = pathlib.Path( - importlib.util.cache_from_source(pathpath.as_posix()) - ) - if candidate.exists(): + candidate = importlib.util.cache_from_source(path) + if os.path.exists(candidate): return candidate # even for pep3147, fall back to the old way of finding .pyc files, # to support sourceless operation - ext = pathpath.suffix + filepath, ext = os.path.splitext(path) for ext in importlib.machinery.BYTECODE_SUFFIXES: - if pathpath.with_suffix(ext).exists(): - return pathpath.with_suffix(ext) + if os.path.exists(filepath + ext): + return filepath + ext else: return None -def load_python_file( - dir_: Union[str, os.PathLike[str]], filename: Union[str, os.PathLike[str]] -) -> ModuleType: +def load_python_file(dir_: str, filename: str): """Load a file from the given path as a Python module.""" - dir_ = pathlib.Path(dir_) - filename_as_path = pathlib.Path(filename) - filename = filename_as_path.name - module_id = re.sub(r"\W", "_", filename) - path = dir_ / filename - ext = path.suffix + path = os.path.join(dir_, filename) + _, ext = os.path.splitext(filename) if ext == ".py": - if path.exists(): + if os.path.exists(path): module = load_module_py(module_id, path) else: pyc_path = pyc_file_from_path(path) @@ -122,32 +99,12 @@ def load_python_file( module = load_module_py(module_id, pyc_path) elif ext in (".pyc", ".pyo"): module = load_module_py(module_id, path) - else: - assert False return module -def load_module_py( - module_id: str, path: Union[str, os.PathLike[str]] -) -> ModuleType: +def load_module_py(module_id: str, path: str): spec = importlib.util.spec_from_file_location(module_id, path) assert spec module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore return module - - -def _preserving_path_as_str(path: Union[str, os.PathLike[str]]) -> str: - """receive str/pathlike and return a string. - - Does not convert an incoming string path to a Path first, to help with - unit tests that are doing string path round trips without OS-specific - processing if not necessary. - - """ - if isinstance(path, str): - return path - elif isinstance(path, pathlib.PurePath): - return str(path) - else: - return str(pathlib.Path(path)) diff --git a/venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py b/venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py index a909ead..3f175cf 100644 --- a/venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py +++ b/venv/lib/python3.12/site-packages/alembic/util/sqla_compat.py @@ -1,27 +1,24 @@ -# mypy: allow-untyped-defs, allow-incomplete-defs, allow-untyped-calls -# mypy: no-warn-return-any, allow-any-generics - from __future__ import annotations import contextlib import re from typing import Any -from typing import Callable from typing import Dict from typing import Iterable from typing import Iterator +from typing import Mapping from typing import Optional -from typing import Protocol -from typing import Set -from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union from sqlalchemy import __version__ +from sqlalchemy import inspect from sqlalchemy import schema from sqlalchemy import sql from sqlalchemy import types as sqltypes +from sqlalchemy.engine import url +from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import CheckConstraint from sqlalchemy.schema import Column from sqlalchemy.schema import ForeignKeyConstraint @@ -29,33 +26,31 @@ from sqlalchemy.sql import visitors from sqlalchemy.sql.base import DialectKWArgs from sqlalchemy.sql.elements import BindParameter from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.elements import TextClause from sqlalchemy.sql.elements import UnaryExpression -from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME # type: ignore[attr-defined] # noqa: E501 from sqlalchemy.sql.visitors import traverse from typing_extensions import TypeGuard if TYPE_CHECKING: - from sqlalchemy import ClauseElement - from sqlalchemy import Identity from sqlalchemy import Index from sqlalchemy import Table from sqlalchemy.engine import Connection from sqlalchemy.engine import Dialect from sqlalchemy.engine import Transaction + from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.base import ColumnCollection from sqlalchemy.sql.compiler import SQLCompiler + from sqlalchemy.sql.dml import Insert from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.selectable import Select + from sqlalchemy.sql.selectable import TableClause _CE = TypeVar("_CE", bound=Union["ColumnElement[Any]", "SchemaItem"]) -class _CompilerProtocol(Protocol): - def __call__(self, element: Any, compiler: Any, **kw: Any) -> str: ... - - def _safe_int(value: str) -> Union[int, str]: try: return int(value) @@ -66,65 +61,90 @@ def _safe_int(value: str) -> Union[int, str]: _vers = tuple( [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] ) +sqla_13 = _vers >= (1, 3) +sqla_14 = _vers >= (1, 4) # https://docs.sqlalchemy.org/en/latest/changelog/changelog_14.html#change-0c6e0cc67dfe6fac5164720e57ef307d sqla_14_18 = _vers >= (1, 4, 18) sqla_14_26 = _vers >= (1, 4, 26) sqla_2 = _vers >= (2,) sqlalchemy_version = __version__ -if TYPE_CHECKING: +try: + from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME +except ImportError: + from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501 - def compiles( - element: Type[ClauseElement], *dialects: str - ) -> Callable[[_CompilerProtocol], _CompilerProtocol]: ... +class _Unsupported: + "Placeholder for unsupported SQLAlchemy classes" + + +try: + from sqlalchemy import Computed +except ImportError: + if not TYPE_CHECKING: + + class Computed(_Unsupported): + pass + + has_computed = False + has_computed_reflection = False else: - from sqlalchemy.ext.compiler import compiles # noqa: I100,I202 + has_computed = True + has_computed_reflection = _vers >= (1, 3, 16) +try: + from sqlalchemy import Identity +except ImportError: + if not TYPE_CHECKING: -identity_has_dialect_kwargs = issubclass(schema.Identity, DialectKWArgs) + class Identity(_Unsupported): + pass + has_identity = False +else: + identity_has_dialect_kwargs = issubclass(Identity, DialectKWArgs) -def _get_identity_options_dict( - identity: Union[Identity, schema.Sequence, None], - dialect_kwargs: bool = False, -) -> Dict[str, Any]: - if identity is None: - return {} - elif identity_has_dialect_kwargs: - assert hasattr(identity, "_as_dict") - as_dict = identity._as_dict() - if dialect_kwargs: - assert isinstance(identity, DialectKWArgs) - as_dict.update(identity.dialect_kwargs) - else: - as_dict = {} - if isinstance(identity, schema.Identity): - # always=None means something different than always=False - as_dict["always"] = identity.always - if identity.on_null is not None: - as_dict["on_null"] = identity.on_null - # attributes common to Identity and Sequence - attrs = ( - "start", - "increment", - "minvalue", - "maxvalue", - "nominvalue", - "nomaxvalue", - "cycle", - "cache", - "order", - ) - as_dict.update( - { - key: getattr(identity, key, None) - for key in attrs - if getattr(identity, key, None) is not None - } - ) - return as_dict + def _get_identity_options_dict( + identity: Union[Identity, schema.Sequence, None], + dialect_kwargs: bool = False, + ) -> Dict[str, Any]: + if identity is None: + return {} + elif identity_has_dialect_kwargs: + as_dict = identity._as_dict() # type: ignore + if dialect_kwargs: + assert isinstance(identity, DialectKWArgs) + as_dict.update(identity.dialect_kwargs) + else: + as_dict = {} + if isinstance(identity, Identity): + # always=None means something different than always=False + as_dict["always"] = identity.always + if identity.on_null is not None: + as_dict["on_null"] = identity.on_null + # attributes common to Identity and Sequence + attrs = ( + "start", + "increment", + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + "order", + ) + as_dict.update( + { + key: getattr(identity, key, None) + for key in attrs + if getattr(identity, key, None) is not None + } + ) + return as_dict + has_identity = True if sqla_2: from sqlalchemy.sql.base import _NoneName @@ -133,6 +153,7 @@ else: _ConstraintName = Union[None, str, _NoneName] + _ConstraintNameDefined = Union[str, _NoneName] @@ -142,11 +163,15 @@ def constraint_name_defined( return name is _NONE_NAME or isinstance(name, (str, _NoneName)) -def constraint_name_string(name: _ConstraintName) -> TypeGuard[str]: +def constraint_name_string( + name: _ConstraintName, +) -> TypeGuard[str]: return isinstance(name, str) -def constraint_name_or_none(name: _ConstraintName) -> Optional[str]: +def constraint_name_or_none( + name: _ConstraintName, +) -> Optional[str]: return name if constraint_name_string(name) else None @@ -176,10 +201,17 @@ def _ensure_scope_for_ddl( yield +def url_render_as_string(url, hide_password=True): + if sqla_14: + return url.render_as_string(hide_password=hide_password) + else: + return url.__to_string__(hide_password=hide_password) + + def _safe_begin_connection_transaction( connection: Connection, ) -> Transaction: - transaction = connection.get_transaction() + transaction = _get_connection_transaction(connection) if transaction: return transaction else: @@ -189,7 +221,7 @@ def _safe_begin_connection_transaction( def _safe_commit_connection_transaction( connection: Connection, ) -> None: - transaction = connection.get_transaction() + transaction = _get_connection_transaction(connection) if transaction: transaction.commit() @@ -197,7 +229,7 @@ def _safe_commit_connection_transaction( def _safe_rollback_connection_transaction( connection: Connection, ) -> None: - transaction = connection.get_transaction() + transaction = _get_connection_transaction(connection) if transaction: transaction.rollback() @@ -218,34 +250,70 @@ def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]: def _copy(schema_item: _CE, **kw) -> _CE: if hasattr(schema_item, "_copy"): - return schema_item._copy(**kw) + return schema_item._copy(**kw) # type: ignore[union-attr] else: return schema_item.copy(**kw) # type: ignore[union-attr] +def _get_connection_transaction( + connection: Connection, +) -> Optional[Transaction]: + if sqla_14: + return connection.get_transaction() + else: + r = connection._root # type: ignore[attr-defined] + return r._Connection__transaction + + +def _create_url(*arg, **kw) -> url.URL: + if hasattr(url.URL, "create"): + return url.URL.create(*arg, **kw) + else: + return url.URL(*arg, **kw) + + def _connectable_has_table( connectable: Connection, tablename: str, schemaname: Union[str, None] ) -> bool: - return connectable.dialect.has_table(connectable, tablename, schemaname) + if sqla_14: + return inspect(connectable).has_table(tablename, schemaname) + else: + return connectable.dialect.has_table( + connectable, tablename, schemaname + ) def _exec_on_inspector(inspector, statement, **params): - with inspector._operation_context() as conn: - return conn.execute(statement, params) + if sqla_14: + with inspector._operation_context() as conn: + return conn.execute(statement, params) + else: + return inspector.bind.execute(statement, params) def _nullability_might_be_unset(metadata_column): - from sqlalchemy.sql import schema + if not sqla_14: + return metadata_column.nullable + else: + from sqlalchemy.sql import schema - return metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED + return ( + metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED + ) def _server_default_is_computed(*server_default) -> bool: - return any(isinstance(sd, schema.Computed) for sd in server_default) + if not has_computed: + return False + else: + return any(isinstance(sd, Computed) for sd in server_default) def _server_default_is_identity(*server_default) -> bool: - return any(isinstance(sd, schema.Identity) for sd in server_default) + if not sqla_14: + return False + else: + return any(isinstance(sd, Identity) for sd in server_default) def _table_for_constraint(constraint: Constraint) -> Table: @@ -266,6 +334,15 @@ def _columns_for_constraint(constraint): return list(constraint.columns) +def _reflect_table(inspector: Inspector, table: Table) -> None: + if sqla_14: + return inspector.reflect_table(table, None) + else: + return inspector.reflecttable( # type: ignore[attr-defined] + table, None + ) + + def _resolve_for_variant(type_, dialect): if _type_has_variants(type_): base_type, mapping = _get_variant_mapping(type_) @@ -274,7 +351,7 @@ def _resolve_for_variant(type_, dialect): return type_ -if hasattr(sqltypes.TypeEngine, "_variant_mapping"): # 2.0 +if hasattr(sqltypes.TypeEngine, "_variant_mapping"): def _type_has_variants(type_): return bool(type_._variant_mapping) @@ -291,12 +368,7 @@ else: return type_.impl, type_.mapping -def _fk_spec(constraint: ForeignKeyConstraint) -> Any: - if TYPE_CHECKING: - assert constraint.columns is not None - assert constraint.elements is not None - assert isinstance(constraint.parent, Table) - +def _fk_spec(constraint): source_columns = [ constraint.columns[key].name for key in constraint.column_keys ] @@ -325,7 +397,7 @@ def _fk_spec(constraint: ForeignKeyConstraint) -> Any: def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool: - spec = constraint.elements[0]._get_colspec() + spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] tokens = spec.split(".") tokens.pop(-1) # colname tablekey = ".".join(tokens) @@ -337,13 +409,13 @@ def _is_type_bound(constraint: Constraint) -> bool: # this deals with SQLAlchemy #3260, don't copy CHECK constraints # that will be generated by the type. # new feature added for #3260 - return constraint._type_bound + return constraint._type_bound # type: ignore[attr-defined] def _find_columns(clause): """locate Column objects within the given expression.""" - cols: Set[ColumnElement[Any]] = set() + cols = set() traverse(clause, {}, {"column": cols.add}) return cols @@ -430,7 +502,7 @@ class _textual_index_element(sql.ColumnElement): self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) table.append_column(self.fake_column) - def get_children(self, **kw): + def get_children(self): return [self.fake_column] @@ -452,44 +524,116 @@ def _render_literal_bindparam( return compiler.render_literal_bindparam(element, **kw) +def _get_index_expressions(idx): + return list(idx.expressions) + + +def _get_index_column_names(idx): + return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] + + +def _column_kwargs(col: Column) -> Mapping: + if sqla_13: + return col.kwargs + else: + return {} + + def _get_constraint_final_name( constraint: Union[Index, Constraint], dialect: Optional[Dialect] ) -> Optional[str]: if constraint.name is None: return None assert dialect is not None - # for SQLAlchemy 1.4 we would like to have the option to expand - # the use of "deferred" names for constraints as well as to have - # some flexibility with "None" name and similar; make use of new - # SQLAlchemy API to return what would be the final compiled form of - # the name for this dialect. - return dialect.identifier_preparer.format_constraint( - constraint, _alembic_quote=False - ) + if sqla_14: + # for SQLAlchemy 1.4 we would like to have the option to expand + # the use of "deferred" names for constraints as well as to have + # some flexibility with "None" name and similar; make use of new + # SQLAlchemy API to return what would be the final compiled form of + # the name for this dialect. + return dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + else: + # prior to SQLAlchemy 1.4, work around quoting logic to get at the + # final compiled name without quotes. + if hasattr(constraint.name, "quote"): + # might be quoted_name, might be truncated_name, keep it the + # same + quoted_name_cls: type = type(constraint.name) + else: + quoted_name_cls = quoted_name + + new_name = quoted_name_cls(str(constraint.name), quote=False) + constraint = constraint.__class__(name=new_name) + + if isinstance(constraint, schema.Index): + # name should not be quoted. + d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type] + return d._prepared_index_name( # type: ignore[attr-defined] + constraint + ) + else: + # name should not be quoted. + return dialect.identifier_preparer.format_constraint(constraint) def _constraint_is_named( constraint: Union[Constraint, Index], dialect: Optional[Dialect] ) -> bool: - if constraint.name is None: - return False - assert dialect is not None - name = dialect.identifier_preparer.format_constraint( - constraint, _alembic_quote=False - ) - return name is not None + if sqla_14: + if constraint.name is None: + return False + assert dialect is not None + name = dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + return name is not None + else: + return constraint.name is not None + + +def _is_mariadb(mysql_dialect: Dialect) -> bool: + if sqla_14: + return mysql_dialect.is_mariadb # type: ignore[attr-defined] + else: + return bool( + mysql_dialect.server_version_info + and mysql_dialect._is_mariadb # type: ignore[attr-defined] + ) + + +def _mariadb_normalized_version_info(mysql_dialect): + return mysql_dialect._mariadb_normalized_version_info + + +def _insert_inline(table: Union[TableClause, Table]) -> Insert: + if sqla_14: + return table.insert().inline() + else: + return table.insert(inline=True) # type: ignore[call-arg] + + +if sqla_14: + from sqlalchemy import create_mock_engine + from sqlalchemy import select as _select +else: + from sqlalchemy import create_engine + + def create_mock_engine(url, executor, **kw): # type: ignore[misc] + return create_engine( + "postgresql://", strategy="mock", executor=executor + ) + + def _select(*columns, **kw) -> Select: # type: ignore[no-redef] + return sql.select(list(columns), **kw) # type: ignore[call-overload] def is_expression_index(index: Index) -> bool: + expr: Any for expr in index.expressions: - if is_expression(expr): + while isinstance(expr, UnaryExpression): + expr = expr.element + if not isinstance(expr, ColumnClause) or expr.is_literal: return True return False - - -def is_expression(expr: Any) -> bool: - while isinstance(expr, UnaryExpression): - expr = expr.element - if not isinstance(expr, ColumnClause) or expr.is_literal: - return True - return False diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/AUTHORS b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/AUTHORS similarity index 100% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/AUTHORS rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/AUTHORS diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/LICENSE b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/LICENSE rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/METADATA similarity index 72% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/METADATA index d9f971e..08124ba 100644 --- a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: asyncpg -Version: 0.30.0 +Version: 0.29.0 Summary: An asyncio PostgreSQL driver Author-email: MagicStack Inc License: Apache License, Version 2.0 @@ -25,22 +25,14 @@ Requires-Python: >=3.8.0 Description-Content-Type: text/x-rst License-File: LICENSE License-File: AUTHORS -Requires-Dist: async-timeout>=4.0.3; python_version < "3.11.0" +Requires-Dist: async-timeout >=4.0.3 ; python_version < "3.12.0" Provides-Extra: docs -Requires-Dist: Sphinx~=8.1.3; extra == "docs" -Requires-Dist: sphinx-rtd-theme>=1.2.2; extra == "docs" -Provides-Extra: gssauth -Requires-Dist: gssapi; platform_system != "Windows" and extra == "gssauth" -Requires-Dist: sspilib; platform_system == "Windows" and extra == "gssauth" +Requires-Dist: Sphinx ~=5.3.0 ; extra == 'docs' +Requires-Dist: sphinxcontrib-asyncio ~=0.3.0 ; extra == 'docs' +Requires-Dist: sphinx-rtd-theme >=1.2.2 ; extra == 'docs' Provides-Extra: test -Requires-Dist: flake8~=6.1; extra == "test" -Requires-Dist: flake8-pyi~=24.1.0; extra == "test" -Requires-Dist: distro~=1.9.0; extra == "test" -Requires-Dist: mypy~=1.8.0; extra == "test" -Requires-Dist: uvloop>=0.15.3; (platform_system != "Windows" and python_version < "3.14.0") and extra == "test" -Requires-Dist: gssapi; platform_system == "Linux" and extra == "test" -Requires-Dist: k5test; platform_system == "Linux" and extra == "test" -Requires-Dist: sspilib; platform_system == "Windows" and extra == "test" +Requires-Dist: flake8 ~=6.1 ; extra == 'test' +Requires-Dist: uvloop >=0.15.3 ; (platform_system != "Windows" and python_version < "3.12.0") and extra == 'test' asyncpg -- A fast PostgreSQL Database Client Library for Python/asyncio ======================================================================= @@ -58,9 +50,8 @@ framework. You can read more about asyncpg in an introductory `blog post `_. asyncpg requires Python 3.8 or later and is supported for PostgreSQL -versions 9.5 to 17. Other PostgreSQL versions or other databases -implementing the PostgreSQL protocol *may* work, but are not being -actively tested. +versions 9.5 to 16. Older PostgreSQL versions or other databases implementing +the PostgreSQL protocol *may* work, but are not being actively tested. Documentation @@ -103,18 +94,11 @@ This enables asyncpg to have easy-to-use support for: Installation ------------ -asyncpg is available on PyPI. When not using GSSAPI/SSPI authentication it -has no dependencies. Use pip to install:: +asyncpg is available on PyPI and has no dependencies. +Use pip to install:: $ pip install asyncpg -If you need GSSAPI/SSPI authentication, use:: - - $ pip install 'asyncpg[gssauth]' - -For more details, please `see the documentation -`_. - Basic Usage ----------- @@ -133,7 +117,8 @@ Basic Usage ) await conn.close() - asyncio.run(run()) + loop = asyncio.get_event_loop() + loop.run_until_complete(run()) License diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/RECORD similarity index 68% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/RECORD rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/RECORD index 4a1e719..0346fde 100644 --- a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/RECORD @@ -1,12 +1,12 @@ -asyncpg-0.30.0.dist-info/AUTHORS,sha256=gIYYcUuWiSZS93lstwQtCT56St1NtKg-fikn8ourw64,130 -asyncpg-0.30.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -asyncpg-0.30.0.dist-info/LICENSE,sha256=2SItc_2sUJkhdAdu-gT0T2-82dVhVafHCS6YdXBCpvY,11466 -asyncpg-0.30.0.dist-info/METADATA,sha256=60MN0tXDvcPtxahUC1vxSP8-dS5hYDtir_YIbY2NCkQ,5010 -asyncpg-0.30.0.dist-info/RECORD,, -asyncpg-0.30.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -asyncpg-0.30.0.dist-info/WHEEL,sha256=OVgtqZzfzIXXtylXP90gxCZ6CKBCwKYyHM8PpMEjN1M,151 -asyncpg-0.30.0.dist-info/top_level.txt,sha256=DdhVhpzCq49mykkHNag6i9zuJx05_tx4CMZymM1F8dU,8 -asyncpg/__init__.py,sha256=bzD31aMekbKR9waMXuAxIYFbmrQ-S1Mttjmru_sSjo8,647 +asyncpg-0.29.0.dist-info/AUTHORS,sha256=gIYYcUuWiSZS93lstwQtCT56St1NtKg-fikn8ourw64,130 +asyncpg-0.29.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +asyncpg-0.29.0.dist-info/LICENSE,sha256=2SItc_2sUJkhdAdu-gT0T2-82dVhVafHCS6YdXBCpvY,11466 +asyncpg-0.29.0.dist-info/METADATA,sha256=_xxlp3Q6M3HJGWcW4cnzhtcswIBd0n7IztyBiZe4Pj0,4356 +asyncpg-0.29.0.dist-info/RECORD,, +asyncpg-0.29.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +asyncpg-0.29.0.dist-info/WHEEL,sha256=JmQLNqDEfvnYMfsIaVeSP3fmUcYDwmF12m3QYW0c7QQ,152 +asyncpg-0.29.0.dist-info/top_level.txt,sha256=DdhVhpzCq49mykkHNag6i9zuJx05_tx4CMZymM1F8dU,8 +asyncpg/__init__.py,sha256=jOW3EoH2dDw1bsrd4qipodmPJsEN6D5genWdyqhB7e8,563 asyncpg/__pycache__/__init__.cpython-312.pyc,, asyncpg/__pycache__/_asyncio_compat.cpython-312.pyc,, asyncpg/__pycache__/_version.cpython-312.pyc,, @@ -23,23 +23,23 @@ asyncpg/__pycache__/serverversion.cpython-312.pyc,, asyncpg/__pycache__/transaction.cpython-312.pyc,, asyncpg/__pycache__/types.cpython-312.pyc,, asyncpg/__pycache__/utils.cpython-312.pyc,, -asyncpg/_asyncio_compat.py,sha256=pXF_aF4o_AqxNql0sPnuGdoe5sSSwQxHpKWF6ShZTbo,2540 -asyncpg/_testbase/__init__.py,sha256=IzMqfgI5gtOxajneoeWyoI4NtmE5sp7S5dXmU0gwwB8,16499 +asyncpg/_asyncio_compat.py,sha256=VgUVf12ztecdiiAMjpS53R_XizOQMKXJBRK-9iCG6cI,2299 +asyncpg/_testbase/__init__.py,sha256=Sj6bhG3a8k5hqp1eFv7I6IkfulcvCXbd1y4tvfz5WQk,16066 asyncpg/_testbase/__pycache__/__init__.cpython-312.pyc,, asyncpg/_testbase/__pycache__/fuzzer.cpython-312.pyc,, asyncpg/_testbase/fuzzer.py,sha256=3Uxdu0YXei-7JZMCuCI3bxKMdnbuossV-KC68GG-AS4,9804 -asyncpg/_version.py,sha256=MLgciqpbfndZJPsc0fi_WNdVVcsn3Wobpaw0WiaRvEo,641 -asyncpg/cluster.py,sha256=s_HmtiEGJqJ6GQWa6_zmfe11fZ29OpOtMT6Ufcu-g0g,24476 -asyncpg/compat.py,sha256=ebs2IeJw82rY9m0ZCmOYUqry_2nF3zqTi3tsWP5FT2o,2459 -asyncpg/connect_utils.py,sha256=vaVSrnmko33wPjw1X5wlbooF0FTeFlN5b50burZuUWc,36923 -asyncpg/connection.py,sha256=EFlI_1VIkSFzSszsUCCl0eFJITT-5McSuAVmWJyCy-Y,98545 +asyncpg/_version.py,sha256=vGtvByhKF_7cyfQ46GVcrEyZ0o87ts1ofOzkmLgbmFg,576 +asyncpg/cluster.py,sha256=Bna0wFKj9tACcD4Uxjv9eeo5EwAEeJi4t5YVbN434ao,23283 +asyncpg/compat.py,sha256=mQmQgtRgu1clS-Aqiz76g1tHH9qXIRK_xJ7sokx-Y2U,1769 +asyncpg/connect_utils.py,sha256=xZE61cj1Afwm_VyKSDmWHcYwDCwIIk66OXq9MBHyH8M,34979 +asyncpg/connection.py,sha256=f30Jo8XllatqjavvlrkNCcgnIaKnNTQvf32NVJB3ExM,95227 asyncpg/connresource.py,sha256=tBAidNpEhbDvrMOKQbwn3ZNgIVAtsVxARxTnwj5fk-Q,1384 asyncpg/cursor.py,sha256=rKeSIJMW5mUpvsian6a1MLrLoEwbkYTZsmZtEgwFT6s,9160 -asyncpg/exceptions/__init__.py,sha256=FXUYDFQw9gxE3mVz99FmsldYxivLUMtTIhXzu5tZ7Pk,29157 +asyncpg/exceptions/__init__.py,sha256=yZXt3k0lHuF-5czqfBcsMfhxgI5fXAT31hSTn7_fiMM,28826 asyncpg/exceptions/__pycache__/__init__.cpython-312.pyc,, asyncpg/exceptions/__pycache__/_base.cpython-312.pyc,, asyncpg/exceptions/_base.py,sha256=u62xv69n4AHO1xr35FjdgZhYvqdeb_mkQKyp-ip_AyQ,9260 -asyncpg/introspection.py,sha256=biiHj5yQMB8RGch2TiH2TPocN3OO6_GasyijFYxgUOM,9215 +asyncpg/introspection.py,sha256=0oyQXJF6WHpVMq7K_8VIOMVTlGde71cFCA_9NkuDgcQ,8957 asyncpg/pgproto/__init__.pxd,sha256=uUIkKuI6IGnQ5tZXtrjOC_13qjp9MZOwewKlrxKFzPY,213 asyncpg/pgproto/__init__.py,sha256=uUIkKuI6IGnQ5tZXtrjOC_13qjp9MZOwewKlrxKFzPY,213 asyncpg/pgproto/__pycache__/__init__.cpython-312.pyc,, @@ -70,44 +70,42 @@ asyncpg/pgproto/debug.pxd,sha256=SuLG2tteWe3cXnS0czRTTNnnm2QGgG02icp_6G_X9Yw,263 asyncpg/pgproto/frb.pxd,sha256=B2s2dw-SkzfKWeLEWzVLTkjjYYW53pazPcVNH3vPxAk,1212 asyncpg/pgproto/frb.pyx,sha256=7bipWSBXebweq3JBFlCvSwa03fIZGLkKPqWbJ8VFWFI,409 asyncpg/pgproto/hton.pxd,sha256=Swx5ry82iWYO9Ok4fRa_b7cLSrIPyxNYlyXm-ncYweo,953 -asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so,sha256=pq0nrGmFE6y2VQgcWlKcFODTl9h9We00i1xQT55RYdE,3131904 +asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so,sha256=niR6XwwgUbpcrq6BrfQXz0NgIq2fn9xyMgzPq2yrACY,2849672 asyncpg/pgproto/pgproto.pxd,sha256=QUUxWiHKdKfFxdDT0czSvOFsA4b59MJRR6WlUbJFgPg,430 -asyncpg/pgproto/pgproto.pyi,sha256=W5nuATmpHFfhRF7Hnjt5Vuvr1lBJ-xkJ8nIvEYE1N1E,275 asyncpg/pgproto/pgproto.pyx,sha256=bK75qfRQlofzO8dDzJ2mHUE0wLeXSsc5SLeAGvyXSeE,1249 asyncpg/pgproto/tohex.pxd,sha256=fQVaxBu6dBw2P_ROR8MSPVDlVep0McKi69fdQBLhifI,361 asyncpg/pgproto/types.py,sha256=wzJgyDJ63Eu2TJym0EhhEr6-D9iIV3cdlzab11sgRS0,13014 asyncpg/pgproto/uuid.pyx,sha256=PrQIvQKJJItsYFpwZtDCcR9Z_DIbEi_MUt6tQjnVaYI,9943 -asyncpg/pool.py,sha256=oZh4JC01xizpa3MQSJ4mcOW71Nb_jYWluY_Dm2549fg,41296 -asyncpg/prepared_stmt.py,sha256=YfOSeQavN1c1o5SajD9ylTCLHpNV5plGBEw9ku8KyBk,9752 -asyncpg/protocol/__init__.py,sha256=c-b07Si_DGN9rqiCUAmR9RaCUCy_LiJ4lqHCb0yMBRI,340 +asyncpg/pool.py,sha256=VilAdZmMrodLmu7xeYk2ExoJRFUzk4ORT4kdxMMVE64,38168 +asyncpg/prepared_stmt.py,sha256=jay1C7UISpmXmotWkUXgdRidgtSdvmaCxlGZ6xlNGEM,8992 +asyncpg/protocol/__init__.py,sha256=6mxFfJskIjmKjSxxOybsuHY68wa2BlqY3z0VWG1BT4g,304 asyncpg/protocol/__pycache__/__init__.cpython-312.pyc,, asyncpg/protocol/codecs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 asyncpg/protocol/codecs/__pycache__/__init__.cpython-312.pyc,, asyncpg/protocol/codecs/array.pyx,sha256=1S_6xdgxllG8_1Lb68XdPkH1QgF63gAAmjh091Q7Dyk,29486 asyncpg/protocol/codecs/base.pxd,sha256=NfDsh60UZX-gVThlj8rzGmLRqMbXAYqSJsAwKTcZ1Cg,6224 -asyncpg/protocol/codecs/base.pyx,sha256=C1SPRtSdYbshnvZOHJVj9Gp30VSj5z6nQRBUoPgj2IU,33464 +asyncpg/protocol/codecs/base.pyx,sha256=V8-mmRPV3eFn2jUmdFILFNNsjUaFH8_x4S5IJ7OjtCM,33475 asyncpg/protocol/codecs/pgproto.pyx,sha256=5PDv1JT_nXbDbHtYVrGCcZN3CxzQdgwqlXT8GpyMamk,17175 asyncpg/protocol/codecs/range.pyx,sha256=-P-acyY2e5TlEtjqbkeH28PYk-DGLxqbmzKDFGL5BbI,6359 asyncpg/protocol/codecs/record.pyx,sha256=l17HPv3ZeZzvDMXmh-FTdOQ0LxqaQsge_4hlmnGaf6s,2362 asyncpg/protocol/codecs/textutils.pyx,sha256=UmTt1Zs5N2oLVDMTSlSe1zAFt5q4_4akbXZoS6HSPO8,2011 asyncpg/protocol/consts.pxi,sha256=VT7NLBpLgPUvcUbPflrX84I79JZiFg4zFzBK28nCRZo,381 -asyncpg/protocol/coreproto.pxd,sha256=77yJqaBMGWHmxyihZIFfyVgfzICF9jLwKSvtuCoE8rM,6215 -asyncpg/protocol/coreproto.pyx,sha256=sMvXqxnppthc_LJYibMAJts0IfEPgYVs4nwXmY3v-IY,41037 +asyncpg/protocol/coreproto.pxd,sha256=ozuSON07EOnWmJI4v3gtTjD18APpZfk1WfnoWLZ53as,6149 +asyncpg/protocol/coreproto.pyx,sha256=UprN-4_PaJFN82fCCA2tE0t_i_dShyTdtsbymOYGnfE,38015 asyncpg/protocol/cpythonx.pxd,sha256=VX71g4PiwXWGTY-BzBPm7S-AiX5ySRrY40qAggH-BIA,613 asyncpg/protocol/encodings.pyx,sha256=QegnSON5y-a0aQFD9zFbhAzhYTbKYj-vl3VGiyqIU3I,1644 asyncpg/protocol/pgtypes.pxi,sha256=w8Mb6N7Z58gxPYWZkj5lwk0PRW7oBTIf9fo0MvPzm4c,6924 asyncpg/protocol/prepared_stmt.pxd,sha256=GhHzJgQMehpWg0i3XSmbkJH6G5nnnmdNCf2EU_gXhDY,1115 -asyncpg/protocol/prepared_stmt.pyx,sha256=wfo57hwGrghO3-0o7OxABV2heL2Fb0teENUZNmMj6aI,13058 -asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so,sha256=mIYQ9YlP2JuiyV4448aKNqyUZgEVjBiVXlsSzKxM41k,9439904 -asyncpg/protocol/protocol.pxd,sha256=yOVFbkD7mA8VK5IGIJ4dGTyvHKWZTQOFfCFNfdeUdK8,1927 -asyncpg/protocol/protocol.pyi,sha256=Dg0-ZTvLCXc3g3aCvEHvSKVzRp63Q-9iceiqTSQMr2g,9732 -asyncpg/protocol/protocol.pyx,sha256=V99Dm45e8vgV3qSa-jmS2YypntSymrznLtyxoveU7jI,34850 +asyncpg/protocol/prepared_stmt.pyx,sha256=fbhQpVuDFEQ1GOw--sZdrD-iOkTvU5JXFOlxKpTe36c,13052 +asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so,sha256=92ZeyBjeYWIwDLtGqPCzTYw-94N8V-3BiTguwL6iNu4,8713328 +asyncpg/protocol/protocol.pxd,sha256=0Y1NFvnR3N0rmvBMUAocYi4U9RbAyg-6qkoqOgy53Fg,1950 +asyncpg/protocol/protocol.pyx,sha256=2EN1Aq45eR3pGQjQciafqFQzi4ilLqDLP2LpLWM3wVE,34824 asyncpg/protocol/record/__init__.pxd,sha256=KJyCfN_ST2yyEDnUS3PfipeIEYmY8CVTeOwFPcUcVNc,495 asyncpg/protocol/scram.pxd,sha256=t_nkicIS_4AzxyHoq-aYUNrFNv8O0W7E090HfMAIuno,1299 asyncpg/protocol/scram.pyx,sha256=nT_Rawg6h3OrRWDBwWN7lju5_hnOmXpwWFWVrb3l_dQ,14594 asyncpg/protocol/settings.pxd,sha256=8DTwZ5mi0aAUJRWE6SUIRDhWFGFis1mj8lcA8hNFTL0,1066 -asyncpg/protocol/settings.pyx,sha256=yICjZF5FXwfmdxQBg-1qO0XbpLvZL11-c3aMbiwM7oo,3777 -asyncpg/serverversion.py,sha256=WwlqBJkXZHvvnFluubCjPoaX_7OqjR8QgiOe90w6C9E,2133 +asyncpg/protocol/settings.pyx,sha256=Z_GsQoRKzqBeztO8AJMTbv_xpT-mk8LgLfvQ2l-W7cY,3795 +asyncpg/serverversion.py,sha256=xdxEy45U9QGhpfTp3c4g6jSJ3NEb4lsDcTe3qvFNDQg,1790 asyncpg/transaction.py,sha256=uAJok6Shx7-Kdt5l4NX-GJtLxVJSPXTOJUryGdbIVG8,8497 -asyncpg/types.py,sha256=2x-nAVdfk41PA83DyYcWxkUNXsiGLotGkMX0gVpuFoY,5520 -asyncpg/utils.py,sha256=Y0vATexoIHFkpWURlqnlUZUacc4F1iZJ9rWJ3654OnM,1495 +asyncpg/types.py,sha256=msRSL9mXKPWjVXMi0yrk5vhVwQp9Sdwyfcp_zz8ZkNU,4653 +asyncpg/utils.py,sha256=NWmcsmYORwc4rjJvwrUqJrv1lP2Qq5c-v139LBv2ZVQ,1367 diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/WHEEL similarity index 78% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/WHEEL index 057fef6..c5825c5 100644 --- a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (75.2.0) +Generator: bdist_wheel (0.41.3) Root-Is-Purelib: false Tag: cp312-cp312-manylinux_2_17_x86_64 Tag: cp312-cp312-manylinux2014_x86_64 diff --git a/venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/asyncpg-0.30.0.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/asyncpg-0.29.0.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/asyncpg/__init__.py b/venv/lib/python3.12/site-packages/asyncpg/__init__.py index e8811a9..e8cd11e 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/__init__.py +++ b/venv/lib/python3.12/site-packages/asyncpg/__init__.py @@ -4,7 +4,6 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from __future__ import annotations from .connection import connect, Connection # NOQA from .exceptions import * # NOQA @@ -15,10 +14,6 @@ from .types import * # NOQA from ._version import __version__ # NOQA -from . import exceptions - -__all__: tuple[str, ...] = ( - 'connect', 'create_pool', 'Pool', 'Record', 'Connection' -) +__all__ = ('connect', 'create_pool', 'Pool', 'Record', 'Connection') __all__ += exceptions.__all__ # NOQA diff --git a/venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py b/venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py index a211d0a..ad7dfd8 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py +++ b/venv/lib/python3.12/site-packages/asyncpg/_asyncio_compat.py @@ -4,25 +4,18 @@ # # SPDX-License-Identifier: PSF-2.0 -from __future__ import annotations import asyncio import functools import sys -import typing - -if typing.TYPE_CHECKING: - from . import compat if sys.version_info < (3, 11): from async_timeout import timeout as timeout_ctx else: from asyncio import timeout as timeout_ctx -_T = typing.TypeVar('_T') - -async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T: +async def wait_for(fut, timeout): """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -72,7 +65,7 @@ async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T: return await fut -async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: +async def _cancel_and_wait(fut): """Cancel the *fut* future or task and wait until it completes.""" loop = asyncio.get_running_loop() @@ -89,6 +82,6 @@ async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None: fut.remove_done_callback(cb) -def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None: +def _release_waiter(waiter, *args): if not waiter.done(): waiter.set_result(None) diff --git a/venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py b/venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py index 95775e1..7aca834 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py +++ b/venv/lib/python3.12/site-packages/asyncpg/_testbase/__init__.py @@ -117,22 +117,10 @@ class TestCase(unittest.TestCase, metaclass=TestCaseMeta): self.__unhandled_exceptions = [] def tearDown(self): - excs = [] - for exc in self.__unhandled_exceptions: - if isinstance(exc, ConnectionResetError): - texc = traceback.TracebackException.from_exception( - exc, lookup_lines=False) - if texc.stack[-1].name == "_call_connection_lost": - # On Windows calling socket.shutdown may raise - # ConnectionResetError, which happens in the - # finally block of _call_connection_lost. - continue - excs.append(exc) - - if excs: + if self.__unhandled_exceptions: formatted = [] - for i, context in enumerate(excs): + for i, context in enumerate(self.__unhandled_exceptions): formatted.append(self._format_loop_exception(context, i + 1)) self.fail( @@ -226,6 +214,13 @@ def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None): return cluster +def _start_cluster(ClusterCls, cluster_kwargs, server_settings, + initdb_options=None): + cluster = _init_cluster(ClusterCls, cluster_kwargs, initdb_options) + cluster.start(port='dynamic', server_settings=server_settings) + return cluster + + def _get_initdb_options(initdb_options=None): if not initdb_options: initdb_options = {} @@ -249,12 +244,8 @@ def _init_default_cluster(initdb_options=None): _default_cluster = pg_cluster.RunningCluster() else: _default_cluster = _init_cluster( - pg_cluster.TempCluster, - cluster_kwargs={ - "data_dir_suffix": ".apgtest", - }, - initdb_options=_get_initdb_options(initdb_options), - ) + pg_cluster.TempCluster, cluster_kwargs={}, + initdb_options=_get_initdb_options(initdb_options)) return _default_cluster @@ -271,7 +262,6 @@ def create_pool(dsn=None, *, max_size=10, max_queries=50000, max_inactive_connection_lifetime=60.0, - connect=None, setup=None, init=None, loop=None, @@ -281,18 +271,12 @@ def create_pool(dsn=None, *, **connect_kwargs): return pool_class( dsn, - min_size=min_size, - max_size=max_size, - max_queries=max_queries, - loop=loop, - connect=connect, - setup=setup, - init=init, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, connection_class=connection_class, record_class=record_class, - **connect_kwargs, - ) + **connect_kwargs) class ClusterTestCase(TestCase): diff --git a/venv/lib/python3.12/site-packages/asyncpg/_version.py b/venv/lib/python3.12/site-packages/asyncpg/_version.py index 245eee7..64da11d 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/_version.py +++ b/venv/lib/python3.12/site-packages/asyncpg/_version.py @@ -10,8 +10,4 @@ # supported platforms, publish the packages on PyPI, merge the PR # to the target branch, create a Git tag pointing to the commit. -from __future__ import annotations - -import typing - -__version__: typing.Final = '0.30.0' +__version__ = '0.29.0' diff --git a/venv/lib/python3.12/site-packages/asyncpg/cluster.py b/venv/lib/python3.12/site-packages/asyncpg/cluster.py index 606c2ea..4467cc2 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/cluster.py +++ b/venv/lib/python3.12/site-packages/asyncpg/cluster.py @@ -9,11 +9,9 @@ import asyncio import os import os.path import platform -import random import re import shutil import socket -import string import subprocess import sys import tempfile @@ -47,29 +45,6 @@ def find_available_port(): sock.close() -def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None): - name = "".join(random.choices(string.ascii_lowercase, k=8)) - if dir is None: - dir = tempfile.gettempdir() - if prefix is None: - prefix = tempfile.gettempprefix() - if suffix is None: - suffix = "" - fn = os.path.join(dir, prefix + name + suffix) - os.mkdir(fn, 0o755) - return fn - - -def _mkdtemp(suffix=None, prefix=None, dir=None): - if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"): - # Due to mitigations introduced in python/cpython#118486 - # when Python runs in a session created via an SSH connection - # tempfile.mkdtemp creates directories that are not accessible. - return _world_readable_mkdtemp(suffix, prefix, dir) - else: - return tempfile.mkdtemp(suffix, prefix, dir) - - class ClusterError(Exception): pass @@ -147,13 +122,9 @@ class Cluster: else: extra_args = [] - os.makedirs(self._data_dir, exist_ok=True) process = subprocess.run( [self._pg_ctl, 'init', '-D', self._data_dir] + extra_args, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - cwd=self._data_dir, - ) + stdout=subprocess.PIPE, stderr=subprocess.STDOUT) output = process.stdout @@ -228,10 +199,7 @@ class Cluster: process = subprocess.run( [self._pg_ctl, 'start', '-D', self._data_dir, '-o', ' '.join(extra_args)], - stdout=stdout, - stderr=subprocess.STDOUT, - cwd=self._data_dir, - ) + stdout=stdout, stderr=subprocess.STDOUT) if process.returncode != 0: if process.stderr: @@ -250,10 +218,7 @@ class Cluster: self._daemon_process = \ subprocess.Popen( [self._postgres, '-D', self._data_dir, *extra_args], - stdout=stdout, - stderr=subprocess.STDOUT, - cwd=self._data_dir, - ) + stdout=stdout, stderr=subprocess.STDOUT) self._daemon_pid = self._daemon_process.pid @@ -267,10 +232,7 @@ class Cluster: process = subprocess.run( [self._pg_ctl, 'reload', '-D', self._data_dir], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=self._data_dir, - ) + stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = process.stderr @@ -283,10 +245,7 @@ class Cluster: process = subprocess.run( [self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait), '-m', 'fast'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=self._data_dir, - ) + stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = process.stderr @@ -624,9 +583,9 @@ class TempCluster(Cluster): def __init__(self, *, data_dir_suffix=None, data_dir_prefix=None, data_dir_parent=None, pg_config_path=None): - self._data_dir = _mkdtemp(suffix=data_dir_suffix, - prefix=data_dir_prefix, - dir=data_dir_parent) + self._data_dir = tempfile.mkdtemp(suffix=data_dir_suffix, + prefix=data_dir_prefix, + dir=data_dir_parent) super().__init__(self._data_dir, pg_config_path=pg_config_path) diff --git a/venv/lib/python3.12/site-packages/asyncpg/compat.py b/venv/lib/python3.12/site-packages/asyncpg/compat.py index 57eec65..3eec9eb 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/compat.py +++ b/venv/lib/python3.12/site-packages/asyncpg/compat.py @@ -4,26 +4,22 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from __future__ import annotations -import enum import pathlib import platform import typing import sys -if typing.TYPE_CHECKING: - import asyncio -SYSTEM: typing.Final = platform.uname().system +SYSTEM = platform.uname().system -if sys.platform == 'win32': +if SYSTEM == 'Windows': import ctypes.wintypes - CSIDL_APPDATA: typing.Final = 0x001a + CSIDL_APPDATA = 0x001a - def get_pg_home_directory() -> pathlib.Path | None: + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: # We cannot simply use expanduser() as that returns the user's # home directory, whereas Postgres stores its config in # %AppData% on Windows. @@ -35,14 +31,14 @@ if sys.platform == 'win32': return pathlib.Path(buf.value) / 'postgresql' else: - def get_pg_home_directory() -> pathlib.Path | None: + def get_pg_home_directory() -> typing.Optional[pathlib.Path]: try: return pathlib.Path.home() except (RuntimeError, KeyError): return None -async def wait_closed(stream: asyncio.StreamWriter) -> None: +async def wait_closed(stream): # Not all asyncio versions have StreamWriter.wait_closed(). if hasattr(stream, 'wait_closed'): try: @@ -53,13 +49,6 @@ async def wait_closed(stream: asyncio.StreamWriter) -> None: pass -if sys.version_info < (3, 12): - def markcoroutinefunction(c): # type: ignore - pass -else: - from inspect import markcoroutinefunction # noqa: F401 - - if sys.version_info < (3, 12): from ._asyncio_compat import wait_for as wait_for # noqa: F401 else: @@ -70,19 +59,3 @@ if sys.version_info < (3, 11): from ._asyncio_compat import timeout_ctx as timeout # noqa: F401 else: from asyncio import timeout as timeout # noqa: F401 - -if sys.version_info < (3, 9): - from typing import ( # noqa: F401 - Awaitable as Awaitable, - ) -else: - from collections.abc import ( # noqa: F401 - Awaitable as Awaitable, - ) - -if sys.version_info < (3, 11): - class StrEnum(str, enum.Enum): - __str__ = str.__str__ - __repr__ = enum.Enum.__repr__ -else: - from enum import StrEnum as StrEnum # noqa: F401 diff --git a/venv/lib/python3.12/site-packages/asyncpg/connect_utils.py b/venv/lib/python3.12/site-packages/asyncpg/connect_utils.py index 4890d00..414231f 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/connect_utils.py +++ b/venv/lib/python3.12/site-packages/asyncpg/connect_utils.py @@ -45,11 +45,6 @@ class SSLMode(enum.IntEnum): return getattr(cls, sslmode.replace('-', '_')) -class SSLNegotiation(compat.StrEnum): - postgres = "postgres" - direct = "direct" - - _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ @@ -58,11 +53,9 @@ _ConnectionParameters = collections.namedtuple( 'database', 'ssl', 'sslmode', - 'ssl_negotiation', + 'direct_tls', 'server_settings', 'target_session_attrs', - 'krbsrvname', - 'gsslib', ]) @@ -268,13 +261,12 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None sslcert = sslkey = sslrootcert = sslcrl = sslpassword = None ssl_min_protocol_version = ssl_max_protocol_version = None - sslnegotiation = None if dsn: parsed = urllib.parse.urlparse(dsn) @@ -368,9 +360,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if 'sslrootcert' in query: sslrootcert = query.pop('sslrootcert') - if 'sslnegotiation' in query: - sslnegotiation = query.pop('sslnegotiation') - if 'sslcrl' in query: sslcrl = query.pop('sslcrl') @@ -394,16 +383,6 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs - if 'krbsrvname' in query: - val = query.pop('krbsrvname') - if krbsrvname is None: - krbsrvname = val - - if 'gsslib' in query: - val = query.pop('gsslib') - if gsslib is None: - gsslib = val - if query: if server_settings is None: server_settings = query @@ -512,36 +491,13 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if ssl is None and have_tcp_addrs: ssl = 'prefer' - if direct_tls is not None: - sslneg = ( - SSLNegotiation.direct if direct_tls else SSLNegotiation.postgres - ) - else: - if sslnegotiation is None: - sslnegotiation = os.environ.get("PGSSLNEGOTIATION") - - if sslnegotiation is not None: - try: - sslneg = SSLNegotiation(sslnegotiation) - except ValueError: - modes = ', '.join( - m.name.replace('_', '-') - for m in SSLNegotiation - ) - raise exceptions.ClientConfigurationError( - f'`sslnegotiation` parameter must be one of: {modes}' - ) from None - else: - sslneg = SSLNegotiation.postgres - if isinstance(ssl, (str, SSLMode)): try: sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.ClientConfigurationError( - '`sslmode` parameter must be one of: {}'.format(modes) - ) from None + '`sslmode` parameter must be one of: {}'.format(modes)) # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html if sslmode < SSLMode.allow: @@ -694,24 +650,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, ) ) from None - if krbsrvname is None: - krbsrvname = os.getenv('PGKRBSRVNAME') - - if gsslib is None: - gsslib = os.getenv('PGGSSLIB') - if gsslib is None: - gsslib = 'sspi' if _system == 'Windows' else 'gssapi' - if gsslib not in {'gssapi', 'sspi'}: - raise exceptions.ClientConfigurationError( - "gsslib parameter must be either 'gssapi' or 'sspi'" - ", got {!r}".format(gsslib)) - params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, - sslmode=sslmode, ssl_negotiation=sslneg, + sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, - target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + target_session_attrs=target_session_attrs) return addrs, params @@ -722,7 +665,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -751,8 +694,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, - target_session_attrs=target_session_attrs, - krbsrvname=krbsrvname, gsslib=gsslib) + target_session_attrs=target_session_attrs) config = _ClientConfiguration( command_timeout=command_timeout, @@ -914,9 +856,9 @@ async def __connect_addr( # UNIX socket connector = loop.create_unix_connection(proto_factory, addr) - elif params.ssl and params.ssl_negotiation is SSLNegotiation.direct: - # if ssl and ssl_negotiation is `direct`, skip STARTTLS and perform - # direct SSL connection + elif params.ssl and params.direct_tls: + # if ssl and direct_tls are given, skip STARTTLS and perform direct + # SSL connection connector = loop.create_connection( proto_factory, *addr, ssl=params.ssl ) diff --git a/venv/lib/python3.12/site-packages/asyncpg/connection.py b/venv/lib/python3.12/site-packages/asyncpg/connection.py index 3a86466..0367e36 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/connection.py +++ b/venv/lib/python3.12/site-packages/asyncpg/connection.py @@ -231,8 +231,9 @@ class Connection(metaclass=ConnectionMeta): :param callable callback: A callable or a coroutine function receiving one argument: - **record**, a LoggedQuery containing `query`, `args`, `timeout`, - `elapsed`, `exception`, `conn_addr`, and `conn_params`. + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. .. versionadded:: 0.29.0 """ @@ -756,44 +757,6 @@ class Connection(metaclass=ConnectionMeta): return None return data[0] - async def fetchmany( - self, query, args, *, timeout: float=None, record_class=None - ): - """Run a query for each sequence of arguments in *args* - and return the results as a list of :class:`Record`. - - :param query: - Query to execute. - :param args: - An iterable containing sequences of arguments for the query. - :param float timeout: - Optional timeout value in seconds. - :param type record_class: - If specified, the class to use for records returned by this method. - Must be a subclass of :class:`~asyncpg.Record`. If not specified, - a per-connection *record_class* is used. - - :return list: - A list of :class:`~asyncpg.Record` instances. If specified, the - actual type of list elements would be *record_class*. - - Example: - - .. code-block:: pycon - - >>> rows = await con.fetchmany(''' - ... INSERT INTO mytab (a, b) VALUES ($1, $2) RETURNING a; - ... ''', [('x', 1), ('y', 2), ('z', 3)]) - >>> rows - [, , ] - - .. versionadded:: 0.30.0 - """ - self._check_open() - return await self._executemany( - query, args, timeout, return_rows=True, record_class=record_class - ) - async def copy_from_table(self, table_name, *, output, columns=None, schema_name=None, timeout=None, format=None, oids=None, delimiter=None, @@ -837,7 +800,7 @@ class Connection(metaclass=ConnectionMeta): ... output='file.csv', format='csv') ... print(result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 100' .. _`COPY statement documentation`: @@ -906,7 +869,7 @@ class Connection(metaclass=ConnectionMeta): ... output='file.csv', format='csv') ... print(result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 10' .. _`COPY statement documentation`: @@ -982,7 +945,7 @@ class Connection(metaclass=ConnectionMeta): ... 'mytable', source='datafile.tbl') ... print(result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 140000' .. _`COPY statement documentation`: @@ -1064,7 +1027,7 @@ class Connection(metaclass=ConnectionMeta): ... (2, 'ham', 'spam')]) ... print(result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 2' Asynchronous record iterables are also supported: @@ -1082,7 +1045,7 @@ class Connection(metaclass=ConnectionMeta): ... 'mytable', records=record_gen(100)) ... print(result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) 'COPY 100' .. versionadded:: 0.11.0 @@ -1342,7 +1305,7 @@ class Connection(metaclass=ConnectionMeta): ... print(result) ... print(datetime.datetime(2002, 1, 1) + result) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) relativedelta(years=+2, months=+3, days=+1) 2004-04-02 00:00:00 @@ -1515,10 +1478,11 @@ class Connection(metaclass=ConnectionMeta): self._abort() self._cleanup() - async def _reset(self): + async def reset(self, *, timeout=None): self._check_open() self._listeners.clear() self._log_listeners.clear() + reset_query = self._get_reset_query() if self._protocol.is_in_transaction() or self._top_xact is not None: if self._top_xact is None or not self._top_xact._managed: @@ -1530,36 +1494,10 @@ class Connection(metaclass=ConnectionMeta): }) self._top_xact = None - await self.execute("ROLLBACK") + reset_query = 'ROLLBACK;\n' + reset_query - async def reset(self, *, timeout=None): - """Reset the connection state. - - Calling this will reset the connection session state to a state - resembling that of a newly obtained connection. Namely, an open - transaction (if any) is rolled back, open cursors are closed, - all `LISTEN `_ - registrations are removed, all session configuration - variables are reset to their default values, and all advisory locks - are released. - - Note that the above describes the default query returned by - :meth:`Connection.get_reset_query`. If one overloads the method - by subclassing ``Connection``, then this method will do whatever - the overloaded method returns, except open transactions are always - terminated and any callbacks registered by - :meth:`Connection.add_listener` or :meth:`Connection.add_log_listener` - are removed. - - :param float timeout: - A timeout for resetting the connection. If not specified, defaults - to no timeout. - """ - async with compat.timeout(timeout): - await self._reset() - reset_query = self.get_reset_query() - if reset_query: - await self.execute(reset_query) + if reset_query: + await self.execute(reset_query, timeout=timeout) def _abort(self): # Put the connection into the aborted state. @@ -1720,15 +1658,7 @@ class Connection(metaclass=ConnectionMeta): con_ref = self._proxy return con_ref - def get_reset_query(self): - """Return the query sent to server on connection release. - - The query returned by this method is used by :meth:`Connection.reset`, - which is, in turn, used by :class:`~asyncpg.pool.Pool` before making - the connection available to another acquirer. - - .. versionadded:: 0.30.0 - """ + def _get_reset_query(self): if self._reset_query is not None: return self._reset_query @@ -1842,7 +1772,7 @@ class Connection(metaclass=ConnectionMeta): ... await con.execute('LOCK TABLE tbl') ... await change_type(con) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) .. versionadded:: 0.14.0 """ @@ -1879,8 +1809,9 @@ class Connection(metaclass=ConnectionMeta): :param callable callback: A callable or a coroutine function receiving one argument: - **record**, a LoggedQuery containing `query`, `args`, `timeout`, - `elapsed`, `exception`, `conn_addr`, and `conn_params`. + **record**: a LoggedQuery containing `query`, `args`, `timeout`, + `elapsed`, `exception`, `conn_addr`, and + `conn_params`. Example: @@ -1967,27 +1898,17 @@ class Connection(metaclass=ConnectionMeta): ) return result, stmt - async def _executemany( - self, - query, - args, - timeout, - return_rows=False, - record_class=None, - ): + async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( state=stmt, args=args, portal_name='', timeout=timeout, - return_rows=return_rows, ) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: with self._time_and_log(query, args, timeout): - result, _ = await self._do_execute( - query, executor, timeout, record_class=record_class - ) + result, _ = await self._do_execute(query, executor, timeout) return result async def _do_execute( @@ -2082,13 +2003,11 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, - direct_tls=None, + direct_tls=False, connection_class=Connection, record_class=protocol.Record, server_settings=None, - target_session_attrs=None, - krbsrvname=None, - gsslib=None): + target_session_attrs=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2113,7 +2032,7 @@ async def connect(dsn=None, *, .. note:: The URI must be *valid*, which means that all components must - be properly quoted with :py:func:`urllib.parse.quote_plus`, and + be properly quoted with :py:func:`urllib.parse.quote`, and any literal IPv6 addresses must be enclosed in square brackets. For example: @@ -2316,14 +2235,6 @@ async def connect(dsn=None, *, or the value of the ``PGTARGETSESSIONATTRS`` environment variable, or ``"any"`` if neither is specified. - :param str krbsrvname: - Kerberos service name to use when authenticating with GSSAPI. This - must match the server configuration. Defaults to 'postgres'. - - :param str gsslib: - GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' - or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. - :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2337,7 +2248,7 @@ async def connect(dsn=None, *, ... types = await con.fetch('SELECT * FROM pg_type') ... print(types) ... - >>> asyncio.run(run()) + >>> asyncio.get_event_loop().run_until_complete(run()) [ bool: +def is_scalar_type(typeinfo) -> bool: return ( typeinfo['kind'] in SCALAR_TYPE_KINDS and not typeinfo['elemtype'] ) -def is_domain_type(typeinfo: protocol.Record) -> bool: - return typeinfo['kind'] == b'd' # type: ignore[no-any-return] +def is_domain_type(typeinfo) -> bool: + return typeinfo['kind'] == b'd' -def is_composite_type(typeinfo: protocol.Record) -> bool: - return typeinfo['kind'] == b'c' # type: ignore[no-any-return] +def is_composite_type(typeinfo) -> bool: + return typeinfo['kind'] == b'c' diff --git a/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so index f4b4c5a..2377746 100755 Binary files a/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyi b/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyi deleted file mode 100644 index 24cc630..0000000 --- a/venv/lib/python3.12/site-packages/asyncpg/pgproto/pgproto.pyi +++ /dev/null @@ -1,13 +0,0 @@ -import codecs -import typing -import uuid - -class CodecContext: - def get_text_codec(self) -> codecs.CodecInfo: ... - -class ReadBuffer: ... -class WriteBuffer: ... -class BufferError(Exception): ... - -class UUID(uuid.UUID): - def __init__(self, inp: typing.AnyStr) -> None: ... diff --git a/venv/lib/python3.12/site-packages/asyncpg/pool.py b/venv/lib/python3.12/site-packages/asyncpg/pool.py index e3898d5..06e698d 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/pool.py +++ b/venv/lib/python3.12/site-packages/asyncpg/pool.py @@ -33,8 +33,7 @@ class PoolConnectionProxyMeta(type): if not inspect.isfunction(meth): continue - iscoroutine = inspect.iscoroutinefunction(meth) - wrapper = mcls._wrap_connection_method(attrname, iscoroutine) + wrapper = mcls._wrap_connection_method(attrname) wrapper = functools.update_wrapper(wrapper, meth) dct[attrname] = wrapper @@ -44,7 +43,7 @@ class PoolConnectionProxyMeta(type): return super().__new__(mcls, name, bases, dct) @staticmethod - def _wrap_connection_method(meth_name, iscoroutine): + def _wrap_connection_method(meth_name): def call_con_method(self, *args, **kwargs): # This method will be owned by PoolConnectionProxy class. if self._con is None: @@ -56,9 +55,6 @@ class PoolConnectionProxyMeta(type): meth = getattr(self._con.__class__, meth_name) return meth(self._con, *args, **kwargs) - if iscoroutine: - compat.markcoroutinefunction(call_con_method) - return call_con_method @@ -210,12 +206,7 @@ class PoolConnectionHolder: if budget is not None: budget -= time.monotonic() - started - if self._pool._reset is not None: - async with compat.timeout(budget): - await self._con._reset() - await self._pool._reset(self._con) - else: - await self._con.reset(timeout=budget) + await self._con.reset(timeout=budget) except (Exception, asyncio.CancelledError) as ex: # If the `reset` call failed, terminate the connection. # A new one will be created when `acquire` is called @@ -318,7 +309,7 @@ class Pool: __slots__ = ( '_queue', '_loop', '_minsize', '_maxsize', - '_init', '_connect', '_reset', '_connect_args', '_connect_kwargs', + '_init', '_connect_args', '_connect_kwargs', '_holders', '_initialized', '_initializing', '_closing', '_closed', '_connection_class', '_record_class', '_generation', '_setup', '_max_queries', '_max_inactive_connection_lifetime' @@ -329,10 +320,8 @@ class Pool: max_size, max_queries, max_inactive_connection_lifetime, - connect=None, - setup=None, - init=None, - reset=None, + setup, + init, loop, connection_class, record_class, @@ -392,22 +381,18 @@ class Pool: self._closing = False self._closed = False self._generation = 0 - - self._connect = connect if connect is not None else connection.connect + self._init = init self._connect_args = connect_args self._connect_kwargs = connect_kwargs self._setup = setup - self._init = init - self._reset = reset - self._max_queries = max_queries self._max_inactive_connection_lifetime = \ max_inactive_connection_lifetime async def _async__init__(self): if self._initialized: - return self + return if self._initializing: raise exceptions.InterfaceError( 'pool is being initialized in another task') @@ -514,25 +499,13 @@ class Pool: self._connect_kwargs = connect_kwargs async def _get_new_connection(self): - con = await self._connect( + con = await connection.connect( *self._connect_args, loop=self._loop, connection_class=self._connection_class, record_class=self._record_class, **self._connect_kwargs, ) - if not isinstance(con, self._connection_class): - good = self._connection_class - good_n = f'{good.__module__}.{good.__name__}' - bad = type(con) - if bad.__module__ == "builtins": - bad_n = bad.__name__ - else: - bad_n = f'{bad.__module__}.{bad.__name__}' - raise exceptions.InterfaceError( - "expected pool connect callback to return an instance of " - f"'{good_n}', got " f"'{bad_n}'" - ) if self._init is not None: try: @@ -632,22 +605,6 @@ class Pool: record_class=record_class ) - async def fetchmany(self, query, args, *, timeout=None, record_class=None): - """Run a query for each sequence of arguments in *args* - and return the results as a list of :class:`Record`. - - Pool performs this operation using one of its connections. Other than - that, it behaves identically to - :meth:`Connection.fetchmany() - `. - - .. versionadded:: 0.30.0 - """ - async with self.acquire() as con: - return await con.fetchmany( - query, args, timeout=timeout, record_class=record_class - ) - async def copy_from_table( self, table_name, @@ -1040,10 +997,8 @@ def create_pool(dsn=None, *, max_size=10, max_queries=50000, max_inactive_connection_lifetime=300.0, - connect=None, setup=None, init=None, - reset=None, loop=None, connection_class=connection.Connection, record_class=protocol.Record, @@ -1124,16 +1079,9 @@ def create_pool(dsn=None, *, Number of seconds after which inactive connections in the pool will be closed. Pass ``0`` to disable this mechanism. - :param coroutine connect: - A coroutine that is called instead of - :func:`~asyncpg.connection.connect` whenever the pool needs to make a - new connection. Must return an instance of type specified by - *connection_class* or :class:`~asyncpg.connection.Connection` if - *connection_class* was not specified. - :param coroutine setup: A coroutine to prepare a connection right before it is returned - from :meth:`Pool.acquire()`. An example use + from :meth:`Pool.acquire() `. An example use case would be to automatically set up notifications listeners for all connections of a pool. @@ -1145,25 +1093,6 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. - :param coroutine reset: - A coroutine to reset a connection before it is returned to the pool by - :meth:`Pool.release()`. The function is supposed - to reset any changes made to the database session so that the next - acquirer gets the connection in a well-defined state. - - The default implementation calls :meth:`Connection.reset() <\ - asyncpg.connection.Connection.reset>`, which runs the following:: - - SELECT pg_advisory_unlock_all(); - CLOSE ALL; - UNLISTEN *; - RESET ALL; - - The exact reset query is determined by detected server capabilities, - and a custom *reset* implementation can obtain the default query - by calling :meth:`Connection.get_reset_query() <\ - asyncpg.connection.Connection.get_reset_query>`. - :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -1190,22 +1119,12 @@ def create_pool(dsn=None, *, .. versionchanged:: 0.22.0 Added the *record_class* parameter. - - .. versionchanged:: 0.30.0 - Added the *connect* and *reset* parameters. """ return Pool( dsn, connection_class=connection_class, record_class=record_class, - min_size=min_size, - max_size=max_size, - max_queries=max_queries, - loop=loop, - connect=connect, - setup=setup, - init=init, - reset=reset, + min_size=min_size, max_size=max_size, + max_queries=max_queries, loop=loop, setup=setup, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, - **connect_kwargs, - ) + **connect_kwargs) diff --git a/venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py b/venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py index d66a5ad..8e241d6 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py +++ b/venv/lib/python3.12/site-packages/asyncpg/prepared_stmt.py @@ -147,8 +147,8 @@ class PreparedStatement(connresource.ConnectionResource): # will discard any output that a SELECT would return, other # side effects of the statement will happen as usual. If you # wish to use EXPLAIN ANALYZE on an INSERT, UPDATE, DELETE, - # MERGE, CREATE TABLE AS, or EXECUTE statement without letting - # the command affect your data, use this approach: + # CREATE TABLE AS, or EXECUTE statement without letting the + # command affect your data, use this approach: # BEGIN; # EXPLAIN ANALYZE ...; # ROLLBACK; @@ -210,27 +210,6 @@ class PreparedStatement(connresource.ConnectionResource): return None return data[0] - @connresource.guarded - async def fetchmany(self, args, *, timeout=None): - """Execute the statement and return a list of :class:`Record` objects. - - :param args: Query arguments. - :param float timeout: Optional timeout value in seconds. - - :return: A list of :class:`Record` instances. - - .. versionadded:: 0.30.0 - """ - return await self.__do_execute( - lambda protocol: protocol.bind_execute_many( - self._state, - args, - portal_name='', - timeout=timeout, - return_rows=True, - ) - ) - @connresource.guarded async def executemany(self, args, *, timeout: float=None): """Execute the statement for each sequence of arguments in *args*. @@ -243,12 +222,7 @@ class PreparedStatement(connresource.ConnectionResource): """ return await self.__do_execute( lambda protocol: protocol.bind_execute_many( - self._state, - args, - portal_name='', - timeout=timeout, - return_rows=False, - )) + self._state, args, '', timeout)) async def __do_execute(self, executor): protocol = self._connection._protocol diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py b/venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py index af9287b..8b3e06a 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/__init__.py @@ -6,6 +6,4 @@ # flake8: NOQA -from __future__ import annotations - from .protocol import Protocol, Record, NO_TIMEOUT, BUILTIN_TYPE_NAME_MAP diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx b/venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx index e8b44c7..c269e37 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/codecs/base.pyx @@ -483,7 +483,7 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: cdef class DataCodecConfig: - def __init__(self): + def __init__(self, cache_key): # Codec instance cache for derived types: # composites, arrays, ranges, domains and their combinations. self._derived_type_codecs = {} diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd b/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd index 34c7c71..7ce4f57 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pxd @@ -51,6 +51,16 @@ cdef enum AuthenticationMessage: AUTH_SASL_FINAL = 12 +AUTH_METHOD_NAME = { + AUTH_REQUIRED_KERBEROS: 'kerberosv5', + AUTH_REQUIRED_PASSWORD: 'password', + AUTH_REQUIRED_PASSWORDMD5: 'md5', + AUTH_REQUIRED_GSS: 'gss', + AUTH_REQUIRED_SASL: 'scram-sha-256', + AUTH_REQUIRED_SSPI: 'sspi', +} + + cdef enum ResultType: RESULT_OK = 1 RESULT_FAILED = 2 @@ -86,13 +96,10 @@ cdef class CoreProtocol: object transport - object address # Instance of _ConnectionParameters object con_params # Instance of SCRAMAuthentication SCRAMAuthentication scram - # Instance of gssapi.SecurityContext or sspilib.SecurityContext - object gss_ctx readonly int32_t backend_pid readonly int32_t backend_secret @@ -138,10 +145,6 @@ cdef class CoreProtocol: cdef _auth_password_message_md5(self, bytes salt) cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) cdef _auth_password_message_sasl_continue(self, bytes server_response) - cdef _auth_gss_init_gssapi(self) - cdef _auth_gss_init_sspi(self, bint negotiate) - cdef _auth_gss_get_service(self) - cdef _auth_gss_step(self, bytes server_response) cdef _write(self, buf) cdef _writelines(self, list buffers) @@ -171,7 +174,7 @@ cdef class CoreProtocol: cdef _bind_execute(self, str portal_name, str stmt_name, WriteBuffer bind_data, int32_t limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data, bint return_rows) + object bind_data) cdef bint _bind_execute_many_more(self, bint first=*) cdef _bind_execute_many_fail(self, object error, bint first=*) cdef _bind(self, str portal_name, str stmt_name, diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx b/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx index 1985787..64afe93 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/coreproto.pyx @@ -11,20 +11,9 @@ import hashlib include "scram.pyx" -cdef dict AUTH_METHOD_NAME = { - AUTH_REQUIRED_KERBEROS: 'kerberosv5', - AUTH_REQUIRED_PASSWORD: 'password', - AUTH_REQUIRED_PASSWORDMD5: 'md5', - AUTH_REQUIRED_GSS: 'gss', - AUTH_REQUIRED_SASL: 'scram-sha-256', - AUTH_REQUIRED_SSPI: 'sspi', -} - - cdef class CoreProtocol: - def __init__(self, addr, con_params): - self.address = addr + def __init__(self, con_params): # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() self.user = con_params.user @@ -37,9 +26,6 @@ cdef class CoreProtocol: self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` self.scram = None - # type of `gss_ctx` is `gssapi.SecurityContext` or - # `sspilib.SecurityContext` - self.gss_ctx = None self._reset_result() @@ -633,35 +619,22 @@ cdef class CoreProtocol: 'could not verify server signature for ' 'SCRAM authentciation: scram-sha-256', ) - self.scram = None - elif status in (AUTH_REQUIRED_GSS, AUTH_REQUIRED_SSPI): - # AUTH_REQUIRED_SSPI is the same as AUTH_REQUIRED_GSS, except that - # it uses protocol negotiation with SSPI clients. Both methods use - # AUTH_REQUIRED_GSS_CONTINUE for subsequent authentication steps. - if self.gss_ctx is not None: - self.result_type = RESULT_FAILED - self.result = apg_exc.InterfaceError( - 'duplicate GSSAPI/SSPI authentication request') - else: - if self.con_params.gsslib == 'gssapi': - self._auth_gss_init_gssapi() - else: - self._auth_gss_init_sspi(status == AUTH_REQUIRED_SSPI) - self.auth_msg = self._auth_gss_step(None) - - elif status == AUTH_REQUIRED_GSS_CONTINUE: - server_response = self.buffer.consume_message() - self.auth_msg = self._auth_gss_step(server_response) + elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, + AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, + AUTH_REQUIRED_SSPI): + self.result_type = RESULT_FAILED + self.result = apg_exc.InterfaceError( + 'unsupported authentication method requested by the ' + 'server: {!r}'.format(AUTH_METHOD_NAME[status])) else: self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( 'unsupported authentication method requested by the ' - 'server: {!r}'.format(AUTH_METHOD_NAME.get(status, status))) + 'server: {}'.format(status)) - if status not in (AUTH_SASL_CONTINUE, AUTH_SASL_FINAL, - AUTH_REQUIRED_GSS_CONTINUE): + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: self.buffer.discard_message() cdef _auth_password_message_cleartext(self): @@ -718,59 +691,6 @@ cdef class CoreProtocol: return msg - cdef _auth_gss_init_gssapi(self): - try: - import gssapi - except ModuleNotFoundError: - raise apg_exc.InterfaceError( - 'gssapi module not found; please install asyncpg[gssauth] to ' - 'use asyncpg with Kerberos/GSSAPI/SSPI authentication' - ) from None - - service_name, host = self._auth_gss_get_service() - self.gss_ctx = gssapi.SecurityContext( - name=gssapi.Name( - f'{service_name}@{host}', gssapi.NameType.hostbased_service), - usage='initiate') - - cdef _auth_gss_init_sspi(self, bint negotiate): - try: - import sspilib - except ModuleNotFoundError: - raise apg_exc.InterfaceError( - 'sspilib module not found; please install asyncpg[gssauth] to ' - 'use asyncpg with Kerberos/GSSAPI/SSPI authentication' - ) from None - - service_name, host = self._auth_gss_get_service() - self.gss_ctx = sspilib.ClientSecurityContext( - target_name=f'{service_name}/{host}', - credential=sspilib.UserCredential( - protocol='Negotiate' if negotiate else 'Kerberos')) - - cdef _auth_gss_get_service(self): - service_name = self.con_params.krbsrvname or 'postgres' - if isinstance(self.address, str): - raise apg_exc.InternalClientError( - 'GSSAPI/SSPI authentication is only supported for TCP/IP ' - 'connections') - - return service_name, self.address[0] - - cdef _auth_gss_step(self, bytes server_response): - cdef: - WriteBuffer msg - - token = self.gss_ctx.step(server_response) - if not token: - self.gss_ctx = None - return None - msg = WriteBuffer.new_message(b'p') - msg.write_bytes(token) - msg.end_message() - - return msg - cdef _parse_msg_ready_for_query(self): cdef char status = self.buffer.read_byte() @@ -1020,12 +940,12 @@ cdef class CoreProtocol: self._send_bind_message(portal_name, stmt_name, bind_data, limit) cdef bint _bind_execute_many(self, str portal_name, str stmt_name, - object bind_data, bint return_rows): + object bind_data): self._ensure_connected() self._set_state(PROTOCOL_BIND_EXECUTE_MANY) - self.result = [] if return_rows else None - self._discard_data = not return_rows + self.result = None + self._discard_data = True self._execute_iter = bind_data self._execute_portal_name = portal_name self._execute_stmt_name = stmt_name diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx b/venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx index cb0afa2..7335825 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/prepared_stmt.pyx @@ -142,7 +142,7 @@ cdef class PreparedStatementState: # that the user tried to parametrize a statement that does # not support parameters. hint += (r' Note that parameters are supported only in' - r' SELECT, INSERT, UPDATE, DELETE, MERGE and VALUES' + r' SELECT, INSERT, UPDATE, DELETE, and VALUES' r' statements, and will *not* work in statements ' r' like CREATE VIEW or DECLARE CURSOR.') diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so index 1d84afb..da7e65e 100755 Binary files a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd index cd221fb..a9ac8d5 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pxd @@ -31,6 +31,7 @@ cdef class BaseProtocol(CoreProtocol): cdef: object loop + object address ConnectionSettings settings object cancel_sent_waiter object cancel_waiter diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyi b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyi deleted file mode 100644 index b81d13c..0000000 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyi +++ /dev/null @@ -1,300 +0,0 @@ -import asyncio -import asyncio.protocols -import hmac -from codecs import CodecInfo -from collections.abc import Callable, Iterable, Iterator, Sequence -from hashlib import md5, sha256 -from typing import ( - Any, - ClassVar, - Final, - Generic, - Literal, - NewType, - TypeVar, - final, - overload, -) -from typing_extensions import TypeAlias - -import asyncpg.pgproto.pgproto - -from ..connect_utils import _ConnectionParameters -from ..pgproto.pgproto import WriteBuffer -from ..types import Attribute, Type - -_T = TypeVar('_T') -_Record = TypeVar('_Record', bound=Record) -_OtherRecord = TypeVar('_OtherRecord', bound=Record) -_PreparedStatementState = TypeVar( - '_PreparedStatementState', bound=PreparedStatementState[Any] -) - -_NoTimeoutType = NewType('_NoTimeoutType', object) -_TimeoutType: TypeAlias = float | None | _NoTimeoutType - -BUILTIN_TYPE_NAME_MAP: Final[dict[str, int]] -BUILTIN_TYPE_OID_MAP: Final[dict[int, str]] -NO_TIMEOUT: Final[_NoTimeoutType] - -hashlib_md5 = md5 - -@final -class ConnectionSettings(asyncpg.pgproto.pgproto.CodecContext): - __pyx_vtable__: Any - def __init__(self, conn_key: object) -> None: ... - def add_python_codec( - self, - typeoid: int, - typename: str, - typeschema: str, - typeinfos: Iterable[object], - typekind: str, - encoder: Callable[[Any], Any], - decoder: Callable[[Any], Any], - format: object, - ) -> Any: ... - def clear_type_cache(self) -> None: ... - def get_data_codec( - self, oid: int, format: object = ..., ignore_custom_codec: bool = ... - ) -> Any: ... - def get_text_codec(self) -> CodecInfo: ... - def register_data_types(self, types: Iterable[object]) -> None: ... - def remove_python_codec( - self, typeoid: int, typename: str, typeschema: str - ) -> None: ... - def set_builtin_type_codec( - self, - typeoid: int, - typename: str, - typeschema: str, - typekind: str, - alias_to: str, - format: object = ..., - ) -> Any: ... - def __getattr__(self, name: str) -> Any: ... - def __reduce__(self) -> Any: ... - -@final -class PreparedStatementState(Generic[_Record]): - closed: bool - prepared: bool - name: str - query: str - refs: int - record_class: type[_Record] - ignore_custom_codec: bool - __pyx_vtable__: Any - def __init__( - self, - name: str, - query: str, - protocol: BaseProtocol[Any], - record_class: type[_Record], - ignore_custom_codec: bool, - ) -> None: ... - def _get_parameters(self) -> tuple[Type, ...]: ... - def _get_attributes(self) -> tuple[Attribute, ...]: ... - def _init_types(self) -> set[int]: ... - def _init_codecs(self) -> None: ... - def attach(self) -> None: ... - def detach(self) -> None: ... - def mark_closed(self) -> None: ... - def mark_unprepared(self) -> None: ... - def __reduce__(self) -> Any: ... - -class CoreProtocol: - backend_pid: Any - backend_secret: Any - __pyx_vtable__: Any - def __init__(self, addr: object, con_params: _ConnectionParameters) -> None: ... - def is_in_transaction(self) -> bool: ... - def __reduce__(self) -> Any: ... - -class BaseProtocol(CoreProtocol, Generic[_Record]): - queries_count: Any - is_ssl: bool - __pyx_vtable__: Any - def __init__( - self, - addr: object, - connected_fut: object, - con_params: _ConnectionParameters, - record_class: type[_Record], - loop: object, - ) -> None: ... - def set_connection(self, connection: object) -> None: ... - def get_server_pid(self, *args: object, **kwargs: object) -> int: ... - def get_settings(self, *args: object, **kwargs: object) -> ConnectionSettings: ... - def get_record_class(self) -> type[_Record]: ... - def abort(self) -> None: ... - async def bind( - self, - state: PreparedStatementState[_OtherRecord], - args: Sequence[object], - portal_name: str, - timeout: _TimeoutType, - ) -> Any: ... - @overload - async def bind_execute( - self, - state: PreparedStatementState[_OtherRecord], - args: Sequence[object], - portal_name: str, - limit: int, - return_extra: Literal[False], - timeout: _TimeoutType, - ) -> list[_OtherRecord]: ... - @overload - async def bind_execute( - self, - state: PreparedStatementState[_OtherRecord], - args: Sequence[object], - portal_name: str, - limit: int, - return_extra: Literal[True], - timeout: _TimeoutType, - ) -> tuple[list[_OtherRecord], bytes, bool]: ... - @overload - async def bind_execute( - self, - state: PreparedStatementState[_OtherRecord], - args: Sequence[object], - portal_name: str, - limit: int, - return_extra: bool, - timeout: _TimeoutType, - ) -> list[_OtherRecord] | tuple[list[_OtherRecord], bytes, bool]: ... - async def bind_execute_many( - self, - state: PreparedStatementState[_OtherRecord], - args: Iterable[Sequence[object]], - portal_name: str, - timeout: _TimeoutType, - ) -> None: ... - async def close(self, timeout: _TimeoutType) -> None: ... - def _get_timeout(self, timeout: _TimeoutType) -> float | None: ... - def _is_cancelling(self) -> bool: ... - async def _wait_for_cancellation(self) -> None: ... - async def close_statement( - self, state: PreparedStatementState[_OtherRecord], timeout: _TimeoutType - ) -> Any: ... - async def copy_in(self, *args: object, **kwargs: object) -> str: ... - async def copy_out(self, *args: object, **kwargs: object) -> str: ... - async def execute(self, *args: object, **kwargs: object) -> Any: ... - def is_closed(self, *args: object, **kwargs: object) -> Any: ... - def is_connected(self, *args: object, **kwargs: object) -> Any: ... - def data_received(self, data: object) -> None: ... - def connection_made(self, transport: object) -> None: ... - def connection_lost(self, exc: Exception | None) -> None: ... - def pause_writing(self, *args: object, **kwargs: object) -> Any: ... - @overload - async def prepare( - self, - stmt_name: str, - query: str, - timeout: float | None = ..., - *, - state: _PreparedStatementState, - ignore_custom_codec: bool = ..., - record_class: None, - ) -> _PreparedStatementState: ... - @overload - async def prepare( - self, - stmt_name: str, - query: str, - timeout: float | None = ..., - *, - state: None = ..., - ignore_custom_codec: bool = ..., - record_class: type[_OtherRecord], - ) -> PreparedStatementState[_OtherRecord]: ... - async def close_portal(self, portal_name: str, timeout: _TimeoutType) -> None: ... - async def query(self, *args: object, **kwargs: object) -> str: ... - def resume_writing(self, *args: object, **kwargs: object) -> Any: ... - def __reduce__(self) -> Any: ... - -@final -class Codec: - __pyx_vtable__: Any - def __reduce__(self) -> Any: ... - -class DataCodecConfig: - __pyx_vtable__: Any - def __init__(self) -> None: ... - def add_python_codec( - self, - typeoid: int, - typename: str, - typeschema: str, - typekind: str, - typeinfos: Iterable[object], - encoder: Callable[[ConnectionSettings, WriteBuffer, object], object], - decoder: Callable[..., object], - format: object, - xformat: object, - ) -> Any: ... - def add_types(self, types: Iterable[object]) -> Any: ... - def clear_type_cache(self) -> None: ... - def declare_fallback_codec(self, oid: int, name: str, schema: str) -> Codec: ... - def remove_python_codec( - self, typeoid: int, typename: str, typeschema: str - ) -> Any: ... - def set_builtin_type_codec( - self, - typeoid: int, - typename: str, - typeschema: str, - typekind: str, - alias_to: str, - format: object = ..., - ) -> Any: ... - def __reduce__(self) -> Any: ... - -class Protocol(BaseProtocol[_Record], asyncio.protocols.Protocol): ... - -class Record: - @overload - def get(self, key: str) -> Any | None: ... - @overload - def get(self, key: str, default: _T) -> Any | _T: ... - def items(self) -> Iterator[tuple[str, Any]]: ... - def keys(self) -> Iterator[str]: ... - def values(self) -> Iterator[Any]: ... - @overload - def __getitem__(self, index: str) -> Any: ... - @overload - def __getitem__(self, index: int) -> Any: ... - @overload - def __getitem__(self, index: slice) -> tuple[Any, ...]: ... - def __iter__(self) -> Iterator[Any]: ... - def __contains__(self, x: object) -> bool: ... - def __len__(self) -> int: ... - -class Timer: - def __init__(self, budget: float | None) -> None: ... - def __enter__(self) -> None: ... - def __exit__(self, et: object, e: object, tb: object) -> None: ... - def get_remaining_budget(self) -> float: ... - def has_budget_greater_than(self, amount: float) -> bool: ... - -@final -class SCRAMAuthentication: - AUTHENTICATION_METHODS: ClassVar[list[str]] - DEFAULT_CLIENT_NONCE_BYTES: ClassVar[int] - DIGEST = sha256 - REQUIREMENTS_CLIENT_FINAL_MESSAGE: ClassVar[list[str]] - REQUIREMENTS_CLIENT_PROOF: ClassVar[list[str]] - SASLPREP_PROHIBITED: ClassVar[tuple[Callable[[str], bool], ...]] - authentication_method: bytes - authorization_message: bytes | None - client_channel_binding: bytes - client_first_message_bare: bytes | None - client_nonce: bytes | None - client_proof: bytes | None - password_salt: bytes | None - password_iterations: int - server_first_message: bytes | None - server_key: hmac.HMAC | None - server_nonce: bytes | None diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx index bd2ad05..b43b0e9 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/protocol.pyx @@ -75,7 +75,7 @@ NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): def __init__(self, addr, connected_fut, con_params, record_class: type, loop): # type of `con_params` is `_ConnectionParameters` - CoreProtocol.__init__(self, addr, con_params) + CoreProtocol.__init__(self, con_params) self.loop = loop self.transport = None @@ -83,7 +83,8 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_waiter = None self.cancel_sent_waiter = None - self.settings = ConnectionSettings((addr, con_params.database)) + self.address = addr + self.settings = ConnectionSettings((self.address, con_params.database)) self.record_class = record_class self.statement = None @@ -212,7 +213,6 @@ cdef class BaseProtocol(CoreProtocol): args, portal_name: str, timeout, - return_rows: bool, ): if self.cancel_waiter is not None: await self.cancel_waiter @@ -238,8 +238,7 @@ cdef class BaseProtocol(CoreProtocol): more = self._bind_execute_many( portal_name, state.name, - arg_bufs, - return_rows) # network op + arg_bufs) # network op self.last_query = state.query self.statement = state diff --git a/venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx b/venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx index 2b53566..8e6591b 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx +++ b/venv/lib/python3.12/site-packages/asyncpg/protocol/settings.pyx @@ -11,12 +11,12 @@ from asyncpg import exceptions @cython.final cdef class ConnectionSettings(pgproto.CodecContext): - def __cinit__(self): + def __cinit__(self, conn_key): self._encoding = 'utf-8' self._is_utf8 = True self._settings = {} self._codec = codecs.lookup('utf-8') - self._data_codecs = DataCodecConfig() + self._data_codecs = DataCodecConfig(conn_key) cdef add_setting(self, str name, str val): self._settings[name] = val diff --git a/venv/lib/python3.12/site-packages/asyncpg/serverversion.py b/venv/lib/python3.12/site-packages/asyncpg/serverversion.py index ee9647b..31568a2 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/serverversion.py +++ b/venv/lib/python3.12/site-packages/asyncpg/serverversion.py @@ -4,14 +4,12 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from __future__ import annotations import re -import typing from .types import ServerVersion -version_regex: typing.Final = re.compile( +version_regex = re.compile( r"(Postgre[^\s]*)?\s*" r"(?P[0-9]+)\.?" r"((?P[0-9]+)\.?)?" @@ -21,15 +19,7 @@ version_regex: typing.Final = re.compile( ) -class _VersionDict(typing.TypedDict): - major: int - minor: int | None - micro: int | None - releaselevel: str | None - serial: int | None - - -def split_server_version_string(version_string: str) -> ServerVersion: +def split_server_version_string(version_string): version_match = version_regex.search(version_string) if version_match is None: @@ -38,17 +28,17 @@ def split_server_version_string(version_string: str) -> ServerVersion: f'version from "{version_string}"' ) - version: _VersionDict = version_match.groupdict() # type: ignore[assignment] # noqa: E501 + version = version_match.groupdict() for ver_key, ver_value in version.items(): # Cast all possible versions parts to int try: - version[ver_key] = int(ver_value) # type: ignore[literal-required, call-overload] # noqa: E501 + version[ver_key] = int(ver_value) except (TypeError, ValueError): pass - if version["major"] < 10: + if version.get("major") < 10: return ServerVersion( - version["major"], + version.get("major"), version.get("minor") or 0, version.get("micro") or 0, version.get("releaselevel") or "final", @@ -62,7 +52,7 @@ def split_server_version_string(version_string: str) -> ServerVersion: # want to keep that behaviour consistent, i.e not fail # a major version check due to a bugfix release. return ServerVersion( - version["major"], + version.get("major"), 0, version.get("minor") or 0, version.get("releaselevel") or "final", diff --git a/venv/lib/python3.12/site-packages/asyncpg/types.py b/venv/lib/python3.12/site-packages/asyncpg/types.py index 7a24e24..bd5813f 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/types.py +++ b/venv/lib/python3.12/site-packages/asyncpg/types.py @@ -4,18 +4,14 @@ # This module is part of asyncpg and is released under # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from __future__ import annotations -import typing +import collections from asyncpg.pgproto.types import ( BitString, Point, Path, Polygon, Box, Line, LineSegment, Circle, ) -if typing.TYPE_CHECKING: - from typing_extensions import Self - __all__ = ( 'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon', @@ -23,13 +19,7 @@ __all__ = ( ) -class Type(typing.NamedTuple): - oid: int - name: str - kind: str - schema: str - - +Type = collections.namedtuple('Type', ['oid', 'name', 'kind', 'schema']) Type.__doc__ = 'Database data type.' Type.oid.__doc__ = 'OID of the type.' Type.name.__doc__ = 'Type name. For example "int2".' @@ -38,61 +28,25 @@ Type.kind.__doc__ = \ Type.schema.__doc__ = 'Name of the database schema that defines the type.' -class Attribute(typing.NamedTuple): - name: str - type: Type - - +Attribute = collections.namedtuple('Attribute', ['name', 'type']) Attribute.__doc__ = 'Database relation attribute.' Attribute.name.__doc__ = 'Attribute name.' Attribute.type.__doc__ = 'Attribute data type :class:`asyncpg.types.Type`.' -class ServerVersion(typing.NamedTuple): - major: int - minor: int - micro: int - releaselevel: str - serial: int - - +ServerVersion = collections.namedtuple( + 'ServerVersion', ['major', 'minor', 'micro', 'releaselevel', 'serial']) ServerVersion.__doc__ = 'PostgreSQL server version tuple.' -class _RangeValue(typing.Protocol): - def __eq__(self, __value: object) -> bool: - ... - - def __lt__(self, __other: _RangeValue) -> bool: - ... - - def __gt__(self, __other: _RangeValue) -> bool: - ... - - -_RV = typing.TypeVar('_RV', bound=_RangeValue) - - -class Range(typing.Generic[_RV]): +class Range: """Immutable representation of PostgreSQL `range` type.""" - __slots__ = ('_lower', '_upper', '_lower_inc', '_upper_inc', '_empty') + __slots__ = '_lower', '_upper', '_lower_inc', '_upper_inc', '_empty' - _lower: _RV | None - _upper: _RV | None - _lower_inc: bool - _upper_inc: bool - _empty: bool - - def __init__( - self, - lower: _RV | None = None, - upper: _RV | None = None, - *, - lower_inc: bool = True, - upper_inc: bool = False, - empty: bool = False - ) -> None: + def __init__(self, lower=None, upper=None, *, + lower_inc=True, upper_inc=False, + empty=False): self._empty = empty if empty: self._lower = self._upper = None @@ -104,34 +58,34 @@ class Range(typing.Generic[_RV]): self._upper_inc = upper is not None and upper_inc @property - def lower(self) -> _RV | None: + def lower(self): return self._lower @property - def lower_inc(self) -> bool: + def lower_inc(self): return self._lower_inc @property - def lower_inf(self) -> bool: + def lower_inf(self): return self._lower is None and not self._empty @property - def upper(self) -> _RV | None: + def upper(self): return self._upper @property - def upper_inc(self) -> bool: + def upper_inc(self): return self._upper_inc @property - def upper_inf(self) -> bool: + def upper_inf(self): return self._upper is None and not self._empty @property - def isempty(self) -> bool: + def isempty(self): return self._empty - def _issubset_lower(self, other: Self) -> bool: + def _issubset_lower(self, other): if other._lower is None: return True if self._lower is None: @@ -142,7 +96,7 @@ class Range(typing.Generic[_RV]): and (other._lower_inc or not self._lower_inc) ) - def _issubset_upper(self, other: Self) -> bool: + def _issubset_upper(self, other): if other._upper is None: return True if self._upper is None: @@ -153,7 +107,7 @@ class Range(typing.Generic[_RV]): and (other._upper_inc or not self._upper_inc) ) - def issubset(self, other: Self) -> bool: + def issubset(self, other): if self._empty: return True if other._empty: @@ -161,13 +115,13 @@ class Range(typing.Generic[_RV]): return self._issubset_lower(other) and self._issubset_upper(other) - def issuperset(self, other: Self) -> bool: + def issuperset(self, other): return other.issubset(self) - def __bool__(self) -> bool: + def __bool__(self): return not self._empty - def __eq__(self, other: object) -> bool: + def __eq__(self, other): if not isinstance(other, Range): return NotImplemented @@ -178,14 +132,14 @@ class Range(typing.Generic[_RV]): self._upper_inc, self._empty ) == ( - other._lower, # pyright: ignore [reportUnknownMemberType] - other._upper, # pyright: ignore [reportUnknownMemberType] + other._lower, + other._upper, other._lower_inc, other._upper_inc, other._empty ) - def __hash__(self) -> int: + def __hash__(self): return hash(( self._lower, self._upper, @@ -194,7 +148,7 @@ class Range(typing.Generic[_RV]): self._empty )) - def __repr__(self) -> str: + def __repr__(self): if self._empty: desc = 'empty' else: diff --git a/venv/lib/python3.12/site-packages/asyncpg/utils.py b/venv/lib/python3.12/site-packages/asyncpg/utils.py index 5c1ca69..3940e04 100644 --- a/venv/lib/python3.12/site-packages/asyncpg/utils.py +++ b/venv/lib/python3.12/site-packages/asyncpg/utils.py @@ -42,11 +42,4 @@ async def _mogrify(conn, query, args): # Finally, replace $n references with text values. return re.sub( - r"\$(\d+)\b", - lambda m: ( - textified[int(m.group(1)) - 1] - if textified[int(m.group(1)) - 1] is not None - else "NULL" - ), - query, - ) + r'\$(\d+)\b', lambda m: textified[int(m.group(1)) - 1], query) diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/METADATA similarity index 79% rename from venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/METADATA index 629a992..789f784 100644 --- a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/METADATA @@ -1,32 +1,30 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: bcrypt -Version: 5.0.0 +Version: 4.0.1 Summary: Modern password hashing for your software and your servers -Author-email: The Python Cryptographic Authority developers -License: Apache-2.0 -Project-URL: homepage, https://github.com/pyca/bcrypt/ +Home-page: https://github.com/pyca/bcrypt/ +Author: The Python Cryptographic Authority developers +Author-email: cryptography-dev@python.org +License: Apache License, Version 2.0 +Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: License :: OSI Approved :: Apache Software License Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 -Classifier: Programming Language :: Python :: 3.14 -Classifier: Programming Language :: Python :: Free Threading :: 3 - Stable -Requires-Python: >=3.8 +Requires-Python: >=3.6 Description-Content-Type: text/x-rst License-File: LICENSE Provides-Extra: tests -Requires-Dist: pytest!=3.3.0,>=3.2.1; extra == "tests" +Requires-Dist: pytest (!=3.3.0,>=3.2.1) ; extra == 'tests' Provides-Extra: typecheck -Requires-Dist: mypy; extra == "typecheck" -Dynamic: license-file +Requires-Dist: mypy ; extra == 'typecheck' bcrypt ====== @@ -47,7 +45,7 @@ Installation To install bcrypt, simply: -.. code:: console +.. code:: bash $ pip install bcrypt @@ -56,19 +54,19 @@ compiler and a Rust compiler (the minimum supported Rust version is 1.56.0). For Debian and Ubuntu, the following command will ensure that the required dependencies are installed: -.. code:: console +.. code:: bash $ sudo apt-get install build-essential cargo For Fedora and RHEL-derivatives, the following command will ensure that the required dependencies are installed: -.. code:: console +.. code:: bash $ sudo yum install gcc cargo For Alpine, the following command will ensure that the required dependencies are installed: -.. code:: console +.. code:: bash $ apk add --update musl-dev gcc cargo @@ -81,62 +79,6 @@ While bcrypt remains an acceptable choice for password storage, depending on you Changelog ========= -5.0.0 ------ - -* Bumped MSRV to 1.74. -* Added support for Python 3.14 and free-threaded Python 3.14. -* Added support for Windows on ARM. -* Passing ``hashpw`` a password longer than 72 bytes now raises a - ``ValueError``. Previously the password was silently truncated, following the - behavior of the original OpenBSD ``bcrypt`` implementation. - -4.3.0 ------ - -* Dropped support for Python 3.7. -* We now support free-threaded Python 3.13. -* We now support PyPy 3.11. -* We now publish wheels for free-threaded Python 3.13, for PyPy 3.11 on - ``manylinux``, and for ARMv7l on ``manylinux``. - -4.2.1 ------ - -* Bump Rust dependency versions - this should resolve crashes on Python 3.13 - free-threaded builds. -* We no longer build ``manylinux`` wheels for PyPy 3.9. - -4.2.0 ------ - -* Bump Rust dependency versions -* Removed the ``BCRYPT_ALLOW_RUST_163`` environment variable. - -4.1.3 ------ - -* Bump Rust dependency versions - -4.1.2 ------ - -* Publish both ``py37`` and ``py39`` wheels. This should resolve some errors - relating to initializing a module multiple times per process. - -4.1.1 ------ - -* Fixed the type signature on the ``kdf`` method. -* Fixed packaging bug on Windows. -* Fixed incompatibility with passlib package detection assumptions. - -4.1.0 ------ - -* Dropped support for Python 3.6. -* Bumped MSRV to 1.64. (Note: Rust 1.63 can be used by setting the ``BCRYPT_ALLOW_RUST_163`` environment variable) - 4.0.1 ----- @@ -329,7 +271,12 @@ Compatibility ------------- This library should be compatible with py-bcrypt and it will run on Python -3.8+ (including free-threaded builds), and PyPy 3. +3.6+, and PyPy 3. + +C Code +------ + +This library uses code from OpenBSD. Security -------- @@ -341,3 +288,5 @@ identify a vulnerability, we ask you to contact us privately. .. _`standard library`: https://docs.python.org/3/library/hashlib.html#hashlib.scrypt .. _`argon2_cffi`: https://argon2-cffi.readthedocs.io .. _`cryptography`: https://cryptography.io/en/latest/hazmat/primitives/key-derivation-functions/#cryptography.hazmat.primitives.kdf.scrypt.Scrypt + + diff --git a/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/RECORD new file mode 100644 index 0000000..217979e --- /dev/null +++ b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/RECORD @@ -0,0 +1,14 @@ +bcrypt-4.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +bcrypt-4.0.1.dist-info/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850 +bcrypt-4.0.1.dist-info/METADATA,sha256=peZwWFa95xnpp4NiIE7gJkV01CTkbVXIzoEN66SXd3c,8972 +bcrypt-4.0.1.dist-info/RECORD,, +bcrypt-4.0.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +bcrypt-4.0.1.dist-info/WHEEL,sha256=ZXaM-AC_dnzk1sUAdQV_bMrIMG6zI-GthFaEkNkWsgU,112 +bcrypt-4.0.1.dist-info/top_level.txt,sha256=BkR_qBzDbSuycMzHWE1vzXrfYecAzUVmQs6G2CukqNI,7 +bcrypt/__about__.py,sha256=F7i0CQOa8G3Yjw1T71jQv8yi__Z_4TzLyZJv1GFqVx0,1320 +bcrypt/__init__.py,sha256=EpUdbfHaiHlSoaM-SSUB6MOgNpWOIkS0ZrjxogPIRLM,3781 +bcrypt/__pycache__/__about__.cpython-312.pyc,, +bcrypt/__pycache__/__init__.cpython-312.pyc,, +bcrypt/_bcrypt.abi3.so,sha256=_T-y5IrekziUzkYio4hWH7Xzw92XBKewSLd8kmERhGU,1959696 +bcrypt/_bcrypt.pyi,sha256=O-vvHdooGyAxIkdKemVqOzBF5aMhh0evPSaDMgETgEk,214 +bcrypt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/WHEEL new file mode 100644 index 0000000..dc50279 --- /dev/null +++ b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: false +Tag: cp36-abi3-manylinux_2_28_x86_64 + diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/bcrypt-4.0.1.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/RECORD deleted file mode 100644 index 2692cc7..0000000 --- a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/RECORD +++ /dev/null @@ -1,12 +0,0 @@ -bcrypt-5.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -bcrypt-5.0.0.dist-info/METADATA,sha256=yV1BfLlI6udlVy23eNbzDa62DSEbUrlWvlLBCI6UAdI,10524 -bcrypt-5.0.0.dist-info/RECORD,, -bcrypt-5.0.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -bcrypt-5.0.0.dist-info/WHEEL,sha256=WieEZvWpc0Erab6-NfTu9412g-GcE58js6gvBn3Q7B4,111 -bcrypt-5.0.0.dist-info/licenses/LICENSE,sha256=gXPVwptPlW1TJ4HSuG5OMPg-a3h43OGMkZRR1rpwfJA,10850 -bcrypt-5.0.0.dist-info/top_level.txt,sha256=BkR_qBzDbSuycMzHWE1vzXrfYecAzUVmQs6G2CukqNI,7 -bcrypt/__init__.py,sha256=cv-NupIX6P7o6A4PK_F0ur6IZoDr3GnvyzFO9k16wKQ,1000 -bcrypt/__init__.pyi,sha256=ITUCB9mPVU8sKUbJQMDUH5YfQXZb1O55F9qvKZR_o8I,333 -bcrypt/__pycache__/__init__.cpython-312.pyc,, -bcrypt/_bcrypt.abi3.so,sha256=oFwJu4Gq44FqJDttx_oWpypfuUQ30BkCWzD2FhojdYw,631768 -bcrypt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/WHEEL deleted file mode 100644 index eb203c1..0000000 --- a/venv/lib/python3.12/site-packages/bcrypt-5.0.0.dist-info/WHEEL +++ /dev/null @@ -1,5 +0,0 @@ -Wheel-Version: 1.0 -Generator: setuptools (80.9.0) -Root-Is-Purelib: false -Tag: cp39-abi3-manylinux_2_34_x86_64 - diff --git a/venv/lib/python3.12/site-packages/bcrypt/__about__.py b/venv/lib/python3.12/site-packages/bcrypt/__about__.py new file mode 100644 index 0000000..020b748 --- /dev/null +++ b/venv/lib/python3.12/site-packages/bcrypt/__about__.py @@ -0,0 +1,41 @@ +# Author:: Donald Stufft () +# Copyright:: Copyright (c) 2013 Donald Stufft +# License:: Apache License, Version 2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals + +__all__ = [ + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", +] + +__title__ = "bcrypt" +__summary__ = "Modern password hashing for your software and your servers" +__uri__ = "https://github.com/pyca/bcrypt/" + +__version__ = "4.0.1" + +__author__ = "The Python Cryptographic Authority developers" +__email__ = "cryptography-dev@python.org" + +__license__ = "Apache License, Version 2.0" +__copyright__ = "Copyright 2013-2022 {0}".format(__author__) diff --git a/venv/lib/python3.12/site-packages/bcrypt/__init__.py b/venv/lib/python3.12/site-packages/bcrypt/__init__.py index 81a92fd..1f2886f 100644 --- a/venv/lib/python3.12/site-packages/bcrypt/__init__.py +++ b/venv/lib/python3.12/site-packages/bcrypt/__init__.py @@ -1,3 +1,7 @@ +# Author:: Donald Stufft () +# Copyright:: Copyright (c) 2013 Donald Stufft +# License:: Apache License, Version 2.0 +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -9,8 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import absolute_import +from __future__ import division -from ._bcrypt import ( +import hmac +import os +import warnings + +from .__about__ import ( __author__, __copyright__, __email__, @@ -18,26 +28,100 @@ from ._bcrypt import ( __summary__, __title__, __uri__, - checkpw, - gensalt, - hashpw, - kdf, -) -from ._bcrypt import ( - __version_ex__ as __version__, + __version__, ) +from . import _bcrypt # noqa: I100 + __all__ = [ - "__author__", - "__copyright__", - "__email__", - "__license__", - "__summary__", "__title__", + "__summary__", "__uri__", "__version__", - "checkpw", + "__author__", + "__email__", + "__license__", + "__copyright__", "gensalt", "hashpw", "kdf", + "checkpw", ] + + +def gensalt(rounds: int = 12, prefix: bytes = b"2b") -> bytes: + if prefix not in (b"2a", b"2b"): + raise ValueError("Supported prefixes are b'2a' or b'2b'") + + if rounds < 4 or rounds > 31: + raise ValueError("Invalid rounds") + + salt = os.urandom(16) + output = _bcrypt.encode_base64(salt) + + return ( + b"$" + + prefix + + b"$" + + ("%2.2u" % rounds).encode("ascii") + + b"$" + + output + ) + + +def hashpw(password: bytes, salt: bytes) -> bytes: + if isinstance(password, str) or isinstance(salt, str): + raise TypeError("Strings must be encoded before hashing") + + # bcrypt originally suffered from a wraparound bug: + # http://www.openwall.com/lists/oss-security/2012/01/02/4 + # This bug was corrected in the OpenBSD source by truncating inputs to 72 + # bytes on the updated prefix $2b$, but leaving $2a$ unchanged for + # compatibility. However, pyca/bcrypt 2.0.0 *did* correctly truncate inputs + # on $2a$, so we do it here to preserve compatibility with 2.0.0 + password = password[:72] + + return _bcrypt.hashpass(password, salt) + + +def checkpw(password: bytes, hashed_password: bytes) -> bool: + if isinstance(password, str) or isinstance(hashed_password, str): + raise TypeError("Strings must be encoded before checking") + + ret = hashpw(password, hashed_password) + return hmac.compare_digest(ret, hashed_password) + + +def kdf( + password: bytes, + salt: bytes, + desired_key_bytes: int, + rounds: int, + ignore_few_rounds: bool = False, +) -> bytes: + if isinstance(password, str) or isinstance(salt, str): + raise TypeError("Strings must be encoded before hashing") + + if len(password) == 0 or len(salt) == 0: + raise ValueError("password and salt must not be empty") + + if desired_key_bytes <= 0 or desired_key_bytes > 512: + raise ValueError("desired_key_bytes must be 1-512") + + if rounds < 1: + raise ValueError("rounds must be 1 or more") + + if rounds < 50 and not ignore_few_rounds: + # They probably think bcrypt.kdf()'s rounds parameter is logarithmic, + # expecting this value to be slow enough (it probably would be if this + # were bcrypt). Emit a warning. + warnings.warn( + ( + "Warning: bcrypt.kdf() called with only {0} round(s). " + "This few is not secure: the parameter is linear, like PBKDF2." + ).format(rounds), + UserWarning, + stacklevel=2, + ) + + return _bcrypt.pbkdf(password, salt, rounds, desired_key_bytes) diff --git a/venv/lib/python3.12/site-packages/bcrypt/__init__.pyi b/venv/lib/python3.12/site-packages/bcrypt/__init__.pyi deleted file mode 100644 index 12e4a2e..0000000 --- a/venv/lib/python3.12/site-packages/bcrypt/__init__.pyi +++ /dev/null @@ -1,10 +0,0 @@ -def gensalt(rounds: int = 12, prefix: bytes = b"2b") -> bytes: ... -def hashpw(password: bytes, salt: bytes) -> bytes: ... -def checkpw(password: bytes, hashed_password: bytes) -> bool: ... -def kdf( - password: bytes, - salt: bytes, - desired_key_bytes: int, - rounds: int, - ignore_few_rounds: bool = False, -) -> bytes: ... diff --git a/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.abi3.so b/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.abi3.so index 4806fec..5651953 100755 Binary files a/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.abi3.so and b/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.abi3.so differ diff --git a/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.pyi b/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.pyi new file mode 100644 index 0000000..640e913 --- /dev/null +++ b/venv/lib/python3.12/site-packages/bcrypt/_bcrypt.pyi @@ -0,0 +1,7 @@ +import typing + +def encode_base64(data: bytes) -> bytes: ... +def hashpass(password: bytes, salt: bytes) -> bytes: ... +def pbkdf( + password: bytes, salt: bytes, rounds: int, desired_key_bytes: int +) -> bytes: ... diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/METADATA b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/METADATA similarity index 74% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/METADATA rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/METADATA index 6322f88..cab80e6 100644 --- a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: celery -Version: 5.5.3 +Version: 5.3.4 Summary: Distributed Task Queue. Home-page: https://docs.celeryq.dev/ Author: Ask Solem @@ -24,136 +24,110 @@ Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Operating System :: OS Independent Requires-Python: >=3.8 License-File: LICENSE -Requires-Dist: billiard<5.0,>=4.2.1 -Requires-Dist: kombu<5.6,>=5.5.2 -Requires-Dist: vine<6.0,>=5.1.0 -Requires-Dist: click<9.0,>=8.1.2 -Requires-Dist: click-didyoumean>=0.3.0 -Requires-Dist: click-repl>=0.2.0 -Requires-Dist: click-plugins>=1.1.1 -Requires-Dist: backports.zoneinfo[tzdata]>=0.2.1; python_version < "3.9" -Requires-Dist: python-dateutil>=2.8.2 +Requires-Dist: billiard (<5.0,>=4.1.0) +Requires-Dist: kombu (<6.0,>=5.3.2) +Requires-Dist: vine (<6.0,>=5.0.0) +Requires-Dist: click (<9.0,>=8.1.2) +Requires-Dist: click-didyoumean (>=0.3.0) +Requires-Dist: click-repl (>=0.2.0) +Requires-Dist: click-plugins (>=1.1.1) +Requires-Dist: tzdata (>=2022.7) +Requires-Dist: python-dateutil (>=2.8.2) +Requires-Dist: importlib-metadata (>=3.6) ; python_version < "3.8" +Requires-Dist: backports.zoneinfo (>=0.2.1) ; python_version < "3.9" Provides-Extra: arangodb -Requires-Dist: pyArango>=2.0.2; extra == "arangodb" +Requires-Dist: pyArango (>=2.0.2) ; extra == 'arangodb' Provides-Extra: auth -Requires-Dist: cryptography==44.0.2; extra == "auth" +Requires-Dist: cryptography (==41.0.3) ; extra == 'auth' Provides-Extra: azureblockblob -Requires-Dist: azure-storage-blob>=12.15.0; extra == "azureblockblob" -Requires-Dist: azure-identity>=1.19.0; extra == "azureblockblob" +Requires-Dist: azure-storage-blob (>=12.15.0) ; extra == 'azureblockblob' Provides-Extra: brotli -Requires-Dist: brotlipy>=0.7.0; platform_python_implementation == "PyPy" and extra == "brotli" -Requires-Dist: brotli>=1.0.0; platform_python_implementation == "CPython" and extra == "brotli" +Requires-Dist: brotli (>=1.0.0) ; (platform_python_implementation == "CPython") and extra == 'brotli' +Requires-Dist: brotlipy (>=0.7.0) ; (platform_python_implementation == "PyPy") and extra == 'brotli' Provides-Extra: cassandra -Requires-Dist: cassandra-driver<4,>=3.25.0; extra == "cassandra" +Requires-Dist: cassandra-driver (<4,>=3.25.0) ; extra == 'cassandra' Provides-Extra: consul -Requires-Dist: python-consul2==0.1.5; extra == "consul" +Requires-Dist: python-consul2 (==0.1.5) ; extra == 'consul' Provides-Extra: cosmosdbsql -Requires-Dist: pydocumentdb==2.3.5; extra == "cosmosdbsql" +Requires-Dist: pydocumentdb (==2.3.5) ; extra == 'cosmosdbsql' Provides-Extra: couchbase -Requires-Dist: couchbase>=3.0.0; (platform_python_implementation != "PyPy" and (platform_system != "Windows" or python_version < "3.10")) and extra == "couchbase" +Requires-Dist: couchbase (>=3.0.0) ; (platform_python_implementation != "PyPy" and (platform_system != "Windows" or python_version < "3.10")) and extra == 'couchbase' Provides-Extra: couchdb -Requires-Dist: pycouchdb==1.16.0; extra == "couchdb" +Requires-Dist: pycouchdb (==1.14.2) ; extra == 'couchdb' Provides-Extra: django -Requires-Dist: Django>=2.2.28; extra == "django" +Requires-Dist: Django (>=2.2.28) ; extra == 'django' Provides-Extra: dynamodb -Requires-Dist: boto3>=1.26.143; extra == "dynamodb" +Requires-Dist: boto3 (>=1.26.143) ; extra == 'dynamodb' Provides-Extra: elasticsearch -Requires-Dist: elasticsearch<=8.17.2; extra == "elasticsearch" -Requires-Dist: elastic-transport<=8.17.1; extra == "elasticsearch" +Requires-Dist: elasticsearch (<8.0) ; extra == 'elasticsearch' Provides-Extra: eventlet -Requires-Dist: eventlet>=0.32.0; python_version < "3.10" and extra == "eventlet" +Requires-Dist: eventlet (>=0.32.0) ; (python_version < "3.10") and extra == 'eventlet' Provides-Extra: gevent -Requires-Dist: gevent>=1.5.0; extra == "gevent" -Provides-Extra: gcs -Requires-Dist: google-cloud-storage>=2.10.0; extra == "gcs" -Requires-Dist: google-cloud-firestore==2.20.1; extra == "gcs" -Requires-Dist: grpcio==1.67.0; extra == "gcs" +Requires-Dist: gevent (>=1.5.0) ; extra == 'gevent' Provides-Extra: librabbitmq -Requires-Dist: librabbitmq>=2.0.0; python_version < "3.11" and extra == "librabbitmq" +Requires-Dist: librabbitmq (>=2.0.0) ; (python_version < "3.11") and extra == 'librabbitmq' Provides-Extra: memcache -Requires-Dist: pylibmc==1.6.3; platform_system != "Windows" and extra == "memcache" +Requires-Dist: pylibmc (==1.6.3) ; (platform_system != "Windows") and extra == 'memcache' Provides-Extra: mongodb -Requires-Dist: kombu[mongodb]; extra == "mongodb" +Requires-Dist: pymongo[srv] (>=4.0.2) ; extra == 'mongodb' Provides-Extra: msgpack -Requires-Dist: kombu[msgpack]; extra == "msgpack" +Requires-Dist: msgpack (==1.0.5) ; extra == 'msgpack' Provides-Extra: pymemcache -Requires-Dist: python-memcached>=1.61; extra == "pymemcache" -Provides-Extra: pydantic -Requires-Dist: pydantic>=2.4; extra == "pydantic" +Requires-Dist: python-memcached (==1.59) ; extra == 'pymemcache' Provides-Extra: pyro -Requires-Dist: pyro4==4.82; python_version < "3.11" and extra == "pyro" +Requires-Dist: pyro4 (==4.82) ; (python_version < "3.11") and extra == 'pyro' Provides-Extra: pytest -Requires-Dist: pytest-celery[all]<1.3.0,>=1.2.0; extra == "pytest" +Requires-Dist: pytest-celery (==0.0.0) ; extra == 'pytest' Provides-Extra: redis -Requires-Dist: kombu[redis]; extra == "redis" +Requires-Dist: redis (!=4.5.5,<5.0.0,>=4.5.2) ; extra == 'redis' Provides-Extra: s3 -Requires-Dist: boto3>=1.26.143; extra == "s3" +Requires-Dist: boto3 (>=1.26.143) ; extra == 's3' Provides-Extra: slmq -Requires-Dist: softlayer_messaging>=1.0.3; extra == "slmq" +Requires-Dist: softlayer-messaging (>=1.0.3) ; extra == 'slmq' Provides-Extra: solar -Requires-Dist: ephem==4.2; platform_python_implementation != "PyPy" and extra == "solar" +Requires-Dist: ephem (==4.1.4) ; (platform_python_implementation != "PyPy") and extra == 'solar' Provides-Extra: sqlalchemy -Requires-Dist: kombu[sqlalchemy]; extra == "sqlalchemy" +Requires-Dist: sqlalchemy (<2.1,>=1.4.48) ; extra == 'sqlalchemy' Provides-Extra: sqs -Requires-Dist: boto3>=1.26.143; extra == "sqs" -Requires-Dist: urllib3>=1.26.16; extra == "sqs" -Requires-Dist: kombu[sqs]>=5.5.0; extra == "sqs" +Requires-Dist: boto3 (>=1.26.143) ; extra == 'sqs' +Requires-Dist: urllib3 (>=1.26.16) ; extra == 'sqs' +Requires-Dist: kombu[sqs] (>=5.3.0) ; extra == 'sqs' +Requires-Dist: pycurl (>=7.43.0.5) ; (sys_platform != "win32" and platform_python_implementation == "CPython") and extra == 'sqs' Provides-Extra: tblib -Requires-Dist: tblib>=1.5.0; python_version >= "3.8.0" and extra == "tblib" -Requires-Dist: tblib>=1.3.0; python_version < "3.8.0" and extra == "tblib" +Requires-Dist: tblib (>=1.3.0) ; (python_version < "3.8.0") and extra == 'tblib' +Requires-Dist: tblib (>=1.5.0) ; (python_version >= "3.8.0") and extra == 'tblib' Provides-Extra: yaml -Requires-Dist: kombu[yaml]; extra == "yaml" +Requires-Dist: PyYAML (>=3.10) ; extra == 'yaml' Provides-Extra: zookeeper -Requires-Dist: kazoo>=1.3.1; extra == "zookeeper" +Requires-Dist: kazoo (>=1.3.1) ; extra == 'zookeeper' Provides-Extra: zstd -Requires-Dist: zstandard==0.23.0; extra == "zstd" -Dynamic: author -Dynamic: author-email -Dynamic: classifier -Dynamic: description -Dynamic: home-page -Dynamic: keywords -Dynamic: license -Dynamic: license-file -Dynamic: platform -Dynamic: project-url -Dynamic: provides-extra -Dynamic: requires-dist -Dynamic: requires-python -Dynamic: summary +Requires-Dist: zstandard (==0.21.0) ; extra == 'zstd' .. image:: https://docs.celeryq.dev/en/latest/_images/celery-banner-small.png |build-status| |coverage| |license| |wheel| |semgrep| |pyversion| |pyimp| |ocbackerbadge| |ocsponsorbadge| -:Version: 5.5.3 (immunity) +:Version: 5.3.4 (emerald-rush) :Web: https://docs.celeryq.dev/en/stable/index.html :Download: https://pypi.org/project/celery/ :Source: https://github.com/celery/celery/ -:DeepWiki: |deepwiki| :Keywords: task, queue, job, async, rabbitmq, amqp, redis, python, distributed, actors Donations ========= -Open Collective ---------------- +This project relies on your generous donations. -.. image:: https://opencollective.com/static/images/opencollectivelogo-footer-n.svg - :alt: Open Collective logo - :width: 200px +If you are using Celery to create a commercial product, please consider becoming our `backer`_ or our `sponsor`_ to ensure Celery's future. -`Open Collective `_ is our community-powered funding platform that fuels Celery's -ongoing development. Your sponsorship directly supports improvements, maintenance, and innovative features that keep -Celery robust and reliable. +.. _`backer`: https://opencollective.com/celery#backer +.. _`sponsor`: https://opencollective.com/celery#sponsor For enterprise ============== @@ -162,47 +136,6 @@ Available as part of the Tidelift Subscription. The maintainers of ``celery`` and thousands of other packages are working with Tidelift to deliver commercial support and maintenance for the open source dependencies you use to build your applications. Save time, reduce risk, and improve code health, while paying the maintainers of the exact dependencies you use. `Learn more. `_ -Sponsors -======== - -Blacksmith ----------- - -.. image:: ./docs/images/blacksmith-logo-white-on-black.svg - :alt: Blacksmith logo - :width: 240px - -`Official Announcement `_ - -Upstash -------- - -.. image:: https://upstash.com/logo/upstash-dark-bg.svg - :alt: Upstash logo - :width: 200px - -`Upstash `_ offers a serverless Redis database service, -providing a seamless solution for Celery users looking to leverage -serverless architectures. Upstash's serverless Redis service is designed -with an eventual consistency model and durable storage, facilitated -through a multi-tier storage architecture. - -Dragonfly ---------- - -.. image:: https://github.com/celery/celery/raw/main/docs/images/dragonfly.svg - :alt: Dragonfly logo - :width: 150px - -`Dragonfly `_ is a drop-in Redis replacement that cuts costs and boosts performance. -Designed to fully utilize the power of modern cloud hardware and deliver on the data demands of modern applications, -Dragonfly frees developers from the limits of traditional in-memory data stores. - - - -.. |oc-sponsor-1| image:: https://opencollective.com/celery/sponsor/0/avatar.svg - :target: https://opencollective.com/celery/sponsor/0/website - What's a Task Queue? ==================== @@ -235,10 +168,10 @@ in such a way that the client enqueues an URL to be requested by a worker. What do I need? =============== -Celery version 5.5.x runs on: +Celery version 5.3.4 runs on: -- Python (3.8, 3.9, 3.10, 3.11, 3.12, 3.13) -- PyPy3.9+ (v7.3.12+) +- Python (3.8, 3.9, 3.10, 3.11) +- PyPy3.8+ (v7.3.11+) This is the version of celery which will support Python 3.8 or newer. @@ -269,7 +202,7 @@ Get Started =========== If this is the first time you're trying to use Celery, or you're -new to Celery v5.5.x coming from previous versions then you should read our +new to Celery v5.3.4 coming from previous versions then you should read our getting started tutorials: - `First steps with Celery`_ @@ -333,7 +266,7 @@ It supports... - **Message Transports** - - RabbitMQ_, Redis_, Amazon SQS, Google Pub/Sub + - RabbitMQ_, Redis_, Amazon SQS - **Concurrency** @@ -345,7 +278,6 @@ It supports... - memcached - SQLAlchemy, Django ORM - Apache Cassandra, IronCache, Elasticsearch - - Google Cloud Storage - **Serialization** @@ -379,8 +311,6 @@ integration packages: +--------------------+------------------------+ | `Tornado`_ | `tornado-celery`_ | +--------------------+------------------------+ - | `FastAPI`_ | not needed | - +--------------------+------------------------+ The integration packages aren't strictly necessary, but they can make development easier, and sometimes they add important hooks like closing @@ -397,7 +327,6 @@ database connections at ``fork``. .. _`web2py-celery`: https://code.google.com/p/web2py-celery/ .. _`Tornado`: https://www.tornadoweb.org/ .. _`tornado-celery`: https://github.com/mher/tornado-celery/ -.. _`FastAPI`: https://fastapi.tiangolo.com/ .. _celery-documentation: @@ -407,6 +336,8 @@ Documentation The `latest documentation`_ is hosted at Read The Docs, containing user guides, tutorials, and an API reference. +最新的中文文档托管在 https://www.celerycn.io/ 中,包含用户指南、教程、API接口等。 + .. _`latest documentation`: https://docs.celeryq.dev/en/latest/ .. _celery-installation: @@ -496,9 +427,6 @@ Transports and Backends :``celery[s3]``: for using S3 Storage as a result backend. -:``celery[gcs]``: - for using Google Cloud Storage as a result backend. - :``celery[couchbase]``: for using Couchbase as a result backend. @@ -535,10 +463,6 @@ Transports and Backends You should probably not use this in your requirements, it's here for informational purposes only. -:``celery[gcpubsub]``: - for using Google Pub/Sub as a message transport. - - .. _celery-installing-from-source: @@ -668,6 +592,19 @@ Thank you to all our backers! 🙏 [`Become a backer`_] .. |oc-backers| image:: https://opencollective.com/celery/backers.svg?width=890 :target: https://opencollective.com/celery#backers +Sponsors +-------- + +Support this project by becoming a sponsor. Your logo will show up here with a +link to your website. [`Become a sponsor`_] + +.. _`Become a sponsor`: https://opencollective.com/celery#sponsor + +|oc-sponsors| + +.. |oc-sponsors| image:: https://opencollective.com/celery/sponsor/0/avatar.svg + :target: https://opencollective.com/celery/sponsor/0/website + .. _license: License @@ -716,8 +653,3 @@ file in the top distribution directory for the full license text. .. |downloads| image:: https://pepy.tech/badge/celery :alt: Downloads :target: https://pepy.tech/project/celery - -.. |deepwiki| image:: https://devin.ai/assets/deepwiki-badge.png - :alt: Ask http://DeepWiki.com - :target: https://deepwiki.com/celery/celery - :width: 125px diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/RECORD b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/RECORD similarity index 71% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/RECORD rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/RECORD index cd077eb..3ee987c 100644 --- a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/RECORD @@ -1,13 +1,13 @@ -../../../bin/celery,sha256=lKonuVsJ65W3NAxVZLFxKbSmaI38YHsNYMByAh7uwuw,235 -celery-5.5.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -celery-5.5.3.dist-info/METADATA,sha256=0LXMLl9irDLbUsh7Ot_bnv4HOw8PZlC-Ow3BYsuN8zY,22953 -celery-5.5.3.dist-info/RECORD,, -celery-5.5.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -celery-5.5.3.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91 -celery-5.5.3.dist-info/entry_points.txt,sha256=FkfFPVffdhqvYOPHkpE85ki09ni0e906oNdWLdN7z_Q,48 -celery-5.5.3.dist-info/licenses/LICENSE,sha256=w1jN938ou6tQ1KdU4SMRgznBUjA0noK_Zkic7OOsCTo,2717 -celery-5.5.3.dist-info/top_level.txt,sha256=sQQ-a5HNsZIi2A8DiKQnB1HODFMfmrzIAZIE8t_XiOA,7 -celery/__init__.py,sha256=W4mGD3BD5qK5bwfyMUd9RqkMyE41_4BALYzmnEnPD_M,5945 +../../../bin/celery,sha256=0OpNT6_Y6Sx11-WiH9_RMtvjbzkIK39w-IsUCkdOmRA,239 +celery-5.3.4.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +celery-5.3.4.dist-info/LICENSE,sha256=w1jN938ou6tQ1KdU4SMRgznBUjA0noK_Zkic7OOsCTo,2717 +celery-5.3.4.dist-info/METADATA,sha256=VwAVQZ0Kl2NxLaXXqYf8PcnptX9fakvtAmI2xHeTqdo,21051 +celery-5.3.4.dist-info/RECORD,, +celery-5.3.4.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +celery-5.3.4.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92 +celery-5.3.4.dist-info/entry_points.txt,sha256=FkfFPVffdhqvYOPHkpE85ki09ni0e906oNdWLdN7z_Q,48 +celery-5.3.4.dist-info/top_level.txt,sha256=sQQ-a5HNsZIi2A8DiKQnB1HODFMfmrzIAZIE8t_XiOA,7 +celery/__init__.py,sha256=N18V32hIC7cyR2Wp-uucng-ZXRTBlbBqrANrslxVudE,5949 celery/__main__.py,sha256=0iT3WCc80mA88XhdAxTpt_g6TFRgmwHSc9GG-HiPzkE,409 celery/__pycache__/__init__.cpython-312.pyc,, celery/__pycache__/__main__.cpython-312.pyc,, @@ -40,21 +40,21 @@ celery/app/__pycache__/routes.cpython-312.pyc,, celery/app/__pycache__/task.cpython-312.pyc,, celery/app/__pycache__/trace.cpython-312.pyc,, celery/app/__pycache__/utils.cpython-312.pyc,, -celery/app/amqp.py,sha256=jlXBDiFRZqJqu4r2YlSQVzTQM_TkpST82WPBPshZ-nE,23582 +celery/app/amqp.py,sha256=SWV-lr5zv1PJjGMyWQZlbJ0ToaQrzfIpZdOYEaGWgqs,23151 celery/app/annotations.py,sha256=93zuKNCE7pcMD3K5tM5HMeVCQ5lfJR_0htFpottgOeU,1445 celery/app/autoretry.py,sha256=PfSi8sb77jJ57ler-Y5ffdqDWvHMKFgQ_bpVD5937tc,2506 -celery/app/backends.py,sha256=lOQJcKva66fNqfYBuDAcCZIpbHGNKbqsE_hLlB_XdnA,2746 -celery/app/base.py,sha256=nn54l1hjtlXQBqIhHKSBNTZqKhIO961xSklMyyV4Xfw,55932 +celery/app/backends.py,sha256=__GqdylFJSa9G_JDSdXdsygfe7FjK7fgn4fZgetdUMw,2702 +celery/app/base.py,sha256=o68aTkvYf8JoYQWl7j3vtXAP5CiPK4Iwh-5MKgVXRmo,50088 celery/app/builtins.py,sha256=gnOyE07M8zgxatTmb0D0vKztx1sQZaRi_hO_d-FLNUs,6673 -celery/app/control.py,sha256=iWy_E2l1BWX8WtxA5OoW2QtHOrJIJL7OIukkEh85CTo,29231 -celery/app/defaults.py,sha256=Hbcck1I99lT8cLdh-JACZQUDCeYrbV_gPIj9sClEaWg,15647 +celery/app/control.py,sha256=La-b_hQGnyWxoM5PIMr-aIzeyasRKkfNJXRvznMHjjk,29170 +celery/app/defaults.py,sha256=XzImSLArwDREJWJbgt1bDz-Cgdxtq9cBfSixa85IQ0Y,15014 celery/app/events.py,sha256=9ZyjdhUVvrt6xLdOMOVTPN7gjydLWQGNr4hvFoProuA,1326 -celery/app/log.py,sha256=pSW4hbrH6M_e1CNXYQ8Dxkst7XM5JzfBJvM8R9QnlJQ,9102 +celery/app/log.py,sha256=uAlmoLQH347P1WroX13J2XolenmcyBIi2a-aD6kMnZk,9067 celery/app/registry.py,sha256=imdGUFb9CS4iiZ1pxAwcQAbe1JKKjyv9WTy94qHHQvk,2001 -celery/app/routes.py,sha256=phoACykZ3ESCNXh5X1oAwQwilGu-0wp5TUi_cahogx8,4551 -celery/app/task.py,sha256=ySJ9-7mkb8PSkFw3JMQBL3W12vzYSvR5utCW_JGdIBE,44274 -celery/app/trace.py,sha256=w0qM9MGHeJzOFJREq5m4obPJ6D5hCMHAJwmGdB6n8PM,27551 -celery/app/utils.py,sha256=eZG28T4SMQNUOWpNVQHFozOzvVmYvU9ST9etbhCQXrg,13171 +celery/app/routes.py,sha256=DMdr5nmEnqJWXkLFIzWWxM2sz9ZYeA--8FeSaxKcBCg,4527 +celery/app/task.py,sha256=4bknTqa3yZ_0VFVb_aX9glA3YwCmpAP1KzCOV2x7p6A,43278 +celery/app/trace.py,sha256=cblXI8oJIU_CmJYvvES6BzcRsW9t6NguQuzDmOzdKWY,28434 +celery/app/utils.py,sha256=52e5u-PUJbwEHtNr_XdpJNnuHdC9c2q6FPkiBu_1SmY,13160 celery/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 celery/apps/__pycache__/__init__.cpython-312.pyc,, celery/apps/__pycache__/beat.cpython-312.pyc,, @@ -62,7 +62,7 @@ celery/apps/__pycache__/multi.cpython-312.pyc,, celery/apps/__pycache__/worker.cpython-312.pyc,, celery/apps/beat.py,sha256=BX7NfHO_BYy9OuVTcSnyrOTVS1eshFctHDpYGfgKT5A,5724 celery/apps/multi.py,sha256=1pujkm0isInjAR9IHno5JucuWcwZAJ1mtqJU1DVkJQo,16360 -celery/apps/worker.py,sha256=o2nJ53_rmgYtlzgW7dTL63InMqaOQjTAZwLT1YGzh0U,20297 +celery/apps/worker.py,sha256=B1_uXLtclcrQAVHupd9B8pXubk4TCOIytGbWIsEioeQ,13208 celery/backends/__init__.py,sha256=1kN92df1jDp3gC6mrGEZI2eE-kOEUIKdOOHRAdry2a0,23 celery/backends/__pycache__/__init__.cpython-312.pyc,, celery/backends/__pycache__/arangodb.cpython-312.pyc,, @@ -78,36 +78,34 @@ celery/backends/__pycache__/couchdb.cpython-312.pyc,, celery/backends/__pycache__/dynamodb.cpython-312.pyc,, celery/backends/__pycache__/elasticsearch.cpython-312.pyc,, celery/backends/__pycache__/filesystem.cpython-312.pyc,, -celery/backends/__pycache__/gcs.cpython-312.pyc,, celery/backends/__pycache__/mongodb.cpython-312.pyc,, celery/backends/__pycache__/redis.cpython-312.pyc,, celery/backends/__pycache__/rpc.cpython-312.pyc,, celery/backends/__pycache__/s3.cpython-312.pyc,, celery/backends/arangodb.py,sha256=aMwuBglVJxigWN8L9NWh-q2NjPQegw__xgRcTMLf5eU,5937 celery/backends/asynchronous.py,sha256=1_tCrURDVg0FvZhRzlRGYwTmsdWK14nBzvPulhwJeR4,10309 -celery/backends/azureblockblob.py,sha256=vMg80FGC1hRQhYYGHIjlFi_Qa8Fb3ktt0xP_vkH5LzQ,6071 -celery/backends/base.py,sha256=w2UPVsGasypjCd4rdGkOo9blIsoTZWrhuPuaWg_nfYQ,44038 +celery/backends/azureblockblob.py,sha256=7jbjTmChq_uJlvzg06dp9q9-sMHKuS0Z3LyjXjgycdk,5127 +celery/backends/base.py,sha256=A4rgCmGvCjlLqfJGuQydE4Dft9WGUfKTqa79FAIUAsk,43970 celery/backends/cache.py,sha256=_o9EBmBByNsbI_UF-PJ5W0u-qwcJ37Q5jaIrApPO4q8,4831 -celery/backends/cassandra.py,sha256=QkXkaYShcf34jBrXe_JJfzx1cj8uXoSRTOAc49cw3Jk,9014 +celery/backends/cassandra.py,sha256=xB5z3JtNqmnaQY8bjst-PR1dnNgZrX8lKwEQpYiRhv8,9006 celery/backends/consul.py,sha256=oAB_94ftS95mjycQ4YL4zIdA-tGmwFyq3B0OreyBPNQ,3816 celery/backends/cosmosdbsql.py,sha256=XdCVCjxO71XhsgiM9DueJngmKx_tE0erexHf37-JhqE,6777 celery/backends/couchbase.py,sha256=fyyihfJNW6hWgVlHKuTCHkzWlDjkzWQAWhgW3GJzAds,3393 celery/backends/couchdb.py,sha256=M_z0zgNFPwFw89paa5kIQ9x9o7VRPwuKCLZgoFhFDpA,2935 -celery/backends/database/__init__.py,sha256=NBdfiaYwWxpGlcP-baWnr18r3leH_b4OW_QsbJMYpSo,8133 +celery/backends/database/__init__.py,sha256=GMBZQy0B1igxHOXP-YoYKkr0FOuxAwesYi6MFz8wRdQ,7751 celery/backends/database/__pycache__/__init__.cpython-312.pyc,, celery/backends/database/__pycache__/models.cpython-312.pyc,, celery/backends/database/__pycache__/session.cpython-312.pyc,, -celery/backends/database/models.py,sha256=j9e_XbXgLfUcRofbhGkVjrVgYQg5UY08vDQ6jmWIk7M,3394 +celery/backends/database/models.py,sha256=_6WZMv53x8I1iBRCa4hY35LaBUeLIZJzDusjvS-8aAg,3351 celery/backends/database/session.py,sha256=3zu7XwYoE52aS6dsSmJanqlvS6ssjet7hSNUbliwnLo,3011 -celery/backends/dynamodb.py,sha256=DGMQ3LbwgZDIm7bp-8_B4QzgvBSR9KS1VNi6piSrLJM,19580 -celery/backends/elasticsearch.py,sha256=26c6z6X08p69cue6-WoQHJNY71Xmq6voaAx3GQ79Vgw,9582 -celery/backends/filesystem.py,sha256=dmxlaTUZP62r2QDCi2n6-7EaPBBSwJWhUPpd2IRmqf0,3777 -celery/backends/gcs.py,sha256=U_ayh1uIR8J_v5nGR9wEeq-80OesKjoeOW4YBrXpJiU,12411 -celery/backends/mongodb.py,sha256=iCeU6WusM7tDm0LHf_3nU7Xn_FQ7r4Xm0FGRzyIqFu0,11438 -celery/backends/redis.py,sha256=d5lTIivhaPqi2ZFX9WQx0YVR4MKx01mWcKNK5BqwBHI,26531 -celery/backends/rpc.py,sha256=3hFLwM_-uAXwZfzDRP5nGVWX4v-w9D0KvyWASdbcbBI,12077 +celery/backends/dynamodb.py,sha256=sEb4TOcrEFOvFU19zRSmXZ-taNDJgbb0_R-4KpNRgcg,17179 +celery/backends/elasticsearch.py,sha256=nseWGjMB49OkHn4LbZLjlo2GLSoHCZOFObklrFsWNW4,8319 +celery/backends/filesystem.py,sha256=Q-8RCPG7TaDVJEOnwMfS8Ggygc8BYcKuBljwzwOegec,3776 +celery/backends/mongodb.py,sha256=XIL1oYEao-YpbmE0CB_sGYP_FJnSP8_CZNouBicxcrg,11419 +celery/backends/redis.py,sha256=wnl45aMLf4SSmX2JDEiFIlnNaKY3I6PBjJeL7adEuCA,26389 +celery/backends/rpc.py,sha256=Pfzjpz7znOfmHRERuQfOlTW-entAsl803oc1-EWpnTY,12077 celery/backends/s3.py,sha256=MUL4-bEHCcTL53XXyb020zyLYTr44DDjOh6BXtkp9lQ,2752 -celery/beat.py,sha256=sIXY81GRrSMcwfgvWCxE4pxandh-XBhReCXvjKOk42o,24544 +celery/beat.py,sha256=j_ZEA73B7NWvlGVbXVcLeOq_tFk0JNT4HiAVdvH7HG4,24455 celery/bin/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 celery/bin/__pycache__/__init__.cpython-312.pyc,, celery/bin/__pycache__/amqp.cpython-312.pyc,, @@ -128,11 +126,11 @@ celery/bin/__pycache__/shell.cpython-312.pyc,, celery/bin/__pycache__/upgrade.cpython-312.pyc,, celery/bin/__pycache__/worker.cpython-312.pyc,, celery/bin/amqp.py,sha256=LTO0FZzKs2Z0MBxkccaDG-dQEsmbaLLhKp-0gR4HdQA,10023 -celery/bin/base.py,sha256=yK_iZpyKbwZQ9ciLRzcrkasw9I-GRa_YB_2EVxK11To,9174 +celery/bin/base.py,sha256=mmF-aIFRXOBdjczGFePXORK2YdxLI-cpsnVrDcNSmAw,8525 celery/bin/beat.py,sha256=qijjERLGEHITaVSGkFgxTxtPYOwl0LUANkC2s2UmNAk,2592 celery/bin/call.py,sha256=_4co_yn2gM5uGP77FjeVqfa7w6VmrEDGSCLPSXYRp-w,2370 -celery/bin/celery.py,sha256=80j70fqa-1TcAYwMN4eysk7fevTqbDy2kx5GNApDxoU,7595 -celery/bin/control.py,sha256=grohiNzi7AQ9l1T9Eed36eU7TKwF2llAs0Cl8VnI8aU,8645 +celery/bin/celery.py,sha256=UW5KmKDphrt7SpyGLnZY16fc6_XI6BdSVdrxb_Vvi3U,7440 +celery/bin/control.py,sha256=nr_kFxalRvKqC2pgJmQVNmRxktnqfStlpRM51I9pXS4,7058 celery/bin/events.py,sha256=fDemvULNVhgG7WiGC-nRnX3yDy4eXTaq8he7T4mD6Jk,2794 celery/bin/graph.py,sha256=Ld2dKSxIdWHxFXrjsTXAUBj6jb02AVGyTPXDUZA_gvo,5796 celery/bin/list.py,sha256=2OKPiXn6sgum_02RH1d_TBoXcpNcNsooT98Ht9pWuaY,1058 @@ -145,7 +143,7 @@ celery/bin/shell.py,sha256=D4Oiw9lEyF-xHJ3fJ5_XckgALDrsDTYlsycT1p4156E,4839 celery/bin/upgrade.py,sha256=EBzSm8hb0n6DXMzG5sW5vC4j6WHYbfrN2Fx83s30i1M,3064 celery/bin/worker.py,sha256=cdYBrO2P3HoNzuPwXIJH4GAMu1KlLTEYF40EkVu0veo,12886 celery/bootsteps.py,sha256=49bMT6CB0LPOK6-i8dLp7Hpko_WaLJ9yWlCWF3Ai2XI,12277 -celery/canvas.py,sha256=2pCVzN6OaLSRQXfm6LmcwDHn2ecCrl6fdd7pkTcSFxk,96992 +celery/canvas.py,sha256=O3S3p0p8K8m4kcy47h4n-hM92Ye9kg870aQEPzJYfXQ,95808 celery/concurrency/__init__.py,sha256=CivIIzjLWHEJf9Ed0QFSTCOxNaWpunFDTzC2jzw3yE0,1457 celery/concurrency/__pycache__/__init__.cpython-312.pyc,, celery/concurrency/__pycache__/asynpool.cpython-312.pyc,, @@ -155,10 +153,10 @@ celery/concurrency/__pycache__/gevent.cpython-312.pyc,, celery/concurrency/__pycache__/prefork.cpython-312.pyc,, celery/concurrency/__pycache__/solo.cpython-312.pyc,, celery/concurrency/__pycache__/thread.cpython-312.pyc,, -celery/concurrency/asynpool.py,sha256=xACoE2WAc05gSxJpljzoxnu-xjR_wBrys3rmCvpT1pk,51822 +celery/concurrency/asynpool.py,sha256=3hlvqZ99tHXzqZZglwoBAOHNbHZ8zVBWd9soWYQrro8,51471 celery/concurrency/base.py,sha256=atOLC90FY7who__TonZbpd2awbOinkgWSx3m15Mg1WI,4706 celery/concurrency/eventlet.py,sha256=i4Xn3Kqg0cxbMyw7_aCTVCi7EOA5aLEiRdkb1xMTpvM,5126 -celery/concurrency/gevent.py,sha256=fiPNf6a380aJOmarkcYSG9FJsSH0DGZS8EjWfIuAhz8,4953 +celery/concurrency/gevent.py,sha256=oExJqOLAWSlV2JlzNnDL22GPlwEpg7ExPJBZMNP4CC8,3387 celery/concurrency/prefork.py,sha256=vdnfeiUtnxa2ZcPSBB-pI6Mwqb2jm8dl-fH_XHPEo6M,5850 celery/concurrency/solo.py,sha256=H9ZaV-RxC30M1YUCjQvLnbDQCTLafwGyC4g4nwqz3uM,754 celery/concurrency/thread.py,sha256=rMpruen--ePsdPoqz9mDwswu5GY3avji_eG-7AAY53I,1807 @@ -170,10 +168,6 @@ celery/contrib/__pycache__/pytest.cpython-312.pyc,, celery/contrib/__pycache__/rdb.cpython-312.pyc,, celery/contrib/__pycache__/sphinx.cpython-312.pyc,, celery/contrib/abortable.py,sha256=ffr47ovGoIUO2gMMSrJwWPP6MSyk3_S1XuS02KxRMu4,5003 -celery/contrib/django/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -celery/contrib/django/__pycache__/__init__.cpython-312.pyc,, -celery/contrib/django/__pycache__/task.cpython-312.pyc,, -celery/contrib/django/task.py,sha256=2-CeHXNq4VRMgUoRMsRLnMFJ-yj2C2WB8nfSNNw58-o,727 celery/contrib/migrate.py,sha256=EvvNWhrykV3lTkZHOghofwemZ-_sixKG97XUyQbS9Dc,14361 celery/contrib/pytest.py,sha256=ztbqIZ0MuWRLTA-RT6k5BKVvuuk2-HPoFD9-q3uHo-s,6754 celery/contrib/rdb.py,sha256=BKorafe3KkOj-tt-bEL39R74u2njv-_7rRHfRajr3Ss,5005 @@ -189,7 +183,7 @@ celery/contrib/testing/app.py,sha256=lvW-YY2H18B60mA5SQetO3CzTI7jKQRsZXGthR27hxE celery/contrib/testing/manager.py,sha256=WnvWLdVJQfSap5rVSKO8NV2gBzWsczmi5Fr3Hp-85-4,8605 celery/contrib/testing/mocks.py,sha256=mcWdsxpTvaWkG-QBGnETLcdevl-bzaq3eSOSsGo2y6w,4182 celery/contrib/testing/tasks.py,sha256=pJM3aabw7udcppz4QNeUg1-6nlnbklrT-hP5JXmL-gM,208 -celery/contrib/testing/worker.py,sha256=RUDXaEaRng6_WD-rydaGziolGEBZ1zhiUiHdCR9DmLA,7217 +celery/contrib/testing/worker.py,sha256=91V-7MfPw7FZC5pBLwvNgJ_ykA5h1QO0DRV1Bu_nI7Q,7051 celery/events/__init__.py,sha256=9d2cviCw5zIsZ3AvQJkx77HPTlxmVIahRR7Qa54nQnU,477 celery/events/__pycache__/__init__.cpython-312.pyc,, celery/events/__pycache__/cursesmon.cpython-312.pyc,, @@ -202,7 +196,7 @@ celery/events/__pycache__/state.cpython-312.pyc,, celery/events/cursesmon.py,sha256=GfQQSJwaMKtZawPsvvQ6qGv7f613hMhAJspDa1hz9OM,17961 celery/events/dispatcher.py,sha256=7b3-3d_6ukvRNajyfiHMX1YvoWNIzaB6zS3-zEUQhG4,8987 celery/events/dumper.py,sha256=7zOVmAVfG2HXW79Fuvpo_0C2cjztTzgIXnaiUc4NL8c,3116 -celery/events/event.py,sha256=JiIqTm_if7OixGHw_RMCJZM3XkIVmmOXa0pdEA2gulA,1750 +celery/events/event.py,sha256=nt1yRUzDrYp9YLbsIJD3eo_AoMhT5sQtZAX-vEkq4Q8,1736 celery/events/receiver.py,sha256=7dVvezYkBQOtyI-rH77-5QDJztPLB933VF7NgmezSuU,4998 celery/events/snapshot.py,sha256=OLQuxx1af29LKnYKDoTesnPfK_5dFx3zCZ7JSdg9t7A,3294 celery/events/state.py,sha256=DdYeAw7hGGFTMc4HRMb0MkizlkJryaysV3t8lXbxhD4,25648 @@ -210,35 +204,34 @@ celery/exceptions.py,sha256=FrlxQiodRtx0RrJfgQo5ZMYTJ8BShrJkteSH29TCUKM,9086 celery/fixups/__init__.py,sha256=7ctNaKHiOa2fVePcdKPU9J-_bQ0k1jFHaoZlCHXY0vU,14 celery/fixups/__pycache__/__init__.cpython-312.pyc,, celery/fixups/__pycache__/django.cpython-312.pyc,, -celery/fixups/django.py,sha256=hdjdpvdZ6v7sx52ri0oS7rIzxC7kMGIX9zOXPK1Lrd4,7427 +celery/fixups/django.py,sha256=Px_oC0wTednDePOV-B9ZokMJJbYAsKhgs0zSH5tKRXA,7161 celery/loaders/__init__.py,sha256=LnRTWk8pz2r7BUj2VUJiBstPjSBwCP0gUDRkbchGW24,490 celery/loaders/__pycache__/__init__.cpython-312.pyc,, celery/loaders/__pycache__/app.cpython-312.pyc,, celery/loaders/__pycache__/base.cpython-312.pyc,, celery/loaders/__pycache__/default.cpython-312.pyc,, celery/loaders/app.py,sha256=xqRpRDJkGmTW21N_7zx5F4Na-GCTbNs6Q6tGfInnZnU,199 -celery/loaders/base.py,sha256=bZ-SwMNLIwhPNxigNJTOukd21QoKNfM8sSRb2C_NWL8,9147 +celery/loaders/base.py,sha256=l2V-9ObaY-TQHSmmouLizOeqrTGtSq7Wvzl0CrPgVZs,8825 celery/loaders/default.py,sha256=TZq6zR4tg_20sVJAuSwSBLVRHRyfevHkHhUYrNRYkTU,1520 -celery/local.py,sha256=aTPsyEVONXA9g2Wt30j66HnlkFiIyud8RKusIQnZJ5I,16039 -celery/platforms.py,sha256=DDCGCp8yt6f_DrZPSiCjWbju2HJCsFWfk0ytSf-BDxA,25610 -celery/result.py,sha256=fBtnxntU8Qzsd8nk3ODIEyR3vtXDXO_SFCY8VimuIMI,35612 -celery/schedules.py,sha256=ATDKxf_yzojN5awmjpS1YkFk-wWCDCc60uBt7GBJO5s,33030 +celery/local.py,sha256=8iy7CIvQRZMw4958J0SjMHcVwW7AIbkaIpBztdS5wiQ,16087 +celery/platforms.py,sha256=CIpGvQoOTrtJluX3BThBvC0iZdj0vwXgCNiOuWVqar8,25290 +celery/result.py,sha256=r4mdMl2Bts3v-1ukZTKvYd1J1SzC6-7ug12SGi9_Gek,35529 +celery/schedules.py,sha256=g40h0m5_0JfM6Rc0CH7TjyK1MC3Cf6M2rDRmGkS8hxs,32003 celery/security/__init__.py,sha256=I1px-x5-19O-FcCQm1AHHfVB6Pp-bauwbZ-C1fxGJyc,2363 celery/security/__pycache__/__init__.cpython-312.pyc,, celery/security/__pycache__/certificate.cpython-312.pyc,, celery/security/__pycache__/key.cpython-312.pyc,, celery/security/__pycache__/serialization.cpython-312.pyc,, celery/security/__pycache__/utils.cpython-312.pyc,, -celery/security/certificate.py,sha256=lopB0DY2fn8uEWz780bqTXPtbEcJTL_OEcO_yeQZWRs,4030 +celery/security/certificate.py,sha256=Jm-XWVQpzJxB52n4V-zHKO3YsNrlkyFpXiYhzB3QJsk,4008 celery/security/key.py,sha256=NbocdV_aJjQMZs9DJZrStpTnkFZw_K8SICEMwalsPqI,1189 -celery/security/serialization.py,sha256=ZGK6MFpphQgue7Rl3XA0n14f91o-JvAXJBbJuTaANgc,3832 +celery/security/serialization.py,sha256=yyCQV8YzHwXr0Ht1KJ9-neUSAZJf2tuzKkpndKpvXqs,4248 celery/security/utils.py,sha256=VJuWxLZFKXQXzlBczuxo94wXWSULnXwbO_5ul_hwse0,845 celery/signals.py,sha256=z2T4UqrODczbaRFAyoNzO0th4lt_jMWzlxnrBh_MUCI,4384 celery/states.py,sha256=CYEkbmDJmMHf2RzTFtafPcu8EBG5wAYz8mt4NduYc7U,3324 celery/utils/__init__.py,sha256=lIJjBxvXCspC-ib-XasdEPlB0xAQc16P0eOPb0gWsL0,935 celery/utils/__pycache__/__init__.cpython-312.pyc,, celery/utils/__pycache__/abstract.cpython-312.pyc,, -celery/utils/__pycache__/annotations.cpython-312.pyc,, celery/utils/__pycache__/collections.cpython-312.pyc,, celery/utils/__pycache__/debug.cpython-312.pyc,, celery/utils/__pycache__/deprecated.cpython-312.pyc,, @@ -249,7 +242,6 @@ celery/utils/__pycache__/iso8601.cpython-312.pyc,, celery/utils/__pycache__/log.cpython-312.pyc,, celery/utils/__pycache__/nodenames.cpython-312.pyc,, celery/utils/__pycache__/objects.cpython-312.pyc,, -celery/utils/__pycache__/quorum_queues.cpython-312.pyc,, celery/utils/__pycache__/saferepr.cpython-312.pyc,, celery/utils/__pycache__/serialization.cpython-312.pyc,, celery/utils/__pycache__/sysinfo.cpython-312.pyc,, @@ -259,33 +251,31 @@ celery/utils/__pycache__/threads.cpython-312.pyc,, celery/utils/__pycache__/time.cpython-312.pyc,, celery/utils/__pycache__/timer2.cpython-312.pyc,, celery/utils/abstract.py,sha256=xN2Qr-TEp12P8AYO6WigxFr5p8kJPUUb0f5UX3FtHjI,2874 -celery/utils/annotations.py,sha256=04zURyjqjDIeLp6ui_I_HdC259Ww6UVAZLmAiUjR3vQ,2084 -celery/utils/collections.py,sha256=KsRWWGePZQelCUHMEvA_pVexh6HpZo1Y1JfCG-rM1f8,25432 +celery/utils/collections.py,sha256=IQH-QPk2en-C04TA_3zH-6bCPdC93eTscGGx-UT_bEw,25454 celery/utils/debug.py,sha256=9g5U0NlTvlP9OFwjxfyXgihfzD-Kk_fcy7QDjhkqapw,4709 celery/utils/deprecated.py,sha256=4asPe222TWJh8mcL53Ob6Y7XROPgqv23nCR-EUHJoBo,3620 celery/utils/dispatch/__init__.py,sha256=s0_ZpvFWXw1cecEue1vj-MpOPQUPE41g5s-YsjnX6mo,74 celery/utils/dispatch/__pycache__/__init__.cpython-312.pyc,, celery/utils/dispatch/__pycache__/signal.cpython-312.pyc,, -celery/utils/dispatch/signal.py,sha256=P1feenrOM5u9OtWV-MCIZTNgjglRJMBH2MgrxHuZ2Bg,13859 +celery/utils/dispatch/signal.py,sha256=LcmfBabnRAOR-wiADWQfBT-gN3Lzi29JpAcCvMLNNX4,13603 celery/utils/functional.py,sha256=TimJEByjq8NtocfSwfEUHoic6G5kCYim3Cl_V84Nnyk,12017 celery/utils/graph.py,sha256=oP25YXsQfND-VwF-MGolOGX0GbReIzVc9SJfIP1rUIc,9041 -celery/utils/imports.py,sha256=K02ZiqLZwGVCYEMnjdIilkuq7n4EnqzFArN6yqEBbC0,5126 -celery/utils/iso8601.py,sha256=0T7k3yiD4AfnUs9GsE2jMk-mDIn5d5011GS0kleUrVo,2916 -celery/utils/log.py,sha256=QCdpoulAOKEZ9TeGRFdrJhbOzLYyhLYcoZd3LUYwUuI,8756 -celery/utils/nodenames.py,sha256=t1qv6YYEkFfGg4j3dvz1IyzvTzV66NZNygSWVhOokiY,3163 +celery/utils/imports.py,sha256=SlTvyvy_91RU-XMgDogLEZiPQytdblura6TLfI34CkA,5032 +celery/utils/iso8601.py,sha256=BIjBHQDYhRWgUPO2PJuQIZr6v1M7bOek8Q7VMbYcQvE,2871 +celery/utils/log.py,sha256=vCbO8Jk0oPdiXCSHTM4plJ83xdfF1qJgg-JUyqbUXXE,8757 +celery/utils/nodenames.py,sha256=URBwdtWR_CF8Ldf6tjxE4y7rl0KxFFD36HjjZcrwQ5Y,2858 celery/utils/objects.py,sha256=NZ_Nx0ehrJut91sruAI2kVGyjhaDQR_ntTmF9Om_SI8,4215 -celery/utils/quorum_queues.py,sha256=HVc01iGI8-g4Esuc6h5hI__JelZLX9ZEKmLsmWsMMEs,705 -celery/utils/saferepr.py,sha256=_5DeQi5UuvPLVEJPpPS-EwtHoISgHYxeKO0NwQ4GGL0,9022 +celery/utils/saferepr.py,sha256=3S99diwXefbcJS5UwRHzn7ZoPuiY9LlZg9ph_Sb872Y,8945 celery/utils/serialization.py,sha256=5e1Blvm8GtkNn3LoDObRN9THJRRVVgmp4OFt0eh1AJM,8209 celery/utils/static/__init__.py,sha256=KwDq8hA-Xd721HldwJJ34ExwrIEyngEoSIzeAnqc5CA,299 celery/utils/static/__pycache__/__init__.cpython-312.pyc,, celery/utils/static/celery_128.png,sha256=8NmZxCALQPp3KVOsOPfJVaNLvwwLYqiS5ViOc6x0SGU,2556 -celery/utils/sysinfo.py,sha256=TbRElxGr1HWDhZB3gvFVJXb2NKFX48RDLFDRqFx26VI,1264 -celery/utils/term.py,sha256=UejfpiJxJd8Lu-wgcsuo_u_01xhmvw6d8sSkXMdk-Ek,5209 +celery/utils/sysinfo.py,sha256=LYdGzxbF357PrYNw31_9f8CEvrldtb0VAWIFclBtCnA,1085 +celery/utils/term.py,sha256=xUQR7vXr_f1-X-TG5o4eAnPGmrh5RM6ffXsdKEaMo6Y,4534 celery/utils/text.py,sha256=e9d5mDgGmyG6xc7PKfmFVnGoGj9DAocJ13uTSZ4Xyqw,5844 celery/utils/threads.py,sha256=_SVLpXSiQQNd2INSaMNC2rGFZHjNDs-lV-NnlWLLz1k,9552 -celery/utils/time.py,sha256=phv7idn7QgGUJedtlBzuRqdKj_b5bruBrv4cfUcmioI,15770 -celery/utils/timer2.py,sha256=hwSESQR33EzeqWtZbNdpqj7mTbSKKIi5ZvUrv_3Lov4,5541 +celery/utils/time.py,sha256=vE2m8q54MQ39-1MPUK5sNyWy0AyN4pyNOR6jhMleXEE,14987 +celery/utils/timer2.py,sha256=xv_7x_bDtILx4regqEm1ppQNenozSwOXi-21qQ4EJG4,4813 celery/worker/__init__.py,sha256=EKUgWOMq_1DfWb-OaAWv4rNLd7gi91aidefMjHMoxzI,95 celery/worker/__pycache__/__init__.cpython-312.pyc,, celery/worker/__pycache__/autoscale.cpython-312.pyc,, @@ -306,7 +296,6 @@ celery/worker/consumer/__pycache__/agent.cpython-312.pyc,, celery/worker/consumer/__pycache__/connection.cpython-312.pyc,, celery/worker/consumer/__pycache__/consumer.cpython-312.pyc,, celery/worker/consumer/__pycache__/control.cpython-312.pyc,, -celery/worker/consumer/__pycache__/delayed_delivery.cpython-312.pyc,, celery/worker/consumer/__pycache__/events.cpython-312.pyc,, celery/worker/consumer/__pycache__/gossip.cpython-312.pyc,, celery/worker/consumer/__pycache__/heart.cpython-312.pyc,, @@ -314,19 +303,18 @@ celery/worker/consumer/__pycache__/mingle.cpython-312.pyc,, celery/worker/consumer/__pycache__/tasks.cpython-312.pyc,, celery/worker/consumer/agent.py,sha256=bThS8ZVeuybAyqNe8jmdN6RgaJhDq0llewosGrO85-c,525 celery/worker/consumer/connection.py,sha256=a7g23wmzevkEiMjjjD8Kt4scihf_NgkpR4gcuksys9M,1026 -celery/worker/consumer/consumer.py,sha256=7lFbFwgbSFGM1Bw-Nj-5NG0ZcC3cIUJRw9ocqyKt-XY,30164 +celery/worker/consumer/consumer.py,sha256=j88iy-6bT5aZNv2NZDjUoHegPHP3cKT4HXZLxI82H4c,28866 celery/worker/consumer/control.py,sha256=0NiJ9P-AHdv134mXkgRgU9hfhdJ_P7HKb7z9A4Xqa2Q,946 -celery/worker/consumer/delayed_delivery.py,sha256=OO3OOq6jkaR2W5o_hdRrWqJ82y5kILAt-JdeadmtnjM,8666 celery/worker/consumer/events.py,sha256=FgDwbV0Jbj9aWPbV3KAUtsXZq4JvZEfrWfnrYgvkMgo,2054 -celery/worker/consumer/gossip.py,sha256=LI8FsUFbNaUQyn600CHcksNbS_jFWzFhgU4fYEt7HhI,6863 +celery/worker/consumer/gossip.py,sha256=g-WJL2rr_q9aM_SaTUrQlPj2ONf8vHs2LvmyRQtDMEU,6833 celery/worker/consumer/heart.py,sha256=IenkkliKk6sAk2a1NfYyh-doNDlmFWGRiaJd5e8ALpI,930 -celery/worker/consumer/mingle.py,sha256=TtQDjAcrJLTDOT14v_QPsV8x_LNo7ZFzkL06LaIazd4,2531 -celery/worker/consumer/tasks.py,sha256=H0NDWrE_VP6zGGBXC02uS3Sf0Lx7Rt0NCSLDRRYC5oY,2703 -celery/worker/control.py,sha256=bcVf7t8RjMpHSBc-LlAw3eu-Dn1fdq4B24OVCW5IY5E,19921 +celery/worker/consumer/mingle.py,sha256=UG8K6sXF1KUJXNiJ4eMHUMIg4_7K1tDWqYRNfd9Nz9k,2519 +celery/worker/consumer/tasks.py,sha256=PwNqAZHJGQakiymFa4q6wbpmDCp3UtSN_7fd5jgATRk,1960 +celery/worker/control.py,sha256=30azpxShUHNuKevEsJG47zQ11ldrEaaq5yatUvQT23U,19884 celery/worker/heartbeat.py,sha256=sTV_d0RB9M6zsXIvLZ7VU6teUfX3IK1ITynDpxMS298,2107 -celery/worker/loops.py,sha256=qGlz-rWkmfUQCZ2TYM3Gpc_f2ihCUAuC1ENZeWDutwM,4599 +celery/worker/loops.py,sha256=W9ayCwYXOA0aCxPPotXc49uA_n7CnMsDRPJVUNb8bZM,4433 celery/worker/pidbox.py,sha256=LcQsKDkd8Z93nQxk0SOLulB8GLEfIjPkN-J0pGk7dfM,3630 -celery/worker/request.py,sha256=IHVVP7zJMEPNvFqLKLXR6wJebS3aLmXjzk9KdR9Esaw,27333 +celery/worker/request.py,sha256=MF7RsVmm4JrybOhnQZguxDcIpEuefdOTMxADDoJvg70,27229 celery/worker/state.py,sha256=_nQgvGeoahKz_TJCx7Tr20kKrNtDgaBA78eA17hA-8s,8583 celery/worker/strategy.py,sha256=MSznfZXkqD6WZRSaanIRZvg-f41DSAc2WgTVUIljh0c,7324 -celery/worker/worker.py,sha256=ivruJ2WK5JyvF7rLYuuMHfVklifOrrQl71lx6g4WUmM,15755 +celery/worker/worker.py,sha256=rNopjWdAzb9Ksszjw9WozvCA5nkDQnbp0n11MeLAitc,14460 diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/WHEEL b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/WHEEL similarity index 65% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/WHEEL index da097d6..1f37c02 100644 --- a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.4.0) +Generator: bdist_wheel (0.40.0) Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/entry_points.txt b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/entry_points.txt similarity index 100% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/entry_points.txt rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/entry_points.txt diff --git a/venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/celery-5.5.3.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/celery-5.3.4.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/celery/__init__.py b/venv/lib/python3.12/site-packages/celery/__init__.py index d291dec..e11a18c 100644 --- a/venv/lib/python3.12/site-packages/celery/__init__.py +++ b/venv/lib/python3.12/site-packages/celery/__init__.py @@ -15,9 +15,9 @@ from collections import namedtuple # Lazy loading from . import local -SERIES = 'immunity' +SERIES = 'emerald-rush' -__version__ = '5.5.3' +__version__ = '5.3.4' __author__ = 'Ask Solem' __contact__ = 'auvipy@gmail.com' __homepage__ = 'https://docs.celeryq.dev/' diff --git a/venv/lib/python3.12/site-packages/celery/app/amqp.py b/venv/lib/python3.12/site-packages/celery/app/amqp.py index 8dcec36..9e52af4 100644 --- a/venv/lib/python3.12/site-packages/celery/app/amqp.py +++ b/venv/lib/python3.12/site-packages/celery/app/amqp.py @@ -249,13 +249,9 @@ class AMQP: if max_priority is None: max_priority = conf.task_queue_max_priority if not queues and conf.task_default_queue: - queue_arguments = None - if conf.task_default_queue_type == 'quorum': - queue_arguments = {'x-queue-type': 'quorum'} queues = (Queue(conf.task_default_queue, exchange=self.default_exchange, - routing_key=default_routing_key, - queue_arguments=queue_arguments),) + routing_key=default_routing_key),) autoexchange = (self.autoexchange if autoexchange is None else autoexchange) return self.queues_cls( @@ -289,7 +285,7 @@ class AMQP: create_sent_event=False, root_id=None, parent_id=None, shadow=None, chain=None, now=None, timezone=None, origin=None, ignore_result=False, argsrepr=None, kwargsrepr=None, stamped_headers=None, - replaced_task_nesting=0, **options): + **options): args = args or () kwargs = kwargs or {} @@ -343,7 +339,6 @@ class AMQP: 'kwargsrepr': kwargsrepr, 'origin': origin or anon_nodename(), 'ignore_result': ignore_result, - 'replaced_task_nesting': replaced_task_nesting, 'stamped_headers': stamped_headers, 'stamps': stamps, } @@ -467,8 +462,7 @@ class AMQP: retry=None, retry_policy=None, serializer=None, delivery_mode=None, compression=None, declare=None, - headers=None, exchange_type=None, - timeout=None, confirm_timeout=None, **kwargs): + headers=None, exchange_type=None, **kwargs): retry = default_retry if retry is None else retry headers2, properties, body, sent_event = message if headers: @@ -529,7 +523,6 @@ class AMQP: retry=retry, retry_policy=_rp, delivery_mode=delivery_mode, declare=declare, headers=headers2, - timeout=timeout, confirm_timeout=confirm_timeout, **properties ) if after_receivers: diff --git a/venv/lib/python3.12/site-packages/celery/app/backends.py b/venv/lib/python3.12/site-packages/celery/app/backends.py index a274b85..5481528 100644 --- a/venv/lib/python3.12/site-packages/celery/app/backends.py +++ b/venv/lib/python3.12/site-packages/celery/app/backends.py @@ -34,7 +34,6 @@ BACKEND_ALIASES = { 'azureblockblob': 'celery.backends.azureblockblob:AzureBlockBlobBackend', 'arangodb': 'celery.backends.arangodb:ArangoDbBackend', 's3': 'celery.backends.s3:S3Backend', - 'gs': 'celery.backends.gcs:GCSBackend', } diff --git a/venv/lib/python3.12/site-packages/celery/app/base.py b/venv/lib/python3.12/site-packages/celery/app/base.py index a4d1c4c..cfd71c6 100644 --- a/venv/lib/python3.12/site-packages/celery/app/base.py +++ b/venv/lib/python3.12/site-packages/celery/app/base.py @@ -1,23 +1,17 @@ """Actual App instance implementation.""" -import functools -import importlib import inspect import os import sys import threading -import typing import warnings from collections import UserDict, defaultdict, deque from datetime import datetime -from datetime import timezone as datetime_timezone from operator import attrgetter from click.exceptions import Exit -from dateutil.parser import isoparse -from kombu import Exchange, pools +from kombu import pools from kombu.clocks import LamportClock from kombu.common import oid_from -from kombu.transport.native_delayed_delivery import calculate_routing_key from kombu.utils.compat import register_after_fork from kombu.utils.objects import cached_property from kombu.utils.uuid import uuid @@ -38,8 +32,6 @@ from celery.utils.log import get_logger from celery.utils.objects import FallbackContext, mro_lookup from celery.utils.time import maybe_make_aware, timezone, to_utc -from ..utils.annotations import annotation_is_class, annotation_issubclass, get_optional_arg -from ..utils.quorum_queues import detect_quorum_queues # Load all builtin tasks from . import backends, builtins # noqa from .annotations import prepare as prepare_annotations @@ -49,10 +41,6 @@ from .registry import TaskRegistry from .utils import (AppPickler, Settings, _new_key_to_old, _old_key_to_new, _unpickle_app, _unpickle_app_v2, appstr, bugreport, detect_settings) -if typing.TYPE_CHECKING: # pragma: no cover # codecov does not capture this - # flake8 marks the BaseModel import as unused, because the actual typehint is quoted. - from pydantic import BaseModel # noqa: F401 - __all__ = ('Celery',) logger = get_logger(__name__) @@ -102,70 +90,6 @@ def _after_fork_cleanup_app(app): logger.info('after forker raised exception: %r', exc, exc_info=1) -def pydantic_wrapper( - app: "Celery", - task_fun: typing.Callable[..., typing.Any], - task_name: str, - strict: bool = True, - context: typing.Optional[typing.Dict[str, typing.Any]] = None, - dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None -): - """Wrapper to validate arguments and serialize return values using Pydantic.""" - try: - pydantic = importlib.import_module('pydantic') - except ModuleNotFoundError as ex: - raise ImproperlyConfigured('You need to install pydantic to use pydantic model serialization.') from ex - - BaseModel: typing.Type['BaseModel'] = pydantic.BaseModel # noqa: F811 # only defined when type checking - - if context is None: - context = {} - if dump_kwargs is None: - dump_kwargs = {} - dump_kwargs.setdefault('mode', 'json') - - task_signature = inspect.signature(task_fun) - - @functools.wraps(task_fun) - def wrapper(*task_args, **task_kwargs): - # Validate task parameters if type hinted as BaseModel - bound_args = task_signature.bind(*task_args, **task_kwargs) - for arg_name, arg_value in bound_args.arguments.items(): - arg_annotation = task_signature.parameters[arg_name].annotation - - optional_arg = get_optional_arg(arg_annotation) - if optional_arg is not None and arg_value is not None: - arg_annotation = optional_arg - - if annotation_issubclass(arg_annotation, BaseModel): - bound_args.arguments[arg_name] = arg_annotation.model_validate( - arg_value, - strict=strict, - context={**context, 'celery_app': app, 'celery_task_name': task_name}, - ) - - # Call the task with (potentially) converted arguments - returned_value = task_fun(*bound_args.args, **bound_args.kwargs) - - # Dump Pydantic model if the returned value is an instance of pydantic.BaseModel *and* its - # class matches the typehint - return_annotation = task_signature.return_annotation - optional_return_annotation = get_optional_arg(return_annotation) - if optional_return_annotation is not None: - return_annotation = optional_return_annotation - - if ( - annotation_is_class(return_annotation) - and isinstance(returned_value, BaseModel) - and isinstance(returned_value, return_annotation) - ): - return returned_value.model_dump(**dump_kwargs) - - return returned_value - - return wrapper - - class PendingConfiguration(UserDict, AttributeDictMixin): # `app.conf` will be of this type before being explicitly configured, # meaning the app can keep any configuration set directly @@ -314,12 +238,6 @@ class Celery: self.loader_cls = loader or self._get_default_loader() self.log_cls = log or self.log_cls self.control_cls = control or self.control_cls - self._custom_task_cls_used = ( - # Custom task class provided as argument - bool(task_cls) - # subclass of Celery with a task_cls attribute - or self.__class__ is not Celery and hasattr(self.__class__, 'task_cls') - ) self.task_cls = task_cls or self.task_cls self.set_as_current = set_as_current self.registry_cls = symbol_by_name(self.registry_cls) @@ -515,7 +433,6 @@ class Celery: if shared: def cons(app): return app._task_from_fun(fun, **opts) - cons.__name__ = fun.__name__ connect_on_app_finalize(cons) if not lazy or self.finalized: @@ -544,27 +461,13 @@ class Celery: def type_checker(self, fun, bound=False): return staticmethod(head_from_fun(fun, bound=bound)) - def _task_from_fun( - self, - fun, - name=None, - base=None, - bind=False, - pydantic: bool = False, - pydantic_strict: bool = False, - pydantic_context: typing.Optional[typing.Dict[str, typing.Any]] = None, - pydantic_dump_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None, - **options, - ): + def _task_from_fun(self, fun, name=None, base=None, bind=False, **options): if not self.finalized and not self.autofinalize: raise RuntimeError('Contract breach: app not finalized') name = name or self.gen_task_name(fun.__name__, fun.__module__) base = base or self.Task if name not in self._tasks: - if pydantic is True: - fun = pydantic_wrapper(self, fun, name, pydantic_strict, pydantic_context, pydantic_dump_kwargs) - run = fun if bind else staticmethod(fun) task = type(fun.__name__, (base,), dict({ 'app': self, @@ -808,7 +711,7 @@ class Celery: retries=0, chord=None, reply_to=None, time_limit=None, soft_time_limit=None, root_id=None, parent_id=None, route_name=None, - shadow=None, chain=None, task_type=None, replaced_task_nesting=0, **options): + shadow=None, chain=None, task_type=None, **options): """Send task by name. Supports the same arguments as :meth:`@-Task.apply_async`. @@ -831,48 +734,13 @@ class Celery: ignore_result = options.pop('ignore_result', False) options = router.route( options, route_name or name, args, kwargs, task_type) - - driver_type = self.producer_pool.connections.connection.transport.driver_type - - if (eta or countdown) and detect_quorum_queues(self, driver_type)[0]: - - queue = options.get("queue") - exchange_type = queue.exchange.type if queue else options["exchange_type"] - routing_key = queue.routing_key if queue else options["routing_key"] - exchange_name = queue.exchange.name if queue else options["exchange"] - - if exchange_type != 'direct': - if eta: - if isinstance(eta, str): - eta = isoparse(eta) - countdown = (maybe_make_aware(eta) - self.now()).total_seconds() - - if countdown: - if countdown > 0: - routing_key = calculate_routing_key(int(countdown), routing_key) - exchange = Exchange( - 'celery_delayed_27', - type='topic', - ) - options.pop("queue", None) - options['routing_key'] = routing_key - options['exchange'] = exchange - - else: - logger.warning( - 'Direct exchanges are not supported with native delayed delivery.\n' - f'{exchange_name} is a direct exchange but should be a topic exchange or ' - 'a fanout exchange in order for native delayed delivery to work properly.\n' - 'If quorum queues are used, this task may block the worker process until the ETA arrives.' - ) - if expires is not None: if isinstance(expires, datetime): expires_s = (maybe_make_aware( expires) - self.now()).total_seconds() elif isinstance(expires, str): expires_s = (maybe_make_aware( - isoparse(expires)) - self.now()).total_seconds() + datetime.fromisoformat(expires)) - self.now()).total_seconds() else: expires_s = expires @@ -913,7 +781,7 @@ class Celery: self.conf.task_send_sent_event, root_id, parent_id, shadow, chain, ignore_result=ignore_result, - replaced_task_nesting=replaced_task_nesting, **options + **options ) stamped_headers = options.pop('stamped_headers', []) @@ -1026,7 +894,6 @@ class Celery: 'broker_connection_timeout', connect_timeout ), ) - broker_connection = connection def _acquire_connection(self, pool=True): @@ -1046,7 +913,6 @@ class Celery: will be acquired from the connection pool. """ return FallbackContext(connection, self._acquire_connection, pool=pool) - default_connection = connection_or_acquire # XXX compat def producer_or_acquire(self, producer=None): @@ -1062,7 +928,6 @@ class Celery: return FallbackContext( producer, self.producer_pool.acquire, block=True, ) - default_producer = producer_or_acquire # XXX compat def prepare_config(self, c): @@ -1071,7 +936,7 @@ class Celery: def now(self): """Return the current time and date as a datetime.""" - now_in_utc = to_utc(datetime.now(datetime_timezone.utc)) + now_in_utc = to_utc(datetime.utcnow()) return now_in_utc.astimezone(self.timezone) def select_queues(self, queues=None): @@ -1109,14 +974,7 @@ class Celery: This is used by PendingConfiguration: as soon as you access a key the configuration is read. """ - try: - conf = self._conf = self._load_config() - except AttributeError as err: - # AttributeError is not propagated, it is "handled" by - # PendingConfiguration parent class. This causes - # confusing RecursionError. - raise ModuleNotFoundError(*err.args) from err - + conf = self._conf = self._load_config() return conf def _load_config(self): diff --git a/venv/lib/python3.12/site-packages/celery/app/control.py b/venv/lib/python3.12/site-packages/celery/app/control.py index 603d930..52763e8 100644 --- a/venv/lib/python3.12/site-packages/celery/app/control.py +++ b/venv/lib/python3.12/site-packages/celery/app/control.py @@ -360,7 +360,7 @@ class Inspect: * ``routing_key`` - Routing key used when task was published * ``priority`` - Priority used when task was published * ``redelivered`` - True if the task was redelivered - * ``worker_pid`` - PID of worker processing the task + * ``worker_pid`` - PID of worker processin the task """ # signature used be unary: query_task(ids=[id1, id2]) @@ -527,8 +527,7 @@ class Control: if result: for host in result: for response in host.values(): - if isinstance(response['ok'], set): - task_ids.update(response['ok']) + task_ids.update(response['ok']) if task_ids: return self.revoke(list(task_ids), destination=destination, terminate=terminate, signal=signal, **kwargs) diff --git a/venv/lib/python3.12/site-packages/celery/app/defaults.py b/venv/lib/python3.12/site-packages/celery/app/defaults.py index f8e2511..a9f6868 100644 --- a/venv/lib/python3.12/site-packages/celery/app/defaults.py +++ b/venv/lib/python3.12/site-packages/celery/app/defaults.py @@ -95,7 +95,6 @@ NAMESPACES = Namespace( heartbeat=Option(120, type='int'), heartbeat_checkrate=Option(3.0, type='int'), login_method=Option(None, type='string'), - native_delayed_delivery_queue_type=Option(default='quorum', type='string'), pool_limit=Option(10, type='int'), use_ssl=Option(False, type='bool'), @@ -141,12 +140,6 @@ NAMESPACES = Namespace( connection_timeout=Option(20, type='int'), read_timeout=Option(120, type='int'), ), - gcs=Namespace( - bucket=Option(type='string'), - project=Option(type='string'), - base_path=Option('', type='string'), - ttl=Option(0, type='float'), - ), control=Namespace( queue_ttl=Option(300.0, type='float'), queue_expires=Option(10.0, type='float'), @@ -250,7 +243,6 @@ NAMESPACES = Namespace( ), table_schemas=Option(type='dict'), table_names=Option(type='dict', old={'celery_result_db_tablenames'}), - create_tables_at_setup=Option(True, type='bool'), ), task=Namespace( __old__=OLD_NS, @@ -263,7 +255,6 @@ NAMESPACES = Namespace( inherit_parent_priority=Option(False, type='bool'), default_delivery_mode=Option(2, type='string'), default_queue=Option('celery'), - default_queue_type=Option('classic', type='string'), default_exchange=Option(None, type='string'), # taken from queue default_exchange_type=Option('direct'), default_routing_key=Option(None, type='string'), # taken from queue @@ -311,8 +302,6 @@ NAMESPACES = Namespace( cancel_long_running_tasks_on_connection_loss=Option( False, type='bool' ), - soft_shutdown_timeout=Option(0.0, type='float'), - enable_soft_shutdown_on_idle=Option(False, type='bool'), concurrency=Option(None, type='int'), consumer=Option('celery.worker.consumer:Consumer', type='string'), direct=Option(False, type='bool', old={'celery_worker_direct'}), @@ -336,7 +325,6 @@ NAMESPACES = Namespace( pool_restarts=Option(False, type='bool'), proc_alive_timeout=Option(4.0, type='float'), prefetch_multiplier=Option(4, type='int'), - enable_prefetch_count_reduction=Option(True, type='bool'), redirect_stdouts=Option( True, type='bool', old={'celery_redirect_stdouts'}, ), @@ -350,7 +338,6 @@ NAMESPACES = Namespace( task_log_format=Option(DEFAULT_TASK_LOG_FMT), timer=Option(type='string'), timer_precision=Option(1.0, type='float'), - detect_quorum_queues=Option(True, type='bool'), ), ) diff --git a/venv/lib/python3.12/site-packages/celery/app/log.py b/venv/lib/python3.12/site-packages/celery/app/log.py index a4db105..4c807f4 100644 --- a/venv/lib/python3.12/site-packages/celery/app/log.py +++ b/venv/lib/python3.12/site-packages/celery/app/log.py @@ -18,7 +18,6 @@ from celery import signals from celery._state import get_current_task from celery.exceptions import CDeprecationWarning, CPendingDeprecationWarning from celery.local import class_property -from celery.platforms import isatty from celery.utils.log import (ColorFormatter, LoggingProxy, get_logger, get_multiprocessing_logger, mlevel, reset_multiprocessing_logger) from celery.utils.nodenames import node_format @@ -204,7 +203,7 @@ class Logging: if colorize or colorize is None: # Only use color if there's no active log file # and stderr is an actual terminal. - return logfile is None and isatty(sys.stderr) + return logfile is None and sys.stderr.isatty() return colorize def colored(self, logfile=None, enabled=None): diff --git a/venv/lib/python3.12/site-packages/celery/app/routes.py b/venv/lib/python3.12/site-packages/celery/app/routes.py index bed2c07..a56ce59 100644 --- a/venv/lib/python3.12/site-packages/celery/app/routes.py +++ b/venv/lib/python3.12/site-packages/celery/app/routes.py @@ -20,7 +20,7 @@ except AttributeError: # pragma: no cover # for support Python 3.7 Pattern = re.Pattern -__all__ = ('MapRoute', 'Router', 'expand_router_string', 'prepare') +__all__ = ('MapRoute', 'Router', 'prepare') class MapRoute: diff --git a/venv/lib/python3.12/site-packages/celery/app/task.py b/venv/lib/python3.12/site-packages/celery/app/task.py index 90ba855..7998d60 100644 --- a/venv/lib/python3.12/site-packages/celery/app/task.py +++ b/venv/lib/python3.12/site-packages/celery/app/task.py @@ -104,7 +104,7 @@ class Context: def _get_custom_headers(self, *args, **kwargs): headers = {} headers.update(*args, **kwargs) - celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr', 'compression'} + celery_keys = {*Context.__dict__.keys(), 'lang', 'task', 'argsrepr', 'kwargsrepr'} for key in celery_keys: headers.pop(key, None) if not headers: @@ -466,7 +466,7 @@ class Task: shadow (str): Override task name used in logs/monitoring. Default is retrieved from :meth:`shadow_name`. - connection (kombu.Connection): Reuse existing broker connection + connection (kombu.Connection): Re-use existing broker connection instead of acquiring one from the connection pool. retry (bool): If enabled sending of the task message will be @@ -535,8 +535,6 @@ class Task: publisher (kombu.Producer): Deprecated alias to ``producer``. headers (Dict): Message headers to be included in the message. - The headers can be used as an overlay for custom labeling - using the :ref:`canvas-stamping` feature. Returns: celery.result.AsyncResult: Promise of future evaluation. @@ -545,8 +543,6 @@ class Task: TypeError: If not enough arguments are passed, or too many arguments are passed. Note that signature checks may be disabled by specifying ``@task(typing=False)``. - ValueError: If soft_time_limit and time_limit both are set - but soft_time_limit is greater than time_limit kombu.exceptions.OperationalError: If a connection to the transport cannot be made, or if the connection is lost. @@ -554,9 +550,6 @@ class Task: Also supports all keyword arguments supported by :meth:`kombu.Producer.publish`. """ - if self.soft_time_limit and self.time_limit and self.soft_time_limit > self.time_limit: - raise ValueError('soft_time_limit must be less than or equal to time_limit') - if self.typing: try: check_arguments = self.__header__ @@ -795,7 +788,6 @@ class Task: request = { 'id': task_id, - 'task': self.name, 'retries': retries, 'is_eager': True, 'logfile': logfile, @@ -832,7 +824,7 @@ class Task: if isinstance(retval, Retry) and retval.sig is not None: return retval.sig.apply(retries=retries + 1) state = states.SUCCESS if ret.info is None else ret.info.state - return EagerResult(task_id, retval, state, traceback=tb, name=self.name) + return EagerResult(task_id, retval, state, traceback=tb) def AsyncResult(self, task_id, **kwargs): """Get AsyncResult instance for the specified task. @@ -962,20 +954,11 @@ class Task: root_id=self.request.root_id, replaced_task_nesting=replaced_task_nesting ) - - # If the replaced task is a chain, we want to set all of the chain tasks - # with the same replaced_task_nesting value to mark their replacement nesting level - if isinstance(sig, _chain): - for chain_task in maybe_list(sig.tasks) or []: - chain_task.set(replaced_task_nesting=replaced_task_nesting) - # If the task being replaced is part of a chain, we need to re-create # it with the replacement signature - these subsequent tasks will # retain their original task IDs as well for t in reversed(self.request.chain or []): - chain_task = signature(t, app=self.app) - chain_task.set(replaced_task_nesting=replaced_task_nesting) - sig |= chain_task + sig |= signature(t, app=self.app) return self.on_replace(sig) def add_to_chord(self, sig, lazy=False): @@ -1116,7 +1099,7 @@ class Task: return result def push_request(self, *args, **kwargs): - self.request_stack.push(Context(*args, **{**self.request.__dict__, **kwargs})) + self.request_stack.push(Context(*args, **kwargs)) def pop_request(self): self.request_stack.pop() diff --git a/venv/lib/python3.12/site-packages/celery/app/trace.py b/venv/lib/python3.12/site-packages/celery/app/trace.py index 2e8cf8a..3933d01 100644 --- a/venv/lib/python3.12/site-packages/celery/app/trace.py +++ b/venv/lib/python3.12/site-packages/celery/app/trace.py @@ -8,6 +8,7 @@ import os import sys import time from collections import namedtuple +from typing import Any, Callable, Dict, FrozenSet, Optional, Sequence, Tuple, Type, Union from warnings import warn from billiard.einfo import ExceptionInfo, ExceptionWithTraceback @@ -16,6 +17,8 @@ from kombu.serialization import loads as loads_message from kombu.serialization import prepare_accept_content from kombu.utils.encoding import safe_repr, safe_str +import celery +import celery.loaders.app from celery import current_app, group, signals, states from celery._state import _task_stack from celery.app.task import Context @@ -291,10 +294,20 @@ def traceback_clear(exc=None): tb = tb.tb_next -def build_tracer(name, task, loader=None, hostname=None, store_errors=True, - Info=TraceInfo, eager=False, propagate=False, app=None, - monotonic=time.monotonic, trace_ok_t=trace_ok_t, - IGNORE_STATES=IGNORE_STATES): +def build_tracer( + name: str, + task: Union[celery.Task, celery.local.PromiseProxy], + loader: Optional[celery.loaders.app.AppLoader] = None, + hostname: Optional[str] = None, + store_errors: bool = True, + Info: Type[TraceInfo] = TraceInfo, + eager: bool = False, + propagate: bool = False, + app: Optional[celery.Celery] = None, + monotonic: Callable[[], int] = time.monotonic, + trace_ok_t: Type[trace_ok_t] = trace_ok_t, + IGNORE_STATES: FrozenSet[str] = IGNORE_STATES) -> \ + Callable[[str, Tuple[Any, ...], Dict[str, Any], Any], trace_ok_t]: """Return a function that traces task execution. Catches all exceptions and updates result backend with the @@ -374,7 +387,12 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True, from celery import canvas signature = canvas.maybe_signature # maybe_ does not clone if already - def on_error(request, exc, state=FAILURE, call_errbacks=True): + def on_error( + request: celery.app.task.Context, + exc: Union[Exception, Type[Exception]], + state: str = FAILURE, + call_errbacks: bool = True) -> Tuple[Info, Any, Any, Any]: + """Handle any errors raised by a `Task`'s execution.""" if propagate: raise I = Info(state, exc) @@ -383,7 +401,13 @@ def build_tracer(name, task, loader=None, hostname=None, store_errors=True, ) return I, R, I.state, I.retval - def trace_task(uuid, args, kwargs, request=None): + def trace_task( + uuid: str, + args: Sequence[Any], + kwargs: Dict[str, Any], + request: Optional[Dict[str, Any]] = None) -> trace_ok_t: + """Execute and trace a `Task`.""" + # R - is the possibly prepared return value. # I - is the Info object. # T - runtime diff --git a/venv/lib/python3.12/site-packages/celery/app/utils.py b/venv/lib/python3.12/site-packages/celery/app/utils.py index da2ee66..0dd3409 100644 --- a/venv/lib/python3.12/site-packages/celery/app/utils.py +++ b/venv/lib/python3.12/site-packages/celery/app/utils.py @@ -35,7 +35,7 @@ settings -> transport:{transport} results:{results} """ HIDDEN_SETTINGS = re.compile( - 'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE|BEAT_DBURI', + 'API|TOKEN|KEY|SECRET|PASS|PROFANITIES_LIST|SIGNATURE|DATABASE', re.IGNORECASE, ) diff --git a/venv/lib/python3.12/site-packages/celery/apps/worker.py b/venv/lib/python3.12/site-packages/celery/apps/worker.py index 5558dab..dcc04da 100644 --- a/venv/lib/python3.12/site-packages/celery/apps/worker.py +++ b/venv/lib/python3.12/site-packages/celery/apps/worker.py @@ -20,7 +20,7 @@ from kombu.utils.encoding import safe_str from celery import VERSION_BANNER, platforms, signals from celery.app import trace from celery.loaders.app import AppLoader -from celery.platforms import EX_FAILURE, EX_OK, check_privileges, isatty +from celery.platforms import EX_FAILURE, EX_OK, check_privileges from celery.utils import static, term from celery.utils.debug import cry from celery.utils.imports import qualname @@ -77,9 +77,8 @@ def active_thread_count(): if not t.name.startswith('Dummy-')) -def safe_say(msg, f=sys.__stderr__): - if hasattr(f, 'fileno') and f.fileno() is not None: - os.write(f.fileno(), f'\n{msg}\n'.encode()) +def safe_say(msg): + print(f'\n{msg}', file=sys.__stderr__, flush=True) class Worker(WorkController): @@ -107,7 +106,7 @@ class Worker(WorkController): super().setup_defaults(**kwargs) self.purge = purge self.no_color = no_color - self._isatty = isatty(sys.stdout) + self._isatty = sys.stdout.isatty() self.colored = self.app.log.colored( self.logfile, enabled=not no_color if no_color is not None else no_color @@ -279,27 +278,15 @@ class Worker(WorkController): ) -def _shutdown_handler(worker: Worker, sig='SIGTERM', how='Warm', callback=None, exitcode=EX_OK, verbose=True): - """Install signal handler for warm/cold shutdown. - - The handler will run from the MainProcess. - - Args: - worker (Worker): The worker that received the signal. - sig (str, optional): The signal that was received. Defaults to 'TERM'. - how (str, optional): The type of shutdown to perform. Defaults to 'Warm'. - callback (Callable, optional): Signal handler. Defaults to None. - exitcode (int, optional): The exit code to use. Defaults to EX_OK. - verbose (bool, optional): Whether to print the type of shutdown. Defaults to True. - """ +def _shutdown_handler(worker, sig='TERM', how='Warm', + callback=None, exitcode=EX_OK): def _handle_request(*args): with in_sighandler(): from celery.worker import state if current_process()._name == 'MainProcess': if callback: callback(worker) - if verbose: - safe_say(f'worker: {how} shutdown (MainProcess)', sys.__stdout__) + safe_say(f'worker: {how} shutdown (MainProcess)') signals.worker_shutting_down.send( sender=worker.hostname, sig=sig, how=how, exitcode=exitcode, @@ -310,126 +297,19 @@ def _shutdown_handler(worker: Worker, sig='SIGTERM', how='Warm', callback=None, platforms.signals[sig] = _handle_request -def on_hard_shutdown(worker: Worker): - """Signal handler for hard shutdown. - - The handler will terminate the worker immediately by force using the exit code ``EX_FAILURE``. - - In practice, you should never get here, as the standard shutdown process should be enough. - This handler is only for the worst-case scenario, where the worker is stuck and cannot be - terminated gracefully (e.g., spamming the Ctrl+C in the terminal to force the worker to terminate). - - Args: - worker (Worker): The worker that received the signal. - - Raises: - WorkerTerminate: This exception will be raised in the MainProcess to terminate the worker immediately. - """ - from celery.exceptions import WorkerTerminate - raise WorkerTerminate(EX_FAILURE) - - -def during_soft_shutdown(worker: Worker): - """This signal handler is called when the worker is in the middle of the soft shutdown process. - - When the worker is in the soft shutdown process, it is waiting for tasks to finish. If the worker - receives a SIGINT (Ctrl+C) or SIGQUIT signal (or possibly SIGTERM if REMAP_SIGTERM is set to "SIGQUIT"), - the handler will cancels all unacked requests to allow the worker to terminate gracefully and replace the - signal handler for SIGINT and SIGQUIT with the hard shutdown handler ``on_hard_shutdown`` to terminate - the worker immediately by force next time the signal is received. - - It will give the worker once last chance to gracefully terminate (the cold shutdown), after canceling all - unacked requests, before using the hard shutdown handler to terminate the worker forcefully. - - Args: - worker (Worker): The worker that received the signal. - """ - # Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM) - # with the hard shutdown handler to terminate the worker immediately by force - install_worker_term_hard_handler(worker, sig='SIGINT', callback=on_hard_shutdown, verbose=False) - install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=on_hard_shutdown) - - # Cancel all unacked requests and allow the worker to terminate naturally - worker.consumer.cancel_all_unacked_requests() - - # We get here if the worker was in the middle of the soft (cold) shutdown process, - # and the matching signal was received. This can typically happen when the worker is - # waiting for tasks to finish, and the user decides to still cancel the running tasks. - # We give the worker the last chance to gracefully terminate by letting the soft shutdown - # waiting time to finish, which is running in the MainProcess from the previous signal handler call. - safe_say('Waiting gracefully for cold shutdown to complete...', sys.__stdout__) - - -def on_cold_shutdown(worker: Worker): - """Signal handler for cold shutdown. - - Registered for SIGQUIT and SIGINT (Ctrl+C) signals. If REMAP_SIGTERM is set to "SIGQUIT", this handler will also - be registered for SIGTERM. - - This handler will initiate the cold (and soft if enabled) shutdown procesdure for the worker. - - Worker running with N tasks: - - SIGTERM: - -The worker will initiate the warm shutdown process until all tasks are finished. Additional. - SIGTERM signals will be ignored. SIGQUIT will transition to the cold shutdown process described below. - - SIGQUIT: - - The worker will initiate the cold shutdown process. - - If the soft shutdown is enabled, the worker will wait for the tasks to finish up to the soft - shutdown timeout (practically having a limited warm shutdown just before the cold shutdown). - - Cancel all tasks (from the MainProcess) and allow the worker to complete the cold shutdown - process gracefully. - - Caveats: - - SIGINT (Ctrl+C) signal is defined to replace itself with the cold shutdown (SIGQUIT) after first use, - and to emit a message to the user to hit Ctrl+C again to initiate the cold shutdown process. But, most - important, it will also be caught in WorkController.start() to initiate the warm shutdown process. - - SIGTERM will also be handled in WorkController.start() to initiate the warm shutdown process (the same). - - If REMAP_SIGTERM is set to "SIGQUIT", the SIGTERM signal will be remapped to SIGQUIT, and the cold - shutdown process will be initiated instead of the warm shutdown process using SIGTERM. - - If SIGQUIT is received (also via SIGINT) during the cold/soft shutdown process, the handler will cancel all - unacked requests but still wait for the soft shutdown process to finish before terminating the worker - gracefully. The next time the signal is received though, the worker will terminate immediately by force. - - So, the purpose of this handler is to allow waiting for the soft shutdown timeout, then cancel all tasks from - the MainProcess and let the WorkController.terminate() to terminate the worker naturally. If the soft shutdown - is disabled, it will immediately cancel all tasks let the cold shutdown finish normally. - - Args: - worker (Worker): The worker that received the signal. - """ - safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!', sys.__stdout__) - - # Replace the signal handler for SIGINT (Ctrl+C) and SIGQUIT (and possibly SIGTERM) - install_worker_term_hard_handler(worker, sig='SIGINT', callback=during_soft_shutdown) - install_worker_term_hard_handler(worker, sig='SIGQUIT', callback=during_soft_shutdown) - if REMAP_SIGTERM == "SIGQUIT": - install_worker_term_hard_handler(worker, sig='SIGTERM', callback=during_soft_shutdown) - # else, SIGTERM will print the _shutdown_handler's message and do nothing, every time it is received.. - - # Initiate soft shutdown process (if enabled and tasks are running) - worker.wait_for_soft_shutdown() - - # Cancel all unacked requests and allow the worker to terminate naturally - worker.consumer.cancel_all_unacked_requests() - - # Stop the pool to allow successful tasks call on_success() - worker.consumer.pool.stop() - - -# Allow SIGTERM to be remapped to SIGQUIT to initiate cold shutdown instead of warm shutdown using SIGTERM if REMAP_SIGTERM == "SIGQUIT": install_worker_term_handler = partial( - _shutdown_handler, sig='SIGTERM', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE, + _shutdown_handler, sig='SIGTERM', how='Cold', exitcode=EX_FAILURE, ) else: install_worker_term_handler = partial( _shutdown_handler, sig='SIGTERM', how='Warm', ) - if not is_jython: # pragma: no cover install_worker_term_hard_handler = partial( - _shutdown_handler, sig='SIGQUIT', how='Cold', callback=on_cold_shutdown, exitcode=EX_FAILURE, + _shutdown_handler, sig='SIGQUIT', how='Cold', + exitcode=EX_FAILURE, ) else: # pragma: no cover install_worker_term_handler = \ @@ -437,9 +317,8 @@ else: # pragma: no cover def on_SIGINT(worker): - safe_say('worker: Hitting Ctrl+C again will initiate cold shutdown, terminating all running tasks!', - sys.__stdout__) - install_worker_term_hard_handler(worker, sig='SIGINT', verbose=False) + safe_say('worker: Hitting Ctrl+C again will terminate all running tasks!') + install_worker_term_hard_handler(worker, sig='SIGINT') if not is_jython: # pragma: no cover @@ -464,8 +343,7 @@ def install_worker_restart_handler(worker, sig='SIGHUP'): def restart_worker_sig_handler(*args): """Signal handler restarting the current python program.""" set_in_sighandler(True) - safe_say(f"Restarting celery worker ({' '.join(sys.argv)})", - sys.__stdout__) + safe_say(f"Restarting celery worker ({' '.join(sys.argv)})") import atexit atexit.register(_reload_current_worker) from celery.worker import state diff --git a/venv/lib/python3.12/site-packages/celery/backends/azureblockblob.py b/venv/lib/python3.12/site-packages/celery/backends/azureblockblob.py index 3648cbe..862777b 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/azureblockblob.py +++ b/venv/lib/python3.12/site-packages/celery/backends/azureblockblob.py @@ -1,5 +1,4 @@ """The Azure Storage Block Blob backend for Celery.""" -from kombu.transport.azurestoragequeues import Transport as AzureStorageQueuesTransport from kombu.utils import cached_property from kombu.utils.encoding import bytes_to_str @@ -29,13 +28,6 @@ class AzureBlockBlobBackend(KeyValueStoreBackend): container_name=None, *args, **kwargs): - """ - Supported URL formats: - - azureblockblob://CONNECTION_STRING - azureblockblob://DefaultAzureCredential@STORAGE_ACCOUNT_URL - azureblockblob://ManagedIdentityCredential@STORAGE_ACCOUNT_URL - """ super().__init__(*args, **kwargs) if azurestorage is None or azurestorage.__version__ < '12': @@ -73,26 +65,11 @@ class AzureBlockBlobBackend(KeyValueStoreBackend): the container is created if it doesn't yet exist. """ - if ( - "DefaultAzureCredential" in self._connection_string or - "ManagedIdentityCredential" in self._connection_string - ): - # Leveraging the work that Kombu already did for us - credential_, url = AzureStorageQueuesTransport.parse_uri( - self._connection_string - ) - client = BlobServiceClient( - account_url=url, - credential=credential_, - connection_timeout=self._connection_timeout, - read_timeout=self._read_timeout, - ) - else: - client = BlobServiceClient.from_connection_string( - self._connection_string, - connection_timeout=self._connection_timeout, - read_timeout=self._read_timeout, - ) + client = BlobServiceClient.from_connection_string( + self._connection_string, + connection_timeout=self._connection_timeout, + read_timeout=self._read_timeout + ) try: client.create_container(name=self._container_name) diff --git a/venv/lib/python3.12/site-packages/celery/backends/base.py b/venv/lib/python3.12/site-packages/celery/backends/base.py index dc79f4e..4216c3b 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/base.py +++ b/venv/lib/python3.12/site-packages/celery/backends/base.py @@ -9,7 +9,7 @@ import sys import time import warnings from collections import namedtuple -from datetime import timedelta +from datetime import datetime, timedelta from functools import partial from weakref import WeakValueDictionary @@ -460,7 +460,7 @@ class Backend: state, traceback, request, format_date=True, encode=False): if state in self.READY_STATES: - date_done = self.app.now() + date_done = datetime.utcnow() if format_date: date_done = date_done.isoformat() else: @@ -833,11 +833,9 @@ class BaseKeyValueStoreBackend(Backend): """ global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None) if global_keyprefix: - if global_keyprefix[-1] not in ':_-.': - global_keyprefix += '_' - self.task_keyprefix = f"{global_keyprefix}{self.task_keyprefix}" - self.group_keyprefix = f"{global_keyprefix}{self.group_keyprefix}" - self.chord_keyprefix = f"{global_keyprefix}{self.chord_keyprefix}" + self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}" + self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}" + self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}" def _encode_prefixes(self): self.task_keyprefix = self.key_t(self.task_keyprefix) @@ -1082,7 +1080,7 @@ class BaseKeyValueStoreBackend(Backend): ) finally: deps.delete() - self.delete(key) + self.client.delete(key) else: self.expire(key, self.expires) diff --git a/venv/lib/python3.12/site-packages/celery/backends/cassandra.py b/venv/lib/python3.12/site-packages/celery/backends/cassandra.py index 4ca071d..0eb37f3 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/cassandra.py +++ b/venv/lib/python3.12/site-packages/celery/backends/cassandra.py @@ -86,7 +86,7 @@ class CassandraBackend(BaseBackend): supports_autoexpire = True # autoexpire supported via entry_ttl def __init__(self, servers=None, keyspace=None, table=None, entry_ttl=None, - port=None, bundle_path=None, **kwargs): + port=9042, bundle_path=None, **kwargs): super().__init__(**kwargs) if not cassandra: @@ -96,7 +96,7 @@ class CassandraBackend(BaseBackend): self.servers = servers or conf.get('cassandra_servers', None) self.bundle_path = bundle_path or conf.get( 'cassandra_secure_bundle_path', None) - self.port = port or conf.get('cassandra_port', None) or 9042 + self.port = port or conf.get('cassandra_port', None) self.keyspace = keyspace or conf.get('cassandra_keyspace', None) self.table = table or conf.get('cassandra_table', None) self.cassandra_options = conf.get('cassandra_options', {}) diff --git a/venv/lib/python3.12/site-packages/celery/backends/database/__init__.py b/venv/lib/python3.12/site-packages/celery/backends/database/__init__.py index df03db5..91080ad 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/database/__init__.py +++ b/venv/lib/python3.12/site-packages/celery/backends/database/__init__.py @@ -98,23 +98,11 @@ class DatabaseBackend(BaseBackend): 'Missing connection string! Do you have the' ' database_url setting set to a real value?') - self.session_manager = SessionManager() - - create_tables_at_setup = conf.database_create_tables_at_setup - if create_tables_at_setup is True: - self._create_tables() - @property def extended_result(self): return self.app.conf.find_value_for_key('extended', 'result') - def _create_tables(self): - """Create the task and taskset tables.""" - self.ResultSession() - - def ResultSession(self, session_manager=None): - if session_manager is None: - session_manager = self.session_manager + def ResultSession(self, session_manager=SessionManager()): return session_manager.session_factory( dburi=self.url, short_lived_sessions=self.short_lived_sessions, diff --git a/venv/lib/python3.12/site-packages/celery/backends/database/models.py b/venv/lib/python3.12/site-packages/celery/backends/database/models.py index a5df8f4..1c766b5 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/database/models.py +++ b/venv/lib/python3.12/site-packages/celery/backends/database/models.py @@ -1,5 +1,5 @@ """Database models used by the SQLAlchemy result store backend.""" -from datetime import datetime, timezone +from datetime import datetime import sqlalchemy as sa from sqlalchemy.types import PickleType @@ -22,8 +22,8 @@ class Task(ResultModelBase): task_id = sa.Column(sa.String(155), unique=True) status = sa.Column(sa.String(50), default=states.PENDING) result = sa.Column(PickleType, nullable=True) - date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc), - onupdate=datetime.now(timezone.utc), nullable=True) + date_done = sa.Column(sa.DateTime, default=datetime.utcnow, + onupdate=datetime.utcnow, nullable=True) traceback = sa.Column(sa.Text, nullable=True) def __init__(self, task_id): @@ -84,7 +84,7 @@ class TaskSet(ResultModelBase): autoincrement=True, primary_key=True) taskset_id = sa.Column(sa.String(155), unique=True) result = sa.Column(PickleType, nullable=True) - date_done = sa.Column(sa.DateTime, default=datetime.now(timezone.utc), + date_done = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=True) def __init__(self, taskset_id, result): diff --git a/venv/lib/python3.12/site-packages/celery/backends/dynamodb.py b/venv/lib/python3.12/site-packages/celery/backends/dynamodb.py index 0423a46..90fbae0 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/dynamodb.py +++ b/venv/lib/python3.12/site-packages/celery/backends/dynamodb.py @@ -1,8 +1,6 @@ """AWS DynamoDB result store backend.""" from collections import namedtuple -from ipaddress import ip_address from time import sleep, time -from typing import Any, Dict from kombu.utils.url import _parse_url as parse_url @@ -56,15 +54,11 @@ class DynamoDBBackend(KeyValueStoreBackend): supports_autoexpire = True _key_field = DynamoDBAttribute(name='id', data_type='S') - # Each record has either a value field or count field _value_field = DynamoDBAttribute(name='result', data_type='B') - _count_filed = DynamoDBAttribute(name="chord_count", data_type='N') _timestamp_field = DynamoDBAttribute(name='timestamp', data_type='N') _ttl_field = DynamoDBAttribute(name='ttl', data_type='N') _available_fields = None - implements_incr = True - def __init__(self, url=None, table_name=None, *args, **kwargs): super().__init__(*args, **kwargs) @@ -97,9 +91,9 @@ class DynamoDBBackend(KeyValueStoreBackend): aws_credentials_given = access_key_given - if region == 'localhost' or DynamoDBBackend._is_valid_ip(region): + if region == 'localhost': # We are using the downloadable, local version of DynamoDB - self.endpoint_url = f'http://{region}:{port}' + self.endpoint_url = f'http://localhost:{port}' self.aws_region = 'us-east-1' logger.warning( 'Using local-only DynamoDB endpoint URL: {}'.format( @@ -154,14 +148,6 @@ class DynamoDBBackend(KeyValueStoreBackend): secret_access_key=aws_secret_access_key ) - @staticmethod - def _is_valid_ip(ip): - try: - ip_address(ip) - return True - except ValueError: - return False - def _get_client(self, access_key_id=None, secret_access_key=None): """Get client connection.""" if self._client is None: @@ -473,40 +459,6 @@ class DynamoDBBackend(KeyValueStoreBackend): }) return put_request - def _prepare_init_count_request(self, key: str) -> Dict[str, Any]: - """Construct the counter initialization request parameters""" - timestamp = time() - return { - 'TableName': self.table_name, - 'Item': { - self._key_field.name: { - self._key_field.data_type: key - }, - self._count_filed.name: { - self._count_filed.data_type: "0" - }, - self._timestamp_field.name: { - self._timestamp_field.data_type: str(timestamp) - } - } - } - - def _prepare_inc_count_request(self, key: str) -> Dict[str, Any]: - """Construct the counter increment request parameters""" - return { - 'TableName': self.table_name, - 'Key': { - self._key_field.name: { - self._key_field.data_type: key - } - }, - 'UpdateExpression': f"set {self._count_filed.name} = {self._count_filed.name} + :num", - "ExpressionAttributeValues": { - ":num": {"N": "1"}, - }, - "ReturnValues": "UPDATED_NEW", - } - def _item_to_dict(self, raw_response): """Convert get_item() response to field-value pairs.""" if 'Item' not in raw_response: @@ -539,18 +491,3 @@ class DynamoDBBackend(KeyValueStoreBackend): key = str(key) request_parameters = self._prepare_get_request(key) self.client.delete_item(**request_parameters) - - def incr(self, key: bytes) -> int: - """Atomically increase the chord_count and return the new count""" - key = str(key) - request_parameters = self._prepare_inc_count_request(key) - item_response = self.client.update_item(**request_parameters) - new_count: str = item_response["Attributes"][self._count_filed.name][self._count_filed.data_type] - return int(new_count) - - def _apply_chord_incr(self, header_result_args, body, **kwargs): - chord_key = self.get_key_for_chord(header_result_args[0]) - init_count_request = self._prepare_init_count_request(str(chord_key)) - self.client.put_item(**init_count_request) - return super()._apply_chord_incr( - header_result_args, body, **kwargs) diff --git a/venv/lib/python3.12/site-packages/celery/backends/elasticsearch.py b/venv/lib/python3.12/site-packages/celery/backends/elasticsearch.py index 9e6f265..5448129 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/elasticsearch.py +++ b/venv/lib/python3.12/site-packages/celery/backends/elasticsearch.py @@ -1,5 +1,5 @@ """Elasticsearch result store backend.""" -from datetime import datetime, timezone +from datetime import datetime from kombu.utils.encoding import bytes_to_str from kombu.utils.url import _parse_url @@ -14,11 +14,6 @@ try: except ImportError: elasticsearch = None -try: - import elastic_transport -except ImportError: - elastic_transport = None - __all__ = ('ElasticsearchBackend',) E_LIB_MISSING = """\ @@ -36,7 +31,7 @@ class ElasticsearchBackend(KeyValueStoreBackend): """ index = 'celery' - doc_type = None + doc_type = 'backend' scheme = 'http' host = 'localhost' port = 9200 @@ -88,17 +83,17 @@ class ElasticsearchBackend(KeyValueStoreBackend): self._server = None def exception_safe_to_retry(self, exc): - if isinstance(exc, elasticsearch.exceptions.ApiError): + if isinstance(exc, (elasticsearch.exceptions.TransportError)): # 401: Unauthorized # 409: Conflict + # 429: Too Many Requests # 500: Internal Server Error # 502: Bad Gateway + # 503: Service Unavailable # 504: Gateway Timeout # N/A: Low level exception (i.e. socket exception) - if exc.status_code in {401, 409, 500, 502, 504, 'N/A'}: + if exc.status_code in {401, 409, 429, 500, 502, 503, 504, 'N/A'}: return True - if isinstance(exc, elasticsearch.exceptions.TransportError): - return True return False def get(self, key): @@ -113,23 +108,17 @@ class ElasticsearchBackend(KeyValueStoreBackend): pass def _get(self, key): - if self.doc_type: - return self.server.get( - index=self.index, - id=key, - doc_type=self.doc_type, - ) - else: - return self.server.get( - index=self.index, - id=key, - ) + return self.server.get( + index=self.index, + doc_type=self.doc_type, + id=key, + ) def _set_with_state(self, key, value, state): body = { 'result': value, '@timestamp': '{}Z'.format( - datetime.now(timezone.utc).isoformat()[:-9] + datetime.utcnow().isoformat()[:-3] ), } try: @@ -146,23 +135,14 @@ class ElasticsearchBackend(KeyValueStoreBackend): def _index(self, id, body, **kwargs): body = {bytes_to_str(k): v for k, v in body.items()} - if self.doc_type: - return self.server.index( - id=bytes_to_str(id), - index=self.index, - doc_type=self.doc_type, - body=body, - params={'op_type': 'create'}, - **kwargs - ) - else: - return self.server.index( - id=bytes_to_str(id), - index=self.index, - body=body, - params={'op_type': 'create'}, - **kwargs - ) + return self.server.index( + id=bytes_to_str(id), + index=self.index, + doc_type=self.doc_type, + body=body, + params={'op_type': 'create'}, + **kwargs + ) def _update(self, id, body, state, **kwargs): """Update state in a conflict free manner. @@ -202,32 +182,19 @@ class ElasticsearchBackend(KeyValueStoreBackend): prim_term = res_get.get('_primary_term', 1) # try to update document with current seq_no and primary_term - if self.doc_type: - res = self.server.update( - id=bytes_to_str(id), - index=self.index, - doc_type=self.doc_type, - body={'doc': body}, - params={'if_primary_term': prim_term, 'if_seq_no': seq_no}, - **kwargs - ) - else: - res = self.server.update( - id=bytes_to_str(id), - index=self.index, - body={'doc': body}, - params={'if_primary_term': prim_term, 'if_seq_no': seq_no}, - **kwargs - ) + res = self.server.update( + id=bytes_to_str(id), + index=self.index, + doc_type=self.doc_type, + body={'doc': body}, + params={'if_primary_term': prim_term, 'if_seq_no': seq_no}, + **kwargs + ) # result is elastic search update query result # noop = query did not update any document # updated = at least one document got updated if res['result'] == 'noop': - raise elasticsearch.exceptions.ConflictError( - "conflicting update occurred concurrently", - elastic_transport.ApiResponseMeta(409, "HTTP/1.1", - elastic_transport.HttpHeaders(), 0, elastic_transport.NodeConfig( - self.scheme, self.host, self.port)), None) + raise elasticsearch.exceptions.ConflictError(409, 'conflicting update occurred concurrently', {}) return res def encode(self, data): @@ -258,10 +225,7 @@ class ElasticsearchBackend(KeyValueStoreBackend): return [self.get(key) for key in keys] def delete(self, key): - if self.doc_type: - self.server.delete(index=self.index, id=key, doc_type=self.doc_type) - else: - self.server.delete(index=self.index, id=key) + self.server.delete(index=self.index, doc_type=self.doc_type, id=key) def _get_server(self): """Connect to the Elasticsearch server.""" @@ -269,10 +233,11 @@ class ElasticsearchBackend(KeyValueStoreBackend): if self.username and self.password: http_auth = (self.username, self.password) return elasticsearch.Elasticsearch( - f'{self.scheme}://{self.host}:{self.port}', + f'{self.host}:{self.port}', retry_on_timeout=self.es_retry_on_timeout, max_retries=self.es_max_retries, timeout=self.es_timeout, + scheme=self.scheme, http_auth=http_auth, ) diff --git a/venv/lib/python3.12/site-packages/celery/backends/filesystem.py b/venv/lib/python3.12/site-packages/celery/backends/filesystem.py index 1a624f3..22fd5dc 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/filesystem.py +++ b/venv/lib/python3.12/site-packages/celery/backends/filesystem.py @@ -50,7 +50,7 @@ class FilesystemBackend(KeyValueStoreBackend): self.open = open self.unlink = unlink - # Let's verify that we've everything setup right + # Lets verify that we've everything setup right self._do_directory_test(b'.fs-backend-' + uuid().encode(encoding)) def __reduce__(self, args=(), kwargs=None): diff --git a/venv/lib/python3.12/site-packages/celery/backends/gcs.py b/venv/lib/python3.12/site-packages/celery/backends/gcs.py deleted file mode 100644 index d667a9c..0000000 --- a/venv/lib/python3.12/site-packages/celery/backends/gcs.py +++ /dev/null @@ -1,352 +0,0 @@ -"""Google Cloud Storage result store backend for Celery.""" -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timedelta -from os import getpid -from threading import RLock - -from kombu.utils.encoding import bytes_to_str -from kombu.utils.functional import dictfilter -from kombu.utils.url import url_to_parts - -from celery.canvas import maybe_signature -from celery.exceptions import ChordError, ImproperlyConfigured -from celery.result import GroupResult, allow_join_result -from celery.utils.log import get_logger - -from .base import KeyValueStoreBackend - -try: - import requests - from google.api_core import retry - from google.api_core.exceptions import Conflict - from google.api_core.retry import if_exception_type - from google.cloud import storage - from google.cloud.storage import Client - from google.cloud.storage.retry import DEFAULT_RETRY -except ImportError: - storage = None - -try: - from google.cloud import firestore, firestore_admin_v1 -except ImportError: - firestore = None - firestore_admin_v1 = None - - -__all__ = ('GCSBackend',) - - -logger = get_logger(__name__) - - -class GCSBackendBase(KeyValueStoreBackend): - """Google Cloud Storage task result backend.""" - - def __init__(self, **kwargs): - if not storage: - raise ImproperlyConfigured( - 'You must install google-cloud-storage to use gcs backend' - ) - super().__init__(**kwargs) - self._client_lock = RLock() - self._pid = getpid() - self._retry_policy = DEFAULT_RETRY - self._client = None - - conf = self.app.conf - if self.url: - url_params = self._params_from_url() - conf.update(**dictfilter(url_params)) - - self.bucket_name = conf.get('gcs_bucket') - if not self.bucket_name: - raise ImproperlyConfigured( - 'Missing bucket name: specify gcs_bucket to use gcs backend' - ) - self.project = conf.get('gcs_project') - if not self.project: - raise ImproperlyConfigured( - 'Missing project:specify gcs_project to use gcs backend' - ) - self.base_path = conf.get('gcs_base_path', '').strip('/') - self._threadpool_maxsize = int(conf.get('gcs_threadpool_maxsize', 10)) - self.ttl = float(conf.get('gcs_ttl') or 0) - if self.ttl < 0: - raise ImproperlyConfigured( - f'Invalid ttl: {self.ttl} must be greater than or equal to 0' - ) - elif self.ttl: - if not self._is_bucket_lifecycle_rule_exists(): - raise ImproperlyConfigured( - f'Missing lifecycle rule to use gcs backend with ttl on ' - f'bucket: {self.bucket_name}' - ) - - def get(self, key): - key = bytes_to_str(key) - blob = self._get_blob(key) - try: - return blob.download_as_bytes(retry=self._retry_policy) - except storage.blob.NotFound: - return None - - def set(self, key, value): - key = bytes_to_str(key) - blob = self._get_blob(key) - if self.ttl: - blob.custom_time = datetime.utcnow() + timedelta(seconds=self.ttl) - blob.upload_from_string(value, retry=self._retry_policy) - - def delete(self, key): - key = bytes_to_str(key) - blob = self._get_blob(key) - if blob.exists(): - blob.delete(retry=self._retry_policy) - - def mget(self, keys): - with ThreadPoolExecutor() as pool: - return list(pool.map(self.get, keys)) - - @property - def client(self): - """Returns a storage client.""" - - # make sure it's thread-safe, as creating a new client is expensive - with self._client_lock: - if self._client and self._pid == getpid(): - return self._client - # make sure each process gets its own connection after a fork - self._client = Client(project=self.project) - self._pid = getpid() - - # config the number of connections to the server - adapter = requests.adapters.HTTPAdapter( - pool_connections=self._threadpool_maxsize, - pool_maxsize=self._threadpool_maxsize, - max_retries=3, - ) - client_http = self._client._http - client_http.mount("https://", adapter) - client_http._auth_request.session.mount("https://", adapter) - - return self._client - - @property - def bucket(self): - return self.client.bucket(self.bucket_name) - - def _get_blob(self, key): - key_bucket_path = f'{self.base_path}/{key}' if self.base_path else key - return self.bucket.blob(key_bucket_path) - - def _is_bucket_lifecycle_rule_exists(self): - bucket = self.bucket - bucket.reload() - for rule in bucket.lifecycle_rules: - if rule['action']['type'] == 'Delete': - return True - return False - - def _params_from_url(self): - url_parts = url_to_parts(self.url) - - return { - 'gcs_bucket': url_parts.hostname, - 'gcs_base_path': url_parts.path, - **url_parts.query, - } - - -class GCSBackend(GCSBackendBase): - """Google Cloud Storage task result backend. - - Uses Firestore for chord ref count. - """ - - implements_incr = True - supports_native_join = True - - # Firestore parameters - _collection_name = 'celery' - _field_count = 'chord_count' - _field_expires = 'expires_at' - - def __init__(self, **kwargs): - if not (firestore and firestore_admin_v1): - raise ImproperlyConfigured( - 'You must install google-cloud-firestore to use gcs backend' - ) - super().__init__(**kwargs) - - self._firestore_lock = RLock() - self._firestore_client = None - - self.firestore_project = self.app.conf.get( - 'firestore_project', self.project - ) - if not self._is_firestore_ttl_policy_enabled(): - raise ImproperlyConfigured( - f'Missing TTL policy to use gcs backend with ttl on ' - f'Firestore collection: {self._collection_name} ' - f'project: {self.firestore_project}' - ) - - @property - def firestore_client(self): - """Returns a firestore client.""" - - # make sure it's thread-safe, as creating a new client is expensive - with self._firestore_lock: - if self._firestore_client and self._pid == getpid(): - return self._firestore_client - # make sure each process gets its own connection after a fork - self._firestore_client = firestore.Client( - project=self.firestore_project - ) - self._pid = getpid() - return self._firestore_client - - def _is_firestore_ttl_policy_enabled(self): - client = firestore_admin_v1.FirestoreAdminClient() - - name = ( - f"projects/{self.firestore_project}" - f"/databases/(default)/collectionGroups/{self._collection_name}" - f"/fields/{self._field_expires}" - ) - request = firestore_admin_v1.GetFieldRequest(name=name) - field = client.get_field(request=request) - - ttl_config = field.ttl_config - return ttl_config and ttl_config.state in { - firestore_admin_v1.Field.TtlConfig.State.ACTIVE, - firestore_admin_v1.Field.TtlConfig.State.CREATING, - } - - def _apply_chord_incr(self, header_result_args, body, **kwargs): - key = self.get_key_for_chord(header_result_args[0]).decode() - self._expire_chord_key(key, 86400) - return super()._apply_chord_incr(header_result_args, body, **kwargs) - - def incr(self, key: bytes) -> int: - doc = self._firestore_document(key) - resp = doc.set( - {self._field_count: firestore.Increment(1)}, - merge=True, - retry=retry.Retry( - predicate=if_exception_type(Conflict), - initial=1.0, - maximum=180.0, - multiplier=2.0, - timeout=180.0, - ), - ) - return resp.transform_results[0].integer_value - - def on_chord_part_return(self, request, state, result, **kwargs): - """Chord part return callback. - - Called for each task in the chord. - Increments the counter stored in Firestore. - If the counter reaches the number of tasks in the chord, the callback - is called. - If the callback raises an exception, the chord is marked as errored. - If the callback returns a value, the chord is marked as successful. - """ - app = self.app - gid = request.group - if not gid: - return - key = self.get_key_for_chord(gid) - val = self.incr(key) - size = request.chord.get("chord_size") - if size is None: - deps = self._restore_deps(gid, request) - if deps is None: - return - size = len(deps) - if val > size: # pragma: no cover - logger.warning( - 'Chord counter incremented too many times for %r', gid - ) - elif val == size: - # Read the deps once, to reduce the number of reads from GCS ($$) - deps = self._restore_deps(gid, request) - if deps is None: - return - callback = maybe_signature(request.chord, app=app) - j = deps.join_native - try: - with allow_join_result(): - ret = j( - timeout=app.conf.result_chord_join_timeout, - propagate=True, - ) - except Exception as exc: # pylint: disable=broad-except - try: - culprit = next(deps._failed_join_report()) - reason = 'Dependency {0.id} raised {1!r}'.format( - culprit, - exc, - ) - except StopIteration: - reason = repr(exc) - - logger.exception('Chord %r raised: %r', gid, reason) - self.chord_error_from_stack(callback, ChordError(reason)) - else: - try: - callback.delay(ret) - except Exception as exc: # pylint: disable=broad-except - logger.exception('Chord %r raised: %r', gid, exc) - self.chord_error_from_stack( - callback, - ChordError(f'Callback error: {exc!r}'), - ) - finally: - deps.delete() - # Firestore doesn't have an exact ttl policy, so delete the key. - self._delete_chord_key(key) - - def _restore_deps(self, gid, request): - app = self.app - try: - deps = GroupResult.restore(gid, backend=self) - except Exception as exc: # pylint: disable=broad-except - callback = maybe_signature(request.chord, app=app) - logger.exception('Chord %r raised: %r', gid, exc) - self.chord_error_from_stack( - callback, - ChordError(f'Cannot restore group: {exc!r}'), - ) - return - if deps is None: - try: - raise ValueError(gid) - except ValueError as exc: - callback = maybe_signature(request.chord, app=app) - logger.exception('Chord callback %r raised: %r', gid, exc) - self.chord_error_from_stack( - callback, - ChordError(f'GroupResult {gid} no longer exists'), - ) - return deps - - def _delete_chord_key(self, key): - doc = self._firestore_document(key) - doc.delete() - - def _expire_chord_key(self, key, expires): - """Set TTL policy for a Firestore document. - - Firestore ttl data is typically deleted within 24 hours after its - expiration date. - """ - val_expires = datetime.utcnow() + timedelta(seconds=expires) - doc = self._firestore_document(key) - doc.set({self._field_expires: val_expires}, merge=True) - - def _firestore_document(self, key): - return self.firestore_client.collection( - self._collection_name - ).document(bytes_to_str(key)) diff --git a/venv/lib/python3.12/site-packages/celery/backends/mongodb.py b/venv/lib/python3.12/site-packages/celery/backends/mongodb.py index 1789f6c..c64fe38 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/mongodb.py +++ b/venv/lib/python3.12/site-packages/celery/backends/mongodb.py @@ -1,5 +1,5 @@ """MongoDB result store backend.""" -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta from kombu.exceptions import EncodeError from kombu.utils.objects import cached_property @@ -228,7 +228,7 @@ class MongoBackend(BaseBackend): meta = { '_id': group_id, 'result': self.encode([i.id for i in result]), - 'date_done': datetime.now(timezone.utc), + 'date_done': datetime.utcnow(), } self.group_collection.replace_one({'_id': group_id}, meta, upsert=True) return result diff --git a/venv/lib/python3.12/site-packages/celery/backends/redis.py b/venv/lib/python3.12/site-packages/celery/backends/redis.py index 3e3ef73..8acc608 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/redis.py +++ b/venv/lib/python3.12/site-packages/celery/backends/redis.py @@ -359,11 +359,6 @@ class RedisBackend(BaseKeyValueStoreBackend, AsyncBackendMixin): connparams.update(query) return connparams - def exception_safe_to_retry(self, exc): - if isinstance(exc, self.connection_errors): - return True - return False - @cached_property def retry_policy(self): retry_policy = super().retry_policy diff --git a/venv/lib/python3.12/site-packages/celery/backends/rpc.py b/venv/lib/python3.12/site-packages/celery/backends/rpc.py index 927c7f5..399c1dc 100644 --- a/venv/lib/python3.12/site-packages/celery/backends/rpc.py +++ b/venv/lib/python3.12/site-packages/celery/backends/rpc.py @@ -222,7 +222,7 @@ class RPCBackend(base.Backend, AsyncBackendMixin): def on_out_of_band_result(self, task_id, message): # Callback called when a reply for a task is received, - # but we have no idea what to do with it. + # but we have no idea what do do with it. # Since the result is not pending, we put it in a separate # buffer: probably it will become pending later. if self.result_consumer: diff --git a/venv/lib/python3.12/site-packages/celery/beat.py b/venv/lib/python3.12/site-packages/celery/beat.py index 86ad837..76e4472 100644 --- a/venv/lib/python3.12/site-packages/celery/beat.py +++ b/venv/lib/python3.12/site-packages/celery/beat.py @@ -1,7 +1,6 @@ """The periodic task scheduler.""" import copy -import dbm import errno import heapq import os @@ -569,11 +568,11 @@ class PersistentScheduler(Scheduler): for _ in (1, 2): try: self._store['entries'] - except (KeyError, UnicodeDecodeError, TypeError): + except KeyError: # new schedule db try: self._store['entries'] = {} - except (KeyError, UnicodeDecodeError, TypeError) + dbm.error as exc: + except KeyError as exc: self._store = self._destroy_open_corrupted_schedule(exc) continue else: diff --git a/venv/lib/python3.12/site-packages/celery/bin/base.py b/venv/lib/python3.12/site-packages/celery/bin/base.py index 61cc37a..63a2895 100644 --- a/venv/lib/python3.12/site-packages/celery/bin/base.py +++ b/venv/lib/python3.12/site-packages/celery/bin/base.py @@ -4,10 +4,9 @@ import numbers from collections import OrderedDict from functools import update_wrapper from pprint import pformat -from typing import Any import click -from click import Context, ParamType +from click import ParamType from kombu.utils.objects import cached_property from celery._state import get_current_app @@ -171,37 +170,19 @@ class CeleryCommand(click.Command): formatter.write_dl(opts_group) -class DaemonOption(CeleryOption): - """Common daemonization option""" - - def __init__(self, *args, **kwargs): - super().__init__(args, - help_group=kwargs.pop("help_group", "Daemonization Options"), - callback=kwargs.pop("callback", self.daemon_setting), - **kwargs) - - def daemon_setting(self, ctx: Context, opt: CeleryOption, value: Any) -> Any: - """ - Try to fetch daemonization option from applications settings. - Use the daemon command name as prefix (eg. `worker` -> `worker_pidfile`) - """ - return value or getattr(ctx.obj.app.conf, f"{ctx.command.name}_{self.name}", None) - - class CeleryDaemonCommand(CeleryCommand): """Daemon commands.""" def __init__(self, *args, **kwargs): """Initialize a Celery command with common daemon options.""" super().__init__(*args, **kwargs) - self.params.extend(( - DaemonOption("--logfile", "-f", help="Log destination; defaults to stderr"), - DaemonOption("--pidfile", help="PID file path; defaults to no PID file"), - DaemonOption("--uid", help="Drops privileges to this user ID"), - DaemonOption("--gid", help="Drops privileges to this group ID"), - DaemonOption("--umask", help="Create files and directories with this umask"), - DaemonOption("--executable", help="Override path to the Python executable"), - )) + self.params.append(CeleryOption(('-f', '--logfile'), help_group="Daemonization Options", + help="Log destination; defaults to stderr")) + self.params.append(CeleryOption(('--pidfile',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--uid',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--gid',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--umask',), help_group="Daemonization Options")) + self.params.append(CeleryOption(('--executable',), help_group="Daemonization Options")) class CommaSeparatedList(ParamType): diff --git a/venv/lib/python3.12/site-packages/celery/bin/celery.py b/venv/lib/python3.12/site-packages/celery/bin/celery.py index 4ddf9c7..4aeed42 100644 --- a/venv/lib/python3.12/site-packages/celery/bin/celery.py +++ b/venv/lib/python3.12/site-packages/celery/bin/celery.py @@ -11,6 +11,7 @@ except ImportError: import click import click.exceptions +from click.types import ParamType from click_didyoumean import DYMGroup from click_plugins import with_plugins @@ -47,6 +48,34 @@ Unable to load celery application. {0}""") +class App(ParamType): + """Application option.""" + + name = "application" + + def convert(self, value, param, ctx): + try: + return find_app(value) + except ModuleNotFoundError as e: + if e.name != value: + exc = traceback.format_exc() + self.fail( + UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc) + ) + self.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name)) + except AttributeError as e: + attribute_name = e.args[0].capitalize() + self.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name)) + except Exception: + exc = traceback.format_exc() + self.fail( + UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(value, exc) + ) + + +APP = App() + + if sys.version_info >= (3, 10): _PLUGINS = entry_points(group='celery.commands') else: @@ -62,11 +91,7 @@ else: '--app', envvar='APP', cls=CeleryOption, - # May take either: a str when invoked from command line (Click), - # or a Celery object when invoked from inside Celery; hence the - # need to prevent Click from "processing" the Celery object and - # converting it into its str representation. - type=click.UNPROCESSED, + type=APP, help_group="Global Options") @click.option('-b', '--broker', @@ -135,26 +160,6 @@ def celery(ctx, app, broker, result_backend, loader, config, workdir, os.environ['CELERY_CONFIG_MODULE'] = config if skip_checks: os.environ['CELERY_SKIP_CHECKS'] = 'true' - - if isinstance(app, str): - try: - app = find_app(app) - except ModuleNotFoundError as e: - if e.name != app: - exc = traceback.format_exc() - ctx.fail( - UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc) - ) - ctx.fail(UNABLE_TO_LOAD_APP_MODULE_NOT_FOUND.format(e.name)) - except AttributeError as e: - attribute_name = e.args[0].capitalize() - ctx.fail(UNABLE_TO_LOAD_APP_APP_MISSING.format(attribute_name)) - except Exception: - exc = traceback.format_exc() - ctx.fail( - UNABLE_TO_LOAD_APP_ERROR_OCCURRED.format(app, exc) - ) - ctx.obj = CLIContext(app=app, no_color=no_color, workdir=workdir, quiet=quiet) diff --git a/venv/lib/python3.12/site-packages/celery/bin/control.py b/venv/lib/python3.12/site-packages/celery/bin/control.py index 38a917e..f7bba96 100644 --- a/venv/lib/python3.12/site-packages/celery/bin/control.py +++ b/venv/lib/python3.12/site-packages/celery/bin/control.py @@ -1,6 +1,5 @@ """The ``celery control``, ``. inspect`` and ``. status`` programs.""" from functools import partial -from typing import Literal import click from kombu.utils.json import dumps @@ -40,69 +39,18 @@ def _consume_arguments(meta, method, args): args[:] = args[i:] -def _compile_arguments(command, args): - meta = Panel.meta[command] +def _compile_arguments(action, args): + meta = Panel.meta[action] arguments = {} if meta.args: arguments.update({ - k: v for k, v in _consume_arguments(meta, command, args) + k: v for k, v in _consume_arguments(meta, action, args) }) if meta.variadic: arguments.update({meta.variadic: args}) return arguments -_RemoteControlType = Literal['inspect', 'control'] - - -def _verify_command_name(type_: _RemoteControlType, command: str) -> None: - choices = _get_commands_of_type(type_) - - if command not in choices: - command_listing = ", ".join(choices) - raise click.UsageError( - message=f'Command {command} not recognized. Available {type_} commands: {command_listing}', - ) - - -def _list_option(type_: _RemoteControlType): - def callback(ctx: click.Context, param, value) -> None: - if not value: - return - choices = _get_commands_of_type(type_) - - formatter = click.HelpFormatter() - - with formatter.section(f'{type_.capitalize()} Commands'): - command_list = [] - for command_name, info in choices.items(): - if info.signature: - command_preview = f'{command_name} {info.signature}' - else: - command_preview = command_name - command_list.append((command_preview, info.help)) - formatter.write_dl(command_list) - ctx.obj.echo(formatter.getvalue(), nl=False) - ctx.exit() - - return click.option( - '--list', - is_flag=True, - help=f'List available {type_} commands and exit.', - expose_value=False, - is_eager=True, - callback=callback, - ) - - -def _get_commands_of_type(type_: _RemoteControlType) -> dict: - command_name_info_pairs = [ - (name, info) for name, info in Panel.meta.items() - if info.type == type_ and info.visible - ] - return dict(sorted(command_name_info_pairs)) - - @click.command(cls=CeleryCommand) @click.option('-t', '--timeout', @@ -148,8 +96,10 @@ def status(ctx, timeout, destination, json, **kwargs): @click.command(cls=CeleryCommand, context_settings={'allow_extra_args': True}) -@click.argument('command') -@_list_option('inspect') +@click.argument("action", type=click.Choice([ + name for name, info in Panel.meta.items() + if info.type == 'inspect' and info.visible +])) @click.option('-t', '--timeout', cls=CeleryOption, @@ -171,19 +121,19 @@ def status(ctx, timeout, destination, json, **kwargs): help='Use json as output format.') @click.pass_context @handle_preload_options -def inspect(ctx, command, timeout, destination, json, **kwargs): - """Inspect the workers by sending them the COMMAND inspect command. +def inspect(ctx, action, timeout, destination, json, **kwargs): + """Inspect the worker at runtime. Availability: RabbitMQ (AMQP) and Redis transports. """ - _verify_command_name('inspect', command) callback = None if json else partial(_say_remote_command_reply, ctx, show_reply=True) - arguments = _compile_arguments(command, ctx.args) + arguments = _compile_arguments(action, ctx.args) inspect = ctx.obj.app.control.inspect(timeout=timeout, destination=destination, callback=callback) - replies = inspect._request(command, **arguments) + replies = inspect._request(action, + **arguments) if not replies: raise CeleryCommandException( @@ -203,8 +153,10 @@ def inspect(ctx, command, timeout, destination, json, **kwargs): @click.command(cls=CeleryCommand, context_settings={'allow_extra_args': True}) -@click.argument('command') -@_list_option('control') +@click.argument("action", type=click.Choice([ + name for name, info in Panel.meta.items() + if info.type == 'control' and info.visible +])) @click.option('-t', '--timeout', cls=CeleryOption, @@ -226,17 +178,16 @@ def inspect(ctx, command, timeout, destination, json, **kwargs): help='Use json as output format.') @click.pass_context @handle_preload_options -def control(ctx, command, timeout, destination, json): - """Send the COMMAND control command to the workers. +def control(ctx, action, timeout, destination, json): + """Workers remote control. Availability: RabbitMQ (AMQP), Redis, and MongoDB transports. """ - _verify_command_name('control', command) callback = None if json else partial(_say_remote_command_reply, ctx, show_reply=True) args = ctx.args - arguments = _compile_arguments(command, args) - replies = ctx.obj.app.control.broadcast(command, timeout=timeout, + arguments = _compile_arguments(action, args) + replies = ctx.obj.app.control.broadcast(action, timeout=timeout, destination=destination, callback=callback, reply=True, diff --git a/venv/lib/python3.12/site-packages/celery/canvas.py b/venv/lib/python3.12/site-packages/celery/canvas.py index da395c1..a4007f0 100644 --- a/venv/lib/python3.12/site-packages/celery/canvas.py +++ b/venv/lib/python3.12/site-packages/celery/canvas.py @@ -396,7 +396,7 @@ class Signature(dict): else: args, kwargs, options = self.args, self.kwargs, self.options # pylint: disable=too-many-function-args - # Works on this, as it's a property + # Borks on this, as it's a property return _apply(args, kwargs, **options) def _merge(self, args=None, kwargs=None, options=None, force=False): @@ -515,7 +515,7 @@ class Signature(dict): if group_index is not None: opts['group_index'] = group_index # pylint: disable=too-many-function-args - # Works on this, as it's a property. + # Borks on this, as it's a property. return self.AsyncResult(tid) _freeze = freeze @@ -958,8 +958,6 @@ class _chain(Signature): if isinstance(other, group): # unroll group with one member other = maybe_unroll_group(other) - if not isinstance(other, group): - return self.__or__(other) # chain | group() -> chain tasks = self.unchain_tasks() if not tasks: @@ -974,20 +972,15 @@ class _chain(Signature): tasks, other), app=self._app) elif isinstance(other, _chain): # chain | chain -> chain - return reduce(operator.or_, other.unchain_tasks(), self) + # use type(self) for _chain subclasses + return type(self)(seq_concat_seq( + self.unchain_tasks(), other.unchain_tasks()), app=self._app) elif isinstance(other, Signature): if self.tasks and isinstance(self.tasks[-1], group): # CHAIN [last item is group] | TASK -> chord sig = self.clone() sig.tasks[-1] = chord( sig.tasks[-1], other, app=self._app) - # In the scenario where the second-to-last item in a chain is a chord, - # it leads to a situation where two consecutive chords are formed. - # In such cases, a further upgrade can be considered. - # This would involve chaining the body of the second-to-last chord with the last chord." - if len(sig.tasks) > 1 and isinstance(sig.tasks[-2], chord): - sig.tasks[-2].body = sig.tasks[-2].body | sig.tasks[-1] - sig.tasks = sig.tasks[:-1] return sig elif self.tasks and isinstance(self.tasks[-1], chord): # CHAIN [last item is chord] -> chain with chord body. @@ -1223,12 +1216,6 @@ class _chain(Signature): task, body=prev_task, root_id=root_id, app=app, ) - if tasks: - prev_task = tasks[-1] - prev_res = results[-1] - else: - prev_task = None - prev_res = None if is_last_task: # chain(task_id=id) means task id is set for the last task @@ -1274,7 +1261,6 @@ class _chain(Signature): while node.parent: node = node.parent prev_res = node - self.id = last_task_id return tasks, results def apply(self, args=None, kwargs=None, **options): @@ -1686,8 +1672,6 @@ class group(Signature): # # We return a concretised tuple of the signatures actually applied to # each child task signature, of which there might be none! - sig = maybe_signature(sig) - return tuple(child_task.link_error(sig.clone(immutable=True)) for child_task in self.tasks) def _prepared(self, tasks, partial_args, group_id, root_id, app, @@ -2287,8 +2271,6 @@ class _chord(Signature): ``False`` (the current default), then the error callback will only be applied to the body. """ - errback = maybe_signature(errback) - if self.app.conf.task_allow_error_cb_on_chord_header: for task in maybe_list(self.tasks) or []: task.link_error(errback.clone(immutable=True)) @@ -2307,13 +2289,6 @@ class _chord(Signature): CPendingDeprecationWarning ) - # Edge case for nested chords in the header - for task in maybe_list(self.tasks) or []: - if isinstance(task, chord): - # Let the nested chord do the error linking itself on its - # header and body where needed, based on the current configuration - task.link_error(errback) - self.body.link_error(errback) return errback diff --git a/venv/lib/python3.12/site-packages/celery/concurrency/asynpool.py b/venv/lib/python3.12/site-packages/celery/concurrency/asynpool.py index dd2f068..c024e68 100644 --- a/venv/lib/python3.12/site-packages/celery/concurrency/asynpool.py +++ b/venv/lib/python3.12/site-packages/celery/concurrency/asynpool.py @@ -103,35 +103,26 @@ def _get_job_writer(job): return writer() # is a weakref -def _ensure_integral_fd(fd): - return fd if isinstance(fd, Integral) else fd.fileno() - - if hasattr(select, 'poll'): def _select_imp(readers=None, writers=None, err=None, timeout=0, poll=select.poll, POLLIN=select.POLLIN, POLLOUT=select.POLLOUT, POLLERR=select.POLLERR): poller = poll() register = poller.register - fd_to_mask = {} if readers: - for fd in map(_ensure_integral_fd, readers): - fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLIN + [register(fd, POLLIN) for fd in readers] if writers: - for fd in map(_ensure_integral_fd, writers): - fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLOUT + [register(fd, POLLOUT) for fd in writers] if err: - for fd in map(_ensure_integral_fd, err): - fd_to_mask[fd] = fd_to_mask.get(fd, 0) | POLLERR - - for fd, event_mask in fd_to_mask.items(): - register(fd, event_mask) + [register(fd, POLLERR) for fd in err] R, W = set(), set() timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3) events = poller.poll(timeout) for fd, event in events: + if not isinstance(fd, Integral): + fd = fd.fileno() if event & POLLIN: R.add(fd) if event & POLLOUT: @@ -203,7 +194,7 @@ def iterate_file_descriptors_safely(fds_iter, source_data, or possibly other reasons, so safely manage our lists of FDs. :param fds_iter: the file descriptors to iterate and apply hub_method :param source_data: data source to remove FD if it renders OSError - :param hub_method: the method to call with each fd and kwargs + :param hub_method: the method to call with with each fd and kwargs :*args to pass through to the hub_method; with a special syntax string '*fd*' represents a substitution for the current fd object in the iteration (for some callers). @@ -781,7 +772,7 @@ class AsynPool(_pool.Pool): None, WRITE | ERR, consolidate=True) else: iterate_file_descriptors_safely( - inactive, all_inqueues, hub.remove_writer) + inactive, all_inqueues, hub_remove) self.on_poll_start = on_poll_start def on_inqueue_close(fd, proc): @@ -827,7 +818,7 @@ class AsynPool(_pool.Pool): # worker is already busy with another task continue if ready_fd not in all_inqueues: - hub.remove_writer(ready_fd) + hub_remove(ready_fd) continue try: job = pop_message() @@ -838,7 +829,7 @@ class AsynPool(_pool.Pool): # this may create a spinloop where the event loop # always wakes up. for inqfd in diff(active_writes): - hub.remove_writer(inqfd) + hub_remove(inqfd) break else: @@ -936,7 +927,7 @@ class AsynPool(_pool.Pool): else: errors = 0 finally: - hub.remove_writer(fd) + hub_remove(fd) write_stats[proc.index] += 1 # message written, so this fd is now available active_writes.discard(fd) diff --git a/venv/lib/python3.12/site-packages/celery/concurrency/gevent.py b/venv/lib/python3.12/site-packages/celery/concurrency/gevent.py index fd58e91..b0ea7e6 100644 --- a/venv/lib/python3.12/site-packages/celery/concurrency/gevent.py +++ b/venv/lib/python3.12/site-packages/celery/concurrency/gevent.py @@ -1,6 +1,4 @@ """Gevent execution pool.""" -import functools -import types from time import monotonic from kombu.asynchronous import timer as _timer @@ -18,22 +16,15 @@ __all__ = ('TaskPool',) # We cache globals and attribute lookups, so disable this warning. -def apply_target(target, args=(), kwargs=None, callback=None, - accept_callback=None, getpid=None, **_): - kwargs = {} if not kwargs else kwargs - return base.apply_target(target, args, kwargs, callback, accept_callback, - pid=getpid(), **_) - - def apply_timeout(target, args=(), kwargs=None, callback=None, - accept_callback=None, getpid=None, timeout=None, + accept_callback=None, pid=None, timeout=None, timeout_callback=None, Timeout=Timeout, apply_target=base.apply_target, **rest): kwargs = {} if not kwargs else kwargs try: with Timeout(timeout): return apply_target(target, args, kwargs, callback, - accept_callback, getpid(), + accept_callback, pid, propagate=(Timeout,), **rest) except Timeout: return timeout_callback(False, timeout) @@ -91,22 +82,18 @@ class TaskPool(base.BasePool): is_green = True task_join_will_block = False _pool = None - _pool_map = None _quick_put = None def __init__(self, *args, **kwargs): - from gevent import getcurrent, spawn_raw + from gevent import spawn_raw from gevent.pool import Pool self.Pool = Pool - self.getcurrent = getcurrent - self.getpid = lambda: id(getcurrent()) self.spawn_n = spawn_raw self.timeout = kwargs.get('timeout') super().__init__(*args, **kwargs) def on_start(self): self._pool = self.Pool(self.limit) - self._pool_map = {} self._quick_put = self._pool.spawn def on_stop(self): @@ -115,15 +102,12 @@ class TaskPool(base.BasePool): def on_apply(self, target, args=None, kwargs=None, callback=None, accept_callback=None, timeout=None, - timeout_callback=None, apply_target=apply_target, **_): + timeout_callback=None, apply_target=base.apply_target, **_): timeout = self.timeout if timeout is None else timeout - target = self._make_killable_target(target) - greenlet = self._quick_put(apply_timeout if timeout else apply_target, - target, args, kwargs, callback, accept_callback, - self.getpid, timeout=timeout, timeout_callback=timeout_callback) - self._add_to_pool_map(id(greenlet), greenlet) - greenlet.terminate = types.MethodType(_terminate, greenlet) - return greenlet + return self._quick_put(apply_timeout if timeout else apply_target, + target, args, kwargs, callback, accept_callback, + timeout=timeout, + timeout_callback=timeout_callback) def grow(self, n=1): self._pool._semaphore.counter += n @@ -133,39 +117,6 @@ class TaskPool(base.BasePool): self._pool._semaphore.counter -= n self._pool.size -= n - def terminate_job(self, pid, signal=None): - import gevent - - if pid in self._pool_map: - greenlet = self._pool_map[pid] - gevent.kill(greenlet) - @property def num_processes(self): return len(self._pool) - - @staticmethod - def _make_killable_target(target): - def killable_target(*args, **kwargs): - from greenlet import GreenletExit - try: - return target(*args, **kwargs) - except GreenletExit: - return (False, None, None) - - return killable_target - - def _add_to_pool_map(self, pid, greenlet): - self._pool_map[pid] = greenlet - greenlet.link( - functools.partial(self._cleanup_after_job_finish, pid=pid, pool_map=self._pool_map), - ) - - @staticmethod - def _cleanup_after_job_finish(greenlet, pool_map, pid): - del pool_map[pid] - - -def _terminate(self, signal): - # Done in `TaskPool.terminate_job` - pass diff --git a/venv/lib/python3.12/site-packages/celery/contrib/django/task.py b/venv/lib/python3.12/site-packages/celery/contrib/django/task.py deleted file mode 100644 index b0dc667..0000000 --- a/venv/lib/python3.12/site-packages/celery/contrib/django/task.py +++ /dev/null @@ -1,21 +0,0 @@ -import functools - -from django.db import transaction - -from celery.app.task import Task - - -class DjangoTask(Task): - """ - Extend the base :class:`~celery.app.task.Task` for Django. - - Provide a nicer API to trigger tasks at the end of the DB transaction. - """ - - def delay_on_commit(self, *args, **kwargs) -> None: - """Call :meth:`~celery.app.task.Task.delay` with Django's ``on_commit()``.""" - transaction.on_commit(functools.partial(self.delay, *args, **kwargs)) - - def apply_async_on_commit(self, *args, **kwargs) -> None: - """Call :meth:`~celery.app.task.Task.apply_async` with Django's ``on_commit()``.""" - transaction.on_commit(functools.partial(self.apply_async, *args, **kwargs)) diff --git a/venv/lib/python3.12/site-packages/celery/contrib/testing/worker.py b/venv/lib/python3.12/site-packages/celery/contrib/testing/worker.py index 46eac75..fa8f688 100644 --- a/venv/lib/python3.12/site-packages/celery/contrib/testing/worker.py +++ b/venv/lib/python3.12/site-packages/celery/contrib/testing/worker.py @@ -3,10 +3,10 @@ import logging import os import threading from contextlib import contextmanager -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Union # noqa import celery.worker.consumer # noqa -from celery import Celery, worker +from celery import Celery, worker # noqa from celery.result import _set_task_join_will_block, allow_join_result from celery.utils.dispatch import Signal from celery.utils.nodenames import anon_nodename @@ -30,10 +30,6 @@ test_worker_stopped = Signal( class TestWorkController(worker.WorkController): """Worker that can synchronize on being fully started.""" - # When this class is imported in pytest files, prevent pytest from thinking - # this is a test class - __test__ = False - logger_queue = None def __init__(self, *args, **kwargs): @@ -135,15 +131,16 @@ def start_worker( @contextmanager -def _start_worker_thread(app: Celery, - concurrency: int = 1, - pool: str = 'solo', - loglevel: Union[str, int] = WORKER_LOGLEVEL, - logfile: Optional[str] = None, - WorkController: Any = TestWorkController, - perform_ping_check: bool = True, - shutdown_timeout: float = 10.0, - **kwargs) -> Iterable[worker.WorkController]: +def _start_worker_thread(app, + concurrency=1, + pool='solo', + loglevel=WORKER_LOGLEVEL, + logfile=None, + WorkController=TestWorkController, + perform_ping_check=True, + shutdown_timeout=10.0, + **kwargs): + # type: (Celery, int, str, Union[str, int], str, Any, **Any) -> Iterable """Start Celery worker in a thread. Yields: @@ -159,7 +156,7 @@ def _start_worker_thread(app: Celery, worker = WorkController( app=app, concurrency=concurrency, - hostname=kwargs.pop("hostname", anon_nodename()), + hostname=anon_nodename(), pool=pool, loglevel=loglevel, logfile=logfile, @@ -214,7 +211,8 @@ def _start_worker_process(app, cluster.stopwait() -def setup_app_for_worker(app: Celery, loglevel: Union[str, int], logfile: str) -> None: +def setup_app_for_worker(app, loglevel, logfile) -> None: + # type: (Celery, Union[str, int], str) -> None """Setup the app to be used for starting an embedded worker.""" app.finalize() app.set_current() diff --git a/venv/lib/python3.12/site-packages/celery/events/event.py b/venv/lib/python3.12/site-packages/celery/events/event.py index fd2ee1e..a05ed70 100644 --- a/venv/lib/python3.12/site-packages/celery/events/event.py +++ b/venv/lib/python3.12/site-packages/celery/events/event.py @@ -55,7 +55,7 @@ def get_exchange(conn, name=EVENT_EXCHANGE_NAME): (from topic -> fanout). """ ex = copy(event_exchange) - if conn.transport.driver_type in {'redis', 'gcpubsub'}: + if conn.transport.driver_type == 'redis': # quick hack for Issue #436 ex.type = 'fanout' if name != ex.name: diff --git a/venv/lib/python3.12/site-packages/celery/fixups/django.py b/venv/lib/python3.12/site-packages/celery/fixups/django.py index b354994..473c3b6 100644 --- a/venv/lib/python3.12/site-packages/celery/fixups/django.py +++ b/venv/lib/python3.12/site-packages/celery/fixups/django.py @@ -2,7 +2,7 @@ import os import sys import warnings -from datetime import datetime, timezone +from datetime import datetime from importlib import import_module from typing import IO, TYPE_CHECKING, Any, List, Optional, cast @@ -16,7 +16,6 @@ if TYPE_CHECKING: from types import ModuleType from typing import Protocol - from django.db.backends.base.base import BaseDatabaseWrapper from django.db.utils import ConnectionHandler from celery.app.base import Celery @@ -79,9 +78,6 @@ class DjangoFixup: self._settings = symbol_by_name('django.conf:settings') self.app.loader.now = self.now - if not self.app._custom_task_cls_used: - self.app.task_cls = 'celery.contrib.django.task:DjangoTask' - signals.import_modules.connect(self.on_import_modules) signals.worker_init.connect(self.on_worker_init) return self @@ -104,7 +100,7 @@ class DjangoFixup: self.worker_fixup.install() def now(self, utc: bool = False) -> datetime: - return datetime.now(timezone.utc) if utc else self._now() + return datetime.utcnow() if utc else self._now() def autodiscover_tasks(self) -> List[str]: from django.apps import apps @@ -165,16 +161,15 @@ class DjangoWorkerFixup: # network IO that close() might cause. for c in self._db.connections.all(): if c and c.connection: - self._maybe_close_db_fd(c) + self._maybe_close_db_fd(c.connection) # use the _ version to avoid DB_REUSE preventing the conn.close() call self._close_database(force=True) self.close_cache() - def _maybe_close_db_fd(self, c: "BaseDatabaseWrapper") -> None: + def _maybe_close_db_fd(self, fd: IO) -> None: try: - with c.wrap_database_errors: - _maybe_close_fd(c.connection) + _maybe_close_fd(fd) except self.interface_errors: pass diff --git a/venv/lib/python3.12/site-packages/celery/loaders/base.py b/venv/lib/python3.12/site-packages/celery/loaders/base.py index 01e8425..aa7139c 100644 --- a/venv/lib/python3.12/site-packages/celery/loaders/base.py +++ b/venv/lib/python3.12/site-packages/celery/loaders/base.py @@ -3,7 +3,7 @@ import importlib import os import re import sys -from datetime import datetime, timezone +from datetime import datetime from kombu.utils import json from kombu.utils.objects import cached_property @@ -62,7 +62,7 @@ class BaseLoader: def now(self, utc=True): if utc: - return datetime.now(timezone.utc) + return datetime.utcnow() return datetime.now() def on_task_init(self, task_id, task): @@ -253,12 +253,10 @@ def find_related_module(package, related_name): # Django 1.7 allows for specifying a class name in INSTALLED_APPS. # (Issue #2248). try: - # Return package itself when no related_name. module = importlib.import_module(package) if not related_name and module: return module - except ModuleNotFoundError: - # On import error, try to walk package up one level. + except ImportError: package, _, _ = package.rpartition('.') if not package: raise @@ -266,13 +264,9 @@ def find_related_module(package, related_name): module_name = f'{package}.{related_name}' try: - # Try to find related_name under package. return importlib.import_module(module_name) - except ModuleNotFoundError as e: - import_exc_name = getattr(e, 'name', None) - # If candidate does not exist, then return None. - if import_exc_name and module_name == import_exc_name: - return - - # Otherwise, raise because error probably originated from a nested import. - raise e + except ImportError as e: + import_exc_name = getattr(e, 'name', module_name) + if import_exc_name is not None and import_exc_name != module_name: + raise e + return diff --git a/venv/lib/python3.12/site-packages/celery/local.py b/venv/lib/python3.12/site-packages/celery/local.py index 34eafff..7bbe615 100644 --- a/venv/lib/python3.12/site-packages/celery/local.py +++ b/venv/lib/python3.12/site-packages/celery/local.py @@ -397,6 +397,7 @@ COMPAT_MODULES = { }, 'log': { 'get_default_logger': 'log.get_default_logger', + 'setup_logger': 'log.setup_logger', 'setup_logging_subsystem': 'log.setup_logging_subsystem', 'redirect_stdouts_to_logger': 'log.redirect_stdouts_to_logger', }, diff --git a/venv/lib/python3.12/site-packages/celery/platforms.py b/venv/lib/python3.12/site-packages/celery/platforms.py index c0d0438..f424ac3 100644 --- a/venv/lib/python3.12/site-packages/celery/platforms.py +++ b/venv/lib/python3.12/site-packages/celery/platforms.py @@ -42,7 +42,7 @@ __all__ = ( 'DaemonContext', 'detached', 'parse_uid', 'parse_gid', 'setgroups', 'initgroups', 'setgid', 'setuid', 'maybe_drop_privileges', 'signals', 'signal_name', 'set_process_title', 'set_mp_process_title', - 'get_errno_name', 'ignore_errno', 'fd_by_path', 'isatty', + 'get_errno_name', 'ignore_errno', 'fd_by_path', ) # exitcodes @@ -95,14 +95,6 @@ SIGNAMES = { SIGMAP = {getattr(_signal, name): name for name in SIGNAMES} -def isatty(fh): - """Return true if the process has a controlling terminal.""" - try: - return fh.isatty() - except AttributeError: - pass - - def pyimplementation(): """Return string identifying the current Python implementation.""" if hasattr(_platform, 'python_implementation'): @@ -194,14 +186,10 @@ class Pidfile: if not pid: self.remove() return True - if pid == os.getpid(): - # this can be common in k8s pod with PID of 1 - don't kill - self.remove() - return True try: os.kill(pid, 0) - except OSError as exc: + except os.error as exc: if exc.errno == errno.ESRCH or exc.errno == errno.EPERM: print('Stale pidfile exists - Removing it.', file=sys.stderr) self.remove() diff --git a/venv/lib/python3.12/site-packages/celery/result.py b/venv/lib/python3.12/site-packages/celery/result.py index 75512c5..0c9e0a3 100644 --- a/venv/lib/python3.12/site-packages/celery/result.py +++ b/venv/lib/python3.12/site-packages/celery/result.py @@ -6,7 +6,6 @@ from collections import deque from contextlib import contextmanager from weakref import proxy -from dateutil.parser import isoparse from kombu.utils.objects import cached_property from vine import Thenable, barrier, promise @@ -533,7 +532,7 @@ class AsyncResult(ResultBase): """UTC date and time.""" date_done = self._get_task_meta().get('date_done') if date_done and not isinstance(date_done, datetime.datetime): - return isoparse(date_done) + return datetime.datetime.fromisoformat(date_done) return date_done @property @@ -984,14 +983,13 @@ class GroupResult(ResultSet): class EagerResult(AsyncResult): """Result that we know has already been executed.""" - def __init__(self, id, ret_value, state, traceback=None, name=None): + def __init__(self, id, ret_value, state, traceback=None): # pylint: disable=super-init-not-called # XXX should really not be inheriting from AsyncResult self.id = id self._result = ret_value self._state = state self._traceback = traceback - self._name = name self.on_ready = promise() self.on_ready(self) @@ -1044,7 +1042,6 @@ class EagerResult(AsyncResult): 'result': self._result, 'status': self._state, 'traceback': self._traceback, - 'name': self._name, } @property diff --git a/venv/lib/python3.12/site-packages/celery/schedules.py b/venv/lib/python3.12/site-packages/celery/schedules.py index 010b339..b35436a 100644 --- a/venv/lib/python3.12/site-packages/celery/schedules.py +++ b/venv/lib/python3.12/site-packages/celery/schedules.py @@ -4,8 +4,9 @@ from __future__ import annotations import re from bisect import bisect, bisect_left from collections import namedtuple +from collections.abc import Iterable from datetime import datetime, timedelta, tzinfo -from typing import Any, Callable, Iterable, Mapping, Sequence, Union +from typing import Any, Callable, Mapping, Sequence from kombu.utils.objects import cached_property @@ -14,7 +15,7 @@ from celery import Celery from . import current_app from .utils.collections import AttributeDict from .utils.time import (ffwd, humanize_seconds, localize, maybe_make_aware, maybe_timedelta, remaining, timezone, - weekday, yearmonth) + weekday) __all__ = ( 'ParseException', 'schedule', 'crontab', 'crontab_parser', @@ -51,10 +52,7 @@ Argument event "{event}" is invalid, must be one of {all_events}.\ """ -Cronspec = Union[int, str, Iterable[int]] - - -def cronfield(s: Cronspec | None) -> Cronspec: +def cronfield(s: str) -> str: return '*' if s is None else s @@ -302,12 +300,9 @@ class crontab_parser: i = int(s) except ValueError: try: - i = yearmonth(s) + i = weekday(s) except KeyError: - try: - i = weekday(s) - except KeyError: - raise ValueError(f'Invalid weekday literal {s!r}.') + raise ValueError(f'Invalid weekday literal {s!r}.') max_val = self.min_ + self.max_ - 1 if i > max_val: @@ -398,8 +393,8 @@ class crontab(BaseSchedule): present in ``month_of_year``. """ - def __init__(self, minute: Cronspec = '*', hour: Cronspec = '*', day_of_week: Cronspec = '*', - day_of_month: Cronspec = '*', month_of_year: Cronspec = '*', **kwargs: Any) -> None: + def __init__(self, minute: str = '*', hour: str = '*', day_of_week: str = '*', + day_of_month: str = '*', month_of_year: str = '*', **kwargs: Any) -> None: self._orig_minute = cronfield(minute) self._orig_hour = cronfield(hour) self._orig_day_of_week = cronfield(day_of_week) @@ -413,26 +408,9 @@ class crontab(BaseSchedule): self.month_of_year = self._expand_cronspec(month_of_year, 12, 1) super().__init__(**kwargs) - @classmethod - def from_string(cls, crontab: str) -> crontab: - """ - Create a Crontab from a cron expression string. For example ``crontab.from_string('* * * * *')``. - - .. code-block:: text - - ┌───────────── minute (0–59) - │ ┌───────────── hour (0–23) - │ │ ┌───────────── day of the month (1–31) - │ │ │ ┌───────────── month (1–12) - │ │ │ │ ┌───────────── day of the week (0–6) (Sunday to Saturday) - * * * * * - """ - minute, hour, day_of_month, month_of_year, day_of_week = crontab.split(" ") - return cls(minute, hour, day_of_week, day_of_month, month_of_year) - @staticmethod def _expand_cronspec( - cronspec: Cronspec, + cronspec: int | str | Iterable, max_: int, min_: int = 0) -> set[Any]: """Expand cron specification. @@ -557,7 +535,7 @@ class crontab(BaseSchedule): def __repr__(self) -> str: return CRON_REPR.format(self) - def __reduce__(self) -> tuple[type, tuple[Cronspec, Cronspec, Cronspec, Cronspec, Cronspec], Any]: + def __reduce__(self) -> tuple[type, tuple[str, str, str, str, str], Any]: return (self.__class__, (self._orig_minute, self._orig_hour, self._orig_day_of_week, diff --git a/venv/lib/python3.12/site-packages/celery/security/certificate.py b/venv/lib/python3.12/site-packages/celery/security/certificate.py index edaa764..80398b3 100644 --- a/venv/lib/python3.12/site-packages/celery/security/certificate.py +++ b/venv/lib/python3.12/site-packages/celery/security/certificate.py @@ -43,7 +43,7 @@ class Certificate: def has_expired(self) -> bool: """Check if the certificate has expired.""" - return datetime.datetime.now(datetime.timezone.utc) >= self._cert.not_valid_after_utc + return datetime.datetime.utcnow() >= self._cert.not_valid_after def get_pubkey(self) -> ( DSAPublicKey | EllipticCurvePublicKey | Ed448PublicKey | Ed25519PublicKey | RSAPublicKey diff --git a/venv/lib/python3.12/site-packages/celery/security/serialization.py b/venv/lib/python3.12/site-packages/celery/security/serialization.py index 7b7dc12..c58ef90 100644 --- a/venv/lib/python3.12/site-packages/celery/security/serialization.py +++ b/venv/lib/python3.12/site-packages/celery/security/serialization.py @@ -11,11 +11,6 @@ from .utils import get_digest_algorithm, reraise_errors __all__ = ('SecureSerializer', 'register_auth') -# Note: we guarantee that this value won't appear in the serialized data, -# so we can use it as a separator. -# If you change this value, make sure it's not present in the serialized data. -DEFAULT_SEPARATOR = str_to_bytes("\x00\x01") - class SecureSerializer: """Signed serializer.""" @@ -34,8 +29,7 @@ class SecureSerializer: assert self._cert is not None with reraise_errors('Unable to serialize: {0!r}', (Exception,)): content_type, content_encoding, body = dumps( - data, serializer=self._serializer) - + bytes_to_str(data), serializer=self._serializer) # What we sign is the serialized body, not the body itself. # this way the receiver doesn't have to decode the contents # to verify the signature (and thus avoiding potential flaws @@ -54,26 +48,43 @@ class SecureSerializer: payload['signer'], payload['body']) self._cert_store[signer].verify(body, signature, self._digest) - return loads(body, payload['content_type'], + return loads(bytes_to_str(body), payload['content_type'], payload['content_encoding'], force=True) def _pack(self, body, content_type, content_encoding, signer, signature, - sep=DEFAULT_SEPARATOR): + sep=str_to_bytes('\x00\x01')): fields = sep.join( - ensure_bytes(s) for s in [b64encode(signer), b64encode(signature), - content_type, content_encoding, body] + ensure_bytes(s) for s in [signer, signature, content_type, + content_encoding, body] ) return b64encode(fields) - def _unpack(self, payload, sep=DEFAULT_SEPARATOR): + def _unpack(self, payload, sep=str_to_bytes('\x00\x01')): raw_payload = b64decode(ensure_bytes(payload)) - v = raw_payload.split(sep, maxsplit=4) + first_sep = raw_payload.find(sep) + + signer = raw_payload[:first_sep] + signer_cert = self._cert_store[signer] + + # shift 3 bits right to get signature length + # 2048bit rsa key has a signature length of 256 + # 4096bit rsa key has a signature length of 512 + sig_len = signer_cert.get_pubkey().key_size >> 3 + sep_len = len(sep) + signature_start_position = first_sep + sep_len + signature_end_position = signature_start_position + sig_len + signature = raw_payload[ + signature_start_position:signature_end_position + ] + + v = raw_payload[signature_end_position + sep_len:].split(sep) + return { - 'signer': b64decode(v[0]), - 'signature': b64decode(v[1]), - 'content_type': bytes_to_str(v[2]), - 'content_encoding': bytes_to_str(v[3]), - 'body': v[4], + 'signer': signer, + 'signature': signature, + 'content_type': bytes_to_str(v[0]), + 'content_encoding': bytes_to_str(v[1]), + 'body': bytes_to_str(v[2]), } diff --git a/venv/lib/python3.12/site-packages/celery/utils/annotations.py b/venv/lib/python3.12/site-packages/celery/utils/annotations.py deleted file mode 100644 index 38a549c..0000000 --- a/venv/lib/python3.12/site-packages/celery/utils/annotations.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Code related to handling annotations.""" - -import sys -import types -import typing -from inspect import isclass - - -def is_none_type(value: typing.Any) -> bool: - """Check if the given value is a NoneType.""" - if sys.version_info < (3, 10): - # raise Exception('below 3.10', value, type(None)) - return value is type(None) - return value == types.NoneType # type: ignore[no-any-return] - - -def get_optional_arg(annotation: typing.Any) -> typing.Any: - """Get the argument from an Optional[...] annotation, or None if it is no such annotation.""" - origin = typing.get_origin(annotation) - if origin != typing.Union and (sys.version_info >= (3, 10) and origin != types.UnionType): - return None - - union_args = typing.get_args(annotation) - if len(union_args) != 2: # Union does _not_ have two members, so it's not an Optional - return None - - has_none_arg = any(is_none_type(arg) for arg in union_args) - # There will always be at least one type arg, as we have already established that this is a Union with exactly - # two members, and both cannot be None (`Union[None, None]` does not work). - type_arg = next(arg for arg in union_args if not is_none_type(arg)) # pragma: no branch - - if has_none_arg: - return type_arg - return None - - -def annotation_is_class(annotation: typing.Any) -> bool: - """Test if a given annotation is a class that can be used in isinstance()/issubclass().""" - # isclass() returns True for generic type hints (e.g. `list[str]`) until Python 3.10. - # NOTE: The guard for Python 3.9 is because types.GenericAlias is only added in Python 3.9. This is not a problem - # as the syntax is added in the same version in the first place. - if (3, 9) <= sys.version_info < (3, 11) and isinstance(annotation, types.GenericAlias): - return False - return isclass(annotation) - - -def annotation_issubclass(annotation: typing.Any, cls: type) -> bool: - """Test if a given annotation is of the given subclass.""" - return annotation_is_class(annotation) and issubclass(annotation, cls) diff --git a/venv/lib/python3.12/site-packages/celery/utils/collections.py b/venv/lib/python3.12/site-packages/celery/utils/collections.py index 396ed81..6fb559a 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/collections.py +++ b/venv/lib/python3.12/site-packages/celery/utils/collections.py @@ -595,7 +595,8 @@ class LimitedSet: break # oldest item hasn't expired yet self.pop() - def pop(self, default: Any = None) -> Any: + def pop(self, default=None) -> Any: + # type: (Any) -> Any """Remove and return the oldest item, or :const:`None` when empty.""" while self._heap: _, item = heappop(self._heap) diff --git a/venv/lib/python3.12/site-packages/celery/utils/dispatch/signal.py b/venv/lib/python3.12/site-packages/celery/utils/dispatch/signal.py index ad8047e..0cfa612 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/dispatch/signal.py +++ b/venv/lib/python3.12/site-packages/celery/utils/dispatch/signal.py @@ -54,9 +54,6 @@ def _boundmethod_safe_weakref(obj): def _make_lookup_key(receiver, sender, dispatch_uid): if dispatch_uid: return (dispatch_uid, _make_id(sender)) - # Issue #9119 - retry-wrapped functions use the underlying function for dispatch_uid - elif hasattr(receiver, '_dispatch_uid'): - return (receiver._dispatch_uid, _make_id(sender)) else: return (_make_id(receiver), _make_id(sender)) @@ -173,7 +170,6 @@ class Signal: # pragma: no cover # it up later with the original func id options['dispatch_uid'] = _make_id(fun) fun = _retry_receiver(fun) - fun._dispatch_uid = options['dispatch_uid'] self._connect_signal(fun, sender, options['weak'], options['dispatch_uid']) diff --git a/venv/lib/python3.12/site-packages/celery/utils/imports.py b/venv/lib/python3.12/site-packages/celery/utils/imports.py index 676a451..390b22c 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/imports.py +++ b/venv/lib/python3.12/site-packages/celery/utils/imports.py @@ -51,13 +51,8 @@ def instantiate(name, *args, **kwargs): @contextmanager def cwd_in_path(): """Context adding the current working directory to sys.path.""" - try: - cwd = os.getcwd() - except FileNotFoundError: - cwd = None - if not cwd: - yield - elif cwd in sys.path: + cwd = os.getcwd() + if cwd in sys.path: yield else: sys.path.insert(0, cwd) diff --git a/venv/lib/python3.12/site-packages/celery/utils/iso8601.py b/venv/lib/python3.12/site-packages/celery/utils/iso8601.py index f878bec..ffe342b 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/iso8601.py +++ b/venv/lib/python3.12/site-packages/celery/utils/iso8601.py @@ -50,9 +50,9 @@ TIMEZONE_REGEX = re.compile( ) -def parse_iso8601(datestring: str) -> datetime: +def parse_iso8601(datestring): """Parse and convert ISO-8601 string to datetime.""" - warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat or dateutil.parser.isoparse") + warn("parse_iso8601", "v5.3", "v6", "datetime.datetime.fromisoformat") m = ISO8601_REGEX.match(datestring) if not m: raise ValueError('unable to parse date string %r' % datestring) diff --git a/venv/lib/python3.12/site-packages/celery/utils/log.py b/venv/lib/python3.12/site-packages/celery/utils/log.py index f67a3dd..4e8fc11 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/log.py +++ b/venv/lib/python3.12/site-packages/celery/utils/log.py @@ -37,7 +37,7 @@ base_logger = logger = _get_logger('celery') def set_in_sighandler(value): - """Set flag signifying that we're inside a signal handler.""" + """Set flag signifiying that we're inside a signal handler.""" global _in_sighandler _in_sighandler = value diff --git a/venv/lib/python3.12/site-packages/celery/utils/nodenames.py b/venv/lib/python3.12/site-packages/celery/utils/nodenames.py index 91509a4..b3d1a52 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/nodenames.py +++ b/venv/lib/python3.12/site-packages/celery/utils/nodenames.py @@ -1,6 +1,4 @@ """Worker name utilities.""" -from __future__ import annotations - import os import socket from functools import partial @@ -24,18 +22,13 @@ NODENAME_DEFAULT = 'celery' gethostname = memoize(1, Cache=dict)(socket.gethostname) __all__ = ( - 'worker_direct', - 'gethostname', - 'nodename', - 'anon_nodename', - 'nodesplit', - 'default_nodename', - 'node_format', - 'host_format', + 'worker_direct', 'gethostname', 'nodename', + 'anon_nodename', 'nodesplit', 'default_nodename', + 'node_format', 'host_format', ) -def worker_direct(hostname: str | Queue) -> Queue: +def worker_direct(hostname): """Return the :class:`kombu.Queue` being a direct route to a worker. Arguments: @@ -53,20 +46,21 @@ def worker_direct(hostname: str | Queue) -> Queue: ) -def nodename(name: str, hostname: str) -> str: +def nodename(name, hostname): """Create node name from name/hostname pair.""" return NODENAME_SEP.join((name, hostname)) -def anon_nodename(hostname: str | None = None, prefix: str = 'gen') -> str: +def anon_nodename(hostname=None, prefix='gen'): """Return the nodename for this process (not a worker). This is used for e.g. the origin task message field. """ - return nodename(''.join([prefix, str(os.getpid())]), hostname or gethostname()) + return nodename(''.join([prefix, str(os.getpid())]), + hostname or gethostname()) -def nodesplit(name: str) -> tuple[None, str] | list[str]: +def nodesplit(name): """Split node name into tuple of name/hostname.""" parts = name.split(NODENAME_SEP, 1) if len(parts) == 1: @@ -74,21 +68,21 @@ def nodesplit(name: str) -> tuple[None, str] | list[str]: return parts -def default_nodename(hostname: str) -> str: +def default_nodename(hostname): """Return the default nodename for this process.""" name, host = nodesplit(hostname or '') return nodename(name or NODENAME_DEFAULT, host or gethostname()) -def node_format(s: str, name: str, **extra: dict) -> str: +def node_format(s, name, **extra): """Format worker node name (name@host.com).""" shortname, host = nodesplit(name) - return host_format(s, host, shortname or NODENAME_DEFAULT, p=name, **extra) + return host_format( + s, host, shortname or NODENAME_DEFAULT, p=name, **extra) -def _fmt_process_index(prefix: str = '', default: str = '0') -> str: +def _fmt_process_index(prefix='', default='0'): from .log import current_process_index - index = current_process_index() return f'{prefix}{index}' if index else default @@ -96,19 +90,13 @@ def _fmt_process_index(prefix: str = '', default: str = '0') -> str: _fmt_process_index_with_prefix = partial(_fmt_process_index, '-', '') -def host_format(s: str, host: str | None = None, name: str | None = None, **extra: dict) -> str: +def host_format(s, host=None, name=None, **extra): """Format host %x abbreviations.""" host = host or gethostname() hname, _, domain = host.partition('.') name = name or hname - keys = dict( - { - 'h': host, - 'n': name, - 'd': domain, - 'i': _fmt_process_index, - 'I': _fmt_process_index_with_prefix, - }, - **extra, - ) + keys = dict({ + 'h': host, 'n': name, 'd': domain, + 'i': _fmt_process_index, 'I': _fmt_process_index_with_prefix, + }, **extra) return simple_format(s, keys) diff --git a/venv/lib/python3.12/site-packages/celery/utils/quorum_queues.py b/venv/lib/python3.12/site-packages/celery/utils/quorum_queues.py deleted file mode 100644 index 0eb058f..0000000 --- a/venv/lib/python3.12/site-packages/celery/utils/quorum_queues.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - - -def detect_quorum_queues(app, driver_type: str) -> tuple[bool, str]: - """Detect if any of the queues are quorum queues. - - Returns: - tuple[bool, str]: A tuple containing a boolean indicating if any of the queues are quorum queues - and the name of the first quorum queue found or an empty string if no quorum queues were found. - """ - is_rabbitmq_broker = driver_type == 'amqp' - - if is_rabbitmq_broker: - queues = app.amqp.queues - for qname in queues: - qarguments = queues[qname].queue_arguments or {} - if qarguments.get("x-queue-type") == "quorum": - return True, qname - - return False, "" diff --git a/venv/lib/python3.12/site-packages/celery/utils/saferepr.py b/venv/lib/python3.12/site-packages/celery/utils/saferepr.py index 9b37bc9..feddd41 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/saferepr.py +++ b/venv/lib/python3.12/site-packages/celery/utils/saferepr.py @@ -15,7 +15,7 @@ from decimal import Decimal from itertools import chain from numbers import Number from pprint import _recursion -from typing import Any, AnyStr, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple # noqa +from typing import Any, AnyStr, Callable, Dict, Iterator, List, Sequence, Set, Tuple # noqa from .text import truncate @@ -41,7 +41,7 @@ _quoted = namedtuple('_quoted', ('value',)) #: Recursion protection. _dirty = namedtuple('_dirty', ('objid',)) -#: Types that are represented as chars. +#: Types that are repsented as chars. chars_t = (bytes, str) #: Types that are regarded as safe to call repr on. @@ -194,12 +194,9 @@ def _reprseq(val, lit_start, lit_end, builtin_type, chainer): ) -def reprstream(stack: deque, - seen: Optional[Set] = None, - maxlevels: int = 3, - level: int = 0, - isinstance: Callable = isinstance) -> Iterator[Any]: +def reprstream(stack, seen=None, maxlevels=3, level=0, isinstance=isinstance): """Streaming repr, yielding tokens.""" + # type: (deque, Set, int, int, Callable) -> Iterator[Any] seen = seen or set() append = stack.append popleft = stack.popleft diff --git a/venv/lib/python3.12/site-packages/celery/utils/sysinfo.py b/venv/lib/python3.12/site-packages/celery/utils/sysinfo.py index 52fc45e..57425dd 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/sysinfo.py +++ b/venv/lib/python3.12/site-packages/celery/utils/sysinfo.py @@ -1,6 +1,4 @@ """System information utilities.""" -from __future__ import annotations - import os from math import ceil @@ -11,16 +9,16 @@ __all__ = ('load_average', 'df') if hasattr(os, 'getloadavg'): - def _load_average() -> tuple[float, ...]: + def _load_average(): return tuple(ceil(l * 1e2) / 1e2 for l in os.getloadavg()) else: # pragma: no cover # Windows doesn't have getloadavg - def _load_average() -> tuple[float, ...]: - return 0.0, 0.0, 0.0, + def _load_average(): + return (0.0, 0.0, 0.0) -def load_average() -> tuple[float, ...]: +def load_average(): """Return system load average as a triple.""" return _load_average() @@ -28,23 +26,23 @@ def load_average() -> tuple[float, ...]: class df: """Disk information.""" - def __init__(self, path: str | bytes | os.PathLike) -> None: + def __init__(self, path): self.path = path @property - def total_blocks(self) -> float: + def total_blocks(self): return self.stat.f_blocks * self.stat.f_frsize / 1024 @property - def available(self) -> float: + def available(self): return self.stat.f_bavail * self.stat.f_frsize / 1024 @property - def capacity(self) -> int: + def capacity(self): avail = self.stat.f_bavail used = self.stat.f_blocks - self.stat.f_bfree return int(ceil(used * 100.0 / (used + avail) + 0.5)) @cached_property - def stat(self) -> os.statvfs_result: + def stat(self): return os.statvfs(os.path.abspath(self.path)) diff --git a/venv/lib/python3.12/site-packages/celery/utils/term.py b/venv/lib/python3.12/site-packages/celery/utils/term.py index ba6a321..a2eff99 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/term.py +++ b/venv/lib/python3.12/site-packages/celery/utils/term.py @@ -1,7 +1,6 @@ """Terminals and colors.""" -from __future__ import annotations - import base64 +import codecs import os import platform import sys @@ -9,8 +8,6 @@ from functools import reduce __all__ = ('colored',) -from typing import Any - BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) OP_SEQ = '\033[%dm' RESET_SEQ = '\033[0m' @@ -29,7 +26,7 @@ _IMG_PRE = '\033Ptmux;\033\033]' if TERM_IS_SCREEN else '\033]' _IMG_POST = '\a\033\\' if TERM_IS_SCREEN else '\a' -def fg(s: int) -> str: +def fg(s): return COLOR_SEQ % s @@ -44,11 +41,11 @@ class colored: ... c.green('dog '))) """ - def __init__(self, *s: object, **kwargs: Any) -> None: - self.s: tuple[object, ...] = s - self.enabled: bool = not IS_WINDOWS and kwargs.get('enabled', True) - self.op: str = kwargs.get('op', '') - self.names: dict[str, Any] = { + def __init__(self, *s, **kwargs): + self.s = s + self.enabled = not IS_WINDOWS and kwargs.get('enabled', True) + self.op = kwargs.get('op', '') + self.names = { 'black': self.black, 'red': self.red, 'green': self.green, @@ -59,10 +56,10 @@ class colored: 'white': self.white, } - def _add(self, a: object, b: object) -> str: - return f"{a}{b}" + def _add(self, a, b): + return str(a) + str(b) - def _fold_no_color(self, a: Any, b: Any) -> str: + def _fold_no_color(self, a, b): try: A = a.no_color() except AttributeError: @@ -72,113 +69,109 @@ class colored: except AttributeError: B = str(b) - return f"{A}{B}" + return ''.join((str(A), str(B))) - def no_color(self) -> str: + def no_color(self): if self.s: return str(reduce(self._fold_no_color, self.s)) return '' - def embed(self) -> str: + def embed(self): prefix = '' if self.enabled: prefix = self.op - return f"{prefix}{reduce(self._add, self.s)}" + return ''.join((str(prefix), str(reduce(self._add, self.s)))) - def __str__(self) -> str: + def __str__(self): suffix = '' if self.enabled: suffix = RESET_SEQ - return f"{self.embed()}{suffix}" + return str(''.join((self.embed(), str(suffix)))) - def node(self, s: tuple[object, ...], op: str) -> colored: + def node(self, s, op): return self.__class__(enabled=self.enabled, op=op, *s) - def black(self, *s: object) -> colored: + def black(self, *s): return self.node(s, fg(30 + BLACK)) - def red(self, *s: object) -> colored: + def red(self, *s): return self.node(s, fg(30 + RED)) - def green(self, *s: object) -> colored: + def green(self, *s): return self.node(s, fg(30 + GREEN)) - def yellow(self, *s: object) -> colored: + def yellow(self, *s): return self.node(s, fg(30 + YELLOW)) - def blue(self, *s: object) -> colored: + def blue(self, *s): return self.node(s, fg(30 + BLUE)) - def magenta(self, *s: object) -> colored: + def magenta(self, *s): return self.node(s, fg(30 + MAGENTA)) - def cyan(self, *s: object) -> colored: + def cyan(self, *s): return self.node(s, fg(30 + CYAN)) - def white(self, *s: object) -> colored: + def white(self, *s): return self.node(s, fg(30 + WHITE)) - def __repr__(self) -> str: + def __repr__(self): return repr(self.no_color()) - def bold(self, *s: object) -> colored: + def bold(self, *s): return self.node(s, OP_SEQ % 1) - def underline(self, *s: object) -> colored: + def underline(self, *s): return self.node(s, OP_SEQ % 4) - def blink(self, *s: object) -> colored: + def blink(self, *s): return self.node(s, OP_SEQ % 5) - def reverse(self, *s: object) -> colored: + def reverse(self, *s): return self.node(s, OP_SEQ % 7) - def bright(self, *s: object) -> colored: + def bright(self, *s): return self.node(s, OP_SEQ % 8) - def ired(self, *s: object) -> colored: + def ired(self, *s): return self.node(s, fg(40 + RED)) - def igreen(self, *s: object) -> colored: + def igreen(self, *s): return self.node(s, fg(40 + GREEN)) - def iyellow(self, *s: object) -> colored: + def iyellow(self, *s): return self.node(s, fg(40 + YELLOW)) - def iblue(self, *s: colored) -> colored: + def iblue(self, *s): return self.node(s, fg(40 + BLUE)) - def imagenta(self, *s: object) -> colored: + def imagenta(self, *s): return self.node(s, fg(40 + MAGENTA)) - def icyan(self, *s: object) -> colored: + def icyan(self, *s): return self.node(s, fg(40 + CYAN)) - def iwhite(self, *s: object) -> colored: + def iwhite(self, *s): return self.node(s, fg(40 + WHITE)) - def reset(self, *s: object) -> colored: - return self.node(s or ('',), RESET_SEQ) + def reset(self, *s): + return self.node(s or [''], RESET_SEQ) - def __add__(self, other: object) -> str: - return f"{self}{other}" + def __add__(self, other): + return str(self) + str(other) -def supports_images() -> bool: - - try: - return sys.stdin.isatty() and bool(os.environ.get('ITERM_PROFILE')) - except AttributeError: - return False +def supports_images(): + return sys.stdin.isatty() and ITERM_PROFILE -def _read_as_base64(path: str) -> str: - with open(path, mode='rb') as fh: +def _read_as_base64(path): + with codecs.open(path, mode='rb') as fh: encoded = base64.b64encode(fh.read()) - return encoded.decode('ascii') + return encoded if isinstance(encoded, str) else encoded.decode('ascii') -def imgcat(path: str, inline: int = 1, preserve_aspect_ratio: int = 0, **kwargs: Any) -> str: +def imgcat(path, inline=1, preserve_aspect_ratio=0, **kwargs): return '\n%s1337;File=inline=%d;preserveAspectRatio=%d:%s%s' % ( _IMG_PRE, inline, preserve_aspect_ratio, _read_as_base64(path), _IMG_POST) diff --git a/venv/lib/python3.12/site-packages/celery/utils/time.py b/venv/lib/python3.12/site-packages/celery/utils/time.py index 2376bb3..f5329a5 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/time.py +++ b/venv/lib/python3.12/site-packages/celery/utils/time.py @@ -14,7 +14,6 @@ from types import ModuleType from typing import Any, Callable from dateutil import tz as dateutil_tz -from dateutil.parser import isoparse from kombu.utils.functional import reprcall from kombu.utils.objects import cached_property @@ -41,9 +40,6 @@ C_REMDEBUG = os.environ.get('C_REMDEBUG', False) DAYNAMES = 'sun', 'mon', 'tue', 'wed', 'thu', 'fri', 'sat' WEEKDAYS = dict(zip(DAYNAMES, range(7))) -MONTHNAMES = 'jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec' -YEARMONTHS = dict(zip(MONTHNAMES, range(1, 13))) - RATE_MODIFIER_MAP = { 's': lambda n: n, 'm': lambda n: n / 60.0, @@ -204,7 +200,7 @@ def delta_resolution(dt: datetime, delta: timedelta) -> datetime: def remaining( start: datetime, ends_in: timedelta, now: Callable | None = None, relative: bool = False) -> timedelta: - """Calculate the real remaining time for a start date and a timedelta. + """Calculate the remaining time for a start date and a timedelta. For example, "how many seconds left for 30 seconds after start?" @@ -215,28 +211,24 @@ def remaining( using :func:`delta_resolution` (i.e., rounded to the resolution of `ends_in`). now (Callable): Function returning the current time and date. - Defaults to :func:`datetime.now(timezone.utc)`. + Defaults to :func:`datetime.utcnow`. Returns: ~datetime.timedelta: Remaining time. """ - now = now or datetime.now(datetime_timezone.utc) + now = now or datetime.utcnow() + if str( + start.tzinfo) == str( + now.tzinfo) and now.utcoffset() != start.utcoffset(): + # DST started/ended + start = start.replace(tzinfo=now.tzinfo) end_date = start + ends_in if relative: end_date = delta_resolution(end_date, ends_in).replace(microsecond=0) - - # Using UTC to calculate real time difference. - # Python by default uses wall time in arithmetic between datetimes with - # equal non-UTC timezones. - now_utc = now.astimezone(timezone.utc) - end_date_utc = end_date.astimezone(timezone.utc) - ret = end_date_utc - now_utc + ret = end_date - now if C_REMDEBUG: # pragma: no cover - print( - 'rem: NOW:{!r} NOW_UTC:{!r} START:{!r} ENDS_IN:{!r} ' - 'END_DATE:{} END_DATE_UTC:{!r} REM:{}'.format( - now, now_utc, start, ends_in, end_date, end_date_utc, ret) - ) + print('rem: NOW:{!r} START:{!r} ENDS_IN:{!r} END_DATE:{} REM:{}'.format( + now, start, ends_in, end_date, ret)) return ret @@ -265,21 +257,6 @@ def weekday(name: str) -> int: raise KeyError(name) -def yearmonth(name: str) -> int: - """Return the position of a month: 1 - 12, where 1 is January. - - Example: - >>> yearmonth('january'), yearmonth('jan'), yearmonth('may') - (1, 1, 5) - """ - abbreviation = name[0:3].lower() - try: - return YEARMONTHS[abbreviation] - except KeyError: - # Show original day name in exception, instead of abbr. - raise KeyError(name) - - def humanize_seconds( secs: int, prefix: str = '', sep: str = '', now: str = 'now', microseconds: bool = False) -> str: @@ -311,7 +288,7 @@ def maybe_iso8601(dt: datetime | str | None) -> None | datetime: return if isinstance(dt, datetime): return dt - return isoparse(dt) + return datetime.fromisoformat(dt) def is_naive(dt: datetime) -> bool: @@ -325,7 +302,7 @@ def _can_detect_ambiguous(tz: tzinfo) -> bool: return isinstance(tz, ZoneInfo) or hasattr(tz, "is_ambiguous") -def _is_ambiguous(dt: datetime, tz: tzinfo) -> bool: +def _is_ambigious(dt: datetime, tz: tzinfo) -> bool: """Helper function to determine if a timezone is ambiguous using python's dateutil module. Returns False if the timezone cannot detect ambiguity, or if there is no ambiguity, otherwise True. @@ -342,7 +319,7 @@ def make_aware(dt: datetime, tz: tzinfo) -> datetime: """Set timezone for a :class:`~datetime.datetime` object.""" dt = dt.replace(tzinfo=tz) - if _is_ambiguous(dt, tz): + if _is_ambigious(dt, tz): dt = min(dt.replace(fold=0), dt.replace(fold=1)) return dt diff --git a/venv/lib/python3.12/site-packages/celery/utils/timer2.py b/venv/lib/python3.12/site-packages/celery/utils/timer2.py index adfdb40..88d8ffd 100644 --- a/venv/lib/python3.12/site-packages/celery/utils/timer2.py +++ b/venv/lib/python3.12/site-packages/celery/utils/timer2.py @@ -10,7 +10,6 @@ import threading from itertools import count from threading import TIMEOUT_MAX as THREAD_TIMEOUT_MAX from time import sleep -from typing import Any, Callable, Iterator, Optional, Tuple from kombu.asynchronous.timer import Entry from kombu.asynchronous.timer import Timer as Schedule @@ -31,23 +30,20 @@ class Timer(threading.Thread): Entry = Entry Schedule = Schedule - running: bool = False - on_tick: Optional[Callable[[float], None]] = None + running = False + on_tick = None - _timer_count: count = count(1) + _timer_count = count(1) if TIMER_DEBUG: # pragma: no cover - def start(self, *args: Any, **kwargs: Any) -> None: + def start(self, *args, **kwargs): import traceback print('- Timer starting') traceback.print_stack() super().start(*args, **kwargs) - def __init__(self, schedule: Optional[Schedule] = None, - on_error: Optional[Callable[[Exception], None]] = None, - on_tick: Optional[Callable[[float], None]] = None, - on_start: Optional[Callable[['Timer'], None]] = None, - max_interval: Optional[float] = None, **kwargs: Any) -> None: + def __init__(self, schedule=None, on_error=None, on_tick=None, + on_start=None, max_interval=None, **kwargs): self.schedule = schedule or self.Schedule(on_error=on_error, max_interval=max_interval) self.on_start = on_start @@ -64,10 +60,8 @@ class Timer(threading.Thread): self.daemon = True self.name = f'Timer-{next(self._timer_count)}' - def _next_entry(self) -> Optional[float]: + def _next_entry(self): with self.not_empty: - delay: Optional[float] - entry: Optional[Entry] delay, entry = next(self.scheduler) if entry is None: if delay is None: @@ -76,10 +70,10 @@ class Timer(threading.Thread): return self.schedule.apply_entry(entry) __next__ = next = _next_entry # for 2to3 - def run(self) -> None: + def run(self): try: self.running = True - self.scheduler: Iterator[Tuple[Optional[float], Optional[Entry]]] = iter(self.schedule) + self.scheduler = iter(self.schedule) while not self.__is_shutdown.is_set(): delay = self._next_entry() @@ -100,61 +94,61 @@ class Timer(threading.Thread): sys.stderr.flush() os._exit(1) - def stop(self) -> None: + def stop(self): self.__is_shutdown.set() if self.running: self.__is_stopped.wait() self.join(THREAD_TIMEOUT_MAX) self.running = False - def ensure_started(self) -> None: + def ensure_started(self): if not self.running and not self.is_alive(): if self.on_start: self.on_start(self) self.start() - def _do_enter(self, meth: str, *args: Any, **kwargs: Any) -> Entry: + def _do_enter(self, meth, *args, **kwargs): self.ensure_started() with self.mutex: entry = getattr(self.schedule, meth)(*args, **kwargs) self.not_empty.notify() return entry - def enter(self, entry: Entry, eta: float, priority: Optional[int] = None) -> Entry: + def enter(self, entry, eta, priority=None): return self._do_enter('enter_at', entry, eta, priority=priority) - def call_at(self, *args: Any, **kwargs: Any) -> Entry: + def call_at(self, *args, **kwargs): return self._do_enter('call_at', *args, **kwargs) - def enter_after(self, *args: Any, **kwargs: Any) -> Entry: + def enter_after(self, *args, **kwargs): return self._do_enter('enter_after', *args, **kwargs) - def call_after(self, *args: Any, **kwargs: Any) -> Entry: + def call_after(self, *args, **kwargs): return self._do_enter('call_after', *args, **kwargs) - def call_repeatedly(self, *args: Any, **kwargs: Any) -> Entry: + def call_repeatedly(self, *args, **kwargs): return self._do_enter('call_repeatedly', *args, **kwargs) - def exit_after(self, secs: float, priority: int = 10) -> None: + def exit_after(self, secs, priority=10): self.call_after(secs, sys.exit, priority) - def cancel(self, tref: Entry) -> None: + def cancel(self, tref): tref.cancel() - def clear(self) -> None: + def clear(self): self.schedule.clear() - def empty(self) -> bool: + def empty(self): return not len(self) - def __len__(self) -> int: + def __len__(self): return len(self.schedule) - def __bool__(self) -> bool: + def __bool__(self): """``bool(timer)``.""" return True __nonzero__ = __bool__ @property - def queue(self) -> list: + def queue(self): return self.schedule.queue diff --git a/venv/lib/python3.12/site-packages/celery/worker/consumer/consumer.py b/venv/lib/python3.12/site-packages/celery/worker/consumer/consumer.py index 3e6a66d..e072ef5 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/consumer/consumer.py +++ b/venv/lib/python3.12/site-packages/celery/worker/consumer/consumer.py @@ -169,7 +169,6 @@ class Consumer: 'celery.worker.consumer.heart:Heart', 'celery.worker.consumer.control:Control', 'celery.worker.consumer.tasks:Tasks', - 'celery.worker.consumer.delayed_delivery:DelayedDelivery', 'celery.worker.consumer.consumer:Evloop', 'celery.worker.consumer.agent:Agent', ] @@ -391,20 +390,19 @@ class Consumer: else: warnings.warn(CANCEL_TASKS_BY_DEFAULT, CPendingDeprecationWarning) - if self.app.conf.worker_enable_prefetch_count_reduction: - self.initial_prefetch_count = max( - self.prefetch_multiplier, - self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier - ) + self.initial_prefetch_count = max( + self.prefetch_multiplier, + self.max_prefetch_count - len(tuple(active_requests)) * self.prefetch_multiplier + ) - self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count - if not self._maximum_prefetch_restored: - logger.info( - f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid " - f"over-fetching since {len(tuple(active_requests))} tasks are currently being processed.\n" - f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks " - "complete processing." - ) + self._maximum_prefetch_restored = self.initial_prefetch_count == self.max_prefetch_count + if not self._maximum_prefetch_restored: + logger.info( + f"Temporarily reducing the prefetch count to {self.initial_prefetch_count} to avoid over-fetching " + f"since {len(tuple(active_requests))} tasks are currently being processed.\n" + f"The prefetch count will be gradually restored to {self.max_prefetch_count} as the tasks " + "complete processing." + ) def register_with_event_loop(self, hub): self.blueprint.send_all( @@ -413,7 +411,6 @@ class Consumer: ) def shutdown(self): - self.perform_pending_operations() self.blueprint.shutdown(self) def stop(self): @@ -478,9 +475,9 @@ class Consumer: return self.ensure_connected( self.app.connection_for_read(heartbeat=heartbeat)) - def connection_for_write(self, url=None, heartbeat=None): + def connection_for_write(self, heartbeat=None): return self.ensure_connected( - self.app.connection_for_write(url=url, heartbeat=heartbeat)) + self.app.connection_for_write(heartbeat=heartbeat)) def ensure_connected(self, conn): # Callback called for each retry while the connection @@ -507,14 +504,13 @@ class Consumer: # to determine whether connection retries are disabled. retry_disabled = not self.app.conf.broker_connection_retry - if retry_disabled: - warnings.warn( - CPendingDeprecationWarning( - "The broker_connection_retry configuration setting will no longer determine\n" - "whether broker connection retries are made during startup in Celery 6.0 and above.\n" - "If you wish to refrain from retrying connections on startup,\n" - "you should set broker_connection_retry_on_startup to False instead.") - ) + warnings.warn( + CPendingDeprecationWarning( + f"The broker_connection_retry configuration setting will no longer determine\n" + f"whether broker connection retries are made during startup in Celery 6.0 and above.\n" + f"If you wish to retain the existing behavior for retrying connections on startup,\n" + f"you should set broker_connection_retry_on_startup to {self.app.conf.broker_connection_retry}.") + ) else: if self.first_connection_attempt: retry_disabled = not self.app.conf.broker_connection_retry_on_startup @@ -700,10 +696,7 @@ class Consumer: def _restore_prefetch_count_after_connection_restart(self, p, *args): with self.qos._mutex: - if any(( - not self.app.conf.worker_enable_prefetch_count_reduction, - self._maximum_prefetch_restored, - )): + if self._maximum_prefetch_restored: return new_prefetch_count = min(self.max_prefetch_count, self._new_prefetch_count) @@ -733,29 +726,6 @@ class Consumer: self=self, state=self.blueprint.human_state(), ) - def cancel_all_unacked_requests(self): - """Cancel all active requests that either do not require late acknowledgments or, - if they do, have not been acknowledged yet. - """ - - def should_cancel(request): - if not request.task.acks_late: - # Task does not require late acknowledgment, cancel it. - return True - - if not request.acknowledged: - # Task is late acknowledged, but it has not been acknowledged yet, cancel it. - return True - - # Task is late acknowledged, but it has already been acknowledged. - return False # Do not cancel and allow it to gracefully finish as it has already been acknowledged. - - requests_to_cancel = tuple(filter(should_cancel, active_requests)) - - if requests_to_cancel: - for request in requests_to_cancel: - request.cancel(self.pool) - class Evloop(bootsteps.StartStopStep): """Event loop service. diff --git a/venv/lib/python3.12/site-packages/celery/worker/consumer/delayed_delivery.py b/venv/lib/python3.12/site-packages/celery/worker/consumer/delayed_delivery.py deleted file mode 100644 index 66a5501..0000000 --- a/venv/lib/python3.12/site-packages/celery/worker/consumer/delayed_delivery.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Native delayed delivery functionality for Celery workers. - -This module provides the DelayedDelivery bootstep which handles setup and configuration -of native delayed delivery functionality when using quorum queues. -""" -from typing import Iterator, List, Optional, Set, Union, ValuesView - -from kombu import Connection, Queue -from kombu.transport.native_delayed_delivery import (bind_queue_to_native_delayed_delivery_exchange, - declare_native_delayed_delivery_exchanges_and_queues) -from kombu.utils.functional import retry_over_time - -from celery import Celery, bootsteps -from celery.utils.log import get_logger -from celery.utils.quorum_queues import detect_quorum_queues -from celery.worker.consumer import Consumer, Tasks - -__all__ = ('DelayedDelivery',) - -logger = get_logger(__name__) - - -# Default retry settings -RETRY_INTERVAL = 1.0 # seconds between retries -MAX_RETRIES = 3 # maximum number of retries - - -# Valid queue types for delayed delivery -VALID_QUEUE_TYPES = {'classic', 'quorum'} - - -class DelayedDelivery(bootsteps.StartStopStep): - """Bootstep that sets up native delayed delivery functionality. - - This component handles the setup and configuration of native delayed delivery - for Celery workers. It is automatically included when quorum queues are - detected in the application configuration. - - Responsibilities: - - Declaring native delayed delivery exchanges and queues - - Binding all application queues to the delayed delivery exchanges - - Handling connection failures gracefully with retries - - Validating configuration settings - """ - - requires = (Tasks,) - - def include_if(self, c: Consumer) -> bool: - """Determine if this bootstep should be included. - - Args: - c: The Celery consumer instance - - Returns: - bool: True if quorum queues are detected, False otherwise - """ - return detect_quorum_queues(c.app, c.app.connection_for_write().transport.driver_type)[0] - - def start(self, c: Consumer) -> None: - """Initialize delayed delivery for all broker URLs. - - Attempts to set up delayed delivery for each broker URL in the configuration. - Failures are logged but don't prevent attempting remaining URLs. - - Args: - c: The Celery consumer instance - - Raises: - ValueError: If configuration validation fails - """ - app: Celery = c.app - - try: - self._validate_configuration(app) - except ValueError as e: - logger.critical("Configuration validation failed: %s", str(e)) - raise - - broker_urls = self._validate_broker_urls(app.conf.broker_url) - setup_errors = [] - - for broker_url in broker_urls: - try: - retry_over_time( - self._setup_delayed_delivery, - args=(c, broker_url), - catch=(ConnectionRefusedError, OSError), - errback=self._on_retry, - interval_start=RETRY_INTERVAL, - max_retries=MAX_RETRIES, - ) - except Exception as e: - logger.warning( - "Failed to setup delayed delivery for %r: %s", - broker_url, str(e) - ) - setup_errors.append((broker_url, e)) - - if len(setup_errors) == len(broker_urls): - logger.critical( - "Failed to setup delayed delivery for all broker URLs. " - "Native delayed delivery will not be available." - ) - - def _setup_delayed_delivery(self, c: Consumer, broker_url: str) -> None: - """Set up delayed delivery for a specific broker URL. - - Args: - c: The Celery consumer instance - broker_url: The broker URL to configure - - Raises: - ConnectionRefusedError: If connection to the broker fails - OSError: If there are network-related issues - Exception: For other unexpected errors during setup - """ - connection: Connection = c.app.connection_for_write(url=broker_url) - queue_type = c.app.conf.broker_native_delayed_delivery_queue_type - logger.debug( - "Setting up delayed delivery for broker %r with queue type %r", - broker_url, queue_type - ) - - try: - declare_native_delayed_delivery_exchanges_and_queues( - connection, - queue_type - ) - except Exception as e: - logger.warning( - "Failed to declare exchanges and queues for %r: %s", - broker_url, str(e) - ) - raise - - try: - self._bind_queues(c.app, connection) - except Exception as e: - logger.warning( - "Failed to bind queues for %r: %s", - broker_url, str(e) - ) - raise - - def _bind_queues(self, app: Celery, connection: Connection) -> None: - """Bind all application queues to delayed delivery exchanges. - - Args: - app: The Celery application instance - connection: The broker connection to use - - Raises: - Exception: If queue binding fails - """ - queues: ValuesView[Queue] = app.amqp.queues.values() - if not queues: - logger.warning("No queues found to bind for delayed delivery") - return - - for queue in queues: - try: - logger.debug("Binding queue %r to delayed delivery exchange", queue.name) - bind_queue_to_native_delayed_delivery_exchange(connection, queue) - except Exception as e: - logger.error( - "Failed to bind queue %r: %s", - queue.name, str(e) - ) - raise - - def _on_retry(self, exc: Exception, interval_range: Iterator[float], intervals_count: int) -> None: - """Callback for retry attempts. - - Args: - exc: The exception that triggered the retry - interval_range: An iterator which returns the time in seconds to sleep next - intervals_count: Number of retry attempts so far - """ - logger.warning( - "Retrying delayed delivery setup (attempt %d/%d) after error: %s", - intervals_count + 1, MAX_RETRIES, str(exc) - ) - - def _validate_configuration(self, app: Celery) -> None: - """Validate all required configuration settings. - - Args: - app: The Celery application instance - - Raises: - ValueError: If any configuration is invalid - """ - # Validate broker URLs - self._validate_broker_urls(app.conf.broker_url) - - # Validate queue type - self._validate_queue_type(app.conf.broker_native_delayed_delivery_queue_type) - - def _validate_broker_urls(self, broker_urls: Union[str, List[str]]) -> Set[str]: - """Validate and split broker URLs. - - Args: - broker_urls: Broker URLs, either as a semicolon-separated string - or as a list of strings - - Returns: - Set of valid broker URLs - - Raises: - ValueError: If no valid broker URLs are found or if invalid URLs are provided - """ - if not broker_urls: - raise ValueError("broker_url configuration is empty") - - if isinstance(broker_urls, str): - brokers = broker_urls.split(";") - elif isinstance(broker_urls, list): - if not all(isinstance(url, str) for url in broker_urls): - raise ValueError("All broker URLs must be strings") - brokers = broker_urls - else: - raise ValueError(f"broker_url must be a string or list, got {broker_urls!r}") - - valid_urls = {url for url in brokers} - - if not valid_urls: - raise ValueError("No valid broker URLs found in configuration") - - return valid_urls - - def _validate_queue_type(self, queue_type: Optional[str]) -> None: - """Validate the queue type configuration. - - Args: - queue_type: The configured queue type - - Raises: - ValueError: If queue type is invalid - """ - if not queue_type: - raise ValueError("broker_native_delayed_delivery_queue_type is not configured") - - if queue_type not in VALID_QUEUE_TYPES: - sorted_types = sorted(VALID_QUEUE_TYPES) - raise ValueError( - f"Invalid queue type {queue_type!r}. Must be one of: {', '.join(sorted_types)}" - ) diff --git a/venv/lib/python3.12/site-packages/celery/worker/consumer/gossip.py b/venv/lib/python3.12/site-packages/celery/worker/consumer/gossip.py index 509471c..16e1c2e 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/consumer/gossip.py +++ b/venv/lib/python3.12/site-packages/celery/worker/consumer/gossip.py @@ -176,7 +176,6 @@ class Gossip(bootsteps.ConsumerStep): channel, queues=[ev.queue], on_message=partial(self.on_message, ev.event_from_message), - accept=ev.accept, no_ack=True )] diff --git a/venv/lib/python3.12/site-packages/celery/worker/consumer/mingle.py b/venv/lib/python3.12/site-packages/celery/worker/consumer/mingle.py index d3f626e..532ab75 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/consumer/mingle.py +++ b/venv/lib/python3.12/site-packages/celery/worker/consumer/mingle.py @@ -22,7 +22,7 @@ class Mingle(bootsteps.StartStopStep): label = 'Mingle' requires = (Events,) - compatible_transports = {'amqp', 'redis', 'gcpubsub'} + compatible_transports = {'amqp', 'redis'} def __init__(self, c, without_mingle=False, **kwargs): self.enabled = not without_mingle and self.compatible_transport(c.app) diff --git a/venv/lib/python3.12/site-packages/celery/worker/consumer/tasks.py b/venv/lib/python3.12/site-packages/celery/worker/consumer/tasks.py index 67cbfc1..b4e4aee 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/consumer/tasks.py +++ b/venv/lib/python3.12/site-packages/celery/worker/consumer/tasks.py @@ -1,18 +1,13 @@ """Worker Task Consumer Bootstep.""" - -from __future__ import annotations - from kombu.common import QoS, ignore_errors from celery import bootsteps from celery.utils.log import get_logger -from celery.utils.quorum_queues import detect_quorum_queues from .mingle import Mingle __all__ = ('Tasks',) - logger = get_logger(__name__) debug = logger.debug @@ -30,7 +25,10 @@ class Tasks(bootsteps.StartStopStep): """Start task consumer.""" c.update_strategies() - qos_global = self.qos_global(c) + # - RabbitMQ 3.3 completely redefines how basic_qos works... + # This will detect if the new qos semantics is in effect, + # and if so make sure the 'apply_global' flag is set on qos updates. + qos_global = not c.connection.qos_semantics_matches_spec # set initial prefetch count c.connection.default_channel.basic_qos( @@ -65,24 +63,3 @@ class Tasks(bootsteps.StartStopStep): def info(self, c): """Return task consumer info.""" return {'prefetch_count': c.qos.value if c.qos else 'N/A'} - - def qos_global(self, c) -> bool: - """Determine if global QoS should be applied. - - Additional information: - https://www.rabbitmq.com/docs/consumer-prefetch - https://www.rabbitmq.com/docs/quorum-queues#global-qos - """ - # - RabbitMQ 3.3 completely redefines how basic_qos works... - # This will detect if the new qos semantics is in effect, - # and if so make sure the 'apply_global' flag is set on qos updates. - qos_global = not c.connection.qos_semantics_matches_spec - - if c.app.conf.worker_detect_quorum_queues: - using_quorum_queues, qname = detect_quorum_queues(c.app, c.connection.transport.driver_type) - - if using_quorum_queues: - qos_global = False - logger.info("Global QoS is disabled. Prefetch count in now static.") - - return qos_global diff --git a/venv/lib/python3.12/site-packages/celery/worker/control.py b/venv/lib/python3.12/site-packages/celery/worker/control.py index 8f9fc4f..41d059e 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/control.py +++ b/venv/lib/python3.12/site-packages/celery/worker/control.py @@ -7,7 +7,6 @@ from billiard.common import TERM_SIGNAME from kombu.utils.encoding import safe_repr from celery.exceptions import WorkerShutdown -from celery.platforms import EX_OK from celery.platforms import signals as _signals from celery.utils.functional import maybe_list from celery.utils.log import get_logger @@ -581,7 +580,7 @@ def autoscale(state, max=None, min=None): def shutdown(state, msg='Got shutdown from remote', **kwargs): """Shutdown worker(s).""" logger.warning(msg) - raise WorkerShutdown(EX_OK) + raise WorkerShutdown(msg) # -- Queues diff --git a/venv/lib/python3.12/site-packages/celery/worker/loops.py b/venv/lib/python3.12/site-packages/celery/worker/loops.py index 1f9e589..0630e67 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/loops.py +++ b/venv/lib/python3.12/site-packages/celery/worker/loops.py @@ -119,10 +119,8 @@ def synloop(obj, connection, consumer, blueprint, hub, qos, obj.on_ready() - def _loop_cycle(): - """ - Perform one iteration of the blocking event loop. - """ + while blueprint.state == RUN and obj.connection: + state.maybe_shutdown() if heartbeat_error[0] is not None: raise heartbeat_error[0] if qos.prev != qos.value: @@ -135,9 +133,3 @@ def synloop(obj, connection, consumer, blueprint, hub, qos, except OSError: if blueprint.state == RUN: raise - - while blueprint.state == RUN and obj.connection: - try: - state.maybe_shutdown() - finally: - _loop_cycle() diff --git a/venv/lib/python3.12/site-packages/celery/worker/request.py b/venv/lib/python3.12/site-packages/celery/worker/request.py index df99b54..5d7c93a 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/request.py +++ b/venv/lib/python3.12/site-packages/celery/worker/request.py @@ -602,8 +602,8 @@ class Request: is_worker_lost = isinstance(exc, WorkerLostError) if self.task.acks_late: reject = ( - (self.task.reject_on_worker_lost and is_worker_lost) - or (isinstance(exc, TimeLimitExceeded) and not self.task.acks_on_failure_or_timeout) + self.task.reject_on_worker_lost and + is_worker_lost ) ack = self.task.acks_on_failure_or_timeout if reject: @@ -777,7 +777,7 @@ def create_request_cls(base, task, pool, hostname, eventer, if isinstance(exc, (SystemExit, KeyboardInterrupt)): raise exc return self.on_failure(retval, return_ok=True) - task_ready(self, successful=True) + task_ready(self) if acks_late: self.acknowledge() diff --git a/venv/lib/python3.12/site-packages/celery/worker/worker.py b/venv/lib/python3.12/site-packages/celery/worker/worker.py index 2444012..04f8c30 100644 --- a/venv/lib/python3.12/site-packages/celery/worker/worker.py +++ b/venv/lib/python3.12/site-packages/celery/worker/worker.py @@ -14,8 +14,7 @@ The worker consists of several components, all managed by bootsteps import os import sys -from datetime import datetime, timezone -from time import sleep +from datetime import datetime from billiard import cpu_count from kombu.utils.compat import detect_environment @@ -90,7 +89,7 @@ class WorkController: def __init__(self, app=None, hostname=None, **kwargs): self.app = app or self.app self.hostname = default_nodename(hostname) - self.startup_time = datetime.now(timezone.utc) + self.startup_time = datetime.utcnow() self.app.loader.init_worker() self.on_before_init(**kwargs) self.setup_defaults(**kwargs) @@ -242,7 +241,7 @@ class WorkController: not self.app.IS_WINDOWS) def stop(self, in_sighandler=False, exitcode=None): - """Graceful shutdown of the worker server (Warm shutdown).""" + """Graceful shutdown of the worker server.""" if exitcode is not None: self.exitcode = exitcode if self.blueprint.state == RUN: @@ -252,7 +251,7 @@ class WorkController: self._send_worker_shutdown() def terminate(self, in_sighandler=False): - """Not so graceful shutdown of the worker server (Cold shutdown).""" + """Not so graceful shutdown of the worker server.""" if self.blueprint.state != TERMINATE: self.signal_consumer_close() if not in_sighandler or self.pool.signal_safe: @@ -294,7 +293,7 @@ class WorkController: return reload_from_cwd(sys.modules[module], reloader) def info(self): - uptime = datetime.now(timezone.utc) - self.startup_time + uptime = datetime.utcnow() - self.startup_time return {'total': self.state.total_count, 'pid': os.getpid(), 'clock': str(self.app.clock), @@ -408,28 +407,3 @@ class WorkController: 'worker_disable_rate_limits', disable_rate_limits, ) self.worker_lost_wait = either('worker_lost_wait', worker_lost_wait) - - def wait_for_soft_shutdown(self): - """Wait :setting:`worker_soft_shutdown_timeout` if soft shutdown is enabled. - - To enable soft shutdown, set the :setting:`worker_soft_shutdown_timeout` in the - configuration. Soft shutdown can be used to allow the worker to finish processing - few more tasks before initiating a cold shutdown. This mechanism allows the worker - to finish short tasks that are already in progress and requeue long-running tasks - to be picked up by another worker. - - .. warning:: - If there are no tasks in the worker, the worker will not wait for the - soft shutdown timeout even if it is set as it makes no sense to wait for - the timeout when there are no tasks to process. - """ - app = self.app - requests = tuple(state.active_requests) - - if app.conf.worker_enable_soft_shutdown_on_idle: - requests = True - - if app.conf.worker_soft_shutdown_timeout > 0 and requests: - log = f"Initiating Soft Shutdown, terminating in {app.conf.worker_soft_shutdown_timeout} seconds" - logger.warning(log) - sleep(app.conf.worker_soft_shutdown_timeout) diff --git a/venv/lib/python3.12/site-packages/dotenv/cli.py b/venv/lib/python3.12/site-packages/dotenv/cli.py index 075a7af..65ead46 100644 --- a/venv/lib/python3.12/site-packages/dotenv/cli.py +++ b/venv/lib/python3.12/site-packages/dotenv/cli.py @@ -3,10 +3,8 @@ import os import shlex import sys from contextlib import contextmanager -from typing import Any, Dict, IO, Iterator, List, Optional - -if sys.platform == 'win32': - from subprocess import Popen +from subprocess import Popen +from typing import Any, Dict, IO, Iterator, List try: import click @@ -19,7 +17,7 @@ from .main import dotenv_values, set_key, unset_key from .version import __version__ -def enumerate_env() -> Optional[str]: +def enumerate_env(): """ Return a path for the ${pwd}/.env file. @@ -163,13 +161,14 @@ def run(ctx: click.Context, override: bool, commandline: List[str]) -> None: if not commandline: click.echo('No command given.') exit(1) - run_command(commandline, dotenv_as_dict) + ret = run_command(commandline, dotenv_as_dict) + exit(ret) -def run_command(command: List[str], env: Dict[str, str]) -> None: - """Replace the current process with the specified command. +def run_command(command: List[str], env: Dict[str, str]) -> int: + """Run command in sub process. - Replaces the current process with the specified command and the variables from `env` + Runs the command in a sub process with the variables from `env` added in the current environment variables. Parameters @@ -181,8 +180,8 @@ def run_command(command: List[str], env: Dict[str, str]) -> None: Returns ------- - None - This function does not return any value. It replaces the current process with the new one. + int + The return code of the command """ # copy the current environment variables and add the vales from @@ -190,16 +189,11 @@ def run_command(command: List[str], env: Dict[str, str]) -> None: cmd_env = os.environ.copy() cmd_env.update(env) - if sys.platform == 'win32': - # execvpe on Windows returns control immediately - # rather than once the command has finished. - p = Popen(command, - universal_newlines=True, - bufsize=0, - shell=False, - env=cmd_env) - _, _ = p.communicate() + p = Popen(command, + universal_newlines=True, + bufsize=0, + shell=False, + env=cmd_env) + _, _ = p.communicate() - exit(p.returncode) - else: - os.execvpe(command[0], args=command, env=cmd_env) + return p.returncode diff --git a/venv/lib/python3.12/site-packages/dotenv/main.py b/venv/lib/python3.12/site-packages/dotenv/main.py index 8e6a7cf..f40c20e 100644 --- a/venv/lib/python3.12/site-packages/dotenv/main.py +++ b/venv/lib/python3.12/site-packages/dotenv/main.py @@ -1,13 +1,13 @@ import io import logging import os -import pathlib import shutil import sys import tempfile from collections import OrderedDict from contextlib import contextmanager -from typing import IO, Dict, Iterable, Iterator, Mapping, Optional, Tuple, Union +from typing import (IO, Dict, Iterable, Iterator, Mapping, Optional, Tuple, + Union) from .parser import Binding, parse_stream from .variables import parse_variables @@ -16,7 +16,7 @@ from .variables import parse_variables # These paths may flow to `open()` and `shutil.move()`; `shutil.move()` # only accepts string paths, not byte paths or file descriptors. See # https://github.com/python/typeshed/pull/6832. -StrPath = Union[str, "os.PathLike[str]"] +StrPath = Union[str, 'os.PathLike[str]'] logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def with_warn_for_invalid_lines(mappings: Iterator[Binding]) -> Iterator[Binding for mapping in mappings: if mapping.error: logger.warning( - "python-dotenv could not parse statement starting at line %s", + "Python-dotenv could not parse statement starting at line %s", mapping.original.line, ) yield mapping @@ -59,10 +59,10 @@ class DotEnv: else: if self.verbose: logger.info( - "python-dotenv could not find configuration file %s.", - self.dotenv_path or ".env", + "Python-dotenv could not find configuration file %s.", + self.dotenv_path or '.env', ) - yield io.StringIO("") + yield io.StringIO('') def dict(self) -> Dict[str, Optional[str]]: """Return dotenv as dict""" @@ -72,9 +72,7 @@ class DotEnv: raw_values = self.parse() if self.interpolate: - self._dict = OrderedDict( - resolve_variables(raw_values, override=self.override) - ) + self._dict = OrderedDict(resolve_variables(raw_values, override=self.override)) else: self._dict = OrderedDict(raw_values) @@ -102,7 +100,8 @@ class DotEnv: return True def get(self, key: str) -> Optional[str]: - """ """ + """ + """ data = self.dict() if key in data: @@ -132,21 +131,17 @@ def rewrite( path: StrPath, encoding: Optional[str], ) -> Iterator[Tuple[IO[str], IO[str]]]: - pathlib.Path(path).touch() - + if not os.path.isfile(path): + with open(path, mode="w", encoding=encoding) as source: + source.write("") with tempfile.NamedTemporaryFile(mode="w", encoding=encoding, delete=False) as dest: - error = None try: with open(path, encoding=encoding) as source: yield (source, dest) - except BaseException as err: - error = err - - if error is None: - shutil.move(dest.name, path) - else: - os.unlink(dest.name) - raise error from None + except BaseException: + os.unlink(dest.name) + raise + shutil.move(dest.name, path) def set_key( @@ -166,8 +161,9 @@ def set_key( if quote_mode not in ("always", "auto", "never"): raise ValueError(f"Unknown quote_mode: {quote_mode}") - quote = quote_mode == "always" or ( - quote_mode == "auto" and not value_to_set.isalnum() + quote = ( + quote_mode == "always" + or (quote_mode == "auto" and not value_to_set.isalnum()) ) if quote: @@ -175,7 +171,7 @@ def set_key( else: value_out = value_to_set if export: - line_out = f"export {key_to_set}={value_out}\n" + line_out = f'export {key_to_set}={value_out}\n' else: line_out = f"{key_to_set}={value_out}\n" @@ -222,9 +218,7 @@ def unset_key( dest.write(mapping.original.string) if not removed: - logger.warning( - "Key %s not removed from %s - key doesn't exist.", key_to_unset, dotenv_path - ) + logger.warning("Key %s not removed from %s - key doesn't exist.", key_to_unset, dotenv_path) return None, key_to_unset return removed, key_to_unset @@ -236,7 +230,7 @@ def resolve_variables( ) -> Mapping[str, Optional[str]]: new_values: Dict[str, Optional[str]] = {} - for name, value in values: + for (name, value) in values: if value is None: result = None else: @@ -260,7 +254,7 @@ def _walk_to_root(path: str) -> Iterator[str]: Yield directories starting from the given directory up to the root """ if not os.path.exists(path): - raise IOError("Starting path not found") + raise IOError('Starting path not found') if os.path.isfile(path): path = os.path.dirname(path) @@ -274,7 +268,7 @@ def _walk_to_root(path: str) -> Iterator[str]: def find_dotenv( - filename: str = ".env", + filename: str = '.env', raise_error_if_not_found: bool = False, usecwd: bool = False, ) -> str: @@ -285,19 +279,11 @@ def find_dotenv( """ def _is_interactive(): - """Decide whether this is running in a REPL or IPython notebook""" - if hasattr(sys, "ps1") or hasattr(sys, "ps2"): - return True - try: - main = __import__("__main__", None, None, fromlist=["__file__"]) - except ModuleNotFoundError: - return False - return not hasattr(main, "__file__") + """ Decide whether this is running in a REPL or IPython notebook """ + main = __import__('__main__', None, None, fromlist=['__file__']) + return not hasattr(main, '__file__') - def _is_debugger(): - return sys.gettrace() is not None - - if usecwd or _is_interactive() or _is_debugger() or getattr(sys, "frozen", False): + if usecwd or _is_interactive() or getattr(sys, 'frozen', False): # Should work without __file__, e.g. in REPL or IPython notebook. path = os.getcwd() else: @@ -305,9 +291,7 @@ def find_dotenv( frame = sys._getframe() current_file = __file__ - while frame.f_code.co_filename == current_file or not os.path.exists( - frame.f_code.co_filename - ): + while frame.f_code.co_filename == current_file: assert frame.f_back is not None frame = frame.f_back frame_filename = frame.f_code.co_filename @@ -319,9 +303,9 @@ def find_dotenv( return check_path if raise_error_if_not_found: - raise IOError("File not found") + raise IOError('File not found') - return "" + return '' def load_dotenv( @@ -346,9 +330,7 @@ def load_dotenv( Bool: True if at least one environment variable is set else False If both `dotenv_path` and `stream` are `None`, `find_dotenv()` is used to find the - .env file with it's default parameters. If you need to change the default parameters - of `find_dotenv()`, you can explicitly call `find_dotenv()` and pass the result - to this function as `dotenv_path`. + .env file. """ if dotenv_path is None and stream is None: dotenv_path = find_dotenv() diff --git a/venv/lib/python3.12/site-packages/dotenv/version.py b/venv/lib/python3.12/site-packages/dotenv/version.py index a82b376..5becc17 100644 --- a/venv/lib/python3.12/site-packages/dotenv/version.py +++ b/venv/lib/python3.12/site-packages/dotenv/version.py @@ -1 +1 @@ -__version__ = "1.1.1" +__version__ = "1.0.0" diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/METADATA b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/METADATA similarity index 63% rename from venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/METADATA rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/METADATA index ac8789a..2fa535f 100644 --- a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/METADATA @@ -1,73 +1,56 @@ Metadata-Version: 2.1 Name: fastapi -Version: 0.117.1 +Version: 0.104.1 Summary: FastAPI framework, high performance, easy to learn, fast to code, ready for production -Author-Email: =?utf-8?q?Sebasti=C3=A1n_Ram=C3=ADrez?= -Classifier: Intended Audience :: Information Technology -Classifier: Intended Audience :: System Administrators -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python -Classifier: Topic :: Internet -Classifier: Topic :: Software Development :: Libraries :: Application Frameworks -Classifier: Topic :: Software Development :: Libraries :: Python Modules -Classifier: Topic :: Software Development :: Libraries -Classifier: Topic :: Software Development -Classifier: Typing :: Typed +Project-URL: Homepage, https://github.com/tiangolo/fastapi +Project-URL: Documentation, https://fastapi.tiangolo.com/ +Project-URL: Repository, https://github.com/tiangolo/fastapi +Author-email: Sebastián Ramírez +License-Expression: MIT +License-File: LICENSE Classifier: Development Status :: 4 - Beta Classifier: Environment :: Web Environment Classifier: Framework :: AsyncIO Classifier: Framework :: FastAPI Classifier: Framework :: Pydantic Classifier: Framework :: Pydantic :: 1 -Classifier: Framework :: Pydantic :: 2 Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Information Technology +Classifier: Intended Audience :: System Administrators Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 -Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers +Classifier: Topic :: Internet Classifier: Topic :: Internet :: WWW/HTTP -Project-URL: Homepage, https://github.com/fastapi/fastapi -Project-URL: Documentation, https://fastapi.tiangolo.com/ -Project-URL: Repository, https://github.com/fastapi/fastapi -Project-URL: Issues, https://github.com/fastapi/fastapi/issues -Project-URL: Changelog, https://fastapi.tiangolo.com/release-notes/ +Classifier: Topic :: Internet :: WWW/HTTP :: HTTP Servers +Classifier: Topic :: Software Development +Classifier: Topic :: Software Development :: Libraries +Classifier: Topic :: Software Development :: Libraries :: Application Frameworks +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Classifier: Typing :: Typed Requires-Python: >=3.8 -Requires-Dist: starlette<0.49.0,>=0.40.0 +Requires-Dist: anyio<4.0.0,>=3.7.1 Requires-Dist: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4 +Requires-Dist: starlette<0.28.0,>=0.27.0 Requires-Dist: typing-extensions>=4.8.0 -Provides-Extra: standard -Requires-Dist: fastapi-cli[standard]>=0.0.8; extra == "standard" -Requires-Dist: httpx<1.0.0,>=0.23.0; extra == "standard" -Requires-Dist: jinja2>=3.1.5; extra == "standard" -Requires-Dist: python-multipart>=0.0.18; extra == "standard" -Requires-Dist: email-validator>=2.0.0; extra == "standard" -Requires-Dist: uvicorn[standard]>=0.12.0; extra == "standard" -Provides-Extra: standard-no-fastapi-cloud-cli -Requires-Dist: fastapi-cli[standard-no-fastapi-cloud-cli]>=0.0.8; extra == "standard-no-fastapi-cloud-cli" -Requires-Dist: httpx<1.0.0,>=0.23.0; extra == "standard-no-fastapi-cloud-cli" -Requires-Dist: jinja2>=3.1.5; extra == "standard-no-fastapi-cloud-cli" -Requires-Dist: python-multipart>=0.0.18; extra == "standard-no-fastapi-cloud-cli" -Requires-Dist: email-validator>=2.0.0; extra == "standard-no-fastapi-cloud-cli" -Requires-Dist: uvicorn[standard]>=0.12.0; extra == "standard-no-fastapi-cloud-cli" Provides-Extra: all -Requires-Dist: fastapi-cli[standard]>=0.0.8; extra == "all" -Requires-Dist: httpx<1.0.0,>=0.23.0; extra == "all" -Requires-Dist: jinja2>=3.1.5; extra == "all" -Requires-Dist: python-multipart>=0.0.18; extra == "all" -Requires-Dist: itsdangerous>=1.1.0; extra == "all" -Requires-Dist: pyyaml>=5.3.1; extra == "all" -Requires-Dist: ujson!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,>=4.0.1; extra == "all" -Requires-Dist: orjson>=3.2.1; extra == "all" -Requires-Dist: email-validator>=2.0.0; extra == "all" -Requires-Dist: uvicorn[standard]>=0.12.0; extra == "all" -Requires-Dist: pydantic-settings>=2.0.0; extra == "all" -Requires-Dist: pydantic-extra-types>=2.0.0; extra == "all" +Requires-Dist: email-validator>=2.0.0; extra == 'all' +Requires-Dist: httpx>=0.23.0; extra == 'all' +Requires-Dist: itsdangerous>=1.1.0; extra == 'all' +Requires-Dist: jinja2>=2.11.2; extra == 'all' +Requires-Dist: orjson>=3.2.1; extra == 'all' +Requires-Dist: pydantic-extra-types>=2.0.0; extra == 'all' +Requires-Dist: pydantic-settings>=2.0.0; extra == 'all' +Requires-Dist: python-multipart>=0.0.5; extra == 'all' +Requires-Dist: pyyaml>=5.3.1; extra == 'all' +Requires-Dist: ujson!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,>=4.0.1; extra == 'all' +Requires-Dist: uvicorn[standard]>=0.12.0; extra == 'all' Description-Content-Type: text/markdown

@@ -77,11 +60,11 @@ Description-Content-Type: text/markdown FastAPI framework, high performance, easy to learn, fast to code, ready for production

- - Test + + Test - - Coverage + + Coverage Package version @@ -95,11 +78,11 @@ Description-Content-Type: text/markdown **Documentation**: https://fastapi.tiangolo.com -**Source Code**: https://github.com/fastapi/fastapi +**Source Code**: https://github.com/tiangolo/fastapi --- -FastAPI is a modern, fast (high-performance), web framework for building APIs with Python based on standard Python type hints. +FastAPI is a modern, fast (high-performance), web framework for building APIs with Python 3.8+ based on standard Python type hints. The key features are: @@ -118,22 +101,19 @@ The key features are: - - - - - - - - - - - + + + + + + + + + + + + - - - - @@ -143,7 +123,7 @@ The key features are: "_[...] I'm using **FastAPI** a ton these days. [...] I'm actually planning to use it for all of my team's **ML services at Microsoft**. Some of them are getting integrated into the core **Windows** product and some **Office** products._" -

Kabir Khan - Microsoft (ref)
+
Kabir Khan - Microsoft (ref)
--- @@ -161,13 +141,13 @@ The key features are: "_I’m over the moon excited about **FastAPI**. It’s so fun!_" -
Brian Okken - Python Bytes podcast host (ref)
+
Brian Okken - Python Bytes podcast host (ref)
--- "_Honestly, what you've built looks super solid and polished. In many ways, it's what I wanted **Hug** to be - it's really inspiring to see someone build that._" -
Timothy Crosley - Hug creator (ref)
+
Timothy Crosley - Hug creator (ref)
--- @@ -175,7 +155,7 @@ The key features are: "_We've switched over to **FastAPI** for our **APIs** [...] I think you'll like it [...]_" -
Ines Montani - Matthew Honnibal - Explosion AI founders - spaCy creators (ref) - (ref)
+
Ines Montani - Matthew Honnibal - Explosion AI founders - spaCy creators (ref) - (ref)
--- @@ -195,32 +175,42 @@ If you are building a CLI app to be ## Requirements +Python 3.8+ + FastAPI stands on the shoulders of giants: * Starlette for the web parts. -* Pydantic for the data parts. +* Pydantic for the data parts. ## Installation -Create and activate a virtual environment and then install FastAPI: -
```console -$ pip install "fastapi[standard]" +$ pip install fastapi ---> 100% ```
-**Note**: Make sure you put `"fastapi[standard]"` in quotes to ensure it works in all terminals. +You will also need an ASGI server, for production such as Uvicorn or Hypercorn. + +
+ +```console +$ pip install "uvicorn[standard]" + +---> 100% +``` + +
## Example ### Create it -Create a file `main.py` with: +* Create a file `main.py` with: ```Python from typing import Union @@ -276,24 +266,11 @@ Run the server with:
```console -$ fastapi dev main.py +$ uvicorn main:app --reload - ╭────────── FastAPI CLI - Development mode ───────────╮ - │ │ - │ Serving at: http://127.0.0.1:8000 │ - │ │ - │ API docs: http://127.0.0.1:8000/docs │ - │ │ - │ Running in development mode, for production use: │ - │ │ - │ fastapi run │ - │ │ - ╰─────────────────────────────────────────────────────╯ - -INFO: Will watch for changes in these directories: ['/home/user/code/awesomeapp'] INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) -INFO: Started reloader process [2248755] using WatchFiles -INFO: Started server process [2248757] +INFO: Started reloader process [28720] +INFO: Started server process [28722] INFO: Waiting for application startup. INFO: Application startup complete. ``` @@ -301,13 +278,13 @@ INFO: Application startup complete.
-About the command fastapi dev main.py... +About the command uvicorn main:app --reload... -The command `fastapi dev` reads your `main.py` file, detects the **FastAPI** app in it, and starts a server using Uvicorn. +The command `uvicorn main:app` refers to: -By default, `fastapi dev` will start with auto-reload enabled for local development. - -You can read more about it in the FastAPI CLI docs. +* `main`: the file `main.py` (the Python "module"). +* `app`: the object created inside of `main.py` with the line `app = FastAPI()`. +* `--reload`: make the server restart after code changes. Only do this for development.
@@ -380,7 +357,7 @@ def update_item(item_id: int, item: Item): return {"item_name": item.name, "item_id": item_id} ``` -The `fastapi dev` server should reload automatically. +The server should reload automatically (because you added `--reload` to the `uvicorn` command above). ### Interactive API docs upgrade @@ -414,7 +391,7 @@ You do that with standard modern Python types. You don't have to learn a new syntax, the methods or classes of a specific library, etc. -Just standard **Python**. +Just standard **Python 3.8+**. For example, for an `int`: @@ -464,7 +441,7 @@ Coming back to the previous code example, **FastAPI** will: * Check if there is an optional query parameter named `q` (as in `http://127.0.0.1:8000/items/foo?q=somequery`) for `GET` requests. * As the `q` parameter is declared with `= None`, it is optional. * Without the `None` it would be required (as is the body in the case with `PUT`). -* For `PUT` requests to `/items/{item_id}`, read the body as JSON: +* For `PUT` requests to `/items/{item_id}`, Read the body as JSON: * Check that it has a required attribute `name` that should be a `str`. * Check that it has a required attribute `price` that has to be a `float`. * Check that it has an optional attribute `is_offer`, that should be a `bool`, if present. @@ -524,52 +501,30 @@ Independent TechEmpower benchmarks show **FastAPI** applications running under U To understand more about it, see the section Benchmarks. -## Dependencies - -FastAPI depends on Pydantic and Starlette. - -### `standard` Dependencies - -When you install FastAPI with `pip install "fastapi[standard]"` it comes with the `standard` group of optional dependencies: +## Optional Dependencies Used by Pydantic: -* email-validator - for email validation. +* email_validator - for email validation. +* pydantic-settings - for settings management. +* pydantic-extra-types - for extra types to be used with Pydantic. Used by Starlette: * httpx - Required if you want to use the `TestClient`. * jinja2 - Required if you want to use the default template configuration. -* python-multipart - Required if you want to support form "parsing", with `request.form()`. - -Used by FastAPI: - -* uvicorn - for the server that loads and serves your application. This includes `uvicorn[standard]`, which includes some dependencies (e.g. `uvloop`) needed for high performance serving. -* `fastapi-cli[standard]` - to provide the `fastapi` command. - * This includes `fastapi-cloud-cli`, which allows you to deploy your FastAPI application to FastAPI Cloud. - -### Without `standard` Dependencies - -If you don't want to include the `standard` optional dependencies, you can install with `pip install fastapi` instead of `pip install "fastapi[standard]"`. - -### Without `fastapi-cloud-cli` - -If you want to install FastAPI with the standard dependencies but without the `fastapi-cloud-cli`, you can install with `pip install "fastapi[standard-no-fastapi-cloud-cli]"`. - -### Additional Optional Dependencies - -There are some additional dependencies you might want to install. - -Additional optional Pydantic dependencies: - -* pydantic-settings - for settings management. -* pydantic-extra-types - for extra types to be used with Pydantic. - -Additional optional FastAPI dependencies: - -* orjson - Required if you want to use `ORJSONResponse`. +* python-multipart - Required if you want to support form "parsing", with `request.form()`. +* itsdangerous - Required for `SessionMiddleware` support. +* pyyaml - Required for Starlette's `SchemaGenerator` support (you probably don't need it with FastAPI). * ujson - Required if you want to use `UJSONResponse`. +Used by FastAPI / Starlette: + +* uvicorn - for the server that loads and serves your application. +* orjson - Required if you want to use `ORJSONResponse`. + +You can install all of these with `pip install "fastapi[all]"`. + ## License This project is licensed under the terms of the MIT license. diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/RECORD similarity index 59% rename from venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/RECORD rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/RECORD index c6183f7..fa93a74 100644 --- a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/RECORD @@ -1,19 +1,14 @@ -../../../bin/fastapi,sha256=OEUEr4c4P8T3MgT1IotIxN3V27TAonSeCWDRN70Vgfw,231 -fastapi-0.117.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -fastapi-0.117.1.dist-info/METADATA,sha256=CSMeNXJKTuCRib4fhSOBx_tdb_N2YOd8vfHzaDAx_X0,28135 -fastapi-0.117.1.dist-info/RECORD,, -fastapi-0.117.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -fastapi-0.117.1.dist-info/WHEEL,sha256=9P2ygRxDrTJz3gsagc0Z96ukrxjr-LFBGOgv3AuKlCA,90 -fastapi-0.117.1.dist-info/entry_points.txt,sha256=GCf-WbIZxyGT4MUmrPGj1cOHYZoGsNPHAvNkT6hnGeA,61 -fastapi-0.117.1.dist-info/licenses/LICENSE,sha256=Tsif_IFIW5f-xYSy1KlhAy7v_oNEU4lP2cEnSQbMdE4,1086 -fastapi/__init__.py,sha256=71oE4uLHKzglF1IOxDbfcd-YwW2Qeu5emeUz00x0S98,1081 -fastapi/__main__.py,sha256=bKePXLdO4SsVSM6r9SVoLickJDcR2c0cTOxZRKq26YQ,37 +fastapi-0.104.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +fastapi-0.104.1.dist-info/METADATA,sha256=Zgj7yzBMm50KgBZsq5R9A29zVk7LMUvkUC6oTWuR8J0,24298 +fastapi-0.104.1.dist-info/RECORD,, +fastapi-0.104.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +fastapi-0.104.1.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87 +fastapi-0.104.1.dist-info/licenses/LICENSE,sha256=Tsif_IFIW5f-xYSy1KlhAy7v_oNEU4lP2cEnSQbMdE4,1086 +fastapi/__init__.py,sha256=n8125d7_qIsNGVM_1QL7_LpYtGH8GYrkJjgSMjP31cE,1081 fastapi/__pycache__/__init__.cpython-312.pyc,, -fastapi/__pycache__/__main__.cpython-312.pyc,, fastapi/__pycache__/_compat.cpython-312.pyc,, fastapi/__pycache__/applications.cpython-312.pyc,, fastapi/__pycache__/background.cpython-312.pyc,, -fastapi/__pycache__/cli.cpython-312.pyc,, fastapi/__pycache__/concurrency.cpython-312.pyc,, fastapi/__pycache__/datastructures.cpython-312.pyc,, fastapi/__pycache__/encoders.cpython-312.pyc,, @@ -31,29 +26,30 @@ fastapi/__pycache__/testclient.cpython-312.pyc,, fastapi/__pycache__/types.cpython-312.pyc,, fastapi/__pycache__/utils.cpython-312.pyc,, fastapi/__pycache__/websockets.cpython-312.pyc,, -fastapi/_compat.py,sha256=EQyNY-qrN3cjwI1r69JVAROc2lQCvi6W1we6_7jx_gc,24274 -fastapi/applications.py,sha256=Sr6fkAYFmuyIT4b0Rm33NQzO8oz4-DEc3PLTxp4LJgU,177570 -fastapi/background.py,sha256=rouLirxUANrcYC824MSMypXL_Qb2HYg2YZqaiEqbEKI,1768 -fastapi/cli.py,sha256=OYhZb0NR_deuT5ofyPF2NoNBzZDNOP8Salef2nk-HqA,418 -fastapi/concurrency.py,sha256=MirfowoSpkMQZ8j_g0ZxaQKpV6eB3G-dB5TgcXCrgEA,1424 -fastapi/datastructures.py,sha256=b2PEz77XGq-u3Ur1Inwk0AGjOsQZO49yF9C7IPJ15cY,5766 +fastapi/_compat.py,sha256=BlQp8ec0cFM6FLAEASdpYd7Ip9TY1FZr8PGiGRO4QLg,22798 +fastapi/applications.py,sha256=C7mT6eZh0XUO2HmLM43_gJMyqjoyy_SdgypDHRrLu34,179073 +fastapi/background.py,sha256=F1tsrJKfDZaRchNgF9ykB2PcRaPBJTbL4htN45TJAIc,1799 +fastapi/concurrency.py,sha256=NAK9SMlTCOALLjTAR6KzWUDEkVj7_EyNRz0-lDVW_W8,1467 +fastapi/datastructures.py,sha256=FF1s2g6cAQ5XxlNToB3scgV94Zf3DjdzcaI7ToaTrmg,5797 fastapi/dependencies/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 fastapi/dependencies/__pycache__/__init__.cpython-312.pyc,, fastapi/dependencies/__pycache__/models.cpython-312.pyc,, fastapi/dependencies/__pycache__/utils.cpython-312.pyc,, -fastapi/dependencies/models.py,sha256=Pjl6vx-4nZ5Tta9kJa3-RfQKkXtCpS09-FhMgs9eWNs,1507 -fastapi/dependencies/utils.py,sha256=WVgX-cF_H318wOlsZSiAP2mX6-puEsx_MAQ6AHSzITE,36814 -fastapi/encoders.py,sha256=r_fOgMylrlnCDTh3W9u2W0ZsHTJqIhLpU6QipHMy0m8,11119 -fastapi/exception_handlers.py,sha256=YVcT8Zy021VYYeecgdyh5YEUjEIHKcLspbkSf4OfbJI,1275 -fastapi/exceptions.py,sha256=taNixuFEXb67lI1bnX1ubq8y8TseJ4yoPlWjyP0fTzk,4969 +fastapi/dependencies/models.py,sha256=-n-YCxzgVBkurQi49qOTooT71v_oeAhHJ-qQFonxh5o,2494 +fastapi/dependencies/utils.py,sha256=DjRdd_NVdXh_jDYKTRjUIXkwkLD0WE4oFXQC4peMr2c,29915 +fastapi/encoders.py,sha256=90lbmIW8NZjpPVzbgKhpY49B7TFqa7hrdQDQa70SM9U,11024 +fastapi/exception_handlers.py,sha256=MBrIOA-ugjJDivIi4rSsUJBdTsjuzN76q4yh0q1COKw,1332 +fastapi/exceptions.py,sha256=SQsPxq-QYBZUhq6L4K3B3W7gaSD3Gub2f17erStRagY,5000 fastapi/logger.py,sha256=I9NNi3ov8AcqbsbC9wl1X-hdItKgYt2XTrx1f99Zpl4,54 fastapi/middleware/__init__.py,sha256=oQDxiFVcc1fYJUOIFvphnK7pTT5kktmfL32QXpBFvvo,58 fastapi/middleware/__pycache__/__init__.cpython-312.pyc,, +fastapi/middleware/__pycache__/asyncexitstack.cpython-312.pyc,, fastapi/middleware/__pycache__/cors.cpython-312.pyc,, fastapi/middleware/__pycache__/gzip.cpython-312.pyc,, fastapi/middleware/__pycache__/httpsredirect.cpython-312.pyc,, fastapi/middleware/__pycache__/trustedhost.cpython-312.pyc,, fastapi/middleware/__pycache__/wsgi.cpython-312.pyc,, +fastapi/middleware/asyncexitstack.py,sha256=LvMyVI1QdmWNWYPZqx295VFavssUfVpUsonPOsMWz1E,1035 fastapi/middleware/cors.py,sha256=ynwjWQZoc_vbhzZ3_ZXceoaSrslHFHPdoM52rXr0WUU,79 fastapi/middleware/gzip.py,sha256=xM5PcsH8QlAimZw4VDvcmTnqQamslThsfe3CVN2voa0,79 fastapi/middleware/httpsredirect.py,sha256=rL8eXMnmLijwVkH7_400zHri1AekfeBd6D6qs8ix950,115 @@ -66,15 +62,15 @@ fastapi/openapi/__pycache__/docs.cpython-312.pyc,, fastapi/openapi/__pycache__/models.cpython-312.pyc,, fastapi/openapi/__pycache__/utils.cpython-312.pyc,, fastapi/openapi/constants.py,sha256=adGzmis1L1HJRTE3kJ5fmHS_Noq6tIY6pWv_SFzoFDU,153 -fastapi/openapi/docs.py,sha256=zSDv4xY6XHcKsaG4zyk1HqSnrZtfZFBB0J7ZBk5YHPE,10345 -fastapi/openapi/models.py,sha256=m1BNHxf_RiDTK1uCfMre6XZN5y7krZNA62QEP_2EV9s,15625 -fastapi/openapi/utils.py,sha256=ZI-nwdT2PtX8kaRPJylZo4LJHjYAcoVGxkd181P75x4,23997 -fastapi/param_functions.py,sha256=JHNPLIYvoAwdnZZavIVsxOat8x23fX_Kl33reh7HKl8,64019 -fastapi/params.py,sha256=g450axUBQgQJODdtM7WBxZbQj9Z64inFvadrgHikBbU,28237 +fastapi/openapi/docs.py,sha256=Fo_SGB0eEfGvlNLqP-w_jgYifmHTe-3LbO_qC-ncFVY,10387 +fastapi/openapi/models.py,sha256=DEmsWA-9sNqv2H4YneZUW86r1nMwD920EiTvan5kndI,17763 +fastapi/openapi/utils.py,sha256=PUuz_ISarHVPBRyIgfyHz8uwH0eEsDY3rJUfW__I9GI,22303 +fastapi/param_functions.py,sha256=VWEsJbkH8lJZgcJ6fI6uzquui1kgHrDv1i_wXM7cW3M,63896 +fastapi/params.py,sha256=LzjihAvODd3w7-GddraUyVtH1xfwR9smIoQn-Z_g4mg,27807 fastapi/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 fastapi/requests.py,sha256=zayepKFcienBllv3snmWI20Gk0oHNVLU4DDhqXBb4LU,142 fastapi/responses.py,sha256=QNQQlwpKhQoIPZTTWkpc9d_QGeGZ_aVQPaDV3nQ8m7c,1761 -fastapi/routing.py,sha256=4zaZIdeq8VtsBLCxmmEgnfHqDD6USTc_l8BxUe1ye4M,176533 +fastapi/routing.py,sha256=VADa3-b52ahpweFCcmAKXkVKldMrfF60N5gZWobI42M,172198 fastapi/security/__init__.py,sha256=bO8pNmxqVRXUjfl2mOKiVZLn0FpBQ61VUYVjmppnbJw,881 fastapi/security/__pycache__/__init__.cpython-312.pyc,, fastapi/security/__pycache__/api_key.cpython-312.pyc,, @@ -83,15 +79,15 @@ fastapi/security/__pycache__/http.cpython-312.pyc,, fastapi/security/__pycache__/oauth2.cpython-312.pyc,, fastapi/security/__pycache__/open_id_connect_url.cpython-312.pyc,, fastapi/security/__pycache__/utils.cpython-312.pyc,, -fastapi/security/api_key.py,sha256=di-0gQ8MKugi2YfmlMoDHk-QMF_vnLGJRFOA6tcZ7fA,9016 +fastapi/security/api_key.py,sha256=bcZbUzTqeR_CI_LXuJdDq1qL322kmhgy5ApOCqgGDi4,9399 fastapi/security/base.py,sha256=dl4pvbC-RxjfbWgPtCWd8MVU-7CB2SZ22rJDXVCXO6c,141 -fastapi/security/http.py,sha256=rWR2x-5CUsjWmRucYthwRig6MG1o-boyrr4Xo-PuuxU,13606 -fastapi/security/oauth2.py,sha256=M1AFIDT7G3oQChq83poI3eg8ZDeibcvnGmya2CTS7JY,22036 -fastapi/security/open_id_connect_url.py,sha256=8vizZ2tGqEp1ur8SwtVgyHJhGAJ5AqahgcvSpaIioDI,2722 +fastapi/security/http.py,sha256=_YdhSRRUCGydVDUILygWg0VlkPA28t_gjcy_axD3eOk,13537 +fastapi/security/oauth2.py,sha256=QAUOE2f6KXbXjkrJIIYCOugI6-R0g9EECZ5t8eN9nA4,21612 +fastapi/security/open_id_connect_url.py,sha256=Mb8wFxrRh4CrsFW0RcjBEQLASPHGDtZRP6c2dCrspAg,2753 fastapi/security/utils.py,sha256=bd8T0YM7UQD5ATKucr1bNtAvz_Y3__dVNAv5UebiPvc,293 fastapi/staticfiles.py,sha256=iirGIt3sdY2QZXd36ijs3Cj-T0FuGFda3cd90kM9Ikw,69 fastapi/templating.py,sha256=4zsuTWgcjcEainMJFAlW6-gnslm6AgOS1SiiDWfmQxk,76 fastapi/testclient.py,sha256=nBvaAmX66YldReJNZXPOk1sfuo2Q6hs8bOvIaCep6LQ,66 -fastapi/types.py,sha256=nFb36sK3DSoqoyo7Miwy3meKK5UdFBgkAgLSzQlUVyI,383 -fastapi/utils.py,sha256=S59stPvKPUJ7MSkke3FaegSyig_4Uwhd32jnLiMF1jE,8032 +fastapi/types.py,sha256=WZJ1jvm1MCwIrxxRYxKwtXS9HqcGk0RnCbLzrMZh-lI,428 +fastapi/utils.py,sha256=rpSasHpgooPIfe67yU3HzOMDv7PtxiG9x6K-bhu6Z18,8193 fastapi/websockets.py,sha256=419uncYObEKZG0YcrXscfQQYLSWoE10jqxVMetGdR98,222 diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/WHEEL b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/WHEEL similarity index 67% rename from venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/WHEEL index 12228d4..ba1a8af 100644 --- a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/WHEEL @@ -1,4 +1,4 @@ Wheel-Version: 1.0 -Generator: hatchling 1.27.0 +Generator: hatchling 1.18.0 Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/licenses/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/fastapi-0.104.1.dist-info/licenses/LICENSE diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/WHEEL deleted file mode 100644 index 045c8ac..0000000 --- a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/WHEEL +++ /dev/null @@ -1,4 +0,0 @@ -Wheel-Version: 1.0 -Generator: pdm-backend (2.4.5) -Root-Is-Purelib: true -Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/entry_points.txt b/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/entry_points.txt deleted file mode 100644 index b81849e..0000000 --- a/venv/lib/python3.12/site-packages/fastapi-0.117.1.dist-info/entry_points.txt +++ /dev/null @@ -1,5 +0,0 @@ -[console_scripts] -fastapi = fastapi.cli:main - -[gui_scripts] - diff --git a/venv/lib/python3.12/site-packages/fastapi/__init__.py b/venv/lib/python3.12/site-packages/fastapi/__init__.py index 986fd20..c81f09b 100644 --- a/venv/lib/python3.12/site-packages/fastapi/__init__.py +++ b/venv/lib/python3.12/site-packages/fastapi/__init__.py @@ -1,6 +1,6 @@ """FastAPI framework, high performance, easy to learn, fast to code, ready for production""" -__version__ = "0.117.1" +__version__ = "0.104.1" from starlette import status as status diff --git a/venv/lib/python3.12/site-packages/fastapi/__main__.py b/venv/lib/python3.12/site-packages/fastapi/__main__.py deleted file mode 100644 index fc36465..0000000 --- a/venv/lib/python3.12/site-packages/fastapi/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from fastapi.cli import main - -main() diff --git a/venv/lib/python3.12/site-packages/fastapi/_compat.py b/venv/lib/python3.12/site-packages/fastapi/_compat.py index 26b6638..fc605d0 100644 --- a/venv/lib/python3.12/site-packages/fastapi/_compat.py +++ b/venv/lib/python3.12/site-packages/fastapi/_compat.py @@ -2,7 +2,6 @@ from collections import deque from copy import copy from dataclasses import dataclass, is_dataclass from enum import Enum -from functools import lru_cache from typing import ( Any, Callable, @@ -16,7 +15,6 @@ from typing import ( Tuple, Type, Union, - cast, ) from fastapi.exceptions import RequestErrorModel @@ -26,8 +24,7 @@ from pydantic.version import VERSION as PYDANTIC_VERSION from starlette.datastructures import UploadFile from typing_extensions import Annotated, Literal, get_args, get_origin -PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2]) -PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2 +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") sequence_annotation_to_type = { @@ -46,8 +43,6 @@ sequence_annotation_to_type = { sequence_types = tuple(sequence_annotation_to_type.keys()) -Url: Type[Any] - if PYDANTIC_V2: from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError from pydantic import TypeAdapter @@ -73,7 +68,7 @@ if PYDANTIC_V2: general_plain_validator_function as with_info_plain_validator_function, # noqa: F401 ) - RequiredParam = PydanticUndefined + Required = PydanticUndefined Undefined = PydanticUndefined UndefinedType = PydanticUndefinedType evaluate_forwardref = eval_type_lenient @@ -132,7 +127,7 @@ if PYDANTIC_V2: ) except ValidationError as exc: return None, _regenerate_error_with_loc( - errors=exc.errors(include_url=False), loc_prefix=loc + errors=exc.errors(), loc_prefix=loc ) def serialize( @@ -232,10 +227,6 @@ if PYDANTIC_V2: field_mapping, definitions = schema_generator.generate_definitions( inputs=inputs ) - for item_def in cast(Dict[str, Dict[str, Any]], definitions).values(): - if "description" in item_def: - item_description = cast(str, item_def["description"]).split("\f")[0] - item_def["description"] = item_description return field_mapping, definitions # type: ignore[return-value] def is_scalar_field(field: ModelField) -> bool: @@ -258,12 +249,7 @@ if PYDANTIC_V2: return is_bytes_sequence_annotation(field.type_) def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo: - cls = type(field_info) - merged_field_info = cls.from_annotation(annotation) - new_field_info = copy(field_info) - new_field_info.metadata = merged_field_info.metadata - new_field_info.annotation = merged_field_info.annotation - return new_field_info + return type(field_info).from_annotation(annotation) def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]: origin_type = ( @@ -275,7 +261,7 @@ if PYDANTIC_V2: def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]: error = ValidationError.from_exception_data( "Field required", [{"type": "missing", "loc": loc, "input": {}}] - ).errors(include_url=False)[0] + ).errors()[0] error["input"] = None return error # type: ignore[return-value] @@ -286,12 +272,6 @@ if PYDANTIC_V2: BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload] return BodyModel - def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: - return [ - ModelField(field_info=field_info, name=name) - for name, field_info in model.model_fields.items() - ] - else: from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX from pydantic import AnyUrl as Url # noqa: F401 @@ -319,10 +299,9 @@ else: from pydantic.fields import ( # type: ignore[no-redef,attr-defined] ModelField as ModelField, # noqa: F401 ) - - # Keeping old "Required" functionality from Pydantic V1, without - # shadowing typing.Required. - RequiredParam: Any = Ellipsis # type: ignore[no-redef] + from pydantic.fields import ( # type: ignore[no-redef,attr-defined] + Required as Required, # noqa: F401 + ) from pydantic.fields import ( # type: ignore[no-redef,attr-defined] Undefined as Undefined, ) @@ -393,10 +372,9 @@ else: ) definitions.update(m_definitions) model_name = model_name_map[model] - definitions[model_name] = m_schema - for m_schema in definitions.values(): if "description" in m_schema: m_schema["description"] = m_schema["description"].split("\f")[0] + definitions[model_name] = m_schema return definitions def is_pv1_scalar_field(field: ModelField) -> bool: @@ -528,9 +506,6 @@ else: BodyModel.__fields__[f.name] = f # type: ignore[index] return BodyModel - def get_model_fields(model: Type[BaseModel]) -> List[ModelField]: - return list(model.__fields__.values()) # type: ignore[attr-defined] - def _regenerate_error_with_loc( *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...] @@ -550,12 +525,6 @@ def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool: - origin = get_origin(annotation) - if origin is Union or origin is UnionType: - for arg in get_args(annotation): - if field_annotation_is_sequence(arg): - return True - return False return _annotation_is_sequence(annotation) or _annotation_is_sequence( get_origin(annotation) ) @@ -658,8 +627,3 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool: is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation) for sub_annotation in get_args(annotation) ) - - -@lru_cache -def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]: - return get_model_fields(model) diff --git a/venv/lib/python3.12/site-packages/fastapi/applications.py b/venv/lib/python3.12/site-packages/fastapi/applications.py index b3424ef..3021d75 100644 --- a/venv/lib/python3.12/site-packages/fastapi/applications.py +++ b/venv/lib/python3.12/site-packages/fastapi/applications.py @@ -22,6 +22,7 @@ from fastapi.exception_handlers import ( ) from fastapi.exceptions import RequestValidationError, WebSocketRequestValidationError from fastapi.logger import logger +from fastapi.middleware.asyncexitstack import AsyncExitStackMiddleware from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, @@ -36,11 +37,13 @@ from starlette.datastructures import State from starlette.exceptions import HTTPException from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.errors import ServerErrorMiddleware +from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse, Response from starlette.routing import BaseRoute from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send -from typing_extensions import Annotated, Doc, deprecated +from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined] AppType = TypeVar("AppType", bound="FastAPI") @@ -297,7 +300,7 @@ class FastAPI(Starlette): browser tabs open). Or if you want to leave fixed the possible URLs. If the servers `list` is not provided, or is an empty `list`, the - default value would be a `dict` with a `url` value of `/`. + default value would be a a `dict` with a `url` value of `/`. Each item in the `list` is a `dict` containing: @@ -748,7 +751,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -810,32 +813,6 @@ class FastAPI(Starlette): """ ), ] = True, - openapi_external_docs: Annotated[ - Optional[Dict[str, Any]], - Doc( - """ - This field allows you to provide additional external documentation links. - If provided, it must be a dictionary containing: - - * `description`: A brief description of the external documentation. - * `url`: The URL pointing to the external documentation. The value **MUST** - be a valid URL format. - - **Example**: - - ```python - from fastapi import FastAPI - - external_docs = { - "description": "Detailed API Reference", - "url": "https://example.com/api-docs", - } - - app = FastAPI(openapi_external_docs=external_docs) - ``` - """ - ), - ] = None, **extra: Annotated[ Any, Doc( @@ -864,7 +841,6 @@ class FastAPI(Starlette): self.swagger_ui_parameters = swagger_ui_parameters self.servers = servers or [] self.separate_input_output_schemas = separate_input_output_schemas - self.openapi_external_docs = openapi_external_docs self.extra = extra self.openapi_version: Annotated[ str, @@ -929,7 +905,7 @@ class FastAPI(Starlette): A state object for the application. This is the same object for the entire application, it doesn't change from request to request. - You normally wouldn't use this in FastAPI, for most of the cases you + You normally woudln't use this in FastAPI, for most of the cases you would instead use FastAPI dependencies. This is simply inherited from Starlette. @@ -990,6 +966,55 @@ class FastAPI(Starlette): self.middleware_stack: Union[ASGIApp, None] = None self.setup() + def build_middleware_stack(self) -> ASGIApp: + # Duplicate/override from Starlette to add AsyncExitStackMiddleware + # inside of ExceptionMiddleware, inside of custom user middlewares + debug = self.debug + error_handler = None + exception_handlers = {} + + for key, value in self.exception_handlers.items(): + if key in (500, Exception): + error_handler = value + else: + exception_handlers[key] = value + + middleware = ( + [Middleware(ServerErrorMiddleware, handler=error_handler, debug=debug)] + + self.user_middleware + + [ + Middleware( + ExceptionMiddleware, handlers=exception_handlers, debug=debug + ), + # Add FastAPI-specific AsyncExitStackMiddleware for dependencies with + # contextvars. + # This needs to happen after user middlewares because those create a + # new contextvars context copy by using a new AnyIO task group. + # The initial part of dependencies with 'yield' is executed in the + # FastAPI code, inside all the middlewares. However, the teardown part + # (after 'yield') is executed in the AsyncExitStack in this middleware. + # If the AsyncExitStack lived outside of the custom middlewares and + # contextvars were set in a dependency with 'yield' in that internal + # contextvars context, the values would not be available in the + # outer context of the AsyncExitStack. + # By placing the middleware and the AsyncExitStack here, inside all + # user middlewares, the code before and after 'yield' in dependencies + # with 'yield' is executed in the same contextvars context. Thus, all values + # set in contextvars before 'yield' are still available after 'yield,' as + # expected. + # Additionally, by having this AsyncExitStack here, after the + # ExceptionMiddleware, dependencies can now catch handled exceptions, + # e.g. HTTPException, to customize the teardown code (e.g. DB session + # rollback). + Middleware(AsyncExitStackMiddleware), + ] + ) + + app = self.router + for cls, options in reversed(middleware): + app = cls(app=app, **options) + return app + def openapi(self) -> Dict[str, Any]: """ Generate the OpenAPI schema of the application. This is called by FastAPI @@ -1019,7 +1044,6 @@ class FastAPI(Starlette): tags=self.openapi_tags, servers=self.servers, separate_input_output_schemas=self.separate_input_output_schemas, - external_docs=self.openapi_external_docs, ) return self.openapi_schema @@ -1047,7 +1071,7 @@ class FastAPI(Starlette): oauth2_redirect_url = root_path + oauth2_redirect_url return get_swagger_ui_html( openapi_url=openapi_url, - title=f"{self.title} - Swagger UI", + title=self.title + " - Swagger UI", oauth2_redirect_url=oauth2_redirect_url, init_oauth=self.swagger_ui_init_oauth, swagger_ui_parameters=self.swagger_ui_parameters, @@ -1071,7 +1095,7 @@ class FastAPI(Starlette): root_path = req.scope.get("root_path", "").rstrip("/") openapi_url = root_path + self.openapi_url return get_redoc_html( - openapi_url=openapi_url, title=f"{self.title} - ReDoc" + openapi_url=openapi_url, title=self.title + " - ReDoc" ) self.add_route(self.redoc_url, redoc_html, include_in_schema=False) @@ -1084,7 +1108,7 @@ class FastAPI(Starlette): def add_api_route( self, path: str, - endpoint: Callable[..., Any], + endpoint: Callable[..., Coroutine[Any, Any, Response]], *, response_model: Any = Default(None), status_code: Optional[int] = None, @@ -1748,7 +1772,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2121,7 +2145,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2499,7 +2523,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2877,7 +2901,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3250,7 +3274,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3623,7 +3647,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3996,7 +4020,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -4374,7 +4398,7 @@ class FastAPI(Starlette): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -4453,7 +4477,7 @@ class FastAPI(Starlette): app = FastAPI() - @app.trace("/items/{item_id}") + @app.put("/items/{item_id}") def trace_item(item_id: str): return None ``` @@ -4543,17 +4567,14 @@ class FastAPI(Starlette): ```python import time - from typing import Awaitable, Callable - from fastapi import FastAPI, Request, Response + from fastapi import FastAPI, Request app = FastAPI() @app.middleware("http") - async def add_process_time_header( - request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: + async def add_process_time_header(request: Request, call_next): start_time = time.time() response = await call_next(request) process_time = time.time() - start_time diff --git a/venv/lib/python3.12/site-packages/fastapi/background.py b/venv/lib/python3.12/site-packages/fastapi/background.py index 203578a..35ab1b2 100644 --- a/venv/lib/python3.12/site-packages/fastapi/background.py +++ b/venv/lib/python3.12/site-packages/fastapi/background.py @@ -1,7 +1,7 @@ from typing import Any, Callable from starlette.background import BackgroundTasks as StarletteBackgroundTasks -from typing_extensions import Annotated, Doc, ParamSpec +from typing_extensions import Annotated, Doc, ParamSpec # type: ignore [attr-defined] P = ParamSpec("P") diff --git a/venv/lib/python3.12/site-packages/fastapi/cli.py b/venv/lib/python3.12/site-packages/fastapi/cli.py deleted file mode 100644 index 8d3301e..0000000 --- a/venv/lib/python3.12/site-packages/fastapi/cli.py +++ /dev/null @@ -1,13 +0,0 @@ -try: - from fastapi_cli.cli import main as cli_main - -except ImportError: # pragma: no cover - cli_main = None # type: ignore - - -def main() -> None: - if not cli_main: # type: ignore[truthy-function] - message = 'To use the fastapi command, please install "fastapi[standard]":\n\n\tpip install "fastapi[standard]"\n' - print(message) - raise RuntimeError(message) # noqa: B904 - cli_main() diff --git a/venv/lib/python3.12/site-packages/fastapi/concurrency.py b/venv/lib/python3.12/site-packages/fastapi/concurrency.py index 3202c70..754061c 100644 --- a/venv/lib/python3.12/site-packages/fastapi/concurrency.py +++ b/venv/lib/python3.12/site-packages/fastapi/concurrency.py @@ -1,7 +1,8 @@ +from contextlib import AsyncExitStack as AsyncExitStack # noqa from contextlib import asynccontextmanager as asynccontextmanager from typing import AsyncGenerator, ContextManager, TypeVar -import anyio.to_thread +import anyio from anyio import CapacityLimiter from starlette.concurrency import iterate_in_threadpool as iterate_in_threadpool # noqa from starlette.concurrency import run_in_threadpool as run_in_threadpool # noqa @@ -28,7 +29,7 @@ async def contextmanager_in_threadpool( except Exception as e: ok = bool( await anyio.to_thread.run_sync( - cm.__exit__, type(e), e, e.__traceback__, limiter=exit_limiter + cm.__exit__, type(e), e, None, limiter=exit_limiter ) ) if not ok: diff --git a/venv/lib/python3.12/site-packages/fastapi/datastructures.py b/venv/lib/python3.12/site-packages/fastapi/datastructures.py index cf8406b..ce03e3c 100644 --- a/venv/lib/python3.12/site-packages/fastapi/datastructures.py +++ b/venv/lib/python3.12/site-packages/fastapi/datastructures.py @@ -24,7 +24,7 @@ from starlette.datastructures import Headers as Headers # noqa: F401 from starlette.datastructures import QueryParams as QueryParams # noqa: F401 from starlette.datastructures import State as State # noqa: F401 from starlette.datastructures import UploadFile as StarletteUploadFile -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class UploadFile(StarletteUploadFile): diff --git a/venv/lib/python3.12/site-packages/fastapi/dependencies/models.py b/venv/lib/python3.12/site-packages/fastapi/dependencies/models.py index 418c117..61ef006 100644 --- a/venv/lib/python3.12/site-packages/fastapi/dependencies/models.py +++ b/venv/lib/python3.12/site-packages/fastapi/dependencies/models.py @@ -1,37 +1,58 @@ -from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import Any, Callable, List, Optional, Sequence from fastapi._compat import ModelField from fastapi.security.base import SecurityBase -@dataclass class SecurityRequirement: - security_scheme: SecurityBase - scopes: Optional[Sequence[str]] = None + def __init__( + self, security_scheme: SecurityBase, scopes: Optional[Sequence[str]] = None + ): + self.security_scheme = security_scheme + self.scopes = scopes -@dataclass class Dependant: - path_params: List[ModelField] = field(default_factory=list) - query_params: List[ModelField] = field(default_factory=list) - header_params: List[ModelField] = field(default_factory=list) - cookie_params: List[ModelField] = field(default_factory=list) - body_params: List[ModelField] = field(default_factory=list) - dependencies: List["Dependant"] = field(default_factory=list) - security_requirements: List[SecurityRequirement] = field(default_factory=list) - name: Optional[str] = None - call: Optional[Callable[..., Any]] = None - request_param_name: Optional[str] = None - websocket_param_name: Optional[str] = None - http_connection_param_name: Optional[str] = None - response_param_name: Optional[str] = None - background_tasks_param_name: Optional[str] = None - security_scopes_param_name: Optional[str] = None - security_scopes: Optional[List[str]] = None - use_cache: bool = True - path: Optional[str] = None - cache_key: Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] = field(init=False) - - def __post_init__(self) -> None: + def __init__( + self, + *, + path_params: Optional[List[ModelField]] = None, + query_params: Optional[List[ModelField]] = None, + header_params: Optional[List[ModelField]] = None, + cookie_params: Optional[List[ModelField]] = None, + body_params: Optional[List[ModelField]] = None, + dependencies: Optional[List["Dependant"]] = None, + security_schemes: Optional[List[SecurityRequirement]] = None, + name: Optional[str] = None, + call: Optional[Callable[..., Any]] = None, + request_param_name: Optional[str] = None, + websocket_param_name: Optional[str] = None, + http_connection_param_name: Optional[str] = None, + response_param_name: Optional[str] = None, + background_tasks_param_name: Optional[str] = None, + security_scopes_param_name: Optional[str] = None, + security_scopes: Optional[List[str]] = None, + use_cache: bool = True, + path: Optional[str] = None, + ) -> None: + self.path_params = path_params or [] + self.query_params = query_params or [] + self.header_params = header_params or [] + self.cookie_params = cookie_params or [] + self.body_params = body_params or [] + self.dependencies = dependencies or [] + self.security_requirements = security_schemes or [] + self.request_param_name = request_param_name + self.websocket_param_name = websocket_param_name + self.http_connection_param_name = http_connection_param_name + self.response_param_name = response_param_name + self.background_tasks_param_name = background_tasks_param_name + self.security_scopes = security_scopes + self.security_scopes_param_name = security_scopes_param_name + self.name = name + self.call = call + self.use_cache = use_cache + # Store the path to be able to re-generate a dependable from it in overrides + self.path = path + # Save the cache key at creation to optimize performance self.cache_key = (self.call, tuple(sorted(set(self.security_scopes or [])))) diff --git a/venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py b/venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py index e49380c..96e07a4 100644 --- a/venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py +++ b/venv/lib/python3.12/site-packages/fastapi/dependencies/utils.py @@ -1,8 +1,6 @@ import inspect -import sys -from contextlib import AsyncExitStack, contextmanager -from copy import copy, deepcopy -from dataclasses import dataclass +from contextlib import contextmanager +from copy import deepcopy from typing import ( Any, Callable, @@ -25,7 +23,7 @@ from fastapi._compat import ( PYDANTIC_V2, ErrorWrapper, ModelField, - RequiredParam, + Required, Undefined, _regenerate_error_with_loc, copy_field_info, @@ -33,7 +31,6 @@ from fastapi._compat import ( evaluate_forwardref, field_annotation_is_scalar, get_annotation_from_field_info, - get_cached_model_fields, get_missing_field_error, is_bytes_field, is_bytes_sequence_field, @@ -49,6 +46,7 @@ from fastapi._compat import ( ) from fastapi.background import BackgroundTasks from fastapi.concurrency import ( + AsyncExitStack, asynccontextmanager, contextmanager_in_threadpool, ) @@ -57,28 +55,16 @@ from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect -from fastapi.utils import create_model_field, get_path_param_names -from pydantic import BaseModel +from fastapi.utils import create_response_field, get_path_param_names from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool -from starlette.datastructures import ( - FormData, - Headers, - ImmutableMultiDict, - QueryParams, - UploadFile, -) +from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket from typing_extensions import Annotated, get_args, get_origin -if sys.version_info >= (3, 13): # pragma: no cover - from inspect import iscoroutinefunction -else: # pragma: no cover - from asyncio import iscoroutinefunction - multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' 'You can install "python-multipart" with: \n\n' @@ -94,23 +80,17 @@ multipart_incorrect_install_error = ( ) -def ensure_multipart_is_installed() -> None: - try: - from python_multipart import __version__ - - # Import an attribute that can be mocked/deleted in testing - assert __version__ > "0.0.12" - except (ImportError, AssertionError): +def check_file_field(field: ModelField) -> None: + field_info = field.field_info + if isinstance(field_info, params.Form): try: # __version__ is available in both multiparts, and can be mocked - from multipart import __version__ # type: ignore[no-redef,import-untyped] + from multipart import __version__ # type: ignore assert __version__ try: # parse_options_header is only available in the right multipart - from multipart.multipart import ( # type: ignore[import-untyped] - parse_options_header, - ) + from multipart.multipart import parse_options_header # type: ignore assert parse_options_header except ImportError: @@ -139,9 +119,9 @@ def get_param_sub_dependant( def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: - assert callable(depends.dependency), ( - "A parameter-less dependency must have a callable dependency" - ) + assert callable( + depends.dependency + ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) @@ -196,7 +176,7 @@ def get_flat_dependant( header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), - security_requirements=dependant.security_requirements.copy(), + security_schemes=dependant.security_requirements.copy(), use_cache=dependant.use_cache, path=dependant.path, ) @@ -215,23 +195,14 @@ def get_flat_dependant( return flat_dependant -def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]: - if not fields: - return fields - first_field = fields[0] - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - return fields_to_extract - return fields - - def get_flat_params(dependant: Dependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - path_params = _get_flat_fields_from_params(flat_dependant.path_params) - query_params = _get_flat_fields_from_params(flat_dependant.query_params) - header_params = _get_flat_fields_from_params(flat_dependant.header_params) - cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params) - return path_params + query_params + header_params + cookie_params + return ( + flat_dependant.path_params + + flat_dependant.query_params + + flat_dependant.header_params + + flat_dependant.cookie_params + ) def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: @@ -254,8 +225,6 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) - if annotation is type(None): - return None return annotation @@ -290,16 +259,16 @@ def get_dependant( ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names - param_details = analyze_param( + type_annotation, depends, param_field = analyze_param( param_name=param_name, annotation=param.annotation, value=param.default, is_path_param=is_path_param, ) - if param_details.depends is not None: + if depends is not None: sub_dependant = get_param_sub_dependant( param_name=param_name, - depends=param_details.depends, + depends=depends, path=path, security_scopes=security_scopes, ) @@ -307,18 +276,18 @@ def get_dependant( continue if add_non_field_param_to_dependency( param_name=param_name, - type_annotation=param_details.type_annotation, + type_annotation=type_annotation, dependant=dependant, ): - assert param_details.field is None, ( - f"Cannot specify multiple FastAPI annotations for {param_name!r}" - ) + assert ( + param_field is None + ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue - assert param_details.field is not None - if isinstance(param_details.field.field_info, params.Body): - dependant.body_params.append(param_details.field) + assert param_field is not None + if is_body_param(param_field=param_field, is_path_param=is_path_param): + dependant.body_params.append(param_field) else: - add_param_to_fields(field=param_details.field, dependant=dependant) + add_param_to_fields(field=param_field, dependant=dependant) return dependant @@ -346,29 +315,20 @@ def add_non_field_param_to_dependency( return None -@dataclass -class ParamDetails: - type_annotation: Any - depends: Optional[params.Depends] - field: Optional[ModelField] - - def analyze_param( *, param_name: str, annotation: Any, value: Any, is_path_param: bool, -) -> ParamDetails: +) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]: field_info = None depends = None type_annotation: Any = Any - use_annotation: Any = Any - if annotation is not inspect.Signature.empty: - use_annotation = annotation - type_annotation = annotation - # Extract Annotated info - if get_origin(use_annotation) is Annotated: + if ( + annotation is not inspect.Signature.empty + and get_origin(annotation) is Annotated + ): annotated_args = get_args(annotation) type_annotation = annotated_args[0] fastapi_annotations = [ @@ -376,26 +336,16 @@ def analyze_param( for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends)) ] - fastapi_specific_annotations = [ - arg - for arg in fastapi_annotations - if isinstance(arg, (params.Param, params.Body, params.Depends)) - ] - if fastapi_specific_annotations: - fastapi_annotation: Union[FieldInfo, params.Depends, None] = ( - fastapi_specific_annotations[-1] - ) - else: - fastapi_annotation = None - # Set default for Annotated FieldInfo + assert ( + len(fastapi_annotations) <= 1 + ), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}" + fastapi_annotation = next(iter(fastapi_annotations), None) if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. field_info = copy_field_info( - field_info=fastapi_annotation, annotation=use_annotation + field_info=fastapi_annotation, annotation=annotation ) - assert ( - field_info.default is Undefined or field_info.default is RequiredParam - ), ( + assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." ) @@ -403,11 +353,12 @@ def analyze_param( assert not is_path_param, "Path parameters cannot have default values" field_info.default = value else: - field_info.default = RequiredParam - # Get Annotated Depends + field_info.default = Required elif isinstance(fastapi_annotation, params.Depends): depends = fastapi_annotation - # Get Depends from default value + elif annotation is not inspect.Signature.empty: + type_annotation = annotation + if isinstance(value, params.Depends): assert depends is None, ( "Cannot specify `Depends` in `Annotated` and default value" @@ -418,7 +369,6 @@ def analyze_param( f" default value together for {param_name!r}" ) depends = value - # Get FieldInfo from default value elif isinstance(value, FieldInfo): assert field_info is None, ( "Cannot specify FastAPI annotations in `Annotated` and default value" @@ -428,13 +378,9 @@ def analyze_param( if PYDANTIC_V2: field_info.annotation = type_annotation - # Get Depends from type annotation if depends is not None and depends.dependency is None: - # Copy `depends` before mutating it - depends = copy(depends) depends.dependency = type_annotation - # Handle non-param type annotations like Request if lenient_issubclass( type_annotation, ( @@ -447,30 +393,27 @@ def analyze_param( ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" - assert field_info is None, ( - f"Cannot specify FastAPI annotation for type {type_annotation!r}" - ) - # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value + assert ( + field_info is None + ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" elif field_info is None and depends is None: - default_value = value if value is not inspect.Signature.empty else RequiredParam + default_value = value if value is not inspect.Signature.empty else Required if is_path_param: - # We might check here that `default_value is RequiredParam`, but the fact is that the same + # We might check here that `default_value is Required`, but the fact is that the same # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. - field_info = params.Path(annotation=use_annotation) + field_info = params.Path(annotation=type_annotation) elif is_uploadfile_or_nonable_uploadfile_annotation( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): - field_info = params.File(annotation=use_annotation, default=default_value) + field_info = params.File(annotation=type_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): - field_info = params.Body(annotation=use_annotation, default=default_value) + field_info = params.Body(annotation=type_annotation, default=default_value) else: - field_info = params.Query(annotation=use_annotation, default=default_value) + field_info = params.Query(annotation=type_annotation, default=default_value) field = None - # It's a field_info, not a dependency if field_info is not None: - # Handle field_info.in_ if is_path_param: assert isinstance(field_info, params.Path), ( f"Cannot use `{field_info.__class__.__name__}` for path param" @@ -481,67 +424,69 @@ def analyze_param( and getattr(field_info, "in_", None) is None ): field_info.in_ = params.ParamTypes.query - use_annotation_from_field_info = get_annotation_from_field_info( - use_annotation, + use_annotation = get_annotation_from_field_info( + type_annotation, field_info, param_name, ) - if isinstance(field_info, params.Form): - ensure_multipart_is_installed() if not field_info.alias and getattr(field_info, "convert_underscores", None): alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name field_info.alias = alias - field = create_model_field( + field = create_response_field( name=param_name, - type_=use_annotation_from_field_info, + type_=use_annotation, default=field_info.default, alias=alias, - required=field_info.default in (RequiredParam, Undefined), + required=field_info.default in (Required, Undefined), field_info=field_info, ) - if is_path_param: - assert is_scalar_field(field=field), ( - "Path params must be of one of the supported types" - ) - elif isinstance(field_info, params.Query): - assert ( - is_scalar_field(field) - or is_scalar_sequence_field(field) - or ( - lenient_issubclass(field.type_, BaseModel) - # For Pydantic v1 - and getattr(field, "shape", 1) == 1 - ) - ) - return ParamDetails(type_annotation=type_annotation, depends=depends, field=field) + return type_annotation, depends, field + + +def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: + if is_path_param: + assert is_scalar_field( + field=param_field + ), "Path params must be of one of the supported types" + return False + elif is_scalar_field(field=param_field): + return False + elif isinstance( + param_field.field_info, (params.Query, params.Header) + ) and is_scalar_sequence_field(param_field): + return False + else: + assert isinstance( + param_field.field_info, params.Body + ), f"Param: {param_field.name} can only be a request body, using Body()" + return True def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: - field_info = field.field_info - field_info_in = getattr(field_info, "in_", None) - if field_info_in == params.ParamTypes.path: + field_info = cast(params.Param, field.field_info) + if field_info.in_ == params.ParamTypes.path: dependant.path_params.append(field) - elif field_info_in == params.ParamTypes.query: + elif field_info.in_ == params.ParamTypes.query: dependant.query_params.append(field) - elif field_info_in == params.ParamTypes.header: + elif field_info.in_ == params.ParamTypes.header: dependant.header_params.append(field) else: - assert field_info_in == params.ParamTypes.cookie, ( - f"non-body parameters must be in path, query, header or cookie: {field.name}" - ) + assert ( + field_info.in_ == params.ParamTypes.cookie + ), f"non-body parameters must be in path, query, header or cookie: {field.name}" dependant.cookie_params.append(field) def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): - return iscoroutinefunction(call) + return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False dunder_call = getattr(call, "__call__", None) # noqa: B004 - return iscoroutinefunction(dunder_call) + return inspect.iscoroutinefunction(dunder_call) def is_async_gen_callable(call: Callable[..., Any]) -> bool: @@ -568,15 +513,6 @@ async def solve_generator( return await stack.enter_async_context(cm) -@dataclass -class SolvedDependency: - values: Dict[str, Any] - errors: List[Any] - background_tasks: Optional[StarletteBackgroundTasks] - response: Response - dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any] - - async def solve_dependencies( *, request: Union[Request, WebSocket], @@ -586,17 +522,20 @@ async def solve_dependencies( response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, - async_exit_stack: AsyncExitStack, - embed_body_fields: bool, -) -> SolvedDependency: +) -> Tuple[ + Dict[str, Any], + List[Any], + Optional[StarletteBackgroundTasks], + Response, + Dict[Tuple[Callable[..., Any], Tuple[str]], Any], +]: values: Dict[str, Any] = {} errors: List[Any] = [] if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore - if dependency_cache is None: - dependency_cache = {} + dependency_cache = dependency_cache or {} sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) @@ -629,23 +568,30 @@ async def solve_dependencies( response=response, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, ) - background_tasks = solved_result.background_tasks - if solved_result.errors: - errors.extend(solved_result.errors) + ( + sub_values, + sub_errors, + background_tasks, + _, # the subdependency returns the same response we have + sub_dependency_cache, + ) = solved_result + dependency_cache.update(sub_dependency_cache) + if sub_errors: + errors.extend(sub_errors) continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): + stack = request.scope.get("fastapi_astack") + assert isinstance(stack, AsyncExitStack) solved = await solve_generator( - call=call, stack=async_exit_stack, sub_values=solved_result.values + call=call, stack=stack, sub_values=sub_values ) elif is_coroutine_callable(call): - solved = await call(**solved_result.values) + solved = await call(**sub_values) else: - solved = await run_in_threadpool(call, **solved_result.values) + solved = await run_in_threadpool(call, **sub_values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: @@ -672,9 +618,7 @@ async def solve_dependencies( body_values, body_errors, ) = await request_body_to_args( # body_params checked above - body_fields=dependant.body_params, - received_body=body, - embed_body_fields=embed_body_fields, + required_params=dependant.body_params, received_body=body ) values.update(body_values) errors.extend(body_errors) @@ -694,289 +638,142 @@ async def solve_dependencies( values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) - return SolvedDependency( - values=values, - errors=errors, - background_tasks=background_tasks, - response=response, - dependency_cache=dependency_cache, - ) - - -def _validate_value_with_model_field( - *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...] -) -> Tuple[Any, List[Any]]: - if value is None: - if field.required: - return None, [get_missing_field_error(loc=loc)] - else: - return deepcopy(field.default), [] - v_, errors_ = field.validate(value, values, loc=loc) - if isinstance(errors_, ErrorWrapper): - return None, [errors_] - elif isinstance(errors_, list): - new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) - return None, new_errors - else: - return v_, [] - - -def _get_multidict_value( - field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None -) -> Any: - alias = alias or field.alias - if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)): - value = values.getlist(alias) - else: - value = values.get(alias, None) - if ( - value is None - or ( - isinstance(field.field_info, params.Form) - and isinstance(value, str) # For type checks - and value == "" - ) - or (is_sequence_field(field) and len(value) == 0) - ): - if field.required: - return - else: - return deepcopy(field.default) - return value + return values, errors, background_tasks, response, dependency_cache def request_params_to_args( - fields: Sequence[ModelField], + required_params: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], ) -> Tuple[Dict[str, Any], List[Any]]: - values: Dict[str, Any] = {} - errors: List[Dict[str, Any]] = [] - - if not fields: - return values, errors - - first_field = fields[0] - fields_to_extract = fields - single_not_embedded_field = False - default_convert_underscores = True - if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel): - fields_to_extract = get_cached_model_fields(first_field.type_) - single_not_embedded_field = True - # If headers are in a Pydantic model, the way to disable convert_underscores - # would be with Header(convert_underscores=False) at the Pydantic model level - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) - - params_to_process: Dict[str, Any] = {} - - processed_keys = set() - - for field in fields_to_extract: - alias = None - if isinstance(received_params, Headers): - # Handle fields extracted from a Pydantic Model for a header, each field - # doesn't have a FieldInfo of type Header with the default convert_underscores=True - convert_underscores = getattr( - field.field_info, "convert_underscores", default_convert_underscores - ) - if convert_underscores: - alias = ( - field.alias - if field.alias != field.name - else field.name.replace("_", "-") - ) - value = _get_multidict_value(field, received_params, alias=alias) - if value is not None: - params_to_process[field.name] = value - processed_keys.add(alias or field.alias) - processed_keys.add(field.name) - - for key, value in received_params.items(): - if key not in processed_keys: - params_to_process[key] = value - - if single_not_embedded_field: - field_info = first_field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) - loc: Tuple[str, ...] = (field_info.in_.value,) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=params_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - - for field in fields: - value = _get_multidict_value(field, received_params) + values = {} + errors = [] + for field in required_params: + if is_scalar_sequence_field(field) and isinstance( + received_params, (QueryParams, Headers) + ): + value = received_params.getlist(field.alias) or field.default + else: + value = received_params.get(field.alias) field_info = field.field_info - assert isinstance(field_info, params.Param), ( - "Params must be subclasses of Param" - ) + assert isinstance( + field_info, params.Param + ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) - if errors_: - errors.extend(errors_) + if value is None: + if field.required: + errors.append(get_missing_field_error(loc=loc)) + else: + values[field.name] = deepcopy(field.default) + continue + v_, errors_ = field.validate(value, values, loc=loc) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) + errors.extend(new_errors) else: values[field.name] = v_ return values, errors -def is_union_of_base_models(field_type: Any) -> bool: - """Check if field type is a Union where all members are BaseModel subclasses.""" - from fastapi.types import UnionType - - origin = get_origin(field_type) - - # Check if it's a Union type (covers both typing.Union and types.UnionType in Python 3.10+) - if origin is not Union and origin is not UnionType: - return False - - union_args = get_args(field_type) - - for arg in union_args: - if not lenient_issubclass(arg, BaseModel): - return False - - return True - - -def _should_embed_body_fields(fields: List[ModelField]) -> bool: - if not fields: - return False - # More than one dependency could have the same field, it would show up as multiple - # fields but it's the same one, so count them by name - body_param_names_set = {field.name for field in fields} - # A top level field has to be a single field, not multiple - if len(body_param_names_set) > 1: - return True - first_field = fields[0] - # If it explicitly specifies it is embedded, it has to be embedded - if getattr(first_field.field_info, "embed", None): - return True - # If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level - # otherwise it has to be embedded, so that the key value pair can be extracted - if ( - isinstance(first_field.field_info, params.Form) - and not lenient_issubclass(first_field.type_, BaseModel) - and not is_union_of_base_models(first_field.type_) - ): - return True - return False - - -async def _extract_form_body( - body_fields: List[ModelField], - received_body: FormData, -) -> Dict[str, Any]: - values = {} - - for field in body_fields: - value = _get_multidict_value(field, received_body) - field_info = field.field_info - if ( - isinstance(field_info, params.File) - and is_bytes_field(field) - and isinstance(value, UploadFile) - ): - value = await value.read() - elif ( - is_bytes_sequence_field(field) - and isinstance(field_info, params.File) - and value_is_sequence(value) - ): - # For types - assert isinstance(value, sequence_types) # type: ignore[arg-type] - results: List[Union[bytes, str]] = [] - - async def process_fn( - fn: Callable[[], Coroutine[Any, Any, Any]], - ) -> None: - result = await fn() - results.append(result) # noqa: B023 - - async with anyio.create_task_group() as tg: - for sub_value in value: - tg.start_soon(process_fn, sub_value.read) - value = serialize_sequence_value(field=field, value=results) - if value is not None: - values[field.alias] = value - for key, value in received_body.items(): - if key not in values: - values[key] = value - return values - - async def request_body_to_args( - body_fields: List[ModelField], + required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], - embed_body_fields: bool, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: - values: Dict[str, Any] = {} + values = {} errors: List[Dict[str, Any]] = [] - assert body_fields, "request_body_to_args() should be called with fields" - single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields - first_field = body_fields[0] - body_to_process = received_body + if required_params: + field = required_params[0] + field_info = field.field_info + embed = getattr(field_info, "embed", None) + field_alias_omitted = len(required_params) == 1 and not embed + if field_alias_omitted: + received_body = {field.alias: received_body} - fields_to_extract: List[ModelField] = body_fields + for field in required_params: + loc: Tuple[str, ...] + if field_alias_omitted: + loc = ("body",) + else: + loc = ("body", field.alias) - if ( - single_not_embedded_field - and lenient_issubclass(first_field.type_, BaseModel) - and isinstance(received_body, FormData) - ): - fields_to_extract = get_cached_model_fields(first_field.type_) - - if isinstance(received_body, FormData): - body_to_process = await _extract_form_body(fields_to_extract, received_body) - - if single_not_embedded_field: - loc: Tuple[str, ...] = ("body",) - v_, errors_ = _validate_value_with_model_field( - field=first_field, value=body_to_process, values=values, loc=loc - ) - return {first_field.name: v_}, errors_ - for field in body_fields: - loc = ("body", field.alias) - value: Optional[Any] = None - if body_to_process is not None: - try: - value = body_to_process.get(field.alias) - # If the received body is a list, not a dict - except AttributeError: - errors.append(get_missing_field_error(loc)) + value: Optional[Any] = None + if received_body is not None: + if (is_sequence_field(field)) and isinstance(received_body, FormData): + value = received_body.getlist(field.alias) + else: + try: + value = received_body.get(field.alias) + except AttributeError: + errors.append(get_missing_field_error(loc)) + continue + if ( + value is None + or (isinstance(field_info, params.Form) and value == "") + or ( + isinstance(field_info, params.Form) + and is_sequence_field(field) + and len(value) == 0 + ) + ): + if field.required: + errors.append(get_missing_field_error(loc)) + else: + values[field.name] = deepcopy(field.default) continue - v_, errors_ = _validate_value_with_model_field( - field=field, value=value, values=values, loc=loc - ) - if errors_: - errors.extend(errors_) - else: - values[field.name] = v_ + if ( + isinstance(field_info, params.File) + and is_bytes_field(field) + and isinstance(value, UploadFile) + ): + value = await value.read() + elif ( + is_bytes_sequence_field(field) + and isinstance(field_info, params.File) + and value_is_sequence(value) + ): + # For types + assert isinstance(value, sequence_types) # type: ignore[arg-type] + results: List[Union[bytes, str]] = [] + + async def process_fn( + fn: Callable[[], Coroutine[Any, Any, Any]] + ) -> None: + result = await fn() + results.append(result) # noqa: B023 + + async with anyio.create_task_group() as tg: + for sub_value in value: + tg.start_soon(process_fn, sub_value.read) + value = serialize_sequence_value(field=field, value=results) + + v_, errors_ = field.validate(value, values, loc=loc) + + if isinstance(errors_, list): + errors.extend(errors_) + elif errors_: + errors.append(errors_) + else: + values[field.name] = v_ return values, errors -def get_body_field( - *, flat_dependant: Dependant, name: str, embed_body_fields: bool -) -> Optional[ModelField]: - """ - Get a ModelField representing the request body for a path operation, combining - all body parameters into a single field if necessary. - - Used to check if it's form data (with `isinstance(body_field, params.Form)`) - or JSON and to generate the JSON Schema for a request body. - - This is **not** used to validate/parse the request body, that's done with each - individual body parameter. - """ +def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: + flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] - if not embed_body_fields: + field_info = first_param.field_info + embed = getattr(field_info, "embed", None) + body_param_names_set = {param.name for param in flat_dependant.body_params} + if len(body_param_names_set) == 1 and not embed: + check_file_field(first_param) return first_param + # If one field requires to embed, all have to be embedded + # in case a sub-dependency is evaluated with a single unique body field + # That is combined (embedded) with other body fields + for param in flat_dependant.body_params: + setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name BodyModel = create_body_model( fields=flat_dependant.body_params, model_name=model_name @@ -1002,11 +799,12 @@ def get_body_field( ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] - final_field = create_model_field( + final_field = create_response_field( name="body", type_=BodyModel, required=required, alias="body", field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), ) + check_file_field(final_field) return final_field diff --git a/venv/lib/python3.12/site-packages/fastapi/encoders.py b/venv/lib/python3.12/site-packages/fastapi/encoders.py index b037f8b..e501713 100644 --- a/venv/lib/python3.12/site-packages/fastapi/encoders.py +++ b/venv/lib/python3.12/site-packages/fastapi/encoders.py @@ -22,9 +22,9 @@ from pydantic import BaseModel from pydantic.color import Color from pydantic.networks import AnyUrl, NameEmail from pydantic.types import SecretBytes, SecretStr -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] -from ._compat import PYDANTIC_V2, UndefinedType, Url, _model_dump +from ._compat import PYDANTIC_V2, Url, _model_dump # Taken from Pydantic v1 as is @@ -86,7 +86,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]], + type_encoder_map: Dict[Any, Callable[[Any], Any]] ) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( tuple @@ -219,7 +219,7 @@ def jsonable_encoder( if not PYDANTIC_V2: encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: - encoders = {**encoders, **custom_encoder} + encoders.update(custom_encoder) obj_dict = _model_dump( obj, mode="json", @@ -241,7 +241,6 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - assert not isinstance(obj, type) obj_dict = dataclasses.asdict(obj) return jsonable_encoder( obj_dict, @@ -260,8 +259,6 @@ def jsonable_encoder( return str(obj) if isinstance(obj, (str, int, float, type(None))): return obj - if isinstance(obj, UndefinedType): - return None if isinstance(obj, dict): encoded_dict = {} allowed_keys = set(obj.keys()) diff --git a/venv/lib/python3.12/site-packages/fastapi/exception_handlers.py b/venv/lib/python3.12/site-packages/fastapi/exception_handlers.py index 475dd7b..6c2ba7f 100644 --- a/venv/lib/python3.12/site-packages/fastapi/exception_handlers.py +++ b/venv/lib/python3.12/site-packages/fastapi/exception_handlers.py @@ -5,7 +5,7 @@ from fastapi.websockets import WebSocket from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import JSONResponse, Response -from starlette.status import WS_1008_POLICY_VIOLATION +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, WS_1008_POLICY_VIOLATION async def http_exception_handler(request: Request, exc: HTTPException) -> Response: @@ -21,7 +21,7 @@ async def request_validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: return JSONResponse( - status_code=422, + status_code=HTTP_422_UNPROCESSABLE_ENTITY, content={"detail": jsonable_encoder(exc.errors())}, ) diff --git a/venv/lib/python3.12/site-packages/fastapi/exceptions.py b/venv/lib/python3.12/site-packages/fastapi/exceptions.py index 44d4ada..680d288 100644 --- a/venv/lib/python3.12/site-packages/fastapi/exceptions.py +++ b/venv/lib/python3.12/site-packages/fastapi/exceptions.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence, Type, Union from pydantic import BaseModel, create_model from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import WebSocketException as StarletteWebSocketException -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class HTTPException(StarletteHTTPException): diff --git a/venv/lib/python3.12/site-packages/fastapi/middleware/asyncexitstack.py b/venv/lib/python3.12/site-packages/fastapi/middleware/asyncexitstack.py new file mode 100644 index 0000000..30a0ae6 --- /dev/null +++ b/venv/lib/python3.12/site-packages/fastapi/middleware/asyncexitstack.py @@ -0,0 +1,25 @@ +from typing import Optional + +from fastapi.concurrency import AsyncExitStack +from starlette.types import ASGIApp, Receive, Scope, Send + + +class AsyncExitStackMiddleware: + def __init__(self, app: ASGIApp, context_name: str = "fastapi_astack") -> None: + self.app = app + self.context_name = context_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + dependency_exception: Optional[Exception] = None + async with AsyncExitStack() as stack: + scope[self.context_name] = stack + try: + await self.app(scope, receive, send) + except Exception as e: + dependency_exception = e + raise e + if dependency_exception: + # This exception was possibly handled by the dependency but it should + # still bubble up so that the ServerErrorMiddleware can return a 500 + # or the ExceptionMiddleware can catch and handle any other exceptions + raise dependency_exception diff --git a/venv/lib/python3.12/site-packages/fastapi/openapi/docs.py b/venv/lib/python3.12/site-packages/fastapi/openapi/docs.py index f181b43..69473d1 100644 --- a/venv/lib/python3.12/site-packages/fastapi/openapi/docs.py +++ b/venv/lib/python3.12/site-packages/fastapi/openapi/docs.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional from fastapi.encoders import jsonable_encoder from starlette.responses import HTMLResponse -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] swagger_ui_default_parameters: Annotated[ Dict[str, Any], @@ -53,7 +53,7 @@ def get_swagger_ui_html( It is normally set to a CDN URL. """ ), - ] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js", + ] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui-bundle.js", swagger_css_url: Annotated[ str, Doc( @@ -63,7 +63,7 @@ def get_swagger_ui_html( It is normally set to a CDN URL. """ ), - ] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css", + ] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5.9.0/swagger-ui.css", swagger_favicon_url: Annotated[ str, Doc( @@ -188,7 +188,7 @@ def get_redoc_html( It is normally set to a CDN URL. """ ), - ] = "https://cdn.jsdelivr.net/npm/redoc@2/bundles/redoc.standalone.js", + ] = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js", redoc_favicon_url: Annotated[ str, Doc( diff --git a/venv/lib/python3.12/site-packages/fastapi/openapi/models.py b/venv/lib/python3.12/site-packages/fastapi/openapi/models.py index 81d276a..5f3bdbb 100644 --- a/venv/lib/python3.12/site-packages/fastapi/openapi/models.py +++ b/venv/lib/python3.12/site-packages/fastapi/openapi/models.py @@ -55,7 +55,11 @@ except ImportError: # pragma: no cover return with_info_plain_validator_function(cls._validate) -class BaseModelWithConfig(BaseModel): +class Contact(BaseModel): + name: Optional[str] = None + url: Optional[AnyUrl] = None + email: Optional[EmailStr] = None + if PYDANTIC_V2: model_config = {"extra": "allow"} @@ -65,19 +69,21 @@ class BaseModelWithConfig(BaseModel): extra = "allow" -class Contact(BaseModelWithConfig): - name: Optional[str] = None - url: Optional[AnyUrl] = None - email: Optional[EmailStr] = None - - -class License(BaseModelWithConfig): +class License(BaseModel): name: str identifier: Optional[str] = None url: Optional[AnyUrl] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Info(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Info(BaseModel): title: str summary: Optional[str] = None description: Optional[str] = None @@ -86,18 +92,42 @@ class Info(BaseModelWithConfig): license: Optional[License] = None version: str + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class ServerVariable(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class ServerVariable(BaseModel): enum: Annotated[Optional[List[str]], Field(min_length=1)] = None default: str description: Optional[str] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Server(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Server(BaseModel): url: Union[AnyUrl, str] description: Optional[str] = None variables: Optional[Dict[str, ServerVariable]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class Reference(BaseModel): ref: str = Field(alias="$ref") @@ -108,26 +138,36 @@ class Discriminator(BaseModel): mapping: Optional[Dict[str, str]] = None -class XML(BaseModelWithConfig): +class XML(BaseModel): name: Optional[str] = None namespace: Optional[str] = None prefix: Optional[str] = None attribute: Optional[bool] = None wrapped: Optional[bool] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class ExternalDocumentation(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class ExternalDocumentation(BaseModel): description: Optional[str] = None url: AnyUrl + if PYDANTIC_V2: + model_config = {"extra": "allow"} -# Ref JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation#name-type -SchemaType = Literal[ - "array", "boolean", "integer", "null", "number", "object", "string" -] + else: + + class Config: + extra = "allow" -class Schema(BaseModelWithConfig): +class Schema(BaseModel): # Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-the-json-schema-core-vocabu # Core Vocabulary schema_: Optional[str] = Field(default=None, alias="$schema") @@ -151,7 +191,7 @@ class Schema(BaseModelWithConfig): dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None prefixItems: Optional[List["SchemaOrBool"]] = None # TODO: uncomment and remove below when deprecating Pydantic v1 - # It generates a list of schemas for tuples, before prefixItems was available + # It generales a list of schemas for tuples, before prefixItems was available # items: Optional["SchemaOrBool"] = None items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None contains: Optional["SchemaOrBool"] = None @@ -163,7 +203,7 @@ class Schema(BaseModelWithConfig): unevaluatedProperties: Optional["SchemaOrBool"] = None # Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural # A Vocabulary for Structural Validation - type: Optional[Union[SchemaType, List[SchemaType]]] = None + type: Optional[str] = None enum: Optional[List[Any]] = None const: Optional[Any] = None multipleOf: Optional[float] = Field(default=None, gt=0) @@ -213,6 +253,14 @@ class Schema(BaseModelWithConfig): ), ] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + # Ref: https://json-schema.org/draft/2020-12/json-schema-core.html#name-json-schema-documents # A JSON Schema MUST be an object or a boolean. @@ -241,22 +289,38 @@ class ParameterInType(Enum): cookie = "cookie" -class Encoding(BaseModelWithConfig): +class Encoding(BaseModel): contentType: Optional[str] = None headers: Optional[Dict[str, Union["Header", Reference]]] = None style: Optional[str] = None explode: Optional[bool] = None allowReserved: Optional[bool] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class MediaType(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class MediaType(BaseModel): schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema") example: Optional[Any] = None examples: Optional[Dict[str, Union[Example, Reference]]] = None encoding: Optional[Dict[str, Encoding]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class ParameterBase(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class ParameterBase(BaseModel): description: Optional[str] = None required: Optional[bool] = None deprecated: Optional[bool] = None @@ -270,6 +334,14 @@ class ParameterBase(BaseModelWithConfig): # Serialization rules for more complex scenarios content: Optional[Dict[str, MediaType]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class Parameter(ParameterBase): name: str @@ -280,13 +352,21 @@ class Header(ParameterBase): pass -class RequestBody(BaseModelWithConfig): +class RequestBody(BaseModel): description: Optional[str] = None content: Dict[str, MediaType] required: Optional[bool] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Link(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Link(BaseModel): operationRef: Optional[str] = None operationId: Optional[str] = None parameters: Optional[Dict[str, Union[Any, str]]] = None @@ -294,15 +374,31 @@ class Link(BaseModelWithConfig): description: Optional[str] = None server: Optional[Server] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Response(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Response(BaseModel): description: str headers: Optional[Dict[str, Union[Header, Reference]]] = None content: Optional[Dict[str, MediaType]] = None links: Optional[Dict[str, Union[Link, Reference]]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Operation(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Operation(BaseModel): tags: Optional[List[str]] = None summary: Optional[str] = None description: Optional[str] = None @@ -317,8 +413,16 @@ class Operation(BaseModelWithConfig): security: Optional[List[Dict[str, List[str]]]] = None servers: Optional[List[Server]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class PathItem(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class PathItem(BaseModel): ref: Optional[str] = Field(default=None, alias="$ref") summary: Optional[str] = None description: Optional[str] = None @@ -333,6 +437,14 @@ class PathItem(BaseModelWithConfig): servers: Optional[List[Server]] = None parameters: Optional[List[Union[Parameter, Reference]]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class SecuritySchemeType(Enum): apiKey = "apiKey" @@ -341,10 +453,18 @@ class SecuritySchemeType(Enum): openIdConnect = "openIdConnect" -class SecurityBase(BaseModelWithConfig): +class SecurityBase(BaseModel): type_: SecuritySchemeType = Field(alias="type") description: Optional[str] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class APIKeyIn(Enum): query = "query" @@ -368,10 +488,18 @@ class HTTPBearer(HTTPBase): bearerFormat: Optional[str] = None -class OAuthFlow(BaseModelWithConfig): +class OAuthFlow(BaseModel): refreshUrl: Optional[str] = None scopes: Dict[str, str] = {} + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class OAuthFlowImplicit(OAuthFlow): authorizationUrl: str @@ -390,12 +518,20 @@ class OAuthFlowAuthorizationCode(OAuthFlow): tokenUrl: str -class OAuthFlows(BaseModelWithConfig): +class OAuthFlows(BaseModel): implicit: Optional[OAuthFlowImplicit] = None password: Optional[OAuthFlowPassword] = None clientCredentials: Optional[OAuthFlowClientCredentials] = None authorizationCode: Optional[OAuthFlowAuthorizationCode] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + class OAuth2(SecurityBase): type_: SecuritySchemeType = Field(default=SecuritySchemeType.oauth2, alias="type") @@ -412,7 +548,7 @@ class OpenIdConnect(SecurityBase): SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer] -class Components(BaseModelWithConfig): +class Components(BaseModel): schemas: Optional[Dict[str, Union[Schema, Reference]]] = None responses: Optional[Dict[str, Union[Response, Reference]]] = None parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None @@ -425,14 +561,30 @@ class Components(BaseModelWithConfig): callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class Tag(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class Tag(BaseModel): name: str description: Optional[str] = None externalDocs: Optional[ExternalDocumentation] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} -class OpenAPI(BaseModelWithConfig): + else: + + class Config: + extra = "allow" + + +class OpenAPI(BaseModel): openapi: str info: Info jsonSchemaDialect: Optional[str] = None @@ -445,6 +597,14 @@ class OpenAPI(BaseModelWithConfig): tags: Optional[List[Tag]] = None externalDocs: Optional[ExternalDocumentation] = None + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + _model_rebuild(Schema) _model_rebuild(Operation) diff --git a/venv/lib/python3.12/site-packages/fastapi/openapi/utils.py b/venv/lib/python3.12/site-packages/fastapi/openapi/utils.py index 21105cf..5bfb5ac 100644 --- a/venv/lib/python3.12/site-packages/fastapi/openapi/utils.py +++ b/venv/lib/python3.12/site-packages/fastapi/openapi/utils.py @@ -16,15 +16,11 @@ from fastapi._compat import ( ) from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant -from fastapi.dependencies.utils import ( - _get_flat_fields_from_params, - get_flat_dependant, - get_flat_params, -) +from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE from fastapi.openapi.models import OpenAPI -from fastapi.params import Body, ParamTypes +from fastapi.params import Body, Param from fastapi.responses import Response from fastapi.types import ModelNameMap from fastapi.utils import ( @@ -32,9 +28,9 @@ from fastapi.utils import ( generate_operation_id_for_path, is_body_allowed_for_status_code, ) -from pydantic import BaseModel from starlette.responses import JSONResponse from starlette.routing import BaseRoute +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from typing_extensions import Literal validation_error_definition = { @@ -91,9 +87,9 @@ def get_openapi_security_definitions( return security_definitions, operation_security -def _get_openapi_operation_parameters( +def get_openapi_operation_parameters( *, - dependant: Dependant, + all_route_params: Sequence[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ @@ -102,67 +98,33 @@ def _get_openapi_operation_parameters( separate_input_output_schemas: bool = True, ) -> List[Dict[str, Any]]: parameters = [] - flat_dependant = get_flat_dependant(dependant, skip_repeats=True) - path_params = _get_flat_fields_from_params(flat_dependant.path_params) - query_params = _get_flat_fields_from_params(flat_dependant.query_params) - header_params = _get_flat_fields_from_params(flat_dependant.header_params) - cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params) - parameter_groups = [ - (ParamTypes.path, path_params), - (ParamTypes.query, query_params), - (ParamTypes.header, header_params), - (ParamTypes.cookie, cookie_params), - ] - default_convert_underscores = True - if len(flat_dependant.header_params) == 1: - first_field = flat_dependant.header_params[0] - if lenient_issubclass(first_field.type_, BaseModel): - default_convert_underscores = getattr( - first_field.field_info, "convert_underscores", True - ) - for param_type, param_group in parameter_groups: - for param in param_group: - field_info = param.field_info - # field_info = cast(Param, field_info) - if not getattr(field_info, "include_in_schema", True): - continue - param_schema = get_schema_from_model_field( - field=param, - schema_generator=schema_generator, - model_name_map=model_name_map, - field_mapping=field_mapping, - separate_input_output_schemas=separate_input_output_schemas, - ) - name = param.alias - convert_underscores = getattr( - param.field_info, - "convert_underscores", - default_convert_underscores, - ) - if ( - param_type == ParamTypes.header - and param.alias == param.name - and convert_underscores - ): - name = param.name.replace("_", "-") - - parameter = { - "name": name, - "in": param_type.value, - "required": param.required, - "schema": param_schema, - } - if field_info.description: - parameter["description"] = field_info.description - openapi_examples = getattr(field_info, "openapi_examples", None) - example = getattr(field_info, "example", None) - if openapi_examples: - parameter["examples"] = jsonable_encoder(openapi_examples) - elif example != Undefined: - parameter["example"] = jsonable_encoder(example) - if getattr(field_info, "deprecated", None): - parameter["deprecated"] = True - parameters.append(parameter) + for param in all_route_params: + field_info = param.field_info + field_info = cast(Param, field_info) + if not field_info.include_in_schema: + continue + param_schema = get_schema_from_model_field( + field=param, + schema_generator=schema_generator, + model_name_map=model_name_map, + field_mapping=field_mapping, + separate_input_output_schemas=separate_input_output_schemas, + ) + parameter = { + "name": param.alias, + "in": field_info.in_.value, + "required": param.required, + "schema": param_schema, + } + if field_info.description: + parameter["description"] = field_info.description + if field_info.openapi_examples: + parameter["examples"] = jsonable_encoder(field_info.openapi_examples) + elif field_info.example != Undefined: + parameter["example"] = jsonable_encoder(field_info.example) + if field_info.deprecated: + parameter["deprecated"] = field_info.deprecated + parameters.append(parameter) return parameters @@ -285,8 +247,9 @@ def get_openapi_path( operation.setdefault("security", []).extend(operation_security) if security_definitions: security_schemes.update(security_definitions) - operation_parameters = _get_openapi_operation_parameters( - dependant=route.dependant, + all_route_params = get_flat_params(route.dependant) + operation_parameters = get_openapi_operation_parameters( + all_route_params=all_route_params, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, @@ -384,9 +347,9 @@ def get_openapi_path( openapi_response = operation_responses.setdefault( status_code_key, {} ) - assert isinstance(process_response, dict), ( - "An additional response must be a dict" - ) + assert isinstance( + process_response, dict + ), "An additional response must be a dict" field = route.response_fields.get(additional_status_code) additional_field_schema: Optional[Dict[str, Any]] = None if field: @@ -415,8 +378,7 @@ def get_openapi_path( ) deep_dict_update(openapi_response, process_response) openapi_response["description"] = description - http422 = "422" - all_route_params = get_flat_params(route.dependant) + http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) if (all_route_params or route.body_field) and not any( status in operation["responses"] for status in [http422, "4XX", "default"] @@ -454,9 +416,9 @@ def get_fields_from_routes( route, routing.APIRoute ): if route.body_field: - assert isinstance(route.body_field, ModelField), ( - "A request body must be a Pydantic Field" - ) + assert isinstance( + route.body_field, ModelField + ), "A request body must be a Pydantic Field" body_fields_from_routes.append(route.body_field) if route.response_field: responses_from_routes.append(route.response_field) @@ -488,7 +450,6 @@ def get_openapi( contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, separate_input_output_schemas: bool = True, - external_docs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: info: Dict[str, Any] = {"title": title, "version": version} if summary: @@ -566,6 +527,4 @@ def get_openapi( output["webhooks"] = webhook_paths if tags: output["tags"] = tags - if external_docs: - output["externalDocs"] = external_docs return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore diff --git a/venv/lib/python3.12/site-packages/fastapi/param_functions.py b/venv/lib/python3.12/site-packages/fastapi/param_functions.py index b362162..3f6dbc9 100644 --- a/venv/lib/python3.12/site-packages/fastapi/param_functions.py +++ b/venv/lib/python3.12/site-packages/fastapi/param_functions.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from fastapi import params from fastapi._compat import Undefined from fastapi.openapi.models import Example -from typing_extensions import Annotated, Doc, deprecated +from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined] _Unset: Any = Undefined @@ -240,7 +240,7 @@ def Path( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -565,7 +565,7 @@ def Query( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -880,7 +880,7 @@ def Header( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -1185,7 +1185,7 @@ def Cookie( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -1282,7 +1282,7 @@ def Body( # noqa: N802 ), ] = _Unset, embed: Annotated[ - Union[bool, None], + bool, Doc( """ When `embed` is `True`, the parameter will be expected in a JSON body as a @@ -1294,7 +1294,7 @@ def Body( # noqa: N802 [FastAPI docs for Body - Multiple Parameters](https://fastapi.tiangolo.com/tutorial/body-multiple-params/#embed-a-single-body-parameter). """ ), - ] = None, + ] = False, media_type: Annotated[ str, Doc( @@ -1512,7 +1512,7 @@ def Body( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -1827,7 +1827,7 @@ def Form( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -2141,7 +2141,7 @@ def File( # noqa: N802 ), ] = None, deprecated: Annotated[ - Union[deprecated, str, bool, None], + Optional[bool], Doc( """ Mark this parameter field as deprecated. @@ -2298,7 +2298,7 @@ def Security( # noqa: N802 dependency. The term "scope" comes from the OAuth2 specification, it seems to be - intentionally vague and interpretable. It normally refers to permissions, + intentionaly vague and interpretable. It normally refers to permissions, in cases to roles. These scopes are integrated with OpenAPI (and the API docs at `/docs`). @@ -2343,7 +2343,7 @@ def Security( # noqa: N802 ```python from typing import Annotated - from fastapi import Security, FastAPI + from fastapi import Depends, FastAPI from .db import User from .security import get_current_active_user diff --git a/venv/lib/python3.12/site-packages/fastapi/params.py b/venv/lib/python3.12/site-packages/fastapi/params.py index 8f5601d..b40944d 100644 --- a/venv/lib/python3.12/site-packages/fastapi/params.py +++ b/venv/lib/python3.12/site-packages/fastapi/params.py @@ -6,11 +6,7 @@ from fastapi.openapi.models import Example from pydantic.fields import FieldInfo from typing_extensions import Annotated, deprecated -from ._compat import ( - PYDANTIC_V2, - PYDANTIC_VERSION_MINOR_TUPLE, - Undefined, -) +from ._compat import PYDANTIC_V2, Undefined _Unset: Any = Undefined @@ -67,11 +63,12 @@ class Param(FieldInfo): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, ): + self.deprecated = deprecated if example is not _Unset: warnings.warn( "`example` has been deprecated, please use `examples` instead", @@ -95,7 +92,7 @@ class Param(FieldInfo): max_length=max_length, discriminator=discriminator, multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, + allow_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, **extra, @@ -109,10 +106,6 @@ class Param(FieldInfo): stacklevel=4, ) current_json_schema_extra = json_schema_extra or extra - if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7): - self.deprecated = deprecated - else: - kwargs["deprecated"] = deprecated if PYDANTIC_V2: kwargs.update( { @@ -181,7 +174,7 @@ class Path(Param): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, @@ -267,7 +260,7 @@ class Query(Param): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, @@ -352,7 +345,7 @@ class Header(Param): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, @@ -437,7 +430,7 @@ class Cookie(Param): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, @@ -483,7 +476,7 @@ class Body(FieldInfo): *, default_factory: Union[Callable[[], Any], None] = _Unset, annotation: Optional[Any] = None, - embed: Union[bool, None] = None, + embed: bool = False, media_type: str = "application/json", alias: Optional[str] = None, alias_priority: Union[int, None] = _Unset, @@ -521,13 +514,14 @@ class Body(FieldInfo): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, ): self.embed = embed self.media_type = media_type + self.deprecated = deprecated if example is not _Unset: warnings.warn( "`example` has been deprecated, please use `examples` instead", @@ -551,7 +545,7 @@ class Body(FieldInfo): max_length=max_length, discriminator=discriminator, multiple_of=multiple_of, - allow_inf_nan=allow_inf_nan, + allow_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, **extra, @@ -560,15 +554,11 @@ class Body(FieldInfo): kwargs["examples"] = examples if regex is not None: warnings.warn( - "`regex` has been deprecated, please use `pattern` instead", + "`regex` has been depreacated, please use `pattern` instead", category=DeprecationWarning, stacklevel=4, ) current_json_schema_extra = json_schema_extra or extra - if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7): - self.deprecated = deprecated - else: - kwargs["deprecated"] = deprecated if PYDANTIC_V2: kwargs.update( { @@ -637,7 +627,7 @@ class Form(Body): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, @@ -646,6 +636,7 @@ class Form(Body): default=default, default_factory=default_factory, annotation=annotation, + embed=True, media_type=media_type, alias=alias, alias_priority=alias_priority, @@ -721,7 +712,7 @@ class File(Form): ), ] = _Unset, openapi_examples: Optional[Dict[str, Example]] = None, - deprecated: Union[deprecated, str, bool, None] = None, + deprecated: Optional[bool] = None, include_in_schema: bool = True, json_schema_extra: Union[Dict[str, Any], None] = None, **extra: Any, diff --git a/venv/lib/python3.12/site-packages/fastapi/routing.py b/venv/lib/python3.12/site-packages/fastapi/routing.py index f620ced..54d53bb 100644 --- a/venv/lib/python3.12/site-packages/fastapi/routing.py +++ b/venv/lib/python3.12/site-packages/fastapi/routing.py @@ -1,19 +1,16 @@ +import asyncio import dataclasses import email.message import inspect import json -import sys -from contextlib import AsyncExitStack, asynccontextmanager +from contextlib import AsyncExitStack from enum import Enum, IntEnum from typing import ( Any, - AsyncIterator, Callable, - Collection, Coroutine, Dict, List, - Mapping, Optional, Sequence, Set, @@ -34,10 +31,8 @@ from fastapi._compat import ( from fastapi.datastructures import Default, DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import ( - _should_embed_body_fields, get_body_field, get_dependant, - get_flat_dependant, get_parameterless_sub_dependant, get_typed_return_annotation, solve_dependencies, @@ -52,7 +47,7 @@ from fastapi.exceptions import ( from fastapi.types import DecoratedCallable, IncEx from fastapi.utils import ( create_cloned_field, - create_model_field, + create_response_field, generate_unique_id, get_value_or_default, is_body_allowed_for_status_code, @@ -72,14 +67,9 @@ from starlette.routing import ( websocket_session, ) from starlette.routing import Mount as Mount # noqa -from starlette.types import AppType, ASGIApp, Lifespan, Scope +from starlette.types import ASGIApp, Lifespan, Scope from starlette.websockets import WebSocket -from typing_extensions import Annotated, Doc, deprecated - -if sys.version_info >= (3, 13): # pragma: no cover - from inspect import iscoroutinefunction -else: # pragma: no cover - from asyncio import iscoroutinefunction +from typing_extensions import Annotated, Doc, deprecated # type: ignore [attr-defined] def _prepare_response_content( @@ -125,28 +115,10 @@ def _prepare_response_content( for k, v in res.items() } elif dataclasses.is_dataclass(res): - assert not isinstance(res, type) return dataclasses.asdict(res) return res -def _merge_lifespan_context( - original_context: Lifespan[Any], nested_context: Lifespan[Any] -) -> Lifespan[Any]: - @asynccontextmanager - async def merged_lifespan( - app: AppType, - ) -> AsyncIterator[Optional[Mapping[str, Any]]]: - async with original_context(app) as maybe_original_state: - async with nested_context(app) as maybe_nested_state: - if maybe_nested_state is None and maybe_original_state is None: - yield None # old ASGI compatibility - else: - yield {**(maybe_nested_state or {}), **(maybe_original_state or {})} - - return merged_lifespan # type: ignore[return-value] - - async def serialize_response( *, field: Optional[ModelField] = None, @@ -234,10 +206,9 @@ def get_request_handler( response_model_exclude_defaults: bool = False, response_model_exclude_none: bool = False, dependency_overrides_provider: Optional[Any] = None, - embed_body_fields: bool = False, ) -> Callable[[Request], Coroutine[Any, Any, Response]]: assert dependant.call is not None, "dependant.call must be a function" - is_coroutine = iscoroutinefunction(dependant.call) + is_coroutine = asyncio.iscoroutinefunction(dependant.call) is_body_form = body_field and isinstance(body_field.field_info, params.Form) if isinstance(response_class, DefaultPlaceholder): actual_response_class: Type[Response] = response_class.value @@ -245,149 +216,113 @@ def get_request_handler( actual_response_class = response_class async def app(request: Request) -> Response: - response: Union[Response, None] = None - async with AsyncExitStack() as file_stack: - try: - body: Any = None - if body_field: - if is_body_form: - body = await request.form() - file_stack.push_async_callback(body.close) - else: - body_bytes = await request.body() - if body_bytes: - json_body: Any = Undefined - content_type_value = request.headers.get("content-type") - if not content_type_value: - json_body = await request.json() - else: - message = email.message.Message() - message["content-type"] = content_type_value - if message.get_content_maintype() == "application": - subtype = message.get_content_subtype() - if subtype == "json" or subtype.endswith("+json"): - json_body = await request.json() - if json_body != Undefined: - body = json_body - else: - body = body_bytes - except json.JSONDecodeError as e: - validation_error = RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - } - ], - body=e.doc, - ) - raise validation_error from e - except HTTPException: - # If a middleware raises an HTTPException, it should be raised again - raise - except Exception as e: - http_error = HTTPException( - status_code=400, detail="There was an error parsing the body" - ) - raise http_error from e - errors: List[Any] = [] - async with AsyncExitStack() as async_exit_stack: - solved_result = await solve_dependencies( - request=request, - dependant=dependant, - body=body, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, - ) - errors = solved_result.errors - if not errors: - raw_response = await run_endpoint_function( - dependant=dependant, - values=solved_result.values, - is_coroutine=is_coroutine, - ) - if isinstance(raw_response, Response): - if raw_response.background is None: - raw_response.background = solved_result.background_tasks - response = raw_response - else: - response_args: Dict[str, Any] = { - "background": solved_result.background_tasks - } - # If status_code was set, use it, otherwise use the default from the - # response class, in the case of redirect it's 307 - current_status_code = ( - status_code - if status_code - else solved_result.response.status_code - ) - if current_status_code is not None: - response_args["status_code"] = current_status_code - if solved_result.response.status_code: - response_args["status_code"] = ( - solved_result.response.status_code - ) - content = await serialize_response( - field=response_field, - response_content=raw_response, - include=response_model_include, - exclude=response_model_exclude, - by_alias=response_model_by_alias, - exclude_unset=response_model_exclude_unset, - exclude_defaults=response_model_exclude_defaults, - exclude_none=response_model_exclude_none, - is_coroutine=is_coroutine, - ) - response = actual_response_class(content, **response_args) - if not is_body_allowed_for_status_code(response.status_code): - response.body = b"" - response.headers.raw.extend(solved_result.response.headers.raw) - if errors: - validation_error = RequestValidationError( - _normalize_errors(errors), body=body - ) - raise validation_error - if response is None: - raise FastAPIError( - "No response object was returned. There's a high chance that the " - "application code is raising an exception and a dependency with yield " - "has a block with a bare except, or a block with except Exception, " - "and is not raising the exception again. Read more about it in the " - "docs: https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-with-yield/#dependencies-with-yield-and-except" + try: + body: Any = None + if body_field: + if is_body_form: + body = await request.form() + stack = request.scope.get("fastapi_astack") + assert isinstance(stack, AsyncExitStack) + stack.push_async_callback(body.close) + else: + body_bytes = await request.body() + if body_bytes: + json_body: Any = Undefined + content_type_value = request.headers.get("content-type") + if not content_type_value: + json_body = await request.json() + else: + message = email.message.Message() + message["content-type"] = content_type_value + if message.get_content_maintype() == "application": + subtype = message.get_content_subtype() + if subtype == "json" or subtype.endswith("+json"): + json_body = await request.json() + if json_body != Undefined: + body = json_body + else: + body = body_bytes + except json.JSONDecodeError as e: + raise RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + } + ], + body=e.doc, + ) from e + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=400, detail="There was an error parsing the body" + ) from e + solved_result = await solve_dependencies( + request=request, + dependant=dependant, + body=body, + dependency_overrides_provider=dependency_overrides_provider, + ) + values, errors, background_tasks, sub_response, _ = solved_result + if errors: + raise RequestValidationError(_normalize_errors(errors), body=body) + else: + raw_response = await run_endpoint_function( + dependant=dependant, values=values, is_coroutine=is_coroutine ) - return response + + if isinstance(raw_response, Response): + if raw_response.background is None: + raw_response.background = background_tasks + return raw_response + response_args: Dict[str, Any] = {"background": background_tasks} + # If status_code was set, use it, otherwise use the default from the + # response class, in the case of redirect it's 307 + current_status_code = ( + status_code if status_code else sub_response.status_code + ) + if current_status_code is not None: + response_args["status_code"] = current_status_code + if sub_response.status_code: + response_args["status_code"] = sub_response.status_code + content = await serialize_response( + field=response_field, + response_content=raw_response, + include=response_model_include, + exclude=response_model_exclude, + by_alias=response_model_by_alias, + exclude_unset=response_model_exclude_unset, + exclude_defaults=response_model_exclude_defaults, + exclude_none=response_model_exclude_none, + is_coroutine=is_coroutine, + ) + response = actual_response_class(content, **response_args) + if not is_body_allowed_for_status_code(response.status_code): + response.body = b"" + response.headers.raw.extend(sub_response.headers.raw) + return response return app def get_websocket_app( - dependant: Dependant, - dependency_overrides_provider: Optional[Any] = None, - embed_body_fields: bool = False, + dependant: Dependant, dependency_overrides_provider: Optional[Any] = None ) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]: async def app(websocket: WebSocket) -> None: - async with AsyncExitStack() as async_exit_stack: - # TODO: remove this scope later, after a few releases - # This scope fastapi_astack is no longer used by FastAPI, kept for - # compatibility, just in case - websocket.scope["fastapi_astack"] = async_exit_stack - solved_result = await solve_dependencies( - request=websocket, - dependant=dependant, - dependency_overrides_provider=dependency_overrides_provider, - async_exit_stack=async_exit_stack, - embed_body_fields=embed_body_fields, - ) - if solved_result.errors: - raise WebSocketRequestValidationError( - _normalize_errors(solved_result.errors) - ) - assert dependant.call is not None, "dependant.call must be a function" - await dependant.call(**solved_result.values) + solved_result = await solve_dependencies( + request=websocket, + dependant=dependant, + dependency_overrides_provider=dependency_overrides_provider, + ) + values, errors, _, _2, _3 = solved_result + if errors: + raise WebSocketRequestValidationError(_normalize_errors(errors)) + assert dependant.call is not None, "dependant.call must be a function" + await dependant.call(**values) return app @@ -413,15 +348,11 @@ class APIWebSocketRoute(routing.WebSocketRoute): 0, get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) + self.app = websocket_session( get_websocket_app( dependant=self.dependant, dependency_overrides_provider=dependency_overrides_provider, - embed_body_fields=self._embed_body_fields, ) ) @@ -500,9 +431,9 @@ class APIRoute(routing.Route): methods = ["GET"] self.methods: Set[str] = {method.upper() for method in methods} if isinstance(generate_unique_id_function, DefaultPlaceholder): - current_generate_unique_id: Callable[[APIRoute], str] = ( - generate_unique_id_function.value - ) + current_generate_unique_id: Callable[ + ["APIRoute"], str + ] = generate_unique_id_function.value else: current_generate_unique_id = generate_unique_id_function self.unique_id = self.operation_id or current_generate_unique_id(self) @@ -511,11 +442,11 @@ class APIRoute(routing.Route): status_code = int(status_code) self.status_code = status_code if self.response_model: - assert is_body_allowed_for_status_code(status_code), ( - f"Status code {status_code} must not have a response body" - ) + assert is_body_allowed_for_status_code( + status_code + ), f"Status code {status_code} must not have a response body" response_name = "Response_" + self.unique_id - self.response_field = create_model_field( + self.response_field = create_response_field( name=response_name, type_=self.response_model, mode="serialization", @@ -528,9 +459,9 @@ class APIRoute(routing.Route): # By being a new field, no inheritance will be passed as is. A new model # will always be created. # TODO: remove when deprecating Pydantic v1 - self.secure_cloned_response_field: Optional[ModelField] = ( - create_cloned_field(self.response_field) - ) + self.secure_cloned_response_field: Optional[ + ModelField + ] = create_cloned_field(self.response_field) else: self.response_field = None # type: ignore self.secure_cloned_response_field = None @@ -544,13 +475,11 @@ class APIRoute(routing.Route): assert isinstance(response, dict), "An additional response must be a dict" model = response.get("model") if model: - assert is_body_allowed_for_status_code(additional_status_code), ( - f"Status code {additional_status_code} must not have a response body" - ) + assert is_body_allowed_for_status_code( + additional_status_code + ), f"Status code {additional_status_code} must not have a response body" response_name = f"Response_{additional_status_code}_{self.unique_id}" - response_field = create_model_field( - name=response_name, type_=model, mode="serialization" - ) + response_field = create_response_field(name=response_name, type_=model) response_fields[additional_status_code] = response_field if response_fields: self.response_fields: Dict[Union[int, str], ModelField] = response_fields @@ -564,15 +493,7 @@ class APIRoute(routing.Route): 0, get_parameterless_sub_dependant(depends=depends, path=self.path_format), ) - self._flat_dependant = get_flat_dependant(self.dependant) - self._embed_body_fields = _should_embed_body_fields( - self._flat_dependant.body_params - ) - self.body_field = get_body_field( - flat_dependant=self._flat_dependant, - name=self.unique_id, - embed_body_fields=self._embed_body_fields, - ) + self.body_field = get_body_field(dependant=self.dependant, name=self.unique_id) self.app = request_response(self.get_route_handler()) def get_route_handler(self) -> Callable[[Request], Coroutine[Any, Any, Response]]: @@ -589,7 +510,6 @@ class APIRoute(routing.Route): response_model_exclude_defaults=self.response_model_exclude_defaults, response_model_exclude_none=self.response_model_exclude_none, dependency_overrides_provider=self.dependency_overrides_provider, - embed_body_fields=self._embed_body_fields, ) def matches(self, scope: Scope) -> Tuple[Match, Scope]: @@ -821,7 +741,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -851,9 +771,9 @@ class APIRouter(routing.Router): ) if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" - assert not prefix.endswith("/"), ( - "A path prefix must not end with '/', as the routes will start with '/'" - ) + assert not prefix.endswith( + "/" + ), "A path prefix must not end with '/', as the routes will start with '/'" self.prefix = prefix self.tags: List[Union[str, Enum]] = tags or [] self.dependencies = list(dependencies or []) @@ -869,7 +789,7 @@ class APIRouter(routing.Router): def route( self, path: str, - methods: Optional[Collection[str]] = None, + methods: Optional[List[str]] = None, name: Optional[str] = None, include_in_schema: bool = True, ) -> Callable[[DecoratedCallable], DecoratedCallable]: @@ -1263,9 +1183,9 @@ class APIRouter(routing.Router): """ if prefix: assert prefix.startswith("/"), "A path prefix must start with '/'" - assert not prefix.endswith("/"), ( - "A path prefix must not end with '/', as the routes will start with '/'" - ) + assert not prefix.endswith( + "/" + ), "A path prefix must not end with '/', as the routes will start with '/'" else: for r in router.routes: path = getattr(r, "path") # noqa: B009 @@ -1365,10 +1285,6 @@ class APIRouter(routing.Router): self.add_event_handler("startup", handler) for handler in router.on_shutdown: self.add_event_handler("shutdown", handler) - self.lifespan_context = _merge_lifespan_context( - self.lifespan_context, - router.lifespan_context, - ) def get( self, @@ -1633,7 +1549,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2010,7 +1926,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2392,7 +2308,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -2774,7 +2690,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3151,7 +3067,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3528,7 +3444,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -3910,7 +3826,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -4292,7 +4208,7 @@ class APIRouter(routing.Router): This affects the generated OpenAPI (e.g. visible at `/docs`). Read more about it in the - [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-parameters-from-openapi). + [FastAPI docs for Query Parameters and String Validations](https://fastapi.tiangolo.com/tutorial/query-params-str-validations/#exclude-from-openapi). """ ), ] = True, @@ -4377,7 +4293,7 @@ class APIRouter(routing.Router): app = FastAPI() router = APIRouter() - @router.trace("/items/{item_id}") + @router.put("/items/{item_id}") def trace_item(item_id: str): return None diff --git a/venv/lib/python3.12/site-packages/fastapi/security/api_key.py b/venv/lib/python3.12/site-packages/fastapi/security/api_key.py index 6d6dd01..b1a6b4f 100644 --- a/venv/lib/python3.12/site-packages/fastapi/security/api_key.py +++ b/venv/lib/python3.12/site-packages/fastapi/security/api_key.py @@ -5,19 +5,11 @@ from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class APIKeyBase(SecurityBase): - @staticmethod - def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]: - if not api_key: - if auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" - ) - return None - return api_key + pass class APIKeyQuery(APIKeyBase): @@ -84,7 +76,7 @@ class APIKeyQuery(APIKeyBase): Doc( """ By default, if the query parameter is not provided, `APIKeyQuery` will - automatically cancel the request and send the client an error. + automatically cancel the request and sebd the client an error. If `auto_error` is set to `False`, when the query parameter is not available, instead of erroring out, the dependency result will be @@ -100,7 +92,7 @@ class APIKeyQuery(APIKeyBase): ] = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.query}, + **{"in": APIKeyIn.query}, # type: ignore[arg-type] name=name, description=description, ) @@ -109,7 +101,14 @@ class APIKeyQuery(APIKeyBase): async def __call__(self, request: Request) -> Optional[str]: api_key = request.query_params.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + if not api_key: + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None + return api_key class APIKeyHeader(APIKeyBase): @@ -188,7 +187,7 @@ class APIKeyHeader(APIKeyBase): ] = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.header}, + **{"in": APIKeyIn.header}, # type: ignore[arg-type] name=name, description=description, ) @@ -197,7 +196,14 @@ class APIKeyHeader(APIKeyBase): async def __call__(self, request: Request) -> Optional[str]: api_key = request.headers.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + if not api_key: + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None + return api_key class APIKeyCookie(APIKeyBase): @@ -276,7 +282,7 @@ class APIKeyCookie(APIKeyBase): ] = True, ): self.model: APIKey = APIKey( - **{"in": APIKeyIn.cookie}, + **{"in": APIKeyIn.cookie}, # type: ignore[arg-type] name=name, description=description, ) @@ -285,4 +291,11 @@ class APIKeyCookie(APIKeyBase): async def __call__(self, request: Request) -> Optional[str]: api_key = request.cookies.get(self.model.name) - return self.check_api_key(api_key, self.auto_error) + if not api_key: + if self.auto_error: + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" + ) + else: + return None + return api_key diff --git a/venv/lib/python3.12/site-packages/fastapi/security/http.py b/venv/lib/python3.12/site-packages/fastapi/security/http.py index 9ab2df3..738455d 100644 --- a/venv/lib/python3.12/site-packages/fastapi/security/http.py +++ b/venv/lib/python3.12/site-packages/fastapi/security/http.py @@ -10,12 +10,12 @@ from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel from starlette.requests import Request from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class HTTPBasicCredentials(BaseModel): """ - The HTTP Basic credentials given as the result of using `HTTPBasic` in a + The HTTP Basic credendials given as the result of using `HTTPBasic` in a dependency. Read more about it in the @@ -277,7 +277,7 @@ class HTTPBearer(HTTPBase): bool, Doc( """ - By default, if the HTTP Bearer token is not provided (in an + By default, if the HTTP Bearer token not provided (in an `Authorization` header), `HTTPBearer` will automatically cancel the request and send the client an error. @@ -380,7 +380,7 @@ class HTTPDigest(HTTPBase): bool, Doc( """ - By default, if the HTTP Digest is not provided, `HTTPDigest` will + By default, if the HTTP Digest not provided, `HTTPDigest` will automatically cancel the request and send the client an error. If `auto_error` is set to `False`, when the HTTP Digest is not @@ -413,11 +413,8 @@ class HTTPDigest(HTTPBase): else: return None if scheme.lower() != "digest": - if self.auto_error: - raise HTTPException( - status_code=HTTP_403_FORBIDDEN, - detail="Invalid authentication credentials", - ) - else: - return None + raise HTTPException( + status_code=HTTP_403_FORBIDDEN, + detail="Invalid authentication credentials", + ) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) diff --git a/venv/lib/python3.12/site-packages/fastapi/security/oauth2.py b/venv/lib/python3.12/site-packages/fastapi/security/oauth2.py index 88e394d..9281dfb 100644 --- a/venv/lib/python3.12/site-packages/fastapi/security/oauth2.py +++ b/venv/lib/python3.12/site-packages/fastapi/security/oauth2.py @@ -10,7 +10,7 @@ from starlette.requests import Request from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN # TODO: import from typing when deprecating Python 3.9 -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class OAuth2PasswordRequestForm: @@ -52,9 +52,9 @@ class OAuth2PasswordRequestForm: ``` Note that for OAuth2 the scope `items:read` is a single scope in an opaque string. - You could have custom internal logic to separate it by colon characters (`:`) or + You could have custom internal logic to separate it by colon caracters (`:`) or similar, and get the two parts `items` and `read`. Many applications do that to - group and organize permissions, you could do it as well in your application, just + group and organize permisions, you could do it as well in your application, just know that that it is application specific, it's not part of the specification. """ @@ -63,7 +63,7 @@ class OAuth2PasswordRequestForm: *, grant_type: Annotated[ Union[str, None], - Form(pattern="^password$"), + Form(pattern="password"), Doc( """ The OAuth2 spec says it is required and MUST be the fixed string @@ -85,7 +85,7 @@ class OAuth2PasswordRequestForm: ], password: Annotated[ str, - Form(json_schema_extra={"format": "password"}), + Form(), Doc( """ `password` string. The OAuth2 spec requires the exact field name @@ -130,7 +130,7 @@ class OAuth2PasswordRequestForm: ] = None, client_secret: Annotated[ Union[str, None], - Form(json_schema_extra={"format": "password"}), + Form(), Doc( """ If there's a `client_password` (and a `client_id`), they can be sent @@ -194,9 +194,9 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm): ``` Note that for OAuth2 the scope `items:read` is a single scope in an opaque string. - You could have custom internal logic to separate it by colon characters (`:`) or + You could have custom internal logic to separate it by colon caracters (`:`) or similar, and get the two parts `items` and `read`. Many applications do that to - group and organize permissions, you could do it as well in your application, just + group and organize permisions, you could do it as well in your application, just know that that it is application specific, it's not part of the specification. @@ -217,7 +217,7 @@ class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm): self, grant_type: Annotated[ str, - Form(pattern="^password$"), + Form(pattern="password"), Doc( """ The OAuth2 spec says it is required and MUST be the fixed string @@ -353,7 +353,7 @@ class OAuth2(SecurityBase): bool, Doc( """ - By default, if no HTTP Authorization header is provided, required for + By default, if no HTTP Auhtorization header is provided, required for OAuth2 authentication, it will automatically cancel the request and send the client an error. @@ -441,7 +441,7 @@ class OAuth2PasswordBearer(OAuth2): bool, Doc( """ - By default, if no HTTP Authorization header is provided, required for + By default, if no HTTP Auhtorization header is provided, required for OAuth2 authentication, it will automatically cancel the request and send the client an error. @@ -457,26 +457,11 @@ class OAuth2PasswordBearer(OAuth2): """ ), ] = True, - refreshUrl: Annotated[ - Optional[str], - Doc( - """ - The URL to refresh the token and obtain a new one. - """ - ), - ] = None, ): if not scopes: scopes = {} flows = OAuthFlowsModel( - password=cast( - Any, - { - "tokenUrl": tokenUrl, - "refreshUrl": refreshUrl, - "scopes": scopes, - }, - ) + password=cast(Any, {"tokenUrl": tokenUrl, "scopes": scopes}) ) super().__init__( flows=flows, @@ -558,7 +543,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2): bool, Doc( """ - By default, if no HTTP Authorization header is provided, required for + By default, if no HTTP Auhtorization header is provided, required for OAuth2 authentication, it will automatically cancel the request and send the client an error. diff --git a/venv/lib/python3.12/site-packages/fastapi/security/open_id_connect_url.py b/venv/lib/python3.12/site-packages/fastapi/security/open_id_connect_url.py index c8cceb9..c612b47 100644 --- a/venv/lib/python3.12/site-packages/fastapi/security/open_id_connect_url.py +++ b/venv/lib/python3.12/site-packages/fastapi/security/open_id_connect_url.py @@ -5,7 +5,7 @@ from fastapi.security.base import SecurityBase from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.status import HTTP_403_FORBIDDEN -from typing_extensions import Annotated, Doc +from typing_extensions import Annotated, Doc # type: ignore [attr-defined] class OpenIdConnect(SecurityBase): @@ -49,7 +49,7 @@ class OpenIdConnect(SecurityBase): bool, Doc( """ - By default, if no HTTP Authorization header is provided, required for + By default, if no HTTP Auhtorization header is provided, required for OpenID Connect authentication, it will automatically cancel the request and send the client an error. diff --git a/venv/lib/python3.12/site-packages/fastapi/types.py b/venv/lib/python3.12/site-packages/fastapi/types.py index 3205654..7adf565 100644 --- a/venv/lib/python3.12/site-packages/fastapi/types.py +++ b/venv/lib/python3.12/site-packages/fastapi/types.py @@ -6,5 +6,6 @@ from pydantic import BaseModel DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) UnionType = getattr(types, "UnionType", Union) +NoneType = getattr(types, "UnionType", None) ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str] IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] diff --git a/venv/lib/python3.12/site-packages/fastapi/utils.py b/venv/lib/python3.12/site-packages/fastapi/utils.py index 98725ff..f8463dd 100644 --- a/venv/lib/python3.12/site-packages/fastapi/utils.py +++ b/venv/lib/python3.12/site-packages/fastapi/utils.py @@ -34,9 +34,9 @@ if TYPE_CHECKING: # pragma: nocover from .routing import APIRoute # Cache for `create_cloned_field` -_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = ( - WeakKeyDictionary() -) +_CLONED_TYPES_CACHE: MutableMapping[ + Type[BaseModel], Type[BaseModel] +] = WeakKeyDictionary() def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: @@ -53,16 +53,16 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool: }: return True current_status_code = int(status_code) - return not (current_status_code < 200 or current_status_code in {204, 205, 304}) + return not (current_status_code < 200 or current_status_code in {204, 304}) def get_path_param_names(path: str) -> Set[str]: return set(re.findall("{(.*?)}", path)) -def create_model_field( +def create_response_field( name: str, - type_: Any, + type_: Type[Any], class_validators: Optional[Dict[str, Validator]] = None, default: Optional[Any] = Undefined, required: Union[bool, UndefinedType] = Undefined, @@ -71,6 +71,9 @@ def create_model_field( alias: Optional[str] = None, mode: Literal["validation", "serialization"] = "validation", ) -> ModelField: + """ + Create a new response field. Raises if type_ is invalid. + """ class_validators = class_validators or {} if PYDANTIC_V2: field_info = field_info or FieldInfo( @@ -132,12 +135,11 @@ def create_cloned_field( use_type.__fields__[f.name] = create_cloned_field( f, cloned_types=cloned_types ) - new_field = create_model_field(name=field.name, type_=use_type) + new_field = create_response_field(name=field.name, type_=use_type) new_field.has_alias = field.has_alias # type: ignore[attr-defined] new_field.alias = field.alias # type: ignore[misc] new_field.class_validators = field.class_validators # type: ignore[attr-defined] new_field.default = field.default # type: ignore[misc] - new_field.default_factory = field.default_factory # type: ignore[attr-defined] new_field.required = field.required # type: ignore[misc] new_field.model_config = field.model_config # type: ignore[attr-defined] new_field.field_info = field.field_info @@ -171,17 +173,17 @@ def generate_operation_id_for_path( DeprecationWarning, stacklevel=2, ) - operation_id = f"{name}{path}" + operation_id = name + path operation_id = re.sub(r"\W", "_", operation_id) - operation_id = f"{operation_id}_{method.lower()}" + operation_id = operation_id + "_" + method.lower() return operation_id def generate_unique_id(route: "APIRoute") -> str: - operation_id = f"{route.name}{route.path_format}" + operation_id = route.name + route.path_format operation_id = re.sub(r"\W", "_", operation_id) assert route.methods - operation_id = f"{operation_id}_{list(route.methods)[0].lower()}" + operation_id = operation_id + "_" + list(route.methods)[0].lower() return operation_id @@ -219,3 +221,9 @@ def get_value_or_default( if not isinstance(item, DefaultPlaceholder): return item return first_item + + +def match_pydantic_error_url(error_type: str) -> Any: + from dirty_equals import IsStr + + return IsStr(regex=rf"^https://errors\.pydantic\.dev/.*/v/{error_type}") diff --git a/venv/lib/python3.12/site-packages/jose/__init__.py b/venv/lib/python3.12/site-packages/jose/__init__.py index 7e53b60..054baa7 100644 --- a/venv/lib/python3.12/site-packages/jose/__init__.py +++ b/venv/lib/python3.12/site-packages/jose/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.5.0" +__version__ = "3.3.0" __author__ = "Michael Davis" __license__ = "MIT" __copyright__ = "Copyright 2016 Michael Davis" diff --git a/venv/lib/python3.12/site-packages/jose/backends/__init__.py b/venv/lib/python3.12/site-packages/jose/backends/__init__.py index 9918969..e7bba69 100644 --- a/venv/lib/python3.12/site-packages/jose/backends/__init__.py +++ b/venv/lib/python3.12/site-packages/jose/backends/__init__.py @@ -1,4 +1,10 @@ -from jose.backends.native import get_random_bytes # noqa: F401 +try: + from jose.backends.cryptography_backend import get_random_bytes # noqa: F401 +except ImportError: + try: + from jose.backends.pycrypto_backend import get_random_bytes # noqa: F401 + except ImportError: + from jose.backends.native import get_random_bytes # noqa: F401 try: from jose.backends.cryptography_backend import CryptographyRSAKey as RSAKey # noqa: F401 diff --git a/venv/lib/python3.12/site-packages/jose/backends/_asn1.py b/venv/lib/python3.12/site-packages/jose/backends/_asn1.py index 87e3df1..af5fa8b 100644 --- a/venv/lib/python3.12/site-packages/jose/backends/_asn1.py +++ b/venv/lib/python3.12/site-packages/jose/backends/_asn1.py @@ -2,7 +2,6 @@ Required by rsa_backend but not cryptography_backend. """ - from pyasn1.codec.der import decoder, encoder from pyasn1.type import namedtype, univ diff --git a/venv/lib/python3.12/site-packages/jose/backends/cryptography_backend.py b/venv/lib/python3.12/site-packages/jose/backends/cryptography_backend.py index ec836b4..abd2426 100644 --- a/venv/lib/python3.12/site-packages/jose/backends/cryptography_backend.py +++ b/venv/lib/python3.12/site-packages/jose/backends/cryptography_backend.py @@ -3,6 +3,7 @@ import warnings from cryptography.exceptions import InvalidSignature, InvalidTag from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.bindings.openssl.binding import Binding from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ec, padding, rsa from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature @@ -15,21 +16,35 @@ from cryptography.x509 import load_pem_x509_certificate from ..constants import ALGORITHMS from ..exceptions import JWEError, JWKError -from ..utils import ( - base64_to_long, - base64url_decode, - base64url_encode, - ensure_binary, - is_pem_format, - is_ssh_key, - long_to_base64, -) -from . import get_random_bytes +from ..utils import base64_to_long, base64url_decode, base64url_encode, ensure_binary, long_to_base64 from .base import Key _binding = None +def get_random_bytes(num_bytes): + """ + Get random bytes + + Currently, Cryptography returns OS random bytes. If you want OpenSSL + generated random bytes, you'll have to switch the RAND engine after + initializing the OpenSSL backend + Args: + num_bytes (int): Number of random bytes to generate and return + Returns: + bytes: Random bytes + """ + global _binding + + if _binding is None: + _binding = Binding() + + buf = _binding.ffi.new("char[]", num_bytes) + _binding.lib.RAND_bytes(buf, num_bytes) + rand_bytes = _binding.ffi.buffer(buf, num_bytes)[:] + return rand_bytes + + class CryptographyECKey(Key): SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 @@ -228,8 +243,8 @@ class CryptographyRSAKey(Key): self.cryptography_backend = cryptography_backend - # if it conforms to RSAPublicKey or RSAPrivateKey interface - if (hasattr(key, "public_bytes") and hasattr(key, "public_numbers")) or hasattr(key, "private_bytes"): + # if it conforms to RSAPublicKey interface + if hasattr(key, "public_bytes") and hasattr(key, "public_numbers"): self.prepared_key = key return @@ -424,8 +439,6 @@ class CryptographyAESKey(Key): ALGORITHMS.A256KW: None, } - IV_BYTE_LENGTH_MODE_MAP = {"CBC": algorithms.AES.block_size // 8, "GCM": 96 // 8} - def __init__(self, key, algorithm): if algorithm not in ALGORITHMS.AES: raise JWKError("%s is not a valid AES algorithm" % algorithm) @@ -455,8 +468,7 @@ class CryptographyAESKey(Key): def encrypt(self, plain_text, aad=None): plain_text = ensure_binary(plain_text) try: - iv_byte_length = self.IV_BYTE_LENGTH_MODE_MAP.get(self._mode.name, algorithms.AES.block_size) - iv = get_random_bytes(iv_byte_length) + iv = get_random_bytes(algorithms.AES.block_size // 8) mode = self._mode(iv) if mode.name == "GCM": cipher = aead.AESGCM(self._key) @@ -540,7 +552,14 @@ class CryptographyHMACKey(Key): if isinstance(key, str): key = key.encode("utf-8") - if is_pem_format(key) or is_ssh_key(key): + invalid_strings = [ + b"-----BEGIN PUBLIC KEY-----", + b"-----BEGIN RSA PUBLIC KEY-----", + b"-----BEGIN CERTIFICATE-----", + b"ssh-rsa", + ] + + if any(string_value in key for string_value in invalid_strings): raise JWKError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." diff --git a/venv/lib/python3.12/site-packages/jose/backends/native.py b/venv/lib/python3.12/site-packages/jose/backends/native.py index 8cc77da..eb3a6ae 100644 --- a/venv/lib/python3.12/site-packages/jose/backends/native.py +++ b/venv/lib/python3.12/site-packages/jose/backends/native.py @@ -5,7 +5,7 @@ import os from jose.backends.base import Key from jose.constants import ALGORITHMS from jose.exceptions import JWKError -from jose.utils import base64url_decode, base64url_encode, is_pem_format, is_ssh_key +from jose.utils import base64url_decode, base64url_encode def get_random_bytes(num_bytes): @@ -36,7 +36,14 @@ class HMACKey(Key): if isinstance(key, str): key = key.encode("utf-8") - if is_pem_format(key) or is_ssh_key(key): + invalid_strings = [ + b"-----BEGIN PUBLIC KEY-----", + b"-----BEGIN RSA PUBLIC KEY-----", + b"-----BEGIN CERTIFICATE-----", + b"ssh-rsa", + ] + + if any(string_value in key for string_value in invalid_strings): raise JWKError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." diff --git a/venv/lib/python3.12/site-packages/jose/backends/rsa_backend.py b/venv/lib/python3.12/site-packages/jose/backends/rsa_backend.py index 8139d69..4e8ccf1 100644 --- a/venv/lib/python3.12/site-packages/jose/backends/rsa_backend.py +++ b/venv/lib/python3.12/site-packages/jose/backends/rsa_backend.py @@ -221,6 +221,7 @@ class RSAKey(Key): return self.__class__(pyrsa.PublicKey(n=self._prepared_key.n, e=self._prepared_key.e), self._algorithm) def to_pem(self, pem_format="PKCS8"): + if isinstance(self._prepared_key, pyrsa.PrivateKey): der = self._prepared_key.save_pkcs1(format="DER") if pem_format == "PKCS8": diff --git a/venv/lib/python3.12/site-packages/jose/constants.py b/venv/lib/python3.12/site-packages/jose/constants.py index 58787d4..ab4d74d 100644 --- a/venv/lib/python3.12/site-packages/jose/constants.py +++ b/venv/lib/python3.12/site-packages/jose/constants.py @@ -96,5 +96,3 @@ class Zips: ZIPS = Zips() - -JWE_SIZE_LIMIT = 250 * 1024 diff --git a/venv/lib/python3.12/site-packages/jose/jwe.py b/venv/lib/python3.12/site-packages/jose/jwe.py index 09e5c32..2c387ff 100644 --- a/venv/lib/python3.12/site-packages/jose/jwe.py +++ b/venv/lib/python3.12/site-packages/jose/jwe.py @@ -6,13 +6,13 @@ from struct import pack from . import jwk from .backends import get_random_bytes -from .constants import ALGORITHMS, JWE_SIZE_LIMIT, ZIPS +from .constants import ALGORITHMS, ZIPS from .exceptions import JWEError, JWEParseError from .utils import base64url_decode, base64url_encode, ensure_binary def encrypt(plaintext, key, encryption=ALGORITHMS.A256GCM, algorithm=ALGORITHMS.DIR, zip=None, cty=None, kid=None): - """Encrypts plaintext and returns a JWE compact serialization string. + """Encrypts plaintext and returns a JWE cmpact serialization string. Args: plaintext (bytes): A bytes object to encrypt @@ -76,13 +76,6 @@ def decrypt(jwe_str, key): >>> jwe.decrypt(jwe_string, 'asecret128bitkey') 'Hello, World!' """ - - # Limit the token size - if the data is compressed then decompressing the - # data could lead to large memory usage. This helps address This addresses - # CVE-2024-33664. Also see _decompress() - if len(jwe_str) > JWE_SIZE_LIMIT: - raise JWEError(f"JWE string {len(jwe_str)} bytes exceeds {JWE_SIZE_LIMIT} bytes") - header, encoded_header, encrypted_key, iv, cipher_text, auth_tag = _jwe_compact_deserialize(jwe_str) # Verify that the implementation understands and can process all @@ -431,13 +424,13 @@ def _compress(zip, plaintext): (bytes): Compressed plaintext """ if zip not in ZIPS.SUPPORTED: - raise NotImplementedError(f"ZIP {zip} is not supported!") + raise NotImplementedError("ZIP {} is not supported!") if zip is None: compressed = plaintext elif zip == ZIPS.DEF: compressed = zlib.compress(plaintext) else: - raise NotImplementedError(f"ZIP {zip} is not implemented!") + raise NotImplementedError("ZIP {} is not implemented!") return compressed @@ -453,18 +446,13 @@ def _decompress(zip, compressed): (bytes): Compressed plaintext """ if zip not in ZIPS.SUPPORTED: - raise NotImplementedError(f"ZIP {zip} is not supported!") + raise NotImplementedError("ZIP {} is not supported!") if zip is None: decompressed = compressed elif zip == ZIPS.DEF: - # If, during decompression, there is more data than expected, the - # decompression halts and raise an error. This addresses CVE-2024-33664 - decompressor = zlib.decompressobj() - decompressed = decompressor.decompress(compressed, max_length=JWE_SIZE_LIMIT) - if decompressor.unconsumed_tail: - raise JWEError(f"Decompressed JWE string exceeds {JWE_SIZE_LIMIT} bytes") + decompressed = zlib.decompress(compressed) else: - raise NotImplementedError(f"ZIP {zip} is not implemented!") + raise NotImplementedError("ZIP {} is not implemented!") return decompressed @@ -542,7 +530,7 @@ def _get_key_wrap_cek(enc, key): def _get_random_cek_bytes_for_enc(enc): """ - Get the random cek bytes based on the encryption algorithm + Get the random cek bytes based on the encryptionn algorithm Args: enc (str): Encryption algorithm diff --git a/venv/lib/python3.12/site-packages/jose/jwk.py b/venv/lib/python3.12/site-packages/jose/jwk.py index 2a31847..7afc054 100644 --- a/venv/lib/python3.12/site-packages/jose/jwk.py +++ b/venv/lib/python3.12/site-packages/jose/jwk.py @@ -71,9 +71,9 @@ def construct(key_data, algorithm=None): algorithm = key_data.get("alg", None) if not algorithm: - raise JWKError("Unable to find an algorithm for key") + raise JWKError("Unable to find an algorithm for key: %s" % key_data) key_class = get_key(algorithm) if not key_class: - raise JWKError("Unable to find an algorithm for key") + raise JWKError("Unable to find an algorithm for key: %s" % key_data) return key_class(key_data, algorithm) diff --git a/venv/lib/python3.12/site-packages/jose/jws.py b/venv/lib/python3.12/site-packages/jose/jws.py index 27f6b79..bfaf6bd 100644 --- a/venv/lib/python3.12/site-packages/jose/jws.py +++ b/venv/lib/python3.12/site-packages/jose/jws.py @@ -1,10 +1,6 @@ import binascii import json - -try: - from collections.abc import Iterable, Mapping -except ImportError: - from collections import Mapping, Iterable +from collections.abc import Iterable, Mapping from jose import jwk from jose.backends.base import Key @@ -219,6 +215,7 @@ def _sig_matches_keys(keys, signing_input, signature, alg): def _get_keys(key): + if isinstance(key, Key): return (key,) @@ -251,6 +248,7 @@ def _get_keys(key): def _verify_signature(signing_input, header, signature, key="", algorithms=None): + alg = header.get("alg") if not alg: raise JWSError("No algorithm was specified in the JWS header.") diff --git a/venv/lib/python3.12/site-packages/jose/jwt.py b/venv/lib/python3.12/site-packages/jose/jwt.py index f47e4dd..3f2142e 100644 --- a/venv/lib/python3.12/site-packages/jose/jwt.py +++ b/venv/lib/python3.12/site-packages/jose/jwt.py @@ -1,19 +1,8 @@ import json from calendar import timegm +from collections.abc import Mapping from datetime import datetime, timedelta -try: - from collections.abc import Mapping -except ImportError: - from collections import Mapping - -try: - from datetime import UTC # Preferred in Python 3.13+ -except ImportError: - from datetime import timezone - - UTC = timezone.utc # Preferred in Python 3.12 and below - from jose import jws from .constants import ALGORITHMS @@ -53,6 +42,7 @@ def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=N """ for time_claim in ["exp", "iat", "nbf"]: + # Convert datetime to a intDate value in known time-format claims if isinstance(claims.get(time_claim), datetime): claims[time_claim] = timegm(claims[time_claim].utctimetuple()) @@ -68,15 +58,8 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None Args: token (str): A signed JWS to be verified. - key (str or iterable): A key to attempt to verify the payload with. - This can be simple string with an individual key (e.g. "a1234"), - a tuple or list of keys (e.g. ("a1234...", "b3579"), - a JSON string, (e.g. '["a1234", "b3579"]'), - a dict with the 'keys' key that gives a tuple or list of keys (e.g {'keys': [...]} ) or - a dict or JSON string for a JWK set as defined by RFC 7517 (e.g. - {'keys': [{'kty': 'oct', 'k': 'YTEyMzQ'}, {'kty': 'oct', 'k':'YjM1Nzk'}]} or - '{"keys": [{"kty":"oct","k":"YTEyMzQ"},{"kty":"oct","k":"YjM1Nzk"}]}' - ) in which case the keys must be base64 url safe encoded (with optional padding). + key (str or dict): A key to attempt to verify the payload with. Can be + individual JWK or JWK set. algorithms (str or list): Valid algorithms that should be used to verify the JWS. audience (str): The intended audience of the token. If the "aud" claim is included in the claim set, then the audience must be included and must equal @@ -295,7 +278,7 @@ def _validate_nbf(claims, leeway=0): except ValueError: raise JWTClaimsError("Not Before claim (nbf) must be an integer.") - now = timegm(datetime.now(UTC).utctimetuple()) + now = timegm(datetime.utcnow().utctimetuple()) if nbf > (now + leeway): raise JWTClaimsError("The token is not yet valid (nbf)") @@ -325,7 +308,7 @@ def _validate_exp(claims, leeway=0): except ValueError: raise JWTClaimsError("Expiration Time claim (exp) must be an integer.") - now = timegm(datetime.now(UTC).utctimetuple()) + now = timegm(datetime.utcnow().utctimetuple()) if exp < (now - leeway): raise ExpiredSignatureError("Signature has expired.") @@ -399,7 +382,7 @@ def _validate_sub(claims, subject=None): "sub" value is a case-sensitive string containing a StringOrURI value. Use of this claim is OPTIONAL. - Arg + Args: claims (dict): The claims dictionary to validate. subject (str): The subject of the token. """ @@ -473,6 +456,7 @@ def _validate_at_hash(claims, access_token, algorithm): def _validate_claims(claims, audience=None, issuer=None, subject=None, algorithm=None, access_token=None, options=None): + leeway = options.get("leeway", 0) if isinstance(leeway, timedelta): diff --git a/venv/lib/python3.12/site-packages/jose/utils.py b/venv/lib/python3.12/site-packages/jose/utils.py index d62cafb..fcef885 100644 --- a/venv/lib/python3.12/site-packages/jose/utils.py +++ b/venv/lib/python3.12/site-packages/jose/utils.py @@ -1,5 +1,4 @@ import base64 -import re import struct # Piggyback of the backends implementation of the function that converts a long @@ -10,6 +9,7 @@ try: def long_to_bytes(n, blocksize=0): return _long_to_bytes(n, blocksize or None) + except ImportError: from ecdsa.ecdsa import int_to_string as _long_to_bytes @@ -67,7 +67,7 @@ def base64url_decode(input): """Helper method to base64url_decode a string. Args: - input (bytes): A base64url_encoded string (bytes) to decode. + input (str): A base64url_encoded string to decode. """ rem = len(input) % 4 @@ -82,7 +82,7 @@ def base64url_encode(input): """Helper method to base64url_encode a string. Args: - input (bytes): A base64url_encoded string (bytes) to encode. + input (str): A base64url_encoded string to encode. """ return base64.urlsafe_b64encode(input).replace(b"=", b"") @@ -106,60 +106,3 @@ def ensure_binary(s): if isinstance(s, str): return s.encode("utf-8", "strict") raise TypeError(f"not expecting type '{type(s)}'") - - -# The following was copied from PyJWT: -# https://github.com/jpadilla/pyjwt/commit/9c528670c455b8d948aff95ed50e22940d1ad3fc -# Based on: -# https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 -_PEMS = { - b"CERTIFICATE", - b"TRUSTED CERTIFICATE", - b"PRIVATE KEY", - b"PUBLIC KEY", - b"ENCRYPTED PRIVATE KEY", - b"OPENSSH PRIVATE KEY", - b"DSA PRIVATE KEY", - b"RSA PRIVATE KEY", - b"RSA PUBLIC KEY", - b"EC PRIVATE KEY", - b"DH PARAMETERS", - b"NEW CERTIFICATE REQUEST", - b"CERTIFICATE REQUEST", - b"SSH2 PUBLIC KEY", - b"SSH2 ENCRYPTED PRIVATE KEY", - b"X509 CRL", -} -_PEM_RE = re.compile( - b"----[- ]BEGIN (" + b"|".join(re.escape(pem) for pem in _PEMS) + b")[- ]----", -) - - -def is_pem_format(key: bytes) -> bool: - return bool(_PEM_RE.search(key)) - - -# Based on -# https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b -# /src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 -_CERT_SUFFIX = b"-cert-v01@openssh.com" -_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") -_SSH_KEY_FORMATS = [ - b"ssh-ed25519", - b"ssh-rsa", - b"ssh-dss", - b"ecdsa-sha2-nistp256", - b"ecdsa-sha2-nistp384", - b"ecdsa-sha2-nistp521", -] - - -def is_ssh_key(key: bytes) -> bool: - if any(string_value in key for string_value in _SSH_KEY_FORMATS): - return True - ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) - if ssh_pubkey_match: - key_type = ssh_pubkey_match.group(1) - if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: - return True - return False diff --git a/venv/lib/python3.12/site-packages/jwt/__init__.py b/venv/lib/python3.12/site-packages/jwt/__init__.py index 457a4e3..68d09c1 100644 --- a/venv/lib/python3.12/site-packages/jwt/__init__.py +++ b/venv/lib/python3.12/site-packages/jwt/__init__.py @@ -6,7 +6,7 @@ from .api_jws import ( register_algorithm, unregister_algorithm, ) -from .api_jwt import PyJWT, decode, decode_complete, encode +from .api_jwt import PyJWT, decode, encode from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -27,7 +27,7 @@ from .exceptions import ( ) from .jwks_client import PyJWKClient -__version__ = "2.10.1" +__version__ = "2.8.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" @@ -49,7 +49,6 @@ __all__ = [ "PyJWK", "PyJWKSet", "decode", - "decode_complete", "encode", "get_unverified_header", "register_algorithm", diff --git a/venv/lib/python3.12/site-packages/jwt/algorithms.py b/venv/lib/python3.12/site-packages/jwt/algorithms.py index ccb1500..ed18715 100644 --- a/venv/lib/python3.12/site-packages/jwt/algorithms.py +++ b/venv/lib/python3.12/site-packages/jwt/algorithms.py @@ -3,8 +3,9 @@ from __future__ import annotations import hashlib import hmac import json +import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -20,8 +21,14 @@ from .utils import ( to_base64url_uint, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + try: - from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm + from cryptography.exceptions import InvalidSignature from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding @@ -187,16 +194,18 @@ class Algorithm(ABC): @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: + def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: """ Serializes a given key into a JWK """ @@ -265,18 +274,16 @@ class HMACAlgorithm(Algorithm): @overload @staticmethod - def to_jwk( - key_obj: str | bytes, as_dict: Literal[True] - ) -> JWKDict: ... # pragma: no cover + def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover @overload @staticmethod - def to_jwk( - key_obj: str | bytes, as_dict: Literal[False] = False - ) -> str: ... # pragma: no cover + def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: jwk = { "k": base64url_encode(force_bytes(key_obj)).decode(), "kty": "oct", @@ -297,7 +304,7 @@ class HMACAlgorithm(Algorithm): else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") from None + raise InvalidKeyError("Key is not valid JSON") if obj.get("kty") != "oct": raise InvalidKeyError("Not an HMAC key") @@ -343,27 +350,22 @@ if has_crypto: RSAPrivateKey, load_pem_private_key(key_bytes, password=None) ) except ValueError: - try: - return cast(RSAPublicKey, load_pem_public_key(key_bytes)) - except (ValueError, UnsupportedAlgorithm): - raise InvalidKeyError( - "Could not parse the provided public key." - ) from None + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) @overload @staticmethod - def to_jwk( - key_obj: AllowedRSAKeys, as_dict: Literal[True] - ) -> JWKDict: ... # pragma: no cover + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover @overload @staticmethod - def to_jwk( - key_obj: AllowedRSAKeys, as_dict: Literal[False] = False - ) -> str: ... # pragma: no cover + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): @@ -411,10 +413,10 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") from None + raise InvalidKeyError("Key is not valid JSON") if obj.get("kty") != "RSA": - raise InvalidKeyError("Not an RSA key") from None + raise InvalidKeyError("Not an RSA key") if "d" in obj and "e" in obj and "n" in obj: # Private key @@ -430,7 +432,7 @@ if has_crypto: if any_props_found and not all(props_found): raise InvalidKeyError( "RSA key must include all parameters if any are present besides d" - ) from None + ) public_numbers = RSAPublicNumbers( from_base64url_uint(obj["e"]), @@ -522,7 +524,7 @@ if has_crypto: ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" - ) from None + ) return crypto_key @@ -531,7 +533,7 @@ if has_crypto: return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: + def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: @@ -550,18 +552,18 @@ if has_crypto: @overload @staticmethod - def to_jwk( - key_obj: AllowedECKeys, as_dict: Literal[True] - ) -> JWKDict: ... # pragma: no cover + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover @overload @staticmethod - def to_jwk( - key_obj: AllowedECKeys, as_dict: Literal[False] = False - ) -> str: ... # pragma: no cover + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: + def to_jwk( + key_obj: AllowedECKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -583,20 +585,13 @@ if has_crypto: obj: dict[str, Any] = { "kty": "EC", "crv": crv, - "x": to_base64url_uint( - public_numbers.x, - bit_length=key_obj.curve.key_size, - ).decode(), - "y": to_base64url_uint( - public_numbers.y, - bit_length=key_obj.curve.key_size, - ).decode(), + "x": to_base64url_uint(public_numbers.x).decode(), + "y": to_base64url_uint(public_numbers.y).decode(), } if isinstance(key_obj, EllipticCurvePrivateKey): obj["d"] = to_base64url_uint( - key_obj.private_numbers().private_value, - bit_length=key_obj.curve.key_size, + key_obj.private_numbers().private_value ).decode() if as_dict: @@ -614,13 +609,13 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") from None + raise InvalidKeyError("Key is not valid JSON") if obj.get("kty") != "EC": - raise InvalidKeyError("Not an Elliptic curve key") from None + raise InvalidKeyError("Not an Elliptic curve key") if "x" not in obj or "y" not in obj: - raise InvalidKeyError("Not an Elliptic curve key") from None + raise InvalidKeyError("Not an Elliptic curve key") x = base64url_decode(obj.get("x")) y = base64url_decode(obj.get("y")) @@ -632,23 +627,17 @@ if has_crypto: if len(x) == len(y) == 32: curve_obj = SECP256R1() else: - raise InvalidKeyError( - "Coords should be 32 bytes for curve P-256" - ) from None + raise InvalidKeyError("Coords should be 32 bytes for curve P-256") elif curve == "P-384": if len(x) == len(y) == 48: curve_obj = SECP384R1() else: - raise InvalidKeyError( - "Coords should be 48 bytes for curve P-384" - ) from None + raise InvalidKeyError("Coords should be 48 bytes for curve P-384") elif curve == "P-521": if len(x) == len(y) == 66: curve_obj = SECP521R1() else: - raise InvalidKeyError( - "Coords should be 66 bytes for curve P-521" - ) from None + raise InvalidKeyError("Coords should be 66 bytes for curve P-521") elif curve == "secp256k1": if len(x) == len(y) == 32: curve_obj = SECP256K1() @@ -782,18 +771,16 @@ if has_crypto: @overload @staticmethod - def to_jwk( - key: AllowedOKPKeys, as_dict: Literal[True] - ) -> JWKDict: ... # pragma: no cover + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover @overload @staticmethod - def to_jwk( - key: AllowedOKPKeys, as_dict: Literal[False] = False - ) -> str: ... # pragma: no cover + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, @@ -849,7 +836,7 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") from None + raise InvalidKeyError("Key is not valid JSON") if obj.get("kty") != "OKP": raise InvalidKeyError("Not an Octet Key Pair") diff --git a/venv/lib/python3.12/site-packages/jwt/api_jwk.py b/venv/lib/python3.12/site-packages/jwt/api_jwk.py index 02f4679..456c7f4 100644 --- a/venv/lib/python3.12/site-packages/jwt/api_jwk.py +++ b/venv/lib/python3.12/site-packages/jwt/api_jwk.py @@ -5,13 +5,7 @@ import time from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography -from .exceptions import ( - InvalidKeyError, - MissingCryptographyError, - PyJWKError, - PyJWKSetError, - PyJWTError, -) +from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError from .types import JWKDict @@ -56,25 +50,21 @@ class PyJWK: raise InvalidKeyError(f"Unsupported kty: {kty}") if not has_crypto and algorithm in requires_cryptography: - raise MissingCryptographyError( - f"{algorithm} requires 'cryptography' to be installed." - ) + raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") - self.algorithm_name = algorithm + self.Algorithm = self._algorithms.get(algorithm) - if algorithm in self._algorithms: - self.Algorithm = self._algorithms[algorithm] - else: + if not self.Algorithm: raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}") self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK: + def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod - def from_json(data: str, algorithm: None = None) -> PyJWK: + def from_json(data: str, algorithm: None = None) -> "PyJWK": obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @@ -104,9 +94,7 @@ class PyJWKSet: for key in keys: try: self.keys.append(PyJWK(key)) - except PyJWTError as error: - if isinstance(error, MissingCryptographyError): - raise error + except PyJWTError: # skip unusable keys continue @@ -116,16 +104,16 @@ class PyJWKSet: ) @staticmethod - def from_dict(obj: dict[str, Any]) -> PyJWKSet: + def from_dict(obj: dict[str, Any]) -> "PyJWKSet": keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data: str) -> PyJWKSet: + def from_json(data: str) -> "PyJWKSet": obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid: str) -> PyJWK: + def __getitem__(self, kid: str) -> "PyJWK": for key in self.keys: if key.key_id == kid: return key diff --git a/venv/lib/python3.12/site-packages/jwt/api_jws.py b/venv/lib/python3.12/site-packages/jwt/api_jws.py index 654ee0b..fa6708c 100644 --- a/venv/lib/python3.12/site-packages/jwt/api_jws.py +++ b/venv/lib/python3.12/site-packages/jwt/api_jws.py @@ -3,7 +3,6 @@ from __future__ import annotations import binascii import json import warnings -from collections.abc import Sequence from typing import TYPE_CHECKING, Any from .algorithms import ( @@ -12,7 +11,6 @@ from .algorithms import ( has_crypto, requires_cryptography, ) -from .api_jwk import PyJWK from .exceptions import ( DecodeError, InvalidAlgorithmError, @@ -31,7 +29,7 @@ class PyJWS: def __init__( self, - algorithms: Sequence[str] | None = None, + algorithms: list[str] | None = None, options: dict[str, Any] | None = None, ) -> None: self._algorithms = get_default_algorithms() @@ -105,8 +103,8 @@ class PyJWS: def encode( self, payload: bytes, - key: AllowedPrivateKeys | PyJWK | str | bytes, - algorithm: str | None = None, + key: AllowedPrivateKeys | str | bytes, + algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, @@ -115,13 +113,7 @@ class PyJWS: segments = [] # declare a new var to narrow the type for type checkers - if algorithm is None: - if isinstance(key, PyJWK): - algorithm_ = key.algorithm_name - else: - algorithm_ = "HS256" - else: - algorithm_ = algorithm + algorithm_: str = algorithm if algorithm is not None else "none" # Prefer headers values if present to function parameters. if headers: @@ -165,8 +157,6 @@ class PyJWS: signing_input = b".".join(segments) alg_obj = self.get_algorithm_by_name(algorithm_) - if isinstance(key, PyJWK): - key = key.key key = alg_obj.prepare_key(key) signature = alg_obj.sign(signing_input, key) @@ -182,8 +172,8 @@ class PyJWS: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: Sequence[str] | None = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, @@ -194,14 +184,13 @@ class PyJWS: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, - stacklevel=2, ) if options is None: options = {} merged_options = {**self.options, **options} verify_signature = merged_options["verify_signature"] - if verify_signature and not algorithms and not isinstance(key, PyJWK): + if verify_signature and not algorithms: raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) @@ -228,8 +217,8 @@ class PyJWS: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: Sequence[str] | None = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, @@ -240,7 +229,6 @@ class PyJWS: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, - stacklevel=2, ) decoded = self.decode_complete( jwt, key, algorithms, options, detached_payload=detached_payload @@ -301,28 +289,22 @@ class PyJWS: signing_input: bytes, header: dict[str, Any], signature: bytes, - key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: Sequence[str] | None = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, ) -> None: - if algorithms is None and isinstance(key, PyJWK): - algorithms = [key.algorithm_name] try: alg = header["alg"] except KeyError: - raise InvalidAlgorithmError("Algorithm not specified") from None + raise InvalidAlgorithmError("Algorithm not specified") if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") - if isinstance(key, PyJWK): - alg_obj = key.Algorithm - prepared_key = key.key - else: - try: - alg_obj = self.get_algorithm_by_name(alg) - except NotImplementedError as e: - raise InvalidAlgorithmError("Algorithm not supported") from e - prepared_key = alg_obj.prepare_key(key) + try: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: + raise InvalidAlgorithmError("Algorithm not supported") from e + prepared_key = alg_obj.prepare_key(key) if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") diff --git a/venv/lib/python3.12/site-packages/jwt/api_jwt.py b/venv/lib/python3.12/site-packages/jwt/api_jwt.py index 3a20143..48d739a 100644 --- a/venv/lib/python3.12/site-packages/jwt/api_jwt.py +++ b/venv/lib/python3.12/site-packages/jwt/api_jwt.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import warnings from calendar import timegm -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any @@ -15,15 +15,12 @@ from .exceptions import ( InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, - InvalidJTIError, - InvalidSubjectError, MissingRequiredClaimError, ) from .warnings import RemovedInPyjwt3Warning if TYPE_CHECKING: from .algorithms import AllowedPrivateKeys, AllowedPublicKeys - from .api_jwk import PyJWK class PyJWT: @@ -41,16 +38,14 @@ class PyJWT: "verify_iat": True, "verify_aud": True, "verify_iss": True, - "verify_sub": True, - "verify_jti": True, "require": [], } def encode( self, payload: dict[str, Any], - key: AllowedPrivateKeys | PyJWK | str | bytes, - algorithm: str | None = None, + key: AllowedPrivateKeys | str | bytes, + algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, @@ -105,8 +100,8 @@ class PyJWT: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: Sequence[str] | None = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, @@ -115,8 +110,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | Sequence[str] | None = None, - subject: str | None = None, + issuer: str | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -127,7 +121,6 @@ class PyJWT: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, - stacklevel=2, ) options = dict(options or {}) # shallow-copy or initialize an empty dict options.setdefault("verify_signature", True) @@ -141,7 +134,6 @@ class PyJWT: "The equivalent is setting `verify_signature` to False in the `options` dictionary. " "This invocation has a mismatch between the kwarg and the option entry.", category=DeprecationWarning, - stacklevel=2, ) if not options["verify_signature"]: @@ -150,8 +142,11 @@ class PyJWT: options.setdefault("verify_iat", False) options.setdefault("verify_aud", False) options.setdefault("verify_iss", False) - options.setdefault("verify_sub", False) - options.setdefault("verify_jti", False) + + if options["verify_signature"] and not algorithms: + raise DecodeError( + 'It is required that you pass in a value for the "algorithms" argument when calling decode().' + ) decoded = api_jws.decode_complete( jwt, @@ -165,12 +160,7 @@ class PyJWT: merged_options = {**self.options, **options} self._validate_claims( - payload, - merged_options, - audience=audience, - issuer=issuer, - leeway=leeway, - subject=subject, + payload, merged_options, audience=audience, issuer=issuer, leeway=leeway ) decoded["payload"] = payload @@ -187,7 +177,7 @@ class PyJWT: try: payload = json.loads(decoded["payload"]) except ValueError as e: - raise DecodeError(f"Invalid payload string: {e}") from e + raise DecodeError(f"Invalid payload string: {e}") if not isinstance(payload, dict): raise DecodeError("Invalid payload string: must be a json object") return payload @@ -195,8 +185,8 @@ class PyJWT: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: Sequence[str] | None = None, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, @@ -205,8 +195,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - subject: str | None = None, - issuer: str | Sequence[str] | None = None, + issuer: str | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -217,7 +206,6 @@ class PyJWT: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, - stacklevel=2, ) decoded = self.decode_complete( jwt, @@ -227,7 +215,6 @@ class PyJWT: verify=verify, detached_payload=detached_payload, audience=audience, - subject=subject, issuer=issuer, leeway=leeway, ) @@ -239,7 +226,6 @@ class PyJWT: options: dict[str, Any], audience=None, issuer=None, - subject: str | None = None, leeway: float | timedelta = 0, ) -> None: if isinstance(leeway, timedelta): @@ -269,12 +255,6 @@ class PyJWT: payload, audience, strict=options.get("strict_aud", False) ) - if options["verify_sub"]: - self._validate_sub(payload, subject) - - if options["verify_jti"]: - self._validate_jti(payload) - def _validate_required_claims( self, payload: dict[str, Any], @@ -284,39 +264,6 @@ class PyJWT: if payload.get(claim) is None: raise MissingRequiredClaimError(claim) - def _validate_sub(self, payload: dict[str, Any], subject=None) -> None: - """ - Checks whether "sub" if in the payload is valid ot not. - This is an Optional claim - - :param payload(dict): The payload which needs to be validated - :param subject(str): The subject of the token - """ - - if "sub" not in payload: - return - - if not isinstance(payload["sub"], str): - raise InvalidSubjectError("Subject must be a string") - - if subject is not None: - if payload.get("sub") != subject: - raise InvalidSubjectError("Invalid subject") - - def _validate_jti(self, payload: dict[str, Any]) -> None: - """ - Checks whether "jti" if in the payload is valid ot not - This is an Optional claim - - :param payload(dict): The payload which needs to be validated - """ - - if "jti" not in payload: - return - - if not isinstance(payload.get("jti"), str): - raise InvalidJTIError("JWT ID must be a string") - def _validate_iat( self, payload: dict[str, Any], @@ -326,9 +273,7 @@ class PyJWT: try: iat = int(payload["iat"]) except ValueError: - raise InvalidIssuedAtError( - "Issued At claim (iat) must be an integer." - ) from None + raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") if iat > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (iat)") @@ -341,7 +286,7 @@ class PyJWT: try: nbf = int(payload["nbf"]) except ValueError: - raise DecodeError("Not Before claim (nbf) must be an integer.") from None + raise DecodeError("Not Before claim (nbf) must be an integer.") if nbf > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (nbf)") @@ -355,9 +300,7 @@ class PyJWT: try: exp = int(payload["exp"]) except ValueError: - raise DecodeError( - "Expiration Time claim (exp) must be an integer." - ) from None + raise DecodeError("Expiration Time claim (exp) must be an" " integer.") if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") @@ -419,12 +362,8 @@ class PyJWT: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if isinstance(issuer, str): - if payload["iss"] != issuer: - raise InvalidIssuerError("Invalid issuer") - else: - if payload["iss"] not in issuer: - raise InvalidIssuerError("Invalid issuer") + if payload["iss"] != issuer: + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() diff --git a/venv/lib/python3.12/site-packages/jwt/exceptions.py b/venv/lib/python3.12/site-packages/jwt/exceptions.py index 9b45ae4..8ac6ecf 100644 --- a/venv/lib/python3.12/site-packages/jwt/exceptions.py +++ b/venv/lib/python3.12/site-packages/jwt/exceptions.py @@ -58,10 +58,6 @@ class PyJWKError(PyJWTError): pass -class MissingCryptographyError(PyJWKError): - pass - - class PyJWKSetError(PyJWTError): pass @@ -72,11 +68,3 @@ class PyJWKClientError(PyJWTError): class PyJWKClientConnectionError(PyJWKClientError): pass - - -class InvalidSubjectError(InvalidTokenError): - pass - - -class InvalidJTIError(InvalidTokenError): - pass diff --git a/venv/lib/python3.12/site-packages/jwt/help.py b/venv/lib/python3.12/site-packages/jwt/help.py index 8e1c228..80b0ca5 100644 --- a/venv/lib/python3.12/site-packages/jwt/help.py +++ b/venv/lib/python3.12/site-packages/jwt/help.py @@ -39,10 +39,7 @@ def info() -> Dict[str, Dict[str, str]]: ) if pypy_version_info.releaselevel != "final": implementation_version = "".join( - [ - implementation_version, - pypy_version_info.releaselevel, - ] + [implementation_version, pypy_version_info.releaselevel] ) else: implementation_version = "Unknown" diff --git a/venv/lib/python3.12/site-packages/jwt/jwks_client.py b/venv/lib/python3.12/site-packages/jwt/jwks_client.py index 9a8992c..f19b10a 100644 --- a/venv/lib/python3.12/site-packages/jwt/jwks_client.py +++ b/venv/lib/python3.12/site-packages/jwt/jwks_client.py @@ -45,9 +45,7 @@ class PyJWKClient: if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) - self.get_signing_key = lru_cache(maxsize=max_cached_keys)( - self.get_signing_key - ) # type: ignore + self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore def fetch_data(self) -> Any: jwk_set: Any = None @@ -60,7 +58,7 @@ class PyJWKClient: except (URLError, TimeoutError) as e: raise PyJWKClientConnectionError( f'Fail to fetch data from the url, err: "{e}"' - ) from e + ) else: return jwk_set finally: diff --git a/venv/lib/python3.12/site-packages/jwt/utils.py b/venv/lib/python3.12/site-packages/jwt/utils.py index 56e89bb..81c5ee4 100644 --- a/venv/lib/python3.12/site-packages/jwt/utils.py +++ b/venv/lib/python3.12/site-packages/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Optional, Union +from typing import Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") -def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes: +def to_base64url_uint(val: int) -> bytes: if val < 0: raise ValueError("Must be a positive integer") - int_bytes = bytes_from_int(val, bit_length=bit_length) + int_bytes = bytes_from_int(val) if len(int_bytes) == 0: int_bytes = b"\x00" @@ -63,10 +63,13 @@ def bytes_to_number(string: bytes) -> int: return int(binascii.b2a_hex(string), 16) -def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes: - if bit_length is None: - bit_length = val.bit_length() - byte_length = (bit_length + 7) // 8 +def bytes_from_int(val: int) -> bytes: + remaining = val + byte_length = 0 + + while remaining != 0: + remaining >>= 8 + byte_length += 1 return val.to_bytes(byte_length, "big", signed=False) @@ -128,15 +131,26 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 -_SSH_KEY_FORMATS = ( +_CERT_SUFFIX = b"-cert-v01@openssh.com" +_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") +_SSH_KEY_FORMATS = [ b"ssh-ed25519", b"ssh-rsa", b"ssh-dss", b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp521", -) +] def is_ssh_key(key: bytes) -> bool: - return key.startswith(_SSH_KEY_FORMATS) + if any(string_value in key for string_value in _SSH_KEY_FORMATS): + return True + + ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) + if ssh_pubkey_match: + key_type = ssh_pubkey_match.group(1) + if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: + return True + + return False diff --git a/venv/lib/python3.12/site-packages/kafka/__init__.py b/venv/lib/python3.12/site-packages/kafka/__init__.py index 41a0140..d5e30af 100644 --- a/venv/lib/python3.12/site-packages/kafka/__init__.py +++ b/venv/lib/python3.12/site-packages/kafka/__init__.py @@ -4,7 +4,7 @@ __title__ = 'kafka' from kafka.version import __version__ __author__ = 'Dana Powers' __license__ = 'Apache License 2.0' -__copyright__ = 'Copyright 2025 Dana Powers, David Arthur, and Contributors' +__copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors' # Set default logging handler to avoid "No handler found" warnings. import logging diff --git a/venv/lib/python3.12/site-packages/kafka/admin/client.py b/venv/lib/python3.12/site-packages/kafka/admin/client.py index 8490fdb..c58da0c 100644 --- a/venv/lib/python3.12/site-packages/kafka/admin/client.py +++ b/venv/lib/python3.12/site-packages/kafka/admin/client.py @@ -1,10 +1,9 @@ -from __future__ import absolute_import, division +from __future__ import absolute_import from collections import defaultdict import copy import logging import socket -import time from . import ConfigResourceType from kafka.vendor import six @@ -15,16 +14,15 @@ from kafka.client_async import KafkaClient, selectors from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol import kafka.errors as Errors from kafka.errors import ( - IncompatibleBrokerVersion, KafkaConfigurationError, UnknownTopicOrPartitionError, + IncompatibleBrokerVersion, KafkaConfigurationError, NotControllerError, UnrecognizedBrokerVersion, IllegalArgumentError) -from kafka.future import Future from kafka.metrics import MetricConfig, Metrics from kafka.protocol.admin import ( CreateTopicsRequest, DeleteTopicsRequest, DescribeConfigsRequest, AlterConfigsRequest, CreatePartitionsRequest, ListGroupsRequest, DescribeGroupsRequest, DescribeAclsRequest, CreateAclsRequest, DeleteAclsRequest, - DeleteGroupsRequest, DeleteRecordsRequest, DescribeLogDirsRequest, ElectLeadersRequest, ElectionType) -from kafka.protocol.commit import OffsetFetchRequest -from kafka.protocol.find_coordinator import FindCoordinatorRequest + DeleteGroupsRequest +) +from kafka.protocol.commit import GroupCoordinatorRequest, OffsetFetchRequest from kafka.protocol.metadata import MetadataRequest from kafka.protocol.types import Array from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation @@ -74,7 +72,7 @@ class KafkaAdminClient(object): reconnection attempts will continue periodically with this fixed rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 30000. + 20% below and 20% above the computed value. Default: 1000. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. connections_max_idle_ms: Close idle connections after the number of @@ -142,17 +140,13 @@ class KafkaAdminClient(object): Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_name (str or gssapi.Name): Constructed gssapi.Name for use with - sasl mechanism handshake. If provided, sasl_kerberos_service_name and - sasl_kerberos_domain name are ignored. Default: None. sasl_kerberos_service_name (str): Service name to include in GSSAPI sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (kafka.sasl.oauth.AbstractTokenProvider): OAuthBearer - token provider instance. Default: None - socks5_proxy (str): Socks5 proxy url. Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + """ DEFAULT_CONFIG = { # client configs @@ -161,7 +155,7 @@ class KafkaAdminClient(object): 'request_timeout_ms': 30000, 'connections_max_idle_ms': 9 * 60 * 1000, 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 30000, + 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, 'receive_buffer_bytes': None, 'send_buffer_bytes': None, @@ -184,17 +178,14 @@ class KafkaAdminClient(object): 'sasl_mechanism': None, 'sasl_plain_username': None, 'sasl_plain_password': None, - 'sasl_kerberos_name': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, 'sasl_oauth_token_provider': None, - 'socks5_proxy': None, # metrics configs 'metric_reporters': [], 'metrics_num_samples': 2, 'metrics_sample_window_ms': 30000, - 'kafka_client': KafkaClient, } def __init__(self, **configs): @@ -214,14 +205,14 @@ class KafkaAdminClient(object): reporters = [reporter() for reporter in self.config['metric_reporters']] self._metrics = Metrics(metric_config, reporters) - self._client = self.config['kafka_client']( - metrics=self._metrics, - metric_group_prefix='admin', - **self.config - ) + self._client = KafkaClient(metrics=self._metrics, + metric_group_prefix='admin', + **self.config) + self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) # Get auto-discovered version from client if necessary - self.config['api_version'] = self._client.config['api_version'] + if self.config['api_version'] is None: + self.config['api_version'] = self._client.config['api_version'] self._closed = False self._refresh_controller_id() @@ -238,44 +229,58 @@ class KafkaAdminClient(object): self._closed = True log.debug("KafkaAdminClient is now closed.") + def _matching_api_version(self, operation): + """Find the latest version of the protocol operation supported by both + this library and the broker. + + This resolves to the lesser of either the latest api version this + library supports, or the max version supported by the broker. + + :param operation: A list of protocol operation versions from kafka.protocol. + :return: The max matching version number between client and broker. + """ + broker_api_versions = self._client.get_api_versions() + api_key = operation[0].API_KEY + if broker_api_versions is None or api_key not in broker_api_versions: + raise IncompatibleBrokerVersion( + "Kafka broker does not support the '{}' Kafka protocol." + .format(operation[0].__name__)) + min_version, max_version = broker_api_versions[api_key] + version = min(len(operation) - 1, max_version) + if version < min_version: + # max library version is less than min broker version. Currently, + # no Kafka versions specify a min msg version. Maybe in the future? + raise IncompatibleBrokerVersion( + "No version of the '{}' Kafka protocol is supported by both the client and broker." + .format(operation[0].__name__)) + return version + def _validate_timeout(self, timeout_ms): """Validate the timeout is set or use the configuration default. - Arguments: - timeout_ms: The timeout provided by api call, in milliseconds. - - Returns: - The timeout to use for the operation. + :param timeout_ms: The timeout provided by api call, in milliseconds. + :return: The timeout to use for the operation. """ return timeout_ms or self.config['request_timeout_ms'] - def _refresh_controller_id(self, timeout_ms=30000): + def _refresh_controller_id(self): """Determine the Kafka cluster controller.""" - version = self._client.api_version(MetadataRequest, max_version=6) + version = self._matching_api_version(MetadataRequest) if 1 <= version <= 6: - timeout_at = time.time() + timeout_ms / 1000 - while time.time() < timeout_at: - request = MetadataRequest[version]() - future = self._send_request_to_node(self._client.least_loaded_node(), request) + request = MetadataRequest[version]() + future = self._send_request_to_node(self._client.least_loaded_node(), request) - self._wait_for_futures([future]) + self._wait_for_futures([future]) - response = future.value - controller_id = response.controller_id - if controller_id == -1: - log.warning("Controller ID not available, got -1") - time.sleep(1) - continue - # verify the controller is new enough to support our requests - controller_version = self._client.check_version(node_id=controller_id) - if controller_version < (0, 10, 0): - raise IncompatibleBrokerVersion( - "The controller appears to be running Kafka {}. KafkaAdminClient requires brokers >= 0.10.0.0." - .format(controller_version)) - self._controller_id = controller_id - return - else: - raise Errors.NodeNotReadyError('controller') + response = future.value + controller_id = response.controller_id + # verify the controller is new enough to support our requests + controller_version = self._client.check_version(controller_id, timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) + if controller_version < (0, 10, 0): + raise IncompatibleBrokerVersion( + "The controller appears to be running Kafka {}. KafkaAdminClient requires brokers >= 0.10.0.0." + .format(controller_version)) + self._controller_id = controller_id else: raise UnrecognizedBrokerVersion( "Kafka Admin interface cannot determine the controller using MetadataRequest_v{}." @@ -284,40 +289,43 @@ class KafkaAdminClient(object): def _find_coordinator_id_send_request(self, group_id): """Send a FindCoordinatorRequest to a broker. - Arguments: - group_id: The consumer group ID. This is typically the group + :param group_id: The consumer group ID. This is typically the group name as a string. - - Returns: - A message future + :return: A message future """ - version = self._client.api_version(FindCoordinatorRequest, max_version=2) + # TODO add support for dynamically picking version of + # GroupCoordinatorRequest which was renamed to FindCoordinatorRequest. + # When I experimented with this, the coordinator value returned in + # GroupCoordinatorResponse_v1 didn't match the value returned by + # GroupCoordinatorResponse_v0 and I couldn't figure out why. + version = 0 + # version = self._matching_api_version(GroupCoordinatorRequest) if version <= 0: - request = FindCoordinatorRequest[version](group_id) - elif version <= 2: - request = FindCoordinatorRequest[version](group_id, 0) + request = GroupCoordinatorRequest[version](group_id) else: raise NotImplementedError( - "Support for FindCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." + "Support for GroupCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." .format(version)) return self._send_request_to_node(self._client.least_loaded_node(), request) def _find_coordinator_id_process_response(self, response): """Process a FindCoordinatorResponse. - Arguments: - response: a FindCoordinatorResponse. - - Returns: - The node_id of the broker that is the coordinator. + :param response: a FindCoordinatorResponse. + :return: The node_id of the broker that is the coordinator. """ - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - # Note: When error_type.retriable, Java will retry... see - # KafkaAdminClient's handleFindCoordinatorError method - raise error_type( - "FindCoordinatorRequest failed with response '{}'." - .format(response)) + if response.API_VERSION <= 0: + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + # Note: When error_type.retriable, Java will retry... see + # KafkaAdminClient's handleFindCoordinatorError method + raise error_type( + "FindCoordinatorRequest failed with response '{}'." + .format(response)) + else: + raise NotImplementedError( + "Support for FindCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) return response.coordinator_id def _find_coordinator_ids(self, group_ids): @@ -327,12 +335,9 @@ class KafkaAdminClient(object): Will block until the FindCoordinatorResponse is received for all groups. Any errors are immediately raised. - Arguments: - group_ids: A list of consumer group IDs. This is typically the group + :param group_ids: A list of consumer group IDs. This is typically the group name as a string. - - Returns: - A dict of {group_id: node_id} where node_id is the id of the + :return: A dict of {group_id: node_id} where node_id is the id of the broker that is the coordinator for the corresponding group. """ groups_futures = { @@ -346,36 +351,29 @@ class KafkaAdminClient(object): } return groups_coordinators - def _send_request_to_node(self, node_id, request, wakeup=True): + def _send_request_to_node(self, node_id, request): """Send a Kafka protocol message to a specific broker. - Arguments: - node_id: The broker id to which to send the message. - request: The message to send. + Returns a future that may be polled for status and results. - - Keyword Arguments: - wakeup (bool, optional): Optional flag to disable thread-wakeup. - - Returns: - A future object that may be polled for status and results. + :param node_id: The broker id to which to send the message. + :param request: The message to send. + :return: A future object that may be polled for status and results. + :exception: The exception if the message could not be sent. """ - try: - self._client.await_ready(node_id) - except Errors.KafkaConnectionError as e: - return Future().failure(e) - return self._client.send(node_id, request, wakeup) + while not self._client.ready(node_id): + # poll until the connection to broker is ready, otherwise send() + # will fail with NodeNotReadyError + self._client.poll() + return self._client.send(node_id, request) def _send_request_to_controller(self, request): """Send a Kafka protocol message to the cluster controller. Will block until the message result is received. - Arguments: - request: The message to send. - - Returns: - The Kafka protocol response for the message. + :param request: The message to send. + :return: The Kafka protocol response for the message. """ tries = 2 # in case our cached self._controller_id is outdated while tries: @@ -391,70 +389,30 @@ class KafkaAdminClient(object): # So this is a little brittle in that it assumes all responses have # one of these attributes and that they always unpack into # (topic, error_code) tuples. - topic_error_tuples = getattr(response, 'topic_errors', getattr(response, 'topic_error_codes', None)) - if topic_error_tuples is not None: - success = self._parse_topic_request_response(topic_error_tuples, request, response, tries) - else: - # Leader Election request has a two layer error response (topic and partition) - success = self._parse_topic_partition_request_response(request, response, tries) - - if success: - return response - raise RuntimeError("This should never happen, please file a bug with full stacktrace if encountered") - - def _parse_topic_request_response(self, topic_error_tuples, request, response, tries): - # Also small py2/py3 compatibility -- py3 can ignore extra values - # during unpack via: for x, y, *rest in list_of_values. py2 cannot. - # So for now we have to map across the list and explicitly drop any - # extra values (usually the error_message) - for topic, error_code in map(lambda e: e[:2], topic_error_tuples): - error_type = Errors.for_code(error_code) - if tries and error_type is Errors.NotControllerError: - # No need to inspect the rest of the errors for - # non-retriable errors because NotControllerError should - # either be thrown for all errors or no errors. - self._refresh_controller_id() - return False - elif error_type is not Errors.NoError: - raise error_type( - "Request '{}' failed with response '{}'." - .format(request, response)) - return True - - def _parse_topic_partition_request_response(self, request, response, tries): - # Also small py2/py3 compatibility -- py3 can ignore extra values - # during unpack via: for x, y, *rest in list_of_values. py2 cannot. - # So for now we have to map across the list and explicitly drop any - # extra values (usually the error_message) - for topic, partition_results in response.replication_election_results: - for partition_id, error_code in map(lambda e: e[:2], partition_results): + topic_error_tuples = (response.topic_errors if hasattr(response, 'topic_errors') + else response.topic_error_codes) + # Also small py2/py3 compatibility -- py3 can ignore extra values + # during unpack via: for x, y, *rest in list_of_values. py2 cannot. + # So for now we have to map across the list and explicitly drop any + # extra values (usually the error_message) + for topic, error_code in map(lambda e: e[:2], topic_error_tuples): error_type = Errors.for_code(error_code) - if tries and error_type is Errors.NotControllerError: + if tries and error_type is NotControllerError: # No need to inspect the rest of the errors for # non-retriable errors because NotControllerError should # either be thrown for all errors or no errors. self._refresh_controller_id() - return False - elif error_type not in (Errors.NoError, Errors.ElectionNotNeededError): + break + elif error_type is not Errors.NoError: raise error_type( "Request '{}' failed with response '{}'." .format(request, response)) - return True + else: + return response + raise RuntimeError("This should never happen, please file a bug with full stacktrace if encountered") @staticmethod def _convert_new_topic_request(new_topic): - """ - Build the tuple required by CreateTopicsRequest from a NewTopic object. - - Arguments: - new_topic: A NewTopic instance containing name, partition count, replication factor, - replica assignments, and config entries. - - Returns: - A tuple in the form: - (topic_name, num_partitions, replication_factor, [(partition_id, [replicas])...], - [(config_key, config_value)...]) - """ return ( new_topic.name, new_topic.num_partitions, @@ -470,19 +428,14 @@ class KafkaAdminClient(object): def create_topics(self, new_topics, timeout_ms=None, validate_only=False): """Create new topics in the cluster. - Arguments: - new_topics: A list of NewTopic objects. - - Keyword Arguments: - timeout_ms (numeric, optional): Milliseconds to wait for new topics to be created - before the broker returns. - validate_only (bool, optional): If True, don't actually create new topics. - Not supported by all versions. Default: False - - Returns: - Appropriate version of CreateTopicResponse class. + :param new_topics: A list of NewTopic objects. + :param timeout_ms: Milliseconds to wait for new topics to be created + before the broker returns. + :param validate_only: If True, don't actually create new topics. + Not supported by all versions. Default: False + :return: Appropriate version of CreateTopicResponse class. """ - version = self._client.api_version(CreateTopicsRequest, max_version=3) + version = self._matching_api_version(CreateTopicsRequest) timeout_ms = self._validate_timeout(timeout_ms) if version == 0: if validate_only: @@ -510,17 +463,12 @@ class KafkaAdminClient(object): def delete_topics(self, topics, timeout_ms=None): """Delete topics from the cluster. - Arguments: - topics ([str]): A list of topic name strings. - - Keyword Arguments: - timeout_ms (numeric, optional): Milliseconds to wait for topics to be deleted - before the broker returns. - - Returns: - Appropriate version of DeleteTopicsResponse class. + :param topics: A list of topic name strings. + :param timeout_ms: Milliseconds to wait for topics to be deleted + before the broker returns. + :return: Appropriate version of DeleteTopicsResponse class. """ - version = self._client.api_version(DeleteTopicsRequest, max_version=3) + version = self._matching_api_version(DeleteTopicsRequest) timeout_ms = self._validate_timeout(timeout_ms) if version <= 3: request = DeleteTopicsRequest[version]( @@ -539,7 +487,7 @@ class KafkaAdminClient(object): """ topics == None means "get all topics" """ - version = self._client.api_version(MetadataRequest, max_version=5) + version = self._matching_api_version(MetadataRequest) if version <= 3: if auto_topic_creation: raise IncompatibleBrokerVersion( @@ -562,38 +510,16 @@ class KafkaAdminClient(object): return future.value def list_topics(self): - """Retrieve a list of all topic names in the cluster. - - Returns: - A list of topic name strings. - """ metadata = self._get_cluster_metadata(topics=None) obj = metadata.to_object() return [t['topic'] for t in obj['topics']] def describe_topics(self, topics=None): - """Fetch metadata for the specified topics or all topics if None. - - Keyword Arguments: - topics ([str], optional) A list of topic names. If None, metadata for all - topics is retrieved. - - Returns: - A list of dicts describing each topic (including partition info). - """ metadata = self._get_cluster_metadata(topics=topics) obj = metadata.to_object() return obj['topics'] def describe_cluster(self): - """ - Fetch cluster-wide metadata such as the list of brokers, the controller ID, - and the cluster ID. - - - Returns: - A dict with cluster-wide metadata, excluding topic details. - """ metadata = self._get_cluster_metadata() obj = metadata.to_object() obj.pop('topics') # We have 'describe_topics' for this @@ -601,15 +527,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_describe_acls_response_to_acls(describe_response): - """Convert a DescribeAclsResponse into a list of ACL objects and a KafkaError. - - Arguments: - describe_response: The response object from the DescribeAclsRequest. - - Returns: - A tuple of (list_of_acl_objects, error) where error is an instance - of KafkaError (NoError if successful). - """ version = describe_response.API_VERSION error = Errors.for_code(describe_response.error_code) @@ -649,14 +566,11 @@ class KafkaAdminClient(object): The cluster must be configured with an authorizer for this to work, or you will get a SecurityDisabledError - Arguments: - acl_filter: an ACLFilter object - - Returns: - tuple of a list of matching ACL objects and a KafkaError (NoError if successful) + :param acl_filter: an ACLFilter object + :return: tuple of a list of matching ACL objects and a KafkaError (NoError if successful) """ - version = self._client.api_version(DescribeAclsRequest, max_version=1) + version = self._matching_api_version(DescribeAclsRequest) if version == 0: request = DescribeAclsRequest[version]( resource_type=acl_filter.resource_pattern.resource_type, @@ -698,14 +612,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_create_acls_resource_request_v0(acl): - """Convert an ACL object into the CreateAclsRequest v0 format. - - Arguments: - acl: An ACL object with resource pattern and permissions. - - Returns: - A tuple: (resource_type, resource_name, principal, host, operation, permission_type). - """ return ( acl.resource_pattern.resource_type, @@ -718,14 +624,7 @@ class KafkaAdminClient(object): @staticmethod def _convert_create_acls_resource_request_v1(acl): - """Convert an ACL object into the CreateAclsRequest v1 format. - Arguments: - acl: An ACL object with resource pattern and permissions. - - Returns: - A tuple: (resource_type, resource_name, pattern_type, principal, host, operation, permission_type). - """ return ( acl.resource_pattern.resource_type, acl.resource_pattern.resource_name, @@ -738,19 +637,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_create_acls_response_to_acls(acls, create_response): - """Parse CreateAclsResponse and correlate success/failure with original ACL objects. - - Arguments: - acls: A list of ACL objects that were requested for creation. - create_response: The broker's CreateAclsResponse object. - - Returns: - A dict with: - { - 'succeeded': [list of ACL objects successfully created], - 'failed': [(acl_object, KafkaError), ...] - } - """ version = create_response.API_VERSION creations_error = [] @@ -779,18 +665,15 @@ class KafkaAdminClient(object): This endpoint only accepts a list of concrete ACL objects, no ACLFilters. Throws TopicAlreadyExistsError if topic is already present. - Arguments: - acls: a list of ACL objects - - Returns: - dict of successes and failures + :param acls: a list of ACL objects + :return: dict of successes and failures """ for acl in acls: if not isinstance(acl, ACL): raise IllegalArgumentError("acls must contain ACL objects") - version = self._client.api_version(CreateAclsRequest, max_version=1) + version = self._matching_api_version(CreateAclsRequest) if version == 0: request = CreateAclsRequest[version]( creations=[self._convert_create_acls_resource_request_v0(acl) for acl in acls] @@ -813,14 +696,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_delete_acls_resource_request_v0(acl): - """Convert an ACLFilter object into the DeleteAclsRequest v0 format. - - Arguments: - acl: An ACLFilter object identifying the ACLs to be deleted. - - Returns: - A tuple: (resource_type, resource_name, principal, host, operation, permission_type). - """ return ( acl.resource_pattern.resource_type, acl.resource_pattern.resource_name, @@ -832,14 +707,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_delete_acls_resource_request_v1(acl): - """Convert an ACLFilter object into the DeleteAclsRequest v1 format. - - Arguments: - acl: An ACLFilter object identifying the ACLs to be deleted. - - Returns: - A tuple: (resource_type, resource_name, pattern_type, principal, host, operation, permission_type). - """ return ( acl.resource_pattern.resource_type, acl.resource_pattern.resource_name, @@ -852,16 +719,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_delete_acls_response_to_matching_acls(acl_filters, delete_response): - """Parse the DeleteAclsResponse and map the results back to each input ACLFilter. - - Arguments: - acl_filters: A list of ACLFilter objects that were provided in the request. - delete_response: The response from the DeleteAclsRequest. - - Returns: - A list of tuples of the form: - (acl_filter, [(matching_acl, KafkaError), ...], filter_level_error). - """ version = delete_response.API_VERSION filter_result_list = [] for i, filter_responses in enumerate(delete_response.filter_responses): @@ -900,11 +757,8 @@ class KafkaAdminClient(object): Deletes all ACLs matching the list of input ACLFilter - Arguments: - acl_filters: a list of ACLFilter - - Returns: - a list of 3-tuples corresponding to the list of input filters. + :param acl_filters: a list of ACLFilter + :return: a list of 3-tuples corresponding to the list of input filters. The tuples hold (the input ACLFilter, list of affected ACLs, KafkaError instance) """ @@ -912,7 +766,7 @@ class KafkaAdminClient(object): if not isinstance(acl, ACLFilter): raise IllegalArgumentError("acl_filters must contain ACLFilter type objects") - version = self._client.api_version(DeleteAclsRequest, max_version=1) + version = self._matching_api_version(DeleteAclsRequest) if version == 0: request = DeleteAclsRequest[version]( @@ -936,14 +790,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_describe_config_resource_request(config_resource): - """Convert a ConfigResource into the format required by DescribeConfigsRequest. - - Arguments: - config_resource: A ConfigResource with resource_type, name, and optional config keys. - - Returns: - A tuple: (resource_type, resource_name, [list_of_config_keys] or None). - """ return ( config_resource.resource_type, config_resource.name, @@ -955,18 +801,13 @@ class KafkaAdminClient(object): def describe_configs(self, config_resources, include_synonyms=False): """Fetch configuration parameters for one or more Kafka resources. - Arguments: - config_resources: An list of ConfigResource objects. - Any keys in ConfigResource.configs dict will be used to filter the - result. Setting the configs dict to None will get all values. An - empty dict will get zero values (as per Kafka protocol). - - Keyword Arguments: - include_synonyms (bool, optional): If True, return synonyms in response. Not - supported by all versions. Default: False. - - Returns: - Appropriate version of DescribeConfigsResponse class. + :param config_resources: An list of ConfigResource objects. + Any keys in ConfigResource.configs dict will be used to filter the + result. Setting the configs dict to None will get all values. An + empty dict will get zero values (as per Kafka protocol). + :param include_synonyms: If True, return synonyms in response. Not + supported by all versions. Default: False. + :return: Appropriate version of DescribeConfigsResponse class. """ # Break up requests by type - a broker config request must be sent to the specific broker. @@ -981,7 +822,7 @@ class KafkaAdminClient(object): topic_resources.append(self._convert_describe_config_resource_request(config_resource)) futures = [] - version = self._client.api_version(DescribeConfigsRequest, max_version=2) + version = self._matching_api_version(DescribeConfigsRequest) if version == 0: if include_synonyms: raise IncompatibleBrokerVersion( @@ -1035,14 +876,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_alter_config_resource_request(config_resource): - """Convert a ConfigResource into the format required by AlterConfigsRequest. - - Arguments: - config_resource: A ConfigResource with resource_type, name, and config (key, value) pairs. - - Returns: - A tuple: (resource_type, resource_name, [(config_key, config_value), ...]). - """ return ( config_resource.resource_type, config_resource.name, @@ -1060,13 +893,10 @@ class KafkaAdminClient(object): least-loaded node. See the comment in the source code for details. We would happily accept a PR fixing this. - Arguments: - config_resources: A list of ConfigResource objects. - - Returns: - Appropriate version of AlterConfigsResponse class. + :param config_resources: A list of ConfigResource objects. + :return: Appropriate version of AlterConfigsResponse class. """ - version = self._client.api_version(AlterConfigsRequest, max_version=1) + version = self._matching_api_version(AlterConfigsRequest) if version <= 1: request = AlterConfigsRequest[version]( resources=[self._convert_alter_config_resource_request(config_resource) for config_resource in config_resources] @@ -1095,15 +925,6 @@ class KafkaAdminClient(object): @staticmethod def _convert_create_partitions_request(topic_name, new_partitions): - """Convert a NewPartitions object into the tuple format for CreatePartitionsRequest. - - Arguments: - topic_name: The name of the existing topic. - new_partitions: A NewPartitions instance with total_count and new_assignments. - - Returns: - A tuple: (topic_name, (total_count, [list_of_assignments])). - """ return ( topic_name, ( @@ -1115,19 +936,14 @@ class KafkaAdminClient(object): def create_partitions(self, topic_partitions, timeout_ms=None, validate_only=False): """Create additional partitions for an existing topic. - Arguments: - topic_partitions: A map of topic name strings to NewPartition objects. - - Keyword Arguments: - timeout_ms (numeric, optional): Milliseconds to wait for new partitions to be - created before the broker returns. - validate_only (bool, optional): If True, don't actually create new partitions. - Default: False - - Returns: - Appropriate version of CreatePartitionsResponse class. + :param topic_partitions: A map of topic name strings to NewPartition objects. + :param timeout_ms: Milliseconds to wait for new partitions to be + created before the broker returns. + :param validate_only: If True, don't actually create new partitions. + Default: False + :return: Appropriate version of CreatePartitionsResponse class. """ - version = self._client.api_version(CreatePartitionsRequest, max_version=1) + version = self._matching_api_version(CreatePartitionsRequest) timeout_ms = self._validate_timeout(timeout_ms) if version <= 1: request = CreatePartitionsRequest[version]( @@ -1141,118 +957,8 @@ class KafkaAdminClient(object): .format(version)) return self._send_request_to_controller(request) - def _get_leader_for_partitions(self, partitions, timeout_ms=None): - """Finds ID of the leader node for every given topic partition. - - Will raise UnknownTopicOrPartitionError if for some partition no leader can be found. - - :param partitions: ``[TopicPartition]``: partitions for which to find leaders. - :param timeout_ms: ``float``: Timeout in milliseconds, if None (default), will be read from - config. - - :return: Dictionary with ``{leader_id -> {partitions}}`` - """ - timeout_ms = self._validate_timeout(timeout_ms) - - partitions = set(partitions) - topics = set(tp.topic for tp in partitions) - - response = self._get_cluster_metadata(topics=topics).to_object() - - leader2partitions = defaultdict(list) - valid_partitions = set() - for topic in response.get("topics", ()): - for partition in topic.get("partitions", ()): - t2p = TopicPartition(topic=topic["topic"], partition=partition["partition"]) - if t2p in partitions: - leader2partitions[partition["leader"]].append(t2p) - valid_partitions.add(t2p) - - if len(partitions) != len(valid_partitions): - unknown = set(partitions) - valid_partitions - raise UnknownTopicOrPartitionError( - "The following partitions are not known: %s" - % ", ".join(str(x) for x in unknown) - ) - - return leader2partitions - - def delete_records(self, records_to_delete, timeout_ms=None, partition_leader_id=None): - """Delete records whose offset is smaller than the given offset of the corresponding partition. - - :param records_to_delete: ``{TopicPartition: int}``: The earliest available offsets for the - given partitions. - :param timeout_ms: ``float``: Timeout in milliseconds, if None (default), will be read from - config. - :param partition_leader_id: ``str``: If specified, all deletion requests will be sent to - this node. No check is performed verifying that this is indeed the leader for all - listed partitions: use with caution. - - :return: Dictionary {topicPartition -> metadata}, where metadata is returned by the broker. - See DeleteRecordsResponse for possible fields. error_code for all partitions is - guaranteed to be zero, otherwise an exception is raised. - """ - timeout_ms = self._validate_timeout(timeout_ms) - responses = [] - version = self._client.api_version(DeleteRecordsRequest, max_version=0) - if version is None: - raise IncompatibleBrokerVersion("Broker does not support DeleteGroupsRequest") - - # We want to make as few requests as possible - # If a single node serves as a partition leader for multiple partitions (and/or - # topics), we can send all of those in a single request. - # For that we store {leader -> {partitions for leader}}, and do 1 request per leader - if partition_leader_id is None: - leader2partitions = self._get_leader_for_partitions( - set(records_to_delete), timeout_ms - ) - else: - leader2partitions = {partition_leader_id: set(records_to_delete)} - - for leader, partitions in leader2partitions.items(): - topic2partitions = defaultdict(list) - for partition in partitions: - topic2partitions[partition.topic].append(partition) - - request = DeleteRecordsRequest[version]( - topics=[ - (topic, [(tp.partition, records_to_delete[tp]) for tp in partitions]) - for topic, partitions in topic2partitions.items() - ], - timeout_ms=timeout_ms - ) - future = self._send_request_to_node(leader, request) - self._wait_for_futures([future]) - - responses.append(future.value.to_object()) - - partition2result = {} - partition2error = {} - for response in responses: - for topic in response["topics"]: - for partition in topic["partitions"]: - tp = TopicPartition(topic["name"], partition["partition_index"]) - partition2result[tp] = partition - if partition["error_code"] != 0: - partition2error[tp] = partition["error_code"] - - if partition2error: - if len(partition2error) == 1: - key, error = next(iter(partition2error.items())) - raise Errors.for_code(error)( - "Error deleting records from topic %s partition %s" % (key.topic, key.partition) - ) - else: - raise Errors.BrokerResponseError( - "The following errors occured when trying to delete records: " + - ", ".join( - "%s(partition=%d): %s" % - (partition.topic, partition.partition, Errors.for_code(error).__name__) - for partition, error in partition2error.items() - ) - ) - - return partition2result + # delete records protocol not yet implemented + # Note: send the request to the partition leaders # create delegation token protocol not yet implemented # Note: send the request to the least_loaded_node() @@ -1269,14 +975,12 @@ class KafkaAdminClient(object): def _describe_consumer_groups_send_request(self, group_id, group_coordinator_id, include_authorized_operations=False): """Send a DescribeGroupsRequest to the group's coordinator. - Arguments: - group_id: The group name as a string - group_coordinator_id: The node_id of the groups' coordinator broker. - - Returns: - A message future. + :param group_id: The group name as a string + :param group_coordinator_id: The node_id of the groups' coordinator + broker. + :return: A message future. """ - version = self._client.api_version(DescribeGroupsRequest, max_version=3) + version = self._matching_api_version(DescribeGroupsRequest) if version <= 2: if include_authorized_operations: raise IncompatibleBrokerVersion( @@ -1357,23 +1061,18 @@ class KafkaAdminClient(object): Any errors are immediately raised. - Arguments: - group_ids: A list of consumer group IDs. These are typically the - group names as strings. - - Keyword Arguments: - group_coordinator_id (int, optional): The node_id of the groups' coordinator - broker. If set to None, it will query the cluster for each group to - find that group's coordinator. Explicitly specifying this can be - useful for avoiding extra network round trips if you already know - the group coordinator. This is only useful when all the group_ids - have the same coordinator, otherwise it will error. Default: None. - include_authorized_operations (bool, optional): Whether or not to include - information about the operations a group is allowed to perform. - Only supported on API version >= v3. Default: False. - - Returns: - A list of group descriptions. For now the group descriptions + :param group_ids: A list of consumer group IDs. These are typically the + group names as strings. + :param group_coordinator_id: The node_id of the groups' coordinator + broker. If set to None, it will query the cluster for each group to + find that group's coordinator. Explicitly specifying this can be + useful for avoiding extra network round trips if you already know + the group coordinator. This is only useful when all the group_ids + have the same coordinator, otherwise it will error. Default: None. + :param include_authorized_operations: Whether or not to include + information about the operations a group is allowed to perform. + Only supported on API version >= v3. Default: False. + :return: A list of group descriptions. For now the group descriptions are the raw results from the DescribeGroupsResponse. Long-term, we plan to change this to return namedtuples as well as decoding the partition assignments. @@ -1404,13 +1103,10 @@ class KafkaAdminClient(object): def _list_consumer_groups_send_request(self, broker_id): """Send a ListGroupsRequest to a broker. - Arguments: - broker_id (int): The broker's node_id. - - Returns: - A message future + :param broker_id: The broker's node_id. + :return: A message future """ - version = self._client.api_version(ListGroupsRequest, max_version=2) + version = self._matching_api_version(ListGroupsRequest) if version <= 2: request = ListGroupsRequest[version]() else: @@ -1448,20 +1144,15 @@ class KafkaAdminClient(object): As soon as any error is encountered, it is immediately raised. - Keyword Arguments: - broker_ids ([int], optional): A list of broker node_ids to query for consumer - groups. If set to None, will query all brokers in the cluster. - Explicitly specifying broker(s) can be useful for determining which - consumer groups are coordinated by those broker(s). Default: None - - Returns: - list: List of tuples of Consumer Groups. - - Raises: - CoordinatorNotAvailableError: The coordinator is not - available, so cannot process requests. - CoordinatorLoadInProgressError: The coordinator is loading and - hence can't process requests. + :param broker_ids: A list of broker node_ids to query for consumer + groups. If set to None, will query all brokers in the cluster. + Explicitly specifying broker(s) can be useful for determining which + consumer groups are coordinated by those broker(s). Default: None + :return list: List of tuples of Consumer Groups. + :exception GroupCoordinatorNotAvailableError: The coordinator is not + available, so cannot process requests. + :exception GroupLoadInProgressError: The coordinator is loading and + hence can't process requests. """ # While we return a list, internally use a set to prevent duplicates # because if a group coordinator fails after being queried, and its @@ -1481,20 +1172,13 @@ class KafkaAdminClient(object): group_coordinator_id, partitions=None): """Send an OffsetFetchRequest to a broker. - Arguments: - group_id (str): The consumer group id name for which to fetch offsets. - group_coordinator_id (int): The node_id of the group's coordinator broker. - - Keyword Arguments: - partitions: A list of TopicPartitions for which to fetch - offsets. On brokers >= 0.10.2, this can be set to None to fetch all - known offsets for the consumer group. Default: None. - - Returns: - A message future + :param group_id: The consumer group id name for which to fetch offsets. + :param group_coordinator_id: The node_id of the group's coordinator + broker. + :return: A message future """ - version = self._client.api_version(OffsetFetchRequest, max_version=5) - if version <= 5: + version = self._matching_api_version(OffsetFetchRequest) + if version <= 3: if partitions is None: if version <= 1: raise ValueError( @@ -1519,14 +1203,11 @@ class KafkaAdminClient(object): def _list_consumer_group_offsets_process_response(self, response): """Process an OffsetFetchResponse. - Arguments: - response: an OffsetFetchResponse. - - Returns: - A dictionary composed of TopicPartition keys and - OffsetAndMetadata values. + :param response: an OffsetFetchResponse. + :return: A dictionary composed of TopicPartition keys and + OffsetAndMetada values. """ - if response.API_VERSION <= 5: + if response.API_VERSION <= 3: # OffsetFetchResponse_v1 lacks a top-level error_code if response.API_VERSION > 1: @@ -1538,21 +1219,16 @@ class KafkaAdminClient(object): .format(response)) # transform response into a dictionary with TopicPartition keys and - # OffsetAndMetadata values--this is what the Java AdminClient returns + # OffsetAndMetada values--this is what the Java AdminClient returns offsets = {} for topic, partitions in response.topics: - for partition_data in partitions: - if response.API_VERSION <= 4: - partition, offset, metadata, error_code = partition_data - leader_epoch = -1 - else: - partition, offset, leader_epoch, metadata, error_code = partition_data + for partition, offset, metadata, error_code in partitions: error_type = Errors.for_code(error_code) if error_type is not Errors.NoError: raise error_type( "Unable to fetch consumer group offsets for topic {}, partition {}" .format(topic, partition)) - offsets[TopicPartition(topic, partition)] = OffsetAndMetadata(offset, metadata, leader_epoch) + offsets[TopicPartition(topic, partition)] = OffsetAndMetadata(offset, metadata) else: raise NotImplementedError( "Support for OffsetFetchResponse_v{} has not yet been added to KafkaAdminClient." @@ -1569,22 +1245,17 @@ class KafkaAdminClient(object): As soon as any error is encountered, it is immediately raised. - Arguments: - group_id (str): The consumer group id name for which to fetch offsets. - - Keyword Arguments: - group_coordinator_id (int, optional): The node_id of the group's coordinator - broker. If set to None, will query the cluster to find the group - coordinator. Explicitly specifying this can be useful to prevent - that extra network round trip if you already know the group - coordinator. Default: None. - partitions: A list of TopicPartitions for which to fetch - offsets. On brokers >= 0.10.2, this can be set to None to fetch all - known offsets for the consumer group. Default: None. - - Returns: - dictionary: A dictionary with TopicPartition keys and - OffsetAndMetadata values. Partitions that are not specified and for + :param group_id: The consumer group id name for which to fetch offsets. + :param group_coordinator_id: The node_id of the group's coordinator + broker. If set to None, will query the cluster to find the group + coordinator. Explicitly specifying this can be useful to prevent + that extra network round trip if you already know the group + coordinator. Default: None. + :param partitions: A list of TopicPartitions for which to fetch + offsets. On brokers >= 0.10.2, this can be set to None to fetch all + known offsets for the consumer group. Default: None. + :return dictionary: A dictionary with TopicPartition keys and + OffsetAndMetada values. Partitions that are not specified and for which the group_id does not have a recorded offset are omitted. An offset value of `-1` indicates the group_id has no offset for that TopicPartition. A `-1` can only happen for partitions that are @@ -1607,19 +1278,14 @@ class KafkaAdminClient(object): The result needs checking for potential errors. - Arguments: - group_ids ([str]): The consumer group ids of the groups which are to be deleted. - - Keyword Arguments: - group_coordinator_id (int, optional): The node_id of the broker which is - the coordinator for all the groups. Use only if all groups are coordinated - by the same broker. If set to None, will query the cluster to find the coordinator - for every single group. Explicitly specifying this can be useful to prevent - that extra network round trips if you already know the group coordinator. - Default: None. - - Returns: - A list of tuples (group_id, KafkaError) + :param group_ids: The consumer group ids of the groups which are to be deleted. + :param group_coordinator_id: The node_id of the broker which is the coordinator for + all the groups. Use only if all groups are coordinated by the same broker. + If set to None, will query the cluster to find the coordinator for every single group. + Explicitly specifying this can be useful to prevent + that extra network round trips if you already know the group + coordinator. Default: None. + :return: A list of tuples (group_id, KafkaError) """ if group_coordinator_id is not None: futures = [self._delete_consumer_groups_send_request(group_ids, group_coordinator_id)] @@ -1640,14 +1306,6 @@ class KafkaAdminClient(object): return results def _convert_delete_groups_response(self, response): - """Parse the DeleteGroupsResponse, mapping group IDs to their respective errors. - - Arguments: - response: A DeleteGroupsResponse object from the broker. - - Returns: - A list of (group_id, KafkaError) for each deleted group. - """ if response.API_VERSION <= 1: results = [] for group_id, error_code in response.results: @@ -1659,16 +1317,14 @@ class KafkaAdminClient(object): .format(response.API_VERSION)) def _delete_consumer_groups_send_request(self, group_ids, group_coordinator_id): - """Send a DeleteGroupsRequest to the specified broker (the group coordinator). + """Send a DeleteGroups request to a broker. - Arguments: - group_ids ([str]): A list of consumer group IDs to be deleted. - group_coordinator_id (int): The node_id of the broker coordinating these groups. - - Returns: - A future representing the in-flight DeleteGroupsRequest. + :param group_ids: The consumer group ids of the groups which are to be deleted. + :param group_coordinator_id: The node_id of the broker which is the coordinator for + all the groups. + :return: A message future """ - version = self._client.api_version(DeleteGroupsRequest, max_version=1) + version = self._matching_api_version(DeleteGroupsRequest) if version <= 1: request = DeleteGroupsRequest[version](group_ids) else: @@ -1677,80 +1333,10 @@ class KafkaAdminClient(object): .format(version)) return self._send_request_to_node(group_coordinator_id, request) - @staticmethod - def _convert_topic_partitions(topic_partitions): - return [ - ( - topic, - partition_ids - ) - for topic, partition_ids in topic_partitions.items() - ] - - def _get_all_topic_partitions(self): - return [ - ( - topic, - [partition_info.partition for partition_info in self._client.cluster._partitions[topic].values()] - ) - for topic in self._client.cluster.topics() - ] - - def _get_topic_partitions(self, topic_partitions): - if topic_partitions is None: - return self._get_all_topic_partitions() - return self._convert_topic_partitions(topic_partitions) - - def perform_leader_election(self, election_type, topic_partitions=None, timeout_ms=None): - """Perform leader election on the topic partitions. - - :param election_type: Type of election to attempt. 0 for Perferred, 1 for Unclean - :param topic_partitions: A map of topic name strings to partition ids list. - By default, will run on all topic partitions - :param timeout_ms: Milliseconds to wait for the leader election process to complete - before the broker returns. - - :return: Appropriate version of ElectLeadersResponse class. - """ - version = self._client.api_version(ElectLeadersRequest, max_version=1) - timeout_ms = self._validate_timeout(timeout_ms) - request = ElectLeadersRequest[version]( - election_type=ElectionType(election_type), - topic_partitions=self._get_topic_partitions(topic_partitions), - timeout=timeout_ms, - ) - # TODO convert structs to a more pythonic interface - return self._send_request_to_controller(request) - def _wait_for_futures(self, futures): - """Block until all futures complete. If any fail, raise the encountered exception. - - Arguments: - futures: A list of Future objects awaiting results. - - Raises: - The first encountered exception if a future fails. - """ while not all(future.succeeded() for future in futures): for future in futures: self._client.poll(future=future) if future.failed(): raise future.exception # pylint: disable-msg=raising-bad-type - - def describe_log_dirs(self): - """Send a DescribeLogDirsRequest request to a broker. - - Returns: - A message future - """ - version = self._client.api_version(DescribeLogDirsRequest, max_version=0) - if version <= 0: - request = DescribeLogDirsRequest[version]() - future = self._send_request_to_node(self._client.least_loaded_node(), request) - self._wait_for_futures([future]) - else: - raise NotImplementedError( - "Support for DescribeLogDirsRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return future.value diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/consumer_performance.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/consumer_performance.py deleted file mode 100644 index c35a164..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/consumer_performance.py +++ /dev/null @@ -1,142 +0,0 @@ -#!/usr/bin/env python -# Adapted from https://github.com/mrafayaleem/kafka-jython - -from __future__ import absolute_import, print_function - -import argparse -import pprint -import sys -import threading -import time -import traceback - -from kafka import KafkaConsumer - - -class ConsumerPerformance(object): - @staticmethod - def run(args): - try: - props = {} - for prop in args.consumer_config: - k, v = prop.split('=') - try: - v = int(v) - except ValueError: - pass - if v == 'None': - v = None - elif v == 'False': - v = False - elif v == 'True': - v = True - props[k] = v - - print('Initializing Consumer...') - props['bootstrap_servers'] = args.bootstrap_servers - props['auto_offset_reset'] = 'earliest' - if 'group_id' not in props: - props['group_id'] = 'kafka-consumer-benchmark' - if 'consumer_timeout_ms' not in props: - props['consumer_timeout_ms'] = 10000 - props['metrics_sample_window_ms'] = args.stats_interval * 1000 - for k, v in props.items(): - print('---> {0}={1}'.format(k, v)) - consumer = KafkaConsumer(args.topic, **props) - print('---> group_id={0}'.format(consumer.config['group_id'])) - print('---> report stats every {0} secs'.format(args.stats_interval)) - print('---> raw metrics? {0}'.format(args.raw_metrics)) - timer_stop = threading.Event() - timer = StatsReporter(args.stats_interval, consumer, - event=timer_stop, - raw_metrics=args.raw_metrics) - timer.start() - print('-> OK!') - print() - - start_time = time.time() - records = 0 - for msg in consumer: - records += 1 - if records >= args.num_records: - break - - end_time = time.time() - timer_stop.set() - timer.join() - print('Consumed {0} records'.format(records)) - print('Execution time:', end_time - start_time, 'secs') - - except Exception: - exc_info = sys.exc_info() - traceback.print_exception(*exc_info) - sys.exit(1) - - -class StatsReporter(threading.Thread): - def __init__(self, interval, consumer, event=None, raw_metrics=False): - super(StatsReporter, self).__init__() - self.interval = interval - self.consumer = consumer - self.event = event - self.raw_metrics = raw_metrics - - def print_stats(self): - metrics = self.consumer.metrics() - if self.raw_metrics: - pprint.pprint(metrics) - else: - print('{records-consumed-rate} records/sec ({bytes-consumed-rate} B/sec),' - ' {fetch-latency-avg} latency,' - ' {fetch-rate} fetch/s,' - ' {fetch-size-avg} fetch size,' - ' {records-lag-max} max record lag,' - ' {records-per-request-avg} records/req' - .format(**metrics['consumer-fetch-manager-metrics'])) - - - def print_final(self): - self.print_stats() - - def run(self): - while self.event and not self.event.wait(self.interval): - self.print_stats() - else: - self.print_final() - - -def get_args_parser(): - parser = argparse.ArgumentParser( - description='This tool is used to verify the consumer performance.') - - parser.add_argument( - '--bootstrap-servers', type=str, nargs='+', default=(), - help='host:port for cluster bootstrap servers') - parser.add_argument( - '--topic', type=str, - help='Topic for consumer test (default: kafka-python-benchmark-test)', - default='kafka-python-benchmark-test') - parser.add_argument( - '--num-records', type=int, - help='number of messages to consume (default: 1000000)', - default=1000000) - parser.add_argument( - '--consumer-config', type=str, nargs='+', default=(), - help='kafka consumer related configuration properties like ' - 'bootstrap_servers,client_id etc..') - parser.add_argument( - '--fixture-compression', type=str, - help='specify a compression type for use with broker fixtures / producer') - parser.add_argument( - '--stats-interval', type=int, - help='Interval in seconds for stats reporting to console (default: 5)', - default=5) - parser.add_argument( - '--raw-metrics', action='store_true', - help='Enable this flag to print full metrics dict on each interval') - return parser - - -if __name__ == '__main__': - args = get_args_parser().parse_args() - ConsumerPerformance.run(args) diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/load_example.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/load_example.py deleted file mode 100644 index 29796a7..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/load_example.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python -from __future__ import print_function - -import argparse -import logging -import threading -import time - -from kafka import KafkaConsumer, KafkaProducer - - -class Producer(threading.Thread): - - def __init__(self, bootstrap_servers, topic, stop_event, msg_size): - super(Producer, self).__init__() - self.bootstrap_servers = bootstrap_servers - self.topic = topic - self.stop_event = stop_event - self.big_msg = b'1' * msg_size - - def run(self): - producer = KafkaProducer(bootstrap_servers=self.bootstrap_servers) - self.sent = 0 - - while not self.stop_event.is_set(): - producer.send(self.topic, self.big_msg) - self.sent += 1 - producer.flush() - producer.close() - - -class Consumer(threading.Thread): - def __init__(self, bootstrap_servers, topic, stop_event, msg_size): - super(Consumer, self).__init__() - self.bootstrap_servers = bootstrap_servers - self.topic = topic - self.stop_event = stop_event - self.msg_size = msg_size - - def run(self): - consumer = KafkaConsumer(bootstrap_servers=self.bootstrap_servers, - auto_offset_reset='earliest') - consumer.subscribe([self.topic]) - self.valid = 0 - self.invalid = 0 - - for message in consumer: - if len(message.value) == self.msg_size: - self.valid += 1 - else: - print('Invalid message:', len(message.value), self.msg_size) - self.invalid += 1 - - if self.stop_event.is_set(): - break - consumer.close() - - -def get_args_parser(): - parser = argparse.ArgumentParser( - description='This tool is used to demonstrate consumer and producer load.') - - parser.add_argument( - '--bootstrap-servers', type=str, nargs='+', default=('localhost:9092'), - help='host:port for cluster bootstrap servers (default: localhost:9092)') - parser.add_argument( - '--topic', type=str, - help='Topic for load test (default: kafka-python-benchmark-load-example)', - default='kafka-python-benchmark-load-example') - parser.add_argument( - '--msg-size', type=int, - help='Message size, in bytes, for load test (default: 524288)', - default=524288) - parser.add_argument( - '--load-time', type=int, - help='number of seconds to run load test (default: 10)', - default=10) - parser.add_argument( - '--log-level', type=str, - help='Optional logging level for load test: ERROR|INFO|DEBUG etc', - default=None) - return parser - - -def main(args): - if args.log_level: - logging.basicConfig( - format='%(asctime)s.%(msecs)s:%(name)s:%(thread)d:%(levelname)s:%(process)d:%(message)s', - level=getattr(logging, args.log_level)) - producer_stop = threading.Event() - consumer_stop = threading.Event() - threads = [ - Producer(args.bootstrap_servers, args.topic, producer_stop, args.msg_size), - Consumer(args.bootstrap_servers, args.topic, consumer_stop, args.msg_size) - ] - - for t in threads: - t.start() - - time.sleep(args.load_time) - producer_stop.set() - consumer_stop.set() - print('Messages sent: %d' % threads[0].sent) - print('Messages recvd: %d' % threads[1].valid) - print('Messages invalid: %d' % threads[1].invalid) - - -if __name__ == "__main__": - args = get_args_parser().parse_args() - main(args) diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/producer_performance.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/producer_performance.py deleted file mode 100644 index 1a10929..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/producer_performance.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# Adapted from https://github.com/mrafayaleem/kafka-jython - -from __future__ import absolute_import, print_function - -import argparse -import pprint -import sys -import threading -import time -import traceback - -from kafka.vendor.six.moves import range - -from kafka import KafkaProducer - - -class ProducerPerformance(object): - @staticmethod - def run(args): - try: - props = {} - for prop in args.producer_config: - k, v = prop.split('=') - try: - v = int(v) - except ValueError: - pass - if v == 'None': - v = None - elif v == 'False': - v = False - elif v == 'True': - v = True - props[k] = v - - print('Initializing producer...') - props['bootstrap_servers'] = args.bootstrap_servers - record = bytes(bytearray(args.record_size)) - props['metrics_sample_window_ms'] = args.stats_interval * 1000 - - producer = KafkaProducer(**props) - for k, v in props.items(): - print('---> {0}={1}'.format(k, v)) - print('---> send {0} byte records'.format(args.record_size)) - print('---> report stats every {0} secs'.format(args.stats_interval)) - print('---> raw metrics? {0}'.format(args.raw_metrics)) - timer_stop = threading.Event() - timer = StatsReporter(args.stats_interval, producer, - event=timer_stop, - raw_metrics=args.raw_metrics) - timer.start() - print('-> OK!') - print() - - def _benchmark(): - results = [] - for i in range(args.num_records): - results.append(producer.send(topic=args.topic, value=record)) - print("Send complete...") - producer.flush() - producer.close() - count_success, count_failure = 0, 0 - for r in results: - if r.succeeded(): - count_success += 1 - elif r.failed(): - count_failure += 1 - else: - raise ValueError(r) - print("%d suceeded, %d failed" % (count_success, count_failure)) - - start_time = time.time() - _benchmark() - end_time = time.time() - timer_stop.set() - timer.join() - print('Execution time:', end_time - start_time, 'secs') - - except Exception: - exc_info = sys.exc_info() - traceback.print_exception(*exc_info) - sys.exit(1) - - -class StatsReporter(threading.Thread): - def __init__(self, interval, producer, event=None, raw_metrics=False): - super(StatsReporter, self).__init__() - self.interval = interval - self.producer = producer - self.event = event - self.raw_metrics = raw_metrics - - def print_stats(self): - metrics = self.producer.metrics() - if not metrics: - return - if self.raw_metrics: - pprint.pprint(metrics) - else: - print('{record-send-rate} records/sec ({byte-rate} B/sec),' - ' {request-latency-avg} latency,' - ' {record-size-avg} record size,' - ' {batch-size-avg} batch size,' - ' {records-per-request-avg} records/req' - .format(**metrics['producer-metrics'])) - - def print_final(self): - self.print_stats() - - def run(self): - while self.event and not self.event.wait(self.interval): - self.print_stats() - else: - self.print_final() - - -def get_args_parser(): - parser = argparse.ArgumentParser( - description='This tool is used to verify the producer performance.') - - parser.add_argument( - '--bootstrap-servers', type=str, nargs='+', default=(), - help='host:port for cluster bootstrap server') - parser.add_argument( - '--topic', type=str, - help='Topic name for test (default: kafka-python-benchmark-test)', - default='kafka-python-benchmark-test') - parser.add_argument( - '--num-records', type=int, - help='number of messages to produce (default: 1000000)', - default=1000000) - parser.add_argument( - '--record-size', type=int, - help='message size in bytes (default: 100)', - default=100) - parser.add_argument( - '--producer-config', type=str, nargs='+', default=(), - help='kafka producer related configuaration properties like ' - 'bootstrap_servers,client_id etc..') - parser.add_argument( - '--stats-interval', type=int, - help='Interval in seconds for stats reporting to console (default: 5)', - default=5) - parser.add_argument( - '--raw-metrics', action='store_true', - help='Enable this flag to print full metrics dict on each interval') - return parser - - -if __name__ == '__main__': - args = get_args_parser().parse_args() - ProducerPerformance.run(args) diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_compose.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_compose.py deleted file mode 100644 index 5b07fd5..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_compose.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import print_function -import hashlib -import itertools -import os -import random - -import pyperf - -from kafka.record.memory_records import MemoryRecordsBuilder - - -DEFAULT_BATCH_SIZE = 1600 * 1024 -KEY_SIZE = 6 -VALUE_SIZE = 60 -TIMESTAMP_RANGE = [1505824130000, 1505824140000] - -# With values above v1 record is 100 bytes, so 10 000 bytes for 100 messages -MESSAGES_PER_BATCH = 100 - - -def random_bytes(length): - buffer = bytearray(length) - for i in range(length): - buffer[i] = random.randint(0, 255) - return bytes(buffer) - - -def prepare(): - return iter(itertools.cycle([ - (random_bytes(KEY_SIZE), - random_bytes(VALUE_SIZE), - random.randint(*TIMESTAMP_RANGE) - ) - for _ in range(int(MESSAGES_PER_BATCH * 1.94)) - ])) - - -def finalize(results): - # Just some strange code to make sure PyPy does execute the main code - # properly, without optimizing it away - hash_val = hashlib.md5() - for buf in results: - hash_val.update(buf) - print(hash_val, file=open(os.devnull, "w")) - - -def func(loops, magic): - # Jit can optimize out the whole function if the result is the same each - # time, so we need some randomized input data ) - precomputed_samples = prepare() - results = [] - - # Main benchmark code. - t0 = pyperf.perf_counter() - for _ in range(loops): - batch = MemoryRecordsBuilder( - magic, batch_size=DEFAULT_BATCH_SIZE, compression_type=0) - for _ in range(MESSAGES_PER_BATCH): - key, value, timestamp = next(precomputed_samples) - size = batch.append( - timestamp=timestamp, key=key, value=value) - assert size - batch.close() - results.append(batch.buffer()) - - res = pyperf.perf_counter() - t0 - - finalize(results) - - return res - - -if __name__ == '__main__': - runner = pyperf.Runner() - runner.bench_time_func('batch_append_v0', func, 0) - runner.bench_time_func('batch_append_v1', func, 1) - runner.bench_time_func('batch_append_v2', func, 2) diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_read.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_read.py deleted file mode 100644 index 2ef3229..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/record_batch_read.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/env python -from __future__ import print_function -import hashlib -import itertools -import os -import random - -import pyperf - -from kafka.record.memory_records import MemoryRecords, MemoryRecordsBuilder - - -DEFAULT_BATCH_SIZE = 1600 * 1024 -KEY_SIZE = 6 -VALUE_SIZE = 60 -TIMESTAMP_RANGE = [1505824130000, 1505824140000] - -BATCH_SAMPLES = 5 -MESSAGES_PER_BATCH = 100 - - -def random_bytes(length): - buffer = bytearray(length) - for i in range(length): - buffer[i] = random.randint(0, 255) - return bytes(buffer) - - -def prepare(magic): - samples = [] - for _ in range(BATCH_SAMPLES): - batch = MemoryRecordsBuilder( - magic, batch_size=DEFAULT_BATCH_SIZE, compression_type=0) - for _ in range(MESSAGES_PER_BATCH): - size = batch.append( - random.randint(*TIMESTAMP_RANGE), - random_bytes(KEY_SIZE), - random_bytes(VALUE_SIZE), - headers=[]) - assert size - batch.close() - samples.append(bytes(batch.buffer())) - - return iter(itertools.cycle(samples)) - - -def finalize(results): - # Just some strange code to make sure PyPy does execute the code above - # properly - hash_val = hashlib.md5() - for buf in results: - hash_val.update(buf) - print(hash_val, file=open(os.devnull, "w")) - - -def func(loops, magic): - # Jit can optimize out the whole function if the result is the same each - # time, so we need some randomized input data ) - precomputed_samples = prepare(magic) - results = [] - - # Main benchmark code. - batch_data = next(precomputed_samples) - t0 = pyperf.perf_counter() - for _ in range(loops): - records = MemoryRecords(batch_data) - while records.has_next(): - batch = records.next_batch() - batch.validate_crc() - for record in batch: - results.append(record.value) - - res = pyperf.perf_counter() - t0 - finalize(results) - - return res - - -if __name__ == '__main__': - runner = pyperf.Runner() - runner.bench_time_func('batch_read_v0', func, 0) - runner.bench_time_func('batch_read_v1', func, 1) - runner.bench_time_func('batch_read_v2', func, 2) diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/varint_speed.py b/venv/lib/python3.12/site-packages/kafka/benchmarks/varint_speed.py deleted file mode 100644 index b2628a1..0000000 --- a/venv/lib/python3.12/site-packages/kafka/benchmarks/varint_speed.py +++ /dev/null @@ -1,434 +0,0 @@ -#!/usr/bin/env python -from __future__ import print_function -import pyperf -from kafka.vendor import six - - -test_data = [ - (b"\x00", 0), - (b"\x01", -1), - (b"\x02", 1), - (b"\x7E", 63), - (b"\x7F", -64), - (b"\x80\x01", 64), - (b"\x81\x01", -65), - (b"\xFE\x7F", 8191), - (b"\xFF\x7F", -8192), - (b"\x80\x80\x01", 8192), - (b"\x81\x80\x01", -8193), - (b"\xFE\xFF\x7F", 1048575), - (b"\xFF\xFF\x7F", -1048576), - (b"\x80\x80\x80\x01", 1048576), - (b"\x81\x80\x80\x01", -1048577), - (b"\xFE\xFF\xFF\x7F", 134217727), - (b"\xFF\xFF\xFF\x7F", -134217728), - (b"\x80\x80\x80\x80\x01", 134217728), - (b"\x81\x80\x80\x80\x01", -134217729), - (b"\xFE\xFF\xFF\xFF\x7F", 17179869183), - (b"\xFF\xFF\xFF\xFF\x7F", -17179869184), - (b"\x80\x80\x80\x80\x80\x01", 17179869184), - (b"\x81\x80\x80\x80\x80\x01", -17179869185), - (b"\xFE\xFF\xFF\xFF\xFF\x7F", 2199023255551), - (b"\xFF\xFF\xFF\xFF\xFF\x7F", -2199023255552), - (b"\x80\x80\x80\x80\x80\x80\x01", 2199023255552), - (b"\x81\x80\x80\x80\x80\x80\x01", -2199023255553), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\x7F", 281474976710655), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -281474976710656), - (b"\x80\x80\x80\x80\x80\x80\x80\x01", 281474976710656), - (b"\x81\x80\x80\x80\x80\x80\x80\x01", -281474976710657), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 36028797018963967), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -36028797018963968), - (b"\x80\x80\x80\x80\x80\x80\x80\x80\x01", 36028797018963968), - (b"\x81\x80\x80\x80\x80\x80\x80\x80\x01", -36028797018963969), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 4611686018427387903), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -4611686018427387904), - (b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01", 4611686018427387904), - (b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01", -4611686018427387905), -] - - -BENCH_VALUES_ENC = [ - 60, # 1 byte - -8192, # 2 bytes - 1048575, # 3 bytes - 134217727, # 4 bytes - -17179869184, # 5 bytes - 2199023255551, # 6 bytes -] - -BENCH_VALUES_DEC = [ - b"\x7E", # 1 byte - b"\xFF\x7F", # 2 bytes - b"\xFE\xFF\x7F", # 3 bytes - b"\xFF\xFF\xFF\x7F", # 4 bytes - b"\x80\x80\x80\x80\x01", # 5 bytes - b"\xFE\xFF\xFF\xFF\xFF\x7F", # 6 bytes -] -BENCH_VALUES_DEC = list(map(bytearray, BENCH_VALUES_DEC)) - - -def _assert_valid_enc(enc_func): - for encoded, decoded in test_data: - assert enc_func(decoded) == encoded, decoded - - -def _assert_valid_dec(dec_func): - for encoded, decoded in test_data: - res, pos = dec_func(bytearray(encoded)) - assert res == decoded, (decoded, res) - assert pos == len(encoded), (decoded, pos) - - -def _assert_valid_size(size_func): - for encoded, decoded in test_data: - assert size_func(decoded) == len(encoded), decoded - - -def encode_varint_1(num): - """ Encode an integer to a varint presentation. See - https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints - on how those can be produced. - - Arguments: - num (int): Value to encode - - Returns: - bytearray: Encoded presentation of integer with length from 1 to 10 - bytes - """ - # Shift sign to the end of number - num = (num << 1) ^ (num >> 63) - # Max 10 bytes. We assert those are allocated - buf = bytearray(10) - - for i in range(10): - # 7 lowest bits from the number and set 8th if we still have pending - # bits left to encode - buf[i] = num & 0x7f | (0x80 if num > 0x7f else 0) - num = num >> 7 - if num == 0: - break - else: - # Max size of endcoded double is 10 bytes for unsigned values - raise ValueError("Out of double range") - return buf[:i + 1] - - -def encode_varint_2(value, int2byte=six.int2byte): - value = (value << 1) ^ (value >> 63) - - bits = value & 0x7f - value >>= 7 - res = b"" - while value: - res += int2byte(0x80 | bits) - bits = value & 0x7f - value >>= 7 - return res + int2byte(bits) - - -def encode_varint_3(value, buf): - append = buf.append - value = (value << 1) ^ (value >> 63) - - bits = value & 0x7f - value >>= 7 - while value: - append(0x80 | bits) - bits = value & 0x7f - value >>= 7 - append(bits) - return value - - -def encode_varint_4(value, int2byte=six.int2byte): - value = (value << 1) ^ (value >> 63) - - if value <= 0x7f: # 1 byte - return int2byte(value) - if value <= 0x3fff: # 2 bytes - return int2byte(0x80 | (value & 0x7f)) + int2byte(value >> 7) - if value <= 0x1fffff: # 3 bytes - return int2byte(0x80 | (value & 0x7f)) + \ - int2byte(0x80 | ((value >> 7) & 0x7f)) + \ - int2byte(value >> 14) - if value <= 0xfffffff: # 4 bytes - return int2byte(0x80 | (value & 0x7f)) + \ - int2byte(0x80 | ((value >> 7) & 0x7f)) + \ - int2byte(0x80 | ((value >> 14) & 0x7f)) + \ - int2byte(value >> 21) - if value <= 0x7ffffffff: # 5 bytes - return int2byte(0x80 | (value & 0x7f)) + \ - int2byte(0x80 | ((value >> 7) & 0x7f)) + \ - int2byte(0x80 | ((value >> 14) & 0x7f)) + \ - int2byte(0x80 | ((value >> 21) & 0x7f)) + \ - int2byte(value >> 28) - else: - # Return to general algorithm - bits = value & 0x7f - value >>= 7 - res = b"" - while value: - res += int2byte(0x80 | bits) - bits = value & 0x7f - value >>= 7 - return res + int2byte(bits) - - -def encode_varint_5(value, buf, pos=0): - value = (value << 1) ^ (value >> 63) - - bits = value & 0x7f - value >>= 7 - while value: - buf[pos] = 0x80 | bits - bits = value & 0x7f - value >>= 7 - pos += 1 - buf[pos] = bits - return pos + 1 - -def encode_varint_6(value, buf): - append = buf.append - value = (value << 1) ^ (value >> 63) - - if value <= 0x7f: # 1 byte - append(value) - return 1 - if value <= 0x3fff: # 2 bytes - append(0x80 | (value & 0x7f)) - append(value >> 7) - return 2 - if value <= 0x1fffff: # 3 bytes - append(0x80 | (value & 0x7f)) - append(0x80 | ((value >> 7) & 0x7f)) - append(value >> 14) - return 3 - if value <= 0xfffffff: # 4 bytes - append(0x80 | (value & 0x7f)) - append(0x80 | ((value >> 7) & 0x7f)) - append(0x80 | ((value >> 14) & 0x7f)) - append(value >> 21) - return 4 - if value <= 0x7ffffffff: # 5 bytes - append(0x80 | (value & 0x7f)) - append(0x80 | ((value >> 7) & 0x7f)) - append(0x80 | ((value >> 14) & 0x7f)) - append(0x80 | ((value >> 21) & 0x7f)) - append(value >> 28) - return 5 - else: - # Return to general algorithm - bits = value & 0x7f - value >>= 7 - i = 0 - while value: - append(0x80 | bits) - bits = value & 0x7f - value >>= 7 - i += 1 - append(bits) - return i - - -def size_of_varint_1(value): - """ Number of bytes needed to encode an integer in variable-length format. - """ - value = (value << 1) ^ (value >> 63) - res = 0 - while True: - res += 1 - value = value >> 7 - if value == 0: - break - return res - - -def size_of_varint_2(value): - """ Number of bytes needed to encode an integer in variable-length format. - """ - value = (value << 1) ^ (value >> 63) - if value <= 0x7f: - return 1 - if value <= 0x3fff: - return 2 - if value <= 0x1fffff: - return 3 - if value <= 0xfffffff: - return 4 - if value <= 0x7ffffffff: - return 5 - if value <= 0x3ffffffffff: - return 6 - if value <= 0x1ffffffffffff: - return 7 - if value <= 0xffffffffffffff: - return 8 - if value <= 0x7fffffffffffffff: - return 9 - return 10 - - -if six.PY3: - def _read_byte(memview, pos): - """ Read a byte from memoryview as an integer - - Raises: - IndexError: if position is out of bounds - """ - return memview[pos] -else: - def _read_byte(memview, pos): - """ Read a byte from memoryview as an integer - - Raises: - IndexError: if position is out of bounds - """ - return ord(memview[pos]) - - -def decode_varint_1(buffer, pos=0): - """ Decode an integer from a varint presentation. See - https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints - on how those can be produced. - - Arguments: - buffer (bytes-like): any object acceptable by ``memoryview`` - pos (int): optional position to read from - - Returns: - (int, int): Decoded int value and next read position - """ - value = 0 - shift = 0 - memview = memoryview(buffer) - for i in range(pos, pos + 10): - try: - byte = _read_byte(memview, i) - except IndexError: - raise ValueError("End of byte stream") - if byte & 0x80 != 0: - value |= (byte & 0x7f) << shift - shift += 7 - else: - value |= byte << shift - break - else: - # Max size of endcoded double is 10 bytes for unsigned values - raise ValueError("Out of double range") - # Normalize sign - return (value >> 1) ^ -(value & 1), i + 1 - - -def decode_varint_2(buffer, pos=0): - result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= ((b & 0x7f) << shift) - pos += 1 - if not (b & 0x80): - # result = result_type(() & mask) - return ((result >> 1) ^ -(result & 1), pos) - shift += 7 - if shift >= 64: - raise ValueError("Out of int64 range") - - -def decode_varint_3(buffer, pos=0): - result = buffer[pos] - if not (result & 0x81): - return (result >> 1), pos + 1 - if not (result & 0x80): - return (result >> 1) ^ (~0), pos + 1 - - result &= 0x7f - pos += 1 - shift = 7 - while 1: - b = buffer[pos] - result |= ((b & 0x7f) << shift) - pos += 1 - if not (b & 0x80): - return ((result >> 1) ^ -(result & 1), pos) - shift += 7 - if shift >= 64: - raise ValueError("Out of int64 range") - - -if __name__ == '__main__': - _assert_valid_enc(encode_varint_1) - _assert_valid_enc(encode_varint_2) - - for encoded, decoded in test_data: - res = bytearray() - encode_varint_3(decoded, res) - assert res == encoded - - _assert_valid_enc(encode_varint_4) - - # import dis - # dis.dis(encode_varint_4) - - for encoded, decoded in test_data: - res = bytearray(10) - written = encode_varint_5(decoded, res) - assert res[:written] == encoded - - for encoded, decoded in test_data: - res = bytearray() - encode_varint_6(decoded, res) - assert res == encoded - - _assert_valid_size(size_of_varint_1) - _assert_valid_size(size_of_varint_2) - _assert_valid_dec(decode_varint_1) - _assert_valid_dec(decode_varint_2) - _assert_valid_dec(decode_varint_3) - - # import dis - # dis.dis(decode_varint_3) - - runner = pyperf.Runner() - # Encode algorithms returning a bytes result - for bench_func in [ - encode_varint_1, - encode_varint_2, - encode_varint_4]: - for i, value in enumerate(BENCH_VALUES_ENC): - runner.bench_func( - '{}_{}byte'.format(bench_func.__name__, i + 1), - bench_func, value) - - # Encode algorithms writing to the buffer - for bench_func in [ - encode_varint_3, - encode_varint_5, - encode_varint_6]: - for i, value in enumerate(BENCH_VALUES_ENC): - fname = bench_func.__name__ - runner.timeit( - '{}_{}byte'.format(fname, i + 1), - stmt="{}({}, buffer)".format(fname, value), - setup="from __main__ import {}; buffer = bytearray(10)".format( - fname) - ) - - # Size algorithms - for bench_func in [ - size_of_varint_1, - size_of_varint_2]: - for i, value in enumerate(BENCH_VALUES_ENC): - runner.bench_func( - '{}_{}byte'.format(bench_func.__name__, i + 1), - bench_func, value) - - # Decode algorithms - for bench_func in [ - decode_varint_1, - decode_varint_2, - decode_varint_3]: - for i, value in enumerate(BENCH_VALUES_DEC): - runner.bench_func( - '{}_{}byte'.format(bench_func.__name__, i + 1), - bench_func, value) diff --git a/venv/lib/python3.12/site-packages/kafka/client_async.py b/venv/lib/python3.12/site-packages/kafka/client_async.py index 7d46657..58f22d4 100644 --- a/venv/lib/python3.12/site-packages/kafka/client_async.py +++ b/venv/lib/python3.12/site-packages/kafka/client_async.py @@ -19,18 +19,17 @@ except ImportError: from kafka.vendor import six from kafka.cluster import ClusterMetadata -from kafka.conn import BrokerConnection, ConnectionStates, get_ip_port_afi +from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts, get_ip_port_afi from kafka import errors as Errors from kafka.future import Future from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Rate from kafka.metrics.stats.rate import TimeUnit -from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS from kafka.protocol.metadata import MetadataRequest -from kafka.util import Dict, Timer, WeakMethod, ensure_valid_topic_name +from kafka.util import Dict, WeakMethod # Although this looks unused, it actually monkey-patches socket.socketpair() # and should be left in as long as we're using socket.socketpair() in this file -from kafka.vendor import socketpair # noqa: F401 +from kafka.vendor import socketpair from kafka.version import __version__ if six.PY2: @@ -76,7 +75,7 @@ class KafkaClient(object): reconnection attempts will continue periodically with this fixed rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 30000. + 20% below and 20% above the computed value. Default: 1000. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. connections_max_idle_ms: Close idle connections after the number of @@ -102,9 +101,6 @@ class KafkaClient(object): which we force a refresh of metadata even if we haven't seen any partition leadership changes to proactively discover any new brokers or partitions. Default: 300000 - allow_auto_create_topics (bool): Enable/disable auto topic creation - on metadata request. Only available with api_version >= (0, 11). - Default: True security_protocol (str): Protocol used to communicate with brokers. Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. Default: PLAINTEXT. @@ -133,24 +129,12 @@ class KafkaClient(object): format. If no cipher can be selected (because compile-time options or other configuration forbids use of all the specified ciphers), an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers - api_version (tuple): Specify which Kafka API version to use. If set to - None, the client will attempt to determine the broker version via - ApiVersionsRequest API or, for brokers earlier than 0.10, probing - various known APIs. Dynamic version checking is performed eagerly - during __init__ and can raise NoBrokersAvailableError if no connection - was made before timeout (see api_version_auto_timeout_ms below). - Different versions enable different functionality. - - Examples: - (3, 9) most recent broker release, enable all supported features - (0, 10, 0) enables sasl authentication - (0, 8, 0) enables basic functionality only - - Default: None + api_version (tuple): Specify which Kafka API version to use. If set + to None, KafkaClient will attempt to infer the broker version by + probing various APIs. Example: (0, 10, 2). Default: None api_version_auto_timeout_ms (int): number of milliseconds to throw a timeout exception from the constructor when checking the broker - api version. Only applies if api_version set to None. - Default: 2000 + api version. Only applies if api_version is None selector (selectors.BaseSelector): Provide a specific selector implementation to use for I/O multiplexing. Default: selectors.DefaultSelector @@ -164,16 +148,12 @@ class KafkaClient(object): Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_name (str or gssapi.Name): Constructed gssapi.Name for use with - sasl mechanism handshake. If provided, sasl_kerberos_service_name and - sasl_kerberos_domain name are ignored. Default: None. sasl_kerberos_service_name (str): Service name to include in GSSAPI sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (kafka.sasl.oauth.AbstractTokenProvider): OAuthBearer - token provider instance. Default: None - socks5_proxy (str): Socks5 proxy URL. Default: None + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None """ DEFAULT_CONFIG = { @@ -184,7 +164,7 @@ class KafkaClient(object): 'wakeup_timeout_ms': 3000, 'connections_max_idle_ms': 9 * 60 * 1000, 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 30000, + 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, 'receive_buffer_bytes': None, 'send_buffer_bytes': None, @@ -192,7 +172,6 @@ class KafkaClient(object): 'sock_chunk_bytes': 4096, # undocumented experimental option 'sock_chunk_buffer_count': 1000, # undocumented experimental option 'retry_backoff_ms': 100, - 'allow_auto_create_topics': True, 'metadata_max_age_ms': 300000, 'security_protocol': 'PLAINTEXT', 'ssl_context': None, @@ -211,11 +190,9 @@ class KafkaClient(object): 'sasl_mechanism': None, 'sasl_plain_username': None, 'sasl_plain_password': None, - 'sasl_kerberos_name': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - 'socks5_proxy': None, + 'sasl_oauth_token_provider': None } def __init__(self, **configs): @@ -227,9 +204,8 @@ class KafkaClient(object): # these properties need to be set on top of the initialization pipeline # because they are used when __del__ method is called self._closed = False + self._wake_r, self._wake_w = socket.socketpair() self._selector = self.config['selector']() - self._init_wakeup_socketpair() - self._wake_lock = threading.Lock() self.cluster = ClusterMetadata(**self.config) self._topics = set() # empty set will fetch all topic metadata @@ -238,10 +214,12 @@ class KafkaClient(object): self._api_versions = None self._connecting = set() self._sending = set() - - # Not currently used, but data is collected internally + self._refresh_on_disconnects = True self._last_bootstrap = 0 self._bootstrap_fails = 0 + self._wake_r.setblocking(False) + self._wake_w.settimeout(self.config['wakeup_timeout_ms'] / 1000.0) + self._wake_lock = threading.Lock() self._lock = threading.RLock() @@ -250,6 +228,7 @@ class KafkaClient(object): # lock above. self._pending_completion = collections.deque() + self._selector.register(self._wake_r, selectors.EVENT_READ) self._idle_expiry_manager = IdleConnectionManager(self.config['connections_max_idle_ms']) self._sensors = None if self.config['metrics']: @@ -257,48 +236,26 @@ class KafkaClient(object): self.config['metric_group_prefix'], weakref.proxy(self._conns)) + self._num_bootstrap_hosts = len(collect_hosts(self.config['bootstrap_servers'])) + # Check Broker Version if not set explicitly if self.config['api_version'] is None: - self.config['api_version'] = self.check_version() - elif self.config['api_version'] in BROKER_API_VERSIONS: - self._api_versions = BROKER_API_VERSIONS[self.config['api_version']] - elif (self.config['api_version'] + (0,)) in BROKER_API_VERSIONS: - log.warning('Configured api_version %s is ambiguous; using %s', - self.config['api_version'], self.config['api_version'] + (0,)) - self.config['api_version'] = self.config['api_version'] + (0,) - self._api_versions = BROKER_API_VERSIONS[self.config['api_version']] - else: - compatible_version = None - for v in sorted(BROKER_API_VERSIONS.keys(), reverse=True): - if v <= self.config['api_version']: - compatible_version = v - break - if compatible_version: - log.warning('Configured api_version %s not supported; using %s', - self.config['api_version'], compatible_version) - self.config['api_version'] = compatible_version - self._api_versions = BROKER_API_VERSIONS[compatible_version] - else: - raise Errors.UnrecognizedBrokerVersion(self.config['api_version']) + check_timeout = self.config['api_version_auto_timeout_ms'] / 1000 + self.config['api_version'] = self.check_version(timeout=check_timeout) - def _init_wakeup_socketpair(self): - self._wake_r, self._wake_w = socket.socketpair() - self._wake_r.setblocking(False) - self._wake_w.settimeout(self.config['wakeup_timeout_ms'] / 1000.0) - self._waking = False - self._selector.register(self._wake_r, selectors.EVENT_READ) + def _can_bootstrap(self): + effective_failures = self._bootstrap_fails // self._num_bootstrap_hosts + backoff_factor = 2 ** effective_failures + backoff_ms = min(self.config['reconnect_backoff_ms'] * backoff_factor, + self.config['reconnect_backoff_max_ms']) - def _close_wakeup_socketpair(self): - if self._wake_r is not None: - try: - self._selector.unregister(self._wake_r) - except (KeyError, ValueError, TypeError): - pass - self._wake_r.close() - if self._wake_w is not None: - self._wake_w.close() - self._wake_r = None - self._wake_w = None + backoff_ms *= random.uniform(0.8, 1.2) + + next_at = self._last_bootstrap + backoff_ms / 1000.0 + now = time.time() + if next_at > now: + return False + return True def _can_connect(self, node_id): if node_id not in self._conns: @@ -310,7 +267,7 @@ class KafkaClient(object): def _conn_state_change(self, node_id, sock, conn): with self._lock: - if conn.state is ConnectionStates.CONNECTING: + if conn.connecting(): # SSL connections can enter this state 2x (second during Handshake) if node_id not in self._connecting: self._connecting.add(node_id) @@ -322,19 +279,7 @@ class KafkaClient(object): if self.cluster.is_bootstrap(node_id): self._last_bootstrap = time.time() - elif conn.state is ConnectionStates.API_VERSIONS_SEND: - try: - self._selector.register(sock, selectors.EVENT_WRITE, conn) - except KeyError: - self._selector.modify(sock, selectors.EVENT_WRITE, conn) - - elif conn.state in (ConnectionStates.API_VERSIONS_RECV, ConnectionStates.AUTHENTICATING): - try: - self._selector.register(sock, selectors.EVENT_READ, conn) - except KeyError: - self._selector.modify(sock, selectors.EVENT_READ, conn) - - elif conn.state is ConnectionStates.CONNECTED: + elif conn.connected(): log.debug("Node %s connected", node_id) if node_id in self._connecting: self._connecting.remove(node_id) @@ -351,8 +296,6 @@ class KafkaClient(object): if self.cluster.is_bootstrap(node_id): self._bootstrap_fails = 0 - if self._api_versions is None: - self._api_versions = conn._api_versions else: for node_id in list(self._conns.keys()): @@ -365,7 +308,7 @@ class KafkaClient(object): self._connecting.remove(node_id) try: self._selector.unregister(sock) - except (KeyError, ValueError): + except KeyError: pass if self._sensors: @@ -384,7 +327,7 @@ class KafkaClient(object): elif self.cluster.is_bootstrap(node_id): self._bootstrap_fails += 1 - elif conn.connect_failed() and not self._closed and not idle_disconnect: + elif self._refresh_on_disconnects and not self._closed and not idle_disconnect: log.warning("Node %s connection failed -- refreshing metadata", node_id) self.cluster.request_update() @@ -400,11 +343,6 @@ class KafkaClient(object): return True return False - def connection_failed(self, node_id): - if node_id not in self._conns: - return False - return self._conns[node_id].connect_failed() - def _should_recycle_connection(self, conn): # Never recycle unless disconnected if not conn.disconnected(): @@ -415,7 +353,7 @@ class KafkaClient(object): if broker is None: return False - host, _, _ = get_ip_port_afi(broker.host) + host, _, afi = get_ip_port_afi(broker.host) if conn.host != host or conn.port != broker.port: log.info("Broker metadata change detected for node %s" " from %s:%s to %s:%s", conn.node_id, conn.host, conn.port, @@ -424,24 +362,14 @@ class KafkaClient(object): return False - def _init_connect(self, node_id): - """Idempotent non-blocking connection attempt to the given node id. - - Returns True if connection object exists and is connected / connecting - """ + def _maybe_connect(self, node_id): + """Idempotent non-blocking connection attempt to the given node id.""" with self._lock: conn = self._conns.get(node_id) - # Check if existing connection should be recreated because host/port changed - if conn is not None and self._should_recycle_connection(conn): - self._conns.pop(node_id).close() - conn = None - if conn is None: broker = self.cluster.broker_metadata(node_id) - if broker is None: - log.debug('Broker id %s not in current metadata', node_id) - return False + assert broker, 'Broker id %s not in current metadata' % (node_id,) log.debug("Initiating connection to node %s at %s:%s", node_id, broker.host, broker.port) @@ -453,9 +381,16 @@ class KafkaClient(object): **self.config) self._conns[node_id] = conn - if conn.disconnected(): - conn.connect() - return not conn.disconnected() + # Check if existing connection should be recreated because host/port changed + elif self._should_recycle_connection(conn): + self._conns.pop(node_id) + return False + + elif conn.connected(): + return True + + conn.connect() + return conn.connected() def ready(self, node_id, metadata_priority=True): """Check whether a node is connected and ok to send more requests. @@ -481,7 +416,8 @@ class KafkaClient(object): def _close(self): if not self._closed: self._closed = True - self._close_wakeup_socketpair() + self._wake_r.close() + self._wake_w.close() self._selector.close() def close(self, node_id=None): @@ -528,8 +464,9 @@ class KafkaClient(object): def connection_delay(self, node_id): """ Return the number of milliseconds to wait, based on the connection - state, before attempting to send data. When connecting or disconnected, - this respects the reconnect backoff time. When connected, returns a very large + state, before attempting to send data. When disconnected, this respects + the reconnect backoff time. When connecting, returns 0 to allow + non-blocking connect to finish. When connected, returns a very large number to handle slow/stalled connections. Arguments: @@ -543,16 +480,6 @@ class KafkaClient(object): return 0 return conn.connection_delay() - def throttle_delay(self, node_id): - """ - Return the number of milliseconds to wait until a broker is no longer throttled. - When disconnected / connecting, returns 0. - """ - conn = self._conns.get(node_id) - if conn is None: - return 0 - return conn.throttle_delay() - def is_ready(self, node_id, metadata_priority=True): """Check whether a node is ready to send more requests. @@ -585,7 +512,7 @@ class KafkaClient(object): return False return conn.connected() and conn.can_send_more() - def send(self, node_id, request, wakeup=True, request_timeout_ms=None): + def send(self, node_id, request, wakeup=True): """Send a request to a specific node. Bytes are placed on an internal per-connection send-queue. Actual network I/O will be triggered in a subsequent call to .poll() @@ -593,13 +520,7 @@ class KafkaClient(object): Arguments: node_id (int): destination node request (Struct): request object (not-encoded) - - Keyword Arguments: - wakeup (bool, optional): optional flag to disable thread-wakeup. - request_timeout_ms (int, optional): Provide custom timeout in milliseconds. - If response is not processed before timeout, client will fail the - request and close the connection. - Default: None (uses value from client configuration) + wakeup (bool): optional flag to disable thread-wakeup Raises: AssertionError: if node_id is not in current cluster metadata @@ -615,9 +536,8 @@ class KafkaClient(object): # conn.send will queue the request internally # we will need to call send_pending_requests() # to trigger network I/O - future = conn.send(request, blocking=False, request_timeout_ms=request_timeout_ms) - if not future.is_done: - self._sending.add(conn) + future = conn.send(request, blocking=False) + self._sending.add(conn) # Wakeup signal is useful in case another thread is # blocked waiting for incoming network traffic while holding @@ -643,9 +563,12 @@ class KafkaClient(object): Returns: list: responses received (can be empty) """ - if not isinstance(timeout_ms, (int, float, type(None))): + if future is not None: + timeout_ms = 100 + elif timeout_ms is None: + timeout_ms = self.config['request_timeout_ms'] + elif not isinstance(timeout_ms, (int, float)): raise TypeError('Invalid type for timeout: %s' % type(timeout_ms)) - timer = Timer(timeout_ms) # Loop for futures, break after first loop if None responses = [] @@ -656,30 +579,24 @@ class KafkaClient(object): # Attempt to complete pending connections for node_id in list(self._connecting): - # False return means no more connection progress is possible - # Connected nodes will update _connecting via state_change callback - if not self._init_connect(node_id): - # It's possible that the connection attempt triggered a state change - # but if not, make sure to remove from _connecting list - if node_id in self._connecting: - self._connecting.remove(node_id) + self._maybe_connect(node_id) - # Send a metadata request if needed (or initiate new connection) + # Send a metadata request if needed metadata_timeout_ms = self._maybe_refresh_metadata() # If we got a future that is already done, don't block in _poll if future is not None and future.is_done: timeout = 0 else: - user_timeout_ms = timer.timeout_ms if timeout_ms is not None else self.config['request_timeout_ms'] idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms() - request_timeout_ms = self._next_ifr_request_timeout_ms() - log.debug("Timeouts: user %f, metadata %f, idle connection %f, request %f", user_timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, request_timeout_ms) timeout = min( - user_timeout_ms, + timeout_ms, metadata_timeout_ms, idle_connection_timeout_ms, - request_timeout_ms) + self.config['request_timeout_ms']) + # if there are no requests in flight, do not block longer than the retry backoff + if self.in_flight_request_count() == 0: + timeout = min(timeout, self.config['retry_backoff_ms']) timeout = max(0, timeout) # avoid negative timeouts self._poll(timeout / 1000) @@ -690,11 +607,7 @@ class KafkaClient(object): # If all we had was a timeout (future is None) - only do one poll # If we do have a future, we keep looping until it is done - if future is None: - break - elif future.is_done: - break - elif timeout_ms is not None and timer.expired: + if future is None or future.is_done: break return responses @@ -702,8 +615,6 @@ class KafkaClient(object): def _register_send_sockets(self): while self._sending: conn = self._sending.pop() - if conn._sock is None: - continue try: key = self._selector.get_key(conn._sock) events = key.events | selectors.EVENT_WRITE @@ -712,11 +623,6 @@ class KafkaClient(object): self._selector.register(conn._sock, selectors.EVENT_WRITE, conn) def _poll(self, timeout): - # Python throws OverflowError if timeout is > 2147483647 milliseconds - # (though the param to selector.select is in seconds) - # so convert any too-large timeout to blocking - if timeout > 2147483: - timeout = None # This needs to be locked, but since it is only called from within the # locked section of poll(), there is no additional lock acquisition here processed = set() @@ -789,13 +695,11 @@ class KafkaClient(object): for conn in six.itervalues(self._conns): if conn.requests_timed_out(): - timed_out = conn.timed_out_ifrs() - timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000 log.warning('%s timed out after %s ms. Closing connection.', - conn, timeout_ms) + conn, conn.config['request_timeout_ms']) conn.close(error=Errors.RequestTimedOutError( 'Request timed out after %s ms' % - timeout_ms)) + conn.config['request_timeout_ms'])) if self._sensors: self._sensors.io_time.record((time.time() - end_select) * 1000000000) @@ -833,17 +737,16 @@ class KafkaClient(object): break future.success(response) responses.append(response) - return responses def least_loaded_node(self): """Choose the node with fewest outstanding requests, with fallbacks. - This method will prefer a node with an existing connection (not throttled) - with no in-flight-requests. If no such node is found, a node will be chosen - randomly from all nodes that are not throttled or "blacked out" (i.e., + This method will prefer a node with an existing connection and no + in-flight-requests. If no such node is found, a node will be chosen + randomly from disconnected nodes that are not "blacked out" (i.e., are not subject to a reconnect backoff). If no node metadata has been - obtained, will return a bootstrap node. + obtained, will return a bootstrap node (subject to exponential backoff). Returns: node_id or None if no suitable node was found @@ -855,11 +758,11 @@ class KafkaClient(object): found = None for node_id in nodes: conn = self._conns.get(node_id) - connected = conn is not None and conn.connected() and conn.can_send_more() - blacked_out = conn is not None and (conn.blacked_out() or conn.throttled()) + connected = conn is not None and conn.connected() + blacked_out = conn is not None and conn.blacked_out() curr_inflight = len(conn.in_flight_requests) if conn is not None else 0 if connected and curr_inflight == 0: - # if we find an established connection (not throttled) + # if we find an established connection # with no in-flight requests, we can stop right away return node_id elif not blacked_out and curr_inflight < inflight: @@ -869,24 +772,6 @@ class KafkaClient(object): return found - def _refresh_delay_ms(self, node_id): - conn = self._conns.get(node_id) - if conn is not None and conn.connected(): - return self.throttle_delay(node_id) - else: - return self.connection_delay(node_id) - - def least_loaded_node_refresh_ms(self): - """Return connection or throttle delay in milliseconds for next available node. - - This method is used primarily for retry/backoff during metadata refresh - during / after a cluster outage, in which there are no available nodes. - - Returns: - float: delay_ms - """ - return min([self._refresh_delay_ms(broker.nodeId) for broker in self.cluster.brokers()]) - def set_topics(self, topics): """Set specific topics to track for metadata. @@ -911,31 +796,19 @@ class KafkaClient(object): Returns: Future: resolves after metadata request/response - - Raises: - TypeError: if topic is not a string - ValueError: if topic is invalid: must be chars (a-zA-Z0-9._-), and less than 250 length """ - ensure_valid_topic_name(topic) - if topic in self._topics: return Future().success(set(self._topics)) self._topics.add(topic) return self.cluster.request_update() - def _next_ifr_request_timeout_ms(self): - if self._conns: - return min([conn.next_ifr_request_timeout_ms() for conn in six.itervalues(self._conns)]) - else: - return float('inf') - # This method should be locked when running multi-threaded def _maybe_refresh_metadata(self, wakeup=False): """Send a metadata request if needed. Returns: - float: milliseconds until next refresh + int: milliseconds until next refresh """ ttl = self.cluster.ttl() wait_for_in_progress_ms = self.config['request_timeout_ms'] if self._metadata_refresh_in_progress else 0 @@ -949,44 +822,18 @@ class KafkaClient(object): # least_loaded_node() node_id = self.least_loaded_node() if node_id is None: - next_connect_ms = self.least_loaded_node_refresh_ms() - log.debug("Give up sending metadata request since no node is available. (reconnect delay %d ms)", next_connect_ms) - return next_connect_ms + log.debug("Give up sending metadata request since no node is available"); + return self.config['reconnect_backoff_ms'] - if not self._can_send_request(node_id): - # If there's any connection establishment underway, wait until it completes. This prevents - # the client from unnecessarily connecting to additional nodes while a previous connection - # attempt has not been completed. - if self._connecting: - return float('inf') - - elif self._can_connect(node_id): - log.debug("Initializing connection to node %s for metadata request", node_id) - self._connecting.add(node_id) - if not self._init_connect(node_id): - if node_id in self._connecting: - self._connecting.remove(node_id) - # Connection attempt failed immediately, need to retry with a different node - return self.config['reconnect_backoff_ms'] - else: - # Existing connection throttled or max in flight requests. - return self.throttle_delay(node_id) or self.config['request_timeout_ms'] - - # Recheck node_id in case we were able to connect immediately above if self._can_send_request(node_id): topics = list(self._topics) if not topics and self.cluster.is_bootstrap(node_id): topics = list(self.config['bootstrap_topics_filter']) - api_version = self.api_version(MetadataRequest, max_version=7) - if self.cluster.need_all_topic_metadata: - topics = MetadataRequest[api_version].ALL_TOPICS - elif not topics: - topics = MetadataRequest[api_version].NO_TOPICS - if api_version >= 4: - request = MetadataRequest[api_version](topics, self.config['allow_auto_create_topics']) - else: - request = MetadataRequest[api_version](topics) + if self.cluster.need_all_topic_metadata or not topics: + topics = [] if self.config['api_version'] < (0, 10) else None + api_version = 0 if self.config['api_version'] < (0, 10) else 1 + request = MetadataRequest[api_version](topics) log.debug("Sending metadata request %s to node %s", request, node_id) future = self.send(node_id, request, wakeup=wakeup) future.add_callback(self.cluster.update_metadata) @@ -999,146 +846,103 @@ class KafkaClient(object): future.add_errback(refresh_done) return self.config['request_timeout_ms'] - # Should only get here if still connecting + # If there's any connection establishment underway, wait until it completes. This prevents + # the client from unnecessarily connecting to additional nodes while a previous connection + # attempt has not been completed. if self._connecting: - return float('inf') - else: return self.config['reconnect_backoff_ms'] + if self.maybe_connect(node_id, wakeup=wakeup): + log.debug("Initializing connection to node %s for metadata request", node_id) + return self.config['reconnect_backoff_ms'] + + # connected but can't send more, OR connecting + # In either case we just need to wait for a network event + # to let us know the selected connection might be usable again. + return float('inf') + def get_api_versions(self): """Return the ApiVersions map, if available. - Note: Only available after bootstrap; requires broker version 0.10.0 or later. + Note: A call to check_version must previously have succeeded and returned + version 0.10.0 or later Returns: a map of dict mapping {api_key : (min_version, max_version)}, or None if ApiVersion is not supported by the kafka cluster. """ return self._api_versions - def check_version(self, node_id=None, timeout=None, **kwargs): + def check_version(self, node_id=None, timeout=2, strict=False): """Attempt to guess the version of a Kafka broker. - Keyword Arguments: - node_id (str, optional): Broker node id from cluster metadata. If None, attempts - to connect to any available broker until version is identified. - Default: None - timeout (num, optional): Maximum time in seconds to try to check broker version. - If unable to identify version before timeout, raise error (see below). - Default: api_version_auto_timeout_ms / 1000 + Note: It is possible that this method blocks longer than the + specified timeout. This can happen if the entire cluster + is down and the client enters a bootstrap backoff sleep. + This is only possible if node_id is None. - Returns: version tuple, i.e. (3, 9), (2, 0), (0, 10, 2) etc + Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... Raises: NodeNotReadyError (if node_id is provided) NoBrokersAvailable (if node_id is None) + UnrecognizedBrokerVersion: please file bug if seen! + AssertionError (if strict=True): please file bug if seen! """ - timeout = timeout or (self.config['api_version_auto_timeout_ms'] / 1000) - with self._lock: - end = time.time() + timeout - while time.time() < end: - time_remaining = max(end - time.time(), 0) - if node_id is not None and self.connection_delay(node_id) > 0: - sleep_time = min(time_remaining, self.connection_delay(node_id) / 1000.0) - if sleep_time > 0: - time.sleep(sleep_time) - continue - try_node = node_id or self.least_loaded_node() - if try_node is None: - sleep_time = min(time_remaining, self.least_loaded_node_refresh_ms() / 1000.0) - if sleep_time > 0: - log.warning('No node available during check_version; sleeping %.2f secs', sleep_time) - time.sleep(sleep_time) - continue - log.debug('Attempting to check version with node %s', try_node) - if not self._init_connect(try_node): - if try_node == node_id: - raise Errors.NodeNotReadyError("Connection failed to %s" % node_id) - else: - continue - conn = self._conns[try_node] + self._lock.acquire() + end = time.time() + timeout + while time.time() < end: - while conn.connecting() and time.time() < end: - timeout_ms = min((end - time.time()) * 1000, 200) - self.poll(timeout_ms=timeout_ms) + # It is possible that least_loaded_node falls back to bootstrap, + # which can block for an increasing backoff period + try_node = node_id or self.least_loaded_node() + if try_node is None: + self._lock.release() + raise Errors.NoBrokersAvailable() + self._maybe_connect(try_node) + conn = self._conns[try_node] - if conn._api_version is not None: - return conn._api_version - else: - log.debug('Failed to identify api_version after connection attempt to %s', conn) - - # Timeout - else: + # We will intentionally cause socket failures + # These should not trigger metadata refresh + self._refresh_on_disconnects = False + try: + remaining = end - time.time() + version = conn.check_version(timeout=remaining, strict=strict, topics=list(self.config['bootstrap_topics_filter'])) + if version >= (0, 10, 0): + # cache the api versions map if it's available (starting + # in 0.10 cluster version) + self._api_versions = conn.get_api_versions() + self._lock.release() + return version + except Errors.NodeNotReadyError: + # Only raise to user if this is a node-specific request if node_id is not None: - raise Errors.NodeNotReadyError(node_id) - else: - raise Errors.NoBrokersAvailable() + self._lock.release() + raise + finally: + self._refresh_on_disconnects = True - def api_version(self, operation, max_version=None): - """Find the latest version of the protocol operation supported by both - this library and the broker. - - This resolves to the lesser of either the latest api version this - library supports, or the max version supported by the broker. - - Arguments: - operation: A list of protocol operation versions from kafka.protocol. - - Keyword Arguments: - max_version (int, optional): Provide an alternate maximum api version - to reflect limitations in user code. - - Returns: - int: The highest api version number compatible between client and broker. - - Raises: IncompatibleBrokerVersion if no matching version is found - """ - # Cap max_version at the largest available version in operation list - max_version = min(len(operation) - 1, max_version if max_version is not None else float('inf')) - broker_api_versions = self._api_versions - api_key = operation[0].API_KEY - if broker_api_versions is None or api_key not in broker_api_versions: - raise Errors.IncompatibleBrokerVersion( - "Kafka broker does not support the '{}' Kafka protocol." - .format(operation[0].__name__)) - broker_min_version, broker_max_version = broker_api_versions[api_key] - version = min(max_version, broker_max_version) - if version < broker_min_version: - # max library version is less than min broker version. Currently, - # no Kafka versions specify a min msg version. Maybe in the future? - raise Errors.IncompatibleBrokerVersion( - "No version of the '{}' Kafka protocol is supported by both the client and broker." - .format(operation[0].__name__)) - return version + # Timeout + else: + self._lock.release() + raise Errors.NoBrokersAvailable() def wakeup(self): - if self._closed or self._waking or self._wake_w is None: - return with self._wake_lock: try: self._wake_w.sendall(b'x') - self._waking = True - except socket.timeout as e: + except socket.timeout: log.warning('Timeout to send to wakeup socket!') - raise Errors.KafkaTimeoutError(e) - except socket.error as e: - log.warning('Unable to send to wakeup socket! %s', e) - raise e + raise Errors.KafkaTimeoutError() + except socket.error: + log.warning('Unable to send to wakeup socket!') def _clear_wake_fd(self): # reading from wake socket should only happen in a single thread - with self._wake_lock: - self._waking = False - while True: - try: - if not self._wake_r.recv(1024): - # Non-blocking socket returns empty on error - log.warning("Error reading wakeup socket. Rebuilding socketpair.") - self._close_wakeup_socketpair() - self._init_wakeup_socketpair() - break - except socket.error: - # Non-blocking socket raises when socket is ok but no data available to read - break + while True: + try: + self._wake_r.recv(1024) + except socket.error: + break def _maybe_close_oldest_connection(self): expired_connection = self._idle_expiry_manager.poll_expired_connection() @@ -1158,39 +962,6 @@ class KafkaClient(object): else: return False - def await_ready(self, node_id, timeout_ms=30000): - """ - Invokes `poll` to discard pending disconnects, followed by `client.ready` and 0 or more `client.poll` - invocations until the connection to `node` is ready, the timeoutMs expires or the connection fails. - - It returns `true` if the call completes normally or `false` if the timeoutMs expires. If the connection fails, - an `IOException` is thrown instead. Note that if the `NetworkClient` has been configured with a positive - connection timeoutMs, it is possible for this method to raise an `IOException` for a previous connection which - has recently disconnected. - - This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with - care. - """ - timer = Timer(timeout_ms) - self.poll(timeout_ms=0) - if self.is_ready(node_id): - return True - - while not self.is_ready(node_id) and not timer.expired: - if self.connection_failed(node_id): - raise Errors.KafkaConnectionError("Connection to %s failed." % (node_id,)) - self.maybe_connect(node_id) - self.poll(timeout_ms=timer.timeout_ms) - return self.is_ready(node_id) - - def send_and_receive(self, node_id, request): - future = self.send(node_id, request) - self.poll(future=future) - assert future.is_done - if future.failed(): - raise future.exception - return future.value - # OrderedDict requires python2.7+ try: @@ -1227,7 +998,7 @@ class IdleConnectionManager(object): def next_check_ms(self): now = time.time() - if not self.lru_connections or self.next_idle_close_check_time == float('inf'): + if not self.lru_connections: return float('inf') elif self.next_idle_close_check_time <= now: return 0 diff --git a/venv/lib/python3.12/site-packages/kafka/cluster.py b/venv/lib/python3.12/site-packages/kafka/cluster.py index ded8c6f..438baf2 100644 --- a/venv/lib/python3.12/site-packages/kafka/cluster.py +++ b/venv/lib/python3.12/site-packages/kafka/cluster.py @@ -3,15 +3,13 @@ from __future__ import absolute_import import collections import copy import logging -import random -import re import threading import time from kafka.vendor import six from kafka import errors as Errors -from kafka.conn import get_ip_port_afi +from kafka.conn import collect_hosts from kafka.future import Future from kafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition @@ -23,7 +21,7 @@ class ClusterMetadata(object): A class to manage kafka cluster metadata. This class does not perform any IO. It simply updates internal state - given API responses (MetadataResponse, FindCoordinatorResponse). + given API responses (MetadataResponse, GroupCoordinatorResponse). Keyword Arguments: retry_backoff_ms (int): Milliseconds to backoff when retrying on @@ -49,7 +47,7 @@ class ClusterMetadata(object): self._brokers = {} # node_id -> BrokerMetadata self._partitions = {} # topic -> partition -> PartitionMetadata self._broker_partitions = collections.defaultdict(set) # node_id -> {TopicPartition...} - self._coordinators = {} # (coord_type, coord_key) -> node_id + self._groups = {} # group_name -> node_id self._last_refresh_ms = 0 self._last_successful_refresh_ms = 0 self._need_update = True @@ -60,7 +58,6 @@ class ClusterMetadata(object): self.unauthorized_topics = set() self.internal_topics = set() self.controller = None - self.cluster_id = None self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -95,7 +92,7 @@ class ClusterMetadata(object): """Get BrokerMetadata Arguments: - broker_id (int or str): node_id for a broker to check + broker_id (int): node_id for a broker to check Returns: BrokerMetadata or None if not found @@ -114,7 +111,6 @@ class ClusterMetadata(object): Returns: set: {partition (int), ...} - None if topic not found. """ if topic not in self._partitions: return None @@ -144,14 +140,11 @@ class ClusterMetadata(object): return None return self._partitions[partition.topic][partition.partition].leader - def leader_epoch_for_partition(self, partition): - return self._partitions[partition.topic][partition.partition].leader_epoch - def partitions_for_broker(self, broker_id): """Return TopicPartitions for which the broker is a leader. Arguments: - broker_id (int or str): node id for a broker + broker_id (int): node id for a broker Returns: set: {TopicPartition, ...} @@ -166,10 +159,10 @@ class ClusterMetadata(object): group (str): name of consumer group Returns: - node_id (int or str) for group coordinator, -1 if coordinator unknown + int: node_id for group coordinator None if the group does not exist. """ - return self._coordinators.get(('group', group)) + return self._groups.get(group) def ttl(self): """Milliseconds until metadata should be refreshed""" @@ -204,10 +197,6 @@ class ClusterMetadata(object): self._future = Future() return self._future - @property - def need_update(self): - return self._need_update - def topics(self, exclude_internal_topics=True): """Get set of known topics. @@ -245,6 +234,13 @@ class ClusterMetadata(object): Returns: None """ + # In the common case where we ask for a single topic and get back an + # error, we should fail the future + if len(metadata.topics) == 1 and metadata.topics[0][0] != 0: + error_code, topic = metadata.topics[0][:2] + error = Errors.for_code(error_code)(topic) + return self.failed_update(error) + if not metadata.brokers: log.warning("No broker metadata found in MetadataResponse -- ignoring.") return self.failed_update(Errors.MetadataEmptyBrokerList(metadata)) @@ -265,11 +261,6 @@ class ClusterMetadata(object): else: _new_controller = _new_brokers.get(metadata.controller_id) - if metadata.API_VERSION < 2: - _new_cluster_id = None - else: - _new_cluster_id = metadata.cluster_id - _new_partitions = {} _new_broker_partitions = collections.defaultdict(set) _new_unauthorized_topics = set() @@ -286,21 +277,10 @@ class ClusterMetadata(object): error_type = Errors.for_code(error_code) if error_type is Errors.NoError: _new_partitions[topic] = {} - for partition_data in partitions: - leader_epoch = -1 - offline_replicas = [] - if metadata.API_VERSION >= 7: - p_error, partition, leader, leader_epoch, replicas, isr, offline_replicas = partition_data - elif metadata.API_VERSION >= 5: - p_error, partition, leader, replicas, isr, offline_replicas = partition_data - else: - p_error, partition, leader, replicas, isr = partition_data - + for p_error, partition, leader, replicas, isr in partitions: _new_partitions[topic][partition] = PartitionMetadata( - topic=topic, partition=partition, - leader=leader, leader_epoch=leader_epoch, - replicas=replicas, isr=isr, offline_replicas=offline_replicas, - error=p_error) + topic=topic, partition=partition, leader=leader, + replicas=replicas, isr=isr, error=p_error) if leader != -1: _new_broker_partitions[leader].add( TopicPartition(topic, partition)) @@ -326,7 +306,6 @@ class ClusterMetadata(object): with self._lock: self._brokers = _new_brokers self.controller = _new_controller - self.cluster_id = _new_cluster_id self._partitions = _new_partitions self._broker_partitions = _new_broker_partitions self.unauthorized_topics = _new_unauthorized_topics @@ -342,15 +321,7 @@ class ClusterMetadata(object): self._last_successful_refresh_ms = now if f: - # In the common case where we ask for a single topic and get back an - # error, we should fail the future - if len(metadata.topics) == 1 and metadata.topics[0][0] != Errors.NoError.errno: - error_code, topic = metadata.topics[0][:2] - error = Errors.for_code(error_code)(topic) - f.failure(error) - else: - f.success(self) - + f.success(self) log.debug("Updated cluster metadata to %s", self) for listener in self._listeners: @@ -371,25 +342,24 @@ class ClusterMetadata(object): """Remove a previously added listener callback""" self._listeners.remove(listener) - def add_coordinator(self, response, coord_type, coord_key): - """Update with metadata for a group or txn coordinator + def add_group_coordinator(self, group, response): + """Update with metadata for a group coordinator Arguments: - response (FindCoordinatorResponse): broker response - coord_type (str): 'group' or 'transaction' - coord_key (str): consumer_group or transactional_id + group (str): name of group from GroupCoordinatorRequest + response (GroupCoordinatorResponse): broker response Returns: string: coordinator node_id if metadata is updated, None on error """ - log.debug("Updating coordinator for %s/%s: %s", coord_type, coord_key, response) + log.debug("Updating coordinator for %s: %s", group, response) error_type = Errors.for_code(response.error_code) if error_type is not Errors.NoError: - log.error("FindCoordinatorResponse error: %s", error_type) - self._coordinators[(coord_type, coord_key)] = -1 + log.error("GroupCoordinatorResponse error: %s", error_type) + self._groups[group] = -1 return - # Use a coordinator-specific node id so that requests + # Use a coordinator-specific node id so that group requests # get a dedicated connection node_id = 'coordinator-{}'.format(response.coordinator_id) coordinator = BrokerMetadata( @@ -398,9 +368,9 @@ class ClusterMetadata(object): response.port, None) - log.info("Coordinator for %s/%s is %s", coord_type, coord_key, coordinator) + log.info("Group coordinator for %s is %s", group, coordinator) self._coordinator_brokers[node_id] = coordinator - self._coordinators[(coord_type, coord_key)] = node_id + self._groups[group] = node_id return node_id def with_partitions(self, partitions_to_add): @@ -409,7 +379,7 @@ class ClusterMetadata(object): new_metadata._brokers = copy.deepcopy(self._brokers) new_metadata._partitions = copy.deepcopy(self._partitions) new_metadata._broker_partitions = copy.deepcopy(self._broker_partitions) - new_metadata._coordinators = copy.deepcopy(self._coordinators) + new_metadata._groups = copy.deepcopy(self._groups) new_metadata.internal_topics = copy.deepcopy(self.internal_topics) new_metadata.unauthorized_topics = copy.deepcopy(self.unauthorized_topics) @@ -423,26 +393,5 @@ class ClusterMetadata(object): return new_metadata def __str__(self): - return 'ClusterMetadata(brokers: %d, topics: %d, coordinators: %d)' % \ - (len(self._brokers), len(self._partitions), len(self._coordinators)) - - -def collect_hosts(hosts, randomize=True): - """ - Collects a comma-separated set of hosts (host:port) and optionally - randomize the returned list. - """ - - if isinstance(hosts, six.string_types): - hosts = hosts.strip().split(',') - - result = [] - for host_port in hosts: - # ignore leading SECURITY_PROTOCOL:// to mimic java client - host_port = re.sub('^.*://', '', host_port) - host, port, afi = get_ip_port_afi(host_port) - result.append((host, port, afi)) - - if randomize: - random.shuffle(result) - return result + return 'ClusterMetadata(brokers: %d, topics: %d, groups: %d)' % \ + (len(self._brokers), len(self._partitions), len(self._groups)) diff --git a/venv/lib/python3.12/site-packages/kafka/codec.py b/venv/lib/python3.12/site-packages/kafka/codec.py index b73df06..917400e 100644 --- a/venv/lib/python3.12/site-packages/kafka/codec.py +++ b/venv/lib/python3.12/site-packages/kafka/codec.py @@ -187,21 +187,14 @@ def _detect_xerial_stream(payload): The version is the version of this format as written by xerial, in the wild this is currently 1 as such we only support v1. - Compat is there to claim the minimum supported version that + Compat is there to claim the miniumum supported version that can read a xerial block stream, presently in the wild this is 1. """ if len(payload) > 16: - magic = struct.unpack('!' + _XERIAL_V1_FORMAT[:8], bytes(payload)[:8]) - version, compat = struct.unpack('!' + _XERIAL_V1_FORMAT[8:], bytes(payload)[8:16]) - # Until there is more than one way to do xerial blocking, the version + compat - # fields can be ignored. Also some producers (i.e., redpanda) are known to - # incorrectly encode these as little-endian, and that causes us to fail decoding - # when we otherwise would have succeeded. - # See https://github.com/dpkp/kafka-python/issues/2414 - if magic == _XERIAL_V1_HEADER[:8]: - return True + header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16]) + return header == _XERIAL_V1_HEADER return False diff --git a/venv/lib/python3.12/site-packages/kafka/conn.py b/venv/lib/python3.12/site-packages/kafka/conn.py index 64445fa..5c72875 100644 --- a/venv/lib/python3.12/site-packages/kafka/conn.py +++ b/venv/lib/python3.12/site-packages/kafka/conn.py @@ -4,7 +4,7 @@ import copy import errno import io import logging -from random import uniform +from random import shuffle, uniform # selectors in stdlib as of py3.4 try: @@ -14,6 +14,7 @@ except ImportError: from kafka.vendor import selectors34 as selectors import socket +import struct import threading import time @@ -22,21 +23,16 @@ from kafka.vendor import six import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.protocol.admin import DescribeAclsRequest, DescribeClientQuotasRequest, ListGroupsRequest -from kafka.protocol.api_versions import ApiVersionsRequest -from kafka.protocol.broker_api_versions import BROKER_API_VERSIONS +from kafka.oauth.abstract import AbstractTokenProvider +from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2 from kafka.protocol.commit import OffsetFetchRequest -from kafka.protocol.fetch import FetchRequest -from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.list_offsets import ListOffsetsRequest -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.parser import KafkaProtocol +from kafka.protocol.offset import OffsetRequest from kafka.protocol.produce import ProduceRequest -from kafka.protocol.sasl_authenticate import SaslAuthenticateRequest -from kafka.protocol.sasl_handshake import SaslHandshakeRequest -from kafka.protocol.types import Int32 -from kafka.sasl import get_sasl_mechanism -from kafka.socks5_wrapper import Socks5Wrapper +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.fetch import FetchRequest +from kafka.protocol.parser import KafkaProtocol +from kafka.protocol.types import Int32, Int8 +from kafka.scram import ScramClient from kafka.version import __version__ @@ -49,6 +45,10 @@ log = logging.getLogger(__name__) DEFAULT_KAFKA_PORT = 9092 +SASL_QOP_AUTH = 1 +SASL_QOP_AUTH_INT = 2 +SASL_QOP_AUTH_CONF = 4 + try: import ssl ssl_available = True @@ -74,6 +74,15 @@ except ImportError: class SSLWantWriteError(Exception): pass +# needed for SASL_GSSAPI authentication: +try: + import gssapi + from gssapi.raw.misc import GSSError +except ImportError: + #no gssapi available, will disable gssapi mechanism + gssapi = None + GSSError = None + AFI_NAMES = { socket.AF_UNSPEC: "unspecified", @@ -83,13 +92,12 @@ AFI_NAMES = { class ConnectionStates(object): + DISCONNECTING = '' DISCONNECTED = '' CONNECTING = '' HANDSHAKE = '' CONNECTED = '' AUTHENTICATING = '' - API_VERSIONS_SEND = '' - API_VERSIONS_RECV = '' class BrokerConnection(object): @@ -101,10 +109,6 @@ class BrokerConnection(object): server-side log entries that correspond to this client. Also submitted to GroupCoordinator for logging with respect to consumer group administration. Default: 'kafka-python-{version}' - client_software_name (str): Sent to kafka broker for KIP-511. - Default: 'kafka-python' - client_software_version (str): Sent to kafka broker for KIP-511. - Default: The kafka-python version (via kafka.version). reconnect_backoff_ms (int): The amount of time in milliseconds to wait before attempting to reconnect to a given host. Default: 50. @@ -116,7 +120,7 @@ class BrokerConnection(object): reconnection attempts will continue periodically with this fixed rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 30000. + 20% below and 20% above the computed value. Default: 1000. request_timeout_ms (int): Client request timeout in milliseconds. Default: 30000. max_in_flight_requests_per_connection (int): Requests are pipelined @@ -161,11 +165,11 @@ class BrokerConnection(object): or other configuration forbids use of all the specified ciphers), an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers api_version (tuple): Specify which Kafka API version to use. - Must be None or >= (0, 10, 0) to enable SASL authentication. - Default: None + Accepted values are: (0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9), + (0, 10). Default: (0, 8, 2) api_version_auto_timeout_ms (int): number of milliseconds to throw a timeout exception from the constructor when checking the broker - api version. Only applies if api_version is None. Default: 2000. + api version. Only applies if api_version is None selector (selectors.BaseSelector): Provide a specific selector implementation to use for I/O multiplexing. Default: selectors.DefaultSelector @@ -181,26 +185,20 @@ class BrokerConnection(object): Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_name (str or gssapi.Name): Constructed gssapi.Name for use with - sasl mechanism handshake. If provided, sasl_kerberos_service_name and - sasl_kerberos_domain name are ignored. Default: None. sasl_kerberos_service_name (str): Service name to include in GSSAPI sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (kafka.sasl.oauth.AbstractTokenProvider): OAuthBearer - token provider instance. Default: None - socks5_proxy (str): Socks5 proxy url. Default: None + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None """ DEFAULT_CONFIG = { 'client_id': 'kafka-python-' + __version__, - 'client_software_name': 'kafka-python', - 'client_software_version': __version__, 'node_id': 0, 'request_timeout_ms': 30000, 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 30000, + 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, 'receive_buffer_bytes': None, 'send_buffer_bytes': None, @@ -216,8 +214,7 @@ class BrokerConnection(object): 'ssl_crlfile': None, 'ssl_password': None, 'ssl_ciphers': None, - 'api_version': None, - 'api_version_auto_timeout_ms': 2000, + 'api_version': (0, 8, 2), # default to most restrictive 'selector': selectors.DefaultSelector, 'state_change_callback': lambda node_id, sock, conn: True, 'metrics': None, @@ -225,19 +222,12 @@ class BrokerConnection(object): 'sasl_mechanism': None, 'sasl_plain_username': None, 'sasl_plain_password': None, - 'sasl_kerberos_name': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - 'socks5_proxy': None, + 'sasl_oauth_token_provider': None } SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - VERSION_CHECKS = ( - ((0, 9), ListGroupsRequest[0]()), - ((0, 8, 2), FindCoordinatorRequest[0]('kafka-python-default-group')), - ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), - ((0, 8, 0), MetadataRequest[0]([])), - ) + SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") def __init__(self, host, port, afi, **configs): self.host = host @@ -246,11 +236,6 @@ class BrokerConnection(object): self._sock_afi = afi self._sock_addr = None self._api_versions = None - self._api_version = None - self._check_version_idx = None - self._api_versions_idx = 4 # version of ApiVersionsRequest to try on first connect - self._throttle_time = None - self._socks5_proxy = None self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -274,8 +259,23 @@ class BrokerConnection(object): if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): assert ssl_available, "Python wasn't built with SSL support" - self._init_sasl_mechanism() - + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): + assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( + 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) + if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): + assert self.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required for PLAIN or SCRAM sasl' + ) + assert self.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required for PLAIN or SCRAM sasl' + ) + if self.config['sasl_mechanism'] == 'GSSAPI': + assert gssapi is not None, 'GSSAPI lib not available' + assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' + if self.config['sasl_mechanism'] == 'OAUTHBEARER': + token_provider = self.config['sasl_oauth_token_provider'] + assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' + assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' # This is not a general lock / this class is not generally thread-safe yet # However, to avoid pushing responsibility for maintaining # per-connection locks to the upstream client, we will use this lock to @@ -300,8 +300,6 @@ class BrokerConnection(object): self._ssl_context = None if self.config['ssl_context'] is not None: self._ssl_context = self.config['ssl_context'] - self._api_versions_future = None - self._api_versions_check_timeout = self.config['api_version_auto_timeout_ms'] self._sasl_auth_future = None self.last_attempt = 0 self._gai = [] @@ -311,17 +309,11 @@ class BrokerConnection(object): self.config['metric_group_prefix'], self.node_id) - def _init_sasl_mechanism(self): - if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - self._sasl_mechanism = get_sasl_mechanism(self.config['sasl_mechanism'])(host=self.host, **self.config) - else: - self._sasl_mechanism = None - def _dns_lookup(self): self._gai = dns_lookup(self.host, self.port, self.afi) if not self._gai: - log.error('%s: DNS lookup failed for %s:%i (%s)', - self, self.host, self.port, self.afi) + log.error('DNS lookup failed for %s:%i (%s)', + self.host, self.port, self.afi) return False return True @@ -367,7 +359,6 @@ class BrokerConnection(object): def connect(self): """Attempt to connect and return ConnectionState""" if self.state is ConnectionStates.DISCONNECTED and not self.blacked_out(): - self.state = ConnectionStates.CONNECTING self.last_attempt = time.time() next_lookup = self._next_afi_sockaddr() if not next_lookup: @@ -377,21 +368,14 @@ class BrokerConnection(object): log.debug('%s: creating new socket', self) assert self._sock is None self._sock_afi, self._sock_addr = next_lookup - try: - if self.config["socks5_proxy"] is not None: - self._socks5_proxy = Socks5Wrapper(self.config["socks5_proxy"], self.afi) - self._sock = self._socks5_proxy.socket(self._sock_afi, socket.SOCK_STREAM) - else: - self._sock = socket.socket(self._sock_afi, socket.SOCK_STREAM) - except (socket.error, OSError) as e: - self.close(e) - return self.state + self._sock = socket.socket(self._sock_afi, socket.SOCK_STREAM) for option in self.config['socket_options']: log.debug('%s: setting socket option %s', self, option) self._sock.setsockopt(*option) self._sock.setblocking(False) + self.state = ConnectionStates.CONNECTING self.config['state_change_callback'](self.node_id, self._sock, self) log.info('%s: connecting to %s:%d [%s %s]', self, self.host, self.port, self._sock_addr, AFI_NAMES[self._sock_afi]) @@ -401,10 +385,7 @@ class BrokerConnection(object): # to check connection status ret = None try: - if self._socks5_proxy: - ret = self._socks5_proxy.connect_ex(self._sock_addr) - else: - ret = self._sock.connect_ex(self._sock_addr) + ret = self._sock.connect_ex(self._sock_addr) except socket.error as err: ret = err.errno @@ -413,20 +394,28 @@ class BrokerConnection(object): log.debug('%s: established TCP connection', self) if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): - self.state = ConnectionStates.HANDSHAKE log.debug('%s: initiating SSL handshake', self) + self.state = ConnectionStates.HANDSHAKE self.config['state_change_callback'](self.node_id, self._sock, self) # _wrap_ssl can alter the connection state -- disconnects on failure self._wrap_ssl() + + elif self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.debug('%s: initiating SASL authentication', self) + self.state = ConnectionStates.AUTHENTICATING + self.config['state_change_callback'](self.node_id, self._sock, self) + else: - self.state = ConnectionStates.API_VERSIONS_SEND - log.debug('%s: checking broker Api Versions', self) + # security_protocol PLAINTEXT + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) # Connection failed # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022): - log.error('%s: Connect attempt returned error %s.' + log.error('Connect attempt to %s returned error %s.' ' Disconnecting.', self, ret) errstr = errno.errorcode.get(ret, 'UNKNOWN') self.close(Errors.KafkaConnectionError('{} {}'.format(ret, errstr))) @@ -439,32 +428,22 @@ class BrokerConnection(object): if self.state is ConnectionStates.HANDSHAKE: if self._try_handshake(): log.debug('%s: completed SSL handshake.', self) - self.state = ConnectionStates.API_VERSIONS_SEND - log.debug('%s: checking broker Api Versions', self) + if self.config['security_protocol'] == 'SASL_SSL': + log.debug('%s: initiating SASL authentication', self) + self.state = ConnectionStates.AUTHENTICATING + else: + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) - if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV): - if self._try_api_versions_check(): - # _try_api_versions_check has side-effects: possibly disconnected on socket errors - if self.state in (ConnectionStates.API_VERSIONS_SEND, ConnectionStates.API_VERSIONS_RECV): - if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - self.state = ConnectionStates.AUTHENTICATING - log.debug('%s: initiating SASL authentication', self) - self.config['state_change_callback'](self.node_id, self._sock, self) - else: - # security_protocol PLAINTEXT - self.state = ConnectionStates.CONNECTED - log.info('%s: Connection complete.', self) - self._reset_reconnect_backoff() - self.config['state_change_callback'](self.node_id, self._sock, self) - if self.state is ConnectionStates.AUTHENTICATING: assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') if self._try_authenticate(): # _try_authenticate has side-effects: possibly disconnected on socket errors if self.state is ConnectionStates.AUTHENTICATING: - self.state = ConnectionStates.CONNECTED log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED self._reset_reconnect_backoff() self.config['state_change_callback'](self.node_id, self._sock, self) @@ -473,7 +452,7 @@ class BrokerConnection(object): # Connection timed out request_timeout = self.config['request_timeout_ms'] / 1000.0 if time.time() > request_timeout + self.last_attempt: - log.error('%s: Connection attempt timed out', self) + log.error('Connection attempt to %s timed out', self) self.close(Errors.KafkaConnectionError('timeout')) return self.state @@ -517,7 +496,7 @@ class BrokerConnection(object): try: self._sock = self._ssl_context.wrap_socket( self._sock, - server_hostname=self.host.rstrip("."), + server_hostname=self.host, do_handshake_on_connect=False) except ssl.SSLError as e: log.exception('%s: Failed to wrap socket in SSLContext!', self) @@ -532,136 +511,20 @@ class BrokerConnection(object): except (SSLWantReadError, SSLWantWriteError): pass except (SSLZeroReturnError, ConnectionError, TimeoutError, SSLEOFError): - log.warning('%s: SSL connection closed by server during handshake.', self) + log.warning('SSL connection closed by server during handshake.') self.close(Errors.KafkaConnectionError('SSL connection closed by server during handshake')) # Other SSLErrors will be raised to user return False - def _try_api_versions_check(self): - if self._api_versions_future is None: - if self.config['api_version'] is not None: - self._api_version = self.config['api_version'] - # api_version will be normalized by KafkaClient, so this should not happen - if self._api_version not in BROKER_API_VERSIONS: - raise Errors.UnrecognizedBrokerVersion('api_version %s not found in kafka.protocol.broker_api_versions' % (self._api_version,)) - self._api_versions = BROKER_API_VERSIONS[self._api_version] - log.debug('%s: Using pre-configured api_version %s for ApiVersions', self, self._api_version) - return True - elif self._check_version_idx is None: - version = self._api_versions_idx - if version >= 3: - request = ApiVersionsRequest[version]( - client_software_name=self.config['client_software_name'], - client_software_version=self.config['client_software_version'], - _tagged_fields={}) - else: - request = ApiVersionsRequest[version]() - future = Future() - self._api_versions_check_timeout /= 2 - response = self._send(request, blocking=True, request_timeout_ms=self._api_versions_check_timeout) - response.add_callback(self._handle_api_versions_response, future) - response.add_errback(self._handle_api_versions_failure, future) - self._api_versions_future = future - self.state = ConnectionStates.API_VERSIONS_RECV - self.config['state_change_callback'](self.node_id, self._sock, self) - elif self._check_version_idx < len(self.VERSION_CHECKS): - version, request = self.VERSION_CHECKS[self._check_version_idx] - future = Future() - self._api_versions_check_timeout /= 2 - response = self._send(request, blocking=True, request_timeout_ms=self._api_versions_check_timeout) - response.add_callback(self._handle_check_version_response, future, version) - response.add_errback(self._handle_check_version_failure, future) - self._api_versions_future = future - self.state = ConnectionStates.API_VERSIONS_RECV - self.config['state_change_callback'](self.node_id, self._sock, self) - else: - self.close(Errors.KafkaConnectionError('Unable to determine broker version.')) - return False - - for r, f in self.recv(): - f.success(r) - - # A connection error during blocking send could trigger close() which will reset the future - if self._api_versions_future is None: - return False - elif self._api_versions_future.failed(): - ex = self._api_versions_future.exception - if not isinstance(ex, Errors.KafkaConnectionError): - raise ex - return self._api_versions_future.succeeded() - - def _handle_api_versions_response(self, future, response): - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - future.failure(error_type()) - if error_type is Errors.UnsupportedVersionError: - self._api_versions_idx -= 1 - for api_version_data in response.api_versions: - api_key, min_version, max_version = api_version_data[:3] - # If broker provides a lower max_version, skip to that - if api_key == response.API_KEY: - self._api_versions_idx = min(self._api_versions_idx, max_version) - break - if self._api_versions_idx >= 0: - self._api_versions_future = None - self.state = ConnectionStates.API_VERSIONS_SEND - self.config['state_change_callback'](self.node_id, self._sock, self) - else: - self.close(error=error_type()) - return - self._api_versions = dict([ - (api_version_data[0], (api_version_data[1], api_version_data[2])) - for api_version_data in response.api_versions - ]) - self._api_version = self._infer_broker_version_from_api_versions(self._api_versions) - log.info('%s: Broker version identified as %s', self, '.'.join(map(str, self._api_version))) - future.success(self._api_version) - self.connect() - - def _handle_api_versions_failure(self, future, ex): - future.failure(ex) - # Modern brokers should not disconnect on unrecognized api-versions request, - # but in case they do we always want to try v0 as a fallback - # otherwise switch to check_version probe. - if self._api_versions_idx > 0: - self._api_versions_idx = 0 - else: - self._check_version_idx = 0 - # after failure connection is closed, so state should already be DISCONNECTED - - def _handle_check_version_response(self, future, version, _response): - log.info('%s: Broker version identified as %s', self, '.'.join(map(str, version))) - log.info('Set configuration api_version=%s to skip auto' - ' check_version requests on startup', version) - self._api_versions = BROKER_API_VERSIONS[version] - self._api_version = version - future.success(version) - self.connect() - - def _handle_check_version_failure(self, future, ex): - future.failure(ex) - self._check_version_idx += 1 - # after failure connection is closed, so state should already be DISCONNECTED - - def _sasl_handshake_version(self): - if self._api_versions is None: - raise RuntimeError('_api_versions not set') - if SaslHandshakeRequest[0].API_KEY not in self._api_versions: - raise Errors.UnsupportedVersionError('SaslHandshake') - - # Build a SaslHandshakeRequest message - min_version, max_version = self._api_versions[SaslHandshakeRequest[0].API_KEY] - if min_version > 1: - raise Errors.UnsupportedVersionError('SaslHandshake %s' % min_version) - return min(max_version, 1) - def _try_authenticate(self): + assert self.config['api_version'] is None or self.config['api_version'] >= (0, 10) + if self._sasl_auth_future is None: - version = self._sasl_handshake_version() - request = SaslHandshakeRequest[version](self.config['sasl_mechanism']) + # Build a SaslHandShakeRequest message + request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) future = Future() - sasl_response = self._send(request, blocking=True) + sasl_response = self._send(request) sasl_response.add_callback(self._handle_sasl_handshake_response, future) sasl_response.add_errback(lambda f, e: f.failure(e), future) self._sasl_auth_future = future @@ -686,18 +549,23 @@ class BrokerConnection(object): return future.failure(error_type(self)) if self.config['sasl_mechanism'] not in response.enabled_mechanisms: - future.failure( + return future.failure( Errors.UnsupportedSaslMechanismError( 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' % (self.config['sasl_mechanism'], response.enabled_mechanisms))) + elif self.config['sasl_mechanism'] == 'PLAIN': + return self._try_authenticate_plain(future) + elif self.config['sasl_mechanism'] == 'GSSAPI': + return self._try_authenticate_gssapi(future) + elif self.config['sasl_mechanism'] == 'OAUTHBEARER': + return self._try_authenticate_oauth(future) + elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): + return self._try_authenticate_scram(future) else: - self._sasl_authenticate(future) - - assert future.is_done, 'SASL future not complete after mechanism processing!' - if future.failed(): - self.close(error=future.exception) - else: - self.connect() + return future.failure( + Errors.UnsupportedSaslMechanismError( + 'kafka-python does not support SASL mechanism %s' % + self.config['sasl_mechanism'])) def _send_bytes(self, data): """Send some data via non-blocking IO @@ -726,7 +594,6 @@ class BrokerConnection(object): return total_sent def _send_bytes_blocking(self, data): - self._sock.setblocking(True) self._sock.settimeout(self.config['request_timeout_ms'] / 1000) total_sent = 0 try: @@ -738,10 +605,8 @@ class BrokerConnection(object): return total_sent finally: self._sock.settimeout(0.0) - self._sock.setblocking(False) def _recv_bytes_blocking(self, n): - self._sock.setblocking(True) self._sock.settimeout(self.config['request_timeout_ms'] / 1000) try: data = b'' @@ -753,76 +618,225 @@ class BrokerConnection(object): return data finally: self._sock.settimeout(0.0) - self._sock.setblocking(False) - def _send_sasl_authenticate(self, sasl_auth_bytes): - version = self._sasl_handshake_version() - if version == 1: - request = SaslAuthenticateRequest[0](sasl_auth_bytes) - self._send(request, blocking=True) - else: - log.debug('%s: Sending %d raw sasl auth bytes to server', self, len(sasl_auth_bytes)) - try: - self._send_bytes_blocking(Int32.encode(len(sasl_auth_bytes)) + sasl_auth_bytes) - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error sending sasl auth bytes to server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=err) + def _try_authenticate_plain(self, future): + if self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', self) - def _recv_sasl_authenticate(self): - version = self._sasl_handshake_version() - # GSSAPI mechanism does not get a final recv in old non-framed mode - if version == 0 and self._sasl_mechanism.is_done(): - return b'' + data = b'' + # Send PLAIN credentials per RFC-4616 + msg = bytes('\0'.join([self.config['sasl_plain_username'], + self.config['sasl_plain_username'], + self.config['sasl_plain_password']]).encode('utf-8')) + size = Int32.encode(len(msg)) - try: - data = self._recv_bytes_blocking(4) - nbytes = Int32.decode(io.BytesIO(data)) - data += self._recv_bytes_blocking(nbytes) - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving sasl auth bytes from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=err) - return - - if version == 1: - ((correlation_id, response),) = self._protocol.receive_bytes(data) - (future, timestamp, _timeout) = self.in_flight_requests.pop(correlation_id) - latency_ms = (time.time() - timestamp) * 1000 - if self._sensors: - self._sensors.request_time.record(latency_ms) - log.debug('%s: Response %d (%s ms): %s', self, correlation_id, latency_ms, response) - - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - log.error("%s: SaslAuthenticate error: %s (%s)", - self, error_type.__name__, response.error_message) - self.close(error=error_type(response.error_message)) - return - return response.auth_bytes - else: - # unframed bytes w/ SaslHandhake v0 - log.debug('%s: Received %d raw sasl auth bytes from server', self, nbytes) - return data[4:] - - def _sasl_authenticate(self, future): - while not self._sasl_mechanism.is_done(): - send_token = self._sasl_mechanism.auth_bytes() - self._send_sasl_authenticate(send_token) + err = None + close = False + with self._lock: if not self._can_send_recv(): - return future.failure(Errors.KafkaConnectionError("%s: Connection failure during Sasl Authenticate" % self)) - - recv_token = self._recv_sasl_authenticate() - if recv_token is None: - return future.failure(Errors.KafkaConnectionError("%s: Connection failure during Sasl Authenticate" % self)) + err = Errors.NodeNotReadyError(str(self)) + close = False else: - self._sasl_mechanism.receive(recv_token) + try: + self._send_bytes_blocking(size + msg) - if self._sasl_mechanism.is_authenticated(): - log.info('%s: %s', self, self._sasl_mechanism.auth_details()) - return future.success(True) + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) + return future.success(True) + + def _try_authenticate_scram(self, future): + if self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Exchanging credentials in the clear', self) + + scram_client = ScramClient( + self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] + ) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + client_first = scram_client.first_message().encode('utf-8') + size = Int32.encode(len(client_first)) + self._send_bytes_blocking(size + client_first) + + (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) + server_first = self._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_first_message(server_first) + + client_final = scram_client.final_message().encode('utf-8') + size = Int32.encode(len(client_final)) + self._send_bytes_blocking(size + client_final) + + (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) + server_final = self._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_final_message(server_final) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + log.info( + '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] + ) + return future.success(True) + + def _try_authenticate_gssapi(self, future): + kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host + auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name + gssapi_name = gssapi.Name( + auth_id, + name_type=gssapi.NameType.hostbased_service + ).canonicalize(gssapi.MechType.kerberos) + log.debug('%s: GSSAPI name: %s', self, gssapi_name) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + self._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + self._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = self._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = self._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality security layers, so we + # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed + # by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message + size = Int32.encode(len(msg)) + self._send_bytes_blocking(size + msg) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + except Exception as e: + err = e + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) + return future.success(True) + + def _try_authenticate_oauth(self, future): + data = b'' + + msg = bytes(self._build_oauth_client_request().encode("utf-8")) + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + self._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', self) + return future.success(True) + + def _build_oauth_client_request(self): + token_provider = self.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) + + def _token_extensions(self): + """ + Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER + initial request. + """ + token_provider = self.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) + return "\x01" + msg else: - return future.failure(Errors.SaslAuthenticationFailedError('Failed to authenticate via SASL %s' % self.config['sasl_mechanism'])) + return "" def blacked_out(self): """ @@ -830,43 +844,20 @@ class BrokerConnection(object): re-establish a connection yet """ if self.state is ConnectionStates.DISCONNECTED: - return self.connection_delay() > 0 + if time.time() < self.last_attempt + self._reconnect_backoff: + return True return False - def throttled(self): - """ - Return True if we are connected but currently throttled. - """ - if self.state is not ConnectionStates.CONNECTED: - return False - return self.throttle_delay() > 0 - - def throttle_delay(self): - """ - Return the number of milliseconds to wait until connection is no longer throttled. - """ - if self._throttle_time is not None: - remaining_ms = (self._throttle_time - time.time()) * 1000 - if remaining_ms > 0: - return remaining_ms - else: - self._throttle_time = None - return 0 - return 0 - def connection_delay(self): """ Return the number of milliseconds to wait, based on the connection - state, before attempting to send data. When connecting or disconnected, - this respects the reconnect backoff time. When connected, returns a very + state, before attempting to send data. When disconnected, this respects + the reconnect backoff time. When connecting or connected, returns a very large number to handle slow/stalled connections. """ - if self.disconnected() or self.connecting(): - if len(self._gai) > 0: - return 0 - else: - time_waited = time.time() - self.last_attempt - return max(self._reconnect_backoff - time_waited, 0) * 1000 + time_waited = time.time() - (self.last_attempt or 0) + if self.state is ConnectionStates.DISCONNECTED: + return max(self._reconnect_backoff - time_waited, 0) * 1000 else: # When connecting or connected, we should be able to delay # indefinitely since other events (connection or data acked) will @@ -882,33 +873,16 @@ class BrokerConnection(object): different states, such as SSL handshake, authorization, etc).""" return self.state in (ConnectionStates.CONNECTING, ConnectionStates.HANDSHAKE, - ConnectionStates.AUTHENTICATING, - ConnectionStates.API_VERSIONS_SEND, - ConnectionStates.API_VERSIONS_RECV) - - def initializing(self): - """Returns True if socket is connected but full connection is not complete. - During this time the connection may send api requests to the broker to - check api versions and perform SASL authentication.""" - return self.state in (ConnectionStates.AUTHENTICATING, - ConnectionStates.API_VERSIONS_SEND, - ConnectionStates.API_VERSIONS_RECV) + ConnectionStates.AUTHENTICATING) def disconnected(self): """Return True iff socket is closed""" return self.state is ConnectionStates.DISCONNECTED - def connect_failed(self): - """Return True iff connection attempt failed after attempting all dns records""" - return self.disconnected() and self.last_attempt >= 0 and len(self._gai) == 0 - def _reset_reconnect_backoff(self): self._failures = 0 self._reconnect_backoff = self.config['reconnect_backoff_ms'] / 1000.0 - def _reconnect_jitter_pct(self): - return uniform(0.8, 1.2) - def _update_reconnect_backoff(self): # Do not mark as failure if there are more dns entries available to try if len(self._gai) > 0: @@ -917,7 +891,7 @@ class BrokerConnection(object): self._failures += 1 self._reconnect_backoff = self.config['reconnect_backoff_ms'] * 2 ** (self._failures - 1) self._reconnect_backoff = min(self._reconnect_backoff, self.config['reconnect_backoff_max_ms']) - self._reconnect_backoff *= self._reconnect_jitter_pct() + self._reconnect_backoff *= uniform(0.8, 1.2) self._reconnect_backoff /= 1000.0 log.debug('%s: reconnect backoff %s after %s failures', self, self._reconnect_backoff, self._failures) @@ -942,12 +916,9 @@ class BrokerConnection(object): with self._lock: if self.state is ConnectionStates.DISCONNECTED: return - log.log(logging.ERROR if error else logging.INFO, '%s: Closing connection. %s', self, error or '') - if error: - self._update_reconnect_backoff() - self._api_versions_future = None + log.info('%s: Closing connection. %s', self, error or '') + self._update_reconnect_backoff() self._sasl_auth_future = None - self._init_sasl_mechanism() self._protocol = KafkaProtocol( client_id=self.config['client_id'], api_version=self.config['api_version']) @@ -967,43 +938,27 @@ class BrokerConnection(object): # drop lock before state change callback and processing futures self.config['state_change_callback'](self.node_id, sock, self) - if sock: - sock.close() - for (_correlation_id, (future, _timestamp, _timeout)) in ifrs: + sock.close() + for (_correlation_id, (future, _timestamp)) in ifrs: future.failure(error) def _can_send_recv(self): """Return True iff socket is ready for requests / responses""" - return self.connected() or self.initializing() + return self.state in (ConnectionStates.AUTHENTICATING, + ConnectionStates.CONNECTED) - def send(self, request, blocking=True, request_timeout_ms=None): - """Queue request for async network send, return Future() - - Arguments: - request (Request): kafka protocol request object to send. - - Keyword Arguments: - blocking (bool, optional): Whether to immediately send via - blocking socket I/O. Default: True. - request_timeout_ms: Custom timeout in milliseconds for request. - Default: None (uses value from connection configuration) - - Returns: future - """ + def send(self, request, blocking=True): + """Queue request for async network send, return Future()""" future = Future() if self.connecting(): return future.failure(Errors.NodeNotReadyError(str(self))) elif not self.connected(): return future.failure(Errors.KafkaConnectionError(str(self))) elif not self.can_send_more(): - # very small race here, but prefer it over breaking abstraction to check self._throttle_time - if self.throttled(): - return future.failure(Errors.ThrottlingQuotaExceededError(str(self))) return future.failure(Errors.TooManyInFlightRequests(str(self))) - return self._send(request, blocking=blocking, request_timeout_ms=request_timeout_ms) + return self._send(request, blocking=blocking) - def _send(self, request, blocking=True, request_timeout_ms=None): - request_timeout_ms = request_timeout_ms or self.config['request_timeout_ms'] + def _send(self, request, blocking=True): future = Future() with self._lock: if not self._can_send_recv(): @@ -1014,12 +969,11 @@ class BrokerConnection(object): correlation_id = self._protocol.send_request(request) - log.debug('%s: Request %d (timeout_ms %s): %s', self, correlation_id, request_timeout_ms, request) + log.debug('%s Request %d: %s', self, correlation_id, request) if request.expect_response(): - assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' sent_time = time.time() - timeout_at = sent_time + (request_timeout_ms / 1000) - self.in_flight_requests[correlation_id] = (future, sent_time, timeout_at) + assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' + self.in_flight_requests[correlation_id] = (future, sent_time) else: future.success(None) @@ -1048,7 +1002,7 @@ class BrokerConnection(object): return True except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error sending request data", self) + log.exception("Error sending request data to %s", self) error = Errors.KafkaConnectionError("%s: %s" % (self, e)) self.close(error=error) return False @@ -1081,31 +1035,13 @@ class BrokerConnection(object): return len(self._send_buffer) == 0 except (ConnectionError, TimeoutError, Exception) as e: - log.exception("%s: Error sending request data", self) + log.exception("Error sending request data to %s", self) error = Errors.KafkaConnectionError("%s: %s" % (self, e)) self.close(error=error) return False - def _maybe_throttle(self, response): - throttle_time_ms = getattr(response, 'throttle_time_ms', 0) - if self._sensors: - self._sensors.throttle_time.record(throttle_time_ms) - if not throttle_time_ms: - if self._throttle_time is not None: - self._throttle_time = None - return - # Client side throttling enabled in v2.0 brokers - # prior to that throttling (if present) was managed broker-side - if self.config['api_version'] is not None and self.config['api_version'] >= (2, 0): - throttle_time = time.time() + throttle_time_ms / 1000 - self._throttle_time = max(throttle_time, self._throttle_time or 0) - log.warning("%s: %s throttled by broker (%d ms)", self, - response.__class__.__name__, throttle_time_ms) - def can_send_more(self): - """Check for throttling / quota violations and max in-flight-requests""" - if self.throttle_delay() > 0: - return False + """Return True unless there are max_in_flight_requests_per_connection.""" max_ifrs = self.config['max_in_flight_requests_per_connection'] return len(self.in_flight_requests) < max_ifrs @@ -1116,20 +1052,18 @@ class BrokerConnection(object): """ responses = self._recv() if not responses and self.requests_timed_out(): - timed_out = self.timed_out_ifrs() - timeout_ms = (timed_out[0][2] - timed_out[0][1]) * 1000 - log.warning('%s: timed out after %s ms. Closing connection.', - self, timeout_ms) + log.warning('%s timed out after %s ms. Closing connection.', + self, self.config['request_timeout_ms']) self.close(error=Errors.RequestTimedOutError( 'Request timed out after %s ms' % - timeout_ms)) + self.config['request_timeout_ms'])) return () # augment responses w/ correlation_id, future, and timestamp for i, (correlation_id, response) in enumerate(responses): try: with self._lock: - (future, timestamp, _timeout) = self.in_flight_requests.pop(correlation_id) + (future, timestamp) = self.in_flight_requests.pop(correlation_id) except KeyError: self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) return () @@ -1137,8 +1071,7 @@ class BrokerConnection(object): if self._sensors: self._sensors.request_time.record(latency_ms) - log.debug('%s: Response %d (%s ms): %s', self, correlation_id, latency_ms, response) - self._maybe_throttle(response) + log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response) responses[i] = (response, future) return responses @@ -1149,7 +1082,7 @@ class BrokerConnection(object): err = None with self._lock: if not self._can_send_recv(): - log.warning('%s: cannot recv: socket not connected', self) + log.warning('%s cannot recv: socket not connected', self) return () while len(recvd) < self.config['sock_chunk_buffer_count']: @@ -1199,30 +1132,36 @@ class BrokerConnection(object): return () def requests_timed_out(self): - return self.next_ifr_request_timeout_ms() == 0 - - def timed_out_ifrs(self): - now = time.time() - ifrs = sorted(self.in_flight_requests.values(), reverse=True, key=lambda ifr: ifr[2]) - return list(filter(lambda ifr: ifr[2] <= now, ifrs)) - - def next_ifr_request_timeout_ms(self): with self._lock: if self.in_flight_requests: - def get_timeout(v): - return v[2] - next_timeout = min(map(get_timeout, - self.in_flight_requests.values())) - return max(0, (next_timeout - time.time()) * 1000) - else: - return float('inf') + get_timestamp = lambda v: v[1] + oldest_at = min(map(get_timestamp, + self.in_flight_requests.values())) + timeout = self.config['request_timeout_ms'] / 1000.0 + if time.time() >= oldest_at + timeout: + return True + return False + + def _handle_api_version_response(self, response): + error_type = Errors.for_code(response.error_code) + assert error_type is Errors.NoError, "API version check failed" + self._api_versions = dict([ + (api_key, (min_version, max_version)) + for api_key, min_version, max_version in response.api_versions + ]) + return self._api_versions def get_api_versions(self): - # _api_versions is set as a side effect of first connection - # which should typically be bootstrap, but call check_version - # if that hasn't happened yet - if self._api_versions is None: - self.check_version() + if self._api_versions is not None: + return self._api_versions + + version = self.check_version() + if version < (0, 10, 0): + raise Errors.UnsupportedVersionError( + "ApiVersion not supported by cluster version {} < 0.10.0" + .format(version)) + # _api_versions is set as a side effect of check_versions() on a cluster + # that supports 0.10.0 or later return self._api_versions def _infer_broker_version_from_api_versions(self, api_versions): @@ -1230,69 +1169,139 @@ class BrokerConnection(object): # in reverse order. As soon as we find one that works, return it test_cases = [ # format (, ) - # Make sure to update consumer_integration test check when adding newer versions. - # ((3, 9), FetchRequest[17]), - # ((3, 8), ProduceRequest[11]), - # ((3, 7), FetchRequest[16]), - # ((3, 6), AddPartitionsToTxnRequest[4]), - # ((3, 5), FetchRequest[15]), - # ((3, 4), StopReplicaRequest[3]), # broker-internal api... - # ((3, 3), DescribeAclsRequest[3]), - # ((3, 2), JoinGroupRequest[9]), - # ((3, 1), FetchRequest[13]), - # ((3, 0), ListOffsetsRequest[7]), - # ((2, 8), ProduceRequest[9]), - # ((2, 7), FetchRequest[12]), - # ((2, 6), ListGroupsRequest[4]), - # ((2, 5), JoinGroupRequest[7]), - ((2, 6), DescribeClientQuotasRequest[0]), - ((2, 5), DescribeAclsRequest[2]), - ((2, 4), ProduceRequest[8]), - ((2, 3), FetchRequest[11]), - ((2, 2), ListOffsetsRequest[5]), - ((2, 1), FetchRequest[10]), - ((2, 0), FetchRequest[8]), - ((1, 1), FetchRequest[7]), - ((1, 0), MetadataRequest[5]), - ((0, 11), MetadataRequest[4]), + ((2, 5, 0), DescribeAclsRequest_v2), + ((2, 4, 0), ProduceRequest[8]), + ((2, 3, 0), FetchRequest[11]), + ((2, 2, 0), OffsetRequest[5]), + ((2, 1, 0), FetchRequest[10]), + ((2, 0, 0), FetchRequest[8]), + ((1, 1, 0), FetchRequest[7]), + ((1, 0, 0), MetadataRequest[5]), + ((0, 11, 0), MetadataRequest[4]), ((0, 10, 2), OffsetFetchRequest[2]), ((0, 10, 1), MetadataRequest[2]), ] # Get the best match of test cases - for broker_version, proto_struct in sorted(test_cases, reverse=True): - if proto_struct.API_KEY not in api_versions: + for broker_version, struct in sorted(test_cases, reverse=True): + if struct.API_KEY not in api_versions: continue - min_version, max_version = api_versions[proto_struct.API_KEY] - if min_version <= proto_struct.API_VERSION <= max_version: + min_version, max_version = api_versions[struct.API_KEY] + if min_version <= struct.API_VERSION <= max_version: return broker_version - # We know that ApiVersionsResponse is only supported in 0.10+ + # We know that ApiVersionResponse is only supported in 0.10+ # so if all else fails, choose that return (0, 10, 0) - def check_version(self, timeout=2, **kwargs): + def check_version(self, timeout=2, strict=False, topics=[]): """Attempt to guess the broker version. - Keyword Arguments: - timeout (numeric, optional): Maximum number of seconds to block attempting - to connect and check version. Default 2 - Note: This is a blocking call. - Returns: version tuple, i.e. (3, 9), (2, 4), etc ... - - Raises: NodeNotReadyError on timeout + Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... """ timeout_at = time.time() + timeout - if not self.connect_blocking(timeout_at - time.time()): - raise Errors.NodeNotReadyError() + log.info('Probing node %s broker version', self.node_id) + # Monkeypatch some connection configurations to avoid timeouts + override_config = { + 'request_timeout_ms': timeout * 1000, + 'max_in_flight_requests_per_connection': 5 + } + stashed = {} + for key in override_config: + stashed[key] = self.config[key] + self.config[key] = override_config[key] + + def reset_override_configs(): + for key in stashed: + self.config[key] = stashed[key] + + # kafka kills the connection when it doesn't recognize an API request + # so we can send a test request and then follow immediately with a + # vanilla MetadataRequest. If the server did not recognize the first + # request, both will be failed with a ConnectionError that wraps + # socket.error (32, 54, or 104) + from kafka.protocol.admin import ApiVersionRequest, ListGroupsRequest + from kafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest + + test_cases = [ + # All cases starting from 0.10 will be based on ApiVersionResponse + ((0, 10), ApiVersionRequest[0]()), + ((0, 9), ListGroupsRequest[0]()), + ((0, 8, 2), GroupCoordinatorRequest[0]('kafka-python-default-group')), + ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), + ((0, 8, 0), MetadataRequest[0](topics)), + ] + + for version, request in test_cases: + if not self.connect_blocking(timeout_at - time.time()): + reset_override_configs() + raise Errors.NodeNotReadyError() + f = self.send(request) + # HACK: sleeping to wait for socket to send bytes + time.sleep(0.1) + # when broker receives an unrecognized request API + # it abruptly closes our socket. + # so we attempt to send a second request immediately + # that we believe it will definitely recognize (metadata) + # the attempt to write to a disconnected socket should + # immediately fail and allow us to infer that the prior + # request was unrecognized + mr = self.send(MetadataRequest[0](topics)) + + selector = self.config['selector']() + selector.register(self._sock, selectors.EVENT_READ) + while not (f.is_done and mr.is_done): + selector.select(1) + for response, future in self.recv(): + future.success(response) + selector.close() + + if f.succeeded(): + if isinstance(request, ApiVersionRequest[0]): + # Starting from 0.10 kafka broker we determine version + # by looking at ApiVersionResponse + api_versions = self._handle_api_version_response(f.value) + version = self._infer_broker_version_from_api_versions(api_versions) + log.info('Broker version identified as %s', '.'.join(map(str, version))) + log.info('Set configuration api_version=%s to skip auto' + ' check_version requests on startup', version) + break + + # Only enable strict checking to verify that we understand failure + # modes. For most users, the fact that the request failed should be + # enough to rule out a particular broker version. + if strict: + # If the socket flush hack did not work (which should force the + # connection to close and fail all pending requests), then we + # get a basic Request Timeout. This is not ideal, but we'll deal + if isinstance(f.exception, Errors.RequestTimedOutError): + pass + + # 0.9 brokers do not close the socket on unrecognized api + # requests (bug...). In this case we expect to see a correlation + # id mismatch + elif (isinstance(f.exception, Errors.CorrelationIdError) and + version == (0, 10)): + pass + elif six.PY2: + assert isinstance(f.exception.args[0], socket.error) + assert f.exception.args[0].errno in (32, 54, 104) + else: + assert isinstance(f.exception.args[0], ConnectionError) + log.info("Broker is not v%s -- it did not recognize %s", + version, request.__class__.__name__) else: - return self._api_version + reset_override_configs() + raise Errors.UnrecognizedBrokerVersion() + + reset_override_configs() + return version def __str__(self): - return "" % ( - self.config['client_id'], self.node_id, self.host, self.port, self.state, + return "" % ( + self.node_id, self.host, self.port, self.state, AFI_NAMES[self._sock_afi], self._sock_addr) @@ -1349,16 +1358,6 @@ class BrokerConnectionMetrics(object): 'The maximum request latency in ms.'), Max()) - throttle_time = metrics.sensor('throttle-time') - throttle_time.add(metrics.metric_name( - 'throttle-time-avg', metric_group_name, - 'The average throttle time in ms.'), - Avg()) - throttle_time.add(metrics.metric_name( - 'throttle-time-max', metric_group_name, - 'The maximum throttle time in ms.'), - Max()) - # if one sensor of the metrics has been registered for the connection, # then all other sensors should have been registered; and vice versa node_str = 'node-{0}'.format(node_id) @@ -1410,23 +1409,9 @@ class BrokerConnectionMetrics(object): 'The maximum request latency in ms.'), Max()) - throttle_time = metrics.sensor( - node_str + '.throttle', - parents=[metrics.get_sensor('throttle-time')]) - throttle_time.add(metrics.metric_name( - 'throttle-time-avg', metric_group_name, - 'The average throttle time in ms.'), - Avg()) - throttle_time.add(metrics.metric_name( - 'throttle-time-max', metric_group_name, - 'The maximum throttle time in ms.'), - Max()) - - self.bytes_sent = metrics.sensor(node_str + '.bytes-sent') self.bytes_received = metrics.sensor(node_str + '.bytes-received') self.request_time = metrics.sensor(node_str + '.latency') - self.throttle_time = metrics.sensor(node_str + '.throttle') def _address_family(address): @@ -1496,6 +1481,32 @@ def get_ip_port_afi(host_and_port_str): return host, port, af +def collect_hosts(hosts, randomize=True): + """ + Collects a comma-separated set of hosts (host:port) and optionally + randomize the returned list. + """ + + if isinstance(hosts, six.string_types): + hosts = hosts.strip().split(',') + + result = [] + afi = socket.AF_INET + for host_port in hosts: + + host, port, afi = get_ip_port_afi(host_port) + + if port < 0: + port = DEFAULT_KAFKA_PORT + + result.append((host, port, afi)) + + if randomize: + shuffle(result) + + return result + + def is_inet_4_or_6(gai): """Given a getaddrinfo struct, return True iff ipv4 or ipv6""" return gai[0] in (socket.AF_INET, socket.AF_INET6) diff --git a/venv/lib/python3.12/site-packages/kafka/consumer/fetcher.py b/venv/lib/python3.12/site-packages/kafka/consumer/fetcher.py index 1689b23..e4f8c18 100644 --- a/venv/lib/python3.12/site-packages/kafka/consumer/fetcher.py +++ b/venv/lib/python3.12/site-packages/kafka/consumer/fetcher.py @@ -1,9 +1,9 @@ -from __future__ import absolute_import, division +from __future__ import absolute_import import collections import copy -import itertools import logging +import random import sys import time @@ -12,14 +12,13 @@ from kafka.vendor import six import kafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.protocol.fetch import FetchRequest, AbortedTransaction -from kafka.protocol.list_offsets import ( - ListOffsetsRequest, OffsetResetStrategy, UNKNOWN_OFFSET +from kafka.protocol.fetch import FetchRequest +from kafka.protocol.offset import ( + OffsetRequest, OffsetResetStrategy, UNKNOWN_OFFSET ) from kafka.record import MemoryRecords from kafka.serializer import Deserializer -from kafka.structs import TopicPartition, OffsetAndMetadata, OffsetAndTimestamp -from kafka.util import Timer +from kafka.structs import TopicPartition, OffsetAndTimestamp log = logging.getLogger(__name__) @@ -28,13 +27,8 @@ log = logging.getLogger(__name__) READ_UNCOMMITTED = 0 READ_COMMITTED = 1 -ISOLATION_LEVEL_CONFIG = { - 'read_uncommitted': READ_UNCOMMITTED, - 'read_committed': READ_COMMITTED, -} - ConsumerRecord = collections.namedtuple("ConsumerRecord", - ["topic", "partition", "leader_epoch", "offset", "timestamp", "timestamp_type", + ["topic", "partition", "offset", "timestamp", "timestamp_type", "key", "value", "headers", "checksum", "serialized_key_size", "serialized_value_size", "serialized_header_size"]) @@ -43,10 +37,6 @@ CompletedFetch = collections.namedtuple("CompletedFetch", "partition_data", "metric_aggregator"]) -ExceptionMetadata = collections.namedtuple("ExceptionMetadata", - ["partition", "fetched_offset", "exception"]) - - class NoOffsetForPartitionError(Errors.KafkaError): pass @@ -65,15 +55,13 @@ class Fetcher(six.Iterator): 'max_partition_fetch_bytes': 1048576, 'max_poll_records': sys.maxsize, 'check_crcs': True, - 'metrics': None, + 'iterator_refetch_records': 1, # undocumented -- interface may change 'metric_group_prefix': 'consumer', - 'request_timeout_ms': 30000, - 'retry_backoff_ms': 100, - 'enable_incremental_fetch_sessions': True, - 'isolation_level': 'read_uncommitted', + 'api_version': (0, 8, 0), + 'retry_backoff_ms': 100 } - def __init__(self, client, subscriptions, **configs): + def __init__(self, client, subscriptions, metrics, **configs): """Initialize a Kafka Message Fetcher. Keyword Arguments: @@ -81,8 +69,6 @@ class Fetcher(six.Iterator): raw message key and returns a deserialized key. value_deserializer (callable, optional): Any callable that takes a raw message value and returns a deserialized value. - enable_incremental_fetch_sessions: (bool): Use incremental fetch sessions - when available / supported by kafka broker. See KIP-227. Default: True. fetch_min_bytes (int): Minimum amount of data the server should return for a fetch request, otherwise wait up to fetch_max_wait_ms for more data to accumulate. Default: 1. @@ -111,33 +97,20 @@ class Fetcher(six.Iterator): consumed. This ensures no on-the-wire or on-disk corruption to the messages occurred. This check adds some overhead, so it may be disabled in cases seeking extreme performance. Default: True - isolation_level (str): Configure KIP-98 transactional consumer by - setting to 'read_committed'. This will cause the consumer to - skip records from aborted tranactions. Default: 'read_uncommitted' """ self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: if key in configs: self.config[key] = configs[key] - if self.config['isolation_level'] not in ISOLATION_LEVEL_CONFIG: - raise Errors.KafkaConfigurationError('Unrecognized isolation_level') - self._client = client self._subscriptions = subscriptions self._completed_fetches = collections.deque() # Unparsed responses self._next_partition_records = None # Holds a single PartitionRecords until fully consumed self._iterator = None self._fetch_futures = collections.deque() - if self.config['metrics']: - self._sensors = FetchManagerMetrics(self.config['metrics'], self.config['metric_group_prefix']) - else: - self._sensors = None - self._isolation_level = ISOLATION_LEVEL_CONFIG[self.config['isolation_level']] - self._session_handlers = {} - self._nodes_with_pending_fetch_requests = set() - self._cached_list_offsets_exception = None - self._next_in_line_exception_metadata = None + self._sensors = FetchManagerMetrics(metrics, self.config['metric_group_prefix']) + self._isolation_level = READ_UNCOMMITTED def send_fetches(self): """Send FetchRequests for all assigned partitions that do not already have @@ -147,18 +120,29 @@ class Fetcher(six.Iterator): List of Futures: each future resolves to a FetchResponse """ futures = [] - for node_id, (request, fetch_offsets) in six.iteritems(self._create_fetch_requests()): - log.debug("Sending FetchRequest to node %s", node_id) - self._nodes_with_pending_fetch_requests.add(node_id) - future = self._client.send(node_id, request, wakeup=False) - future.add_callback(self._handle_fetch_response, node_id, fetch_offsets, time.time()) - future.add_errback(self._handle_fetch_error, node_id) - future.add_both(self._clear_pending_fetch_request, node_id) - futures.append(future) + for node_id, request in six.iteritems(self._create_fetch_requests()): + if self._client.ready(node_id): + log.debug("Sending FetchRequest to node %s", node_id) + future = self._client.send(node_id, request, wakeup=False) + future.add_callback(self._handle_fetch_response, request, time.time()) + future.add_errback(log.error, 'Fetch to node %s failed: %s', node_id) + futures.append(future) self._fetch_futures.extend(futures) self._clean_done_fetch_futures() return futures + def reset_offsets_if_needed(self, partitions): + """Lookup and set offsets for any partitions which are awaiting an + explicit reset. + + Arguments: + partitions (set of TopicPartitions): the partitions to reset + """ + for tp in partitions: + # TODO: If there are several offsets to reset, we could submit offset requests in parallel + if self._subscriptions.is_assigned(tp) and self._subscriptions.is_offset_reset_needed(tp): + self._reset_offset(tp) + def _clean_done_fetch_futures(self): while True: if not self._fetch_futures: @@ -172,109 +156,49 @@ class Fetcher(six.Iterator): self._clean_done_fetch_futures() return bool(self._fetch_futures) - def reset_offsets_if_needed(self): - """Reset offsets for the given partitions using the offset reset strategy. + def update_fetch_positions(self, partitions): + """Update the fetch positions for the provided partitions. Arguments: - partitions ([TopicPartition]): the partitions that need offsets reset - - Returns: - bool: True if any partitions need reset; otherwise False (no reset pending) + partitions (list of TopicPartitions): partitions to update Raises: - NoOffsetForPartitionError: if no offset reset strategy is defined - KafkaTimeoutError if timeout_ms provided + NoOffsetForPartitionError: if no offset is stored for a given + partition and no reset policy is available """ - # Raise exception from previous offset fetch if there is one - exc, self._cached_list_offsets_exception = self._cached_list_offsets_exception, None - if exc: - raise exc - - partitions = self._subscriptions.partitions_needing_reset() - if not partitions: - return False - log.debug('Resetting offsets for %s', partitions) - - offset_resets = dict() + # reset the fetch position to the committed position for tp in partitions: - ts = self._subscriptions.assignment[tp].reset_strategy - if ts: - offset_resets[tp] = ts + if not self._subscriptions.is_assigned(tp): + log.warning("partition %s is not assigned - skipping offset" + " update", tp) + continue + elif self._subscriptions.is_fetchable(tp): + log.warning("partition %s is still fetchable -- skipping offset" + " update", tp) + continue - self._reset_offsets_async(offset_resets) - return True + if self._subscriptions.is_offset_reset_needed(tp): + self._reset_offset(tp) + elif self._subscriptions.assignment[tp].committed is None: + # there's no committed position, so we need to reset with the + # default strategy + self._subscriptions.need_offset_reset(tp) + self._reset_offset(tp) + else: + committed = self._subscriptions.assignment[tp].committed.offset + log.debug("Resetting offset for partition %s to the committed" + " offset %s", tp, committed) + self._subscriptions.seek(tp, committed) - def offsets_by_times(self, timestamps, timeout_ms=None): - """Fetch offset for each partition passed in ``timestamps`` map. - - Blocks until offsets are obtained, a non-retriable exception is raised - or ``timeout_ms`` passed. - - Arguments: - timestamps: {TopicPartition: int} dict with timestamps to fetch - offsets by. -1 for the latest available, -2 for the earliest - available. Otherwise timestamp is treated as epoch milliseconds. - timeout_ms (int, optional): The maximum time in milliseconds to block. - - Returns: - {TopicPartition: OffsetAndTimestamp}: Mapping of partition to - retrieved offset, timestamp, and leader_epoch. If offset does not exist for - the provided timestamp, that partition will be missing from - this mapping. - - Raises: - KafkaTimeoutError if timeout_ms provided - """ - offsets = self._fetch_offsets_by_times(timestamps, timeout_ms) + def get_offsets_by_times(self, timestamps, timeout_ms): + offsets = self._retrieve_offsets(timestamps, timeout_ms) for tp in timestamps: if tp not in offsets: offsets[tp] = None - return offsets - - def _fetch_offsets_by_times(self, timestamps, timeout_ms=None): - if not timestamps: - return {} - - timer = Timer(timeout_ms, "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) - timestamps = copy.copy(timestamps) - fetched_offsets = dict() - while True: - if not timestamps: - return {} - - future = self._send_list_offsets_requests(timestamps) - self._client.poll(future=future, timeout_ms=timer.timeout_ms) - - # Timeout w/o future completion - if not future.is_done: - break - - if future.succeeded(): - fetched_offsets.update(future.value[0]) - if not future.value[1]: - return fetched_offsets - - timestamps = {tp: timestamps[tp] for tp in future.value[1]} - - elif not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type - - if future.exception.invalid_metadata or self._client.cluster.need_update: - refresh_future = self._client.cluster.request_update() - self._client.poll(future=refresh_future, timeout_ms=timer.timeout_ms) - - if not future.is_done: - break else: - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) - - timer.maybe_raise() - - raise Errors.KafkaTimeoutError( - "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) + offset, timestamp = offsets[tp] + offsets[tp] = OffsetAndTimestamp(offset, timestamp) + return offsets def beginning_offsets(self, partitions, timeout_ms): return self.beginning_or_end_offset( @@ -286,11 +210,103 @@ class Fetcher(six.Iterator): def beginning_or_end_offset(self, partitions, timestamp, timeout_ms): timestamps = dict([(tp, timestamp) for tp in partitions]) - offsets = self._fetch_offsets_by_times(timestamps, timeout_ms) + offsets = self._retrieve_offsets(timestamps, timeout_ms) for tp in timestamps: - offsets[tp] = offsets[tp].offset + offsets[tp] = offsets[tp][0] return offsets + def _reset_offset(self, partition): + """Reset offsets for the given partition using the offset reset strategy. + + Arguments: + partition (TopicPartition): the partition that needs reset offset + + Raises: + NoOffsetForPartitionError: if no offset reset strategy is defined + """ + timestamp = self._subscriptions.assignment[partition].reset_strategy + if timestamp is OffsetResetStrategy.EARLIEST: + strategy = 'earliest' + elif timestamp is OffsetResetStrategy.LATEST: + strategy = 'latest' + else: + raise NoOffsetForPartitionError(partition) + + log.debug("Resetting offset for partition %s to %s offset.", + partition, strategy) + offsets = self._retrieve_offsets({partition: timestamp}) + + if partition in offsets: + offset = offsets[partition][0] + + # we might lose the assignment while fetching the offset, + # so check it is still active + if self._subscriptions.is_assigned(partition): + self._subscriptions.seek(partition, offset) + else: + log.debug("Could not find offset for partition %s since it is probably deleted" % (partition,)) + + def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): + """Fetch offset for each partition passed in ``timestamps`` map. + + Blocks until offsets are obtained, a non-retriable exception is raised + or ``timeout_ms`` passed. + + Arguments: + timestamps: {TopicPartition: int} dict with timestamps to fetch + offsets by. -1 for the latest available, -2 for the earliest + available. Otherwise timestamp is treated as epoch milliseconds. + + Returns: + {TopicPartition: (int, int)}: Mapping of partition to + retrieved offset and timestamp. If offset does not exist for + the provided timestamp, that partition will be missing from + this mapping. + """ + if not timestamps: + return {} + + start_time = time.time() + remaining_ms = timeout_ms + timestamps = copy.copy(timestamps) + while remaining_ms > 0: + if not timestamps: + return {} + + future = self._send_offset_requests(timestamps) + self._client.poll(future=future, timeout_ms=remaining_ms) + + if future.succeeded(): + return future.value + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type + + elapsed_ms = (time.time() - start_time) * 1000 + remaining_ms = timeout_ms - elapsed_ms + if remaining_ms < 0: + break + + if future.exception.invalid_metadata: + refresh_future = self._client.cluster.request_update() + self._client.poll(future=refresh_future, timeout_ms=remaining_ms) + + # Issue #1780 + # Recheck partition existence after after a successful metadata refresh + if refresh_future.succeeded() and isinstance(future.exception, Errors.StaleMetadata): + log.debug("Stale metadata was raised, and we now have an updated metadata. Rechecking partition existence") + unknown_partition = future.exception.args[0] # TopicPartition from StaleMetadata + if self._client.cluster.leader_for_partition(unknown_partition) is None: + log.debug("Removed partition %s from offsets retrieval" % (unknown_partition, )) + timestamps.pop(unknown_partition) + else: + time.sleep(self.config['retry_backoff_ms'] / 1000.0) + + elapsed_ms = (time.time() - start_time) * 1000 + remaining_ms = timeout_ms - elapsed_ms + + raise Errors.KafkaTimeoutError( + "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) + def fetched_records(self, max_records=None, update_offsets=True): """Returns previously fetched records and updates consumed offsets. @@ -300,7 +316,7 @@ class Fetcher(six.Iterator): Raises: OffsetOutOfRangeError: if no subscription offset_reset_strategy - CorruptRecordError: if message crc validation fails (check_crcs + CorruptRecordException: if message crc validation fails (check_crcs must be set to True) RecordTooLargeError: if a message is larger than the currently configured max_partition_fetch_bytes @@ -317,40 +333,20 @@ class Fetcher(six.Iterator): max_records = self.config['max_poll_records'] assert max_records > 0 - if self._next_in_line_exception_metadata is not None: - exc_meta = self._next_in_line_exception_metadata - self._next_in_line_exception_metadata = None - tp = exc_meta.partition - if self._subscriptions.is_fetchable(tp) and self._subscriptions.position(tp).offset == exc_meta.fetched_offset: - raise exc_meta.exception - drained = collections.defaultdict(list) records_remaining = max_records - # Needed to construct ExceptionMetadata if any exception is found when processing completed_fetch - fetched_partition = None - fetched_offset = -1 - try: - while records_remaining > 0: - if not self._next_partition_records: - if not self._completed_fetches: - break - completion = self._completed_fetches.popleft() - fetched_partition = completion.topic_partition - fetched_offset = completion.fetched_offset - self._next_partition_records = self._parse_fetched_data(completion) - else: - fetched_partition = self._next_partition_records.topic_partition - fetched_offset = self._next_partition_records.next_fetch_offset - records_remaining -= self._append(drained, - self._next_partition_records, - records_remaining, - update_offsets) - except Exception as e: - if not drained: - raise e - # To be thrown in the next call of this method - self._next_in_line_exception_metadata = ExceptionMetadata(fetched_partition, fetched_offset, e) + while records_remaining > 0: + if not self._next_partition_records: + if not self._completed_fetches: + break + completion = self._completed_fetches.popleft() + self._next_partition_records = self._parse_fetched_data(completion) + else: + records_remaining -= self._append(drained, + self._next_partition_records, + records_remaining, + update_offsets) return dict(drained), bool(self._completed_fetches) def _append(self, drained, part, max_records, update_offsets): @@ -358,101 +354,163 @@ class Fetcher(six.Iterator): return 0 tp = part.topic_partition + fetch_offset = part.fetch_offset if not self._subscriptions.is_assigned(tp): # this can happen when a rebalance happened before # fetched records are returned to the consumer's poll call log.debug("Not returning fetched records for partition %s" " since it is no longer assigned", tp) - elif not self._subscriptions.is_fetchable(tp): - # this can happen when a partition is paused before - # fetched records are returned to the consumer's poll call - log.debug("Not returning fetched records for assigned partition" - " %s since it is no longer fetchable", tp) - else: # note that the position should always be available # as long as the partition is still assigned position = self._subscriptions.assignment[tp].position - if part.next_fetch_offset == position.offset: - log.debug("Returning fetched records at offset %d for assigned" - " partition %s", position.offset, tp) + if not self._subscriptions.is_fetchable(tp): + # this can happen when a partition is paused before + # fetched records are returned to the consumer's poll call + log.debug("Not returning fetched records for assigned partition" + " %s since it is no longer fetchable", tp) + + elif fetch_offset == position: + # we are ensured to have at least one record since we already checked for emptiness part_records = part.take(max_records) - # list.extend([]) is a noop, but because drained is a defaultdict - # we should avoid initializing the default list unless there are records - if part_records: - drained[tp].extend(part_records) - # We want to increment subscription position if (1) we're using consumer.poll(), - # or (2) we didn't return any records (consumer iterator will update position - # when each message is yielded). There may be edge cases where we re-fetch records - # that we'll end up skipping, but for now we'll live with that. - highwater = self._subscriptions.assignment[tp].highwater - if highwater is not None and self._sensors: - self._sensors.records_fetch_lag.record(highwater - part.next_fetch_offset) - if update_offsets or not part_records: - # TODO: save leader_epoch - log.debug("Updating fetch position for assigned partition %s to %s (leader epoch %s)", - tp, part.next_fetch_offset, part.leader_epoch) - self._subscriptions.assignment[tp].position = OffsetAndMetadata(part.next_fetch_offset, '', -1) + next_offset = part_records[-1].offset + 1 + + log.log(0, "Returning fetched records at offset %d for assigned" + " partition %s and update position to %s", position, + tp, next_offset) + + for record in part_records: + drained[tp].append(record) + + if update_offsets: + self._subscriptions.assignment[tp].position = next_offset return len(part_records) else: # these records aren't next in line based on the last consumed # position, ignore them they must be from an obsolete request log.debug("Ignoring fetched records for %s at offset %s since" - " the current position is %d", tp, part.next_fetch_offset, - position.offset) + " the current position is %d", tp, part.fetch_offset, + position) - part.drain() + part.discard() return 0 - def _reset_offset_if_needed(self, partition, timestamp, offset): - # we might lose the assignment while fetching the offset, or the user might seek to a different offset, - # so verify it is still assigned and still in need of the requested reset - if not self._subscriptions.is_assigned(partition): - log.debug("Skipping reset of partition %s since it is no longer assigned", partition) - elif not self._subscriptions.is_offset_reset_needed(partition): - log.debug("Skipping reset of partition %s since reset is no longer needed", partition) - elif timestamp and not timestamp == self._subscriptions.assignment[partition].reset_strategy: - log.debug("Skipping reset of partition %s since an alternative reset has been requested", partition) - else: - log.info("Resetting offset for partition %s to offset %s.", partition, offset) - self._subscriptions.seek(partition, offset) + def _message_generator(self): + """Iterate over fetched_records""" + while self._next_partition_records or self._completed_fetches: - def _reset_offsets_async(self, timestamps): - timestamps_by_node = self._group_list_offset_requests(timestamps) - - for node_id, timestamps_and_epochs in six.iteritems(timestamps_by_node): - if not self._client.ready(node_id): + if not self._next_partition_records: + completion = self._completed_fetches.popleft() + self._next_partition_records = self._parse_fetched_data(completion) continue - partitions = set(timestamps_and_epochs.keys()) - expire_at = time.time() + self.config['request_timeout_ms'] / 1000 - self._subscriptions.set_reset_pending(partitions, expire_at) - def on_success(timestamps_and_epochs, result): - fetched_offsets, partitions_to_retry = result - if partitions_to_retry: - self._subscriptions.reset_failed(partitions_to_retry, time.time() + self.config['retry_backoff_ms'] / 1000) - self._client.cluster.request_update() + # Send additional FetchRequests when the internal queue is low + # this should enable moderate pipelining + if len(self._completed_fetches) <= self.config['iterator_refetch_records']: + self.send_fetches() - for partition, offset in six.iteritems(fetched_offsets): - ts, _epoch = timestamps_and_epochs[partition] - self._reset_offset_if_needed(partition, ts, offset.offset) + tp = self._next_partition_records.topic_partition - def on_failure(partitions, error): - self._subscriptions.reset_failed(partitions, time.time() + self.config['retry_backoff_ms'] / 1000) - self._client.cluster.request_update() + # We can ignore any prior signal to drop pending message sets + # because we are starting from a fresh one where fetch_offset == position + # i.e., the user seek()'d to this position + self._subscriptions.assignment[tp].drop_pending_message_set = False - if not getattr(error, 'retriable', False): - if not self._cached_list_offsets_exception: - self._cached_list_offsets_exception = error - else: - log.error("Discarding error in ListOffsetResponse because another error is pending: %s", error) + for msg in self._next_partition_records.take(): - future = self._send_list_offsets_request(node_id, timestamps_and_epochs) - future.add_callback(on_success, timestamps_and_epochs) - future.add_errback(on_failure, partitions) + # Because we are in a generator, it is possible for + # subscription state to change between yield calls + # so we need to re-check on each loop + # this should catch assignment changes, pauses + # and resets via seek_to_beginning / seek_to_end + if not self._subscriptions.is_fetchable(tp): + log.debug("Not returning fetched records for partition %s" + " since it is no longer fetchable", tp) + self._next_partition_records = None + break - def _send_list_offsets_requests(self, timestamps): + # If there is a seek during message iteration, + # we should stop unpacking this message set and + # wait for a new fetch response that aligns with the + # new seek position + elif self._subscriptions.assignment[tp].drop_pending_message_set: + log.debug("Skipping remainder of message set for partition %s", tp) + self._subscriptions.assignment[tp].drop_pending_message_set = False + self._next_partition_records = None + break + + # Compressed messagesets may include earlier messages + elif msg.offset < self._subscriptions.assignment[tp].position: + log.debug("Skipping message offset: %s (expecting %s)", + msg.offset, + self._subscriptions.assignment[tp].position) + continue + + self._subscriptions.assignment[tp].position = msg.offset + 1 + yield msg + + self._next_partition_records = None + + def _unpack_message_set(self, tp, records): + try: + batch = records.next_batch() + while batch is not None: + + # LegacyRecordBatch cannot access either base_offset or last_offset_delta + try: + self._subscriptions.assignment[tp].last_offset_from_message_batch = batch.base_offset + \ + batch.last_offset_delta + except AttributeError: + pass + + for record in batch: + key_size = len(record.key) if record.key is not None else -1 + value_size = len(record.value) if record.value is not None else -1 + key = self._deserialize( + self.config['key_deserializer'], + tp.topic, record.key) + value = self._deserialize( + self.config['value_deserializer'], + tp.topic, record.value) + headers = record.headers + header_size = sum( + len(h_key.encode("utf-8")) + (len(h_val) if h_val is not None else 0) for h_key, h_val in + headers) if headers else -1 + yield ConsumerRecord( + tp.topic, tp.partition, record.offset, record.timestamp, + record.timestamp_type, key, value, headers, record.checksum, + key_size, value_size, header_size) + + batch = records.next_batch() + + # If unpacking raises StopIteration, it is erroneously + # caught by the generator. We want all exceptions to be raised + # back to the user. See Issue 545 + except StopIteration as e: + log.exception('StopIteration raised unpacking messageset') + raise RuntimeError('StopIteration raised unpacking messageset') + + def __iter__(self): # pylint: disable=non-iterator-returned + return self + + def __next__(self): + if not self._iterator: + self._iterator = self._message_generator() + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + raise + + def _deserialize(self, f, topic, bytes_): + if not f: + return bytes_ + if isinstance(f, Deserializer): + return f.deserialize(topic, bytes_) + return f(bytes_) + + def _send_offset_requests(self, timestamps): """Fetch offsets for each partition in timestamps dict. This may send request to multiple nodes, based on who is Leader for partition. @@ -463,98 +521,80 @@ class Fetcher(six.Iterator): Returns: Future: resolves to a mapping of retrieved offsets """ - timestamps_by_node = self._group_list_offset_requests(timestamps) - if not timestamps_by_node: - return Future().failure(Errors.StaleMetadata()) + timestamps_by_node = collections.defaultdict(dict) + for partition, timestamp in six.iteritems(timestamps): + node_id = self._client.cluster.leader_for_partition(partition) + if node_id is None: + self._client.add_topic(partition.topic) + log.debug("Partition %s is unknown for fetching offset," + " wait for metadata refresh", partition) + return Future().failure(Errors.StaleMetadata(partition)) + elif node_id == -1: + log.debug("Leader for partition %s unavailable for fetching " + "offset, wait for metadata refresh", partition) + return Future().failure( + Errors.LeaderNotAvailableError(partition)) + else: + timestamps_by_node[node_id][partition] = timestamp - # Aggregate results until we have all responses + # Aggregate results until we have all list_offsets_future = Future() - fetched_offsets = dict() - partitions_to_retry = set() - remaining_responses = [len(timestamps_by_node)] # list for mutable / 2.7 hack + responses = [] + node_count = len(timestamps_by_node) - def on_success(remaining_responses, value): - remaining_responses[0] -= 1 # noqa: F823 - fetched_offsets.update(value[0]) - partitions_to_retry.update(value[1]) - if not remaining_responses[0] and not list_offsets_future.is_done: - list_offsets_future.success((fetched_offsets, partitions_to_retry)) + def on_success(value): + responses.append(value) + if len(responses) == node_count: + offsets = {} + for r in responses: + offsets.update(r) + list_offsets_future.success(offsets) def on_fail(err): if not list_offsets_future.is_done: list_offsets_future.failure(err) for node_id, timestamps in six.iteritems(timestamps_by_node): - _f = self._send_list_offsets_request(node_id, timestamps) - _f.add_callback(on_success, remaining_responses) + _f = self._send_offset_request(node_id, timestamps) + _f.add_callback(on_success) _f.add_errback(on_fail) return list_offsets_future - def _group_list_offset_requests(self, timestamps): - timestamps_by_node = collections.defaultdict(dict) - for partition, timestamp in six.iteritems(timestamps): - node_id = self._client.cluster.leader_for_partition(partition) - if node_id is None: - self._client.add_topic(partition.topic) - log.debug("Partition %s is unknown for fetching offset", partition) - self._client.cluster.request_update() - elif node_id == -1: - log.debug("Leader for partition %s unavailable for fetching " - "offset, wait for metadata refresh", partition) - self._client.cluster.request_update() - else: - leader_epoch = -1 - timestamps_by_node[node_id][partition] = (timestamp, leader_epoch) - return dict(timestamps_by_node) - - def _send_list_offsets_request(self, node_id, timestamps_and_epochs): - version = self._client.api_version(ListOffsetsRequest, max_version=4) - if self.config['isolation_level'] == 'read_committed' and version < 2: - raise Errors.UnsupportedVersionError('read_committed isolation level requires ListOffsetsRequest >= v2') + def _send_offset_request(self, node_id, timestamps): by_topic = collections.defaultdict(list) - for tp, (timestamp, leader_epoch) in six.iteritems(timestamps_and_epochs): - if version >= 4: - data = (tp.partition, leader_epoch, timestamp) - elif version >= 1: + for tp, timestamp in six.iteritems(timestamps): + if self.config['api_version'] >= (0, 10, 1): data = (tp.partition, timestamp) else: data = (tp.partition, timestamp, 1) by_topic[tp.topic].append(data) - if version <= 1: - request = ListOffsetsRequest[version]( - -1, - list(six.iteritems(by_topic))) + if self.config['api_version'] >= (0, 10, 1): + request = OffsetRequest[1](-1, list(six.iteritems(by_topic))) else: - request = ListOffsetsRequest[version]( - -1, - self._isolation_level, - list(six.iteritems(by_topic))) + request = OffsetRequest[0](-1, list(six.iteritems(by_topic))) # Client returns a future that only fails on network issues # so create a separate future and attach a callback to update it # based on response error codes future = Future() - log.debug("Sending ListOffsetRequest %s to broker %s", request, node_id) _f = self._client.send(node_id, request) - _f.add_callback(self._handle_list_offsets_response, future) + _f.add_callback(self._handle_offset_response, future) _f.add_errback(lambda e: future.failure(e)) return future - def _handle_list_offsets_response(self, future, response): - """Callback for the response of the ListOffsets api call + def _handle_offset_response(self, future, response): + """Callback for the response of the list offset call above. Arguments: future (Future): the future to update based on response - response (ListOffsetsResponse): response from the server + response (OffsetResponse): response from the server Raises: AssertionError: if response does not match partition """ - fetched_offsets = dict() - partitions_to_retry = set() - unauthorized_topics = set() + timestamp_offset_map = {} for topic, part_data in response.topics: for partition_info in part_data: partition, error_code = partition_info[:2] @@ -563,62 +603,58 @@ class Fetcher(six.Iterator): if error_type is Errors.NoError: if response.API_VERSION == 0: offsets = partition_info[2] - assert len(offsets) <= 1, 'Expected ListOffsetsResponse with one offset' + assert len(offsets) <= 1, 'Expected OffsetResponse with one offset' if not offsets: offset = UNKNOWN_OFFSET else: offset = offsets[0] - timestamp = None - leader_epoch = -1 - elif response.API_VERSION <= 3: - timestamp, offset = partition_info[2:] - leader_epoch = -1 + log.debug("Handling v0 ListOffsetResponse response for %s. " + "Fetched offset %s", partition, offset) + if offset != UNKNOWN_OFFSET: + timestamp_offset_map[partition] = (offset, None) else: - timestamp, offset, leader_epoch = partition_info[2:] - log.debug("Handling ListOffsetsResponse response for %s. " - "Fetched offset %s, timestamp %s, leader_epoch %s", - partition, offset, timestamp, leader_epoch) - if offset != UNKNOWN_OFFSET: - fetched_offsets[partition] = OffsetAndTimestamp(offset, timestamp, leader_epoch) + timestamp, offset = partition_info[2:] + log.debug("Handling ListOffsetResponse response for %s. " + "Fetched offset %s, timestamp %s", + partition, offset, timestamp) + if offset != UNKNOWN_OFFSET: + timestamp_offset_map[partition] = (offset, timestamp) elif error_type is Errors.UnsupportedForMessageFormatError: - # The message format on the broker side is before 0.10.0, which means it does not - # support timestamps. We treat this case the same as if we weren't able to find an - # offset corresponding to the requested timestamp and leave it out of the result. + # The message format on the broker side is before 0.10.0, + # we simply put None in the response. log.debug("Cannot search by timestamp for partition %s because the" " message format version is before 0.10.0", partition) - elif error_type in (Errors.NotLeaderForPartitionError, - Errors.ReplicaNotAvailableError, - Errors.KafkaStorageError): + elif error_type is Errors.NotLeaderForPartitionError: log.debug("Attempt to fetch offsets for partition %s failed due" - " to %s, retrying.", error_type.__name__, partition) - partitions_to_retry.add(partition) + " to obsolete leadership information, retrying.", + partition) + future.failure(error_type(partition)) + return elif error_type is Errors.UnknownTopicOrPartitionError: - log.warning("Received unknown topic or partition error in ListOffsets " - "request for partition %s. The topic/partition " + - "may not exist or the user may not have Describe access " - "to it.", partition) - partitions_to_retry.add(partition) - elif error_type is Errors.TopicAuthorizationFailedError: - unauthorized_topics.add(topic) + log.warning("Received unknown topic or partition error in ListOffset " + "request for partition %s. The topic/partition " + + "may not exist or the user may not have Describe access " + "to it.", partition) + future.failure(error_type(partition)) + return else: log.warning("Attempt to fetch offsets for partition %s failed due to:" - " %s", partition, error_type.__name__) - partitions_to_retry.add(partition) - if unauthorized_topics: - future.failure(Errors.TopicAuthorizationFailedError(unauthorized_topics)) - else: - future.success((fetched_offsets, partitions_to_retry)) + " %s", partition, error_type) + future.failure(error_type(partition)) + return + if not future.is_done: + future.success(timestamp_offset_map) def _fetchable_partitions(self): fetchable = self._subscriptions.fetchable_partitions() # do not fetch a partition if we have a pending fetch response to process - # use copy.copy to avoid runtimeerror on mutation from different thread - # TODO: switch to deque.copy() with py3 - discard = {fetch.topic_partition for fetch in copy.copy(self._completed_fetches)} current = self._next_partition_records + pending = copy.copy(self._completed_fetches) if current: - discard.add(current.topic_partition) - return [tp for tp in fetchable if tp not in discard] + fetchable.discard(current.topic_partition) + for fetch in pending: + fetchable.discard(fetch.topic_partition) + return fetchable def _create_fetch_requests(self): """Create fetch requests for all assigned partitions, grouped by node. @@ -626,16 +662,25 @@ class Fetcher(six.Iterator): FetchRequests skipped if no leader, or node has requests in flight Returns: - dict: {node_id: (FetchRequest, {TopicPartition: fetch_offset}), ...} (version depends on client api_versions) + dict: {node_id: FetchRequest, ...} (version depends on api_version) """ # create the fetch info as a dict of lists of partition info tuples # which can be passed to FetchRequest() via .items() - version = self._client.api_version(FetchRequest, max_version=10) - fetchable = collections.defaultdict(collections.OrderedDict) + fetchable = collections.defaultdict(lambda: collections.defaultdict(list)) for partition in self._fetchable_partitions(): node_id = self._client.cluster.leader_for_partition(partition) + # advance position for any deleted compacted messages if required + if self._subscriptions.assignment[partition].last_offset_from_message_batch: + next_offset_from_batch_header = self._subscriptions.assignment[partition].last_offset_from_message_batch + 1 + if next_offset_from_batch_header > self._subscriptions.assignment[partition].position: + log.debug( + "Advance position for partition %s from %s to %s (last message batch location plus one)" + " to correct for deleted compacted messages", + partition, self._subscriptions.assignment[partition].position, next_offset_from_batch_header) + self._subscriptions.assignment[partition].position = next_offset_from_batch_header + position = self._subscriptions.assignment[partition].position # fetch if there is a leader and no in-flight requests @@ -644,161 +689,104 @@ class Fetcher(six.Iterator): " Requesting metadata update", partition) self._client.cluster.request_update() - elif not self._client.connected(node_id) and self._client.connection_delay(node_id) > 0: - # If we try to send during the reconnect backoff window, then the request is just - # going to be failed anyway before being sent, so skip the send for now - log.debug("Skipping fetch for partition %s because node %s is awaiting reconnect backoff", - partition, node_id) - - elif self._client.throttle_delay(node_id) > 0: - # If we try to send while throttled, then the request is just - # going to be failed anyway before being sent, so skip the send for now - log.debug("Skipping fetch for partition %s because node %s is throttled", - partition, node_id) - - elif not self._client.ready(node_id): - # Until we support send request queues, any attempt to send to a not-ready node will be - # immediately failed with NodeNotReadyError. - log.debug("Skipping fetch for partition %s because connection to leader node is not ready yet") - - elif node_id in self._nodes_with_pending_fetch_requests: - log.debug("Skipping fetch for partition %s because there is a pending fetch request to node %s", - partition, node_id) - - else: - # Leader is connected and does not have a pending fetch request - if version < 5: - partition_info = ( - partition.partition, - position.offset, - self.config['max_partition_fetch_bytes'] - ) - elif version <= 8: - partition_info = ( - partition.partition, - position.offset, - -1, # log_start_offset is used internally by brokers / replicas only - self.config['max_partition_fetch_bytes'], - ) - else: - partition_info = ( - partition.partition, - position.leader_epoch, - position.offset, - -1, # log_start_offset is used internally by brokers / replicas only - self.config['max_partition_fetch_bytes'], - ) - - fetchable[node_id][partition] = partition_info + elif self._client.in_flight_request_count(node_id) == 0: + partition_info = ( + partition.partition, + position, + self.config['max_partition_fetch_bytes'] + ) + fetchable[node_id][partition.topic].append(partition_info) log.debug("Adding fetch request for partition %s at offset %d", - partition, position.offset) + partition, position) + else: + log.log(0, "Skipping fetch for partition %s because there is an inflight request to node %s", + partition, node_id) + if self.config['api_version'] >= (0, 11, 0): + version = 4 + elif self.config['api_version'] >= (0, 10, 1): + version = 3 + elif self.config['api_version'] >= (0, 10): + version = 2 + elif self.config['api_version'] == (0, 9): + version = 1 + else: + version = 0 requests = {} - for node_id, next_partitions in six.iteritems(fetchable): - if version >= 7 and self.config['enable_incremental_fetch_sessions']: - if node_id not in self._session_handlers: - self._session_handlers[node_id] = FetchSessionHandler(node_id) - session = self._session_handlers[node_id].build_next(next_partitions) + for node_id, partition_data in six.iteritems(fetchable): + if version < 3: + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + partition_data.items()) else: - # No incremental fetch support - session = FetchRequestData(next_partitions, None, FetchMetadata.LEGACY) - - if version <= 2: - request = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - session.to_send) - elif version == 3: - request = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - self.config['fetch_max_bytes'], - session.to_send) - elif version <= 6: - request = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - self.config['fetch_max_bytes'], - self._isolation_level, - session.to_send) - else: - # Through v8 - request = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - self.config['fetch_max_bytes'], - self._isolation_level, - session.id, - session.epoch, - session.to_send, - session.to_forget) - - fetch_offsets = {} - for tp, partition_data in six.iteritems(next_partitions): - if version <= 8: - offset = partition_data[1] + # As of version == 3 partitions will be returned in order as + # they are requested, so to avoid starvation with + # `fetch_max_bytes` option we need this shuffle + # NOTE: we do have partition_data in random order due to usage + # of unordered structures like dicts, but that does not + # guarantee equal distribution, and starting in Python3.6 + # dicts retain insert order. + partition_data = list(partition_data.items()) + random.shuffle(partition_data) + if version == 3: + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + self.config['fetch_max_bytes'], + partition_data) else: - offset = partition_data[2] - fetch_offsets[tp] = offset - - requests[node_id] = (request, fetch_offsets) - + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + self.config['fetch_max_bytes'], + self._isolation_level, + partition_data) return requests - def _handle_fetch_response(self, node_id, fetch_offsets, send_time, response): + def _handle_fetch_response(self, request, send_time, response): """The callback for fetch completion""" - if response.API_VERSION >= 7 and self.config['enable_incremental_fetch_sessions']: - if node_id not in self._session_handlers: - log.error("Unable to find fetch session handler for node %s. Ignoring fetch response", node_id) - return - if not self._session_handlers[node_id].handle_response(response): - return + fetch_offsets = {} + for topic, partitions in request.topics: + for partition_data in partitions: + partition, offset = partition_data[:2] + fetch_offsets[TopicPartition(topic, partition)] = offset partitions = set([TopicPartition(topic, partition_data[0]) for topic, partitions in response.topics for partition_data in partitions]) - if self._sensors: - metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) - else: - metric_aggregator = None + metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) + # randomized ordering should improve balance for short-lived consumers + random.shuffle(response.topics) for topic, partitions in response.topics: + random.shuffle(partitions) for partition_data in partitions: tp = TopicPartition(topic, partition_data[0]) - fetch_offset = fetch_offsets[tp] completed_fetch = CompletedFetch( - tp, fetch_offset, + tp, fetch_offsets[tp], response.API_VERSION, partition_data[1:], metric_aggregator ) self._completed_fetches.append(completed_fetch) - if self._sensors: - self._sensors.fetch_latency.record((time.time() - send_time) * 1000) - - def _handle_fetch_error(self, node_id, exception): - level = logging.INFO if isinstance(exception, Errors.Cancelled) else logging.ERROR - log.log(level, 'Fetch to node %s failed: %s', node_id, exception) - if node_id in self._session_handlers: - self._session_handlers[node_id].handle_error(exception) - - def _clear_pending_fetch_request(self, node_id, _): - try: - self._nodes_with_pending_fetch_requests.remove(node_id) - except KeyError: - pass + if response.API_VERSION >= 1: + self._sensors.fetch_throttle_time_sensor.record(response.throttle_time_ms) + self._sensors.fetch_latency.record((time.time() - send_time) * 1000) def _parse_fetched_data(self, completed_fetch): tp = completed_fetch.topic_partition fetch_offset = completed_fetch.fetched_offset + num_bytes = 0 + records_count = 0 + parsed_records = None + error_code, highwater = completed_fetch.partition_data[:2] error_type = Errors.for_code(error_code) - parsed_records = None try: if not self._subscriptions.is_fetchable(tp): @@ -808,498 +796,117 @@ class Fetcher(six.Iterator): " since it is no longer fetchable", tp) elif error_type is Errors.NoError: + self._subscriptions.assignment[tp].highwater = highwater + # we are interested in this fetch only if the beginning # offset (of the *request*) matches the current consumed position # Note that the *response* may return a messageset that starts # earlier (e.g., compressed messages) or later (e.g., compacted topic) position = self._subscriptions.assignment[tp].position - if position is None or position.offset != fetch_offset: + if position is None or position != fetch_offset: log.debug("Discarding fetch response for partition %s" " since its offset %d does not match the" " expected offset %d", tp, fetch_offset, - position.offset) + position) return None records = MemoryRecords(completed_fetch.partition_data[-1]) - aborted_transactions = None - if completed_fetch.response_version >= 11: - aborted_transactions = completed_fetch.partition_data[-3] - elif completed_fetch.response_version >= 4: - aborted_transactions = completed_fetch.partition_data[-2] - log.debug("Preparing to read %s bytes of data for partition %s with offset %d", - records.size_in_bytes(), tp, fetch_offset) - parsed_records = self.PartitionRecords(fetch_offset, tp, records, - key_deserializer=self.config['key_deserializer'], - value_deserializer=self.config['value_deserializer'], - check_crcs=self.config['check_crcs'], - isolation_level=self._isolation_level, - aborted_transactions=aborted_transactions, - metric_aggregator=completed_fetch.metric_aggregator, - on_drain=self._on_partition_records_drain) - if not records.has_next() and records.size_in_bytes() > 0: - if completed_fetch.response_version < 3: - # Implement the pre KIP-74 behavior of throwing a RecordTooLargeException. - record_too_large_partitions = {tp: fetch_offset} - raise RecordTooLargeError( - "There are some messages at [Partition=Offset]: %s " - " whose size is larger than the fetch size %s" - " and hence cannot be ever returned. Please condier upgrading your broker to 0.10.1.0 or" - " newer to avoid this issue. Alternatively, increase the fetch size on the client (using" - " max_partition_fetch_bytes)" % ( - record_too_large_partitions, - self.config['max_partition_fetch_bytes']), - record_too_large_partitions) - else: - # This should not happen with brokers that support FetchRequest/Response V3 or higher (i.e. KIP-74) - raise Errors.KafkaError("Failed to make progress reading messages at %s=%s." - " Received a non-empty fetch response from the server, but no" - " complete records were found." % (tp, fetch_offset)) - - if highwater >= 0: - self._subscriptions.assignment[tp].highwater = highwater + if records.has_next(): + log.debug("Adding fetched record for partition %s with" + " offset %d to buffered record list", tp, + position) + unpacked = list(self._unpack_message_set(tp, records)) + parsed_records = self.PartitionRecords(fetch_offset, tp, unpacked) + last_offset = unpacked[-1].offset + self._sensors.records_fetch_lag.record(highwater - last_offset) + num_bytes = records.valid_bytes() + records_count = len(unpacked) + elif records.size_in_bytes() > 0: + # we did not read a single message from a non-empty + # buffer because that message's size is larger than + # fetch size, in this case record this exception + record_too_large_partitions = {tp: fetch_offset} + raise RecordTooLargeError( + "There are some messages at [Partition=Offset]: %s " + " whose size is larger than the fetch size %s" + " and hence cannot be ever returned." + " Increase the fetch size, or decrease the maximum message" + " size the broker will allow." % ( + record_too_large_partitions, + self.config['max_partition_fetch_bytes']), + record_too_large_partitions) + self._sensors.record_topic_fetch_metrics(tp.topic, num_bytes, records_count) elif error_type in (Errors.NotLeaderForPartitionError, - Errors.ReplicaNotAvailableError, - Errors.UnknownTopicOrPartitionError, - Errors.KafkaStorageError): - log.debug("Error fetching partition %s: %s", tp, error_type.__name__) + Errors.UnknownTopicOrPartitionError): self._client.cluster.request_update() elif error_type is Errors.OffsetOutOfRangeError: position = self._subscriptions.assignment[tp].position - if position is None or position.offset != fetch_offset: + if position is None or position != fetch_offset: log.debug("Discarding stale fetch response for partition %s" " since the fetched offset %d does not match the" - " current offset %d", tp, fetch_offset, position.offset) + " current offset %d", tp, fetch_offset, position) elif self._subscriptions.has_default_offset_reset_policy(): log.info("Fetch offset %s is out of range for topic-partition %s", fetch_offset, tp) - self._subscriptions.request_offset_reset(tp) + self._subscriptions.need_offset_reset(tp) else: raise Errors.OffsetOutOfRangeError({tp: fetch_offset}) elif error_type is Errors.TopicAuthorizationFailedError: log.warning("Not authorized to read from topic %s.", tp.topic) - raise Errors.TopicAuthorizationFailedError(set([tp.topic])) - elif getattr(error_type, 'retriable', False): - log.debug("Retriable error fetching partition %s: %s", tp, error_type()) - if getattr(error_type, 'invalid_metadata', False): - self._client.cluster.request_update() + raise Errors.TopicAuthorizationFailedError(set(tp.topic)) + elif error_type is Errors.UnknownError: + log.warning("Unknown error fetching data for topic-partition %s", tp) else: raise error_type('Unexpected error while fetching data') finally: - if parsed_records is None and completed_fetch.metric_aggregator: - completed_fetch.metric_aggregator.record(tp, 0, 0) - - if error_type is not Errors.NoError: - # we move the partition to the end if there was an error. This way, it's more likely that partitions for - # the same topic can remain together (allowing for more efficient serialization). - self._subscriptions.move_partition_to_end(tp) + completed_fetch.metric_aggregator.record(tp, num_bytes, records_count) return parsed_records - def _on_partition_records_drain(self, partition_records): - # we move the partition to the end if we received some bytes. This way, it's more likely that partitions - # for the same topic can remain together (allowing for more efficient serialization). - if partition_records.bytes_read > 0: - self._subscriptions.move_partition_to_end(partition_records.topic_partition) - - def close(self): - if self._next_partition_records is not None: - self._next_partition_records.drain() - self._next_in_line_exception_metadata = None - class PartitionRecords(object): - def __init__(self, fetch_offset, tp, records, - key_deserializer=None, value_deserializer=None, - check_crcs=True, isolation_level=READ_UNCOMMITTED, - aborted_transactions=None, # raw data from response / list of (producer_id, first_offset) tuples - metric_aggregator=None, on_drain=lambda x: None): + def __init__(self, fetch_offset, tp, messages): self.fetch_offset = fetch_offset self.topic_partition = tp - self.leader_epoch = -1 - self.next_fetch_offset = fetch_offset - self.bytes_read = 0 - self.records_read = 0 - self.isolation_level = isolation_level - self.aborted_producer_ids = set() - self.aborted_transactions = collections.deque( - sorted([AbortedTransaction(*data) for data in aborted_transactions] if aborted_transactions else [], - key=lambda txn: txn.first_offset) - ) - self.metric_aggregator = metric_aggregator - self.check_crcs = check_crcs - self.record_iterator = itertools.dropwhile( - self._maybe_skip_record, - self._unpack_records(tp, records, key_deserializer, value_deserializer)) - self.on_drain = on_drain - self._next_inline_exception = None - - def _maybe_skip_record(self, record): + self.messages = messages # When fetching an offset that is in the middle of a # compressed batch, we will get all messages in the batch. # But we want to start 'take' at the fetch_offset # (or the next highest offset in case the message was compacted) - if record.offset < self.fetch_offset: - log.debug("Skipping message offset: %s (expecting %s)", - record.offset, self.fetch_offset) - return True + for i, msg in enumerate(messages): + if msg.offset < fetch_offset: + log.debug("Skipping message offset: %s (expecting %s)", + msg.offset, fetch_offset) + else: + self.message_idx = i + break + else: - return False + self.message_idx = 0 + self.messages = None - # For truthiness evaluation - def __bool__(self): - return self.record_iterator is not None + # For truthiness evaluation we need to define __len__ or __nonzero__ + def __len__(self): + if self.messages is None or self.message_idx >= len(self.messages): + return 0 + return len(self.messages) - self.message_idx - # py2 - __nonzero__ = __bool__ - - def drain(self): - if self.record_iterator is not None: - self.record_iterator = None - self._next_inline_exception = None - if self.metric_aggregator: - self.metric_aggregator.record(self.topic_partition, self.bytes_read, self.records_read) - self.on_drain(self) - - def _maybe_raise_next_inline_exception(self): - if self._next_inline_exception: - exc, self._next_inline_exception = self._next_inline_exception, None - raise exc + def discard(self): + self.messages = None def take(self, n=None): - self._maybe_raise_next_inline_exception() - records = [] - try: - # Note that records.extend(iter) will extend partially when exception raised mid-stream - records.extend(itertools.islice(self.record_iterator, 0, n)) - except Exception as e: - if not records: - raise e - # To be thrown in the next call of this method - self._next_inline_exception = e - return records - - def _unpack_records(self, tp, records, key_deserializer, value_deserializer): - try: - batch = records.next_batch() - last_batch = None - while batch is not None: - last_batch = batch - - if self.check_crcs and not batch.validate_crc(): - raise Errors.CorruptRecordError( - "Record batch for partition %s at offset %s failed crc check" % ( - self.topic_partition, batch.base_offset)) - - - # Try DefaultsRecordBatch / message log format v2 - # base_offset, last_offset_delta, aborted transactions, and control batches - if batch.magic == 2: - self.leader_epoch = batch.leader_epoch - if self.isolation_level == READ_COMMITTED and batch.has_producer_id(): - # remove from the aborted transaction queue all aborted transactions which have begun - # before the current batch's last offset and add the associated producerIds to the - # aborted producer set - self._consume_aborted_transactions_up_to(batch.last_offset) - - producer_id = batch.producer_id - if self._contains_abort_marker(batch): - try: - self.aborted_producer_ids.remove(producer_id) - except KeyError: - pass - elif self._is_batch_aborted(batch): - log.debug("Skipping aborted record batch from partition %s with producer_id %s and" - " offsets %s to %s", - self.topic_partition, producer_id, batch.base_offset, batch.last_offset) - self.next_fetch_offset = batch.next_offset - batch = records.next_batch() - continue - - # Control batches have a single record indicating whether a transaction - # was aborted or committed. These are not returned to the consumer. - if batch.is_control_batch: - self.next_fetch_offset = batch.next_offset - batch = records.next_batch() - continue - - for record in batch: - if self.check_crcs and not record.validate_crc(): - raise Errors.CorruptRecordError( - "Record for partition %s at offset %s failed crc check" % ( - self.topic_partition, record.offset)) - key_size = len(record.key) if record.key is not None else -1 - value_size = len(record.value) if record.value is not None else -1 - key = self._deserialize(key_deserializer, tp.topic, record.key) - value = self._deserialize(value_deserializer, tp.topic, record.value) - headers = record.headers - header_size = sum( - len(h_key.encode("utf-8")) + (len(h_val) if h_val is not None else 0) for h_key, h_val in - headers) if headers else -1 - self.records_read += 1 - self.bytes_read += record.size_in_bytes - self.next_fetch_offset = record.offset + 1 - yield ConsumerRecord( - tp.topic, tp.partition, self.leader_epoch, record.offset, record.timestamp, - record.timestamp_type, key, value, headers, record.checksum, - key_size, value_size, header_size) - - batch = records.next_batch() - else: - # Message format v2 preserves the last offset in a batch even if the last record is removed - # through compaction. By using the next offset computed from the last offset in the batch, - # we ensure that the offset of the next fetch will point to the next batch, which avoids - # unnecessary re-fetching of the same batch (in the worst case, the consumer could get stuck - # fetching the same batch repeatedly). - if last_batch and last_batch.magic == 2: - self.next_fetch_offset = last_batch.next_offset - self.drain() - - # If unpacking raises StopIteration, it is erroneously - # caught by the generator. We want all exceptions to be raised - # back to the user. See Issue 545 - except StopIteration: - log.exception('StopIteration raised unpacking messageset') - raise RuntimeError('StopIteration raised unpacking messageset') - - def _deserialize(self, f, topic, bytes_): - if not f: - return bytes_ - if isinstance(f, Deserializer): - return f.deserialize(topic, bytes_) - return f(bytes_) - - def _consume_aborted_transactions_up_to(self, offset): - if not self.aborted_transactions: - return - - while self.aborted_transactions and self.aborted_transactions[0].first_offset <= offset: - self.aborted_producer_ids.add(self.aborted_transactions.popleft().producer_id) - - def _is_batch_aborted(self, batch): - return batch.is_transactional and batch.producer_id in self.aborted_producer_ids - - def _contains_abort_marker(self, batch): - if not batch.is_control_batch: - return False - record = next(batch) - if not record: - return False - return record.abort - - -class FetchSessionHandler(object): - """ - FetchSessionHandler maintains the fetch session state for connecting to a broker. - - Using the protocol outlined by KIP-227, clients can create incremental fetch sessions. - These sessions allow the client to fetch information about a set of partition over - and over, without explicitly enumerating all the partitions in the request and the - response. - - FetchSessionHandler tracks the partitions which are in the session. It also - determines which partitions need to be included in each fetch request, and what - the attached fetch session metadata should be for each request. - """ - - def __init__(self, node_id): - self.node_id = node_id - self.next_metadata = FetchMetadata.INITIAL - self.session_partitions = {} - - def build_next(self, next_partitions): - """ - Arguments: - next_partitions (dict): TopicPartition -> TopicPartitionState - - Returns: - FetchRequestData - """ - if self.next_metadata.is_full: - log.debug("Built full fetch %s for node %s with %s partition(s).", - self.next_metadata, self.node_id, len(next_partitions)) - self.session_partitions = next_partitions - return FetchRequestData(next_partitions, None, self.next_metadata) - - prev_tps = set(self.session_partitions.keys()) - next_tps = set(next_partitions.keys()) - log.debug("Building incremental partitions from next: %s, previous: %s", next_tps, prev_tps) - added = next_tps - prev_tps - for tp in added: - self.session_partitions[tp] = next_partitions[tp] - removed = prev_tps - next_tps - for tp in removed: - self.session_partitions.pop(tp) - altered = set() - for tp in next_tps & prev_tps: - if next_partitions[tp] != self.session_partitions[tp]: - self.session_partitions[tp] = next_partitions[tp] - altered.add(tp) - - log.debug("Built incremental fetch %s for node %s. Added %s, altered %s, removed %s out of %s", - self.next_metadata, self.node_id, added, altered, removed, self.session_partitions.keys()) - to_send = collections.OrderedDict({tp: next_partitions[tp] for tp in next_partitions if tp in (added | altered)}) - return FetchRequestData(to_send, removed, self.next_metadata) - - def handle_response(self, response): - if response.error_code != Errors.NoError.errno: - error_type = Errors.for_code(response.error_code) - log.info("Node %s was unable to process the fetch request with %s: %s.", - self.node_id, self.next_metadata, error_type()) - if error_type is Errors.FetchSessionIdNotFoundError: - self.next_metadata = FetchMetadata.INITIAL - else: - self.next_metadata = self.next_metadata.next_close_existing() - return False - - response_tps = self._response_partitions(response) - session_tps = set(self.session_partitions.keys()) - if self.next_metadata.is_full: - if response_tps != session_tps: - log.info("Node %s sent an invalid full fetch response with extra %s / omitted %s", - self.node_id, response_tps - session_tps, session_tps - response_tps) - self.next_metadata = FetchMetadata.INITIAL - return False - elif response.session_id == FetchMetadata.INVALID_SESSION_ID: - log.debug("Node %s sent a full fetch response with %s partitions", - self.node_id, len(response_tps)) - self.next_metadata = FetchMetadata.INITIAL - return True - elif response.session_id == FetchMetadata.THROTTLED_SESSION_ID: - log.debug("Node %s sent a empty full fetch response due to a quota violation (%s partitions)", - self.node_id, len(response_tps)) - # Keep current metadata - return True - else: - # The server created a new incremental fetch session. - log.debug("Node %s sent a full fetch response that created a new incremental fetch session %s" - " with %s response partitions", - self.node_id, response.session_id, - len(response_tps)) - self.next_metadata = FetchMetadata.new_incremental(response.session_id) - return True - else: - if response_tps - session_tps: - log.info("Node %s sent an invalid incremental fetch response with extra partitions %s", - self.node_id, response_tps - session_tps) - self.next_metadata = self.next_metadata.next_close_existing() - return False - elif response.session_id == FetchMetadata.INVALID_SESSION_ID: - # The incremental fetch session was closed by the server. - log.debug("Node %s sent an incremental fetch response closing session %s" - " with %s response partitions (%s implied)", - self.node_id, self.next_metadata.session_id, - len(response_tps), len(self.session_partitions) - len(response_tps)) - self.next_metadata = FetchMetadata.INITIAL - return True - elif response.session_id == FetchMetadata.THROTTLED_SESSION_ID: - log.debug("Node %s sent a empty incremental fetch response due to a quota violation (%s partitions)", - self.node_id, len(response_tps)) - # Keep current metadata - return True - else: - # The incremental fetch session was continued by the server. - log.debug("Node %s sent an incremental fetch response for session %s" - " with %s response partitions (%s implied)", - self.node_id, response.session_id, - len(response_tps), len(self.session_partitions) - len(response_tps)) - self.next_metadata = self.next_metadata.next_incremental() - return True - - def handle_error(self, _exception): - self.next_metadata = self.next_metadata.next_close_existing() - - def _response_partitions(self, response): - return {TopicPartition(topic, partition_data[0]) - for topic, partitions in response.topics - for partition_data in partitions} - - -class FetchMetadata(object): - __slots__ = ('session_id', 'epoch') - - MAX_EPOCH = 2147483647 - INVALID_SESSION_ID = 0 # used by clients with no session. - THROTTLED_SESSION_ID = -1 # returned with empty response on quota violation - INITIAL_EPOCH = 0 # client wants to create or recreate a session. - FINAL_EPOCH = -1 # client wants to close any existing session, and not create a new one. - - def __init__(self, session_id, epoch): - self.session_id = session_id - self.epoch = epoch - - @property - def is_full(self): - return self.epoch == self.INITIAL_EPOCH or self.epoch == self.FINAL_EPOCH - - @classmethod - def next_epoch(cls, prev_epoch): - if prev_epoch < 0: - return cls.FINAL_EPOCH - elif prev_epoch == cls.MAX_EPOCH: - return 1 - else: - return prev_epoch + 1 - - def next_close_existing(self): - return self.__class__(self.session_id, self.INITIAL_EPOCH) - - @classmethod - def new_incremental(cls, session_id): - return cls(session_id, cls.next_epoch(cls.INITIAL_EPOCH)) - - def next_incremental(self): - return self.__class__(self.session_id, self.next_epoch(self.epoch)) - -FetchMetadata.INITIAL = FetchMetadata(FetchMetadata.INVALID_SESSION_ID, FetchMetadata.INITIAL_EPOCH) -FetchMetadata.LEGACY = FetchMetadata(FetchMetadata.INVALID_SESSION_ID, FetchMetadata.FINAL_EPOCH) - - -class FetchRequestData(object): - __slots__ = ('_to_send', '_to_forget', '_metadata') - - def __init__(self, to_send, to_forget, metadata): - self._to_send = to_send or dict() # {TopicPartition: (partition, ...)} - self._to_forget = to_forget or set() # {TopicPartition} - self._metadata = metadata - - @property - def metadata(self): - return self._metadata - - @property - def id(self): - return self._metadata.session_id - - @property - def epoch(self): - return self._metadata.epoch - - @property - def to_send(self): - # Return as list of [(topic, [(partition, ...), ...]), ...] - # so it can be passed directly to encoder - partition_data = collections.defaultdict(list) - for tp, partition_info in six.iteritems(self._to_send): - partition_data[tp.topic].append(partition_info) - return list(partition_data.items()) - - @property - def to_forget(self): - # Return as list of [(topic, (partiiton, ...)), ...] - # so it an be passed directly to encoder - partition_data = collections.defaultdict(list) - for tp in self._to_forget: - partition_data[tp.topic].append(tp.partition) - return list(partition_data.items()) - - -class FetchMetrics(object): - __slots__ = ('total_bytes', 'total_records') - - def __init__(self): - self.total_bytes = 0 - self.total_records = 0 + if not len(self): + return [] + if n is None or n > len(self): + n = len(self) + next_idx = self.message_idx + n + res = self.messages[self.message_idx:next_idx] + self.message_idx = next_idx + # fetch_offset should be incremented by 1 to parallel the + # subscription position (also incremented by 1) + self.fetch_offset = max(self.fetch_offset, res[-1].offset + 1) + return res class FetchResponseMetricAggregator(object): @@ -1312,8 +919,8 @@ class FetchResponseMetricAggregator(object): def __init__(self, sensors, partitions): self.sensors = sensors self.unrecorded_partitions = partitions - self.fetch_metrics = FetchMetrics() - self.topic_fetch_metrics = collections.defaultdict(FetchMetrics) + self.total_bytes = 0 + self.total_records = 0 def record(self, partition, num_bytes, num_records): """ @@ -1322,17 +929,13 @@ class FetchResponseMetricAggregator(object): have reported, we write the metric. """ self.unrecorded_partitions.remove(partition) - self.fetch_metrics.total_bytes += num_bytes - self.fetch_metrics.total_records += num_records - self.topic_fetch_metrics[partition.topic].total_bytes += num_bytes - self.topic_fetch_metrics[partition.topic].total_records += num_records + self.total_bytes += num_bytes + self.total_records += num_records # once all expected partitions from the fetch have reported in, record the metrics if not self.unrecorded_partitions: - self.sensors.bytes_fetched.record(self.fetch_metrics.total_bytes) - self.sensors.records_fetched.record(self.fetch_metrics.total_records) - for topic, metrics in six.iteritems(self.topic_fetch_metrics): - self.sensors.record_topic_fetch_metrics(topic, metrics.total_bytes, metrics.total_records) + self.sensors.bytes_fetched.record(self.total_bytes) + self.sensors.records_fetched.record(self.total_records) class FetchManagerMetrics(object): @@ -1366,6 +969,12 @@ class FetchManagerMetrics(object): self.records_fetch_lag.add(metrics.metric_name('records-lag-max', self.group_name, 'The maximum lag in terms of number of records for any partition in self window'), Max()) + self.fetch_throttle_time_sensor = metrics.sensor('fetch-throttle-time') + self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-avg', self.group_name, + 'The average throttle time in ms'), Avg()) + self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-max', self.group_name, + 'The maximum throttle time in ms'), Max()) + def record_topic_fetch_metrics(self, topic, num_bytes, num_records): # record bytes fetched name = '.'.join(['topic', topic, 'bytes-fetched']) diff --git a/venv/lib/python3.12/site-packages/kafka/consumer/group.py b/venv/lib/python3.12/site-packages/kafka/consumer/group.py index bc974ee..26408c3 100644 --- a/venv/lib/python3.12/site-packages/kafka/consumer/group.py +++ b/venv/lib/python3.12/site-packages/kafka/consumer/group.py @@ -5,7 +5,7 @@ import logging import socket import time -from kafka.errors import KafkaConfigurationError, KafkaTimeoutError, UnsupportedVersionError +from kafka.errors import KafkaConfigurationError, UnsupportedVersionError from kafka.vendor import six @@ -16,9 +16,8 @@ from kafka.coordinator.consumer import ConsumerCoordinator from kafka.coordinator.assignors.range import RangePartitionAssignor from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.metrics import MetricConfig, Metrics -from kafka.protocol.list_offsets import OffsetResetStrategy -from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import Timer +from kafka.protocol.offset import OffsetResetStrategy +from kafka.structs import TopicPartition from kafka.version import __version__ log = logging.getLogger(__name__) @@ -61,8 +60,6 @@ class KafkaConsumer(six.Iterator): raw message key and returns a deserialized key. value_deserializer (callable): Any callable that takes a raw message value and returns a deserialized value. - enable_incremental_fetch_sessions: (bool): Use incremental fetch sessions - when available / supported by kafka broker. See KIP-227. Default: True. fetch_min_bytes (int): Minimum amount of data the server should return for a fetch request, otherwise wait up to fetch_max_wait_ms for more data to accumulate. Default: 1. @@ -101,7 +98,7 @@ class KafkaConsumer(six.Iterator): reconnection attempts will continue periodically with this fixed rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 30000. + 20% below and 20% above the computed value. Default: 1000. max_in_flight_requests_per_connection (int): Requests are pipelined to kafka brokers up to this number of maximum requests per broker connection. Default: 5. @@ -121,12 +118,6 @@ class KafkaConsumer(six.Iterator): consumed. This ensures no on-the-wire or on-disk corruption to the messages occurred. This check adds some overhead, so it may be disabled in cases seeking extreme performance. Default: True - isolation_level (str): Configure KIP-98 transactional consumer by - setting to 'read_committed'. This will cause the consumer to - skip records from aborted transactions. Default: 'read_uncommitted' - allow_auto_create_topics (bool): Enable/disable auto topic creation - on metadata request. Only available with api_version >= (0, 11). - Default: True metadata_max_age_ms (int): The period of time in milliseconds after which we force a refresh of metadata, even if we haven't seen any partition leadership changes to proactively discover any new @@ -204,17 +195,10 @@ class KafkaConsumer(six.Iterator): or other configuration forbids use of all the specified ciphers), an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers api_version (tuple): Specify which Kafka API version to use. If set to - None, the client will attempt to determine the broker version via - ApiVersionsRequest API or, for brokers earlier than 0.10, probing - various known APIs. Dynamic version checking is performed eagerly - during __init__ and can raise NoBrokersAvailableError if no connection - was made before timeout (see api_version_auto_timeout_ms below). - Different versions enable different functionality. + None, the client will attempt to infer the broker version by probing + various APIs. Different versions enable different functionality. Examples: - (3, 9) most recent broker release, enable all supported features - (0, 11) enables message format v2 (internal) - (0, 10, 0) enables sasl authentication and message format v1 (0, 9) enables full group coordination features with automatic partition assignment and rebalancing, (0, 8, 2) enables kafka-storage offset commits with manual @@ -228,7 +212,6 @@ class KafkaConsumer(six.Iterator): api_version_auto_timeout_ms (int): number of milliseconds to throw a timeout exception from the constructor when checking the broker api version. Only applies if api_version set to None. - Default: 2000 connections_max_idle_ms: Close idle connections after the number of milliseconds specified by this config. The broker closes idle connections after connections.max.idle.ms, so this avoids hitting @@ -237,7 +220,6 @@ class KafkaConsumer(six.Iterator): metric_reporters (list): A list of classes to use as metrics reporters. Implementing the AbstractMetricsReporter interface allows plugging in classes that will be notified of new metric creation. Default: [] - metrics_enabled (bool): Whether to track metrics on this instance. Default True. metrics_num_samples (int): The number of samples maintained to compute metrics. Default: 2 metrics_sample_window_ms (int): The maximum age in milliseconds of @@ -256,17 +238,12 @@ class KafkaConsumer(six.Iterator): Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_name (str or gssapi.Name): Constructed gssapi.Name for use with - sasl mechanism handshake. If provided, sasl_kerberos_service_name and - sasl_kerberos_domain name are ignored. Default: None. sasl_kerberos_service_name (str): Service name to include in GSSAPI sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (kafka.sasl.oauth.AbstractTokenProvider): OAuthBearer - token provider instance. Default: None - socks5_proxy (str): Socks5 proxy URL. Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None Note: Configuration parameters are described in more detail at @@ -278,7 +255,6 @@ class KafkaConsumer(six.Iterator): 'group_id': None, 'key_deserializer': None, 'value_deserializer': None, - 'enable_incremental_fetch_sessions': True, 'fetch_max_wait_ms': 500, 'fetch_min_bytes': 1, 'fetch_max_bytes': 52428800, @@ -286,15 +262,13 @@ class KafkaConsumer(six.Iterator): 'request_timeout_ms': 305000, # chosen to be higher than the default of max_poll_interval_ms 'retry_backoff_ms': 100, 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 30000, + 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, 'auto_offset_reset': 'latest', 'enable_auto_commit': True, 'auto_commit_interval_ms': 5000, 'default_offset_commit_callback': lambda offsets, response: True, 'check_crcs': True, - 'isolation_level': 'read_uncommitted', - 'allow_auto_create_topics': True, 'metadata_max_age_ms': 5 * 60 * 1000, 'partition_assignment_strategy': (RangePartitionAssignor, RoundRobinPartitionAssignor), 'max_poll_records': 500, @@ -320,7 +294,6 @@ class KafkaConsumer(six.Iterator): 'api_version_auto_timeout_ms': 2000, 'connections_max_idle_ms': 9 * 60 * 1000, 'metric_reporters': [], - 'metrics_enabled': True, 'metrics_num_samples': 2, 'metrics_sample_window_ms': 30000, 'metric_group_prefix': 'consumer', @@ -329,12 +302,10 @@ class KafkaConsumer(six.Iterator): 'sasl_mechanism': None, 'sasl_plain_username': None, 'sasl_plain_password': None, - 'sasl_kerberos_name': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, 'sasl_oauth_token_provider': None, - 'socks5_proxy': None, - 'kafka_client': KafkaClient, + 'legacy_iterator': False, # enable to revert to < 1.4.7 iterator } DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000 @@ -364,15 +335,13 @@ class KafkaConsumer(six.Iterator): "fetch_max_wait_ms ({})." .format(connections_max_idle_ms, request_timeout_ms, fetch_max_wait_ms)) - if self.config['metrics_enabled']: - metrics_tags = {'client-id': self.config['client_id']} - metric_config = MetricConfig(samples=self.config['metrics_num_samples'], - time_window_ms=self.config['metrics_sample_window_ms'], - tags=metrics_tags) - reporters = [reporter() for reporter in self.config['metric_reporters']] - self._metrics = Metrics(metric_config, reporters) - else: - self._metrics = None + metrics_tags = {'client-id': self.config['client_id']} + metric_config = MetricConfig(samples=self.config['metrics_num_samples'], + time_window_ms=self.config['metrics_sample_window_ms'], + tags=metrics_tags) + reporters = [reporter() for reporter in self.config['metric_reporters']] + self._metrics = Metrics(metric_config, reporters) + # TODO _metrics likely needs to be passed to KafkaClient, etc. # api_version was previously a str. Accept old format for now if isinstance(self.config['api_version'], str): @@ -384,10 +353,11 @@ class KafkaConsumer(six.Iterator): log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', str(self.config['api_version']), str_version) - self._client = self.config['kafka_client'](metrics=self._metrics, **self.config) + self._client = KafkaClient(metrics=self._metrics, **self.config) - # Get auto-discovered / normalized version from client - self.config['api_version'] = self._client.config['api_version'] + # Get auto-discovered version from client if necessary + if self.config['api_version'] is None: + self.config['api_version'] = self._client.config['api_version'] # Coordinator configurations are different for older brokers # max_poll_interval_ms is not supported directly -- it must the be @@ -410,9 +380,9 @@ class KafkaConsumer(six.Iterator): self._subscription = SubscriptionState(self.config['auto_offset_reset']) self._fetcher = Fetcher( - self._client, self._subscription, metrics=self._metrics, **self.config) + self._client, self._subscription, self._metrics, **self.config) self._coordinator = ConsumerCoordinator( - self._client, self._subscription, metrics=self._metrics, + self._client, self._subscription, self._metrics, assignors=self.config['partition_assignment_strategy'], **self.config) self._closed = False @@ -452,15 +422,8 @@ class KafkaConsumer(six.Iterator): no rebalance operation triggered when group membership or cluster and topic metadata change. """ - if not partitions: - self.unsubscribe() - else: - # make sure the offsets of topic partitions the consumer is unsubscribing from - # are committed since there will be no following rebalance - self._coordinator.maybe_auto_commit_offsets_now() - self._subscription.assign_from_user(partitions) - self._client.set_topics([tp.topic for tp in partitions]) - log.debug("Subscribed to partition(s): %s", partitions) + self._subscription.assign_from_user(partitions) + self._client.set_topics([tp.topic for tp in partitions]) def assignment(self): """Get the TopicPartitions currently assigned to this consumer. @@ -478,23 +441,20 @@ class KafkaConsumer(six.Iterator): """ return self._subscription.assigned_partitions() - def close(self, autocommit=True, timeout_ms=None): + def close(self, autocommit=True): """Close the consumer, waiting indefinitely for any needed cleanup. Keyword Arguments: autocommit (bool): If auto-commit is configured for this consumer, this optional flag causes the consumer to attempt to commit any pending consumed offsets prior to close. Default: True - timeout_ms (num, optional): Milliseconds to wait for auto-commit. - Default: None """ if self._closed: return log.debug("Closing the KafkaConsumer.") self._closed = True - self._coordinator.close(autocommit=autocommit, timeout_ms=timeout_ms) - if self._metrics: - self._metrics.close() + self._coordinator.close(autocommit=autocommit) + self._metrics.close() self._client.close() try: self.config['key_deserializer'].close() @@ -540,7 +500,7 @@ class KafkaConsumer(six.Iterator): offsets, callback=callback) return future - def commit(self, offsets=None, timeout_ms=None): + def commit(self, offsets=None): """Commit offsets to kafka, blocking until success or error. This commits offsets only to Kafka. The offsets committed using this API @@ -564,16 +524,17 @@ class KafkaConsumer(six.Iterator): assert self.config['group_id'] is not None, 'Requires group_id' if offsets is None: offsets = self._subscription.all_consumed_offsets() - self._coordinator.commit_offsets_sync(offsets, timeout_ms=timeout_ms) + self._coordinator.commit_offsets_sync(offsets) - def committed(self, partition, metadata=False, timeout_ms=None): + def committed(self, partition, metadata=False): """Get the last committed offset for the given partition. This offset will be used as the position for the consumer in the event of a failure. - This call will block to do a remote call to get the latest committed - offsets from the server. + This call may block to do a remote call if the partition in question + isn't assigned to this consumer or if the consumer hasn't yet + initialized its cache of committed offsets. Arguments: partition (TopicPartition): The partition to check. @@ -582,19 +543,28 @@ class KafkaConsumer(six.Iterator): Returns: The last committed offset (int or OffsetAndMetadata), or None if there was no prior commit. - - Raises: - KafkaTimeoutError if timeout_ms provided - BrokerResponseErrors if OffsetFetchRequest raises an error. """ assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' assert self.config['group_id'] is not None, 'Requires group_id' if not isinstance(partition, TopicPartition): raise TypeError('partition must be a TopicPartition namedtuple') - committed = self._coordinator.fetch_committed_offsets([partition], timeout_ms=timeout_ms) - if partition not in committed: - return None - return committed[partition] if metadata else committed[partition].offset + if self._subscription.is_assigned(partition): + committed = self._subscription.assignment[partition].committed + if committed is None: + self._coordinator.refresh_committed_offsets_if_needed() + committed = self._subscription.assignment[partition].committed + else: + commit_map = self._coordinator.fetch_committed_offsets([partition]) + if partition in commit_map: + committed = commit_map[partition] + else: + committed = None + + if committed is not None: + if metadata: + return committed + else: + return committed.offset def _fetch_all_topic_metadata(self): """A blocking call that fetches topic metadata for all topics in the @@ -639,7 +609,7 @@ class KafkaConsumer(six.Iterator): if partitions is None: self._fetch_all_topic_metadata() partitions = cluster.partitions_for_topic(topic) - return partitions or set() + return partitions def poll(self, timeout_ms=0, max_records=None, update_offsets=True): """Fetch data from assigned topics / partitions. @@ -679,88 +649,82 @@ class KafkaConsumer(six.Iterator): assert not self._closed, 'KafkaConsumer is closed' # Poll for new data until the timeout expires - timer = Timer(timeout_ms) - while not self._closed: - records = self._poll_once(timer, max_records, update_offsets=update_offsets) + start = time.time() + remaining = timeout_ms + while True: + records = self._poll_once(remaining, max_records, update_offsets=update_offsets) if records: return records - elif timer.expired: - break - return {} - def _poll_once(self, timer, max_records, update_offsets=True): + elapsed_ms = (time.time() - start) * 1000 + remaining = timeout_ms - elapsed_ms + + if remaining <= 0: + return {} + + def _poll_once(self, timeout_ms, max_records, update_offsets=True): """Do one round of polling. In addition to checking for new data, this does any needed heart-beating, auto-commits, and offset updates. Arguments: - timer (Timer): The maximum time in milliseconds to block. + timeout_ms (int): The maximum time in milliseconds to block. Returns: dict: Map of topic to list of records (may be empty). """ - if not self._coordinator.poll(timeout_ms=timer.timeout_ms): - log.debug('poll: timeout during coordinator.poll(); returning early') - return {} + self._coordinator.poll() - has_all_fetch_positions = self._update_fetch_positions(timeout_ms=timer.timeout_ms) + # Fetch positions if we have partitions we're subscribed to that we + # don't know the offset for + if not self._subscription.has_all_fetch_positions(): + self._update_fetch_positions(self._subscription.missing_fetch_positions()) # If data is available already, e.g. from a previous network client # poll() call to commit, then just return it immediately records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) - log.debug('poll: fetched records: %s, %s', records, partial) - # Before returning the fetched records, we can send off the - # next round of fetches and avoid block waiting for their - # responses to enable pipelining while the user is handling the - # fetched records. - if not partial: - log.debug("poll: Sending fetches") - futures = self._fetcher.send_fetches() - if len(futures): - self._client.poll(timeout_ms=0) - if records: + # Before returning the fetched records, we can send off the + # next round of fetches and avoid block waiting for their + # responses to enable pipelining while the user is handling the + # fetched records. + if not partial: + futures = self._fetcher.send_fetches() + if len(futures): + self._client.poll(timeout_ms=0) return records - # We do not want to be stuck blocking in poll if we are missing some positions - # since the offset lookup may be backing off after a failure - poll_timeout_ms = min(timer.timeout_ms, self._coordinator.time_to_next_poll() * 1000) - if not has_all_fetch_positions: - log.debug('poll: do not have all fetch positions...') - poll_timeout_ms = min(poll_timeout_ms, self.config['retry_backoff_ms']) + # Send any new fetches (won't resend pending fetches) + futures = self._fetcher.send_fetches() + if len(futures): + self._client.poll(timeout_ms=0) - self._client.poll(timeout_ms=poll_timeout_ms) + timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll() * 1000) + self._client.poll(timeout_ms=timeout_ms) # after the long poll, we should check whether the group needs to rebalance # prior to returning data so that the group can stabilize faster if self._coordinator.need_rejoin(): - log.debug('poll: coordinator needs rejoin; returning early') return {} records, _ = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) return records - def position(self, partition, timeout_ms=None): + def position(self, partition): """Get the offset of the next record that will be fetched Arguments: partition (TopicPartition): Partition to check Returns: - int: Offset or None + int: Offset """ if not isinstance(partition, TopicPartition): raise TypeError('partition must be a TopicPartition namedtuple') assert self._subscription.is_assigned(partition), 'Partition is not assigned' - - timer = Timer(timeout_ms) - position = self._subscription.assignment[partition].position - while position is None: - # batch update fetch positions for any partitions without a valid position - if self._update_fetch_positions(timeout_ms=timer.timeout_ms): - position = self._subscription.assignment[partition].position - elif timer.expired: - return None - else: - return position.offset + offset = self._subscription.assignment[partition].position + if offset is None: + self._update_fetch_positions([partition]) + offset = self._subscription.assignment[partition].position + return offset def highwater(self, partition): """Last known highwater offset for a partition. @@ -854,7 +818,8 @@ class KafkaConsumer(six.Iterator): assert partition in self._subscription.assigned_partitions(), 'Unassigned partition' log.debug("Seeking to offset %s for partition %s", offset, partition) self._subscription.assignment[partition].seek(offset) - self._iterator = None + if not self.config['legacy_iterator']: + self._iterator = None def seek_to_beginning(self, *partitions): """Seek to the oldest available offset for partitions. @@ -878,8 +843,9 @@ class KafkaConsumer(six.Iterator): for tp in partitions: log.debug("Seeking to beginning of partition %s", tp) - self._subscription.request_offset_reset(tp, OffsetResetStrategy.EARLIEST) - self._iterator = None + self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST) + if not self.config['legacy_iterator']: + self._iterator = None def seek_to_end(self, *partitions): """Seek to the most recent available offset for partitions. @@ -903,8 +869,9 @@ class KafkaConsumer(six.Iterator): for tp in partitions: log.debug("Seeking to end of partition %s", tp) - self._subscription.request_offset_reset(tp, OffsetResetStrategy.LATEST) - self._iterator = None + self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST) + if not self.config['legacy_iterator']: + self._iterator = None def subscribe(self, topics=(), pattern=None, listener=None): """Subscribe to a list of topics, or a topic regex pattern. @@ -975,16 +942,13 @@ class KafkaConsumer(six.Iterator): def unsubscribe(self): """Unsubscribe from all topics and clear all assigned partitions.""" - # make sure the offsets of topic partitions the consumer is unsubscribing from - # are committed since there will be no following rebalance - self._coordinator.maybe_auto_commit_offsets_now() self._subscription.unsubscribe() - if self.config['api_version'] >= (0, 9): - self._coordinator.maybe_leave_group() + self._coordinator.close() self._client.cluster.need_all_topic_metadata = False self._client.set_topics([]) log.debug("Unsubscribed all topics or patterns and assigned partitions") - self._iterator = None + if not self.config['legacy_iterator']: + self._iterator = None def metrics(self, raw=False): """Get metrics on consumer performance. @@ -996,8 +960,6 @@ class KafkaConsumer(six.Iterator): This is an unstable interface. It may change in future releases without warning. """ - if not self._metrics: - return if raw: return self._metrics.metrics.copy() @@ -1053,7 +1015,7 @@ class KafkaConsumer(six.Iterator): raise ValueError( "The target time for partition {} is {}. The target time " "cannot be negative.".format(tp, ts)) - return self._fetcher.offsets_by_times( + return self._fetcher.get_offsets_by_times( timestamps, self.config['request_timeout_ms']) def beginning_offsets(self, partitions): @@ -1119,7 +1081,7 @@ class KafkaConsumer(six.Iterator): return False return True - def _update_fetch_positions(self, timeout_ms=None): + def _update_fetch_positions(self, partitions): """Set the fetch position to the committed position (if there is one) or reset it using the offset reset policy the user has configured. @@ -1127,36 +1089,30 @@ class KafkaConsumer(six.Iterator): partitions (List[TopicPartition]): The partitions that need updating fetch positions. - Returns True if fetch positions updated, False if timeout or async reset is pending - Raises: NoOffsetForPartitionError: If no offset is stored for a given partition and no offset reset policy is defined. """ - if self._subscription.has_all_fetch_positions(): - return True + # Lookup any positions for partitions which are awaiting reset (which may be the + # case if the user called :meth:`seek_to_beginning` or :meth:`seek_to_end`. We do + # this check first to avoid an unnecessary lookup of committed offsets (which + # typically occurs when the user is manually assigning partitions and managing + # their own offsets). + self._fetcher.reset_offsets_if_needed(partitions) - if (self.config['api_version'] >= (0, 8, 1) and - self.config['group_id'] is not None): - # If there are any partitions which do not have a valid position and are not - # awaiting reset, then we need to fetch committed offsets. We will only do a - # coordinator lookup if there are partitions which have missing positions, so - # a consumer with manually assigned partitions can avoid a coordinator dependence - # by always ensuring that assigned partitions have an initial position. - if not self._coordinator.refresh_committed_offsets_if_needed(timeout_ms=timeout_ms): - return False + if not self._subscription.has_all_fetch_positions(): + # if we still don't have offsets for all partitions, then we should either seek + # to the last committed position or reset using the auto reset policy + if (self.config['api_version'] >= (0, 8, 1) and + self.config['group_id'] is not None): + # first refresh commits for all assigned partitions + self._coordinator.refresh_committed_offsets_if_needed() - # If there are partitions still needing a position and a reset policy is defined, - # request reset using the default policy. If no reset strategy is defined and there - # are partitions with a missing position, then we will raise an exception. - self._subscription.reset_missing_positions() - - # Finally send an asynchronous request to lookup and update the positions of any - # partitions which are awaiting reset. - return not self._fetcher.reset_offsets_if_needed() + # Then, do any offset lookups in case some positions are not known + self._fetcher.update_fetch_positions(partitions) def _message_generator_v2(self): - timeout_ms = 1000 * max(0, self._consumer_timeout - time.time()) + timeout_ms = 1000 * (self._consumer_timeout - time.time()) record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False) for tp, records in six.iteritems(record_map): # Generators are stateful, and it is possible that the tp / records @@ -1171,15 +1127,72 @@ class KafkaConsumer(six.Iterator): log.debug("Not returning fetched records for partition %s" " since it is no longer fetchable", tp) break - self._subscription.assignment[tp].position = OffsetAndMetadata(record.offset + 1, '', -1) + self._subscription.assignment[tp].position = record.offset + 1 yield record + def _message_generator(self): + assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment' + while time.time() < self._consumer_timeout: + + self._coordinator.poll() + + # Fetch offsets for any subscribed partitions that we arent tracking yet + if not self._subscription.has_all_fetch_positions(): + partitions = self._subscription.missing_fetch_positions() + self._update_fetch_positions(partitions) + + poll_ms = min((1000 * (self._consumer_timeout - time.time())), self.config['retry_backoff_ms']) + self._client.poll(timeout_ms=poll_ms) + + # after the long poll, we should check whether the group needs to rebalance + # prior to returning data so that the group can stabilize faster + if self._coordinator.need_rejoin(): + continue + + # We need to make sure we at least keep up with scheduled tasks, + # like heartbeats, auto-commits, and metadata refreshes + timeout_at = self._next_timeout() + + # Short-circuit the fetch iterator if we are already timed out + # to avoid any unintentional interaction with fetcher setup + if time.time() > timeout_at: + continue + + for msg in self._fetcher: + yield msg + if time.time() > timeout_at: + log.debug("internal iterator timeout - breaking for poll") + break + self._client.poll(timeout_ms=0) + + # An else block on a for loop only executes if there was no break + # so this should only be called on a StopIteration from the fetcher + # We assume that it is safe to init_fetches when fetcher is done + # i.e., there are no more records stored internally + else: + self._fetcher.send_fetches() + + def _next_timeout(self): + timeout = min(self._consumer_timeout, + self._client.cluster.ttl() / 1000.0 + time.time(), + self._coordinator.time_to_next_poll() + time.time()) + return timeout + def __iter__(self): # pylint: disable=non-iterator-returned return self def __next__(self): if self._closed: raise StopIteration('KafkaConsumer closed') + # Now that the heartbeat thread runs in the background + # there should be no reason to maintain a separate iterator + # but we'll keep it available for a few releases just in case + if self.config['legacy_iterator']: + return self.next_v1() + else: + return self.next_v2() + + def next_v2(self): self._set_consumer_timeout() while time.time() < self._consumer_timeout: if not self._iterator: @@ -1190,6 +1203,17 @@ class KafkaConsumer(six.Iterator): self._iterator = None raise StopIteration() + def next_v1(self): + if not self._iterator: + self._iterator = self._message_generator() + + self._set_consumer_timeout() + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + raise + def _set_consumer_timeout(self): # consumer_timeout_ms can be used to stop iteration early if self.config['consumer_timeout_ms'] >= 0: diff --git a/venv/lib/python3.12/site-packages/kafka/consumer/subscription_state.py b/venv/lib/python3.12/site-packages/kafka/consumer/subscription_state.py index f99f016..08842d1 100644 --- a/venv/lib/python3.12/site-packages/kafka/consumer/subscription_state.py +++ b/venv/lib/python3.12/site-packages/kafka/consumer/subscription_state.py @@ -1,40 +1,18 @@ from __future__ import absolute_import import abc -from collections import OrderedDict -try: - from collections.abc import Sequence -except ImportError: - from collections import Sequence -try: - # enum in stdlib as of py3.4 - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum import logging -import random import re -import threading -import time from kafka.vendor import six -import kafka.errors as Errors -from kafka.protocol.list_offsets import OffsetResetStrategy +from kafka.errors import IllegalStateError +from kafka.protocol.offset import OffsetResetStrategy from kafka.structs import OffsetAndMetadata -from kafka.util import ensure_valid_topic_name, synchronized log = logging.getLogger(__name__) -class SubscriptionType(IntEnum): - NONE = 0 - AUTO_TOPICS = 1 - AUTO_PATTERN = 2 - USER_ASSIGNED = 3 - - class SubscriptionState(object): """ A class for tracking the topics, partitions, and offsets for the consumer. @@ -54,6 +32,10 @@ class SubscriptionState(object): Note that pause state as well as fetch/consumed positions are not preserved when partition assignment is changed whether directly by the user or through a group rebalance. + + This class also maintains a cache of the latest commit position for each of + the assigned partitions. This is updated through committed() and can be used + to set the initial fetch position (e.g. Fetcher._reset_offset() ). """ _SUBSCRIPTION_EXCEPTION_MESSAGE = ( "You must choose only one way to configure your consumer:" @@ -61,6 +43,10 @@ class SubscriptionState(object): " (2) subscribe to topics matching a regex pattern," " (3) assign itself specific topic-partitions.") + # Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29 + _MAX_NAME_LENGTH = 249 + _TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$') + def __init__(self, offset_reset_strategy='earliest'): """Initialize a SubscriptionState instance @@ -78,24 +64,15 @@ class SubscriptionState(object): self._default_offset_reset_strategy = offset_reset_strategy self.subscription = None # set() or None - self.subscription_type = SubscriptionType.NONE self.subscribed_pattern = None # regex str or None self._group_subscription = set() self._user_assignment = set() - self.assignment = OrderedDict() - self.rebalance_listener = None - self.listeners = [] - self._lock = threading.RLock() + self.assignment = dict() + self.listener = None - def _set_subscription_type(self, subscription_type): - if not isinstance(subscription_type, SubscriptionType): - raise ValueError('SubscriptionType enum required') - if self.subscription_type == SubscriptionType.NONE: - self.subscription_type = subscription_type - elif self.subscription_type != subscription_type: - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + # initialize to true for the consumers to fetch offset upon starting up + self.needs_fetch_committed_offsets = True - @synchronized def subscribe(self, topics=(), pattern=None, listener=None): """Subscribe to a list of topics, or a topic regex pattern. @@ -131,26 +108,39 @@ class SubscriptionState(object): guaranteed, however, that the partitions revoked/assigned through this interface are from topics subscribed in this call. """ + if self._user_assignment or (topics and pattern): + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) assert topics or pattern, 'Must provide topics or pattern' - if (topics and pattern): - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - elif pattern: - self._set_subscription_type(SubscriptionType.AUTO_PATTERN) + if pattern: log.info('Subscribing to pattern: /%s/', pattern) self.subscription = set() self.subscribed_pattern = re.compile(pattern) else: - if isinstance(topics, str) or not isinstance(topics, Sequence): - raise TypeError('Topics must be a list (or non-str sequence)') - self._set_subscription_type(SubscriptionType.AUTO_TOPICS) self.change_subscription(topics) if listener and not isinstance(listener, ConsumerRebalanceListener): raise TypeError('listener must be a ConsumerRebalanceListener') - self.rebalance_listener = listener + self.listener = listener + + def _ensure_valid_topic_name(self, topic): + """ Ensures that the topic name is valid according to the kafka source. """ + + # See Kafka Source: + # https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java + if topic is None: + raise TypeError('All topics must not be None') + if not isinstance(topic, six.string_types): + raise TypeError('All topics must be strings') + if len(topic) == 0: + raise ValueError('All topics must be non-empty strings') + if topic == '.' or topic == '..': + raise ValueError('Topic name cannot be "." or ".."') + if len(topic) > self._MAX_NAME_LENGTH: + raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(self._MAX_NAME_LENGTH, topic)) + if not self._TOPIC_LEGAL_CHARS.match(topic): + raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic)) - @synchronized def change_subscription(self, topics): """Change the topic subscription. @@ -164,8 +154,8 @@ class SubscriptionState(object): - a topic name is '.' or '..' or - a topic name does not consist of ASCII-characters/'-'/'_'/'.' """ - if not self.partitions_auto_assigned(): - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) if isinstance(topics, six.string_types): topics = [topics] @@ -176,13 +166,17 @@ class SubscriptionState(object): return for t in topics: - ensure_valid_topic_name(t) + self._ensure_valid_topic_name(t) log.info('Updating subscribed topics to: %s', topics) self.subscription = set(topics) self._group_subscription.update(topics) - @synchronized + # Remove any assigned partitions which are no longer subscribed to + for tp in set(self.assignment.keys()): + if tp.topic not in self.subscription: + del self.assignment[tp] + def group_subscribe(self, topics): """Add topics to the current group subscription. @@ -192,19 +186,17 @@ class SubscriptionState(object): Arguments: topics (list of str): topics to add to the group subscription """ - if not self.partitions_auto_assigned(): - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) self._group_subscription.update(topics) - @synchronized def reset_group_subscription(self): """Reset the group's subscription to only contain topics subscribed by this consumer.""" - if not self.partitions_auto_assigned(): - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) assert self.subscription is not None, 'Subscription required' self._group_subscription.intersection_update(self.subscription) - @synchronized def assign_from_user(self, partitions): """Manually assign a list of TopicPartitions to this consumer. @@ -223,13 +215,21 @@ class SubscriptionState(object): Raises: IllegalStateError: if consumer has already called subscribe() """ - self._set_subscription_type(SubscriptionType.USER_ASSIGNED) + if self.subscription is not None: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + if self._user_assignment != set(partitions): self._user_assignment = set(partitions) - self._set_assignment({partition: self.assignment.get(partition, TopicPartitionState()) - for partition in partitions}) - @synchronized + for partition in partitions: + if partition not in self.assignment: + self._add_assigned_partition(partition) + + for tp in set(self.assignment.keys()) - self._user_assignment: + del self.assignment[tp] + + self.needs_fetch_committed_offsets = True + def assign_from_subscribed(self, assignments): """Update the assignment to the specified partitions @@ -243,39 +243,26 @@ class SubscriptionState(object): consumer instance. """ if not self.partitions_auto_assigned(): - raise Errors.IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) for tp in assignments: if tp.topic not in self.subscription: raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,)) - # randomized ordering should improve balance for short-lived consumers - self._set_assignment({partition: TopicPartitionState() for partition in assignments}, randomize=True) + # after rebalancing, we always reinitialize the assignment state + self.assignment.clear() + for tp in assignments: + self._add_assigned_partition(tp) + self.needs_fetch_committed_offsets = True log.info("Updated partition assignment: %s", assignments) - def _set_assignment(self, partition_states, randomize=False): - """Batch partition assignment by topic (self.assignment is OrderedDict)""" - self.assignment.clear() - topics = [tp.topic for tp in six.iterkeys(partition_states)] - if randomize: - random.shuffle(topics) - topic_partitions = OrderedDict({topic: [] for topic in topics}) - for tp in six.iterkeys(partition_states): - topic_partitions[tp.topic].append(tp) - for topic in six.iterkeys(topic_partitions): - for tp in topic_partitions[topic]: - self.assignment[tp] = partition_states[tp] - - @synchronized def unsubscribe(self): """Clear all topic subscriptions and partition assignments""" self.subscription = None self._user_assignment.clear() self.assignment.clear() self.subscribed_pattern = None - self.subscription_type = SubscriptionType.NONE - @synchronized def group_subscription(self): """Get the topic subscription for the group. @@ -291,7 +278,6 @@ class SubscriptionState(object): """ return self._group_subscription - @synchronized def seek(self, partition, offset): """Manually specify the fetch offset for a TopicPartition. @@ -303,48 +289,40 @@ class SubscriptionState(object): Arguments: partition (TopicPartition): partition for seek operation - offset (int or OffsetAndMetadata): message offset in partition + offset (int): message offset in partition """ - if not isinstance(offset, (int, OffsetAndMetadata)): - raise TypeError("offset must be type in or OffsetAndMetadata") self.assignment[partition].seek(offset) - @synchronized def assigned_partitions(self): """Return set of TopicPartitions in current assignment.""" return set(self.assignment.keys()) - @synchronized def paused_partitions(self): """Return current set of paused TopicPartitions.""" return set(partition for partition in self.assignment if self.is_paused(partition)) - @synchronized def fetchable_partitions(self): - """Return ordered list of TopicPartitions that should be Fetched.""" - fetchable = list() + """Return set of TopicPartitions that should be Fetched.""" + fetchable = set() for partition, state in six.iteritems(self.assignment): if state.is_fetchable(): - fetchable.append(partition) + fetchable.add(partition) return fetchable - @synchronized def partitions_auto_assigned(self): """Return True unless user supplied partitions manually.""" - return self.subscription_type in (SubscriptionType.AUTO_TOPICS, SubscriptionType.AUTO_PATTERN) + return self.subscription is not None - @synchronized def all_consumed_offsets(self): """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}""" all_consumed = {} for partition, state in six.iteritems(self.assignment): if state.has_valid_position: - all_consumed[partition] = state.position + all_consumed[partition] = OffsetAndMetadata(state.position, '') return all_consumed - @synchronized - def request_offset_reset(self, partition, offset_reset_strategy=None): + def need_offset_reset(self, partition, offset_reset_strategy=None): """Mark partition for offset reset using specified or default strategy. Arguments: @@ -353,113 +331,63 @@ class SubscriptionState(object): """ if offset_reset_strategy is None: offset_reset_strategy = self._default_offset_reset_strategy - self.assignment[partition].reset(offset_reset_strategy) + self.assignment[partition].await_reset(offset_reset_strategy) - @synchronized - def set_reset_pending(self, partitions, next_allowed_reset_time): - for partition in partitions: - self.assignment[partition].set_reset_pending(next_allowed_reset_time) - - @synchronized def has_default_offset_reset_policy(self): """Return True if default offset reset policy is Earliest or Latest""" return self._default_offset_reset_strategy != OffsetResetStrategy.NONE - @synchronized def is_offset_reset_needed(self, partition): return self.assignment[partition].awaiting_reset - @synchronized def has_all_fetch_positions(self): - for state in six.itervalues(self.assignment): + for state in self.assignment.values(): if not state.has_valid_position: return False return True - @synchronized def missing_fetch_positions(self): missing = set() for partition, state in six.iteritems(self.assignment): - if state.is_missing_position(): + if not state.has_valid_position: missing.add(partition) return missing - @synchronized - def has_valid_position(self, partition): - return partition in self.assignment and self.assignment[partition].has_valid_position - - @synchronized - def reset_missing_positions(self): - partitions_with_no_offsets = set() - for tp, state in six.iteritems(self.assignment): - if state.is_missing_position(): - if self._default_offset_reset_strategy == OffsetResetStrategy.NONE: - partitions_with_no_offsets.add(tp) - else: - state.reset(self._default_offset_reset_strategy) - - if partitions_with_no_offsets: - raise Errors.NoOffsetForPartitionError(partitions_with_no_offsets) - - @synchronized - def partitions_needing_reset(self): - partitions = set() - for tp, state in six.iteritems(self.assignment): - if state.awaiting_reset and state.is_reset_allowed(): - partitions.add(tp) - return partitions - - @synchronized def is_assigned(self, partition): return partition in self.assignment - @synchronized def is_paused(self, partition): return partition in self.assignment and self.assignment[partition].paused - @synchronized def is_fetchable(self, partition): return partition in self.assignment and self.assignment[partition].is_fetchable() - @synchronized def pause(self, partition): self.assignment[partition].pause() - @synchronized def resume(self, partition): self.assignment[partition].resume() - @synchronized - def reset_failed(self, partitions, next_retry_time): - for partition in partitions: - self.assignment[partition].reset_failed(next_retry_time) - - @synchronized - def move_partition_to_end(self, partition): - if partition in self.assignment: - try: - self.assignment.move_to_end(partition) - except AttributeError: - state = self.assignment.pop(partition) - self.assignment[partition] = state - - @synchronized - def position(self, partition): - return self.assignment[partition].position + def _add_assigned_partition(self, partition): + self.assignment[partition] = TopicPartitionState() class TopicPartitionState(object): def __init__(self): + self.committed = None # last committed OffsetAndMetadata + self.has_valid_position = False # whether we have valid position self.paused = False # whether this partition has been paused by the user - self.reset_strategy = None # the reset strategy if awaiting_reset is set - self._position = None # OffsetAndMetadata exposed to the user + self.awaiting_reset = False # whether we are awaiting reset + self.reset_strategy = None # the reset strategy if awaitingReset is set + self._position = None # offset exposed to the user self.highwater = None - self.drop_pending_record_batch = False - self.next_allowed_retry_time = None + self.drop_pending_message_set = False + # The last message offset hint available from a message batch with + # magic=2 which includes deleted compacted messages + self.last_offset_from_message_batch = None def _set_position(self, offset): assert self.has_valid_position, 'Valid position required' - assert isinstance(offset, OffsetAndMetadata) self._position = offset def _get_position(self): @@ -467,37 +395,20 @@ class TopicPartitionState(object): position = property(_get_position, _set_position, None, "last position") - def reset(self, strategy): - assert strategy is not None + def await_reset(self, strategy): + self.awaiting_reset = True self.reset_strategy = strategy self._position = None - self.next_allowed_retry_time = None - - def is_reset_allowed(self): - return self.next_allowed_retry_time is None or self.next_allowed_retry_time < time.time() - - @property - def awaiting_reset(self): - return self.reset_strategy is not None - - def set_reset_pending(self, next_allowed_retry_time): - self.next_allowed_retry_time = next_allowed_retry_time - - def reset_failed(self, next_allowed_retry_time): - self.next_allowed_retry_time = next_allowed_retry_time - - @property - def has_valid_position(self): - return self._position is not None - - def is_missing_position(self): - return not self.has_valid_position and not self.awaiting_reset + self.last_offset_from_message_batch = None + self.has_valid_position = False def seek(self, offset): - self._position = offset if isinstance(offset, OffsetAndMetadata) else OffsetAndMetadata(offset, '', -1) + self._position = offset + self.awaiting_reset = False self.reset_strategy = None - self.drop_pending_record_batch = True - self.next_allowed_retry_time = None + self.has_valid_position = True + self.drop_pending_message_set = True + self.last_offset_from_message_batch = None def pause(self): self.paused = True @@ -509,7 +420,6 @@ class TopicPartitionState(object): return not self.paused and self.has_valid_position -@six.add_metaclass(abc.ABCMeta) class ConsumerRebalanceListener(object): """ A callback interface that the user can implement to trigger custom actions @@ -551,6 +461,8 @@ class ConsumerRebalanceListener(object): taking over that partition has their on_partitions_assigned() callback called to load the state. """ + __metaclass__ = abc.ABCMeta + @abc.abstractmethod def on_partitions_revoked(self, revoked): """ diff --git a/venv/lib/python3.12/site-packages/kafka/coordinator/assignors/sticky/sticky_assignor.py b/venv/lib/python3.12/site-packages/kafka/coordinator/assignors/sticky/sticky_assignor.py index 69f68f5..7827086 100644 --- a/venv/lib/python3.12/site-packages/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/venv/lib/python3.12/site-packages/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -2,6 +2,7 @@ import logging from collections import defaultdict, namedtuple from copy import deepcopy +from kafka.cluster import ClusterMetadata from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from kafka.coordinator.assignors.sticky.sorted_set import SortedSet @@ -647,19 +648,15 @@ class StickyPartitionAssignor(AbstractPartitionAssignor): @classmethod def metadata(cls, topics): - return cls._metadata(topics, cls.member_assignment, cls.generation) - - @classmethod - def _metadata(cls, topics, member_assignment_partitions, generation=-1): - if member_assignment_partitions is None: + if cls.member_assignment is None: log.debug("No member assignment available") user_data = b'' else: log.debug("Member assignment is available, generating the metadata: generation {}".format(cls.generation)) partitions_by_topic = defaultdict(list) - for topic_partition in member_assignment_partitions: + for topic_partition in cls.member_assignment: # pylint: disable=not-an-iterable partitions_by_topic[topic_partition.topic].append(topic_partition.partition) - data = StickyAssignorUserDataV1(list(partitions_by_topic.items()), generation) + data = StickyAssignorUserDataV1(six.iteritems(partitions_by_topic), cls.generation) user_data = data.encode() return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data) diff --git a/venv/lib/python3.12/site-packages/kafka/coordinator/base.py b/venv/lib/python3.12/site-packages/kafka/coordinator/base.py index 5e1f726..5e41309 100644 --- a/venv/lib/python3.12/site-packages/kafka/coordinator/base.py +++ b/venv/lib/python3.12/site-packages/kafka/coordinator/base.py @@ -5,7 +5,6 @@ import copy import logging import threading import time -import warnings import weakref from kafka.vendor import six @@ -15,12 +14,11 @@ from kafka import errors as Errors from kafka.future import Future from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.group import HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID -from kafka.util import Timer +from kafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest +from kafka.protocol.group import (HeartbeatRequest, JoinGroupRequest, + LeaveGroupRequest, SyncGroupRequest) log = logging.getLogger('kafka.coordinator') -heartbeat_log = logging.getLogger('kafka.coordinator.heartbeat') class MemberState(object): @@ -35,20 +33,10 @@ class Generation(object): self.member_id = member_id self.protocol = protocol - @property - def is_valid(self): - return self.generation_id != DEFAULT_GENERATION_ID - - def __eq__(self, other): - return (self.generation_id == other.generation_id and - self.member_id == other.member_id and - self.protocol == other.protocol) - - def __str__(self): - return "" % (self.generation_id, self.member_id, self.protocol) - - -Generation.NO_GENERATION = Generation(DEFAULT_GENERATION_ID, UNKNOWN_MEMBER_ID, None) +Generation.NO_GENERATION = Generation( + OffsetCommitRequest[2].DEFAULT_GENERATION_ID, + JoinGroupRequest[0].UNKNOWN_MEMBER_ID, + None) class UnjoinedGroupException(Errors.KafkaError): @@ -99,11 +87,10 @@ class BaseCoordinator(object): 'max_poll_interval_ms': 300000, 'retry_backoff_ms': 100, 'api_version': (0, 10, 1), - 'metrics': None, 'metric_group_prefix': '', } - def __init__(self, client, **configs): + def __init__(self, client, metrics, **configs): """ Keyword Arguments: group_id (str): name of the consumer group to join for dynamic @@ -146,11 +133,8 @@ class BaseCoordinator(object): self.coordinator_id = None self._find_coordinator_future = None self._generation = Generation.NO_GENERATION - if self.config['metrics']: - self._sensors = GroupCoordinatorMetrics(self.heartbeat, self.config['metrics'], - self.config['metric_group_prefix']) - else: - self._sensors = None + self.sensors = GroupCoordinatorMetrics(self.heartbeat, metrics, + self.config['metric_group_prefix']) @abc.abstractmethod def protocol_type(self): @@ -182,7 +166,7 @@ class BaseCoordinator(object): pass @abc.abstractmethod - def _on_join_prepare(self, generation, member_id, timeout_ms=None): + def _on_join_prepare(self, generation, member_id): """Invoked prior to each group join or rejoin. This is typically used to perform any cleanup from the previous @@ -248,27 +232,16 @@ class BaseCoordinator(object): """ if self.coordinator_id is None: return None - elif self._client.is_disconnected(self.coordinator_id) and self._client.connection_delay(self.coordinator_id) > 0: + elif self._client.is_disconnected(self.coordinator_id): self.coordinator_dead('Node Disconnected') return None else: return self.coordinator_id - def connected(self): - """Return True iff the coordinator node is connected""" - with self._lock: - return self.coordinator_id is not None and self._client.connected(self.coordinator_id) - - def ensure_coordinator_ready(self, timeout_ms=None): - """Block until the coordinator for this group is known. - - Keyword Arguments: - timeout_ms (numeric, optional): Maximum number of milliseconds to - block waiting to find coordinator. Default: None. - - Returns: True is coordinator found before timeout_ms, else False + def ensure_coordinator_ready(self): + """Block until the coordinator for this group is known + (and we have an active connection -- java client uses unsent queue). """ - timer = Timer(timeout_ms) with self._client._lock, self._lock: while self.coordinator_unknown(): @@ -276,49 +249,30 @@ class BaseCoordinator(object): # so we will just pick a node at random and treat # it as the "coordinator" if self.config['api_version'] < (0, 8, 2): - maybe_coordinator_id = self._client.least_loaded_node() - if maybe_coordinator_id is None or self._client.cluster.is_bootstrap(maybe_coordinator_id): - future = Future().failure(Errors.NoBrokersAvailable()) - else: - self.coordinator_id = maybe_coordinator_id + self.coordinator_id = self._client.least_loaded_node() + if self.coordinator_id is not None: self._client.maybe_connect(self.coordinator_id) - if timer.expired: - return False - else: - continue - else: - future = self.lookup_coordinator() + continue - self._client.poll(future=future, timeout_ms=timer.timeout_ms) - - if not future.is_done: - return False + future = self.lookup_coordinator() + self._client.poll(future=future) if future.failed(): if future.retriable(): if getattr(future.exception, 'invalid_metadata', False): log.debug('Requesting metadata for group coordinator request: %s', future.exception) metadata_update = self._client.cluster.request_update() - self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms) - if not metadata_update.is_done: - return False + self._client.poll(future=metadata_update) else: - if timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) + time.sleep(self.config['retry_backoff_ms'] / 1000) else: raise future.exception # pylint: disable-msg=raising-bad-type - if timer.expired: - return False - else: - return True def _reset_find_coordinator_future(self, result): self._find_coordinator_future = None def lookup_coordinator(self): - with self._client._lock, self._lock: + with self._lock: if self._find_coordinator_future is not None: return self._find_coordinator_future @@ -376,139 +330,103 @@ class BaseCoordinator(object): return float('inf') return self.heartbeat.time_to_next_heartbeat() - def _reset_join_group_future(self): - with self._lock: - self.join_future = None - - def _initiate_join_group(self): - with self._lock: - # we store the join future in case we are woken up by the user - # after beginning the rebalance in the call to poll below. - # This ensures that we do not mistakenly attempt to rejoin - # before the pending rebalance has completed. - if self.join_future is None: - self.state = MemberState.REBALANCING - self.join_future = self._send_join_group_request() - - # handle join completion in the callback so that the - # callback will be invoked even if the consumer is woken up - # before finishing the rebalance - self.join_future.add_callback(self._handle_join_success) - - # we handle failures below after the request finishes. - # If the join completes after having been woken up, the - # exception is ignored and we will rejoin - self.join_future.add_errback(self._handle_join_failure) - - return self.join_future - def _handle_join_success(self, member_assignment_bytes): - # handle join completion in the callback so that the callback - # will be invoked even if the consumer is woken up before - # finishing the rebalance with self._lock: + log.info("Successfully joined group %s with generation %s", + self.group_id, self._generation.generation_id) self.state = MemberState.STABLE + self.rejoin_needed = False if self._heartbeat_thread: self._heartbeat_thread.enable() - def _handle_join_failure(self, exception): - # we handle failures below after the request finishes. - # if the join completes after having been woken up, - # the exception is ignored and we will rejoin + def _handle_join_failure(self, _): with self._lock: - log.info("Failed to join group %s: %s", self.group_id, exception) self.state = MemberState.UNJOINED - def ensure_active_group(self, timeout_ms=None): - """Ensure that the group is active (i.e. joined and synced) + def ensure_active_group(self): + """Ensure that the group is active (i.e. joined and synced)""" + with self._client._lock, self._lock: + if self._heartbeat_thread is None: + self._start_heartbeat_thread() - Keyword Arguments: - timeout_ms (numeric, optional): Maximum number of milliseconds to - block waiting to join group. Default: None. + while self.need_rejoin() or self._rejoin_incomplete(): + self.ensure_coordinator_ready() - Returns: True if group initialized before timeout_ms, else False - """ - if self.config['api_version'] < (0, 9): - raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') - timer = Timer(timeout_ms) - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): - return False - self._start_heartbeat_thread() - return self.join_group(timeout_ms=timer.timeout_ms) + # call on_join_prepare if needed. We set a flag + # to make sure that we do not call it a second + # time if the client is woken up before a pending + # rebalance completes. This must be called on each + # iteration of the loop because an event requiring + # a rebalance (such as a metadata refresh which + # changes the matched subscription set) can occur + # while another rebalance is still in progress. + if not self.rejoining: + self._on_join_prepare(self._generation.generation_id, + self._generation.member_id) + self.rejoining = True - def join_group(self, timeout_ms=None): - if self.config['api_version'] < (0, 9): - raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') - timer = Timer(timeout_ms) - while self.need_rejoin(): - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): - return False - - # call on_join_prepare if needed. We set a flag - # to make sure that we do not call it a second - # time if the client is woken up before a pending - # rebalance completes. This must be called on each - # iteration of the loop because an event requiring - # a rebalance (such as a metadata refresh which - # changes the matched subscription set) can occur - # while another rebalance is still in progress. - if not self.rejoining: - self._on_join_prepare(self._generation.generation_id, - self._generation.member_id, - timeout_ms=timer.timeout_ms) - self.rejoining = True - - # fence off the heartbeat thread explicitly so that it cannot - # interfere with the join group. # Note that this must come after - # the call to onJoinPrepare since we must be able to continue - # sending heartbeats if that callback takes some time. - log.debug("Disabling heartbeat thread during join-group") - self._disable_heartbeat_thread() - - # ensure that there are no pending requests to the coordinator. - # This is important in particular to avoid resending a pending - # JoinGroup request. - while not self.coordinator_unknown(): - if not self._client.in_flight_request_count(self.coordinator_id): - break - poll_timeout_ms = 200 if timer.timeout_ms is None or timer.timeout_ms > 200 else timer.timeout_ms - self._client.poll(timeout_ms=poll_timeout_ms) - if timer.expired: - return False - else: - continue - - future = self._initiate_join_group() - self._client.poll(future=future, timeout_ms=timer.timeout_ms) - if future.is_done: - self._reset_join_group_future() - else: - return False - - if future.succeeded(): - self.rejoining = False - self.rejoin_needed = False - self._on_join_complete(self._generation.generation_id, - self._generation.member_id, - self._generation.protocol, - future.value) - return True - else: - exception = future.exception - if isinstance(exception, (Errors.UnknownMemberIdError, - Errors.RebalanceInProgressError, - Errors.IllegalGenerationError, - Errors.MemberIdRequiredError)): - continue - elif not future.retriable(): - raise exception # pylint: disable-msg=raising-bad-type - elif timer.expired: - return False + # ensure that there are no pending requests to the coordinator. + # This is important in particular to avoid resending a pending + # JoinGroup request. + while not self.coordinator_unknown(): + if not self._client.in_flight_request_count(self.coordinator_id): + break + self._client.poll() else: - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) + continue + + # we store the join future in case we are woken up by the user + # after beginning the rebalance in the call to poll below. + # This ensures that we do not mistakenly attempt to rejoin + # before the pending rebalance has completed. + if self.join_future is None: + # Fence off the heartbeat thread explicitly so that it cannot + # interfere with the join group. Note that this must come after + # the call to _on_join_prepare since we must be able to continue + # sending heartbeats if that callback takes some time. + self._heartbeat_thread.disable() + + self.state = MemberState.REBALANCING + future = self._send_join_group_request() + + self.join_future = future # this should happen before adding callbacks + + # handle join completion in the callback so that the + # callback will be invoked even if the consumer is woken up + # before finishing the rebalance + future.add_callback(self._handle_join_success) + + # we handle failures below after the request finishes. + # If the join completes after having been woken up, the + # exception is ignored and we will rejoin + future.add_errback(self._handle_join_failure) + + else: + future = self.join_future + + self._client.poll(future=future) + + if future.succeeded(): + self._on_join_complete(self._generation.generation_id, + self._generation.member_id, + self._generation.protocol, + future.value) + self.join_future = None + self.rejoining = False + + else: + self.join_future = None + exception = future.exception + if isinstance(exception, (Errors.UnknownMemberIdError, + Errors.RebalanceInProgressError, + Errors.IllegalGenerationError)): + continue + elif not future.retriable(): + raise exception # pylint: disable-msg=raising-bad-type + time.sleep(self.config['retry_backoff_ms'] / 1000) + + def _rejoin_incomplete(self): + return self.join_future is not None def _send_join_group_request(self): """Join the group and return the assignment for the next generation. @@ -521,7 +439,7 @@ class BaseCoordinator(object): group leader """ if self.coordinator_unknown(): - e = Errors.CoordinatorNotAvailableError(self.coordinator_id) + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) return Future().failure(e) elif not self._client.ready(self.coordinator_id, metadata_priority=False): @@ -534,16 +452,25 @@ class BaseCoordinator(object): (protocol, metadata if isinstance(metadata, bytes) else metadata.encode()) for protocol, metadata in self.group_protocols() ] - version = self._client.api_version(JoinGroupRequest, max_version=4) - if version == 0: - request = JoinGroupRequest[version]( + if self.config['api_version'] < (0, 9): + raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers') + elif (0, 9) <= self.config['api_version'] < (0, 10, 1): + request = JoinGroupRequest[0]( self.group_id, self.config['session_timeout_ms'], self._generation.member_id, self.protocol_type(), member_metadata) + elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0): + request = JoinGroupRequest[1]( + self.group_id, + self.config['session_timeout_ms'], + self.config['max_poll_interval_ms'], + self._generation.member_id, + self.protocol_type(), + member_metadata) else: - request = JoinGroupRequest[version]( + request = JoinGroupRequest[2]( self.group_id, self.config['session_timeout_ms'], self.config['max_poll_interval_ms'], @@ -562,9 +489,8 @@ class BaseCoordinator(object): def _failed_request(self, node_id, request, future, error): # Marking coordinator dead - # unless the error is caused by internal client pipelining or throttling + # unless the error is caused by internal client pipelining if not isinstance(error, (Errors.NodeNotReadyError, - Errors.ThrottlingQuotaExceededError, Errors.TooManyInFlightRequests)): log.error('Error sending %s to node %s [%s]', request.__class__.__name__, node_id, error) @@ -575,11 +501,11 @@ class BaseCoordinator(object): future.failure(error) def _handle_join_group_response(self, future, send_time, response): - log.debug("Received JoinGroup response: %s", response) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - if self._sensors: - self._sensors.join_latency.record((time.time() - send_time) * 1000) + log.debug("Received successful JoinGroup response for group %s: %s", + self.group_id, response) + self.sensors.join_latency.record((time.time() - send_time) * 1000) with self._lock: if self.state is not MemberState.REBALANCING: # if the consumer was woken up before a rebalance completes, @@ -591,7 +517,6 @@ class BaseCoordinator(object): response.member_id, response.group_protocol) - log.info("Successfully joined group %s %s", self.group_id, self._generation) if response.leader_id == response.member_id: log.info("Elected group leader -- performing partition" " assignments using %s", self._generation.protocol) @@ -599,25 +524,25 @@ class BaseCoordinator(object): else: self._on_join_follower().chain(future) - elif error_type is Errors.CoordinatorLoadInProgressError: - log.info("Attempt to join group %s rejected since coordinator %s" - " is loading the group.", self.group_id, self.coordinator_id) + elif error_type is Errors.GroupLoadInProgressError: + log.debug("Attempt to join group %s rejected since coordinator %s" + " is loading the group.", self.group_id, self.coordinator_id) # backoff and retry future.failure(error_type(response)) elif error_type is Errors.UnknownMemberIdError: # reset the member id and retry immediately error = error_type(self._generation.member_id) self.reset_generation() - log.info("Attempt to join group %s failed due to unknown member id", - self.group_id) + log.debug("Attempt to join group %s failed due to unknown member id", + self.group_id) future.failure(error) - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError): + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): # re-discover the coordinator and retry with backoff self.coordinator_dead(error_type()) - log.info("Attempt to join group %s failed due to obsolete " - "coordinator information: %s", self.group_id, - error_type.__name__) + log.debug("Attempt to join group %s failed due to obsolete " + "coordinator information: %s", self.group_id, + error_type.__name__) future.failure(error_type()) elif error_type in (Errors.InconsistentGroupProtocolError, Errors.InvalidSessionTimeoutError, @@ -628,21 +553,7 @@ class BaseCoordinator(object): self.group_id, error) future.failure(error) elif error_type is Errors.GroupAuthorizationFailedError: - log.error("Attempt to join group %s failed due to group authorization error", - self.group_id) future.failure(error_type(self.group_id)) - elif error_type is Errors.MemberIdRequiredError: - # Broker requires a concrete member id to be allowed to join the group. Update member id - # and send another join group request in next cycle. - log.info("Received member id %s for group %s; will retry join-group", - response.member_id, self.group_id) - self.reset_generation(response.member_id) - future.failure(error_type()) - elif error_type is Errors.RebalanceInProgressError: - log.info("Attempt to join group %s failed due to RebalanceInProgressError," - " which could indicate a replication timeout on the broker. Will retry.", - self.group_id) - future.failure(error_type()) else: # unexpected error, throw the exception error = error_type() @@ -651,7 +562,7 @@ class BaseCoordinator(object): def _on_join_follower(self): # send follower's sync group with an empty assignment - version = self._client.api_version(SyncGroupRequest, max_version=2) + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 request = SyncGroupRequest[version]( self.group_id, self._generation.generation_id, @@ -679,7 +590,7 @@ class BaseCoordinator(object): except Exception as e: return Future().failure(e) - version = self._client.api_version(SyncGroupRequest, max_version=2) + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 request = SyncGroupRequest[version]( self.group_id, self._generation.generation_id, @@ -694,7 +605,7 @@ class BaseCoordinator(object): def _send_sync_group_request(self, request): if self.coordinator_unknown(): - e = Errors.CoordinatorNotAvailableError(self.coordinator_id) + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) return Future().failure(e) # We assume that coordinator is ready if we're sending SyncGroup @@ -711,11 +622,9 @@ class BaseCoordinator(object): return future def _handle_sync_group_response(self, future, send_time, response): - log.debug("Received SyncGroup response: %s", response) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - if self._sensors: - self._sensors.sync_latency.record((time.time() - send_time) * 1000) + self.sensors.sync_latency.record((time.time() - send_time) * 1000) future.success(response.member_assignment) return @@ -724,19 +633,19 @@ class BaseCoordinator(object): if error_type is Errors.GroupAuthorizationFailedError: future.failure(error_type(self.group_id)) elif error_type is Errors.RebalanceInProgressError: - log.info("SyncGroup for group %s failed due to coordinator" - " rebalance", self.group_id) + log.debug("SyncGroup for group %s failed due to coordinator" + " rebalance", self.group_id) future.failure(error_type(self.group_id)) elif error_type in (Errors.UnknownMemberIdError, Errors.IllegalGenerationError): error = error_type() - log.info("SyncGroup for group %s failed due to %s", self.group_id, error) + log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) self.reset_generation() future.failure(error) - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError): + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): error = error_type() - log.info("SyncGroup for group %s failed due to %s", self.group_id, error) + log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) self.coordinator_dead(error) future.failure(error) else: @@ -751,20 +660,16 @@ class BaseCoordinator(object): Future: resolves to the node id of the coordinator """ node_id = self._client.least_loaded_node() - if node_id is None or self._client.cluster.is_bootstrap(node_id): + if node_id is None: return Future().failure(Errors.NoBrokersAvailable()) elif not self._client.ready(node_id, metadata_priority=False): e = Errors.NodeNotReadyError(node_id) return Future().failure(e) - version = self._client.api_version(FindCoordinatorRequest, max_version=2) - if version == 0: - request = FindCoordinatorRequest[version](self.group_id) - else: - request = FindCoordinatorRequest[version](self.group_id, 0) - log.debug("Sending group coordinator request for group %s to broker %s: %s", - self.group_id, node_id, request) + log.debug("Sending group coordinator request for group %s to broker %s", + self.group_id, node_id) + request = GroupCoordinatorRequest[0](self.group_id) future = Future() _f = self._client.send(node_id, request) _f.add_callback(self._handle_group_coordinator_response, future) @@ -777,7 +682,7 @@ class BaseCoordinator(object): error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: with self._lock: - coordinator_id = self._client.cluster.add_coordinator(response, 'group', self.group_id) + coordinator_id = self._client.cluster.add_group_coordinator(self.group_id, response) if not coordinator_id: # This could happen if coordinator metadata is different # than broker metadata @@ -791,7 +696,7 @@ class BaseCoordinator(object): self.heartbeat.reset_timeouts() future.success(self.coordinator_id) - elif error_type is Errors.CoordinatorNotAvailableError: + elif error_type is Errors.GroupCoordinatorNotAvailableError: log.debug("Group Coordinator Not Available; retry") future.failure(error_type()) elif error_type is Errors.GroupAuthorizationFailedError: @@ -800,7 +705,7 @@ class BaseCoordinator(object): future.failure(error) else: error = error_type() - log.error("Group Coordinator lookup for group %s failed: %s", + log.error("Group coordinator lookup for group %s failed: %s", self.group_id, error) future.failure(error) @@ -811,7 +716,7 @@ class BaseCoordinator(object): self.coordinator_id, self.group_id, error) self.coordinator_id = None - def generation_if_stable(self): + def generation(self): """Get the current generation state if the group is stable. Returns: the current generation or None if the group is unjoined/rebalancing @@ -821,19 +726,10 @@ class BaseCoordinator(object): return None return self._generation - # deprecated - def generation(self): - warnings.warn("Function coordinator.generation() has been renamed to generation_if_stable()", - DeprecationWarning, stacklevel=2) - return self.generation_if_stable() - - def rebalance_in_progress(self): - return self.state is MemberState.REBALANCING - - def reset_generation(self, member_id=UNKNOWN_MEMBER_ID): - """Reset the generation and member_id because we have fallen out of the group.""" + def reset_generation(self): + """Reset the generation and memberId because we have fallen out of the group.""" with self._lock: - self._generation = Generation(DEFAULT_GENERATION_ID, member_id, None) + self._generation = Generation.NO_GENERATION self.rejoin_needed = True self.state = MemberState.UNJOINED @@ -841,90 +737,73 @@ class BaseCoordinator(object): self.rejoin_needed = True def _start_heartbeat_thread(self): - if self.config['api_version'] < (0, 9): - raise Errors.UnsupportedVersionError('Heartbeat APIs require 0.9+ broker') - with self._lock: - if self._heartbeat_thread is None: - heartbeat_log.info('Starting new heartbeat thread') - self._heartbeat_thread = HeartbeatThread(weakref.proxy(self)) - self._heartbeat_thread.daemon = True - self._heartbeat_thread.start() - heartbeat_log.debug("Started heartbeat thread %s", self._heartbeat_thread.ident) + if self._heartbeat_thread is None: + log.info('Starting new heartbeat thread') + self._heartbeat_thread = HeartbeatThread(weakref.proxy(self)) + self._heartbeat_thread.daemon = True + self._heartbeat_thread.start() - def _disable_heartbeat_thread(self): - with self._lock: - if self._heartbeat_thread is not None: - self._heartbeat_thread.disable() - - def _close_heartbeat_thread(self, timeout_ms=None): + def _close_heartbeat_thread(self): if self._heartbeat_thread is not None: + log.info('Stopping heartbeat thread') try: - self._heartbeat_thread.close(timeout_ms=timeout_ms) + self._heartbeat_thread.close() except ReferenceError: pass self._heartbeat_thread = None def __del__(self): - try: - self._close_heartbeat_thread() - except (TypeError, AttributeError): - pass + self._close_heartbeat_thread() - def close(self, timeout_ms=None): + def close(self): """Close the coordinator, leave the current group, and reset local generation / member_id""" - self._close_heartbeat_thread(timeout_ms=timeout_ms) - if self.config['api_version'] >= (0, 9): - self.maybe_leave_group(timeout_ms=timeout_ms) + self._close_heartbeat_thread() + self.maybe_leave_group() - def maybe_leave_group(self, timeout_ms=None): + def maybe_leave_group(self): """Leave the current group and reset local generation/memberId.""" - if self.config['api_version'] < (0, 9): - raise Errors.UnsupportedVersionError('Group Coordinator APIs require 0.9+ broker') with self._client._lock, self._lock: if (not self.coordinator_unknown() and self.state is not MemberState.UNJOINED - and self._generation.is_valid): + and self._generation is not Generation.NO_GENERATION): # this is a minimal effort attempt to leave the group. we do not # attempt any resending if the request fails or times out. log.info('Leaving consumer group (%s).', self.group_id) - version = self._client.api_version(LeaveGroupRequest, max_version=2) + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 request = LeaveGroupRequest[version](self.group_id, self._generation.member_id) - log.debug('Sending LeaveGroupRequest to %s: %s', self.coordinator_id, request) future = self._client.send(self.coordinator_id, request) future.add_callback(self._handle_leave_group_response) future.add_errback(log.error, "LeaveGroup request failed: %s") - self._client.poll(future=future, timeout_ms=timeout_ms) + self._client.poll(future=future) self.reset_generation() def _handle_leave_group_response(self, response): - log.debug("Received LeaveGroupResponse: %s", response) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - log.info("LeaveGroup request for group %s returned successfully", - self.group_id) + log.debug("LeaveGroup request for group %s returned successfully", + self.group_id) else: log.error("LeaveGroup request for group %s failed with error: %s", self.group_id, error_type()) def _send_heartbeat_request(self): """Send a heartbeat request""" - # Note: acquire both client + coordinator lock before calling if self.coordinator_unknown(): - e = Errors.CoordinatorNotAvailableError(self.coordinator_id) + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) return Future().failure(e) elif not self._client.ready(self.coordinator_id, metadata_priority=False): e = Errors.NodeNotReadyError(self.coordinator_id) return Future().failure(e) - version = self._client.api_version(HeartbeatRequest, max_version=2) + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 request = HeartbeatRequest[version](self.group_id, self._generation.generation_id, self._generation.member_id) - heartbeat_log.debug("Sending HeartbeatRequest to %s: %s", self.coordinator_id, request) + log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member future = Future() _f = self._client.send(self.coordinator_id, request) _f.add_callback(self._handle_heartbeat_response, future, time.time()) @@ -933,42 +812,41 @@ class BaseCoordinator(object): return future def _handle_heartbeat_response(self, future, send_time, response): - if self._sensors: - self._sensors.heartbeat_latency.record((time.time() - send_time) * 1000) - heartbeat_log.debug("Received heartbeat response for group %s: %s", - self.group_id, response) + self.sensors.heartbeat_latency.record((time.time() - send_time) * 1000) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: + log.debug("Received successful heartbeat response for group %s", + self.group_id) future.success(None) - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError): - heartbeat_log.warning("Heartbeat failed for group %s: coordinator (node %s)" - " is either not started or not valid", self.group_id, + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): + log.warning("Heartbeat failed for group %s: coordinator (node %s)" + " is either not started or not valid", self.group_id, self.coordinator()) self.coordinator_dead(error_type()) future.failure(error_type()) elif error_type is Errors.RebalanceInProgressError: - heartbeat_log.warning("Heartbeat failed for group %s because it is" - " rebalancing", self.group_id) + log.warning("Heartbeat failed for group %s because it is" + " rebalancing", self.group_id) self.request_rejoin() future.failure(error_type()) elif error_type is Errors.IllegalGenerationError: - heartbeat_log.warning("Heartbeat failed for group %s: generation id is not " - " current.", self.group_id) + log.warning("Heartbeat failed for group %s: generation id is not " + " current.", self.group_id) self.reset_generation() future.failure(error_type()) elif error_type is Errors.UnknownMemberIdError: - heartbeat_log.warning("Heartbeat: local member_id was not recognized;" - " this consumer needs to re-join") + log.warning("Heartbeat: local member_id was not recognized;" + " this consumer needs to re-join") self.reset_generation() future.failure(error_type) elif error_type is Errors.GroupAuthorizationFailedError: error = error_type(self.group_id) - heartbeat_log.error("Heartbeat failed: authorization error: %s", error) + log.error("Heartbeat failed: authorization error: %s", error) future.failure(error) else: error = error_type() - heartbeat_log.error("Heartbeat failed: Unhandled error: %s", error) + log.error("Heartbeat failed: Unhandled error: %s", error) future.failure(error) @@ -1034,139 +912,100 @@ class HeartbeatThread(threading.Thread): def enable(self): with self.coordinator._lock: - heartbeat_log.debug('Enabling heartbeat thread') self.enabled = True self.coordinator.heartbeat.reset_timeouts() self.coordinator._lock.notify() def disable(self): + self.enabled = False + + def close(self): + self.closed = True with self.coordinator._lock: - heartbeat_log.debug('Disabling heartbeat thread') - self.enabled = False - - def close(self, timeout_ms=None): - with self.coordinator._lock: - if self.closed: - return - - heartbeat_log.info('Stopping heartbeat thread') - self.closed = True - - # Generally this should not happen - close() is triggered - # by the coordinator. But in some cases GC may close the coordinator - # from within the heartbeat thread. - if threading.current_thread() == self: - return - - # Notify coordinator lock to wake thread from sleep/lock.wait self.coordinator._lock.notify() - if self.is_alive(): - if timeout_ms is None: - timeout_ms = self.coordinator.config['heartbeat_interval_ms'] - self.join(timeout_ms / 1000) + self.join(self.coordinator.config['heartbeat_interval_ms'] / 1000) if self.is_alive(): - heartbeat_log.warning("Heartbeat thread did not fully terminate during close") + log.warning("Heartbeat thread did not fully terminate during close") def run(self): try: - heartbeat_log.debug('Heartbeat thread started: %s', self.coordinator.heartbeat) + log.debug('Heartbeat thread started') while not self.closed: self._run_once() except ReferenceError: - heartbeat_log.debug('Heartbeat thread closed due to coordinator gc') + log.debug('Heartbeat thread closed due to coordinator gc') - except Exception as e: - heartbeat_log.exception("Heartbeat thread for group %s failed due to unexpected error: %s", - self.coordinator.group_id, e) + except RuntimeError as e: + log.error("Heartbeat thread for group %s failed due to unexpected error: %s", + self.coordinator.group_id, e) self.failed = e finally: - heartbeat_log.debug('Heartbeat thread closed') + log.debug('Heartbeat thread closed') def _run_once(self): - self.coordinator._client._lock.acquire() - self.coordinator._lock.acquire() - try: + with self.coordinator._client._lock, self.coordinator._lock: + if self.enabled and self.coordinator.state is MemberState.STABLE: + # TODO: When consumer.wakeup() is implemented, we need to + # disable here to prevent propagating an exception to this + # heartbeat thread + # must get client._lock, or maybe deadlock at heartbeat + # failure callbak in consumer poll + self.coordinator._client.poll(timeout_ms=0) + + with self.coordinator._lock: if not self.enabled: - heartbeat_log.debug('Heartbeat disabled. Waiting') - self.coordinator._client._lock.release() + log.debug('Heartbeat disabled. Waiting') self.coordinator._lock.wait() - if self.enabled: - heartbeat_log.debug('Heartbeat re-enabled.') + log.debug('Heartbeat re-enabled.') return if self.coordinator.state is not MemberState.STABLE: # the group is not stable (perhaps because we left the # group or because the coordinator kicked us out), so # disable heartbeats and wait for the main thread to rejoin. - heartbeat_log.debug('Group state is not stable, disabling heartbeats') + log.debug('Group state is not stable, disabling heartbeats') self.disable() return - # TODO: When consumer.wakeup() is implemented, we need to - # disable here to prevent propagating an exception to this - # heartbeat thread - self.coordinator._client.poll(timeout_ms=0) - if self.coordinator.coordinator_unknown(): future = self.coordinator.lookup_coordinator() if not future.is_done or future.failed(): # the immediate future check ensures that we backoff # properly in the case that no brokers are available # to connect to (and the future is automatically failed). - self.coordinator._client._lock.release() self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) - elif not self.coordinator.connected(): - self.coordinator._client._lock.release() - self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) - elif self.coordinator.heartbeat.session_timeout_expired(): # the session timeout has expired without seeing a # successful heartbeat, so we should probably make sure # the coordinator is still healthy. - heartbeat_log.warning('Heartbeat session expired, marking coordinator dead') + log.warning('Heartbeat session expired, marking coordinator dead') self.coordinator.coordinator_dead('Heartbeat session expired') elif self.coordinator.heartbeat.poll_timeout_expired(): # the poll timeout has expired, which means that the # foreground thread has stalled in between calls to # poll(), so we explicitly leave the group. - heartbeat_log.warning( - "Consumer poll timeout has expired. This means the time between subsequent calls to poll()" - " was longer than the configured max_poll_interval_ms, which typically implies that" - " the poll loop is spending too much time processing messages. You can address this" - " either by increasing max_poll_interval_ms or by reducing the maximum size of batches" - " returned in poll() with max_poll_records." - ) + log.warning('Heartbeat poll expired, leaving group') self.coordinator.maybe_leave_group() elif not self.coordinator.heartbeat.should_heartbeat(): - next_hb = self.coordinator.heartbeat.time_to_next_heartbeat() - heartbeat_log.debug('Waiting %0.1f secs to send next heartbeat', next_hb) - self.coordinator._client._lock.release() - self.coordinator._lock.wait(next_hb) + # poll again after waiting for the retry backoff in case + # the heartbeat failed or the coordinator disconnected + log.log(0, 'Not ready to heartbeat, waiting') + self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) else: - heartbeat_log.debug('Sending heartbeat for group %s %s', self.coordinator.group_id, self.coordinator._generation) self.coordinator.heartbeat.sent_heartbeat() future = self.coordinator._send_heartbeat_request() future.add_callback(self._handle_heartbeat_success) future.add_errback(self._handle_heartbeat_failure) - finally: - self.coordinator._lock.release() - try: - # Possibly released in block above to allow coordinator lock wait() - self.coordinator._client._lock.release() - except RuntimeError: - pass - def _handle_heartbeat_success(self, result): with self.coordinator._lock: - heartbeat_log.debug('Heartbeat success') self.coordinator.heartbeat.received_heartbeat() def _handle_heartbeat_failure(self, exception): @@ -1177,10 +1016,8 @@ class HeartbeatThread(threading.Thread): # member in the group for as long as the duration of the # rebalance timeout. If we stop sending heartbeats, however, # then the session timeout may expire before we can rejoin. - heartbeat_log.debug('Treating RebalanceInProgressError as successful heartbeat') self.coordinator.heartbeat.received_heartbeat() else: - heartbeat_log.debug('Heartbeat failure: %s', exception) self.coordinator.heartbeat.fail_heartbeat() # wake up the thread if it's sleeping to reschedule the heartbeat self.coordinator._lock.notify() diff --git a/venv/lib/python3.12/site-packages/kafka/coordinator/consumer.py b/venv/lib/python3.12/site-packages/kafka/coordinator/consumer.py index dca10ae..971f5e8 100644 --- a/venv/lib/python3.12/site-packages/kafka/coordinator/consumer.py +++ b/venv/lib/python3.12/site-packages/kafka/coordinator/consumer.py @@ -19,7 +19,7 @@ from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import Timer, WeakMethod +from kafka.util import WeakMethod log = logging.getLogger(__name__) @@ -39,11 +39,10 @@ class ConsumerCoordinator(BaseCoordinator): 'retry_backoff_ms': 100, 'api_version': (0, 10, 1), 'exclude_internal_topics': True, - 'metrics': None, 'metric_group_prefix': 'consumer' } - def __init__(self, client, subscription, **configs): + def __init__(self, client, subscription, metrics, **configs): """Initialize the coordination manager. Keyword Arguments: @@ -55,7 +54,7 @@ class ConsumerCoordinator(BaseCoordinator): auto_commit_interval_ms (int): milliseconds between automatic offset commits, if enable_auto_commit is True. Default: 5000. default_offset_commit_callback (callable): called as - callback(offsets, response) response will be either an Exception + callback(offsets, exception) response will be either an Exception or None. This callback can be used to trigger custom actions when a commit request completes. assignors (list): List of objects to use to distribute partition @@ -79,7 +78,7 @@ class ConsumerCoordinator(BaseCoordinator): True the only way to receive records from an internal topic is subscribing to it. Requires 0.10+. Default: True """ - super(ConsumerCoordinator, self).__init__(client, **configs) + super(ConsumerCoordinator, self).__init__(client, metrics, **configs) self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -95,7 +94,6 @@ class ConsumerCoordinator(BaseCoordinator): self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000 self.next_auto_commit_deadline = None self.completed_offset_commits = collections.deque() - self._offset_fetch_futures = dict() if self.config['default_offset_commit_callback'] is None: self.config['default_offset_commit_callback'] = self._default_offset_commit_callback @@ -122,21 +120,15 @@ class ConsumerCoordinator(BaseCoordinator): else: self.next_auto_commit_deadline = time.time() + self.auto_commit_interval - if self.config['metrics']: - self._consumer_sensors = ConsumerCoordinatorMetrics( - self.config['metrics'], self.config['metric_group_prefix'], self._subscription) - else: - self._consumer_sensors = None + self.consumer_sensors = ConsumerCoordinatorMetrics( + metrics, self.config['metric_group_prefix'], self._subscription) self._cluster.request_update() self._cluster.add_listener(WeakMethod(self._handle_metadata_update)) def __del__(self): if hasattr(self, '_cluster') and self._cluster: - try: - self._cluster.remove_listener(WeakMethod(self._handle_metadata_update)) - except TypeError: - pass + self._cluster.remove_listener(WeakMethod(self._handle_metadata_update)) super(ConsumerCoordinator, self).__del__() def protocol_type(self): @@ -208,8 +200,8 @@ class ConsumerCoordinator(BaseCoordinator): def _build_metadata_snapshot(self, subscription, cluster): metadata_snapshot = {} for topic in subscription.group_subscription(): - partitions = cluster.partitions_for_topic(topic) - metadata_snapshot[topic] = partitions or set() + partitions = cluster.partitions_for_topic(topic) or [] + metadata_snapshot[topic] = set(partitions) return metadata_snapshot def _lookup_assignor(self, name): @@ -230,6 +222,10 @@ class ConsumerCoordinator(BaseCoordinator): assignment = ConsumerProtocol.ASSIGNMENT.decode(member_assignment_bytes) + # set the flag to refresh last committed offsets + self._subscription.needs_fetch_committed_offsets = True + + # update partition assignment try: self._subscription.assign_from_subscribed(assignment.partitions()) except ValueError as e: @@ -250,16 +246,16 @@ class ConsumerCoordinator(BaseCoordinator): assigned, self.group_id) # execute the user's callback after rebalance - if self._subscription.rebalance_listener: + if self._subscription.listener: try: - self._subscription.rebalance_listener.on_partitions_assigned(assigned) + self._subscription.listener.on_partitions_assigned(assigned) except Exception: - log.exception("User provided rebalance listener %s for group %s" + log.exception("User provided listener %s for group %s" " failed on partition assignment: %s", - self._subscription.rebalance_listener, self.group_id, + self._subscription.listener, self.group_id, assigned) - def poll(self, timeout_ms=None): + def poll(self): """ Poll for coordinator events. Only applicable if group_id is set, and broker version supports GroupCoordinators. This ensures that the @@ -268,46 +264,33 @@ class ConsumerCoordinator(BaseCoordinator): periodic offset commits if they are enabled. """ if self.group_id is None: - return True + return - timer = Timer(timeout_ms) - try: - self._invoke_completed_offset_commit_callbacks() - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): - log.debug('coordinator.poll: timeout in ensure_coordinator_ready; returning early') - return False + self._invoke_completed_offset_commit_callbacks() + self.ensure_coordinator_ready() - if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned(): - if self.need_rejoin(): - # due to a race condition between the initial metadata fetch and the - # initial rebalance, we need to ensure that the metadata is fresh - # before joining initially, and then request the metadata update. If - # metadata update arrives while the rebalance is still pending (for - # example, when the join group is still inflight), then we will lose - # track of the fact that we need to rebalance again to reflect the - # change to the topic subscription. Without ensuring that the - # metadata is fresh, any metadata update that changes the topic - # subscriptions and arrives while a rebalance is in progress will - # essentially be ignored. See KAFKA-3949 for the complete - # description of the problem. - if self._subscription.subscribed_pattern: - metadata_update = self._client.cluster.request_update() - self._client.poll(future=metadata_update, timeout_ms=timer.timeout_ms) - if not metadata_update.is_done: - log.debug('coordinator.poll: timeout updating metadata; returning early') - return False + if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned(): + if self.need_rejoin(): + # due to a race condition between the initial metadata fetch and the + # initial rebalance, we need to ensure that the metadata is fresh + # before joining initially, and then request the metadata update. If + # metadata update arrives while the rebalance is still pending (for + # example, when the join group is still inflight), then we will lose + # track of the fact that we need to rebalance again to reflect the + # change to the topic subscription. Without ensuring that the + # metadata is fresh, any metadata update that changes the topic + # subscriptions and arrives while a rebalance is in progress will + # essentially be ignored. See KAFKA-3949 for the complete + # description of the problem. + if self._subscription.subscribed_pattern: + metadata_update = self._client.cluster.request_update() + self._client.poll(future=metadata_update) - if not self.ensure_active_group(timeout_ms=timer.timeout_ms): - log.debug('coordinator.poll: timeout in ensure_active_group; returning early') - return False + self.ensure_active_group() - self.poll_heartbeat() + self.poll_heartbeat() - self._maybe_auto_commit_offsets_async() - return True - - except Errors.KafkaTimeoutError: - return False + self._maybe_auto_commit_offsets_async() def time_to_next_poll(self): """Return seconds (float) remaining until :meth:`.poll` should be called again""" @@ -357,21 +340,21 @@ class ConsumerCoordinator(BaseCoordinator): group_assignment[member_id] = assignment return group_assignment - def _on_join_prepare(self, generation, member_id, timeout_ms=None): + def _on_join_prepare(self, generation, member_id): # commit offsets prior to rebalance if auto-commit enabled - self._maybe_auto_commit_offsets_sync(timeout_ms=timeout_ms) + self._maybe_auto_commit_offsets_sync() # execute the user's callback before rebalance log.info("Revoking previously assigned partitions %s for group %s", self._subscription.assigned_partitions(), self.group_id) - if self._subscription.rebalance_listener: + if self._subscription.listener: try: revoked = set(self._subscription.assigned_partitions()) - self._subscription.rebalance_listener.on_partitions_revoked(revoked) + self._subscription.listener.on_partitions_revoked(revoked) except Exception: - log.exception("User provided subscription rebalance listener %s" + log.exception("User provided subscription listener %s" " for group %s failed on_partitions_revoked", - self._subscription.rebalance_listener, self.group_id) + self._subscription.listener, self.group_id) self._is_leader = False self._subscription.reset_group_subscription() @@ -400,19 +383,17 @@ class ConsumerCoordinator(BaseCoordinator): return super(ConsumerCoordinator, self).need_rejoin() - def refresh_committed_offsets_if_needed(self, timeout_ms=None): + def refresh_committed_offsets_if_needed(self): """Fetch committed offsets for assigned partitions.""" - missing_fetch_positions = set(self._subscription.missing_fetch_positions()) - try: - offsets = self.fetch_committed_offsets(missing_fetch_positions, timeout_ms=timeout_ms) - except Errors.KafkaTimeoutError: - return False - for partition, offset in six.iteritems(offsets): - log.debug("Setting offset for partition %s to the committed offset %s", partition, offset.offset) - self._subscription.seek(partition, offset.offset) - return True + if self._subscription.needs_fetch_committed_offsets: + offsets = self.fetch_committed_offsets(self._subscription.assigned_partitions()) + for partition, offset in six.iteritems(offsets): + # verify assignment is still active + if self._subscription.is_assigned(partition): + self._subscription.assignment[partition].committed = offset + self._subscription.needs_fetch_committed_offsets = False - def fetch_committed_offsets(self, partitions, timeout_ms=None): + def fetch_committed_offsets(self, partitions): """Fetch the current committed offsets for specified partitions Arguments: @@ -420,45 +401,26 @@ class ConsumerCoordinator(BaseCoordinator): Returns: dict: {TopicPartition: OffsetAndMetadata} - - Raises: - KafkaTimeoutError if timeout_ms provided """ if not partitions: return {} - future_key = frozenset(partitions) - timer = Timer(timeout_ms) while True: - if not self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms): - timer.maybe_raise() + self.ensure_coordinator_ready() # contact coordinator to fetch committed offsets - if future_key in self._offset_fetch_futures: - future = self._offset_fetch_futures[future_key] - else: - future = self._send_offset_fetch_request(partitions) - self._offset_fetch_futures[future_key] = future + future = self._send_offset_fetch_request(partitions) + self._client.poll(future=future) - self._client.poll(future=future, timeout_ms=timer.timeout_ms) + if future.succeeded(): + return future.value - if future.is_done: - del self._offset_fetch_futures[future_key] + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type - if future.succeeded(): - return future.value + time.sleep(self.config['retry_backoff_ms'] / 1000) - elif not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type - - # future failed but is retriable, or is not done yet - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) - timer.maybe_raise() - - def close(self, autocommit=True, timeout_ms=None): + def close(self, autocommit=True): """Close the coordinator, leave the current group, and reset local generation / member_id. @@ -469,14 +431,14 @@ class ConsumerCoordinator(BaseCoordinator): """ try: if autocommit: - self._maybe_auto_commit_offsets_sync(timeout_ms=timeout_ms) + self._maybe_auto_commit_offsets_sync() finally: - super(ConsumerCoordinator, self).close(timeout_ms=timeout_ms) + super(ConsumerCoordinator, self).close() def _invoke_completed_offset_commit_callbacks(self): while self.completed_offset_commits: - callback, offsets, res_or_exc = self.completed_offset_commits.popleft() - callback(offsets, res_or_exc) + callback, offsets, exception = self.completed_offset_commits.popleft() + callback(offsets, exception) def commit_offsets_async(self, offsets, callback=None): """Commit specific offsets asynchronously. @@ -516,18 +478,18 @@ class ConsumerCoordinator(BaseCoordinator): return future def _do_commit_offsets_async(self, offsets, callback=None): - if self.config['api_version'] < (0, 8, 1): - raise Errors.UnsupportedVersionError('OffsetCommitRequest requires 0.8.1+ broker') + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) if callback is None: callback = self.config['default_offset_commit_callback'] + self._subscription.needs_fetch_committed_offsets = True future = self._send_offset_commit_request(offsets) future.add_both(lambda res: self.completed_offset_commits.appendleft((callback, offsets, res))) return future - def commit_offsets_sync(self, offsets, timeout_ms=None): + def commit_offsets_sync(self, offsets): """Commit specific offsets synchronously. This method will retry until the commit completes successfully or an @@ -538,8 +500,7 @@ class ConsumerCoordinator(BaseCoordinator): Raises error on failure """ - if self.config['api_version'] < (0, 8, 1): - raise Errors.UnsupportedVersionError('OffsetCommitRequest requires 0.8.1+ broker') + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) @@ -547,31 +508,24 @@ class ConsumerCoordinator(BaseCoordinator): if not offsets: return - timer = Timer(timeout_ms) while True: - self.ensure_coordinator_ready(timeout_ms=timer.timeout_ms) + self.ensure_coordinator_ready() future = self._send_offset_commit_request(offsets) - self._client.poll(future=future, timeout_ms=timer.timeout_ms) + self._client.poll(future=future) - if future.is_done: - if future.succeeded(): - return future.value + if future.succeeded(): + return future.value - elif not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type - # future failed but is retriable, or it is still pending - if timer.timeout_ms is None or timer.timeout_ms > self.config['retry_backoff_ms']: - time.sleep(self.config['retry_backoff_ms'] / 1000) - else: - time.sleep(timer.timeout_ms / 1000) - timer.maybe_raise() + time.sleep(self.config['retry_backoff_ms'] / 1000) - def _maybe_auto_commit_offsets_sync(self, timeout_ms=None): + def _maybe_auto_commit_offsets_sync(self): if self.config['enable_auto_commit']: try: - self.commit_offsets_sync(self._subscription.all_consumed_offsets(), timeout_ms=timeout_ms) + self.commit_offsets_sync(self._subscription.all_consumed_offsets()) # The three main group membership errors are known and should not # require a stacktrace -- just a warning @@ -599,8 +553,7 @@ class ConsumerCoordinator(BaseCoordinator): Returns: Future: indicating whether the commit was successful or not """ - if self.config['api_version'] < (0, 8, 1): - raise Errors.UnsupportedVersionError('OffsetCommitRequest requires 0.8.1+ broker') + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) @@ -610,46 +563,31 @@ class ConsumerCoordinator(BaseCoordinator): node_id = self.coordinator() if node_id is None: - return Future().failure(Errors.CoordinatorNotAvailableError) + return Future().failure(Errors.GroupCoordinatorNotAvailableError) - # Verify node is ready - if not self._client.ready(node_id, metadata_priority=False): - log.debug("Node %s not ready -- failing offset commit request", - node_id) - return Future().failure(Errors.NodeNotReadyError) # create the offset commit request offset_data = collections.defaultdict(dict) for tp, offset in six.iteritems(offsets): offset_data[tp.topic][tp.partition] = offset - version = self._client.api_version(OffsetCommitRequest, max_version=6) - if version > 1 and self._subscription.partitions_auto_assigned(): - generation = self.generation_if_stable() + if self._subscription.partitions_auto_assigned(): + generation = self.generation() else: generation = Generation.NO_GENERATION # if the generation is None, we are not part of an active group # (and we expect to be). The only thing we can do is fail the commit # and let the user rejoin the group in poll() - if generation is None: - log.info("Failing OffsetCommit request since the consumer is not part of an active group") - if self.rebalance_in_progress(): - # if the client knows it is already rebalancing, we can use RebalanceInProgressError instead of - # CommitFailedError to indicate this is not a fatal error - return Future().failure(Errors.RebalanceInProgressError( - "Offset commit cannot be completed since the" - " consumer is undergoing a rebalance for auto partition assignment. You can try completing the rebalance" - " by calling poll() and then retry the operation.")) - else: - return Future().failure(Errors.CommitFailedError( - "Offset commit cannot be completed since the" - " consumer is not part of an active group for auto partition assignment; it is likely that the consumer" - " was kicked out of the group.")) + if self.config['api_version'] >= (0, 9) and generation is None: + return Future().failure(Errors.CommitFailedError()) - if version == 0: - request = OffsetCommitRequest[version]( + if self.config['api_version'] >= (0, 9): + request = OffsetCommitRequest[2]( self.group_id, + generation.generation_id, + generation.member_id, + OffsetCommitRequest[2].DEFAULT_RETENTION_TIME, [( topic, [( partition, @@ -658,28 +596,21 @@ class ConsumerCoordinator(BaseCoordinator): ) for partition, offset in six.iteritems(partitions)] ) for topic, partitions in six.iteritems(offset_data)] ) - elif version == 1: - request = OffsetCommitRequest[version]( - self.group_id, - # This api version was only used in v0.8.2, prior to join group apis - # so this always ends up as NO_GENERATION - generation.generation_id, - generation.member_id, + elif self.config['api_version'] >= (0, 8, 2): + request = OffsetCommitRequest[1]( + self.group_id, -1, '', [( topic, [( partition, offset.offset, - -1, # timestamp, unused + -1, offset.metadata ) for partition, offset in six.iteritems(partitions)] ) for topic, partitions in six.iteritems(offset_data)] ) - elif version <= 4: - request = OffsetCommitRequest[version]( + elif self.config['api_version'] >= (0, 8, 1): + request = OffsetCommitRequest[0]( self.group_id, - generation.generation_id, - generation.member_id, - OffsetCommitRequest[version].DEFAULT_RETENTION_TIME, [( topic, [( partition, @@ -688,33 +619,6 @@ class ConsumerCoordinator(BaseCoordinator): ) for partition, offset in six.iteritems(partitions)] ) for topic, partitions in six.iteritems(offset_data)] ) - elif version <= 5: - request = OffsetCommitRequest[version]( - self.group_id, - generation.generation_id, - generation.member_id, - [( - topic, [( - partition, - offset.offset, - offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] - ) - else: - request = OffsetCommitRequest[version]( - self.group_id, - generation.generation_id, - generation.member_id, - [( - topic, [( - partition, - offset.offset, - offset.leader_epoch, - offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] - ) log.debug("Sending offset-commit request with %s for group %s to %s", offsets, self.group_id, node_id) @@ -726,10 +630,8 @@ class ConsumerCoordinator(BaseCoordinator): return future def _handle_offset_commit_response(self, offsets, future, send_time, response): - log.debug("Received OffsetCommitResponse: %s", response) # TODO look at adding request_latency_ms to response (like java kafka) - if self._consumer_sensors: - self._consumer_sensors.commit_latency.record((time.time() - send_time) * 1000) + self.consumer_sensors.commit_latency.record((time.time() - send_time) * 1000) unauthorized_topics = set() for topic, partitions in response.topics: @@ -741,6 +643,8 @@ class ConsumerCoordinator(BaseCoordinator): if error_type is Errors.NoError: log.debug("Group %s committed offset %s for partition %s", self.group_id, offset, tp) + if self._subscription.is_assigned(tp): + self._subscription.assignment[tp].committed = offset elif error_type is Errors.GroupAuthorizationFailedError: log.error("Not authorized to commit offsets for group %s", self.group_id) @@ -755,38 +659,29 @@ class ConsumerCoordinator(BaseCoordinator): " %s", self.group_id, tp, error_type.__name__) future.failure(error_type()) return - elif error_type is Errors.CoordinatorLoadInProgressError: + elif error_type is Errors.GroupLoadInProgressError: # just retry log.debug("OffsetCommit for group %s failed: %s", self.group_id, error_type.__name__) future.failure(error_type(self.group_id)) return - elif error_type in (Errors.CoordinatorNotAvailableError, - Errors.NotCoordinatorError, + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, Errors.RequestTimedOutError): log.debug("OffsetCommit for group %s failed: %s", self.group_id, error_type.__name__) self.coordinator_dead(error_type()) future.failure(error_type(self.group_id)) return - elif error_type is Errors.RebalanceInProgressError: - # Consumer never tries to commit offset in between join-group and sync-group, - # and hence on broker-side it is not expected to see a commit offset request - # during CompletingRebalance phase; if it ever happens then broker would return - # this error. In this case we should just treat as a fatal CommitFailed exception. - # However, we do not need to reset generations and just request re-join, such that - # if the caller decides to proceed and poll, it would still try to proceed and re-join normally. - self.request_rejoin() - future.failure(Errors.CommitFailedError(error_type())) - return elif error_type in (Errors.UnknownMemberIdError, - Errors.IllegalGenerationError): - # need reset generation and re-join group + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError): + # need to re-join group error = error_type(self.group_id) - log.warning("OffsetCommit for group %s failed: %s", - self.group_id, error) + log.debug("OffsetCommit for group %s failed: %s", + self.group_id, error) self.reset_generation() - future.failure(Errors.CommitFailedError(error_type())) + future.failure(Errors.CommitFailedError()) return else: log.error("Group %s failed to commit partition %s at offset" @@ -814,18 +709,17 @@ class ConsumerCoordinator(BaseCoordinator): Returns: Future: resolves to dict of offsets: {TopicPartition: OffsetAndMetadata} """ - if self.config['api_version'] < (0, 8, 1): - raise Errors.UnsupportedVersionError('OffsetFetchRequest requires 0.8.1+ broker') + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' assert all(map(lambda k: isinstance(k, TopicPartition), partitions)) if not partitions: return Future().success({}) node_id = self.coordinator() if node_id is None: - return Future().failure(Errors.CoordinatorNotAvailableError) + return Future().failure(Errors.GroupCoordinatorNotAvailableError) # Verify node is ready - if not self._client.ready(node_id, metadata_priority=False): + if not self._client.ready(node_id): log.debug("Node %s not ready -- failing offset fetch request", node_id) return Future().failure(Errors.NodeNotReadyError) @@ -837,13 +731,16 @@ class ConsumerCoordinator(BaseCoordinator): for tp in partitions: topic_partitions[tp.topic].add(tp.partition) - version = self._client.api_version(OffsetFetchRequest, max_version=5) - # Starting in version 2, the request can contain a null topics array to indicate that offsets should be fetched - # TODO: support - request = OffsetFetchRequest[version]( - self.group_id, - list(topic_partitions.items()) - ) + if self.config['api_version'] >= (0, 8, 2): + request = OffsetFetchRequest[1]( + self.group_id, + list(topic_partitions.items()) + ) + else: + request = OffsetFetchRequest[0]( + self.group_id, + list(topic_partitions.items()) + ) # send the request with a callback future = Future() @@ -853,46 +750,21 @@ class ConsumerCoordinator(BaseCoordinator): return future def _handle_offset_fetch_response(self, future, response): - log.debug("Received OffsetFetchResponse: %s", response) - if response.API_VERSION >= 2 and response.error_code != Errors.NoError.errno: - error_type = Errors.for_code(response.error_code) - log.debug("Offset fetch failed: %s", error_type.__name__) - error = error_type() - if error_type is Errors.CoordinatorLoadInProgressError: - # Retry - future.failure(error) - elif error_type is Errors.NotCoordinatorError: - # re-discover the coordinator and retry - self.coordinator_dead(error) - future.failure(error) - elif error_type is Errors.GroupAuthorizationFailedError: - future.failure(error) - else: - log.error("Unknown error fetching offsets: %s", error) - future.failure(error) - return - offsets = {} for topic, partitions in response.topics: - for partition_data in partitions: - partition, offset = partition_data[:2] - if response.API_VERSION >= 5: - leader_epoch, metadata, error_code = partition_data[2:] - else: - metadata, error_code = partition_data[2:] - leader_epoch = -1 # noqa: F841 + for partition, offset, metadata, error_code in partitions: tp = TopicPartition(topic, partition) error_type = Errors.for_code(error_code) if error_type is not Errors.NoError: error = error_type() log.debug("Group %s failed to fetch offset for partition" " %s: %s", self.group_id, tp, error) - if error_type is Errors.CoordinatorLoadInProgressError: + if error_type is Errors.GroupLoadInProgressError: # just retry future.failure(error) - elif error_type is Errors.NotCoordinatorError: + elif error_type is Errors.NotCoordinatorForGroupError: # re-discover the coordinator and retry - self.coordinator_dead(error) + self.coordinator_dead(error_type()) future.failure(error) elif error_type is Errors.UnknownTopicOrPartitionError: log.warning("OffsetFetchRequest -- unknown topic %s" @@ -907,41 +779,34 @@ class ConsumerCoordinator(BaseCoordinator): elif offset >= 0: # record the position with the offset # (-1 indicates no committed offset to fetch) - # TODO: save leader_epoch - offsets[tp] = OffsetAndMetadata(offset, metadata, -1) + offsets[tp] = OffsetAndMetadata(offset, metadata) else: log.debug("Group %s has no committed offset for partition" " %s", self.group_id, tp) future.success(offsets) - def _default_offset_commit_callback(self, offsets, res_or_exc): - if isinstance(res_or_exc, Exception): + def _default_offset_commit_callback(self, offsets, exception): + if exception is not None: + log.error("Offset commit failed: %s", exception) + + def _commit_offsets_async_on_complete(self, offsets, exception): + if exception is not None: log.warning("Auto offset commit failed for group %s: %s", - self.group_id, res_or_exc) + self.group_id, exception) + if getattr(exception, 'retriable', False): + self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline) else: log.debug("Completed autocommit of offsets %s for group %s", offsets, self.group_id) - def _commit_offsets_async_on_complete(self, offsets, res_or_exc): - if isinstance(res_or_exc, Exception) and getattr(res_or_exc, 'retriable', False): - self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline) - self.config['default_offset_commit_callback'](offsets, res_or_exc) - def _maybe_auto_commit_offsets_async(self): if self.config['enable_auto_commit']: if self.coordinator_unknown(): self.next_auto_commit_deadline = time.time() + self.config['retry_backoff_ms'] / 1000 elif time.time() > self.next_auto_commit_deadline: self.next_auto_commit_deadline = time.time() + self.auto_commit_interval - self._do_auto_commit_offsets_async() - - def maybe_auto_commit_offsets_now(self): - if self.config['enable_auto_commit'] and not self.coordinator_unknown(): - self._do_auto_commit_offsets_async() - - def _do_auto_commit_offsets_async(self): - self.commit_offsets_async(self._subscription.all_consumed_offsets(), - self._commit_offsets_async_on_complete) + self.commit_offsets_async(self._subscription.all_consumed_offsets(), + self._commit_offsets_async_on_complete) class ConsumerCoordinatorMetrics(object): diff --git a/venv/lib/python3.12/site-packages/kafka/coordinator/heartbeat.py b/venv/lib/python3.12/site-packages/kafka/coordinator/heartbeat.py index edc9f4a..2f5930b 100644 --- a/venv/lib/python3.12/site-packages/kafka/coordinator/heartbeat.py +++ b/venv/lib/python3.12/site-packages/kafka/coordinator/heartbeat.py @@ -1,13 +1,8 @@ from __future__ import absolute_import, division import copy -import logging import time -from kafka.errors import KafkaConfigurationError - -log = logging.getLogger(__name__) - class Heartbeat(object): DEFAULT_CONFIG = { @@ -25,13 +20,9 @@ class Heartbeat(object): self.config[key] = configs[key] if self.config['group_id'] is not None: - if self.config['heartbeat_interval_ms'] >= self.config['session_timeout_ms']: - raise KafkaConfigurationError('Heartbeat interval must be lower than the session timeout (%s v %s)' % ( - self.config['heartbeat_interval_ms'], self.config['session_timeout_ms'])) - if self.config['heartbeat_interval_ms'] > (self.config['session_timeout_ms'] / 3): - log.warning('heartbeat_interval_ms is high relative to session_timeout_ms (%s v %s).' - ' Recommend heartbeat interval less than 1/3rd of session timeout', - self.config['heartbeat_interval_ms'], self.config['session_timeout_ms']) + assert (self.config['heartbeat_interval_ms'] + <= self.config['session_timeout_ms']), ( + 'Heartbeat interval must be lower than the session timeout') self.last_send = -1 * float('inf') self.last_receive = -1 * float('inf') @@ -75,10 +66,3 @@ class Heartbeat(object): def poll_timeout_expired(self): return (time.time() - self.last_poll) > (self.config['max_poll_interval_ms'] / 1000) - - def __str__(self): - return ("").format(**self.config) diff --git a/venv/lib/python3.12/site-packages/kafka/errors.py b/venv/lib/python3.12/site-packages/kafka/errors.py index ac4eadf..b33cf51 100644 --- a/venv/lib/python3.12/site-packages/kafka/errors.py +++ b/venv/lib/python3.12/site-packages/kafka/errors.py @@ -16,39 +16,23 @@ class KafkaError(RuntimeError): super(KafkaError, self).__str__()) -class Cancelled(KafkaError): - retriable = True - - -class CommitFailedError(KafkaError): - def __init__(self, *args): - if not args: - args = ("Commit cannot be completed since the group has already" - " rebalanced and assigned the partitions to another member.",) - super(CommitFailedError, self).__init__(*args) +class IllegalStateError(KafkaError): + pass class IllegalArgumentError(KafkaError): pass -class IllegalStateError(KafkaError): - pass - - -class IncompatibleBrokerVersion(KafkaError): - pass - - -class KafkaConfigurationError(KafkaError): - pass - - -class KafkaConnectionError(KafkaError): +class NoBrokersAvailable(KafkaError): retriable = True invalid_metadata = True +class NodeNotReadyError(KafkaError): + retriable = True + + class KafkaProtocolError(KafkaError): retriable = True @@ -57,37 +41,20 @@ class CorrelationIdError(KafkaProtocolError): retriable = True -class KafkaTimeoutError(KafkaError): +class Cancelled(KafkaError): retriable = True -class MetadataEmptyBrokerList(KafkaError): +class TooManyInFlightRequests(KafkaError): retriable = True -class NoBrokersAvailable(KafkaError): - retriable = True - invalid_metadata = True - - -class NoOffsetForPartitionError(KafkaError): - pass - - -class NodeNotReadyError(KafkaError): - retriable = True - - -class QuotaViolationError(KafkaError): - pass - - class StaleMetadata(KafkaError): retriable = True invalid_metadata = True -class TooManyInFlightRequests(KafkaError): +class MetadataEmptyBrokerList(KafkaError): retriable = True @@ -95,10 +62,33 @@ class UnrecognizedBrokerVersion(KafkaError): pass -class UnsupportedCodecError(KafkaError): +class IncompatibleBrokerVersion(KafkaError): pass +class CommitFailedError(KafkaError): + def __init__(self, *args, **kwargs): + super(CommitFailedError, self).__init__( + """Commit cannot be completed since the group has already + rebalanced and assigned the partitions to another member. + This means that the time between subsequent calls to poll() + was longer than the configured max_poll_interval_ms, which + typically implies that the poll loop is spending too much + time message processing. You can address this either by + increasing the rebalance timeout with max_poll_interval_ms, + or by reducing the maximum size of batches returned in poll() + with max_poll_records. + """, *args, **kwargs) + + +class AuthenticationMethodNotSupported(KafkaError): + pass + + +class AuthenticationFailedError(KafkaError): + retriable = False + + class BrokerResponseError(KafkaError): errno = None message = None @@ -111,10 +101,6 @@ class BrokerResponseError(KafkaError): super(BrokerResponseError, self).__str__()) -class AuthorizationError(BrokerResponseError): - pass - - class NoError(BrokerResponseError): errno = 0 message = 'NO_ERROR' @@ -134,14 +120,14 @@ class OffsetOutOfRangeError(BrokerResponseError): ' maintained by the server for the given topic/partition.') -class CorruptRecordError(BrokerResponseError): +class CorruptRecordException(BrokerResponseError): errno = 2 message = 'CORRUPT_MESSAGE' description = ('This message has failed its CRC checksum, exceeds the' ' valid size, or is otherwise corrupt.') # Backward compatibility -CorruptRecordException = CorruptRecordError +InvalidMessageError = CorruptRecordException class UnknownTopicOrPartitionError(BrokerResponseError): @@ -200,8 +186,7 @@ class ReplicaNotAvailableError(BrokerResponseError): message = 'REPLICA_NOT_AVAILABLE' description = ('If replica is expected on a broker, but is not (this can be' ' safely ignored).') - retriable = True - invalid_metadata = True + class MessageSizeTooLargeError(BrokerResponseError): errno = 10 @@ -225,35 +210,39 @@ class OffsetMetadataTooLargeError(BrokerResponseError): ' offset metadata.') -class NetworkExceptionError(BrokerResponseError): +# TODO is this deprecated? https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes +class StaleLeaderEpochCodeError(BrokerResponseError): errno = 13 - message = 'NETWORK_EXCEPTION' - retriable = True - invalid_metadata = True + message = 'STALE_LEADER_EPOCH_CODE' -class CoordinatorLoadInProgressError(BrokerResponseError): +class GroupLoadInProgressError(BrokerResponseError): errno = 14 - message = 'COORDINATOR_LOAD_IN_PROGRESS' - description = ('The broker returns this error code for txn or group requests,' - ' when the coordinator is loading and hence cant process requests') + message = 'OFFSETS_LOAD_IN_PROGRESS' + description = ('The broker returns this error code for an offset fetch' + ' request if it is still loading offsets (after a leader' + ' change for that offsets topic partition), or in response' + ' to group membership requests (such as heartbeats) when' + ' group metadata is being loaded by the coordinator.') retriable = True -class CoordinatorNotAvailableError(BrokerResponseError): +class GroupCoordinatorNotAvailableError(BrokerResponseError): errno = 15 - message = 'COORDINATOR_NOT_AVAILABLE' - description = ('The broker returns this error code for consumer and transaction' + message = 'CONSUMER_COORDINATOR_NOT_AVAILABLE' + description = ('The broker returns this error code for group coordinator' + ' requests, offset commits, and most group management' ' requests if the offsets topic has not yet been created, or' - ' if the group/txn coordinator is not active.') + ' if the group coordinator is not active.') retriable = True -class NotCoordinatorError(BrokerResponseError): +class NotCoordinatorForGroupError(BrokerResponseError): errno = 16 - message = 'NOT_COORDINATOR' - description = ('The broker returns this error code if it is not the correct' - ' coordinator for the specified consumer or transaction group') + message = 'NOT_COORDINATOR_FOR_CONSUMER' + description = ('The broker returns this error code if it receives an offset' + ' fetch or commit request for a group that it is not a' + ' coordinator for.') retriable = True @@ -350,21 +339,21 @@ class InvalidCommitOffsetSizeError(BrokerResponseError): ' because of oversize metadata.') -class TopicAuthorizationFailedError(AuthorizationError): +class TopicAuthorizationFailedError(BrokerResponseError): errno = 29 message = 'TOPIC_AUTHORIZATION_FAILED' description = ('Returned by the broker when the client is not authorized to' ' access the requested topic.') -class GroupAuthorizationFailedError(AuthorizationError): +class GroupAuthorizationFailedError(BrokerResponseError): errno = 30 message = 'GROUP_AUTHORIZATION_FAILED' description = ('Returned by the broker when the client is not authorized to' ' access a particular groupId.') -class ClusterAuthorizationFailedError(AuthorizationError): +class ClusterAuthorizationFailedError(BrokerResponseError): errno = 31 message = 'CLUSTER_AUTHORIZATION_FAILED' description = ('Returned by the broker when the client is not authorized to' @@ -452,597 +441,65 @@ class PolicyViolationError(BrokerResponseError): errno = 44 message = 'POLICY_VIOLATION' description = 'Request parameters do not satisfy the configured policy.' - retriable = False - - -class OutOfOrderSequenceNumberError(BrokerResponseError): - errno = 45 - message = 'OUT_OF_ORDER_SEQUENCE_NUMBER' - description = 'The broker received an out of order sequence number.' - retriable = False - - -class DuplicateSequenceNumberError(BrokerResponseError): - errno = 46 - message = 'DUPLICATE_SEQUENCE_NUMBER' - description = 'The broker received a duplicate sequence number.' - retriable = False - - -class InvalidProducerEpochError(BrokerResponseError): - errno = 47 - message = 'INVALID_PRODUCER_EPOCH' - description = 'Producer attempted to produce with an old epoch.' - retriable = False - - -class InvalidTxnStateError(BrokerResponseError): - errno = 48 - message = 'INVALID_TXN_STATE' - description = 'The producer attempted a transactional operation in an invalid state.' - retriable = False - - -class InvalidProducerIdMappingError(BrokerResponseError): - errno = 49 - message = 'INVALID_PRODUCER_ID_MAPPING' - description = 'The producer attempted to use a producer id which is not currently assigned to its transactional id.' - retriable = False - - -class InvalidTransactionTimeoutError(BrokerResponseError): - errno = 50 - message = 'INVALID_TRANSACTION_TIMEOUT' - description = 'The transaction timeout is larger than the maximum value allowed by the broker (as configured by transaction.max.timeout.ms).' - retriable = False - - -class ConcurrentTransactionsError(BrokerResponseError): - errno = 51 - message = 'CONCURRENT_TRANSACTIONS' - description = 'The producer attempted to update a transaction while another concurrent operation on the same transaction was ongoing.' - retriable = True - - -class TransactionCoordinatorFencedError(BrokerResponseError): - errno = 52 - message = 'TRANSACTION_COORDINATOR_FENCED' - description = 'Indicates that the transaction coordinator sending a WriteTxnMarker is no longer the current coordinator for a given producer.' - retriable = False - - -class TransactionalIdAuthorizationFailedError(AuthorizationError): - errno = 53 - message = 'TRANSACTIONAL_ID_AUTHORIZATION_FAILED' - description = 'Transactional Id authorization failed.' - retriable = False class SecurityDisabledError(BrokerResponseError): errno = 54 message = 'SECURITY_DISABLED' description = 'Security features are disabled.' - retriable = False - - -class OperationNotAttemptedError(BrokerResponseError): - errno = 55 - message = 'OPERATION_NOT_ATTEMPTED' - description = 'The broker did not attempt to execute this operation. This may happen for batched RPCs where some operations in the batch failed, causing the broker to respond without trying the rest.' - retriable = False - - -class KafkaStorageError(BrokerResponseError): - errno = 56 - message = 'KAFKA_STORAGE_ERROR' - description = 'Disk error when trying to access log file on the disk.' - retriable = True - invalid_metadata = True - - -class LogDirNotFoundError(BrokerResponseError): - errno = 57 - message = 'LOG_DIR_NOT_FOUND' - description = 'The user-specified log directory is not found in the broker config.' - retriable = False - - -class SaslAuthenticationFailedError(BrokerResponseError): - errno = 58 - message = 'SASL_AUTHENTICATION_FAILED' - description = 'SASL Authentication failed.' - retriable = False - - -class UnknownProducerIdError(BrokerResponseError): - errno = 59 - message = 'UNKNOWN_PRODUCER_ID' - description = 'This exception is raised by the broker if it could not locate the producer metadata associated with the producerId in question. This could happen if, for instance, the producer\'s records were deleted because their retention time had elapsed. Once the last records of the producerId are removed, the producer\'s metadata is removed from the broker, and future appends by the producer will return this exception.' - retriable = False - - -class ReassignmentInProgressError(BrokerResponseError): - errno = 60 - message = 'REASSIGNMENT_IN_PROGRESS' - description = 'A partition reassignment is in progress.' - retriable = False - - -class DelegationTokenAuthDisabledError(BrokerResponseError): - errno = 61 - message = 'DELEGATION_TOKEN_AUTH_DISABLED' - description = 'Delegation Token feature is not enabled.' - retriable = False - - -class DelegationTokenNotFoundError(BrokerResponseError): - errno = 62 - message = 'DELEGATION_TOKEN_NOT_FOUND' - description = 'Delegation Token is not found on server.' - retriable = False - - -class DelegationTokenOwnerMismatchError(BrokerResponseError): - errno = 63 - message = 'DELEGATION_TOKEN_OWNER_MISMATCH' - description = 'Specified Principal is not valid Owner/Renewer.' - retriable = False - - -class DelegationTokenRequestNotAllowedError(BrokerResponseError): - errno = 64 - message = 'DELEGATION_TOKEN_REQUEST_NOT_ALLOWED' - description = 'Delegation Token requests are not allowed on PLAINTEXT/1-way SSL channels and on delegation token authenticated channels.' - retriable = False - - -class DelegationTokenAuthorizationFailedError(AuthorizationError): - errno = 65 - message = 'DELEGATION_TOKEN_AUTHORIZATION_FAILED' - description = 'Delegation Token authorization failed.' - retriable = False - - -class DelegationTokenExpiredError(BrokerResponseError): - errno = 66 - message = 'DELEGATION_TOKEN_EXPIRED' - description = 'Delegation Token is expired.' - retriable = False - - -class InvalidPrincipalTypeError(BrokerResponseError): - errno = 67 - message = 'INVALID_PRINCIPAL_TYPE' - description = 'Supplied principalType is not supported.' - retriable = False class NonEmptyGroupError(BrokerResponseError): errno = 68 message = 'NON_EMPTY_GROUP' description = 'The group is not empty.' - retriable = False class GroupIdNotFoundError(BrokerResponseError): errno = 69 message = 'GROUP_ID_NOT_FOUND' description = 'The group id does not exist.' - retriable = False -class FetchSessionIdNotFoundError(BrokerResponseError): - errno = 70 - message = 'FETCH_SESSION_ID_NOT_FOUND' - description = 'The fetch session ID was not found.' - retriable = True +class KafkaUnavailableError(KafkaError): + pass -class InvalidFetchSessionEpochError(BrokerResponseError): - errno = 71 - message = 'INVALID_FETCH_SESSION_EPOCH' - description = 'The fetch session epoch is invalid.' - retriable = True +class KafkaTimeoutError(KafkaError): + pass -class ListenerNotFoundError(BrokerResponseError): - errno = 72 - message = 'LISTENER_NOT_FOUND' - description = 'There is no listener on the leader broker that matches the listener on which metadata request was processed.' +class FailedPayloadsError(KafkaError): + def __init__(self, payload, *args): + super(FailedPayloadsError, self).__init__(*args) + self.payload = payload + + +class KafkaConnectionError(KafkaError): retriable = True invalid_metadata = True -class TopicDeletionDisabledError(BrokerResponseError): - errno = 73 - message = 'TOPIC_DELETION_DISABLED' - description = 'Topic deletion is disabled.' - retriable = False +class ProtocolError(KafkaError): + pass -class FencedLeaderEpochError(BrokerResponseError): - errno = 74 - message = 'FENCED_LEADER_EPOCH' - description = 'The leader epoch in the request is older than the epoch on the broker.' - retriable = True - invalid_metadata = True +class UnsupportedCodecError(KafkaError): + pass -class UnknownLeaderEpochError(BrokerResponseError): - errno = 75 - message = 'UNKNOWN_LEADER_EPOCH' - description = 'The leader epoch in the request is newer than the epoch on the broker.' - retriable = True - invalid_metadata = True +class KafkaConfigurationError(KafkaError): + pass -class UnsupportedCompressionTypeError(BrokerResponseError): - errno = 76 - message = 'UNSUPPORTED_COMPRESSION_TYPE' - description = 'The requesting client does not support the compression type of given partition.' - retriable = False +class QuotaViolationError(KafkaError): + pass -class StaleBrokerEpochError(BrokerResponseError): - errno = 77 - message = 'STALE_BROKER_EPOCH' - description = 'Broker epoch has changed.' - retriable = False - - -class OffsetNotAvailableError(BrokerResponseError): - errno = 78 - message = 'OFFSET_NOT_AVAILABLE' - description = 'The leader high watermark has not caught up from a recent leader election so the offsets cannot be guaranteed to be monotonically increasing.' - retriable = True - - -class MemberIdRequiredError(BrokerResponseError): - errno = 79 - message = 'MEMBER_ID_REQUIRED' - description = 'The group member needs to have a valid member id before actually entering a consumer group.' - retriable = False - - -class PreferredLeaderNotAvailableError(BrokerResponseError): - errno = 80 - message = 'PREFERRED_LEADER_NOT_AVAILABLE' - description = 'The preferred leader was not available.' - retriable = True - invalid_metadata = True - - -class GroupMaxSizeReachedError(BrokerResponseError): - errno = 81 - message = 'GROUP_MAX_SIZE_REACHED' - description = 'The consumer group has reached its max size.' - retriable = False - - -class FencedInstanceIdError(BrokerResponseError): - errno = 82 - message = 'FENCED_INSTANCE_ID' - description = 'The broker rejected this static consumer since another consumer with the same group.instance.id has registered with a different member.id.' - retriable = False - - -class EligibleLeadersNotAvailableError(BrokerResponseError): - errno = 83 - message = 'ELIGIBLE_LEADERS_NOT_AVAILABLE' - description = 'Eligible topic partition leaders are not available.' - retriable = True - invalid_metadata = True - - -class ElectionNotNeededError(BrokerResponseError): - errno = 84 - message = 'ELECTION_NOT_NEEDED' - description = 'Leader election not needed for topic partition.' - retriable = True - invalid_metadata = True - - -class NoReassignmentInProgressError(BrokerResponseError): - errno = 85 - message = 'NO_REASSIGNMENT_IN_PROGRESS' - description = 'No partition reassignment is in progress.' - retriable = False - - -class GroupSubscribedToTopicError(BrokerResponseError): - errno = 86 - message = 'GROUP_SUBSCRIBED_TO_TOPIC' - description = 'Deleting offsets of a topic is forbidden while the consumer group is actively subscribed to it.' - retriable = False - - -class InvalidRecordError(BrokerResponseError): - errno = 87 - message = 'INVALID_RECORD' - description = 'This record has failed the validation on broker and hence will be rejected.' - retriable = False - - -class UnstableOffsetCommitError(BrokerResponseError): - errno = 88 - message = 'UNSTABLE_OFFSET_COMMIT' - description = 'There are unstable offsets that need to be cleared.' - retriable = True - - -class ThrottlingQuotaExceededError(BrokerResponseError): - errno = 89 - message = 'THROTTLING_QUOTA_EXCEEDED' - description = 'The throttling quota has been exceeded.' - retriable = True - - -class ProducerFencedError(BrokerResponseError): - errno = 90 - message = 'PRODUCER_FENCED' - description = 'There is a newer producer with the same transactionalId which fences the current one.' - retriable = False - - -class ResourceNotFoundError(BrokerResponseError): - errno = 91 - message = 'RESOURCE_NOT_FOUND' - description = 'A request illegally referred to a resource that does not exist.' - retriable = False - - -class DuplicateResourceError(BrokerResponseError): - errno = 92 - message = 'DUPLICATE_RESOURCE' - description = 'A request illegally referred to the same resource twice.' - retriable = False - - -class UnacceptableCredentialError(BrokerResponseError): - errno = 93 - message = 'UNACCEPTABLE_CREDENTIAL' - description = 'Requested credential would not meet criteria for acceptability.' - retriable = False - - -class InconsistentVoterSetError(BrokerResponseError): - errno = 94 - message = 'INCONSISTENT_VOTER_SET' - description = 'Indicates that the either the sender or recipient of a voter-only request is not one of the expected voters.' - retriable = False - - -class InvalidUpdateVersionError(BrokerResponseError): - errno = 95 - message = 'INVALID_UPDATE_VERSION' - description = 'The given update version was invalid.' - retriable = False - - -class FeatureUpdateFailedError(BrokerResponseError): - errno = 96 - message = 'FEATURE_UPDATE_FAILED' - description = 'Unable to update finalized features due to an unexpected server error.' - retriable = False - - -class PrincipalDeserializationFailureError(BrokerResponseError): - errno = 97 - message = 'PRINCIPAL_DESERIALIZATION_FAILURE' - description = 'Request principal deserialization failed during forwarding. This indicates an internal error on the broker cluster security setup.' - retriable = False - - -class SnapshotNotFoundError(BrokerResponseError): - errno = 98 - message = 'SNAPSHOT_NOT_FOUND' - description = 'Requested snapshot was not found.' - retriable = False - - -class PositionOutOfRangeError(BrokerResponseError): - errno = 99 - message = 'POSITION_OUT_OF_RANGE' - description = 'Requested position is not greater than or equal to zero, and less than the size of the snapshot.' - retriable = False - - -class UnknownTopicIdError(BrokerResponseError): - errno = 100 - message = 'UNKNOWN_TOPIC_ID' - description = 'This server does not host this topic ID.' - retriable = True - invalid_metadata = True - - -class DuplicateBrokerRegistrationError(BrokerResponseError): - errno = 101 - message = 'DUPLICATE_BROKER_REGISTRATION' - description = 'This broker ID is already in use.' - retriable = False - - -class BrokerIdNotRegisteredError(BrokerResponseError): - errno = 102 - message = 'BROKER_ID_NOT_REGISTERED' - description = 'The given broker ID was not registered.' - retriable = False - - -class InconsistentTopicIdError(BrokerResponseError): - errno = 103 - message = 'INCONSISTENT_TOPIC_ID' - description = 'The log\'s topic ID did not match the topic ID in the request.' - retriable = True - invalid_metadata = True - - -class InconsistentClusterIdError(BrokerResponseError): - errno = 104 - message = 'INCONSISTENT_CLUSTER_ID' - description = 'The clusterId in the request does not match that found on the server.' - retriable = False - - -class TransactionalIdNotFoundError(BrokerResponseError): - errno = 105 - message = 'TRANSACTIONAL_ID_NOT_FOUND' - description = 'The transactionalId could not be found.' - retriable = False - - -class FetchSessionTopicIdError(BrokerResponseError): - errno = 106 - message = 'FETCH_SESSION_TOPIC_ID_ERROR' - description = 'The fetch session encountered inconsistent topic ID usage.' - retriable = True - - -class IneligibleReplicaError(BrokerResponseError): - errno = 107 - message = 'INELIGIBLE_REPLICA' - description = 'The new ISR contains at least one ineligible replica.' - retriable = False - - -class NewLeaderElectedError(BrokerResponseError): - errno = 108 - message = 'NEW_LEADER_ELECTED' - description = 'The AlterPartition request successfully updated the partition state but the leader has changed.' - retriable = False - - -class OffsetMovedToTieredStorageError(BrokerResponseError): - errno = 109 - message = 'OFFSET_MOVED_TO_TIERED_STORAGE' - description = 'The requested offset is moved to tiered storage.' - retriable = False - - -class FencedMemberEpochError(BrokerResponseError): - errno = 110 - message = 'FENCED_MEMBER_EPOCH' - description = 'The member epoch is fenced by the group coordinator. The member must abandon all its partitions and rejoin.' - retriable = False - - -class UnreleasedInstanceIdError(BrokerResponseError): - errno = 111 - message = 'UNRELEASED_INSTANCE_ID' - description = 'The instance ID is still used by another member in the consumer group. That member must leave first.' - retriable = False - - -class UnsupportedAssignorError(BrokerResponseError): - errno = 112 - message = 'UNSUPPORTED_ASSIGNOR' - description = 'The assignor or its version range is not supported by the consumer group.' - retriable = False - - -class StaleMemberEpochError(BrokerResponseError): - errno = 113 - message = 'STALE_MEMBER_EPOCH' - description = 'The member epoch is stale. The member must retry after receiving its updated member epoch via the ConsumerGroupHeartbeat API.' - retriable = False - - -class MismatchedEndpointTypeError(BrokerResponseError): - errno = 114 - message = 'MISMATCHED_ENDPOINT_TYPE' - description = 'The request was sent to an endpoint of the wrong type.' - retriable = False - - -class UnsupportedEndpointTypeError(BrokerResponseError): - errno = 115 - message = 'UNSUPPORTED_ENDPOINT_TYPE' - description = 'This endpoint type is not supported yet.' - retriable = False - - -class UnknownControllerIdError(BrokerResponseError): - errno = 116 - message = 'UNKNOWN_CONTROLLER_ID' - description = 'This controller ID is not known.' - retriable = False - - -class UnknownSubscriptionIdError(BrokerResponseError): - errno = 117 - message = 'UNKNOWN_SUBSCRIPTION_ID' - description = 'Client sent a push telemetry request with an invalid or outdated subscription ID.' - retriable = False - - -class TelemetryTooLargeError(BrokerResponseError): - errno = 118 - message = 'TELEMETRY_TOO_LARGE' - description = 'Client sent a push telemetry request larger than the maximum size the broker will accept.' - retriable = False - - -class InvalidRegistrationError(BrokerResponseError): - errno = 119 - message = 'INVALID_REGISTRATION' - description = 'The controller has considered the broker registration to be invalid.' - retriable = False - - -class TransactionAbortableError(BrokerResponseError): - errno = 120 - message = 'TRANSACTION_ABORTABLE' - description = 'The server encountered an error with the transaction. The client can abort the transaction to continue using this transactional ID.' - retriable = False - - -class InvalidRecordStateError(BrokerResponseError): - errno = 121 - message = 'INVALID_RECORD_STATE' - description = 'The record state is invalid. The acknowledgement of delivery could not be completed.' - retriable = False - - -class ShareSessionNotFoundError(BrokerResponseError): - errno = 122 - message = 'SHARE_SESSION_NOT_FOUND' - description = 'The share session was not found.' - retriable = True - - -class InvalidShareSessionEpochError(BrokerResponseError): - errno = 123 - message = 'INVALID_SHARE_SESSION_EPOCH' - description = 'The share session epoch is invalid.' - retriable = True - - -class FencedStateEpochError(BrokerResponseError): - errno = 124 - message = 'FENCED_STATE_EPOCH' - description = 'The share coordinator rejected the request because the share-group state epoch did not match.' - retriable = False - - -class InvalidVoterKeyError(BrokerResponseError): - errno = 125 - message = 'INVALID_VOTER_KEY' - description = 'The voter key doesn\'t match the receiving replica\'s key.' - retriable = False - - -class DuplicateVoterError(BrokerResponseError): - errno = 126 - message = 'DUPLICATE_VOTER' - description = 'The voter is already part of the set of voters.' - retriable = False - - -class VoterNotFoundError(BrokerResponseError): - errno = 127 - message = 'VOTER_NOT_FOUND' - description = 'The voter is not part of the set of voters.' - retriable = False +class AsyncProducerQueueFull(KafkaError): + def __init__(self, failed_msgs, *args): + super(AsyncProducerQueueFull, self).__init__(*args) + self.failed_msgs = failed_msgs def _iter_broker_errors(): @@ -1055,12 +512,27 @@ kafka_errors = dict([(x.errno, x) for x in _iter_broker_errors()]) def for_code(error_code): - if error_code in kafka_errors: - return kafka_errors[error_code] - else: - # The broker error code was not found in our list. This can happen when connecting - # to a newer broker (with new error codes), or simply because our error list is - # not complete. - # - # To avoid dropping the error code, create a dynamic error class w/ errno override. - return type('UnrecognizedBrokerError', (UnknownError,), {'errno': error_code}) + return kafka_errors.get(error_code, UnknownError) + + +def check_error(response): + if isinstance(response, Exception): + raise response + if response.error: + error_class = kafka_errors.get(response.error, UnknownError) + raise error_class(response) + + +RETRY_BACKOFF_ERROR_TYPES = ( + KafkaUnavailableError, LeaderNotAvailableError, + KafkaConnectionError, FailedPayloadsError +) + + +RETRY_REFRESH_ERROR_TYPES = ( + NotLeaderForPartitionError, UnknownTopicOrPartitionError, + LeaderNotAvailableError, KafkaConnectionError +) + + +RETRY_ERROR_TYPES = RETRY_BACKOFF_ERROR_TYPES + RETRY_REFRESH_ERROR_TYPES diff --git a/venv/lib/python3.12/site-packages/kafka/future.py b/venv/lib/python3.12/site-packages/kafka/future.py index 2af061e..d0f3c66 100644 --- a/venv/lib/python3.12/site-packages/kafka/future.py +++ b/venv/lib/python3.12/site-packages/kafka/future.py @@ -2,7 +2,6 @@ from __future__ import absolute_import import functools import logging -import threading log = logging.getLogger(__name__) @@ -16,7 +15,6 @@ class Future(object): self.exception = None self._callbacks = [] self._errbacks = [] - self._lock = threading.Lock() def succeeded(self): return self.is_done and not bool(self.exception) @@ -32,46 +30,37 @@ class Future(object): def success(self, value): assert not self.is_done, 'Future is already complete' - with self._lock: - self.value = value - self.is_done = True + self.value = value + self.is_done = True if self._callbacks: self._call_backs('callback', self._callbacks, self.value) return self def failure(self, e): assert not self.is_done, 'Future is already complete' - exception = e if type(e) is not type else e() - assert isinstance(exception, BaseException), ( + self.exception = e if type(e) is not type else e() + assert isinstance(self.exception, BaseException), ( 'future failed without an exception') - with self._lock: - self.exception = exception - self.is_done = True + self.is_done = True self._call_backs('errback', self._errbacks, self.exception) return self def add_callback(self, f, *args, **kwargs): if args or kwargs: f = functools.partial(f, *args, **kwargs) - with self._lock: - if not self.is_done: - self._callbacks.append(f) - elif self.succeeded(): - self._lock.release() - self._call_backs('callback', [f], self.value) - self._lock.acquire() + if self.is_done and not self.exception: + self._call_backs('callback', [f], self.value) + else: + self._callbacks.append(f) return self def add_errback(self, f, *args, **kwargs): if args or kwargs: f = functools.partial(f, *args, **kwargs) - with self._lock: - if not self.is_done: - self._errbacks.append(f) - elif self.failed(): - self._lock.release() - self._call_backs('errback', [f], self.exception) - self._lock.acquire() + if self.is_done and self.exception: + self._call_backs('errback', [f], self.exception) + else: + self._errbacks.append(f) return self def add_both(self, f, *args, **kwargs): diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/compound_stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/compound_stat.py index f5b482d..ac92480 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/compound_stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/compound_stat.py @@ -3,16 +3,16 @@ from __future__ import absolute_import import abc from kafka.metrics.stat import AbstractStat -from kafka.vendor.six import add_metaclass -@add_metaclass(abc.ABCMeta) class AbstractCompoundStat(AbstractStat): """ A compound stat is a stat where a single measurement and associated data structure feeds many metrics. This is the example for a histogram which has many associated percentiles. """ + __metaclass__ = abc.ABCMeta + def stats(self): """ Return list of NamedMeasurable @@ -21,8 +21,6 @@ class AbstractCompoundStat(AbstractStat): class NamedMeasurable(object): - __slots__ = ('_name', '_stat') - def __init__(self, metric_name, measurable_stat): self._name = metric_name self._stat = measurable_stat diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/kafka_metric.py b/venv/lib/python3.12/site-packages/kafka/metrics/kafka_metric.py index fef6848..9fb8d89 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/kafka_metric.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/kafka_metric.py @@ -4,8 +4,6 @@ import time class KafkaMetric(object): - __slots__ = ('_metric_name', '_measurable', '_config') - # NOTE java constructor takes a lock instance def __init__(self, metric_name, measurable, config): if not metric_name: @@ -35,4 +33,4 @@ class KafkaMetric(object): def value(self, time_ms=None): if time_ms is None: time_ms = time.time() * 1000 - return self._measurable.measure(self._config, time_ms) + return self.measurable.measure(self.config, time_ms) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/measurable_stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/measurable_stat.py index 08222b1..4487adf 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/measurable_stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/measurable_stat.py @@ -4,10 +4,8 @@ import abc from kafka.metrics.measurable import AbstractMeasurable from kafka.metrics.stat import AbstractStat -from kafka.vendor.six import add_metaclass -@add_metaclass(abc.ABCMeta) class AbstractMeasurableStat(AbstractStat, AbstractMeasurable): """ An AbstractMeasurableStat is an AbstractStat that is also @@ -15,3 +13,4 @@ class AbstractMeasurableStat(AbstractStat, AbstractMeasurable): This is the interface used for most of the simple statistics such as Avg, Max, Count, etc. """ + __metaclass__ = abc.ABCMeta diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/metric_config.py b/venv/lib/python3.12/site-packages/kafka/metrics/metric_config.py index 7e5ead1..2e55abf 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/metric_config.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/metric_config.py @@ -5,8 +5,6 @@ import sys class MetricConfig(object): """Configuration values for metrics""" - __slots__ = ('quota', '_samples', 'event_window', 'time_window_ms', 'tags') - def __init__(self, quota=None, samples=2, event_window=sys.maxsize, time_window_ms=30 * 1000, tags=None): """ diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/metric_name.py b/venv/lib/python3.12/site-packages/kafka/metrics/metric_name.py index b8ab2a3..b5acd16 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/metric_name.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/metric_name.py @@ -38,7 +38,6 @@ class MetricName(object): # as messages are sent we record the sizes sensor.record(message_size) """ - __slots__ = ('_name', '_group', '_description', '_tags', '_hash') def __init__(self, name, group, description=None, tags=None): """ @@ -94,7 +93,7 @@ class MetricName(object): return True if other is None: return False - return (isinstance(self, type(other)) and + return (type(self) == type(other) and self.group == other.group and self.name == other.name and self.tags == other.tags) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/metrics.py b/venv/lib/python3.12/site-packages/kafka/metrics/metrics.py index 41a37db..2c53488 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/metrics.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/metrics.py @@ -55,11 +55,10 @@ class Metrics(object): self._reporters = reporters or [] for reporter in self._reporters: reporter.init([]) - self._closed = False if enable_expiration: def expire_loop(): - while not self._closed: + while True: # delay 30 seconds time.sleep(30) self.ExpireSensorTask.run(self) @@ -260,4 +259,3 @@ class Metrics(object): reporter.close() self._metrics.clear() - self._closed = True diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/metrics_reporter.py b/venv/lib/python3.12/site-packages/kafka/metrics/metrics_reporter.py index 8df2e9e..d8bd12b 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/metrics_reporter.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/metrics_reporter.py @@ -2,15 +2,14 @@ from __future__ import absolute_import import abc -from kafka.vendor.six import add_metaclass - -@add_metaclass(abc.ABCMeta) class AbstractMetricsReporter(object): """ An abstract class to allow things to listen as new metrics are created so they can be reported. """ + __metaclass__ = abc.ABCMeta + @abc.abstractmethod def init(self, metrics): """ diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/quota.py b/venv/lib/python3.12/site-packages/kafka/metrics/quota.py index 36a30c4..4d1b0d6 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/quota.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/quota.py @@ -3,8 +3,6 @@ from __future__ import absolute_import class Quota(object): """An upper or lower bound for metrics""" - __slots__ = ('_bound', '_upper') - def __init__(self, bound, is_upper): self._bound = bound self._upper = is_upper @@ -36,7 +34,7 @@ class Quota(object): def __eq__(self, other): if self is other: return True - return (isinstance(self, type(other)) and + return (type(self) == type(other) and self.bound == other.bound and self.is_upper_bound() == other.is_upper_bound()) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/stat.py index 8825d27..9fd2f01 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stat.py @@ -2,15 +2,14 @@ from __future__ import absolute_import import abc -from kafka.vendor.six import add_metaclass - -@add_metaclass(abc.ABCMeta) class AbstractStat(object): """ An AbstractStat is a quantity such as average, max, etc that is computed off the stream of updates to a sensor """ + __metaclass__ = abc.ABCMeta + @abc.abstractmethod def record(self, config, value, time_ms): """ diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/avg.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/avg.py index 906d955..cfbaec3 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/avg.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/avg.py @@ -7,8 +7,6 @@ class Avg(AbstractSampledStat): """ An AbstractSampledStat that maintains a simple average over its samples. """ - __slots__ = ('_initial_value', '_samples', '_current') - def __init__(self): super(Avg, self).__init__(0.0) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/count.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/count.py index 6cd6d2a..6e0a2d5 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/count.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/count.py @@ -7,8 +7,6 @@ class Count(AbstractSampledStat): """ An AbstractSampledStat that maintains a simple count of what it has seen. """ - __slots__ = ('_initial_value', '_samples', '_current') - def __init__(self): super(Count, self).__init__(0.0) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/histogram.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/histogram.py index 2c8afbf..ecc6c9d 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/histogram.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/histogram.py @@ -4,8 +4,6 @@ import math class Histogram(object): - __slots__ = ('_hist', '_count', '_bin_scheme') - def __init__(self, bin_scheme): self._hist = [0.0] * bin_scheme.bins self._count = 0.0 @@ -42,8 +40,6 @@ class Histogram(object): return '{%s}' % ','.join(values) class ConstantBinScheme(object): - __slots__ = ('_min', '_max', '_bins', '_bucket_width') - def __init__(self, bins, min_val, max_val): if bins < 2: raise ValueError('Must have at least 2 bins.') @@ -73,8 +69,6 @@ class Histogram(object): return int(((x - self._min) / self._bucket_width) + 1) class LinearBinScheme(object): - __slots__ = ('_bins', '_max', '_scale') - def __init__(self, num_bins, max_val): self._bins = num_bins self._max = max_val diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/max_stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/max_stat.py index 9c5eeb6..08aebdd 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/max_stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/max_stat.py @@ -5,8 +5,6 @@ from kafka.metrics.stats.sampled_stat import AbstractSampledStat class Max(AbstractSampledStat): """An AbstractSampledStat that gives the max over its samples.""" - __slots__ = ('_initial_value', '_samples', '_current') - def __init__(self): super(Max, self).__init__(float('-inf')) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/min_stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/min_stat.py index 6bebe57..072106d 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/min_stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/min_stat.py @@ -7,8 +7,6 @@ from kafka.metrics.stats.sampled_stat import AbstractSampledStat class Min(AbstractSampledStat): """An AbstractSampledStat that gives the min over its samples.""" - __slots__ = ('_initial_value', '_samples', '_current') - def __init__(self): super(Min, self).__init__(float(sys.maxsize)) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentile.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentile.py index 75e64ce..3a86a84 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentile.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentile.py @@ -2,8 +2,6 @@ from __future__ import absolute_import class Percentile(object): - __slots__ = ('_metric_name', '_percentile') - def __init__(self, metric_name, percentile): self._metric_name = metric_name self._percentile = float(percentile) diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentiles.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentiles.py index 2cb2d84..6d702e8 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentiles.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/percentiles.py @@ -13,9 +13,6 @@ class BucketSizing(object): class Percentiles(AbstractSampledStat, AbstractCompoundStat): """A compound stat that reports one or more percentiles""" - __slots__ = ('_initial_value', '_samples', '_current', - '_percentiles', '_buckets', '_bin_scheme') - def __init__(self, size_in_bytes, bucketing, max_val, min_val=0.0, percentiles=None): super(Percentiles, self).__init__(0.0) @@ -30,7 +27,7 @@ class Percentiles(AbstractSampledStat, AbstractCompoundStat): ' to be 0.0.') self.bin_scheme = Histogram.LinearBinScheme(self._buckets, max_val) else: - raise ValueError('Unknown bucket type: %s' % (bucketing,)) + ValueError('Unknown bucket type: %s' % (bucketing,)) def stats(self): measurables = [] diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/rate.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/rate.py index 4d0ba0f..68393fb 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/rate.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/rate.py @@ -37,8 +37,6 @@ class Rate(AbstractMeasurableStat): occurrences (e.g. the count of values measured over the time interval) or other such values. """ - __slots__ = ('_stat', '_unit') - def __init__(self, time_unit=TimeUnit.SECONDS, sampled_stat=None): self._stat = sampled_stat or SampledTotal() self._unit = time_unit @@ -107,7 +105,6 @@ class Rate(AbstractMeasurableStat): class SampledTotal(AbstractSampledStat): - __slots__ = ('_initial_value', '_samples', '_current') def __init__(self, initial_value=None): if initial_value is not None: raise ValueError('initial_value cannot be set on SampledTotal') diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/sampled_stat.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/sampled_stat.py index fe8970d..c41b14b 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/sampled_stat.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/sampled_stat.py @@ -3,10 +3,8 @@ from __future__ import absolute_import import abc from kafka.metrics.measurable_stat import AbstractMeasurableStat -from kafka.vendor.six import add_metaclass -@add_metaclass(abc.ABCMeta) class AbstractSampledStat(AbstractMeasurableStat): """ An AbstractSampledStat records a single scalar value measured over @@ -22,7 +20,7 @@ class AbstractSampledStat(AbstractMeasurableStat): Subclasses of this class define different statistics measured using this basic pattern. """ - __slots__ = ('_initial_value', '_samples', '_current') + __metaclass__ = abc.ABCMeta def __init__(self, initial_value): self._initial_value = initial_value diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/sensor.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/sensor.py index 9f7ac45..571723f 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/sensor.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/sensor.py @@ -15,10 +15,6 @@ class Sensor(object): the `record(double)` api and would maintain a set of metrics about request sizes such as the average or max. """ - __slots__ = ('_lock', '_registry', '_name', '_parents', '_metrics', - '_stats', '_config', '_inactive_sensor_expiration_time_ms', - '_last_record_time') - def __init__(self, registry, name, parents, config, inactive_sensor_expiration_time_seconds): if not name: diff --git a/venv/lib/python3.12/site-packages/kafka/metrics/stats/total.py b/venv/lib/python3.12/site-packages/kafka/metrics/stats/total.py index a78e997..5b3bb87 100644 --- a/venv/lib/python3.12/site-packages/kafka/metrics/stats/total.py +++ b/venv/lib/python3.12/site-packages/kafka/metrics/stats/total.py @@ -5,8 +5,6 @@ from kafka.metrics.measurable_stat import AbstractMeasurableStat class Total(AbstractMeasurableStat): """An un-windowed cumulative total maintained over all time.""" - __slots__ = ('_total') - def __init__(self, value=0.0): self._total = value diff --git a/venv/lib/python3.12/site-packages/kafka/oauth/__init__.py b/venv/lib/python3.12/site-packages/kafka/oauth/__init__.py new file mode 100644 index 0000000..8c83495 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka/oauth/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +from kafka.oauth.abstract import AbstractTokenProvider diff --git a/venv/lib/python3.12/site-packages/kafka/oauth/abstract.py b/venv/lib/python3.12/site-packages/kafka/oauth/abstract.py new file mode 100644 index 0000000..8d89ff5 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka/oauth/abstract.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import + +import abc + +# This statement is compatible with both Python 2.7 & 3+ +ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) + +class AbstractTokenProvider(ABC): + """ + A Token Provider must be used for the SASL OAuthBearer protocol. + + The implementation should ensure token reuse so that multiple + calls at connect time do not create multiple tokens. The implementation + should also periodically refresh the token in order to guarantee + that each call returns an unexpired token. A timeout error should + be returned after a short period of inactivity so that the + broker can log debugging info and retry. + + Token Providers MUST implement the token() method + """ + + def __init__(self, **config): + pass + + @abc.abstractmethod + def token(self): + """ + Returns a (str) ID/Access Token to be sent to the Kafka + client. + """ + pass + + def extensions(self): + """ + This is an OPTIONAL method that may be implemented. + + Returns a map of key-value pairs that can + be sent with the SASL/OAUTHBEARER initial client request. If + not implemented, the values are ignored. This feature is only available + in Kafka >= 2.1.0. + """ + return {} diff --git a/venv/lib/python3.12/site-packages/kafka/producer/buffer.py b/venv/lib/python3.12/site-packages/kafka/producer/buffer.py new file mode 100644 index 0000000..1008017 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka/producer/buffer.py @@ -0,0 +1,115 @@ +from __future__ import absolute_import, division + +import collections +import io +import threading +import time + +from kafka.metrics.stats import Rate + +import kafka.errors as Errors + + +class SimpleBufferPool(object): + """A simple pool of BytesIO objects with a weak memory ceiling.""" + def __init__(self, memory, poolable_size, metrics=None, metric_group_prefix='producer-metrics'): + """Create a new buffer pool. + + Arguments: + memory (int): maximum memory that this buffer pool can allocate + poolable_size (int): memory size per buffer to cache in the free + list rather than deallocating + """ + self._poolable_size = poolable_size + self._lock = threading.RLock() + + buffers = int(memory / poolable_size) if poolable_size else 0 + self._free = collections.deque([io.BytesIO() for _ in range(buffers)]) + + self._waiters = collections.deque() + self.wait_time = None + if metrics: + self.wait_time = metrics.sensor('bufferpool-wait-time') + self.wait_time.add(metrics.metric_name( + 'bufferpool-wait-ratio', metric_group_prefix, + 'The fraction of time an appender waits for space allocation.'), + Rate()) + + def allocate(self, size, max_time_to_block_ms): + """ + Allocate a buffer of the given size. This method blocks if there is not + enough memory and the buffer pool is configured with blocking mode. + + Arguments: + size (int): The buffer size to allocate in bytes [ignored] + max_time_to_block_ms (int): The maximum time in milliseconds to + block for buffer memory to be available + + Returns: + io.BytesIO + """ + with self._lock: + # check if we have a free buffer of the right size pooled + if self._free: + return self._free.popleft() + + elif self._poolable_size == 0: + return io.BytesIO() + + else: + # we are out of buffers and will have to block + buf = None + more_memory = threading.Condition(self._lock) + self._waiters.append(more_memory) + # loop over and over until we have a buffer or have reserved + # enough memory to allocate one + while buf is None: + start_wait = time.time() + more_memory.wait(max_time_to_block_ms / 1000.0) + end_wait = time.time() + if self.wait_time: + self.wait_time.record(end_wait - start_wait) + + if self._free: + buf = self._free.popleft() + else: + self._waiters.remove(more_memory) + raise Errors.KafkaTimeoutError( + "Failed to allocate memory within the configured" + " max blocking time") + + # remove the condition for this thread to let the next thread + # in line start getting memory + removed = self._waiters.popleft() + assert removed is more_memory, 'Wrong condition' + + # signal any additional waiters if there is more memory left + # over for them + if self._free and self._waiters: + self._waiters[0].notify() + + # unlock and return the buffer + return buf + + def deallocate(self, buf): + """ + Return buffers to the pool. If they are of the poolable size add them + to the free list, otherwise just mark the memory as free. + + Arguments: + buffer_ (io.BytesIO): The buffer to return + """ + with self._lock: + # BytesIO.truncate here makes the pool somewhat pointless + # but we stick with the BufferPool API until migrating to + # bytesarray / memoryview. The buffer we return must not + # expose any prior data on read(). + buf.truncate(0) + self._free.append(buf) + if self._waiters: + self._waiters[0].notify() + + def queued(self): + """The number of threads blocked waiting on memory.""" + with self._lock: + return len(self._waiters) diff --git a/venv/lib/python3.12/site-packages/kafka/producer/future.py b/venv/lib/python3.12/site-packages/kafka/producer/future.py index f67db09..07fa4ad 100644 --- a/venv/lib/python3.12/site-packages/kafka/producer/future.py +++ b/venv/lib/python3.12/site-packages/kafka/producer/future.py @@ -38,7 +38,7 @@ class FutureRecordMetadata(Future): produce_future.add_errback(self.failure) def _produce_success(self, offset_and_timestamp): - offset, produce_timestamp_ms = offset_and_timestamp + offset, produce_timestamp_ms, log_start_offset = offset_and_timestamp # Unpacking from args tuple is minor speed optimization (relative_offset, timestamp_ms, checksum, @@ -51,7 +51,7 @@ class FutureRecordMetadata(Future): if offset != -1 and relative_offset is not None: offset += relative_offset tp = self._produce_future.topic_partition - metadata = RecordMetadata(tp[0], tp[1], tp, offset, timestamp_ms, + metadata = RecordMetadata(tp[0], tp[1], tp, offset, timestamp_ms, log_start_offset, checksum, serialized_key_size, serialized_value_size, serialized_header_size) self.success(metadata) @@ -67,5 +67,5 @@ class FutureRecordMetadata(Future): RecordMetadata = collections.namedtuple( - 'RecordMetadata', ['topic', 'partition', 'topic_partition', 'offset', 'timestamp', + 'RecordMetadata', ['topic', 'partition', 'topic_partition', 'offset', 'timestamp', 'log_start_offset', 'checksum', 'serialized_key_size', 'serialized_value_size', 'serialized_header_size']) diff --git a/venv/lib/python3.12/site-packages/kafka/producer/kafka.py b/venv/lib/python3.12/site-packages/kafka/producer/kafka.py index 9401bd8..cde26b0 100644 --- a/venv/lib/python3.12/site-packages/kafka/producer/kafka.py +++ b/venv/lib/python3.12/site-packages/kafka/producer/kafka.py @@ -1,11 +1,11 @@ -from __future__ import absolute_import, division +from __future__ import absolute_import import atexit import copy import logging import socket import threading -import warnings +import time import weakref from kafka.vendor import six @@ -18,12 +18,10 @@ from kafka.partitioner.default import DefaultPartitioner from kafka.producer.future import FutureRecordMetadata, FutureProduceResult from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator from kafka.producer.sender import Sender -from kafka.producer.transaction_manager import TransactionManager from kafka.record.default_records import DefaultRecordBatchBuilder from kafka.record.legacy_records import LegacyRecordBatchBuilder from kafka.serializer import Serializer from kafka.structs import TopicPartition -from kafka.util import Timer, ensure_valid_topic_name log = logging.getLogger(__name__) @@ -36,8 +34,8 @@ class KafkaProducer(object): The producer is thread safe and sharing a single producer instance across threads will generally be faster than having multiple instances. - The producer consists of a RecordAccumulator which holds records that - haven't yet been transmitted to the server, and a Sender background I/O + The producer consists of a pool of buffer space that holds records that + haven't yet been transmitted to the server as well as a background I/O thread that is responsible for turning these records into requests and transmitting them to the cluster. @@ -73,50 +71,14 @@ class KafkaProducer(object): can lead to fewer, more efficient requests when not under maximal load at the cost of a small amount of latency. + The buffer_memory controls the total amount of memory available to the + producer for buffering. If records are sent faster than they can be + transmitted to the server then this buffer space will be exhausted. When + the buffer space is exhausted additional send calls will block. + The key_serializer and value_serializer instruct how to turn the key and value objects the user provides into bytes. - From Kafka 0.11, the KafkaProducer supports two additional modes: - the idempotent producer and the transactional producer. - The idempotent producer strengthens Kafka's delivery semantics from - at least once to exactly once delivery. In particular, producer retries - will no longer introduce duplicates. The transactional producer allows an - application to send messages to multiple partitions (and topics!) - atomically. - - To enable idempotence, the `enable_idempotence` configuration must be set - to True. If set, the `retries` config will default to `float('inf')` and - the `acks` config will default to 'all'. There are no API changes for the - idempotent producer, so existing applications will not need to be modified - to take advantage of this feature. - - To take advantage of the idempotent producer, it is imperative to avoid - application level re-sends since these cannot be de-duplicated. As such, if - an application enables idempotence, it is recommended to leave the - `retries` config unset, as it will be defaulted to `float('inf')`. - Additionally, if a :meth:`~kafka.KafkaProducer.send` returns an error even - with infinite retries (for instance if the message expires in the buffer - before being sent), then it is recommended to shut down the producer and - check the contents of the last produced message to ensure that it is not - duplicated. Finally, the producer can only guarantee idempotence for - messages sent within a single session. - - To use the transactional producer and the attendant APIs, you must set the - `transactional_id` configuration property. If the `transactional_id` is - set, idempotence is automatically enabled along with the producer configs - which idempotence depends on. Further, topics which are included in - transactions should be configured for durability. In particular, the - `replication.factor` should be at least `3`, and the `min.insync.replicas` - for these topics should be set to 2. Finally, in order for transactional - guarantees to be realized from end-to-end, the consumers must be - configured to read only committed messages as well. - - The purpose of the `transactional_id` is to enable transaction recovery - across multiple sessions of a single producer instance. It would typically - be derived from the shard identifier in a partitioned, stateful, - application. As such, it should be unique to each producer instance running - within a partitioned application. - Keyword Arguments: bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' strings) that the producer should contact to bootstrap initial @@ -134,28 +96,6 @@ class KafkaProducer(object): value_serializer (callable): used to convert user-supplied message values to bytes. If not None, called as f(value), should return bytes. Default: None. - enable_idempotence (bool): When set to True, the producer will ensure - that exactly one copy of each message is written in the stream. - If False, producer retries due to broker failures, etc., may write - duplicates of the retried message in the stream. Default: False. - - Note that enabling idempotence requires - `max_in_flight_requests_per_connection` to be set to 1 and `retries` - cannot be zero. Additionally, `acks` must be set to 'all'. If these - values are left at their defaults, the producer will override the - defaults to be suitable. If the values are set to something - incompatible with the idempotent producer, a KafkaConfigurationError - will be raised. - delivery_timeout_ms (float): An upper bound on the time to report success - or failure after producer.send() returns. This limits the total time - that a record will be delayed prior to sending, the time to await - acknowledgement from the broker (if expected), and the time allowed - for retriable send failures. The producer may report failure to send - a record earlier than this config if either an unrecoverable error is - encountered, the retries have been exhausted, or the record is added - to a batch which reached an earlier delivery expiration deadline. - The value of this config should be greater than or equal to the - sum of (request_timeout_ms + linger_ms). Default: 120000. acks (0, 1, 'all'): The number of acknowledgments the producer requires the leader to have received before considering a request complete. This controls the durability of records that are sent. The @@ -183,7 +123,7 @@ class KafkaProducer(object): Compression is of full batches of data, so the efficacy of batching will also impact the compression ratio (more batching means better compression). Default: None. - retries (numeric): Setting a value greater than zero will cause the client + retries (int): Setting a value greater than zero will cause the client to resend any record whose send fails with a potentially transient error. Note that this retry is no different than if the client resent the record upon receiving the error. Allowing retries @@ -191,12 +131,8 @@ class KafkaProducer(object): potentially change the ordering of records because if two batches are sent to a single partition, and the first fails and is retried but the second succeeds, then the records in the second batch may - appear first. Note additionally that produce requests will be - failed before the number of retries has been exhausted if the timeout - configured by delivery_timeout_ms expires first before successful - acknowledgement. Users should generally prefer to leave this config - unset and instead use delivery_timeout_ms to control retry behavior. - Default: float('inf') (infinite) + appear first. + Default: 0. batch_size (int): Requests sent to brokers will contain multiple batches, one for each partition with data available to be sent. A small batch size will make batching less common and may reduce @@ -229,6 +165,12 @@ class KafkaProducer(object): messages with the same key are assigned to the same partition. When a key is None, the message is delivered to a random partition (filtered to partitions with available leaders only, if possible). + buffer_memory (int): The total bytes of memory the producer should use + to buffer records waiting to be sent to the server. If records are + sent faster than they can be delivered to the server the producer + will block up to max_block_ms, raising an exception on timeout. + In the current implementation, this setting is an approximation. + Default: 33554432 (32MB) connections_max_idle_ms: Close idle connections after the number of milliseconds specified by this config. The broker closes idle connections after connections.max.idle.ms, so this avoids hitting @@ -246,9 +188,6 @@ class KafkaProducer(object): This setting will limit the number of record batches the producer will send in a single request to avoid sending huge requests. Default: 1048576. - allow_auto_create_topics (bool): Enable/disable auto topic creation - on metadata request. Only available with api_version >= (0, 11). - Default: True metadata_max_age_ms (int): The period of time in milliseconds after which we force a refresh of metadata even if we haven't seen any partition leadership changes to proactively discover any new @@ -277,7 +216,7 @@ class KafkaProducer(object): reconnection attempts will continue periodically with this fixed rate. To avoid connection storms, a randomization factor of 0.2 will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 30000. + 20% below and 20% above the computed value. Default: 1000. max_in_flight_requests_per_connection (int): Requests are pipelined to kafka brokers up to this number of maximum requests per broker connection. Note that if this setting is set to be greater @@ -294,7 +233,7 @@ class KafkaProducer(object): should verify that the certificate matches the brokers hostname. default: true. ssl_cafile (str): optional filename of ca file to use in certificate - verification. default: none. + veriication. default: none. ssl_certfile (str): optional filename of file in pem format containing the client certificate, as well as any ca certificates needed to establish the certificate's authenticity. default: none. @@ -313,28 +252,14 @@ class KafkaProducer(object): or other configuration forbids use of all the specified ciphers), an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers api_version (tuple): Specify which Kafka API version to use. If set to - None, the client will attempt to determine the broker version via - ApiVersionsRequest API or, for brokers earlier than 0.10, probing - various known APIs. Dynamic version checking is performed eagerly - during __init__ and can raise NoBrokersAvailableError if no connection - was made before timeout (see api_version_auto_timeout_ms below). - Different versions enable different functionality. - - Examples: - (3, 9) most recent broker release, enable all supported features - (0, 11) enables message format v2 (internal) - (0, 10, 0) enables sasl authentication and message format v1 - (0, 8, 0) enables basic functionality only - - Default: None + None, the client will attempt to infer the broker version by probing + various APIs. Example: (0, 10, 2). Default: None api_version_auto_timeout_ms (int): number of milliseconds to throw a timeout exception from the constructor when checking the broker api version. Only applies if api_version set to None. - Default: 2000 metric_reporters (list): A list of classes to use as metrics reporters. Implementing the AbstractMetricsReporter interface allows plugging in classes that will be notified of new metric creation. Default: [] - metrics_enabled (bool): Whether to track metrics on this instance. Default True. metrics_num_samples (int): The number of samples maintained to compute metrics. Default: 2 metrics_sample_window_ms (int): The maximum age in milliseconds of @@ -349,42 +274,33 @@ class KafkaProducer(object): Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_name (str or gssapi.Name): Constructed gssapi.Name for use with - sasl mechanism handshake. If provided, sasl_kerberos_service_name and - sasl_kerberos_domain name are ignored. Default: None. sasl_kerberos_service_name (str): Service name to include in GSSAPI sasl mechanism handshake. Default: 'kafka' sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (kafka.sasl.oauth.AbstractTokenProvider): OAuthBearer - token provider instance. Default: None - socks5_proxy (str): Socks5 proxy URL. Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None Note: Configuration parameters are described in more detail at - https://kafka.apache.org/0100/documentation/#producerconfigs + https://kafka.apache.org/0100/configuration.html#producerconfigs """ DEFAULT_CONFIG = { 'bootstrap_servers': 'localhost', 'client_id': None, 'key_serializer': None, 'value_serializer': None, - 'enable_idempotence': False, - 'transactional_id': None, - 'transaction_timeout_ms': 60000, - 'delivery_timeout_ms': 120000, 'acks': 1, 'bootstrap_topics_filter': set(), 'compression_type': None, - 'retries': float('inf'), + 'retries': 0, 'batch_size': 16384, 'linger_ms': 0, 'partitioner': DefaultPartitioner(), + 'buffer_memory': 33554432, 'connections_max_idle_ms': 9 * 60 * 1000, 'max_block_ms': 60000, 'max_request_size': 1048576, - 'allow_auto_create_topics': True, 'metadata_max_age_ms': 300000, 'retry_backoff_ms': 100, 'request_timeout_ms': 30000, @@ -394,7 +310,7 @@ class KafkaProducer(object): 'sock_chunk_bytes': 4096, # undocumented experimental option 'sock_chunk_buffer_count': 1000, # undocumented experimental option 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 30000, + 'reconnect_backoff_max_ms': 1000, 'max_in_flight_requests_per_connection': 5, 'security_protocol': 'PLAINTEXT', 'ssl_context': None, @@ -408,23 +324,17 @@ class KafkaProducer(object): 'api_version': None, 'api_version_auto_timeout_ms': 2000, 'metric_reporters': [], - 'metrics_enabled': True, 'metrics_num_samples': 2, 'metrics_sample_window_ms': 30000, 'selector': selectors.DefaultSelector, 'sasl_mechanism': None, 'sasl_plain_username': None, 'sasl_plain_password': None, - 'sasl_kerberos_name': None, 'sasl_kerberos_service_name': 'kafka', 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - 'socks5_proxy': None, - 'kafka_client': KafkaClient, + 'sasl_oauth_token_provider': None } - DEPRECATED_CONFIGS = ('buffer_memory',) - _COMPRESSORS = { 'gzip': (has_gzip, LegacyRecordBatchBuilder.CODEC_GZIP), 'snappy': (has_snappy, LegacyRecordBatchBuilder.CODEC_SNAPPY), @@ -434,17 +344,12 @@ class KafkaProducer(object): } def __init__(self, **configs): + log.debug("Starting the Kafka producer") # trace self.config = copy.copy(self.DEFAULT_CONFIG) - user_provided_configs = set(configs.keys()) for key in self.config: if key in configs: self.config[key] = configs.pop(key) - for key in self.DEPRECATED_CONFIGS: - if key in configs: - configs.pop(key) - warnings.warn('Deprecated Producer config: %s' % (key,), DeprecationWarning) - # Only check for extra config keys in top-level class assert not configs, 'Unrecognized configs: %s' % (configs,) @@ -462,35 +367,30 @@ class KafkaProducer(object): self.config['api_version'] = None else: self.config['api_version'] = tuple(map(int, deprecated.split('.'))) - log.warning('%s: use api_version=%s [tuple] -- "%s" as str is deprecated', - str(self), str(self.config['api_version']), deprecated) - - log.debug("%s: Starting Kafka producer", str(self)) + log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', + str(self.config['api_version']), deprecated) # Configure metrics - if self.config['metrics_enabled']: - metrics_tags = {'client-id': self.config['client_id']} - metric_config = MetricConfig(samples=self.config['metrics_num_samples'], - time_window_ms=self.config['metrics_sample_window_ms'], - tags=metrics_tags) - reporters = [reporter() for reporter in self.config['metric_reporters']] - self._metrics = Metrics(metric_config, reporters) - else: - self._metrics = None + metrics_tags = {'client-id': self.config['client_id']} + metric_config = MetricConfig(samples=self.config['metrics_num_samples'], + time_window_ms=self.config['metrics_sample_window_ms'], + tags=metrics_tags) + reporters = [reporter() for reporter in self.config['metric_reporters']] + self._metrics = Metrics(metric_config, reporters) - client = self.config['kafka_client']( - metrics=self._metrics, metric_group_prefix='producer', - wakeup_timeout_ms=self.config['max_block_ms'], - **self.config) + client = KafkaClient(metrics=self._metrics, metric_group_prefix='producer', + wakeup_timeout_ms=self.config['max_block_ms'], + **self.config) - # Get auto-discovered / normalized version from client - self.config['api_version'] = client.config['api_version'] + # Get auto-discovered version from client if necessary + if self.config['api_version'] is None: + self.config['api_version'] = client.config['api_version'] if self.config['compression_type'] == 'lz4': assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers' if self.config['compression_type'] == 'zstd': - assert self.config['api_version'] >= (2, 1), 'Zstd Requires >= Kafka 2.1 Brokers' + assert self.config['api_version'] >= (2, 1, 0), 'Zstd Requires >= Kafka 2.1.0 Brokers' # Check compression_type for library support ct = self.config['compression_type'] @@ -501,58 +401,12 @@ class KafkaProducer(object): assert checker(), "Libraries for {} compression codec not found".format(ct) self.config['compression_attrs'] = compression_attrs + message_version = self._max_usable_produce_magic() + self._accumulator = RecordAccumulator(message_version=message_version, metrics=self._metrics, **self.config) self._metadata = client.cluster - self._transaction_manager = None - self._init_transactions_result = None - if 'enable_idempotence' in user_provided_configs and not self.config['enable_idempotence'] and self.config['transactional_id']: - raise Errors.KafkaConfigurationError("Cannot set transactional_id without enable_idempotence.") - - if self.config['transactional_id']: - self.config['enable_idempotence'] = True - - if self.config['enable_idempotence']: - assert self.config['api_version'] >= (0, 11), "Transactional/Idempotent producer requires >= Kafka 0.11 Brokers" - - self._transaction_manager = TransactionManager( - transactional_id=self.config['transactional_id'], - transaction_timeout_ms=self.config['transaction_timeout_ms'], - retry_backoff_ms=self.config['retry_backoff_ms'], - api_version=self.config['api_version'], - metadata=self._metadata, - ) - if self._transaction_manager.is_transactional(): - log.info("%s: Instantiated a transactional producer.", str(self)) - else: - log.info("%s: Instantiated an idempotent producer.", str(self)) - - if self.config['retries'] == 0: - raise Errors.KafkaConfigurationError("Must set 'retries' to non-zero when using the idempotent producer.") - - if 'max_in_flight_requests_per_connection' not in user_provided_configs: - log.info("%s: Overriding the default 'max_in_flight_requests_per_connection' to 1 since idempontence is enabled.", str(self)) - self.config['max_in_flight_requests_per_connection'] = 1 - elif self.config['max_in_flight_requests_per_connection'] != 1: - raise Errors.KafkaConfigurationError("Must set 'max_in_flight_requests_per_connection' to 1 in order" - " to use the idempotent producer." - " Otherwise we cannot guarantee idempotence.") - - if 'acks' not in user_provided_configs: - log.info("%s: Overriding the default 'acks' config to 'all' since idempotence is enabled", str(self)) - self.config['acks'] = -1 - elif self.config['acks'] != -1: - raise Errors.KafkaConfigurationError("Must set 'acks' config to 'all' in order to use the idempotent" - " producer. Otherwise we cannot guarantee idempotence") - - message_version = self.max_usable_produce_magic(self.config['api_version']) - self._accumulator = RecordAccumulator( - transaction_manager=self._transaction_manager, - message_version=message_version, - **self.config) guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1) self._sender = Sender(client, self._metadata, - self._accumulator, - metrics=self._metrics, - transaction_manager=self._transaction_manager, + self._accumulator, self._metrics, guarantee_message_order=guarantee_message_order, **self.config) self._sender.daemon = True @@ -561,7 +415,7 @@ class KafkaProducer(object): self._cleanup = self._cleanup_factory() atexit.register(self._cleanup) - log.debug("%s: Kafka producer started", str(self)) + log.debug("Kafka producer started") def bootstrap_connected(self): """Return True if the bootstrap is connected.""" @@ -572,7 +426,7 @@ class KafkaProducer(object): _self = weakref.proxy(self) def wrapper(): try: - _self.close(timeout=0, null_logger=True) + _self.close(timeout=0) except (ReferenceError, AttributeError): pass return wrapper @@ -595,28 +449,28 @@ class KafkaProducer(object): self._cleanup = None def __del__(self): - self.close(timeout=1, null_logger=True) + # Disable logger during destruction to avoid touching dangling references + class NullLogger(object): + def __getattr__(self, name): + return lambda *args: None - def close(self, timeout=None, null_logger=False): + global log + log = NullLogger() + + self.close() + + def close(self, timeout=None): """Close this producer. Arguments: timeout (float, optional): timeout in seconds to wait for completion. """ - if null_logger: - # Disable logger during destruction to avoid touching dangling references - class NullLogger(object): - def __getattr__(self, name): - return lambda *args: None - - global log - log = NullLogger() # drop our atexit handler now to avoid leaks self._unregister_cleanup() if not hasattr(self, '_closed') or self._closed: - log.info('%s: Kafka producer closed', str(self)) + log.info('Kafka producer closed') return if timeout is None: # threading.TIMEOUT_MAX is available in Python3.3+ @@ -626,16 +480,15 @@ class KafkaProducer(object): else: assert timeout >= 0 - log.info("%s: Closing the Kafka producer with %s secs timeout.", str(self), timeout) - self.flush(timeout) + log.info("Closing the Kafka producer with %s secs timeout.", timeout) invoked_from_callback = bool(threading.current_thread() is self._sender) if timeout > 0: if invoked_from_callback: - log.warning("%s: Overriding close timeout %s secs to 0 in order to" + log.warning("Overriding close timeout %s secs to 0 in order to" " prevent useless blocking due to self-join. This" " means you have incorrectly invoked close with a" " non-zero timeout from the producer call-back.", - str(self), timeout) + timeout) else: # Try to close gracefully. if self._sender is not None: @@ -643,13 +496,12 @@ class KafkaProducer(object): self._sender.join(timeout) if self._sender is not None and self._sender.is_alive(): - log.info("%s: Proceeding to force close the producer since pending" + log.info("Proceeding to force close the producer since pending" " requests could not be completed within timeout %s.", - str(self), timeout) + timeout) self._sender.force_close() - if self._metrics: - self._metrics.close() + self._metrics.close() try: self.config['key_serializer'].close() except AttributeError: @@ -659,23 +511,23 @@ class KafkaProducer(object): except AttributeError: pass self._closed = True - log.debug("%s: The Kafka producer has closed.", str(self)) + log.debug("The Kafka producer has closed.") def partitions_for(self, topic): """Returns set of all known partitions for the topic.""" - return self._wait_on_metadata(topic, self.config['max_block_ms']) + max_wait = self.config['max_block_ms'] / 1000.0 + return self._wait_on_metadata(topic, max_wait) - @classmethod - def max_usable_produce_magic(cls, api_version): - if api_version >= (0, 11): + def _max_usable_produce_magic(self): + if self.config['api_version'] >= (0, 11): return 2 - elif api_version >= (0, 10, 0): + elif self.config['api_version'] >= (0, 10): return 1 else: return 0 def _estimate_size_in_bytes(self, key, value, headers=[]): - magic = self.max_usable_produce_magic(self.config['api_version']) + magic = self._max_usable_produce_magic() if magic == 2: return DefaultRecordBatchBuilder.estimate_size_in_bytes( key, value, headers) @@ -683,114 +535,6 @@ class KafkaProducer(object): return LegacyRecordBatchBuilder.estimate_size_in_bytes( magic, self.config['compression_type'], key, value) - def init_transactions(self): - """ - Needs to be called before any other methods when the transactional.id is set in the configuration. - - This method does the following: - 1. Ensures any transactions initiated by previous instances of the producer with the same - transactional_id are completed. If the previous instance had failed with a transaction in - progress, it will be aborted. If the last transaction had begun completion, - but not yet finished, this method awaits its completion. - 2. Gets the internal producer id and epoch, used in all future transactional - messages issued by the producer. - - Note that this method will raise KafkaTimeoutError if the transactional state cannot - be initialized before expiration of `max_block_ms`. - - Retrying after a KafkaTimeoutError will continue to wait for the prior request to succeed or fail. - Retrying after any other exception will start a new initialization attempt. - Retrying after a successful initialization will do nothing. - - Raises: - IllegalStateError: if no transactional_id has been configured - AuthorizationError: fatal error indicating that the configured - transactional_id is not authorized. - KafkaError: if the producer has encountered a previous fatal error or for any other unexpected error - KafkaTimeoutError: if the time taken for initialize the transaction has surpassed `max.block.ms`. - """ - if not self._transaction_manager: - raise Errors.IllegalStateError("Cannot call init_transactions without setting a transactional_id.") - if self._init_transactions_result is None: - self._init_transactions_result = self._transaction_manager.initialize_transactions() - self._sender.wakeup() - - try: - if not self._init_transactions_result.wait(timeout_ms=self.config['max_block_ms']): - raise Errors.KafkaTimeoutError("Timeout expired while initializing transactional state in %s ms." % (self.config['max_block_ms'],)) - finally: - if self._init_transactions_result.failed: - self._init_transactions_result = None - - def begin_transaction(self): - """ Should be called before the start of each new transaction. - - Note that prior to the first invocation of this method, - you must invoke `init_transactions()` exactly one time. - - Raises: - ProducerFencedError if another producer is with the same - transactional_id is active. - """ - # Set the transactional bit in the producer. - if not self._transaction_manager: - raise Errors.IllegalStateError("Cannot use transactional methods without enabling transactions") - self._transaction_manager.begin_transaction() - - def send_offsets_to_transaction(self, offsets, consumer_group_id): - """ - Sends a list of consumed offsets to the consumer group coordinator, and also marks - those offsets as part of the current transaction. These offsets will be considered - consumed only if the transaction is committed successfully. - - This method should be used when you need to batch consumed and produced messages - together, typically in a consume-transform-produce pattern. - - Arguments: - offsets ({TopicPartition: OffsetAndMetadata}): map of topic-partition -> offsets to commit - as part of current transaction. - consumer_group_id (str): Name of consumer group for offsets commit. - - Raises: - IllegalStateError: if no transactional_id, or transaction has not been started. - ProducerFencedError: fatal error indicating another producer with the same transactional_id is active. - UnsupportedVersionError: fatal error indicating the broker does not support transactions (i.e. if < 0.11). - UnsupportedForMessageFormatError: fatal error indicating the message format used for the offsets - topic on the broker does not support transactions. - AuthorizationError: fatal error indicating that the configured transactional_id is not authorized. - KafkaErro:r if the producer has encountered a previous fatal or abortable error, or for any - other unexpected error - """ - if not self._transaction_manager: - raise Errors.IllegalStateError("Cannot use transactional methods without enabling transactions") - result = self._transaction_manager.send_offsets_to_transaction(offsets, consumer_group_id) - self._sender.wakeup() - result.wait() - - def commit_transaction(self): - """ Commits the ongoing transaction. - - Raises: ProducerFencedError if another producer with the same - transactional_id is active. - """ - if not self._transaction_manager: - raise Errors.IllegalStateError("Cannot commit transaction since transactions are not enabled") - result = self._transaction_manager.begin_commit() - self._sender.wakeup() - result.wait() - - def abort_transaction(self): - """ Aborts the ongoing transaction. - - Raises: ProducerFencedError if another producer with the same - transactional_id is active. - """ - if not self._transaction_manager: - raise Errors.IllegalStateError("Cannot abort transaction since transactions are not enabled.") - result = self._transaction_manager.begin_abort() - self._sender.wakeup() - result.wait() - def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None): """Publish a message to a topic. @@ -823,58 +567,44 @@ class KafkaProducer(object): Raises: KafkaTimeoutError: if unable to fetch topic metadata, or unable to obtain memory buffer prior to configured max_block_ms - TypeError: if topic is not a string - ValueError: if topic is invalid: must be chars (a-zA-Z0-9._-), and less than 250 length - AssertionError: if KafkaProducer is closed, or key and value are both None """ - assert not self._closed, 'KafkaProducer already closed!' assert value is not None or self.config['api_version'] >= (0, 8, 1), ( 'Null messages require kafka >= 0.8.1') assert not (value is None and key is None), 'Need at least one: key or value' - ensure_valid_topic_name(topic) key_bytes = value_bytes = None - timer = Timer(self.config['max_block_ms'], "Failed to assign partition for message in max_block_ms.") try: - assigned_partition = None - while assigned_partition is None and not timer.expired: - self._wait_on_metadata(topic, timer.timeout_ms) + self._wait_on_metadata(topic, self.config['max_block_ms'] / 1000.0) - key_bytes = self._serialize( - self.config['key_serializer'], - topic, key) - value_bytes = self._serialize( - self.config['value_serializer'], - topic, value) - assert type(key_bytes) in (bytes, bytearray, memoryview, type(None)) - assert type(value_bytes) in (bytes, bytearray, memoryview, type(None)) + key_bytes = self._serialize( + self.config['key_serializer'], + topic, key) + value_bytes = self._serialize( + self.config['value_serializer'], + topic, value) + assert type(key_bytes) in (bytes, bytearray, memoryview, type(None)) + assert type(value_bytes) in (bytes, bytearray, memoryview, type(None)) - assigned_partition = self._partition(topic, partition, key, value, - key_bytes, value_bytes) - if assigned_partition is None: - raise Errors.KafkaTimeoutError("Failed to assign partition for message after %s secs." % timer.elapsed_ms / 1000) - else: - partition = assigned_partition + partition = self._partition(topic, partition, key, value, + key_bytes, value_bytes) if headers is None: headers = [] - assert isinstance(headers, list) - assert all(isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], str) and isinstance(item[1], bytes) for item in headers) + assert type(headers) == list + assert all(type(item) == tuple and len(item) == 2 and type(item[0]) == str and type(item[1]) == bytes for item in headers) message_size = self._estimate_size_in_bytes(key_bytes, value_bytes, headers) self._ensure_valid_record_size(message_size) tp = TopicPartition(topic, partition) - log.debug("%s: Sending (key=%r value=%r headers=%r) to %s", str(self), key, value, headers, tp) - - if self._transaction_manager and self._transaction_manager.is_transactional(): - self._transaction_manager.maybe_add_partition_to_transaction(tp) - + log.debug("Sending (key=%r value=%r headers=%r) to %s", key, value, headers, tp) result = self._accumulator.append(tp, timestamp_ms, - key_bytes, value_bytes, headers) + key_bytes, value_bytes, headers, + self.config['max_block_ms'], + estimated_size=message_size) future, batch_is_full, new_batch_created = result if batch_is_full or new_batch_created: - log.debug("%s: Waking up the sender since %s is either full or" - " getting a new batch", str(self), tp) + log.debug("Waking up the sender since %s is either full or" + " getting a new batch", tp) self._sender.wakeup() return future @@ -882,7 +612,7 @@ class KafkaProducer(object): # for API exceptions return them in the future, # for other exceptions raise directly except Errors.BrokerResponseError as e: - log.error("%s: Exception occurred during message send: %s", str(self), e) + log.debug("Exception occurred during message send: %s", e) return FutureRecordMetadata( FutureProduceResult(TopicPartition(topic, partition)), -1, None, None, @@ -913,7 +643,7 @@ class KafkaProducer(object): KafkaTimeoutError: failure to flush buffered records within the provided timeout """ - log.debug("%s: Flushing accumulated records in producer.", str(self)) + log.debug("Flushing accumulated records in producer.") # trace self._accumulator.begin_flush() self._sender.wakeup() self._accumulator.await_flush_completion(timeout=timeout) @@ -925,8 +655,13 @@ class KafkaProducer(object): "The message is %d bytes when serialized which is larger than" " the maximum request size you have configured with the" " max_request_size configuration" % (size,)) + if size > self.config['buffer_memory']: + raise Errors.MessageSizeTooLargeError( + "The message is %d bytes when serialized which is larger than" + " the total memory buffer you have configured with the" + " buffer_memory configuration." % (size,)) - def _wait_on_metadata(self, topic, max_wait_ms): + def _wait_on_metadata(self, topic, max_wait): """ Wait for cluster metadata including partitions for the given topic to be available. @@ -944,31 +679,32 @@ class KafkaProducer(object): """ # add topic to metadata topic list if it is not there already. self._sender.add_topic(topic) - timer = Timer(max_wait_ms, "Failed to update metadata after %.1f secs." % (max_wait_ms / 1000,)) + begin = time.time() + elapsed = 0.0 metadata_event = None while True: partitions = self._metadata.partitions_for_topic(topic) if partitions is not None: return partitions - timer.maybe_raise() + if not metadata_event: metadata_event = threading.Event() - log.debug("%s: Requesting metadata update for topic %s", str(self), topic) + log.debug("Requesting metadata update for topic %s", topic) + metadata_event.clear() future = self._metadata.request_update() future.add_both(lambda e, *args: e.set(), metadata_event) self._sender.wakeup() - metadata_event.wait(timer.timeout_ms / 1000) - if not future.is_done: + metadata_event.wait(max_wait - elapsed) + elapsed = time.time() - begin + if not metadata_event.is_set(): raise Errors.KafkaTimeoutError( - "Failed to update metadata after %.1f secs." % (max_wait_ms / 1000,)) - elif future.failed() and not future.retriable(): - raise future.exception + "Failed to update metadata after %.1f secs." % (max_wait,)) elif topic in self._metadata.unauthorized_topics: - raise Errors.TopicAuthorizationFailedError(set([topic])) + raise Errors.TopicAuthorizationFailedError(topic) else: - log.debug("%s: _wait_on_metadata woke after %s secs.", str(self), timer.elapsed_ms / 1000) + log.debug("_wait_on_metadata woke after %s secs.", elapsed) def _serialize(self, f, topic, data): if not f: @@ -979,18 +715,16 @@ class KafkaProducer(object): def _partition(self, topic, partition, key, value, serialized_key, serialized_value): - all_partitions = self._metadata.partitions_for_topic(topic) - available = self._metadata.available_partitions_for_topic(topic) - if all_partitions is None or available is None: - return None if partition is not None: assert partition >= 0 - assert partition in all_partitions, 'Unrecognized partition' + assert partition in self._metadata.partitions_for_topic(topic), 'Unrecognized partition' return partition + all_partitions = sorted(self._metadata.partitions_for_topic(topic)) + available = list(self._metadata.available_partitions_for_topic(topic)) return self.config['partitioner'](serialized_key, - sorted(all_partitions), - list(available)) + all_partitions, + available) def metrics(self, raw=False): """Get metrics on producer performance. @@ -1002,8 +736,6 @@ class KafkaProducer(object): This is an unstable interface. It may change in future releases without warning. """ - if not self._metrics: - return if raw: return self._metrics.metrics.copy() @@ -1015,6 +747,3 @@ class KafkaProducer(object): metrics[k.group][k.name] = {} metrics[k.group][k.name] = v.value() return metrics - - def __str__(self): - return "" % (self.config['client_id'], self.config['transactional_id']) diff --git a/venv/lib/python3.12/site-packages/kafka/producer/record_accumulator.py b/venv/lib/python3.12/site-packages/kafka/producer/record_accumulator.py index 3a4e601..a2aa0e8 100644 --- a/venv/lib/python3.12/site-packages/kafka/producer/record_accumulator.py +++ b/venv/lib/python3.12/site-packages/kafka/producer/record_accumulator.py @@ -1,4 +1,4 @@ -from __future__ import absolute_import, division +from __future__ import absolute_import import collections import copy @@ -6,14 +6,8 @@ import logging import threading import time -try: - # enum in stdlib as of py3.4 - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum - import kafka.errors as Errors +from kafka.producer.buffer import SimpleBufferPool from kafka.producer.future import FutureRecordMetadata, FutureProduceResult from kafka.record.memory_records import MemoryRecordsBuilder from kafka.structs import TopicPartition @@ -41,16 +35,10 @@ class AtomicInteger(object): return self._val -class FinalState(IntEnum): - ABORTED = 0 - FAILED = 1 - SUCCEEDED = 2 - - class ProducerBatch(object): - def __init__(self, tp, records, now=None): - now = time.time() if now is None else now + def __init__(self, tp, records, buffer): self.max_record_size = 0 + now = time.time() self.created = now self.drained = None self.attempts = 0 @@ -60,120 +48,81 @@ class ProducerBatch(object): self.topic_partition = tp self.produce_future = FutureProduceResult(tp) self._retry = False - self._final_state = None - - @property - def final_state(self): - return self._final_state + self._buffer = buffer # We only save it, we don't write to it @property def record_count(self): return self.records.next_offset() - @property - def producer_id(self): - return self.records.producer_id if self.records else None - - @property - def producer_epoch(self): - return self.records.producer_epoch if self.records else None - - @property - def has_sequence(self): - return self.records.has_sequence if self.records else False - - def try_append(self, timestamp_ms, key, value, headers, now=None): + def try_append(self, timestamp_ms, key, value, headers): metadata = self.records.append(timestamp_ms, key, value, headers) if metadata is None: return None - now = time.time() if now is None else now self.max_record_size = max(self.max_record_size, metadata.size) - self.last_append = now - future = FutureRecordMetadata( - self.produce_future, - metadata.offset, - metadata.timestamp, - metadata.crc, - len(key) if key is not None else -1, - len(value) if value is not None else -1, - sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1) + self.last_append = time.time() + future = FutureRecordMetadata(self.produce_future, metadata.offset, + metadata.timestamp, metadata.crc, + len(key) if key is not None else -1, + len(value) if value is not None else -1, + sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1) return future - def abort(self, exception): - """Abort the batch and complete the future and callbacks.""" - if self._final_state is not None: - raise Errors.IllegalStateError("Batch has already been completed in final state: %s" % self._final_state) - self._final_state = FinalState.ABORTED - - log.debug("Aborting batch for partition %s: %s", self.topic_partition, exception) - self._complete_future(-1, -1, exception) - - def done(self, base_offset=None, timestamp_ms=None, exception=None): - """ - Finalize the state of a batch. Final state, once set, is immutable. This function may be called - once or twice on a batch. It may be called twice if - 1. An inflight batch expires before a response from the broker is received. The batch's final - state is set to FAILED. But it could succeed on the broker and second time around batch.done() may - try to set SUCCEEDED final state. - - 2. If a transaction abortion happens or if the producer is closed forcefully, the final state is - ABORTED but again it could succeed if broker responds with a success. - - Attempted transitions from [FAILED | ABORTED] --> SUCCEEDED are logged. - Attempted transitions from one failure state to the same or a different failed state are ignored. - Attempted transitions from SUCCEEDED to the same or a failed state throw an exception. - """ - final_state = FinalState.SUCCEEDED if exception is None else FinalState.FAILED - if self._final_state is None: - self._final_state = final_state - if final_state is FinalState.SUCCEEDED: - log.debug("Successfully produced messages to %s with base offset %s", self.topic_partition, base_offset) - else: - log.warning("Failed to produce messages to topic-partition %s with base offset %s: %s", - self.topic_partition, base_offset, exception) - self._complete_future(base_offset, timestamp_ms, exception) - return True - - elif self._final_state is not FinalState.SUCCEEDED: - if final_state is FinalState.SUCCEEDED: - # Log if a previously unsuccessful batch succeeded later on. - log.debug("ProduceResponse returned %s for %s after batch with base offset %s had already been %s.", - final_state, self.topic_partition, base_offset, self._final_state) - else: - # FAILED --> FAILED and ABORTED --> FAILED transitions are ignored. - log.debug("Ignored state transition %s -> %s for %s batch with base offset %s", - self._final_state, final_state, self.topic_partition, base_offset) - else: - # A SUCCESSFUL batch must not attempt another state change. - raise Errors.IllegalStateError("A %s batch must not attempt another state change to %s" % (self._final_state, final_state)) - return False - - def _complete_future(self, base_offset, timestamp_ms, exception): + def done(self, base_offset=None, timestamp_ms=None, exception=None, log_start_offset=None, global_error=None): + level = logging.DEBUG if exception is None else logging.WARNING + log.log(level, "Produced messages to topic-partition %s with base offset" + " %s log start offset %s and error %s.", self.topic_partition, base_offset, + log_start_offset, global_error) # trace if self.produce_future.is_done: - raise Errors.IllegalStateError('Batch is already closed!') + log.warning('Batch is already closed -- ignoring batch.done()') + return elif exception is None: - self.produce_future.success((base_offset, timestamp_ms)) + self.produce_future.success((base_offset, timestamp_ms, log_start_offset)) else: self.produce_future.failure(exception) - def has_reached_delivery_timeout(self, delivery_timeout_ms, now=None): - now = time.time() if now is None else now - return delivery_timeout_ms / 1000 <= now - self.created + def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full): + """Expire batches if metadata is not available + + A batch whose metadata is not available should be expired if one + of the following is true: + + * the batch is not in retry AND request timeout has elapsed after + it is ready (full or linger.ms has reached). + + * the batch is in retry AND request timeout has elapsed after the + backoff period ended. + """ + now = time.time() + since_append = now - self.last_append + since_ready = now - (self.created + linger_ms / 1000.0) + since_backoff = now - (self.last_attempt + retry_backoff_ms / 1000.0) + timeout = request_timeout_ms / 1000.0 + + error = None + if not self.in_retry() and is_full and timeout < since_append: + error = "%d seconds have passed since last append" % (since_append,) + elif not self.in_retry() and timeout < since_ready: + error = "%d seconds have passed since batch creation plus linger time" % (since_ready,) + elif self.in_retry() and timeout < since_backoff: + error = "%d seconds have passed since last attempt plus backoff time" % (since_backoff,) + + if error: + self.records.close() + self.done(-1, None, Errors.KafkaTimeoutError( + "Batch for %s containing %s record(s) expired: %s" % ( + self.topic_partition, self.records.next_offset(), error))) + return True + return False def in_retry(self): return self._retry - def retry(self, now=None): - now = time.time() if now is None else now + def set_retry(self): self._retry = True - self.attempts += 1 - self.last_attempt = now - self.last_append = now - @property - def is_done(self): - return self.produce_future.is_done + def buffer(self): + return self._buffer def __str__(self): return 'ProducerBatch(topic_partition=%s, record_count=%d)' % ( @@ -194,6 +143,12 @@ class RecordAccumulator(object): A small batch size will make batching less common and may reduce throughput (a batch size of zero will disable batching entirely). Default: 16384 + buffer_memory (int): The total bytes of memory the producer should use + to buffer records waiting to be sent to the server. If records are + sent faster than they can be delivered to the server the producer + will block up to max_block_ms, raising an exception on timeout. + In the current implementation, this setting is an approximation. + Default: 33554432 (32MB) compression_attrs (int): The compression type for all data generated by the producer. Valid values are gzip(1), snappy(2), lz4(3), or none(0). @@ -201,7 +156,7 @@ class RecordAccumulator(object): will also impact the compression ratio (more batching means better compression). Default: None. linger_ms (int): An artificial delay time to add before declaring a - record batch (that isn't full) ready for sending. This allows + messageset (that isn't full) ready for sending. This allows time for more records to arrive. Setting a non-zero linger_ms will trade off some latency for potentially better throughput due to more batching (and hence fewer, larger requests). @@ -211,14 +166,14 @@ class RecordAccumulator(object): all retries in a short period of time. Default: 100 """ DEFAULT_CONFIG = { + 'buffer_memory': 33554432, 'batch_size': 16384, 'compression_attrs': 0, 'linger_ms': 0, - 'request_timeout_ms': 30000, - 'delivery_timeout_ms': 120000, 'retry_backoff_ms': 100, - 'transaction_manager': None, - 'message_version': 2, + 'message_version': 0, + 'metrics': None, + 'metric_group_prefix': 'producer-metrics', } def __init__(self, **configs): @@ -228,37 +183,22 @@ class RecordAccumulator(object): self.config[key] = configs.pop(key) self._closed = False - self._transaction_manager = self.config['transaction_manager'] self._flushes_in_progress = AtomicInteger() self._appends_in_progress = AtomicInteger() self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch] self._tp_locks = {None: threading.Lock()} # TopicPartition: Lock, plus a lock to add entries + self._free = SimpleBufferPool(self.config['buffer_memory'], + self.config['batch_size'], + metrics=self.config['metrics'], + metric_group_prefix=self.config['metric_group_prefix']) self._incomplete = IncompleteProducerBatches() # The following variables should only be accessed by the sender thread, # so we don't need to protect them w/ locking. self.muted = set() self._drain_index = 0 - self._next_batch_expiry_time_ms = float('inf') - if self.config['delivery_timeout_ms'] < self.config['linger_ms'] + self.config['request_timeout_ms']: - raise Errors.KafkaConfigurationError("Must set delivery_timeout_ms higher than linger_ms + request_timeout_ms") - - @property - def delivery_timeout_ms(self): - return self.config['delivery_timeout_ms'] - - @property - def next_expiry_time_ms(self): - return self._next_batch_expiry_time_ms - - def _tp_lock(self, tp): - if tp not in self._tp_locks: - with self._tp_locks[None]: - if tp not in self._tp_locks: - self._tp_locks[tp] = threading.Lock() - return self._tp_locks[tp] - - def append(self, tp, timestamp_ms, key, value, headers, now=None): + def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, + estimated_size=0): """Add a record to the accumulator, return the append result. The append result will contain the future metadata, and flag for @@ -271,53 +211,59 @@ class RecordAccumulator(object): key (bytes): The key for the record value (bytes): The value for the record headers (List[Tuple[str, bytes]]): The header fields for the record + max_time_to_block_ms (int): The maximum time in milliseconds to + block for buffer memory to be available Returns: tuple: (future, batch_is_full, new_batch_created) """ assert isinstance(tp, TopicPartition), 'not TopicPartition' assert not self._closed, 'RecordAccumulator is closed' - now = time.time() if now is None else now # We keep track of the number of appending thread to make sure we do # not miss batches in abortIncompleteBatches(). self._appends_in_progress.increment() try: - with self._tp_lock(tp): + if tp not in self._tp_locks: + with self._tp_locks[None]: + if tp not in self._tp_locks: + self._tp_locks[tp] = threading.Lock() + + with self._tp_locks[tp]: # check if we have an in-progress batch dq = self._batches[tp] if dq: last = dq[-1] - future = last.try_append(timestamp_ms, key, value, headers, now=now) + future = last.try_append(timestamp_ms, key, value, headers) if future is not None: batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False - with self._tp_lock(tp): + size = max(self.config['batch_size'], estimated_size) + log.debug("Allocating a new %d byte message buffer for %s", size, tp) # trace + buf = self._free.allocate(size, max_time_to_block_ms) + with self._tp_locks[tp]: # Need to check if producer is closed again after grabbing the # dequeue lock. assert not self._closed, 'RecordAccumulator is closed' if dq: last = dq[-1] - future = last.try_append(timestamp_ms, key, value, headers, now=now) + future = last.try_append(timestamp_ms, key, value, headers) if future is not None: # Somebody else found us a batch, return the one we # waited for! Hopefully this doesn't happen often... + self._free.deallocate(buf) batch_is_full = len(dq) > 1 or last.records.is_full() return future, batch_is_full, False - if self._transaction_manager and self.config['message_version'] < 2: - raise Errors.UnsupportedVersionError("Attempting to use idempotence with a broker which" - " does not support the required message format (v2)." - " The broker must be version 0.11 or later.") records = MemoryRecordsBuilder( self.config['message_version'], self.config['compression_attrs'], self.config['batch_size'] ) - batch = ProducerBatch(tp, records, now=now) - future = batch.try_append(timestamp_ms, key, value, headers, now=now) + batch = ProducerBatch(tp, records, buf) + future = batch.try_append(timestamp_ms, key, value, headers) if not future: raise Exception() @@ -328,43 +274,79 @@ class RecordAccumulator(object): finally: self._appends_in_progress.decrement() - def reset_next_batch_expiry_time(self): - self._next_batch_expiry_time_ms = float('inf') + def abort_expired_batches(self, request_timeout_ms, cluster): + """Abort the batches that have been sitting in RecordAccumulator for + more than the configured request_timeout due to metadata being + unavailable. - def maybe_update_next_batch_expiry_time(self, batch): - self._next_batch_expiry_time_ms = min(self._next_batch_expiry_time_ms, batch.created * 1000 + self.delivery_timeout_ms) + Arguments: + request_timeout_ms (int): milliseconds to timeout + cluster (ClusterMetadata): current metadata for kafka cluster - def expired_batches(self, now=None): - """Get a list of batches which have been sitting in the accumulator too long and need to be expired.""" + Returns: + list of ProducerBatch that were expired + """ expired_batches = [] + to_remove = [] + count = 0 for tp in list(self._batches.keys()): - with self._tp_lock(tp): + assert tp in self._tp_locks, 'TopicPartition not in locks dict' + + # We only check if the batch should be expired if the partition + # does not have a batch in flight. This is to avoid the later + # batches get expired when an earlier batch is still in progress. + # This protection only takes effect when user sets + # max.in.flight.request.per.connection=1. Otherwise the expiration + # order is not guranteed. + if tp in self.muted: + continue + + with self._tp_locks[tp]: # iterate over the batches and expire them if they have stayed # in accumulator for more than request_timeout_ms dq = self._batches[tp] - while dq: - batch = dq[0] - if batch.has_reached_delivery_timeout(self.delivery_timeout_ms, now=now): - dq.popleft() - batch.records.close() + for batch in dq: + is_full = bool(bool(batch != dq[-1]) or batch.records.is_full()) + # check if the batch is expired + if batch.maybe_expire(request_timeout_ms, + self.config['retry_backoff_ms'], + self.config['linger_ms'], + is_full): expired_batches.append(batch) + to_remove.append(batch) + count += 1 + self.deallocate(batch) else: # Stop at the first batch that has not expired. - self.maybe_update_next_batch_expiry_time(batch) break + + # Python does not allow us to mutate the dq during iteration + # Assuming expired batches are infrequent, this is better than + # creating a new copy of the deque for iteration on every loop + if to_remove: + for batch in to_remove: + dq.remove(batch) + to_remove = [] + + if expired_batches: + log.warning("Expired %d batches in accumulator", count) # trace + return expired_batches - def reenqueue(self, batch, now=None): - """ - Re-enqueue the given record batch in the accumulator. In Sender._complete_batch method, we check - whether the batch has reached delivery_timeout_ms or not. Hence we do not do the delivery timeout check here. - """ - batch.retry(now=now) - with self._tp_lock(batch.topic_partition): - dq = self._batches[batch.topic_partition] + def reenqueue(self, batch): + """Re-enqueue the given record batch in the accumulator to retry.""" + now = time.time() + batch.attempts += 1 + batch.last_attempt = now + batch.last_append = now + batch.set_retry() + assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict' + assert batch.topic_partition in self._batches, 'TopicPartition not in batches' + dq = self._batches[batch.topic_partition] + with self._tp_locks[batch.topic_partition]: dq.appendleft(batch) - def ready(self, cluster, now=None): + def ready(self, cluster): """ Get a list of nodes whose partitions are ready to be sent, and the earliest time at which any non-sendable partition will be ready; @@ -398,8 +380,9 @@ class RecordAccumulator(object): ready_nodes = set() next_ready_check = 9999999.99 unknown_leaders_exist = False - now = time.time() if now is None else now + now = time.time() + exhausted = bool(self._free.queued() > 0) # several threads are accessing self._batches -- to simplify # concurrent access, we iterate over a snapshot of partitions # and lock each partition separately as needed @@ -414,23 +397,23 @@ class RecordAccumulator(object): elif tp in self.muted: continue - with self._tp_lock(tp): + with self._tp_locks[tp]: dq = self._batches[tp] if not dq: continue batch = dq[0] - retry_backoff = self.config['retry_backoff_ms'] / 1000 - linger = self.config['linger_ms'] / 1000 - backing_off = bool(batch.attempts > 0 - and (batch.last_attempt + retry_backoff) > now) + retry_backoff = self.config['retry_backoff_ms'] / 1000.0 + linger = self.config['linger_ms'] / 1000.0 + backing_off = bool(batch.attempts > 0 and + batch.last_attempt + retry_backoff > now) waited_time = now - batch.last_attempt time_to_wait = retry_backoff if backing_off else linger time_left = max(time_to_wait - waited_time, 0) full = bool(len(dq) > 1 or batch.records.is_full()) expired = bool(waited_time >= time_to_wait) - sendable = (full or expired or self._closed or - self.flush_in_progress()) + sendable = (full or expired or exhausted or self._closed or + self._flush_in_progress()) if sendable and not backing_off: ready_nodes.add(leader) @@ -444,98 +427,16 @@ class RecordAccumulator(object): return ready_nodes, next_ready_check, unknown_leaders_exist - def has_undrained(self): - """Check whether there are any batches which haven't been drained""" + def has_unsent(self): + """Return whether there is any unsent record in the accumulator.""" for tp in list(self._batches.keys()): - with self._tp_lock(tp): + with self._tp_locks[tp]: dq = self._batches[tp] if len(dq): return True return False - def _should_stop_drain_batches_for_partition(self, first, tp): - if self._transaction_manager: - if not self._transaction_manager.is_send_to_partition_allowed(tp): - return True - if not self._transaction_manager.producer_id_and_epoch.is_valid: - # we cannot send the batch until we have refreshed the PID - log.debug("Waiting to send ready batches because transaction producer id is not valid") - return True - return False - - def drain_batches_for_one_node(self, cluster, node_id, max_size, now=None): - now = time.time() if now is None else now - size = 0 - ready = [] - partitions = list(cluster.partitions_for_broker(node_id)) - if not partitions: - return ready - # to make starvation less likely this loop doesn't start at 0 - self._drain_index %= len(partitions) - start = None - while start != self._drain_index: - tp = partitions[self._drain_index] - if start is None: - start = self._drain_index - self._drain_index += 1 - self._drain_index %= len(partitions) - - # Only proceed if the partition has no in-flight batches. - if tp in self.muted: - continue - - if tp not in self._batches: - continue - - with self._tp_lock(tp): - dq = self._batches[tp] - if len(dq) == 0: - continue - first = dq[0] - backoff = bool(first.attempts > 0 and - first.last_attempt + self.config['retry_backoff_ms'] / 1000 > now) - # Only drain the batch if it is not during backoff - if backoff: - continue - - if (size + first.records.size_in_bytes() > max_size - and len(ready) > 0): - # there is a rare case that a single batch - # size is larger than the request size due - # to compression; in this case we will - # still eventually send this batch in a - # single request - break - else: - if self._should_stop_drain_batches_for_partition(first, tp): - break - - batch = dq.popleft() - if self._transaction_manager and not batch.in_retry(): - # If the batch is in retry, then we should not change the pid and - # sequence number, since this may introduce duplicates. In particular, - # the previous attempt may actually have been accepted, and if we change - # the pid and sequence here, this attempt will also be accepted, causing - # a duplicate. - sequence_number = self._transaction_manager.sequence_number(batch.topic_partition) - log.debug("Dest: %s: %s producer_id=%s epoch=%s sequence=%s", - node_id, batch.topic_partition, - self._transaction_manager.producer_id_and_epoch.producer_id, - self._transaction_manager.producer_id_and_epoch.epoch, - sequence_number) - batch.records.set_producer_state( - self._transaction_manager.producer_id_and_epoch.producer_id, - self._transaction_manager.producer_id_and_epoch.epoch, - sequence_number, - self._transaction_manager.is_transactional() - ) - batch.records.close() - size += batch.records.size_in_bytes() - ready.append(batch) - batch.drained = now - return ready - - def drain(self, cluster, nodes, max_size, now=None): + def drain(self, cluster, nodes, max_size): """ Drain all the data for the given nodes and collate them into a list of batches that will fit within the specified size on a per-node basis. @@ -553,17 +454,59 @@ class RecordAccumulator(object): if not nodes: return {} - now = time.time() if now is None else now + now = time.time() batches = {} for node_id in nodes: - batches[node_id] = self.drain_batches_for_one_node(cluster, node_id, max_size, now=now) + size = 0 + partitions = list(cluster.partitions_for_broker(node_id)) + ready = [] + # to make starvation less likely this loop doesn't start at 0 + self._drain_index %= len(partitions) + start = self._drain_index + while True: + tp = partitions[self._drain_index] + if tp in self._batches and tp not in self.muted: + with self._tp_locks[tp]: + dq = self._batches[tp] + if dq: + first = dq[0] + backoff = ( + bool(first.attempts > 0) and + bool(first.last_attempt + + self.config['retry_backoff_ms'] / 1000.0 + > now) + ) + # Only drain the batch if it is not during backoff + if not backoff: + if (size + first.records.size_in_bytes() > max_size + and len(ready) > 0): + # there is a rare case that a single batch + # size is larger than the request size due + # to compression; in this case we will + # still eventually send this batch in a + # single request + break + else: + batch = dq.popleft() + batch.records.close() + size += batch.records.size_in_bytes() + ready.append(batch) + batch.drained = now + + self._drain_index += 1 + self._drain_index %= len(partitions) + if start == self._drain_index: + break + + batches[node_id] = ready return batches def deallocate(self, batch): """Deallocate the record batch.""" self._incomplete.remove(batch) + self._free.deallocate(batch.buffer()) - def flush_in_progress(self): + def _flush_in_progress(self): """Are there any threads currently waiting on a flush?""" return self._flushes_in_progress.get() > 0 @@ -592,10 +535,6 @@ class RecordAccumulator(object): finally: self._flushes_in_progress.decrement() - @property - def has_incomplete(self): - return bool(self._incomplete) - def abort_incomplete_batches(self): """ This function is only called when sender is closed forcefully. It will fail all the @@ -605,41 +544,27 @@ class RecordAccumulator(object): # 1. Avoid losing batches. # 2. Free up memory in case appending threads are blocked on buffer full. # This is a tight loop but should be able to get through very quickly. - error = Errors.IllegalStateError("Producer is closed forcefully.") while True: - self._abort_batches(error) + self._abort_batches() if not self._appends_in_progress.get(): break # After this point, no thread will append any messages because they will see the close # flag set. We need to do the last abort after no thread was appending in case the there was a new # batch appended by the last appending thread. - self._abort_batches(error) + self._abort_batches() self._batches.clear() - def _abort_batches(self, error): + def _abort_batches(self): """Go through incomplete batches and abort them.""" + error = Errors.IllegalStateError("Producer is closed forcefully.") for batch in self._incomplete.all(): tp = batch.topic_partition # Close the batch before aborting - with self._tp_lock(tp): + with self._tp_locks[tp]: batch.records.close() - self._batches[tp].remove(batch) - batch.abort(error) + batch.done(exception=error) self.deallocate(batch) - def abort_undrained_batches(self, error): - for batch in self._incomplete.all(): - tp = batch.topic_partition - with self._tp_lock(tp): - aborted = False - if not batch.is_done: - aborted = True - batch.records.close() - self._batches[tp].remove(batch) - if aborted: - batch.abort(error) - self.deallocate(batch) - def close(self): """Close this accumulator and force all the record buffers to be drained.""" self._closed = True @@ -654,21 +579,12 @@ class IncompleteProducerBatches(object): def add(self, batch): with self._lock: - self._incomplete.add(batch) + return self._incomplete.add(batch) def remove(self, batch): with self._lock: - try: - self._incomplete.remove(batch) - except KeyError: - pass + return self._incomplete.remove(batch) def all(self): with self._lock: return list(self._incomplete) - - def __bool__(self): - return bool(self._incomplete) - - - __nonzero__ = __bool__ diff --git a/venv/lib/python3.12/site-packages/kafka/producer/sender.py b/venv/lib/python3.12/site-packages/kafka/producer/sender.py index 7a4c557..35688d3 100644 --- a/venv/lib/python3.12/site-packages/kafka/producer/sender.py +++ b/venv/lib/python3.12/site-packages/kafka/producer/sender.py @@ -2,7 +2,6 @@ from __future__ import absolute_import, division import collections import copy -import heapq import logging import threading import time @@ -12,8 +11,6 @@ from kafka.vendor import six from kafka import errors as Errors from kafka.metrics.measurable import AnonMeasurable from kafka.metrics.stats import Avg, Max, Rate -from kafka.producer.transaction_manager import ProducerIdAndEpoch -from kafka.protocol.init_producer_id import InitProducerIdRequest from kafka.protocol.produce import ProduceRequest from kafka.structs import TopicPartition from kafka.version import __version__ @@ -30,18 +27,14 @@ class Sender(threading.Thread): DEFAULT_CONFIG = { 'max_request_size': 1048576, 'acks': 1, - 'retries': float('inf'), + 'retries': 0, 'request_timeout_ms': 30000, - 'retry_backoff_ms': 100, - 'metrics': None, 'guarantee_message_order': False, - 'transaction_manager': None, - 'transactional_id': None, - 'transaction_timeout_ms': 60000, 'client_id': 'kafka-python-' + __version__, + 'api_version': (0, 8, 0), } - def __init__(self, client, metadata, accumulator, **configs): + def __init__(self, client, metadata, accumulator, metrics, **configs): super(Sender, self).__init__() self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -55,75 +48,32 @@ class Sender(threading.Thread): self._running = True self._force_close = False self._topics_to_add = set() - if self.config['metrics']: - self._sensors = SenderMetrics(self.config['metrics'], self._client, self._metadata) - else: - self._sensors = None - self._transaction_manager = self.config['transaction_manager'] - # A per-partition queue of batches ordered by creation time for tracking the in-flight batches - self._in_flight_batches = collections.defaultdict(list) - - def _maybe_remove_from_inflight_batches(self, batch): - try: - queue = self._in_flight_batches[batch.topic_partition] - except KeyError: - return - try: - idx = queue.index((batch.created, batch)) - except ValueError: - return - # https://stackoverflow.com/questions/10162679/python-delete-element-from-heap - queue[idx] = queue[-1] - queue.pop() - heapq.heapify(queue) - - def _get_expired_inflight_batches(self, now=None): - """Get the in-flight batches that has reached delivery timeout.""" - expired_batches = [] - to_remove = [] - for tp, queue in six.iteritems(self._in_flight_batches): - while queue: - _created_at, batch = queue[0] - if batch.has_reached_delivery_timeout(self._accumulator.delivery_timeout_ms): - heapq.heappop(queue) - if batch.final_state is None: - expired_batches.append(batch) - else: - raise Errors.IllegalStateError("%s batch created at %s gets unexpected final state %s" % (batch.topic_partition, batch.created, batch.final_state)) - else: - self._accumulator.maybe_update_next_batch_expiry_time(batch) - break - else: - # Avoid mutating in_flight_batches during iteration - to_remove.append(tp) - for tp in to_remove: - del self._in_flight_batches[tp] - return expired_batches + self._sensors = SenderMetrics(metrics, self._client, self._metadata) def run(self): """The main run loop for the sender thread.""" - log.debug("%s: Starting Kafka producer I/O thread.", str(self)) + log.debug("Starting Kafka producer I/O thread.") # main loop, runs until close is called while self._running: try: self.run_once() except Exception: - log.exception("%s: Uncaught error in kafka producer I/O thread", str(self)) + log.exception("Uncaught error in kafka producer I/O thread") - log.debug("%s: Beginning shutdown of Kafka producer I/O thread, sending" - " remaining records.", str(self)) + log.debug("Beginning shutdown of Kafka producer I/O thread, sending" + " remaining records.") # okay we stopped accepting requests but there may still be # requests in the accumulator or waiting for acknowledgment, # wait until these are completed. while (not self._force_close - and (self._accumulator.has_undrained() + and (self._accumulator.has_unsent() or self._client.in_flight_request_count() > 0)): try: self.run_once() except Exception: - log.exception("%s: Uncaught error in kafka producer I/O thread", str(self)) + log.exception("Uncaught error in kafka producer I/O thread") if self._force_close: # We need to fail all the incomplete batches and wake up the @@ -133,75 +83,38 @@ class Sender(threading.Thread): try: self._client.close() except Exception: - log.exception("%s: Failed to close network client", str(self)) + log.exception("Failed to close network client") - log.debug("%s: Shutdown of Kafka producer I/O thread has completed.", str(self)) + log.debug("Shutdown of Kafka producer I/O thread has completed.") def run_once(self): """Run a single iteration of sending.""" while self._topics_to_add: self._client.add_topic(self._topics_to_add.pop()) - if self._transaction_manager: - try: - if not self._transaction_manager.is_transactional(): - # this is an idempotent producer, so make sure we have a producer id - self._maybe_wait_for_producer_id() - elif self._transaction_manager.has_in_flight_transactional_request() or self._maybe_send_transactional_request(): - # as long as there are outstanding transactional requests, we simply wait for them to return - self._client.poll(timeout_ms=self.config['retry_backoff_ms']) - return - - # do not continue sending if the transaction manager is in a failed state or if there - # is no producer id (for the idempotent case). - if self._transaction_manager.has_fatal_error() or not self._transaction_manager.has_producer_id(): - last_error = self._transaction_manager.last_error - if last_error is not None: - self._maybe_abort_batches(last_error) - self._client.poll(timeout_ms=self.config['retry_backoff_ms']) - return - elif self._transaction_manager.has_abortable_error(): - self._accumulator.abort_undrained_batches(self._transaction_manager.last_error) - - except Errors.SaslAuthenticationFailedError as e: - # This is already logged as error, but propagated here to perform any clean ups. - log.debug("%s: Authentication exception while processing transactional request: %s", str(self), e) - self._transaction_manager.authentication_failed(e) - - poll_timeout_ms = self._send_producer_data() - self._client.poll(timeout_ms=poll_timeout_ms) - - def _send_producer_data(self, now=None): - now = time.time() if now is None else now # get the list of partitions with data ready to send - result = self._accumulator.ready(self._metadata, now=now) + result = self._accumulator.ready(self._metadata) ready_nodes, next_ready_check_delay, unknown_leaders_exist = result # if there are any partitions whose leaders are not known yet, force # metadata update if unknown_leaders_exist: - log.debug('%s: Unknown leaders exist, requesting metadata update', str(self)) + log.debug('Unknown leaders exist, requesting metadata update') self._metadata.request_update() # remove any nodes we aren't ready to send to - not_ready_timeout_ms = float('inf') + not_ready_timeout = float('inf') for node in list(ready_nodes): if not self._client.is_ready(node): - node_delay_ms = self._client.connection_delay(node) - log.debug('%s: Node %s not ready; delaying produce of accumulated batch (%f ms)', str(self), node, node_delay_ms) + log.debug('Node %s not ready; delaying produce of accumulated batch', node) self._client.maybe_connect(node, wakeup=False) ready_nodes.remove(node) - not_ready_timeout_ms = min(not_ready_timeout_ms, node_delay_ms) + not_ready_timeout = min(not_ready_timeout, + self._client.connection_delay(node)) # create produce requests batches_by_node = self._accumulator.drain( - self._metadata, ready_nodes, self.config['max_request_size'], now=now) - - for batch_list in six.itervalues(batches_by_node): - for batch in batch_list: - item = (batch.created, batch) - queue = self._in_flight_batches[batch.topic_partition] - heapq.heappush(queue, item) + self._metadata, ready_nodes, self.config['max_request_size']) if self.config['guarantee_message_order']: # Mute all the partitions drained @@ -209,130 +122,42 @@ class Sender(threading.Thread): for batch in batch_list: self._accumulator.muted.add(batch.topic_partition) - self._accumulator.reset_next_batch_expiry_time() - expired_batches = self._accumulator.expired_batches(now=now) - expired_batches.extend(self._get_expired_inflight_batches(now=now)) - - if expired_batches: - log.debug("%s: Expired %s batches in accumulator", str(self), len(expired_batches)) - - # Reset the producer_id if an expired batch has previously been sent to the broker. - # See the documentation of `TransactionState.reset_producer_id` to understand why - # we need to reset the producer id here. - if self._transaction_manager and any([batch.in_retry() for batch in expired_batches]): - needs_transaction_state_reset = True - else: - needs_transaction_state_reset = False - + expired_batches = self._accumulator.abort_expired_batches( + self.config['request_timeout_ms'], self._metadata) for expired_batch in expired_batches: - error = Errors.KafkaTimeoutError( - "Expiring %d record(s) for %s: %s ms has passed since batch creation" % ( - expired_batch.record_count, expired_batch.topic_partition, - int((time.time() - expired_batch.created) * 1000))) - self._fail_batch(expired_batch, error, base_offset=-1) - - if self._sensors: - self._sensors.update_produce_request_metrics(batches_by_node) - - if needs_transaction_state_reset: - self._transaction_manager.reset_producer_id() - return 0 + self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count) + self._sensors.update_produce_request_metrics(batches_by_node) requests = self._create_produce_requests(batches_by_node) # If we have any nodes that are ready to send + have sendable data, # poll with 0 timeout so this can immediately loop and try sending more - # data. Otherwise, the timeout will be the smaller value between next - # batch expiry time, and the delay time for checking data availability. - # Note that the nodes may have data that isn't yet sendable due to - # lingering, backing off, etc. This specifically does not include nodes with + # data. Otherwise, the timeout is determined by nodes that have + # partitions with data that isn't yet sendable (e.g. lingering, backing + # off). Note that this specifically does not include nodes with # sendable data that aren't ready to send since they would cause busy # looping. - poll_timeout_ms = min(next_ready_check_delay * 1000, - not_ready_timeout_ms, - self._accumulator.next_expiry_time_ms - now * 1000) - if poll_timeout_ms < 0: - poll_timeout_ms = 0 - + poll_timeout_ms = min(next_ready_check_delay * 1000, not_ready_timeout) if ready_nodes: - log.debug("%s: Nodes with data ready to send: %s", str(self), ready_nodes) # trace - log.debug("%s: Created %d produce requests: %s", str(self), len(requests), requests) # trace - # if some partitions are already ready to be sent, the select time - # would be 0; otherwise if some partition already has some data - # accumulated but not ready yet, the select time will be the time - # difference between now and its linger expiry time; otherwise the - # select time will be the time difference between now and the - # metadata expiry time + log.debug("Nodes with data ready to send: %s", ready_nodes) # trace + log.debug("Created %d produce requests: %s", len(requests), requests) # trace poll_timeout_ms = 0 for node_id, request in six.iteritems(requests): batches = batches_by_node[node_id] - log.debug('%s: Sending Produce Request: %r', str(self), request) + log.debug('Sending Produce Request: %r', request) (self._client.send(node_id, request, wakeup=False) .add_callback( self._handle_produce_response, node_id, time.time(), batches) .add_errback( self._failed_produce, batches, node_id)) - return poll_timeout_ms - def _maybe_send_transactional_request(self): - if self._transaction_manager.is_completing() and self._accumulator.has_incomplete: - if self._transaction_manager.is_aborting(): - self._accumulator.abort_undrained_batches(Errors.KafkaError("Failing batch since transaction was aborted")) - # There may still be requests left which are being retried. Since we do not know whether they had - # been successfully appended to the broker log, we must resend them until their final status is clear. - # If they had been appended and we did not receive the error, then our sequence number would no longer - # be correct which would lead to an OutOfSequenceNumberError. - if not self._accumulator.flush_in_progress(): - self._accumulator.begin_flush() - - next_request_handler = self._transaction_manager.next_request_handler(self._accumulator.has_incomplete) - if next_request_handler is None: - return False - - log.debug("%s: Sending transactional request %s", str(self), next_request_handler.request) - while not self._force_close: - target_node = None - try: - if next_request_handler.needs_coordinator(): - target_node = self._transaction_manager.coordinator(next_request_handler.coordinator_type) - if target_node is None: - self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - break - elif not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): - self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - target_node = None - break - else: - target_node = self._client.least_loaded_node() - if target_node is not None and not self._client.await_ready(target_node, timeout_ms=self.config['request_timeout_ms']): - target_node = None - - if target_node is not None: - if next_request_handler.is_retry: - time.sleep(self.config['retry_backoff_ms'] / 1000) - txn_correlation_id = self._transaction_manager.next_in_flight_request_correlation_id() - future = self._client.send(target_node, next_request_handler.request) - future.add_both(next_request_handler.on_complete, txn_correlation_id) - return True - - except Exception as e: - log.warn("%s: Got an exception when trying to find a node to send a transactional request to. Going to back off and retry: %s", str(self), e) - if next_request_handler.needs_coordinator(): - self._transaction_manager.lookup_coordinator_for_request(next_request_handler) - break - - time.sleep(self.config['retry_backoff_ms'] / 1000) - self._metadata.request_update() - - if target_node is None: - self._transaction_manager.retry(next_request_handler) - - return True - - def _maybe_abort_batches(self, exc): - if self._accumulator.has_incomplete: - log.error("%s: Aborting producer batches due to fatal error: %s", str(self), exc) - self._accumulator.abort_batches(exc) + # if some partitions are already ready to be sent, the select time + # would be 0; otherwise if some partition already has some data + # accumulated but not ready yet, the select time will be the time + # difference between now and its linger expiry time; otherwise the + # select time will be the time difference between now and the + # metadata expiry time + self._client.poll(timeout_ms=poll_timeout_ms) def initiate_close(self): """Start closing the sender (won't complete until all data is sent).""" @@ -355,164 +180,82 @@ class Sender(threading.Thread): self._topics_to_add.add(topic) self.wakeup() - def _maybe_wait_for_producer_id(self): - while not self._transaction_manager.has_producer_id(): - try: - node_id = self._client.least_loaded_node() - if node_id is None or not self._client.await_ready(node_id): - log.debug("%s, Could not find an available broker to send InitProducerIdRequest to." + - " Will back off and try again.", str(self)) - time.sleep(self._client.least_loaded_node_refresh_ms() / 1000) - continue - version = self._client.api_version(InitProducerIdRequest, max_version=1) - request = InitProducerIdRequest[version]( - transactional_id=self.config['transactional_id'], - transaction_timeout_ms=self.config['transaction_timeout_ms'], - ) - response = self._client.send_and_receive(node_id, request) - error_type = Errors.for_code(response.error_code) - if error_type is Errors.NoError: - self._transaction_manager.set_producer_id_and_epoch(ProducerIdAndEpoch(response.producer_id, response.producer_epoch)) - break - elif getattr(error_type, 'retriable', False): - log.debug("%s: Retriable error from InitProducerId response: %s", str(self), error_type.__name__) - if getattr(error_type, 'invalid_metadata', False): - self._metadata.request_update() - else: - self._transaction_manager.transition_to_fatal_error(error_type()) - break - except Errors.KafkaConnectionError: - log.debug("%s: Broker %s disconnected while awaiting InitProducerId response", str(self), node_id) - except Errors.RequestTimedOutError: - log.debug("%s: InitProducerId request to node %s timed out", str(self), node_id) - log.debug("%s: Retry InitProducerIdRequest in %sms.", str(self), self.config['retry_backoff_ms']) - time.sleep(self.config['retry_backoff_ms'] / 1000) - def _failed_produce(self, batches, node_id, error): - log.error("%s: Error sending produce request to node %d: %s", str(self), node_id, error) # trace + log.debug("Error sending produce request to node %d: %s", node_id, error) # trace for batch in batches: - self._complete_batch(batch, error, -1) + self._complete_batch(batch, error, -1, None) def _handle_produce_response(self, node_id, send_time, batches, response): """Handle a produce response.""" # if we have a response, parse it - log.debug('%s: Parsing produce response: %r', str(self), response) + log.debug('Parsing produce response: %r', response) if response: batches_by_partition = dict([(batch.topic_partition, batch) for batch in batches]) for topic, partitions in response.topics: for partition_info in partitions: + global_error = None + log_start_offset = None if response.API_VERSION < 2: partition, error_code, offset = partition_info ts = None elif 2 <= response.API_VERSION <= 4: partition, error_code, offset, ts = partition_info elif 5 <= response.API_VERSION <= 7: - partition, error_code, offset, ts, _log_start_offset = partition_info + partition, error_code, offset, ts, log_start_offset = partition_info else: - # Currently unused / TODO: KIP-467 - partition, error_code, offset, ts, _log_start_offset, _record_errors, _global_error = partition_info + # the ignored parameter is record_error of type list[(batch_index: int, error_message: str)] + partition, error_code, offset, ts, log_start_offset, _, global_error = partition_info tp = TopicPartition(topic, partition) error = Errors.for_code(error_code) batch = batches_by_partition[tp] - self._complete_batch(batch, error, offset, timestamp_ms=ts) + self._complete_batch(batch, error, offset, ts, log_start_offset, global_error) + + if response.API_VERSION > 0: + self._sensors.record_throttle_time(response.throttle_time_ms, node=node_id) else: # this is the acks = 0 case, just complete all requests for batch in batches: - self._complete_batch(batch, None, -1) + self._complete_batch(batch, None, -1, None) - def _fail_batch(self, batch, exception, base_offset=None, timestamp_ms=None): - exception = exception if type(exception) is not type else exception() - if self._transaction_manager: - if isinstance(exception, Errors.OutOfOrderSequenceNumberError) and \ - not self._transaction_manager.is_transactional() and \ - self._transaction_manager.has_producer_id(batch.producer_id): - log.error("%s: The broker received an out of order sequence number for topic-partition %s" - " at offset %s. This indicates data loss on the broker, and should be investigated.", - str(self), batch.topic_partition, base_offset) - - # Reset the transaction state since we have hit an irrecoverable exception and cannot make any guarantees - # about the previously committed message. Note that this will discard the producer id and sequence - # numbers for all existing partitions. - self._transaction_manager.reset_producer_id() - elif isinstance(exception, (Errors.ClusterAuthorizationFailedError, - Errors.TransactionalIdAuthorizationFailedError, - Errors.ProducerFencedError, - Errors.InvalidTxnStateError)): - self._transaction_manager.transition_to_fatal_error(exception) - elif self._transaction_manager.is_transactional(): - self._transaction_manager.transition_to_abortable_error(exception) - - if self._sensors: - self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) - - if batch.done(base_offset=base_offset, timestamp_ms=timestamp_ms, exception=exception): - self._maybe_remove_from_inflight_batches(batch) - self._accumulator.deallocate(batch) - - def _complete_batch(self, batch, error, base_offset, timestamp_ms=None): + def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_start_offset=None, global_error=None): """Complete or retry the given batch of records. Arguments: - batch (ProducerBatch): The record batch + batch (RecordBatch): The record batch error (Exception): The error (or None if none) base_offset (int): The base offset assigned to the records if successful timestamp_ms (int, optional): The timestamp returned by the broker for this batch + log_start_offset (int): The start offset of the log at the time this produce response was created + global_error (str): The summarising error message """ # Standardize no-error to None if error is Errors.NoError: error = None - if error is not None: - if self._can_retry(batch, error): - # retry - log.warning("%s: Got error produce response on topic-partition %s," - " retrying (%s attempts left). Error: %s", - str(self), batch.topic_partition, - self.config['retries'] - batch.attempts - 1, - error) - - # If idempotence is enabled only retry the request if the batch matches our current producer id and epoch - if not self._transaction_manager or self._transaction_manager.producer_id_and_epoch.match(batch): - log.debug("%s: Retrying batch to topic-partition %s. Sequence number: %s", - str(self), batch.topic_partition, - self._transaction_manager.sequence_number(batch.topic_partition) if self._transaction_manager else None) - self._accumulator.reenqueue(batch) - self._maybe_remove_from_inflight_batches(batch) - if self._sensors: - self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) - else: - log.warning("%s: Attempted to retry sending a batch but the producer id/epoch changed from %s/%s to %s/%s. This batch will be dropped", - str(self), batch.producer_id, batch.producer_epoch, - self._transaction_manager.producer_id_and_epoch.producer_id, - self._transaction_manager.producer_id_and_epoch.epoch) - self._fail_batch(batch, error, base_offset=base_offset, timestamp_ms=timestamp_ms) - else: - if error is Errors.TopicAuthorizationFailedError: - error = error(batch.topic_partition.topic) - - # tell the user the result of their request - self._fail_batch(batch, error, base_offset=base_offset, timestamp_ms=timestamp_ms) - - if error is Errors.UnknownTopicOrPartitionError: - log.warning("%s: Received unknown topic or partition error in produce request on partition %s." - " The topic/partition may not exist or the user may not have Describe access to it", - str(self), batch.topic_partition) - - if getattr(error, 'invalid_metadata', False): - self._metadata.request_update() - + if error is not None and self._can_retry(batch, error): + # retry + log.warning("Got error produce response on topic-partition %s," + " retrying (%d attempts left). Error: %s", + batch.topic_partition, + self.config['retries'] - batch.attempts - 1, + global_error or error) + self._accumulator.reenqueue(batch) + self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) else: - if batch.done(base_offset=base_offset, timestamp_ms=timestamp_ms): - self._maybe_remove_from_inflight_batches(batch) - self._accumulator.deallocate(batch) + if error is Errors.TopicAuthorizationFailedError: + error = error(batch.topic_partition.topic) - if self._transaction_manager and self._transaction_manager.producer_id_and_epoch.match(batch): - self._transaction_manager.increment_sequence_number(batch.topic_partition, batch.record_count) - log.debug("%s: Incremented sequence number for topic-partition %s to %s", str(self), batch.topic_partition, - self._transaction_manager.sequence_number(batch.topic_partition)) + # tell the user the result of their request + batch.done(base_offset, timestamp_ms, error, log_start_offset, global_error) + self._accumulator.deallocate(batch) + if error is not None: + self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) + + if getattr(error, 'invalid_metadata', False): + self._metadata.request_update() # Unmute the completed partition. if self.config['guarantee_message_order']: @@ -523,10 +266,8 @@ class Sender(threading.Thread): We can retry a send if the error is transient and the number of attempts taken is fewer than the maximum allowed """ - return (not batch.has_reached_delivery_timeout(self._accumulator.delivery_timeout_ms) and - batch.attempts < self.config['retries'] and - batch.final_state is None and - getattr(error, 'retriable', False)) + return (batch.attempts < self.config['retries'] + and getattr(error, 'retriable', False)) def _create_produce_requests(self, collated): """ @@ -534,24 +275,23 @@ class Sender(threading.Thread): per-node basis. Arguments: - collated: {node_id: [ProducerBatch]} + collated: {node_id: [RecordBatch]} Returns: - dict: {node_id: ProduceRequest} (version depends on client api_versions) + dict: {node_id: ProduceRequest} (version depends on api_version) """ requests = {} for node_id, batches in six.iteritems(collated): - if batches: - requests[node_id] = self._produce_request( - node_id, self.config['acks'], - self.config['request_timeout_ms'], batches) + requests[node_id] = self._produce_request( + node_id, self.config['acks'], + self.config['request_timeout_ms'], batches) return requests def _produce_request(self, node_id, acks, timeout, batches): """Create a produce request from the given record batches. Returns: - ProduceRequest (version depends on client api_versions) + ProduceRequest (version depends on api_version) """ produce_records_by_partition = collections.defaultdict(dict) for batch in batches: @@ -561,26 +301,32 @@ class Sender(threading.Thread): buf = batch.records.buffer() produce_records_by_partition[topic][partition] = buf - version = self._client.api_version(ProduceRequest, max_version=7) - topic_partition_data = [ - (topic, list(partition_info.items())) - for topic, partition_info in six.iteritems(produce_records_by_partition)] - transactional_id = self._transaction_manager.transactional_id if self._transaction_manager else None - if version >= 3: - return ProduceRequest[version]( - transactional_id=transactional_id, - required_acks=acks, - timeout=timeout, - topics=topic_partition_data, - ) + kwargs = {} + if self.config['api_version'] >= (2, 1): + version = 7 + elif self.config['api_version'] >= (2, 0): + version = 6 + elif self.config['api_version'] >= (1, 1): + version = 5 + elif self.config['api_version'] >= (1, 0): + version = 4 + elif self.config['api_version'] >= (0, 11): + version = 3 + kwargs = dict(transactional_id=None) + elif self.config['api_version'] >= (0, 10): + version = 2 + elif self.config['api_version'] == (0, 9): + version = 1 else: - if transactional_id is not None: - log.warning('%s: Broker does not support ProduceRequest v3+, required for transactional_id', str(self)) - return ProduceRequest[version]( - required_acks=acks, - timeout=timeout, - topics=topic_partition_data, - ) + version = 0 + return ProduceRequest[version]( + required_acks=acks, + timeout=timeout, + topics=[(topic, list(partition_info.items())) + for topic, partition_info + in six.iteritems(produce_records_by_partition)], + **kwargs + ) def wakeup(self): """Wake up the selector associated with this send thread.""" @@ -589,9 +335,6 @@ class Sender(threading.Thread): def bootstrap_connected(self): return self._client.bootstrap_connected() - def __str__(self): - return "" % (self.config['client_id'], self.config['transactional_id']) - class SenderMetrics(object): @@ -624,6 +367,15 @@ class SenderMetrics(object): sensor_name=sensor_name, description='The maximum time in ms record batches spent in the record accumulator.') + sensor_name = 'produce-throttle-time' + self.produce_throttle_time_sensor = self.metrics.sensor(sensor_name) + self.add_metric('produce-throttle-time-avg', Avg(), + sensor_name=sensor_name, + description='The average throttle time in ms') + self.add_metric('produce-throttle-time-max', Max(), + sensor_name=sensor_name, + description='The maximum throttle time in ms') + sensor_name = 'records-per-request' self.records_per_request_sensor = self.metrics.sensor(sensor_name) self.add_metric('record-send-rate', Rate(), @@ -746,9 +498,8 @@ class SenderMetrics(object): records += batch.record_count total_bytes += batch.records.size_in_bytes() - if node_batch: - self.records_per_request_sensor.record(records) - self.byte_rate_sensor.record(total_bytes) + self.records_per_request_sensor.record(records) + self.byte_rate_sensor.record(total_bytes) def record_retries(self, topic, count): self.retry_sensor.record(count) @@ -761,3 +512,6 @@ class SenderMetrics(object): sensor = self.metrics.get_sensor('topic.' + topic + '.record-errors') if sensor: sensor.record(count) + + def record_throttle_time(self, throttle_time_ms, node=None): + self.produce_throttle_time_sensor.record(throttle_time_ms) diff --git a/venv/lib/python3.12/site-packages/kafka/producer/transaction_manager.py b/venv/lib/python3.12/site-packages/kafka/producer/transaction_manager.py deleted file mode 100644 index 5d69ddc..0000000 --- a/venv/lib/python3.12/site-packages/kafka/producer/transaction_manager.py +++ /dev/null @@ -1,981 +0,0 @@ -from __future__ import absolute_import, division - -import abc -import collections -import heapq -import logging -import threading - -from kafka.vendor import six - -try: - # enum in stdlib as of py3.4 - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum - -import kafka.errors as Errors -from kafka.protocol.add_offsets_to_txn import AddOffsetsToTxnRequest -from kafka.protocol.add_partitions_to_txn import AddPartitionsToTxnRequest -from kafka.protocol.end_txn import EndTxnRequest -from kafka.protocol.find_coordinator import FindCoordinatorRequest -from kafka.protocol.init_producer_id import InitProducerIdRequest -from kafka.protocol.txn_offset_commit import TxnOffsetCommitRequest -from kafka.structs import TopicPartition - - -log = logging.getLogger(__name__) - - -NO_PRODUCER_ID = -1 -NO_PRODUCER_EPOCH = -1 -NO_SEQUENCE = -1 - - -class ProducerIdAndEpoch(object): - __slots__ = ('producer_id', 'epoch') - - def __init__(self, producer_id, epoch): - self.producer_id = producer_id - self.epoch = epoch - - @property - def is_valid(self): - return NO_PRODUCER_ID < self.producer_id - - def match(self, batch): - return self.producer_id == batch.producer_id and self.epoch == batch.producer_epoch - - def __eq__(self, other): - return isinstance(other, ProducerIdAndEpoch) and self.producer_id == other.producer_id and self.epoch == other.epoch - - def __str__(self): - return "ProducerIdAndEpoch(producer_id={}, epoch={})".format(self.producer_id, self.epoch) - - -class TransactionState(IntEnum): - UNINITIALIZED = 0 - INITIALIZING = 1 - READY = 2 - IN_TRANSACTION = 3 - COMMITTING_TRANSACTION = 4 - ABORTING_TRANSACTION = 5 - ABORTABLE_ERROR = 6 - FATAL_ERROR = 7 - - @classmethod - def is_transition_valid(cls, source, target): - if target == cls.INITIALIZING: - return source == cls.UNINITIALIZED - elif target == cls.READY: - return source in (cls.INITIALIZING, cls.COMMITTING_TRANSACTION, cls.ABORTING_TRANSACTION) - elif target == cls.IN_TRANSACTION: - return source == cls.READY - elif target == cls.COMMITTING_TRANSACTION: - return source == cls.IN_TRANSACTION - elif target == cls.ABORTING_TRANSACTION: - return source in (cls.IN_TRANSACTION, cls.ABORTABLE_ERROR) - elif target == cls.ABORTABLE_ERROR: - return source in (cls.IN_TRANSACTION, cls.COMMITTING_TRANSACTION, cls.ABORTABLE_ERROR) - elif target == cls.UNINITIALIZED: - # Disallow transitions to UNITIALIZED - return False - elif target == cls.FATAL_ERROR: - # We can transition to FATAL_ERROR unconditionally. - # FATAL_ERROR is never a valid starting state for any transition. So the only option is to close the - # producer or do purely non transactional requests. - return True - - -class Priority(IntEnum): - # We use the priority to determine the order in which requests need to be sent out. For instance, if we have - # a pending FindCoordinator request, that must always go first. Next, If we need a producer id, that must go second. - # The endTxn request must always go last. - FIND_COORDINATOR = 0 - INIT_PRODUCER_ID = 1 - ADD_PARTITIONS_OR_OFFSETS = 2 - END_TXN = 3 - - -class TransactionManager(object): - """ - A class which maintains state for transactions. Also keeps the state necessary to ensure idempotent production. - """ - NO_INFLIGHT_REQUEST_CORRELATION_ID = -1 - # The retry_backoff_ms is overridden to the following value if the first AddPartitions receives a - # CONCURRENT_TRANSACTIONS error. - ADD_PARTITIONS_RETRY_BACKOFF_MS = 20 - - def __init__(self, transactional_id=None, transaction_timeout_ms=0, retry_backoff_ms=100, api_version=(0, 11), metadata=None): - self._api_version = api_version - self._metadata = metadata - - self._sequence_numbers = collections.defaultdict(lambda: 0) - - self.transactional_id = transactional_id - self.transaction_timeout_ms = transaction_timeout_ms - self._transaction_coordinator = None - self._consumer_group_coordinator = None - self._new_partitions_in_transaction = set() - self._pending_partitions_in_transaction = set() - self._partitions_in_transaction = set() - self._pending_txn_offset_commits = dict() - - self._current_state = TransactionState.UNINITIALIZED - self._last_error = None - self.producer_id_and_epoch = ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH) - - self._transaction_started = False - - self._pending_requests = [] # priority queue via heapq - self._pending_requests_sort_id = 0 - self._in_flight_request_correlation_id = self.NO_INFLIGHT_REQUEST_CORRELATION_ID - - # This is used by the TxnRequestHandlers to control how long to back off before a given request is retried. - # For instance, this value is lowered by the AddPartitionsToTxnHandler when it receives a CONCURRENT_TRANSACTIONS - # error for the first AddPartitionsRequest in a transaction. - self.retry_backoff_ms = retry_backoff_ms - self._lock = threading.Condition() - - def initialize_transactions(self): - with self._lock: - self._ensure_transactional() - self._transition_to(TransactionState.INITIALIZING) - self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) - self._sequence_numbers.clear() - handler = InitProducerIdHandler(self, self.transaction_timeout_ms) - self._enqueue_request(handler) - return handler.result - - def begin_transaction(self): - with self._lock: - self._ensure_transactional() - self._maybe_fail_with_error() - self._transition_to(TransactionState.IN_TRANSACTION) - - def begin_commit(self): - with self._lock: - self._ensure_transactional() - self._maybe_fail_with_error() - self._transition_to(TransactionState.COMMITTING_TRANSACTION) - return self._begin_completing_transaction(True) - - def begin_abort(self): - with self._lock: - self._ensure_transactional() - if self._current_state != TransactionState.ABORTABLE_ERROR: - self._maybe_fail_with_error() - self._transition_to(TransactionState.ABORTING_TRANSACTION) - - # We're aborting the transaction, so there should be no need to add new partitions - self._new_partitions_in_transaction.clear() - return self._begin_completing_transaction(False) - - def _begin_completing_transaction(self, committed): - if self._new_partitions_in_transaction: - self._enqueue_request(self._add_partitions_to_transaction_handler()) - handler = EndTxnHandler(self, committed) - self._enqueue_request(handler) - return handler.result - - def send_offsets_to_transaction(self, offsets, consumer_group_id): - with self._lock: - self._ensure_transactional() - self._maybe_fail_with_error() - if self._current_state != TransactionState.IN_TRANSACTION: - raise Errors.KafkaError("Cannot send offsets to transaction because the producer is not in an active transaction") - - log.debug("Begin adding offsets %s for consumer group %s to transaction", offsets, consumer_group_id) - handler = AddOffsetsToTxnHandler(self, consumer_group_id, offsets) - self._enqueue_request(handler) - return handler.result - - def maybe_add_partition_to_transaction(self, topic_partition): - with self._lock: - self._fail_if_not_ready_for_send() - - if self.is_partition_added(topic_partition) or self.is_partition_pending_add(topic_partition): - return - - log.debug("Begin adding new partition %s to transaction", topic_partition) - self._new_partitions_in_transaction.add(topic_partition) - - def _fail_if_not_ready_for_send(self): - with self._lock: - if self.has_error(): - raise Errors.KafkaError( - "Cannot perform send because at least one previous transactional or" - " idempotent request has failed with errors.", self._last_error) - - if self.is_transactional(): - if not self.has_producer_id(): - raise Errors.IllegalStateError( - "Cannot perform a 'send' before completing a call to init_transactions" - " when transactions are enabled.") - - if self._current_state != TransactionState.IN_TRANSACTION: - raise Errors.IllegalStateError("Cannot call send in state %s" % (self._current_state.name,)) - - def is_send_to_partition_allowed(self, tp): - with self._lock: - if self.has_fatal_error(): - return False - return not self.is_transactional() or tp in self._partitions_in_transaction - - def has_producer_id(self, producer_id=None): - if producer_id is None: - return self.producer_id_and_epoch.is_valid - else: - return self.producer_id_and_epoch.producer_id == producer_id - - def is_transactional(self): - return self.transactional_id is not None - - def has_partitions_to_add(self): - with self._lock: - return bool(self._new_partitions_in_transaction) or bool(self._pending_partitions_in_transaction) - - def is_completing(self): - with self._lock: - return self._current_state in ( - TransactionState.COMMITTING_TRANSACTION, - TransactionState.ABORTING_TRANSACTION) - - @property - def last_error(self): - return self._last_error - - def has_error(self): - with self._lock: - return self._current_state in ( - TransactionState.ABORTABLE_ERROR, - TransactionState.FATAL_ERROR) - - def is_aborting(self): - with self._lock: - return self._current_state == TransactionState.ABORTING_TRANSACTION - - def transition_to_abortable_error(self, exc): - with self._lock: - if self._current_state == TransactionState.ABORTING_TRANSACTION: - log.debug("Skipping transition to abortable error state since the transaction is already being " - " aborted. Underlying exception: %s", exc) - return - self._transition_to(TransactionState.ABORTABLE_ERROR, error=exc) - - def transition_to_fatal_error(self, exc): - with self._lock: - self._transition_to(TransactionState.FATAL_ERROR, error=exc) - - def is_partition_added(self, partition): - with self._lock: - return partition in self._partitions_in_transaction - - def is_partition_pending_add(self, partition): - return partition in self._new_partitions_in_transaction or partition in self._pending_partitions_in_transaction - - def has_producer_id_and_epoch(self, producer_id, producer_epoch): - return ( - self.producer_id_and_epoch.producer_id == producer_id and - self.producer_id_and_epoch.epoch == producer_epoch - ) - - def set_producer_id_and_epoch(self, producer_id_and_epoch): - if not isinstance(producer_id_and_epoch, ProducerIdAndEpoch): - raise TypeError("ProducerAndIdEpoch type required") - log.info("ProducerId set to %s with epoch %s", - producer_id_and_epoch.producer_id, producer_id_and_epoch.epoch) - self.producer_id_and_epoch = producer_id_and_epoch - - def reset_producer_id(self): - """ - This method is used when the producer needs to reset its internal state because of an irrecoverable exception - from the broker. - - We need to reset the producer id and associated state when we have sent a batch to the broker, but we either get - a non-retriable exception or we run out of retries, or the batch expired in the producer queue after it was already - sent to the broker. - - In all of these cases, we don't know whether batch was actually committed on the broker, and hence whether the - sequence number was actually updated. If we don't reset the producer state, we risk the chance that all future - messages will return an OutOfOrderSequenceNumberError. - - Note that we can't reset the producer state for the transactional producer as this would mean bumping the epoch - for the same producer id. This might involve aborting the ongoing transaction during the initProducerIdRequest, - and the user would not have any way of knowing this happened. So for the transactional producer, - it's best to return the produce error to the user and let them abort the transaction and close the producer explicitly. - """ - with self._lock: - if self.is_transactional(): - raise Errors.IllegalStateError( - "Cannot reset producer state for a transactional producer." - " You must either abort the ongoing transaction or" - " reinitialize the transactional producer instead") - self.set_producer_id_and_epoch(ProducerIdAndEpoch(NO_PRODUCER_ID, NO_PRODUCER_EPOCH)) - self._sequence_numbers.clear() - - def sequence_number(self, tp): - with self._lock: - return self._sequence_numbers[tp] - - def increment_sequence_number(self, tp, increment): - with self._lock: - if tp not in self._sequence_numbers: - raise Errors.IllegalStateError("Attempt to increment sequence number for a partition with no current sequence.") - # Sequence number wraps at java max int - base = self._sequence_numbers[tp] - if base > (2147483647 - increment): - self._sequence_numbers[tp] = increment - (2147483647 - base) - 1 - else: - self._sequence_numbers[tp] += increment - - def next_request_handler(self, has_incomplete_batches): - with self._lock: - if self._new_partitions_in_transaction: - self._enqueue_request(self._add_partitions_to_transaction_handler()) - - if not self._pending_requests: - return None - - _, _, next_request_handler = self._pending_requests[0] - # Do not send the EndTxn until all batches have been flushed - if isinstance(next_request_handler, EndTxnHandler) and has_incomplete_batches: - return None - - heapq.heappop(self._pending_requests) - if self._maybe_terminate_request_with_error(next_request_handler): - log.debug("Not sending transactional request %s because we are in an error state", - next_request_handler.request) - return None - - if isinstance(next_request_handler, EndTxnHandler) and not self._transaction_started: - next_request_handler.result.done() - if self._current_state != TransactionState.FATAL_ERROR: - log.debug("Not sending EndTxn for completed transaction since no partitions" - " or offsets were successfully added") - self._complete_transaction() - try: - _, _, next_request_handler = heapq.heappop(self._pending_requests) - except IndexError: - next_request_handler = None - - if next_request_handler: - log.debug("Request %s dequeued for sending", next_request_handler.request) - - return next_request_handler - - def retry(self, request): - with self._lock: - request.set_retry() - self._enqueue_request(request) - - def authentication_failed(self, exc): - with self._lock: - for _, _, request in self._pending_requests: - request.fatal_error(exc) - - def coordinator(self, coord_type): - if coord_type == 'group': - return self._consumer_group_coordinator - elif coord_type == 'transaction': - return self._transaction_coordinator - else: - raise Errors.IllegalStateError("Received an invalid coordinator type: %s" % (coord_type,)) - - def lookup_coordinator_for_request(self, request): - self._lookup_coordinator(request.coordinator_type, request.coordinator_key) - - def next_in_flight_request_correlation_id(self): - self._in_flight_request_correlation_id += 1 - return self._in_flight_request_correlation_id - - def clear_in_flight_transactional_request_correlation_id(self): - self._in_flight_request_correlation_id = self.NO_INFLIGHT_REQUEST_CORRELATION_ID - - def has_in_flight_transactional_request(self): - return self._in_flight_request_correlation_id != self.NO_INFLIGHT_REQUEST_CORRELATION_ID - - def has_fatal_error(self): - return self._current_state == TransactionState.FATAL_ERROR - - def has_abortable_error(self): - return self._current_state == TransactionState.ABORTABLE_ERROR - - # visible for testing - def _test_transaction_contains_partition(self, tp): - with self._lock: - return tp in self._partitions_in_transaction - - # visible for testing - def _test_has_pending_offset_commits(self): - return bool(self._pending_txn_offset_commits) - - # visible for testing - def _test_has_ongoing_transaction(self): - with self._lock: - # transactions are considered ongoing once started until completion or a fatal error - return self._current_state == TransactionState.IN_TRANSACTION or self.is_completing() or self.has_abortable_error() - - # visible for testing - def _test_is_ready(self): - with self._lock: - return self.is_transactional() and self._current_state == TransactionState.READY - - def _transition_to(self, target, error=None): - with self._lock: - if not self._current_state.is_transition_valid(self._current_state, target): - raise Errors.KafkaError("TransactionalId %s: Invalid transition attempted from state %s to state %s" % ( - self.transactional_id, self._current_state.name, target.name)) - - if target in (TransactionState.FATAL_ERROR, TransactionState.ABORTABLE_ERROR): - if error is None: - raise Errors.IllegalArgumentError("Cannot transition to %s with an None exception" % (target.name,)) - self._last_error = error - else: - self._last_error = None - - if self._last_error is not None: - log.debug("Transition from state %s to error state %s (%s)", self._current_state.name, target.name, self._last_error) - else: - log.debug("Transition from state %s to %s", self._current_state, target) - self._current_state = target - - def _ensure_transactional(self): - if not self.is_transactional(): - raise Errors.IllegalStateError("Transactional method invoked on a non-transactional producer.") - - def _maybe_fail_with_error(self): - if self.has_error(): - raise Errors.KafkaError("Cannot execute transactional method because we are in an error state: %s" % (self._last_error,)) - - def _maybe_terminate_request_with_error(self, request_handler): - if self.has_error(): - if self.has_abortable_error() and isinstance(request_handler, FindCoordinatorHandler): - # No harm letting the FindCoordinator request go through if we're expecting to abort - return False - request_handler.fail(self._last_error) - return True - return False - - def _next_pending_requests_sort_id(self): - self._pending_requests_sort_id += 1 - return self._pending_requests_sort_id - - def _enqueue_request(self, request_handler): - log.debug("Enqueuing transactional request %s", request_handler.request) - heapq.heappush( - self._pending_requests, - ( - request_handler.priority, # keep lowest priority at head of queue - self._next_pending_requests_sort_id(), # break ties - request_handler - ) - ) - - def _lookup_coordinator(self, coord_type, coord_key): - with self._lock: - if coord_type == 'group': - self._consumer_group_coordinator = None - elif coord_type == 'transaction': - self._transaction_coordinator = None - else: - raise Errors.IllegalStateError("Invalid coordinator type: %s" % (coord_type,)) - self._enqueue_request(FindCoordinatorHandler(self, coord_type, coord_key)) - - def _complete_transaction(self): - with self._lock: - self._transition_to(TransactionState.READY) - self._transaction_started = False - self._new_partitions_in_transaction.clear() - self._pending_partitions_in_transaction.clear() - self._partitions_in_transaction.clear() - - def _add_partitions_to_transaction_handler(self): - with self._lock: - self._pending_partitions_in_transaction.update(self._new_partitions_in_transaction) - self._new_partitions_in_transaction.clear() - return AddPartitionsToTxnHandler(self, self._pending_partitions_in_transaction) - - -class TransactionalRequestResult(object): - def __init__(self): - self._latch = threading.Event() - self._error = None - - def done(self, error=None): - self._error = error - self._latch.set() - - def wait(self, timeout_ms=None): - timeout = timeout_ms / 1000 if timeout_ms is not None else None - success = self._latch.wait(timeout) - if self._error: - raise self._error - return success - - @property - def is_done(self): - return self._latch.is_set() - - @property - def succeeded(self): - return self._latch.is_set() and self._error is None - - @property - def failed(self): - return self._latch.is_set() and self._error is not None - - @property - def exception(self): - return self._error - - -@six.add_metaclass(abc.ABCMeta) -class TxnRequestHandler(object): - def __init__(self, transaction_manager, result=None): - self.transaction_manager = transaction_manager - self.retry_backoff_ms = transaction_manager.retry_backoff_ms - self.request = None - self._result = result or TransactionalRequestResult() - self._is_retry = False - - @property - def transactional_id(self): - return self.transaction_manager.transactional_id - - @property - def producer_id(self): - return self.transaction_manager.producer_id_and_epoch.producer_id - - @property - def producer_epoch(self): - return self.transaction_manager.producer_id_and_epoch.epoch - - def fatal_error(self, exc): - self.transaction_manager.transition_to_fatal_error(exc) - self._result.done(error=exc) - - def abortable_error(self, exc): - self.transaction_manager.transition_to_abortable_error(exc) - self._result.done(error=exc) - - def fail(self, exc): - self._result.done(error=exc) - - def reenqueue(self): - with self.transaction_manager._lock: - self._is_retry = True - self.transaction_manager._enqueue_request(self) - - def on_complete(self, correlation_id, response_or_exc): - if correlation_id != self.transaction_manager._in_flight_request_correlation_id: - self.fatal_error(RuntimeError("Detected more than one in-flight transactional request.")) - else: - self.transaction_manager.clear_in_flight_transactional_request_correlation_id() - if isinstance(response_or_exc, Errors.KafkaConnectionError): - log.debug("Disconnected from node. Will retry.") - if self.needs_coordinator(): - self.transaction_manager._lookup_coordinator(self.coordinator_type, self.coordinator_key) - self.reenqueue() - elif isinstance(response_or_exc, Errors.UnsupportedVersionError): - self.fatal_error(response_or_exc) - elif not isinstance(response_or_exc, (Exception, type(None))): - log.debug("Received transactional response %s for request %s", response_or_exc, self.request) - with self.transaction_manager._lock: - self.handle_response(response_or_exc) - else: - self.fatal_error(Errors.KafkaError("Could not execute transactional request for unknown reasons: %s" % response_or_exc)) - - def needs_coordinator(self): - return self.coordinator_type is not None - - @property - def result(self): - return self._result - - @property - def coordinator_type(self): - return 'transaction' - - @property - def coordinator_key(self): - return self.transaction_manager.transactional_id - - def set_retry(self): - self._is_retry = True - - @property - def is_retry(self): - return self._is_retry - - @abc.abstractmethod - def handle_response(self, response): - pass - - @abc.abstractproperty - def priority(self): - pass - - -class InitProducerIdHandler(TxnRequestHandler): - def __init__(self, transaction_manager, transaction_timeout_ms): - super(InitProducerIdHandler, self).__init__(transaction_manager) - - if transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - self.request = InitProducerIdRequest[version]( - transactional_id=self.transactional_id, - transaction_timeout_ms=transaction_timeout_ms) - - @property - def priority(self): - return Priority.INIT_PRODUCER_ID - - def handle_response(self, response): - error = Errors.for_code(response.error_code) - - if error is Errors.NoError: - self.transaction_manager.set_producer_id_and_epoch(ProducerIdAndEpoch(response.producer_id, response.producer_epoch)) - self.transaction_manager._transition_to(TransactionState.READY) - self._result.done() - elif error in (Errors.NotCoordinatorError, Errors.CoordinatorNotAvailableError): - self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) - self.reenqueue() - elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError): - self.reenqueue() - elif error is Errors.TransactionalIdAuthorizationFailedError: - self.fatal_error(error()) - else: - self.fatal_error(Errors.KafkaError("Unexpected error in InitProducerIdResponse: %s" % (error()))) - -class AddPartitionsToTxnHandler(TxnRequestHandler): - def __init__(self, transaction_manager, topic_partitions): - super(AddPartitionsToTxnHandler, self).__init__(transaction_manager) - - if transaction_manager._api_version >= (2, 7): - version = 2 - elif transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - topic_data = collections.defaultdict(list) - for tp in topic_partitions: - topic_data[tp.topic].append(tp.partition) - self.request = AddPartitionsToTxnRequest[version]( - transactional_id=self.transactional_id, - producer_id=self.producer_id, - producer_epoch=self.producer_epoch, - topics=list(topic_data.items())) - - @property - def priority(self): - return Priority.ADD_PARTITIONS_OR_OFFSETS - - def handle_response(self, response): - has_partition_errors = False - unauthorized_topics = set() - self.retry_backoff_ms = self.transaction_manager.retry_backoff_ms - - results = {TopicPartition(topic, partition): Errors.for_code(error_code) - for topic, partition_data in response.results - for partition, error_code in partition_data} - - for tp, error in six.iteritems(results): - if error is Errors.NoError: - continue - elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError): - self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) - self.reenqueue() - return - elif error is Errors.ConcurrentTransactionsError: - self.maybe_override_retry_backoff_ms() - self.reenqueue() - return - elif error in (Errors.CoordinatorLoadInProgressError, Errors.UnknownTopicOrPartitionError): - self.reenqueue() - return - elif error is Errors.InvalidProducerEpochError: - self.fatal_error(error()) - return - elif error is Errors.TransactionalIdAuthorizationFailedError: - self.fatal_error(error()) - return - elif error in (Errors.InvalidProducerIdMappingError, Errors.InvalidTxnStateError): - self.fatal_error(Errors.KafkaError(error())) - return - elif error is Errors.TopicAuthorizationFailedError: - unauthorized_topics.add(tp.topic) - elif error is Errors.OperationNotAttemptedError: - log.debug("Did not attempt to add partition %s to transaction because other partitions in the" - " batch had errors.", tp) - has_partition_errors = True - else: - log.error("Could not add partition %s due to unexpected error %s", tp, error()) - has_partition_errors = True - - partitions = set(results) - - # Remove the partitions from the pending set regardless of the result. We use the presence - # of partitions in the pending set to know when it is not safe to send batches. However, if - # the partitions failed to be added and we enter an error state, we expect the batches to be - # aborted anyway. In this case, we must be able to continue sending the batches which are in - # retry for partitions that were successfully added. - self.transaction_manager._pending_partitions_in_transaction -= partitions - - if unauthorized_topics: - self.abortable_error(Errors.TopicAuthorizationFailedError(unauthorized_topics)) - elif has_partition_errors: - self.abortable_error(Errors.KafkaError("Could not add partitions to transaction due to errors: %s" % (results))) - else: - log.debug("Successfully added partitions %s to transaction", partitions) - self.transaction_manager._partitions_in_transaction.update(partitions) - self.transaction_manager._transaction_started = True - self._result.done() - - def maybe_override_retry_backoff_ms(self): - # We only want to reduce the backoff when retrying the first AddPartition which errored out due to a - # CONCURRENT_TRANSACTIONS error since this means that the previous transaction is still completing and - # we don't want to wait too long before trying to start the new one. - # - # This is only a temporary fix, the long term solution is being tracked in - # https://issues.apache.org/jira/browse/KAFKA-5482 - if not self.transaction_manager._partitions_in_transaction: - self.retry_backoff_ms = min(self.transaction_manager.ADD_PARTITIONS_RETRY_BACKOFF_MS, self.retry_backoff_ms) - - -class FindCoordinatorHandler(TxnRequestHandler): - def __init__(self, transaction_manager, coord_type, coord_key): - super(FindCoordinatorHandler, self).__init__(transaction_manager) - - self._coord_type = coord_type - self._coord_key = coord_key - if transaction_manager._api_version >= (2, 0): - version = 2 - else: - version = 1 - if coord_type == 'group': - coord_type_int8 = 0 - elif coord_type == 'transaction': - coord_type_int8 = 1 - else: - raise ValueError("Unrecognized coordinator type: %s" % (coord_type,)) - self.request = FindCoordinatorRequest[version]( - coordinator_key=coord_key, - coordinator_type=coord_type_int8, - ) - - @property - def priority(self): - return Priority.FIND_COORDINATOR - - @property - def coordinator_type(self): - return None - - @property - def coordinator_key(self): - return None - - def handle_response(self, response): - error = Errors.for_code(response.error_code) - - if error is Errors.NoError: - coordinator_id = self.transaction_manager._metadata.add_coordinator( - response, self._coord_type, self._coord_key) - if self._coord_type == 'group': - self.transaction_manager._consumer_group_coordinator = coordinator_id - elif self._coord_type == 'transaction': - self.transaction_manager._transaction_coordinator = coordinator_id - self._result.done() - elif error is Errors.CoordinatorNotAvailableError: - self.reenqueue() - elif error is Errors.TransactionalIdAuthorizationFailedError: - self.fatal_error(error()) - elif error is Errors.GroupAuthorizationFailedError: - self.abortable_error(error(self._coord_key)) - else: - self.fatal_error(Errors.KafkaError( - "Could not find a coordinator with type %s with key %s due to" - " unexpected error: %s" % (self._coord_type, self._coord_key, error()))) - - -class EndTxnHandler(TxnRequestHandler): - def __init__(self, transaction_manager, committed): - super(EndTxnHandler, self).__init__(transaction_manager) - - if self.transaction_manager._api_version >= (2, 7): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - self.request = EndTxnRequest[version]( - transactional_id=self.transactional_id, - producer_id=self.producer_id, - producer_epoch=self.producer_epoch, - committed=committed) - - @property - def priority(self): - return Priority.END_TXN - - def handle_response(self, response): - error = Errors.for_code(response.error_code) - - if error is Errors.NoError: - self.transaction_manager._complete_transaction() - self._result.done() - elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError): - self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) - self.reenqueue() - elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError): - self.reenqueue() - elif error is Errors.InvalidProducerEpochError: - self.fatal_error(error()) - elif error is Errors.TransactionalIdAuthorizationFailedError: - self.fatal_error(error()) - elif error is Errors.InvalidTxnStateError: - self.fatal_error(error()) - else: - self.fatal_error(Errors.KafkaError("Unhandled error in EndTxnResponse: %s" % (error()))) - - -class AddOffsetsToTxnHandler(TxnRequestHandler): - def __init__(self, transaction_manager, consumer_group_id, offsets): - super(AddOffsetsToTxnHandler, self).__init__(transaction_manager) - - self.consumer_group_id = consumer_group_id - self.offsets = offsets - if self.transaction_manager._api_version >= (2, 7): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - self.request = AddOffsetsToTxnRequest[version]( - transactional_id=self.transactional_id, - producer_id=self.producer_id, - producer_epoch=self.producer_epoch, - group_id=consumer_group_id) - - @property - def priority(self): - return Priority.ADD_PARTITIONS_OR_OFFSETS - - def handle_response(self, response): - error = Errors.for_code(response.error_code) - - if error is Errors.NoError: - log.debug("Successfully added partition for consumer group %s to transaction", self.consumer_group_id) - - # note the result is not completed until the TxnOffsetCommit returns - for tp, offset in six.iteritems(self.offsets): - self.transaction_manager._pending_txn_offset_commits[tp] = offset - handler = TxnOffsetCommitHandler(self.transaction_manager, self.consumer_group_id, - self.transaction_manager._pending_txn_offset_commits, self._result) - self.transaction_manager._enqueue_request(handler) - self.transaction_manager._transaction_started = True - elif error in (Errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError): - self.transaction_manager._lookup_coordinator('transaction', self.transactional_id) - self.reenqueue() - elif error in (Errors.CoordinatorLoadInProgressError, Errors.ConcurrentTransactionsError): - self.reenqueue() - elif error is Errors.InvalidProducerEpochError: - self.fatal_error(error()) - elif error is Errors.TransactionalIdAuthorizationFailedError: - self.fatal_error(error()) - elif error is Errors.GroupAuthorizationFailedError: - self.abortable_error(error(self.consumer_group_id)) - else: - self.fatal_error(Errors.KafkaError("Unexpected error in AddOffsetsToTxnResponse: %s" % (error()))) - - -class TxnOffsetCommitHandler(TxnRequestHandler): - def __init__(self, transaction_manager, consumer_group_id, offsets, result): - super(TxnOffsetCommitHandler, self).__init__(transaction_manager, result=result) - - self.consumer_group_id = consumer_group_id - self.offsets = offsets - self.request = self._build_request() - - def _build_request(self): - if self.transaction_manager._api_version >= (2, 1): - version = 2 - elif self.transaction_manager._api_version >= (2, 0): - version = 1 - else: - version = 0 - - topic_data = collections.defaultdict(list) - for tp, offset in six.iteritems(self.offsets): - if version >= 2: - partition_data = (tp.partition, offset.offset, offset.leader_epoch, offset.metadata) - else: - partition_data = (tp.partition, offset.offset, offset.metadata) - topic_data[tp.topic].append(partition_data) - - return TxnOffsetCommitRequest[version]( - transactional_id=self.transactional_id, - group_id=self.consumer_group_id, - producer_id=self.producer_id, - producer_epoch=self.producer_epoch, - topics=list(topic_data.items())) - - @property - def priority(self): - return Priority.ADD_PARTITIONS_OR_OFFSETS - - @property - def coordinator_type(self): - return 'group' - - @property - def coordinator_key(self): - return self.consumer_group_id - - def handle_response(self, response): - lookup_coordinator = False - retriable_failure = False - - errors = {TopicPartition(topic, partition): Errors.for_code(error_code) - for topic, partition_data in response.topics - for partition, error_code in partition_data} - - for tp, error in six.iteritems(errors): - if error is Errors.NoError: - log.debug("Successfully added offsets for %s from consumer group %s to transaction.", - tp, self.consumer_group_id) - del self.transaction_manager._pending_txn_offset_commits[tp] - elif error in (errors.CoordinatorNotAvailableError, Errors.NotCoordinatorError, Errors.RequestTimedOutError): - retriable_failure = True - lookup_coordinator = True - elif error is Errors.UnknownTopicOrPartitionError: - retriable_failure = True - elif error is Errors.GroupAuthorizationFailedError: - self.abortable_error(error(self.consumer_group_id)) - return - elif error in (Errors.TransactionalIdAuthorizationFailedError, - Errors.InvalidProducerEpochError, - Errors.UnsupportedForMessageFormatError): - self.fatal_error(error()) - return - else: - self.fatal_error(Errors.KafkaError("Unexpected error in TxnOffsetCommitResponse: %s" % (error()))) - return - - if lookup_coordinator: - self.transaction_manager._lookup_coordinator('group', self.consumer_group_id) - - if not retriable_failure: - # all attempted partitions were either successful, or there was a fatal failure. - # either way, we are not retrying, so complete the request. - self.result.done() - - # retry the commits which failed with a retriable error. - elif self.transaction_manager._pending_txn_offset_commits: - self.offsets = self.transaction_manager._pending_txn_offset_commits - self.request = self._build_request() - self.reenqueue() diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/__init__.py b/venv/lib/python3.12/site-packages/kafka/protocol/__init__.py index 025447f..26dcc78 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/__init__.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/__init__.py @@ -43,7 +43,4 @@ API_KEYS = { 40: 'ExpireDelegationToken', 41: 'DescribeDelegationToken', 42: 'DeleteGroups', - 45: 'AlterPartitionReassignments', - 46: 'ListPartitionReassignments', - 48: 'DescribeClientQuotas', } diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/abstract.py b/venv/lib/python3.12/site-packages/kafka/protocol/abstract.py index 7ce5fc1..2de65c4 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/abstract.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/abstract.py @@ -2,11 +2,10 @@ from __future__ import absolute_import import abc -from kafka.vendor.six import add_metaclass - -@add_metaclass(abc.ABCMeta) class AbstractType(object): + __metaclass__ = abc.ABCMeta + @abc.abstractmethod def encode(cls, value): # pylint: disable=no-self-argument pass diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/add_offsets_to_txn.py b/venv/lib/python3.12/site-packages/kafka/protocol/add_offsets_to_txn.py deleted file mode 100644 index fa25093..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/add_offsets_to_txn.py +++ /dev/null @@ -1,59 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Int16, Int32, Int64, Schema, String - - -class AddOffsetsToTxnResponse_v0(Response): - API_KEY = 25 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ) - - -class AddOffsetsToTxnResponse_v1(Response): - API_KEY = 25 - API_VERSION = 1 - SCHEMA = AddOffsetsToTxnResponse_v0.SCHEMA - - -class AddOffsetsToTxnResponse_v2(Response): - API_KEY = 25 - API_VERSION = 2 - SCHEMA = AddOffsetsToTxnResponse_v1.SCHEMA - - -class AddOffsetsToTxnRequest_v0(Request): - API_KEY = 25 - API_VERSION = 0 - RESPONSE_TYPE = AddOffsetsToTxnResponse_v0 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('group_id', String('utf-8')), - ) - - -class AddOffsetsToTxnRequest_v1(Request): - API_KEY = 25 - API_VERSION = 1 - RESPONSE_TYPE = AddOffsetsToTxnResponse_v1 - SCHEMA = AddOffsetsToTxnRequest_v0.SCHEMA - - -class AddOffsetsToTxnRequest_v2(Request): - API_KEY = 25 - API_VERSION = 2 - RESPONSE_TYPE = AddOffsetsToTxnResponse_v2 - SCHEMA = AddOffsetsToTxnRequest_v1.SCHEMA - - -AddOffsetsToTxnRequest = [ - AddOffsetsToTxnRequest_v0, AddOffsetsToTxnRequest_v1, AddOffsetsToTxnRequest_v2, -] -AddOffsetsToTxnResponse = [ - AddOffsetsToTxnResponse_v0, AddOffsetsToTxnResponse_v1, AddOffsetsToTxnResponse_v2, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/add_partitions_to_txn.py b/venv/lib/python3.12/site-packages/kafka/protocol/add_partitions_to_txn.py deleted file mode 100644 index fdf28f4..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/add_partitions_to_txn.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int16, Int32, Int64, Schema, String - - -class AddPartitionsToTxnResponse_v0(Response): - API_KEY = 24 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('results', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16)))))) - - -class AddPartitionsToTxnResponse_v1(Response): - API_KEY = 24 - API_VERSION = 1 - SCHEMA = AddPartitionsToTxnResponse_v0.SCHEMA - - -class AddPartitionsToTxnResponse_v2(Response): - API_KEY = 24 - API_VERSION = 2 - SCHEMA = AddPartitionsToTxnResponse_v1.SCHEMA - - -class AddPartitionsToTxnRequest_v0(Request): - API_KEY = 24 - API_VERSION = 0 - RESPONSE_TYPE = AddPartitionsToTxnResponse_v0 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array(Int32))))) - - -class AddPartitionsToTxnRequest_v1(Request): - API_KEY = 24 - API_VERSION = 1 - RESPONSE_TYPE = AddPartitionsToTxnResponse_v1 - SCHEMA = AddPartitionsToTxnRequest_v0.SCHEMA - - -class AddPartitionsToTxnRequest_v2(Request): - API_KEY = 24 - API_VERSION = 2 - RESPONSE_TYPE = AddPartitionsToTxnResponse_v2 - SCHEMA = AddPartitionsToTxnRequest_v1.SCHEMA - - -AddPartitionsToTxnRequest = [ - AddPartitionsToTxnRequest_v0, AddPartitionsToTxnRequest_v1, AddPartitionsToTxnRequest_v2, -] -AddPartitionsToTxnResponse = [ - AddPartitionsToTxnResponse_v0, AddPartitionsToTxnResponse_v1, AddPartitionsToTxnResponse_v2, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/admin.py b/venv/lib/python3.12/site-packages/kafka/protocol/admin.py index 2551668..f3b691a 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/admin.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/admin.py @@ -1,14 +1,67 @@ from __future__ import absolute_import -# enum in stdlib as of py3.4 -try: - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum - from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64, CompactString, CompactArray, TaggedFields +from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String + + +class ApiVersionResponse_v0(Response): + API_KEY = 18 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('api_versions', Array( + ('api_key', Int16), + ('min_version', Int16), + ('max_version', Int16))) + ) + + +class ApiVersionResponse_v1(Response): + API_KEY = 18 + API_VERSION = 1 + SCHEMA = Schema( + ('error_code', Int16), + ('api_versions', Array( + ('api_key', Int16), + ('min_version', Int16), + ('max_version', Int16))), + ('throttle_time_ms', Int32) + ) + + +class ApiVersionResponse_v2(Response): + API_KEY = 18 + API_VERSION = 2 + SCHEMA = ApiVersionResponse_v1.SCHEMA + + +class ApiVersionRequest_v0(Request): + API_KEY = 18 + API_VERSION = 0 + RESPONSE_TYPE = ApiVersionResponse_v0 + SCHEMA = Schema() + + +class ApiVersionRequest_v1(Request): + API_KEY = 18 + API_VERSION = 1 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +class ApiVersionRequest_v2(Request): + API_KEY = 18 + API_VERSION = 2 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +ApiVersionRequest = [ + ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, +] +ApiVersionResponse = [ + ApiVersionResponse_v0, ApiVersionResponse_v1, ApiVersionResponse_v2, +] class CreateTopicsResponse_v0(Response): @@ -186,38 +239,6 @@ DeleteTopicsResponse = [ ] -class DeleteRecordsResponse_v0(Response): - API_KEY = 21 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('name', String('utf-8')), - ('partitions', Array( - ('partition_index', Int32), - ('low_watermark', Int64), - ('error_code', Int16))))), - ) - - -class DeleteRecordsRequest_v0(Request): - API_KEY = 21 - API_VERSION = 0 - RESPONSE_TYPE = DeleteRecordsResponse_v0 - SCHEMA = Schema( - ('topics', Array( - ('name', String('utf-8')), - ('partitions', Array( - ('partition_index', Int32), - ('offset', Int64))))), - ('timeout_ms', Int32) - ) - - -DeleteRecordsResponse = [DeleteRecordsResponse_v0] -DeleteRecordsRequest = [DeleteRecordsRequest_v0] - - class ListGroupsResponse_v0(Response): API_KEY = 16 API_VERSION = 0 @@ -385,6 +406,41 @@ DescribeGroupsResponse = [ ] +class SaslHandShakeResponse_v0(Response): + API_KEY = 17 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('enabled_mechanisms', Array(String('utf-8'))) + ) + + +class SaslHandShakeResponse_v1(Response): + API_KEY = 17 + API_VERSION = 1 + SCHEMA = SaslHandShakeResponse_v0.SCHEMA + + +class SaslHandShakeRequest_v0(Request): + API_KEY = 17 + API_VERSION = 0 + RESPONSE_TYPE = SaslHandShakeResponse_v0 + SCHEMA = Schema( + ('mechanism', String('utf-8')) + ) + + +class SaslHandShakeRequest_v1(Request): + API_KEY = 17 + API_VERSION = 1 + RESPONSE_TYPE = SaslHandShakeResponse_v1 + SCHEMA = SaslHandShakeRequest_v0.SCHEMA + + +SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] +SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] + + class DescribeAclsResponse_v0(Response): API_KEY = 29 API_VERSION = 0 @@ -467,8 +523,8 @@ class DescribeAclsRequest_v2(Request): SCHEMA = DescribeAclsRequest_v1.SCHEMA -DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1, DescribeAclsRequest_v2] -DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1, DescribeAclsResponse_v2] +DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1] +DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1] class CreateAclsResponse_v0(Response): API_KEY = 30 @@ -663,7 +719,7 @@ class DescribeConfigsResponse_v1(Response): ('config_names', String('utf-8')), ('config_value', String('utf-8')), ('read_only', Boolean), - ('config_source', Int8), + ('is_default', Boolean), ('is_sensitive', Boolean), ('config_synonyms', Array( ('config_name', String('utf-8')), @@ -734,47 +790,6 @@ DescribeConfigsResponse = [ ] -class DescribeLogDirsResponse_v0(Response): - API_KEY = 35 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('log_dirs', Array( - ('error_code', Int16), - ('log_dir', String('utf-8')), - ('topics', Array( - ('name', String('utf-8')), - ('partitions', Array( - ('partition_index', Int32), - ('partition_size', Int64), - ('offset_lag', Int64), - ('is_future_key', Boolean) - )) - )) - )) - ) - - -class DescribeLogDirsRequest_v0(Request): - API_KEY = 35 - API_VERSION = 0 - RESPONSE_TYPE = DescribeLogDirsResponse_v0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Int32) - )) - ) - - -DescribeLogDirsResponse = [ - DescribeLogDirsResponse_v0, -] -DescribeLogDirsRequest = [ - DescribeLogDirsRequest_v0, -] - - class SaslAuthenticateResponse_v0(Response): API_KEY = 36 API_VERSION = 0 @@ -908,208 +923,3 @@ DeleteGroupsRequest = [ DeleteGroupsResponse = [ DeleteGroupsResponse_v0, DeleteGroupsResponse_v1 ] - - -class DescribeClientQuotasResponse_v0(Response): - API_KEY = 48 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('entries', Array( - ('entity', Array( - ('entity_type', String('utf-8')), - ('entity_name', String('utf-8')))), - ('values', Array( - ('name', String('utf-8')), - ('value', Float64))))), - ) - - -class DescribeClientQuotasRequest_v0(Request): - API_KEY = 48 - API_VERSION = 0 - RESPONSE_TYPE = DescribeClientQuotasResponse_v0 - SCHEMA = Schema( - ('components', Array( - ('entity_type', String('utf-8')), - ('match_type', Int8), - ('match', String('utf-8')), - )), - ('strict', Boolean) - ) - - -DescribeClientQuotasRequest = [ - DescribeClientQuotasRequest_v0, -] - -DescribeClientQuotasResponse = [ - DescribeClientQuotasResponse_v0, -] - - -class AlterPartitionReassignmentsResponse_v0(Response): - API_KEY = 45 - API_VERSION = 0 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("responses", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - FLEXIBLE_VERSION = True - - -class AlterPartitionReassignmentsRequest_v0(Request): - FLEXIBLE_VERSION = True - API_KEY = 45 - API_VERSION = 0 - RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 - SCHEMA = Schema( - ("timeout_ms", Int32), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0] - -AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0] - - -class ListPartitionReassignmentsResponse_v0(Response): - API_KEY = 46 - API_VERSION = 0 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("adding_replicas", CompactArray(Int32)), - ("removing_replicas", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - FLEXIBLE_VERSION = True - - -class ListPartitionReassignmentsRequest_v0(Request): - FLEXIBLE_VERSION = True - API_KEY = 46 - API_VERSION = 0 - RESPONSE_TYPE = ListPartitionReassignmentsResponse_v0 - SCHEMA = Schema( - ("timeout_ms", Int32), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partition_index", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0] - -ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0] - - -class ElectLeadersResponse_v0(Response): - API_KEY = 43 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('replication_election_results', Array( - ('topic', String('utf-8')), - ('partition_result', Array( - ('partition_id', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')) - )) - )) - ) - - -class ElectLeadersRequest_v0(Request): - API_KEY = 43 - API_VERSION = 1 - RESPONSE_TYPE = ElectLeadersResponse_v0 - SCHEMA = Schema( - ('election_type', Int8), - ('topic_partitions', Array( - ('topic', String('utf-8')), - ('partition_ids', Array(Int32)) - )), - ('timeout', Int32), - ) - - -class ElectLeadersResponse_v1(Response): - API_KEY = 43 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('replication_election_results', Array( - ('topic', String('utf-8')), - ('partition_result', Array( - ('partition_id', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')) - )) - )) - ) - - -class ElectLeadersRequest_v1(Request): - API_KEY = 43 - API_VERSION = 1 - RESPONSE_TYPE = ElectLeadersResponse_v1 - SCHEMA = Schema( - ('election_type', Int8), - ('topic_partitions', Array( - ('topic', String('utf-8')), - ('partition_ids', Array(Int32)) - )), - ('timeout', Int32), - ) - - -class ElectionType(IntEnum): - """ Leader election type - """ - - PREFERRED = 0, - UNCLEAN = 1 - - -ElectLeadersRequest = [ElectLeadersRequest_v0, ElectLeadersRequest_v1] -ElectLeadersResponse = [ElectLeadersResponse_v0, ElectLeadersResponse_v1] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/api.py b/venv/lib/python3.12/site-packages/kafka/protocol/api.py index 9cd5767..64276fc 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/api.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/api.py @@ -3,9 +3,7 @@ from __future__ import absolute_import import abc from kafka.protocol.struct import Struct -from kafka.protocol.types import Int16, Int32, String, Schema, Array, TaggedFields - -from kafka.vendor.six import add_metaclass +from kafka.protocol.types import Int16, Int32, String, Schema, Array class RequestHeader(Struct): @@ -22,38 +20,8 @@ class RequestHeader(Struct): ) -class RequestHeaderV2(Struct): - # Flexible response / request headers end in field buffer - SCHEMA = Schema( - ('api_key', Int16), - ('api_version', Int16), - ('correlation_id', Int32), - ('client_id', String('utf-8')), - ('tags', TaggedFields), - ) - - def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=None): - super(RequestHeaderV2, self).__init__( - request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} - ) - - -class ResponseHeader(Struct): - SCHEMA = Schema( - ('correlation_id', Int32), - ) - - -class ResponseHeaderV2(Struct): - SCHEMA = Schema( - ('correlation_id', Int32), - ('tags', TaggedFields), - ) - - -@add_metaclass(abc.ABCMeta) class Request(Struct): - FLEXIBLE_VERSION = False + __metaclass__ = abc.ABCMeta @abc.abstractproperty def API_KEY(self): @@ -82,15 +50,9 @@ class Request(Struct): def to_object(self): return _to_object(self.SCHEMA, self) - def build_header(self, correlation_id, client_id): - if self.FLEXIBLE_VERSION: - return RequestHeaderV2(self, correlation_id=correlation_id, client_id=client_id) - return RequestHeader(self, correlation_id=correlation_id, client_id=client_id) - -@add_metaclass(abc.ABCMeta) class Response(Struct): - FLEXIBLE_VERSION = False + __metaclass__ = abc.ABCMeta @abc.abstractproperty def API_KEY(self): @@ -110,12 +72,6 @@ class Response(Struct): def to_object(self): return _to_object(self.SCHEMA, self) - @classmethod - def parse_header(cls, read_buffer): - if cls.FLEXIBLE_VERSION: - return ResponseHeaderV2.decode(read_buffer) - return ResponseHeader.decode(read_buffer) - def _to_object(schema, data): obj = {} diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/api_versions.py b/venv/lib/python3.12/site-packages/kafka/protocol/api_versions.py deleted file mode 100644 index e7cedd9..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/api_versions.py +++ /dev/null @@ -1,134 +0,0 @@ -from __future__ import absolute_import - -from io import BytesIO - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, CompactArray, CompactString, Int16, Int32, Schema, TaggedFields - - -class BaseApiVersionsResponse(Response): - API_KEY = 18 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', Array( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16))) - ) - - @classmethod - def decode(cls, data): - if isinstance(data, bytes): - data = BytesIO(data) - # Check error_code, decode as v0 if any error - curr = data.tell() - err = Int16.decode(data) - data.seek(curr) - if err != 0: - return ApiVersionsResponse_v0.decode(data) - return super(BaseApiVersionsResponse, cls).decode(data) - - -class ApiVersionsResponse_v0(Response): - API_KEY = 18 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', Array( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16))) - ) - - -class ApiVersionsResponse_v1(BaseApiVersionsResponse): - API_KEY = 18 - API_VERSION = 1 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', Array( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16))), - ('throttle_time_ms', Int32) - ) - - -class ApiVersionsResponse_v2(BaseApiVersionsResponse): - API_KEY = 18 - API_VERSION = 2 - SCHEMA = ApiVersionsResponse_v1.SCHEMA - - -class ApiVersionsResponse_v3(BaseApiVersionsResponse): - API_KEY = 18 - API_VERSION = 3 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', CompactArray( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16), - ('_tagged_fields', TaggedFields))), - ('throttle_time_ms', Int32), - ('_tagged_fields', TaggedFields) - ) - # Note: ApiVersions Response does not send FLEXIBLE_VERSION header! - - -class ApiVersionsResponse_v4(BaseApiVersionsResponse): - API_KEY = 18 - API_VERSION = 4 - SCHEMA = ApiVersionsResponse_v3.SCHEMA - - -class ApiVersionsRequest_v0(Request): - API_KEY = 18 - API_VERSION = 0 - RESPONSE_TYPE = ApiVersionsResponse_v0 - SCHEMA = Schema() - - -class ApiVersionsRequest_v1(Request): - API_KEY = 18 - API_VERSION = 1 - RESPONSE_TYPE = ApiVersionsResponse_v1 - SCHEMA = ApiVersionsRequest_v0.SCHEMA - - -class ApiVersionsRequest_v2(Request): - API_KEY = 18 - API_VERSION = 2 - RESPONSE_TYPE = ApiVersionsResponse_v2 - SCHEMA = ApiVersionsRequest_v1.SCHEMA - - -class ApiVersionsRequest_v3(Request): - API_KEY = 18 - API_VERSION = 3 - RESPONSE_TYPE = ApiVersionsResponse_v3 - SCHEMA = Schema( - ('client_software_name', CompactString('utf-8')), - ('client_software_version', CompactString('utf-8')), - ('_tagged_fields', TaggedFields) - ) - FLEXIBLE_VERSION = True - - -class ApiVersionsRequest_v4(Request): - API_KEY = 18 - API_VERSION = 4 - RESPONSE_TYPE = ApiVersionsResponse_v4 - SCHEMA = ApiVersionsRequest_v3.SCHEMA - FLEXIBLE_VERSION = True - - -ApiVersionsRequest = [ - ApiVersionsRequest_v0, ApiVersionsRequest_v1, ApiVersionsRequest_v2, - ApiVersionsRequest_v3, ApiVersionsRequest_v4, -] -ApiVersionsResponse = [ - ApiVersionsResponse_v0, ApiVersionsResponse_v1, ApiVersionsResponse_v2, - ApiVersionsResponse_v3, ApiVersionsResponse_v4, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/broker_api_versions.py b/venv/lib/python3.12/site-packages/kafka/protocol/broker_api_versions.py deleted file mode 100644 index af142d0..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/broker_api_versions.py +++ /dev/null @@ -1,68 +0,0 @@ -BROKER_API_VERSIONS = { - # api_versions responses prior to (0, 10) are synthesized for compatibility - (0, 8, 0): {0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0)}, - # adds offset commit + fetch - (0, 8, 1): {0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 8: (0, 0), 9: (0, 0)}, - # adds find coordinator - (0, 8, 2): {0: (0, 0), 1: (0, 0), 2: (0, 0), 3: (0, 0), 8: (0, 1), 9: (0, 1), 10: (0, 0)}, - # adds group management (join/sync/leave/heartbeat) - (0, 9): {0: (0, 1), 1: (0, 1), 2: (0, 0), 3: (0, 0), 8: (0, 2), 9: (0, 1), 10: (0, 0), 11: (0, 0), 12: (0, 0), 13: (0, 0), 14: (0, 0), 15: (0, 0), 16: (0, 0)}, - # adds message format v1, sasl, and api versions api - (0, 10, 0): {0: (0, 2), 1: (0, 2), 2: (0, 0), 3: (0, 1), 4: (0, 0), 5: (0, 0), 6: (0, 2), 7: (1, 1), 8: (0, 2), 9: (0, 1), 10: (0, 0), 11: (0, 0), 12: (0, 0), 13: (0, 0), 14: (0, 0), 15: (0, 0), 16: (0, 0), 17: (0, 0), 18: (0, 0)}, - - # All data below is copied from brokers via api_versions_response (see make servers/*/api_versions) - # adds admin apis create/delete topics, and bumps fetch/listoffsets/metadata/joingroup - (0, 10, 1): {0: (0, 2), 1: (0, 3), 2: (0, 1), 3: (0, 2), 4: (0, 0), 5: (0, 0), 6: (0, 2), 7: (1, 1), 8: (0, 2), 9: (0, 1), 10: (0, 0), 11: (0, 1), 12: (0, 0), 13: (0, 0), 14: (0, 0), 15: (0, 0), 16: (0, 0), 17: (0, 0), 18: (0, 0), 19: (0, 0), 20: (0, 0)}, - - # bumps offsetfetch/create-topics - (0, 10, 2): {0: (0, 2), 1: (0, 3), 2: (0, 1), 3: (0, 2), 4: (0, 0), 5: (0, 0), 6: (0, 3), 7: (1, 1), 8: (0, 2), 9: (0, 2), 10: (0, 0), 11: (0, 1), 12: (0, 0), 13: (0, 0), 14: (0, 0), 15: (0, 0), 16: (0, 0), 17: (0, 0), 18: (0, 0), 19: (0, 1), 20: (0, 0)}, - - # Adds message format v2, and more admin apis (describe/create/delete acls, describe/alter configs, etc) - (0, 11): {0: (0, 3), 1: (0, 5), 2: (0, 2), 3: (0, 4), 4: (0, 0), 5: (0, 0), 6: (0, 3), 7: (1, 1), 8: (0, 3), 9: (0, 3), 10: (0, 1), 11: (0, 2), 12: (0, 1), 13: (0, 1), 14: (0, 1), 15: (0, 1), 16: (0, 1), 17: (0, 0), 18: (0, 1), 19: (0, 2), 20: (0, 1), 21: (0, 0), 22: (0, 0), 23: (0, 0), 24: (0, 0), 25: (0, 0), 26: (0, 0), 27: (0, 0), 28: (0, 0), 29: (0, 0), 30: (0, 0), 31: (0, 0), 32: (0, 0), 33: (0, 0)}, - - # Adds Sasl Authenticate, and additional admin apis (describe/alter log dirs, etc) - (1, 0): {0: (0, 5), 1: (0, 6), 2: (0, 2), 3: (0, 5), 4: (0, 1), 5: (0, 0), 6: (0, 4), 7: (0, 1), 8: (0, 3), 9: (0, 3), 10: (0, 1), 11: (0, 2), 12: (0, 1), 13: (0, 1), 14: (0, 1), 15: (0, 1), 16: (0, 1), 17: (0, 1), 18: (0, 1), 19: (0, 2), 20: (0, 1), 21: (0, 0), 22: (0, 0), 23: (0, 0), 24: (0, 0), 25: (0, 0), 26: (0, 0), 27: (0, 0), 28: (0, 0), 29: (0, 0), 30: (0, 0), 31: (0, 0), 32: (0, 0), 33: (0, 0), 34: (0, 0), 35: (0, 0), 36: (0, 0), 37: (0, 0)}, - - (1, 1): {0: (0, 5), 1: (0, 7), 2: (0, 2), 3: (0, 5), 4: (0, 1), 5: (0, 0), 6: (0, 4), 7: (0, 1), 8: (0, 3), 9: (0, 3), 10: (0, 1), 11: (0, 2), 12: (0, 1), 13: (0, 1), 14: (0, 1), 15: (0, 1), 16: (0, 1), 17: (0, 1), 18: (0, 1), 19: (0, 2), 20: (0, 1), 21: (0, 0), 22: (0, 0), 23: (0, 0), 24: (0, 0), 25: (0, 0), 26: (0, 0), 27: (0, 0), 28: (0, 0), 29: (0, 0), 30: (0, 0), 31: (0, 0), 32: (0, 1), 33: (0, 0), 34: (0, 0), 35: (0, 0), 36: (0, 0), 37: (0, 0), 38: (0, 0), 39: (0, 0), 40: (0, 0), 41: (0, 0), 42: (0, 0)}, - - (2, 0): {0: (0, 6), 1: (0, 8), 2: (0, 3), 3: (0, 6), 4: (0, 1), 5: (0, 0), 6: (0, 4), 7: (0, 1), 8: (0, 4), 9: (0, 4), 10: (0, 2), 11: (0, 3), 12: (0, 2), 13: (0, 2), 14: (0, 2), 15: (0, 2), 16: (0, 2), 17: (0, 1), 18: (0, 2), 19: (0, 3), 20: (0, 2), 21: (0, 1), 22: (0, 1), 23: (0, 1), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 1), 29: (0, 1), 30: (0, 1), 31: (0, 1), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 0), 37: (0, 1), 38: (0, 1), 39: (0, 1), 40: (0, 1), 41: (0, 1), 42: (0, 1)}, - - (2, 1): {0: (0, 7), 1: (0, 10), 2: (0, 4), 3: (0, 7), 4: (0, 1), 5: (0, 0), 6: (0, 4), 7: (0, 1), 8: (0, 6), 9: (0, 5), 10: (0, 2), 11: (0, 3), 12: (0, 2), 13: (0, 2), 14: (0, 2), 15: (0, 2), 16: (0, 2), 17: (0, 1), 18: (0, 2), 19: (0, 3), 20: (0, 3), 21: (0, 1), 22: (0, 1), 23: (0, 2), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 2), 29: (0, 1), 30: (0, 1), 31: (0, 1), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 0), 37: (0, 1), 38: (0, 1), 39: (0, 1), 40: (0, 1), 41: (0, 1), 42: (0, 1)}, - - (2, 2): {0: (0, 7), 1: (0, 10), 2: (0, 5), 3: (0, 7), 4: (0, 2), 5: (0, 1), 6: (0, 5), 7: (0, 2), 8: (0, 6), 9: (0, 5), 10: (0, 2), 11: (0, 4), 12: (0, 2), 13: (0, 2), 14: (0, 2), 15: (0, 2), 16: (0, 2), 17: (0, 1), 18: (0, 2), 19: (0, 3), 20: (0, 3), 21: (0, 1), 22: (0, 1), 23: (0, 2), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 2), 29: (0, 1), 30: (0, 1), 31: (0, 1), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 1), 37: (0, 1), 38: (0, 1), 39: (0, 1), 40: (0, 1), 41: (0, 1), 42: (0, 1), 43: (0, 0)}, - - (2, 3): {0: (0, 7), 1: (0, 11), 2: (0, 5), 3: (0, 8), 4: (0, 2), 5: (0, 1), 6: (0, 5), 7: (0, 2), 8: (0, 7), 9: (0, 5), 10: (0, 2), 11: (0, 5), 12: (0, 3), 13: (0, 2), 14: (0, 3), 15: (0, 3), 16: (0, 2), 17: (0, 1), 18: (0, 2), 19: (0, 3), 20: (0, 3), 21: (0, 1), 22: (0, 1), 23: (0, 3), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 2), 29: (0, 1), 30: (0, 1), 31: (0, 1), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 1), 37: (0, 1), 38: (0, 1), 39: (0, 1), 40: (0, 1), 41: (0, 1), 42: (0, 1), 43: (0, 0), 44: (0, 0)}, - - (2, 4): {0: (0, 8), 1: (0, 11), 2: (0, 5), 3: (0, 9), 4: (0, 4), 5: (0, 2), 6: (0, 6), 7: (0, 3), 8: (0, 8), 9: (0, 6), 10: (0, 3), 11: (0, 6), 12: (0, 4), 13: (0, 4), 14: (0, 4), 15: (0, 5), 16: (0, 3), 17: (0, 1), 18: (0, 3), 19: (0, 5), 20: (0, 4), 21: (0, 1), 22: (0, 2), 23: (0, 3), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 2), 29: (0, 1), 30: (0, 1), 31: (0, 1), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 1), 37: (0, 1), 38: (0, 2), 39: (0, 1), 40: (0, 1), 41: (0, 1), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0)}, - - (2, 5): {0: (0, 8), 1: (0, 11), 2: (0, 5), 3: (0, 9), 4: (0, 4), 5: (0, 2), 6: (0, 6), 7: (0, 3), 8: (0, 8), 9: (0, 7), 10: (0, 3), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 3), 17: (0, 1), 18: (0, 3), 19: (0, 5), 20: (0, 4), 21: (0, 1), 22: (0, 3), 23: (0, 3), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 2), 33: (0, 1), 34: (0, 1), 35: (0, 1), 36: (0, 2), 37: (0, 2), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0)}, - - (2, 6): {0: (0, 8), 1: (0, 11), 2: (0, 5), 3: (0, 9), 4: (0, 4), 5: (0, 3), 6: (0, 6), 7: (0, 3), 8: (0, 8), 9: (0, 7), 10: (0, 3), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 5), 20: (0, 4), 21: (0, 2), 22: (0, 3), 23: (0, 3), 24: (0, 1), 25: (0, 1), 26: (0, 1), 27: (0, 0), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 3), 33: (0, 1), 34: (0, 1), 35: (0, 2), 36: (0, 2), 37: (0, 2), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 0), 49: (0, 0)}, - - (2, 7): {0: (0, 8), 1: (0, 12), 2: (0, 5), 3: (0, 9), 4: (0, 4), 5: (0, 3), 6: (0, 6), 7: (0, 3), 8: (0, 8), 9: (0, 7), 10: (0, 3), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 6), 20: (0, 5), 21: (0, 2), 22: (0, 4), 23: (0, 3), 24: (0, 2), 25: (0, 2), 26: (0, 2), 27: (0, 0), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 3), 33: (0, 1), 34: (0, 1), 35: (0, 2), 36: (0, 2), 37: (0, 3), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 0), 49: (0, 0), 50: (0, 0), 51: (0, 0), 56: (0, 0), 57: (0, 0)}, - - (2, 8): {0: (0, 9), 1: (0, 12), 2: (0, 6), 3: (0, 11), 4: (0, 5), 5: (0, 3), 6: (0, 7), 7: (0, 3), 8: (0, 8), 9: (0, 7), 10: (0, 3), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 2), 36: (0, 2), 37: (0, 3), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 0), 57: (0, 0), 60: (0, 0), 61: (0, 0)}, - - (3, 0): {0: (0, 9), 1: (0, 12), 2: (0, 7), 3: (0, 11), 4: (0, 5), 5: (0, 3), 6: (0, 7), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 2), 36: (0, 2), 37: (0, 3), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 0), 57: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 1): {0: (0, 9), 1: (0, 13), 2: (0, 7), 3: (0, 12), 4: (0, 5), 5: (0, 3), 6: (0, 7), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 7), 12: (0, 4), 13: (0, 4), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 2), 36: (0, 2), 37: (0, 3), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 0), 57: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 2): {0: (0, 9), 1: (0, 13), 2: (0, 7), 3: (0, 12), 4: (0, 6), 5: (0, 3), 6: (0, 7), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 2), 30: (0, 2), 31: (0, 2), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 3), 36: (0, 2), 37: (0, 3), 38: (0, 2), 39: (0, 2), 40: (0, 2), 41: (0, 2), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 1), 57: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 3): {0: (0, 9), 1: (0, 13), 2: (0, 7), 3: (0, 12), 4: (0, 6), 5: (0, 3), 6: (0, 7), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 2), 57: (0, 1), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 4): {0: (0, 9), 1: (0, 13), 2: (0, 7), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 2), 57: (0, 1), 58: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 5): {0: (0, 9), 1: (0, 15), 2: (0, 8), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 3), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 3), 57: (0, 1), 58: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 6): {0: (0, 9), 1: (0, 15), 2: (0, 8), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 8), 9: (0, 8), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 4), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 3), 57: (0, 1), 58: (0, 0), 60: (0, 0), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0)}, - - (3, 7): {0: (0, 10), 1: (0, 16), 2: (0, 8), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 9), 9: (0, 9), 10: (0, 4), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 4), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 4), 23: (0, 4), 24: (0, 4), 25: (0, 3), 26: (0, 3), 27: (0, 1), 28: (0, 3), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 3), 57: (0, 1), 58: (0, 0), 60: (0, 1), 61: (0, 0), 65: (0, 0), 66: (0, 0), 67: (0, 0), 68: (0, 0)}, - - (3, 8): {0: (0, 11), 1: (0, 16), 2: (0, 8), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 9), 9: (0, 9), 10: (0, 5), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 5), 17: (0, 1), 18: (0, 3), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 5), 23: (0, 4), 24: (0, 5), 25: (0, 4), 26: (0, 4), 27: (0, 1), 28: (0, 4), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 3), 57: (0, 1), 58: (0, 0), 60: (0, 1), 61: (0, 0), 65: (0, 0), 66: (0, 1), 67: (0, 0), 68: (0, 0), 69: (0, 0)}, - - (3, 9): {0: (0, 11), 1: (0, 17), 2: (0, 9), 3: (0, 12), 4: (0, 7), 5: (0, 4), 6: (0, 8), 7: (0, 3), 8: (0, 9), 9: (0, 9), 10: (0, 6), 11: (0, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 5), 16: (0, 5), 17: (0, 1), 18: (0, 4), 19: (0, 7), 20: (0, 6), 21: (0, 2), 22: (0, 5), 23: (0, 4), 24: (0, 5), 25: (0, 4), 26: (0, 4), 27: (0, 1), 28: (0, 4), 29: (0, 3), 30: (0, 3), 31: (0, 3), 32: (0, 4), 33: (0, 2), 34: (0, 2), 35: (0, 4), 36: (0, 2), 37: (0, 3), 38: (0, 3), 39: (0, 2), 40: (0, 2), 41: (0, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 56: (0, 3), 57: (0, 1), 58: (0, 0), 60: (0, 1), 61: (0, 0), 65: (0, 0), 66: (0, 1), 67: (0, 0), 68: (0, 0), 69: (0, 0)}, - - (4, 0): {0: (0, 12), 1: (4, 17), 2: (1, 10), 3: (0, 13), 8: (2, 9), 9: (1, 9), 10: (0, 6), 11: (2, 9), 12: (0, 4), 13: (0, 5), 14: (0, 5), 15: (0, 6), 16: (0, 5), 17: (0, 1), 18: (0, 4), 19: (2, 7), 20: (1, 6), 21: (0, 2), 22: (0, 5), 23: (2, 4), 24: (0, 5), 25: (0, 4), 26: (0, 5), 27: (1, 1), 28: (0, 5), 29: (1, 3), 30: (1, 3), 31: (1, 3), 32: (1, 4), 33: (0, 2), 34: (1, 2), 35: (1, 4), 36: (0, 2), 37: (0, 3), 38: (1, 3), 39: (1, 2), 40: (1, 2), 41: (1, 3), 42: (0, 2), 43: (0, 2), 44: (0, 1), 45: (0, 0), 46: (0, 0), 47: (0, 0), 48: (0, 1), 49: (0, 1), 50: (0, 0), 51: (0, 0), 55: (0, 2), 57: (0, 2), 60: (0, 2), 61: (0, 0), 64: (0, 0), 65: (0, 0), 66: (0, 1), 68: (0, 1), 69: (0, 1), 74: (0, 0), 75: (0, 0), 80: (0, 0), 81: (0, 0)}, - -} diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/commit.py b/venv/lib/python3.12/site-packages/kafka/protocol/commit.py index a0439e7..31fc237 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/commit.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/commit.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int16, Int32, Int64, Schema, String +from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String class OffsetCommitResponse_v0(Response): @@ -41,24 +41,6 @@ class OffsetCommitResponse_v3(Response): ) -class OffsetCommitResponse_v4(Response): - API_KEY = 8 - API_VERSION = 4 - SCHEMA = OffsetCommitResponse_v3.SCHEMA - - -class OffsetCommitResponse_v5(Response): - API_KEY = 8 - API_VERSION = 5 - SCHEMA = OffsetCommitResponse_v4.SCHEMA - - -class OffsetCommitResponse_v6(Response): - API_KEY = 8 - API_VERSION = 6 - SCHEMA = OffsetCommitResponse_v5.SCHEMA - - class OffsetCommitRequest_v0(Request): API_KEY = 8 API_VERSION = 0 # Zookeeper-backed storage @@ -94,13 +76,13 @@ class OffsetCommitRequest_v1(Request): class OffsetCommitRequest_v2(Request): API_KEY = 8 - API_VERSION = 2 + API_VERSION = 2 # added retention_time, dropped timestamp RESPONSE_TYPE = OffsetCommitResponse_v2 SCHEMA = Schema( ('consumer_group', String('utf-8')), ('consumer_group_generation_id', Int32), ('consumer_id', String('utf-8')), - ('retention_time', Int64), # added retention_time, dropped timestamp + ('retention_time', Int64), ('topics', Array( ('topic', String('utf-8')), ('partitions', Array( @@ -108,6 +90,7 @@ class OffsetCommitRequest_v2(Request): ('offset', Int64), ('metadata', String('utf-8')))))) ) + DEFAULT_GENERATION_ID = -1 DEFAULT_RETENTION_TIME = -1 @@ -116,63 +99,15 @@ class OffsetCommitRequest_v3(Request): API_VERSION = 3 RESPONSE_TYPE = OffsetCommitResponse_v3 SCHEMA = OffsetCommitRequest_v2.SCHEMA - DEFAULT_RETENTION_TIME = -1 - - -class OffsetCommitRequest_v4(Request): - API_KEY = 8 - API_VERSION = 4 - RESPONSE_TYPE = OffsetCommitResponse_v4 - SCHEMA = OffsetCommitRequest_v3.SCHEMA - DEFAULT_RETENTION_TIME = -1 - - -class OffsetCommitRequest_v5(Request): - API_KEY = 8 - API_VERSION = 5 # drops retention_time - RESPONSE_TYPE = OffsetCommitResponse_v5 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('consumer_group_generation_id', Int32), - ('consumer_id', String('utf-8')), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')))))) - ) - - -class OffsetCommitRequest_v6(Request): - API_KEY = 8 - API_VERSION = 6 - RESPONSE_TYPE = OffsetCommitResponse_v6 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('consumer_group_generation_id', Int32), - ('consumer_id', String('utf-8')), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('leader_epoch', Int32), # added for fencing / kip-320. default -1 - ('metadata', String('utf-8')))))) - ) OffsetCommitRequest = [ OffsetCommitRequest_v0, OffsetCommitRequest_v1, - OffsetCommitRequest_v2, OffsetCommitRequest_v3, - OffsetCommitRequest_v4, OffsetCommitRequest_v5, - OffsetCommitRequest_v6, + OffsetCommitRequest_v2, OffsetCommitRequest_v3 ] OffsetCommitResponse = [ OffsetCommitResponse_v0, OffsetCommitResponse_v1, - OffsetCommitResponse_v2, OffsetCommitResponse_v3, - OffsetCommitResponse_v4, OffsetCommitResponse_v5, - OffsetCommitResponse_v6, + OffsetCommitResponse_v2, OffsetCommitResponse_v3 ] @@ -228,29 +163,6 @@ class OffsetFetchResponse_v3(Response): ) -class OffsetFetchResponse_v4(Response): - API_KEY = 9 - API_VERSION = 4 - SCHEMA = OffsetFetchResponse_v3.SCHEMA - - -class OffsetFetchResponse_v5(Response): - API_KEY = 9 - API_VERSION = 5 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('leader_epoch', Int32), - ('metadata', String('utf-8')), - ('error_code', Int16))))), - ('error_code', Int16) - ) - - class OffsetFetchRequest_v0(Request): API_KEY = 9 API_VERSION = 0 # zookeeper-backed storage @@ -287,27 +199,57 @@ class OffsetFetchRequest_v3(Request): SCHEMA = OffsetFetchRequest_v2.SCHEMA -class OffsetFetchRequest_v4(Request): - API_KEY = 9 - API_VERSION = 4 - RESPONSE_TYPE = OffsetFetchResponse_v4 - SCHEMA = OffsetFetchRequest_v3.SCHEMA - - -class OffsetFetchRequest_v5(Request): - API_KEY = 9 - API_VERSION = 5 - RESPONSE_TYPE = OffsetFetchResponse_v5 - SCHEMA = OffsetFetchRequest_v4.SCHEMA - - OffsetFetchRequest = [ OffsetFetchRequest_v0, OffsetFetchRequest_v1, OffsetFetchRequest_v2, OffsetFetchRequest_v3, - OffsetFetchRequest_v4, OffsetFetchRequest_v5, ] OffsetFetchResponse = [ OffsetFetchResponse_v0, OffsetFetchResponse_v1, OffsetFetchResponse_v2, OffsetFetchResponse_v3, - OffsetFetchResponse_v4, OffsetFetchResponse_v5, ] + + +class GroupCoordinatorResponse_v0(Response): + API_KEY = 10 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('coordinator_id', Int32), + ('host', String('utf-8')), + ('port', Int32) + ) + + +class GroupCoordinatorResponse_v1(Response): + API_KEY = 10 + API_VERSION = 1 + SCHEMA = Schema( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('coordinator_id', Int32), + ('host', String('utf-8')), + ('port', Int32) + ) + + +class GroupCoordinatorRequest_v0(Request): + API_KEY = 10 + API_VERSION = 0 + RESPONSE_TYPE = GroupCoordinatorResponse_v0 + SCHEMA = Schema( + ('consumer_group', String('utf-8')) + ) + + +class GroupCoordinatorRequest_v1(Request): + API_KEY = 10 + API_VERSION = 1 + RESPONSE_TYPE = GroupCoordinatorResponse_v1 + SCHEMA = Schema( + ('coordinator_key', String('utf-8')), + ('coordinator_type', Int8) + ) + + +GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1] +GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/end_txn.py b/venv/lib/python3.12/site-packages/kafka/protocol/end_txn.py deleted file mode 100644 index 96d6cc5..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/end_txn.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Boolean, Int16, Int32, Int64, Schema, String - - -class EndTxnResponse_v0(Response): - API_KEY = 26 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ) - - -class EndTxnResponse_v1(Response): - API_KEY = 26 - API_VERSION = 1 - SCHEMA = EndTxnResponse_v0.SCHEMA - - -class EndTxnResponse_v2(Response): - API_KEY = 26 - API_VERSION = 2 - SCHEMA = EndTxnResponse_v1.SCHEMA - - -class EndTxnRequest_v0(Request): - API_KEY = 26 - API_VERSION = 0 - RESPONSE_TYPE = EndTxnResponse_v0 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('committed', Boolean)) - - -class EndTxnRequest_v1(Request): - API_KEY = 26 - API_VERSION = 1 - RESPONSE_TYPE = EndTxnResponse_v1 - SCHEMA = EndTxnRequest_v0.SCHEMA - - -class EndTxnRequest_v2(Request): - API_KEY = 26 - API_VERSION = 2 - RESPONSE_TYPE = EndTxnResponse_v2 - SCHEMA = EndTxnRequest_v1.SCHEMA - - -EndTxnRequest = [ - EndTxnRequest_v0, EndTxnRequest_v1, EndTxnRequest_v2, -] -EndTxnResponse = [ - EndTxnResponse_v0, EndTxnResponse_v1, EndTxnResponse_v2, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/fetch.py b/venv/lib/python3.12/site-packages/kafka/protocol/fetch.py index 036a37e..f367848 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/fetch.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/fetch.py @@ -1,15 +1,9 @@ from __future__ import absolute_import -import collections - from kafka.protocol.api import Request, Response from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes -AbortedTransaction = collections.namedtuple("AbortedTransaction", - ["producer_id", "first_offset"]) - - class FetchResponse_v0(Response): API_KEY = 1 API_VERSION = 0 @@ -20,7 +14,7 @@ class FetchResponse_v0(Response): ('partition', Int32), ('error_code', Int16), ('highwater_offset', Int64), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -35,7 +29,7 @@ class FetchResponse_v1(Response): ('partition', Int32), ('error_code', Int16), ('highwater_offset', Int64), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -52,7 +46,6 @@ class FetchResponse_v3(Response): class FetchResponse_v4(Response): - # Adds message format v2 API_KEY = 1 API_VERSION = 4 SCHEMA = Schema( @@ -67,7 +60,7 @@ class FetchResponse_v4(Response): ('aborted_transactions', Array( ('producer_id', Int64), ('first_offset', Int64))), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -87,7 +80,7 @@ class FetchResponse_v5(Response): ('aborted_transactions', Array( ('producer_id', Int64), ('first_offset', Int64))), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -122,7 +115,7 @@ class FetchResponse_v7(Response): ('aborted_transactions', Array( ('producer_id', Int64), ('first_offset', Int64))), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -163,7 +156,7 @@ class FetchResponse_v11(Response): ('producer_id', Int64), ('first_offset', Int64))), ('preferred_read_replica', Int32), - ('records', Bytes))))) + ('message_set', Bytes))))) ) @@ -218,7 +211,6 @@ class FetchRequest_v3(Request): class FetchRequest_v4(Request): # Adds isolation_level field - # Adds message format v2 API_KEY = 1 API_VERSION = 4 RESPONSE_TYPE = FetchResponse_v4 @@ -272,7 +264,7 @@ class FetchRequest_v6(Request): class FetchRequest_v7(Request): """ - Add incremental fetch requests (see KIP-227) + Add incremental fetch requests """ API_KEY = 1 API_VERSION = 7 @@ -293,7 +285,7 @@ class FetchRequest_v7(Request): ('log_start_offset', Int64), ('max_bytes', Int32))))), ('forgotten_topics_data', Array( - ('topic', String('utf-8')), + ('topic', String), ('partitions', Array(Int32)) )), ) @@ -333,7 +325,7 @@ class FetchRequest_v9(Request): ('log_start_offset', Int64), ('max_bytes', Int32))))), ('forgotten_topics_data', Array( - ('topic', String('utf-8')), + ('topic', String), ('partitions', Array(Int32)), )), ) @@ -373,7 +365,7 @@ class FetchRequest_v11(Request): ('log_start_offset', Int64), ('max_bytes', Int32))))), ('forgotten_topics_data', Array( - ('topic', String('utf-8')), + ('topic', String), ('partitions', Array(Int32)) )), ('rack_id', String('utf-8')), diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/find_coordinator.py b/venv/lib/python3.12/site-packages/kafka/protocol/find_coordinator.py deleted file mode 100644 index be5b45d..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/find_coordinator.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Int8, Int16, Int32, Schema, String - - -class FindCoordinatorResponse_v0(Response): - API_KEY = 10 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) - ) - - -class FindCoordinatorResponse_v1(Response): - API_KEY = 10 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) - ) - - -class FindCoordinatorResponse_v2(Response): - API_KEY = 10 - API_VERSION = 2 - SCHEMA = FindCoordinatorResponse_v1.SCHEMA - - -class FindCoordinatorRequest_v0(Request): - API_KEY = 10 - API_VERSION = 0 - RESPONSE_TYPE = FindCoordinatorResponse_v0 - SCHEMA = Schema( - ('consumer_group', String('utf-8')) - ) - - -class FindCoordinatorRequest_v1(Request): - API_KEY = 10 - API_VERSION = 1 - RESPONSE_TYPE = FindCoordinatorResponse_v1 - SCHEMA = Schema( - ('coordinator_key', String('utf-8')), - ('coordinator_type', Int8) # 0: consumer, 1: transaction - ) - - -class FindCoordinatorRequest_v2(Request): - API_KEY = 10 - API_VERSION = 2 - RESPONSE_TYPE = FindCoordinatorResponse_v2 - SCHEMA = FindCoordinatorRequest_v1.SCHEMA - - -FindCoordinatorRequest = [FindCoordinatorRequest_v0, FindCoordinatorRequest_v1, FindCoordinatorRequest_v2] -FindCoordinatorResponse = [FindCoordinatorResponse_v0, FindCoordinatorResponse_v1, FindCoordinatorResponse_v2] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/group.py b/venv/lib/python3.12/site-packages/kafka/protocol/group.py index 74e19c9..bcb9655 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/group.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/group.py @@ -5,10 +5,6 @@ from kafka.protocol.struct import Struct from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String -DEFAULT_GENERATION_ID = -1 -UNKNOWN_MEMBER_ID = '' - - class JoinGroupResponse_v0(Response): API_KEY = 11 API_VERSION = 0 @@ -46,18 +42,6 @@ class JoinGroupResponse_v2(Response): ) -class JoinGroupResponse_v3(Response): - API_KEY = 11 - API_VERSION = 3 - SCHEMA = JoinGroupResponse_v2.SCHEMA - - -class JoinGroupResponse_v4(Response): - API_KEY = 11 - API_VERSION = 4 - SCHEMA = JoinGroupResponse_v3.SCHEMA - - class JoinGroupRequest_v0(Request): API_KEY = 11 API_VERSION = 0 @@ -71,6 +55,7 @@ class JoinGroupRequest_v0(Request): ('protocol_name', String('utf-8')), ('protocol_metadata', Bytes))) ) + UNKNOWN_MEMBER_ID = '' class JoinGroupRequest_v1(Request): @@ -87,6 +72,7 @@ class JoinGroupRequest_v1(Request): ('protocol_name', String('utf-8')), ('protocol_metadata', Bytes))) ) + UNKNOWN_MEMBER_ID = '' class JoinGroupRequest_v2(Request): @@ -94,29 +80,14 @@ class JoinGroupRequest_v2(Request): API_VERSION = 2 RESPONSE_TYPE = JoinGroupResponse_v2 SCHEMA = JoinGroupRequest_v1.SCHEMA - - -class JoinGroupRequest_v3(Request): - API_KEY = 11 - API_VERSION = 3 - RESPONSE_TYPE = JoinGroupResponse_v3 - SCHEMA = JoinGroupRequest_v2.SCHEMA - - -class JoinGroupRequest_v4(Request): - API_KEY = 11 - API_VERSION = 4 - RESPONSE_TYPE = JoinGroupResponse_v4 - SCHEMA = JoinGroupRequest_v3.SCHEMA + UNKNOWN_MEMBER_ID = '' JoinGroupRequest = [ - JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2, - JoinGroupRequest_v3, JoinGroupRequest_v4, + JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2 ] JoinGroupResponse = [ - JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2, - JoinGroupResponse_v3, JoinGroupResponse_v4, + JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2 ] @@ -147,12 +118,6 @@ class SyncGroupResponse_v1(Response): ) -class SyncGroupResponse_v2(Response): - API_KEY = 14 - API_VERSION = 2 - SCHEMA = SyncGroupResponse_v1.SCHEMA - - class SyncGroupRequest_v0(Request): API_KEY = 14 API_VERSION = 0 @@ -174,15 +139,8 @@ class SyncGroupRequest_v1(Request): SCHEMA = SyncGroupRequest_v0.SCHEMA -class SyncGroupRequest_v2(Request): - API_KEY = 14 - API_VERSION = 2 - RESPONSE_TYPE = SyncGroupResponse_v2 - SCHEMA = SyncGroupRequest_v1.SCHEMA - - -SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1, SyncGroupRequest_v2] -SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1, SyncGroupResponse_v2] +SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] +SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] class MemberAssignment(Struct): @@ -212,12 +170,6 @@ class HeartbeatResponse_v1(Response): ) -class HeartbeatResponse_v2(Response): - API_KEY = 12 - API_VERSION = 2 - SCHEMA = HeartbeatResponse_v1.SCHEMA - - class HeartbeatRequest_v0(Request): API_KEY = 12 API_VERSION = 0 @@ -236,15 +188,8 @@ class HeartbeatRequest_v1(Request): SCHEMA = HeartbeatRequest_v0.SCHEMA -class HeartbeatRequest_v2(Request): - API_KEY = 12 - API_VERSION = 2 - RESPONSE_TYPE = HeartbeatResponse_v2 - SCHEMA = HeartbeatRequest_v1.SCHEMA - - -HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1, HeartbeatRequest_v2] -HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1, HeartbeatResponse_v2] +HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] +HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] class LeaveGroupResponse_v0(Response): @@ -264,12 +209,6 @@ class LeaveGroupResponse_v1(Response): ) -class LeaveGroupResponse_v2(Response): - API_KEY = 13 - API_VERSION = 2 - SCHEMA = LeaveGroupResponse_v1.SCHEMA - - class LeaveGroupRequest_v0(Request): API_KEY = 13 API_VERSION = 0 @@ -287,12 +226,5 @@ class LeaveGroupRequest_v1(Request): SCHEMA = LeaveGroupRequest_v0.SCHEMA -class LeaveGroupRequest_v2(Request): - API_KEY = 13 - API_VERSION = 2 - RESPONSE_TYPE = LeaveGroupResponse_v2 - SCHEMA = LeaveGroupRequest_v1.SCHEMA - - -LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1, LeaveGroupRequest_v2] -LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1, LeaveGroupResponse_v2] +LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] +LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/init_producer_id.py b/venv/lib/python3.12/site-packages/kafka/protocol/init_producer_id.py deleted file mode 100644 index 8426fe0..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/init_producer_id.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Int16, Int32, Int64, Schema, String - - -class InitProducerIdResponse_v0(Response): - API_KEY = 22 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('producer_id', Int64), - ('producer_epoch', Int16), - ) - - -class InitProducerIdResponse_v1(Response): - API_KEY = 22 - API_VERSION = 1 - SCHEMA = InitProducerIdResponse_v0.SCHEMA - - -class InitProducerIdRequest_v0(Request): - API_KEY = 22 - API_VERSION = 0 - RESPONSE_TYPE = InitProducerIdResponse_v0 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('transaction_timeout_ms', Int32), - ) - - -class InitProducerIdRequest_v1(Request): - API_KEY = 22 - API_VERSION = 1 - RESPONSE_TYPE = InitProducerIdResponse_v1 - SCHEMA = InitProducerIdRequest_v0.SCHEMA - - -InitProducerIdRequest = [ - InitProducerIdRequest_v0, InitProducerIdRequest_v1, -] -InitProducerIdResponse = [ - InitProducerIdResponse_v0, InitProducerIdResponse_v1, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/metadata.py b/venv/lib/python3.12/site-packages/kafka/protocol/metadata.py index bb22ba9..414e5b8 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/metadata.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/metadata.py @@ -128,42 +128,6 @@ class MetadataResponse_v5(Response): ) -class MetadataResponse_v6(Response): - """Metadata Request/Response v6 is the same as v5, - but on quota violation, brokers send out responses before throttling.""" - API_KEY = 3 - API_VERSION = 6 - SCHEMA = MetadataResponse_v5.SCHEMA - - -class MetadataResponse_v7(Response): - """v7 adds per-partition leader_epoch field""" - API_KEY = 3 - API_VERSION = 7 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32), - ('rack', String('utf-8')))), - ('cluster_id', String('utf-8')), - ('controller_id', Int32), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('is_internal', Boolean), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('leader_epoch', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)), - ('offline_replicas', Array(Int32)))))) - ) - - class MetadataRequest_v0(Request): API_KEY = 3 API_VERSION = 0 @@ -171,8 +135,7 @@ class MetadataRequest_v0(Request): SCHEMA = Schema( ('topics', Array(String('utf-8'))) ) - ALL_TOPICS = [] # Empty Array (len 0) for topics returns all topics - NO_TOPICS = [] # v0 does not support a 'no topics' request, so we'll just ask for ALL + ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics class MetadataRequest_v1(Request): @@ -180,8 +143,8 @@ class MetadataRequest_v1(Request): API_VERSION = 1 RESPONSE_TYPE = MetadataResponse_v1 SCHEMA = MetadataRequest_v0.SCHEMA - ALL_TOPICS = None # Null Array (len -1) for topics returns all topics - NO_TOPICS = [] # Empty array (len 0) for topics returns no topics + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v2(Request): @@ -189,8 +152,8 @@ class MetadataRequest_v2(Request): API_VERSION = 2 RESPONSE_TYPE = MetadataResponse_v2 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = None - NO_TOPICS = [] + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v3(Request): @@ -198,8 +161,8 @@ class MetadataRequest_v3(Request): API_VERSION = 3 RESPONSE_TYPE = MetadataResponse_v3 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = None - NO_TOPICS = [] + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v4(Request): @@ -210,8 +173,8 @@ class MetadataRequest_v4(Request): ('topics', Array(String('utf-8'))), ('allow_auto_topic_creation', Boolean) ) - ALL_TOPICS = None - NO_TOPICS = [] + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v5(Request): @@ -223,35 +186,15 @@ class MetadataRequest_v5(Request): API_VERSION = 5 RESPONSE_TYPE = MetadataResponse_v5 SCHEMA = MetadataRequest_v4.SCHEMA - ALL_TOPICS = None - NO_TOPICS = [] - - -class MetadataRequest_v6(Request): - API_KEY = 3 - API_VERSION = 6 - RESPONSE_TYPE = MetadataResponse_v6 - SCHEMA = MetadataRequest_v5.SCHEMA - ALL_TOPICS = None - NO_TOPICS = [] - - -class MetadataRequest_v7(Request): - API_KEY = 3 - API_VERSION = 7 - RESPONSE_TYPE = MetadataResponse_v7 - SCHEMA = MetadataRequest_v6.SCHEMA - ALL_TOPICS = None - NO_TOPICS = [] + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics MetadataRequest = [ MetadataRequest_v0, MetadataRequest_v1, MetadataRequest_v2, - MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5, - MetadataRequest_v6, MetadataRequest_v7, + MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5 ] MetadataResponse = [ MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2, - MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5, - MetadataResponse_v6, MetadataResponse_v7, + MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5 ] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/list_offsets.py b/venv/lib/python3.12/site-packages/kafka/protocol/offset.py similarity index 73% rename from venv/lib/python3.12/site-packages/kafka/protocol/list_offsets.py rename to venv/lib/python3.12/site-packages/kafka/protocol/offset.py index 2e36dd6..1ed382b 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/list_offsets.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/offset.py @@ -12,7 +12,7 @@ class OffsetResetStrategy(object): NONE = 0 -class ListOffsetsResponse_v0(Response): +class OffsetResponse_v0(Response): API_KEY = 2 API_VERSION = 0 SCHEMA = Schema( @@ -24,7 +24,7 @@ class ListOffsetsResponse_v0(Response): ('offsets', Array(Int64)))))) ) -class ListOffsetsResponse_v1(Response): +class OffsetResponse_v1(Response): API_KEY = 2 API_VERSION = 1 SCHEMA = Schema( @@ -38,7 +38,7 @@ class ListOffsetsResponse_v1(Response): ) -class ListOffsetsResponse_v2(Response): +class OffsetResponse_v2(Response): API_KEY = 2 API_VERSION = 2 SCHEMA = Schema( @@ -53,16 +53,16 @@ class ListOffsetsResponse_v2(Response): ) -class ListOffsetsResponse_v3(Response): +class OffsetResponse_v3(Response): """ on quota violation, brokers send out responses before throttling """ API_KEY = 2 API_VERSION = 3 - SCHEMA = ListOffsetsResponse_v2.SCHEMA + SCHEMA = OffsetResponse_v2.SCHEMA -class ListOffsetsResponse_v4(Response): +class OffsetResponse_v4(Response): """ Add leader_epoch to response """ @@ -81,19 +81,19 @@ class ListOffsetsResponse_v4(Response): ) -class ListOffsetsResponse_v5(Response): +class OffsetResponse_v5(Response): """ adds a new error code, OFFSET_NOT_AVAILABLE """ API_KEY = 2 API_VERSION = 5 - SCHEMA = ListOffsetsResponse_v4.SCHEMA + SCHEMA = OffsetResponse_v4.SCHEMA -class ListOffsetsRequest_v0(Request): +class OffsetRequest_v0(Request): API_KEY = 2 API_VERSION = 0 - RESPONSE_TYPE = ListOffsetsResponse_v0 + RESPONSE_TYPE = OffsetResponse_v0 SCHEMA = Schema( ('replica_id', Int32), ('topics', Array( @@ -107,10 +107,10 @@ class ListOffsetsRequest_v0(Request): 'replica_id': -1 } -class ListOffsetsRequest_v1(Request): +class OffsetRequest_v1(Request): API_KEY = 2 API_VERSION = 1 - RESPONSE_TYPE = ListOffsetsResponse_v1 + RESPONSE_TYPE = OffsetResponse_v1 SCHEMA = Schema( ('replica_id', Int32), ('topics', Array( @@ -124,10 +124,10 @@ class ListOffsetsRequest_v1(Request): } -class ListOffsetsRequest_v2(Request): +class OffsetRequest_v2(Request): API_KEY = 2 API_VERSION = 2 - RESPONSE_TYPE = ListOffsetsResponse_v2 + RESPONSE_TYPE = OffsetResponse_v2 SCHEMA = Schema( ('replica_id', Int32), ('isolation_level', Int8), # <- added isolation_level @@ -142,23 +142,23 @@ class ListOffsetsRequest_v2(Request): } -class ListOffsetsRequest_v3(Request): +class OffsetRequest_v3(Request): API_KEY = 2 API_VERSION = 3 - RESPONSE_TYPE = ListOffsetsResponse_v3 - SCHEMA = ListOffsetsRequest_v2.SCHEMA + RESPONSE_TYPE = OffsetResponse_v3 + SCHEMA = OffsetRequest_v2.SCHEMA DEFAULTS = { 'replica_id': -1 } -class ListOffsetsRequest_v4(Request): +class OffsetRequest_v4(Request): """ Add current_leader_epoch to request """ API_KEY = 2 API_VERSION = 4 - RESPONSE_TYPE = ListOffsetsResponse_v4 + RESPONSE_TYPE = OffsetResponse_v4 SCHEMA = Schema( ('replica_id', Int32), ('isolation_level', Int8), # <- added isolation_level @@ -166,7 +166,7 @@ class ListOffsetsRequest_v4(Request): ('topic', String('utf-8')), ('partitions', Array( ('partition', Int32), - ('current_leader_epoch', Int32), + ('current_leader_epoch', Int64), ('timestamp', Int64))))) ) DEFAULTS = { @@ -174,21 +174,21 @@ class ListOffsetsRequest_v4(Request): } -class ListOffsetsRequest_v5(Request): +class OffsetRequest_v5(Request): API_KEY = 2 API_VERSION = 5 - RESPONSE_TYPE = ListOffsetsResponse_v5 - SCHEMA = ListOffsetsRequest_v4.SCHEMA + RESPONSE_TYPE = OffsetResponse_v5 + SCHEMA = OffsetRequest_v4.SCHEMA DEFAULTS = { 'replica_id': -1 } -ListOffsetsRequest = [ - ListOffsetsRequest_v0, ListOffsetsRequest_v1, ListOffsetsRequest_v2, - ListOffsetsRequest_v3, ListOffsetsRequest_v4, ListOffsetsRequest_v5, +OffsetRequest = [ + OffsetRequest_v0, OffsetRequest_v1, OffsetRequest_v2, + OffsetRequest_v3, OffsetRequest_v4, OffsetRequest_v5, ] -ListOffsetsResponse = [ - ListOffsetsResponse_v0, ListOffsetsResponse_v1, ListOffsetsResponse_v2, - ListOffsetsResponse_v3, ListOffsetsResponse_v4, ListOffsetsResponse_v5, +OffsetResponse = [ + OffsetResponse_v0, OffsetResponse_v1, OffsetResponse_v2, + OffsetResponse_v3, OffsetResponse_v4, OffsetResponse_v5, ] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/offset_for_leader_epoch.py b/venv/lib/python3.12/site-packages/kafka/protocol/offset_for_leader_epoch.py deleted file mode 100644 index 8465588..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/offset_for_leader_epoch.py +++ /dev/null @@ -1,140 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, CompactArray, CompactString, Int16, Int32, Int64, Schema, String, TaggedFields - - -class OffsetForLeaderEpochResponse_v0(Response): - API_KEY = 23 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('end_offset', Int64)))))) - - -class OffsetForLeaderEpochResponse_v1(Response): - API_KEY = 23 - API_VERSION = 1 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader_epoch', Int32), - ('end_offset', Int64)))))) - - -class OffsetForLeaderEpochResponse_v2(Response): - API_KEY = 23 - API_VERSION = 2 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader_epoch', Int32), - ('end_offset', Int64)))))) - - -class OffsetForLeaderEpochResponse_v3(Response): - API_KEY = 23 - API_VERSION = 3 - SCHEMA = OffsetForLeaderEpochResponse_v2.SCHEMA - - -class OffsetForLeaderEpochResponse_v4(Response): - API_KEY = 23 - API_VERSION = 4 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', CompactArray( - ('topic', CompactString('utf-8')), - ('partitions', CompactArray( - ('error_code', Int16), - ('partition', Int32), - ('leader_epoch', Int32), - ('end_offset', Int64), - ('tags', TaggedFields))), - ('tags', TaggedFields))), - ('tags', TaggedFields)) - - -class OffsetForLeaderEpochRequest_v0(Request): - API_KEY = 23 - API_VERSION = 0 - RESPONSE_TYPE = OffsetForLeaderEpochResponse_v0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('leader_epoch', Int32)))))) - - -class OffsetForLeaderEpochRequest_v1(Request): - API_KEY = 23 - API_VERSION = 1 - RESPONSE_TYPE = OffsetForLeaderEpochResponse_v1 - SCHEMA = OffsetForLeaderEpochRequest_v0.SCHEMA - - -class OffsetForLeaderEpochRequest_v2(Request): - API_KEY = 23 - API_VERSION = 2 - RESPONSE_TYPE = OffsetForLeaderEpochResponse_v2 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('current_leader_epoch', Int32), - ('leader_epoch', Int32)))))) - - -class OffsetForLeaderEpochRequest_v3(Request): - API_KEY = 23 - API_VERSION = 3 - RESPONSE_TYPE = OffsetForLeaderEpochResponse_v3 - SCHEMA = Schema( - ('replica_id', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('current_leader_epoch', Int32), - ('leader_epoch', Int32)))))) - - -class OffsetForLeaderEpochRequest_v4(Request): - API_KEY = 23 - API_VERSION = 4 - RESPONSE_TYPE = OffsetForLeaderEpochResponse_v4 - SCHEMA = Schema( - ('replica_id', Int32), - ('topics', CompactArray( - ('topic', CompactString('utf-8')), - ('partitions', CompactArray( - ('partition', Int32), - ('current_leader_epoch', Int32), - ('leader_epoch', Int32), - ('tags', TaggedFields))), - ('tags', TaggedFields))), - ('tags', TaggedFields)) - -OffsetForLeaderEpochRequest = [ - OffsetForLeaderEpochRequest_v0, OffsetForLeaderEpochRequest_v1, - OffsetForLeaderEpochRequest_v2, OffsetForLeaderEpochRequest_v3, - OffsetForLeaderEpochRequest_v4, -] -OffsetForLeaderEpochResponse = [ - OffsetForLeaderEpochResponse_v0, OffsetForLeaderEpochResponse_v1, - OffsetForLeaderEpochResponse_v2, OffsetForLeaderEpochResponse_v3, - OffsetForLeaderEpochResponse_v4, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/parser.py b/venv/lib/python3.12/site-packages/kafka/protocol/parser.py index 4bc4273..cfee046 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/parser.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/parser.py @@ -4,9 +4,10 @@ import collections import logging import kafka.errors as Errors -from kafka.protocol.find_coordinator import FindCoordinatorResponse +from kafka.protocol.api import RequestHeader +from kafka.protocol.commit import GroupCoordinatorResponse from kafka.protocol.frame import KafkaBytes -from kafka.protocol.types import Int32, TaggedFields +from kafka.protocol.types import Int32 from kafka.version import __version__ log = logging.getLogger(__name__) @@ -58,8 +59,9 @@ class KafkaProtocol(object): log.debug('Sending request %s', request) if correlation_id is None: correlation_id = self._next_correlation_id() - - header = request.build_header(correlation_id=correlation_id, client_id=self._client_id) + header = RequestHeader(request, + correlation_id=correlation_id, + client_id=self._client_id) message = b''.join([header.encode(), request.encode()]) size = Int32.encode(len(message)) data = size + message @@ -133,17 +135,21 @@ class KafkaProtocol(object): return responses def _process_response(self, read_buffer): - if not self.in_flight_requests: - raise Errors.CorrelationIdError('No in-flight-request found for server response') - (correlation_id, request) = self.in_flight_requests.popleft() - response_type = request.RESPONSE_TYPE - response_header = response_type.parse_header(read_buffer) - recv_correlation_id = response_header.correlation_id + recv_correlation_id = Int32.decode(read_buffer) log.debug('Received correlation id: %d', recv_correlation_id) + + if not self.in_flight_requests: + raise Errors.CorrelationIdError( + 'No in-flight-request found for server response' + ' with correlation ID %d' + % (recv_correlation_id,)) + + (correlation_id, request) = self.in_flight_requests.popleft() + # 0.8.2 quirk if (recv_correlation_id == 0 and correlation_id != 0 and - response_type is FindCoordinatorResponse[0] and + request.RESPONSE_TYPE is GroupCoordinatorResponse[0] and (self._api_version == (0, 8, 2) or self._api_version is None)): log.warning('Kafka 0.8.2 quirk -- GroupCoordinatorResponse' ' Correlation ID does not match request. This' @@ -157,15 +163,15 @@ class KafkaProtocol(object): % (correlation_id, recv_correlation_id)) # decode response - log.debug('Processing response %s', response_type.__name__) + log.debug('Processing response %s', request.RESPONSE_TYPE.__name__) try: - response = response_type.decode(read_buffer) + response = request.RESPONSE_TYPE.decode(read_buffer) except ValueError: read_buffer.seek(0) buf = read_buffer.read() log.error('Response %d [ResponseType: %s Request: %s]:' ' Unable to decode %d-byte buffer: %r', - correlation_id, response_type, + correlation_id, request.RESPONSE_TYPE, request, len(buf), buf) raise Errors.KafkaProtocolError('Unable to decode response') diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/produce.py b/venv/lib/python3.12/site-packages/kafka/protocol/produce.py index 3076a28..9b3f6bf 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/produce.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/produce.py @@ -47,7 +47,6 @@ class ProduceResponse_v2(Response): class ProduceResponse_v3(Response): - # Adds support for message format v2 API_KEY = 0 API_VERSION = 3 SCHEMA = ProduceResponse_v2.SCHEMA @@ -142,7 +141,7 @@ class ProduceRequest_v0(ProduceRequest): ('topic', String('utf-8')), ('partitions', Array( ('partition', Int32), - ('records', Bytes))))) + ('messages', Bytes))))) ) @@ -159,7 +158,6 @@ class ProduceRequest_v2(ProduceRequest): class ProduceRequest_v3(ProduceRequest): - # Adds support for message format v2 API_VERSION = 3 RESPONSE_TYPE = ProduceResponse_v3 SCHEMA = Schema( @@ -170,7 +168,7 @@ class ProduceRequest_v3(ProduceRequest): ('topic', String('utf-8')), ('partitions', Array( ('partition', Int32), - ('records', Bytes))))) + ('messages', Bytes))))) ) diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/sasl_authenticate.py b/venv/lib/python3.12/site-packages/kafka/protocol/sasl_authenticate.py deleted file mode 100644 index a2b9b19..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/sasl_authenticate.py +++ /dev/null @@ -1,42 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Bytes, Int16, Int64, Schema, String - - -class SaslAuthenticateResponse_v0(Response): - API_KEY = 36 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('auth_bytes', Bytes)) - - -class SaslAuthenticateResponse_v1(Response): - API_KEY = 36 - API_VERSION = 1 - SCHEMA = Schema( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('auth_bytes', Bytes), - ('session_lifetime_ms', Int64)) - - -class SaslAuthenticateRequest_v0(Request): - API_KEY = 36 - API_VERSION = 0 - RESPONSE_TYPE = SaslAuthenticateResponse_v0 - SCHEMA = Schema( - ('auth_bytes', Bytes)) - - -class SaslAuthenticateRequest_v1(Request): - API_KEY = 36 - API_VERSION = 1 - RESPONSE_TYPE = SaslAuthenticateResponse_v1 - SCHEMA = SaslAuthenticateRequest_v0.SCHEMA - - -SaslAuthenticateRequest = [SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1] -SaslAuthenticateResponse = [SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/sasl_handshake.py b/venv/lib/python3.12/site-packages/kafka/protocol/sasl_handshake.py deleted file mode 100644 index e91c856..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/sasl_handshake.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int16, Schema, String - - -class SaslHandshakeResponse_v0(Response): - API_KEY = 17 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('enabled_mechanisms', Array(String('utf-8'))) - ) - - -class SaslHandshakeResponse_v1(Response): - API_KEY = 17 - API_VERSION = 1 - SCHEMA = SaslHandshakeResponse_v0.SCHEMA - - -class SaslHandshakeRequest_v0(Request): - API_KEY = 17 - API_VERSION = 0 - RESPONSE_TYPE = SaslHandshakeResponse_v0 - SCHEMA = Schema( - ('mechanism', String('utf-8')) - ) - - -class SaslHandshakeRequest_v1(Request): - API_KEY = 17 - API_VERSION = 1 - RESPONSE_TYPE = SaslHandshakeResponse_v1 - SCHEMA = SaslHandshakeRequest_v0.SCHEMA - - -SaslHandshakeRequest = [SaslHandshakeRequest_v0, SaslHandshakeRequest_v1] -SaslHandshakeResponse = [SaslHandshakeResponse_v0, SaslHandshakeResponse_v1] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/txn_offset_commit.py b/venv/lib/python3.12/site-packages/kafka/protocol/txn_offset_commit.py deleted file mode 100644 index df1b1bd..0000000 --- a/venv/lib/python3.12/site-packages/kafka/protocol/txn_offset_commit.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int16, Int32, Int64, Schema, String - - -class TxnOffsetCommitResponse_v0(Response): - API_KEY = 28 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16)))))) - - -class TxnOffsetCommitResponse_v1(Response): - API_KEY = 28 - API_VERSION = 1 - SCHEMA = TxnOffsetCommitResponse_v0.SCHEMA - - -class TxnOffsetCommitResponse_v2(Response): - API_KEY = 28 - API_VERSION = 2 - SCHEMA = TxnOffsetCommitResponse_v1.SCHEMA - - -class TxnOffsetCommitRequest_v0(Request): - API_KEY = 28 - API_VERSION = 0 - RESPONSE_TYPE = TxnOffsetCommitResponse_v0 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('group_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8'))))))) - - -class TxnOffsetCommitRequest_v1(Request): - API_KEY = 28 - API_VERSION = 1 - RESPONSE_TYPE = TxnOffsetCommitResponse_v1 - SCHEMA = TxnOffsetCommitRequest_v0.SCHEMA - - -class TxnOffsetCommitRequest_v2(Request): - API_KEY = 28 - API_VERSION = 2 - RESPONSE_TYPE = TxnOffsetCommitResponse_v2 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('group_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('leader_epoch', Int32), - ('metadata', String('utf-8'))))))) - - -TxnOffsetCommitRequest = [ - TxnOffsetCommitRequest_v0, TxnOffsetCommitRequest_v1, TxnOffsetCommitRequest_v2, -] -TxnOffsetCommitResponse = [ - TxnOffsetCommitResponse_v0, TxnOffsetCommitResponse_v1, TxnOffsetCommitResponse_v2, -] diff --git a/venv/lib/python3.12/site-packages/kafka/protocol/types.py b/venv/lib/python3.12/site-packages/kafka/protocol/types.py index 0e3685d..d508b26 100644 --- a/venv/lib/python3.12/site-packages/kafka/protocol/types.py +++ b/venv/lib/python3.12/site-packages/kafka/protocol/types.py @@ -77,19 +77,6 @@ class Int64(AbstractType): return _unpack(cls._unpack, data.read(8)) -class Float64(AbstractType): - _pack = struct.Struct('>d').pack - _unpack = struct.Struct('>d').unpack - - @classmethod - def encode(cls, value): - return _pack(cls._pack, value) - - @classmethod - def decode(cls, data): - return _unpack(cls._unpack, data.read(8)) - - class String(AbstractType): def __init__(self, encoding='utf-8'): self.encoding = encoding @@ -194,10 +181,9 @@ class Array(AbstractType): def encode(self, items): if items is None: return Int32.encode(-1) - encoded_items = [self.array_of.encode(item) for item in items] return b''.join( - [Int32.encode(len(encoded_items))] + - encoded_items + [Int32.encode(len(items))] + + [self.array_of.encode(item) for item in items] ) def decode(self, data): @@ -210,156 +196,3 @@ class Array(AbstractType): if list_of_items is None: return 'NULL' return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']' - - -class UnsignedVarInt32(AbstractType): - @classmethod - def decode(cls, data): - value, i = 0, 0 - while True: - b, = struct.unpack('B', data.read(1)) - if not (b & 0x80): - break - value |= (b & 0x7f) << i - i += 7 - if i > 28: - raise ValueError('Invalid value {}'.format(value)) - value |= b << i - return value - - @classmethod - def encode(cls, value): - value &= 0xffffffff - ret = b'' - while (value & 0xffffff80) != 0: - b = (value & 0x7f) | 0x80 - ret += struct.pack('B', b) - value >>= 7 - ret += struct.pack('B', value) - return ret - - -class VarInt32(AbstractType): - @classmethod - def decode(cls, data): - value = UnsignedVarInt32.decode(data) - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value): - # bring it in line with the java binary repr - value &= 0xffffffff - return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) - - -class VarInt64(AbstractType): - @classmethod - def decode(cls, data): - value, i = 0, 0 - while True: - b = data.read(1) - if not (b & 0x80): - break - value |= (b & 0x7f) << i - i += 7 - if i > 63: - raise ValueError('Invalid value {}'.format(value)) - value |= b << i - return (value >> 1) ^ -(value & 1) - - @classmethod - def encode(cls, value): - # bring it in line with the java binary repr - value &= 0xffffffffffffffff - v = (value << 1) ^ (value >> 63) - ret = b'' - while (v & 0xffffffffffffff80) != 0: - b = (value & 0x7f) | 0x80 - ret += struct.pack('B', b) - v >>= 7 - ret += struct.pack('B', v) - return ret - - -class CompactString(String): - def decode(self, data): - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError('Buffer underrun decoding string') - return value.decode(self.encoding) - - def encode(self, value): - if value is None: - return UnsignedVarInt32.encode(0) - value = str(value).encode(self.encoding) - return UnsignedVarInt32.encode(len(value) + 1) + value - - -class TaggedFields(AbstractType): - @classmethod - def decode(cls, data): - num_fields = UnsignedVarInt32.decode(data) - ret = {} - if not num_fields: - return ret - prev_tag = -1 - for i in range(num_fields): - tag = UnsignedVarInt32.decode(data) - if tag <= prev_tag: - raise ValueError('Invalid or out-of-order tag {}'.format(tag)) - prev_tag = tag - size = UnsignedVarInt32.decode(data) - val = data.read(size) - ret[tag] = val - return ret - - @classmethod - def encode(cls, value): - ret = UnsignedVarInt32.encode(len(value)) - for k, v in value.items(): - # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), 'Value {} is not a byte array'.format(v) - assert isinstance(k, int) and k > 0, 'Key {} is not a positive integer'.format(k) - ret += UnsignedVarInt32.encode(k) - ret += v - return ret - - -class CompactBytes(AbstractType): - @classmethod - def decode(cls, data): - length = UnsignedVarInt32.decode(data) - 1 - if length < 0: - return None - value = data.read(length) - if len(value) != length: - raise ValueError('Buffer underrun decoding Bytes') - return value - - @classmethod - def encode(cls, value): - if value is None: - return UnsignedVarInt32.encode(0) - else: - return UnsignedVarInt32.encode(len(value) + 1) + value - - -class CompactArray(Array): - - def encode(self, items): - if items is None: - return UnsignedVarInt32.encode(0) - return b''.join( - [UnsignedVarInt32.encode(len(items) + 1)] + - [self.array_of.encode(item) for item in items] - ) - - def decode(self, data): - length = UnsignedVarInt32.decode(data) - 1 - if length == -1: - return None - return [self.array_of.decode(data) for _ in range(length)] - diff --git a/venv/lib/python3.12/site-packages/kafka/record/_crc32c.py b/venv/lib/python3.12/site-packages/kafka/record/_crc32c.py index 9b51ad8..ecff48f 100644 --- a/venv/lib/python3.12/site-packages/kafka/record/_crc32c.py +++ b/venv/lib/python3.12/site-packages/kafka/record/_crc32c.py @@ -105,7 +105,7 @@ def crc_update(crc, data): Returns: 32-bit updated CRC-32C as long. """ - if not isinstance(data, array.array) or data.itemsize != 1: + if type(data) != array.array or data.itemsize != 1: buf = array.array("B", data) else: buf = data diff --git a/venv/lib/python3.12/site-packages/kafka/record/abc.py b/venv/lib/python3.12/site-packages/kafka/record/abc.py index c78f0da..d5c172a 100644 --- a/venv/lib/python3.12/site-packages/kafka/record/abc.py +++ b/venv/lib/python3.12/site-packages/kafka/record/abc.py @@ -1,19 +1,11 @@ from __future__ import absolute_import - import abc -from kafka.vendor.six import add_metaclass - -@add_metaclass(abc.ABCMeta) class ABCRecord(object): + __metaclass__ = abc.ABCMeta __slots__ = () - @abc.abstractproperty - def size_in_bytes(self): - """ Number of total bytes in record - """ - @abc.abstractproperty def offset(self): """ Absolute offset of record @@ -45,11 +37,6 @@ class ABCRecord(object): be the checksum for v0 and v1 and None for v2 and above. """ - @abc.abstractmethod - def validate_crc(self): - """ Return True if v0/v1 record matches checksum. noop/True for v2 records - """ - @abc.abstractproperty def headers(self): """ If supported by version list of key-value tuples, or empty list if @@ -57,8 +44,8 @@ class ABCRecord(object): """ -@add_metaclass(abc.ABCMeta) class ABCRecordBatchBuilder(object): + __metaclass__ = abc.ABCMeta __slots__ = () @abc.abstractmethod @@ -97,11 +84,11 @@ class ABCRecordBatchBuilder(object): """ -@add_metaclass(abc.ABCMeta) class ABCRecordBatch(object): - """ For v2 encapsulates a RecordBatch, for v0/v1 a single (maybe + """ For v2 incapsulates a RecordBatch, for v0/v1 a single (maybe compressed) message. """ + __metaclass__ = abc.ABCMeta __slots__ = () @abc.abstractmethod @@ -110,24 +97,9 @@ class ABCRecordBatch(object): if needed. """ - @abc.abstractproperty - def base_offset(self): - """ Return base offset for batch - """ - @abc.abstractproperty - def size_in_bytes(self): - """ Return size of batch in bytes (includes header overhead) - """ - - @abc.abstractproperty - def magic(self): - """ Return magic value (0, 1, 2) for batch. - """ - - -@add_metaclass(abc.ABCMeta) class ABCRecords(object): + __metaclass__ = abc.ABCMeta __slots__ = () @abc.abstractmethod diff --git a/venv/lib/python3.12/site-packages/kafka/record/default_records.py b/venv/lib/python3.12/site-packages/kafka/record/default_records.py index a3b9cd5..a098c42 100644 --- a/venv/lib/python3.12/site-packages/kafka/record/default_records.py +++ b/venv/lib/python3.12/site-packages/kafka/record/default_records.py @@ -60,7 +60,7 @@ from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder from kafka.record.util import ( decode_varint, encode_varint, calc_crc32c, size_of_varint ) -from kafka.errors import CorruptRecordError, UnsupportedCodecError +from kafka.errors import CorruptRecordException, UnsupportedCodecError from kafka.codec import ( gzip_encode, snappy_encode, lz4_encode, zstd_encode, gzip_decode, snappy_decode, lz4_decode, zstd_decode @@ -104,9 +104,6 @@ class DefaultRecordBase(object): LOG_APPEND_TIME = 1 CREATE_TIME = 0 - NO_PRODUCER_ID = -1 - NO_SEQUENCE = -1 - MAX_INT = 2147483647 def _assert_has_codec(self, compression_type): if compression_type == self.CODEC_GZIP: @@ -117,8 +114,6 @@ class DefaultRecordBase(object): checker, name = codecs.has_lz4, "lz4" elif compression_type == self.CODEC_ZSTD: checker, name = codecs.has_zstd, "zstd" - else: - raise UnsupportedCodecError("Unrecognized compression type: %s" % (compression_type,)) if not checker(): raise UnsupportedCodecError( "Libraries for {} compression codec not found".format(name)) @@ -141,14 +136,6 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): def base_offset(self): return self._header_data[0] - @property - def size_in_bytes(self): - return self._header_data[1] + self.AFTER_LEN_OFFSET - - @property - def leader_epoch(self): - return self._header_data[2] - @property def magic(self): return self._header_data[3] @@ -165,14 +152,6 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): def last_offset_delta(self): return self._header_data[6] - @property - def last_offset(self): - return self.base_offset + self.last_offset_delta - - @property - def next_offset(self): - return self.last_offset + 1 - @property def compression_type(self): return self.attributes & self.CODEC_MASK @@ -197,40 +176,6 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): def max_timestamp(self): return self._header_data[8] - @property - def producer_id(self): - return self._header_data[9] - - def has_producer_id(self): - return self.producer_id > self.NO_PRODUCER_ID - - @property - def producer_epoch(self): - return self._header_data[10] - - @property - def base_sequence(self): - return self._header_data[11] - - @property - def has_sequence(self): - return self._header_data[11] != -1 # NO_SEQUENCE - - @property - def last_sequence(self): - if self.base_sequence == self.NO_SEQUENCE: - return self.NO_SEQUENCE - return self._increment_sequence(self.base_sequence, self.last_offset_delta) - - def _increment_sequence(self, base, increment): - if base > (self.MAX_INT - increment): - return increment - (self.MAX_INT - base) - 1 - return base + increment - - @property - def records_count(self): - return self._header_data[12] - def _maybe_uncompress(self): if not self._decompressed: compression_type = self.compression_type @@ -294,14 +239,14 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): header_count, pos = decode_varint(buffer, pos) if header_count < 0: - raise CorruptRecordError("Found invalid number of record " + raise CorruptRecordException("Found invalid number of record " "headers {}".format(header_count)) headers = [] while header_count: # Header key is of type String, that can't be None h_key_len, pos = decode_varint(buffer, pos) if h_key_len < 0: - raise CorruptRecordError( + raise CorruptRecordException( "Invalid negative header key size {}".format(h_key_len)) h_key = buffer[pos: pos + h_key_len].decode("utf-8") pos += h_key_len @@ -319,17 +264,13 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): # validate whether we have read all header bytes in the current record if pos - start_pos != length: - raise CorruptRecordError( + raise CorruptRecordException( "Invalid record size: expected to read {} bytes in record " "payload, but instead read {}".format(length, pos - start_pos)) self._pos = pos - if self.is_control_batch: - return ControlRecord( - length, offset, timestamp, self.timestamp_type, key, value, headers) - else: - return DefaultRecord( - length, offset, timestamp, self.timestamp_type, key, value, headers) + return DefaultRecord( + offset, timestamp, self.timestamp_type, key, value, headers) def __iter__(self): self._maybe_uncompress() @@ -338,14 +279,14 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): def __next__(self): if self._next_record_index >= self._num_records: if self._pos != len(self._buffer): - raise CorruptRecordError( + raise CorruptRecordException( "{} unconsumed bytes after all records consumed".format( len(self._buffer) - self._pos)) raise StopIteration try: msg = self._read_msg() except (ValueError, IndexError) as err: - raise CorruptRecordError( + raise CorruptRecordException( "Found invalid record structure: {!r}".format(err)) else: self._next_record_index += 1 @@ -362,25 +303,13 @@ class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): verify_crc = calc_crc32c(data_view.tobytes()) return crc == verify_crc - def __str__(self): - return ( - "DefaultRecordBatch(magic={}, base_offset={}, last_offset_delta={}," - " first_timestamp={}, max_timestamp={}," - " is_transactional={}, producer_id={}, producer_epoch={}, base_sequence={}," - " records_count={})".format( - self.magic, self.base_offset, self.last_offset_delta, - self.first_timestamp, self.max_timestamp, - self.is_transactional, self.producer_id, self.producer_epoch, self.base_sequence, - self.records_count)) - class DefaultRecord(ABCRecord): - __slots__ = ("_size_in_bytes", "_offset", "_timestamp", "_timestamp_type", "_key", "_value", + __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", "_headers") - def __init__(self, size_in_bytes, offset, timestamp, timestamp_type, key, value, headers): - self._size_in_bytes = size_in_bytes + def __init__(self, offset, timestamp, timestamp_type, key, value, headers): self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type @@ -388,10 +317,6 @@ class DefaultRecord(ABCRecord): self._value = value self._headers = headers - @property - def size_in_bytes(self): - return self._size_in_bytes - @property def offset(self): return self._offset @@ -428,9 +353,6 @@ class DefaultRecord(ABCRecord): def checksum(self): return None - def validate_crc(self): - return True - def __repr__(self): return ( "DefaultRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," @@ -440,45 +362,6 @@ class DefaultRecord(ABCRecord): ) -class ControlRecord(DefaultRecord): - __slots__ = ("_size_in_bytes", "_offset", "_timestamp", "_timestamp_type", "_key", "_value", - "_headers", "_version", "_type") - - KEY_STRUCT = struct.Struct( - ">h" # Current Version => Int16 - "h" # Type => Int16 (0 indicates an abort marker, 1 indicates a commit) - ) - - def __init__(self, size_in_bytes, offset, timestamp, timestamp_type, key, value, headers): - super(ControlRecord, self).__init__(size_in_bytes, offset, timestamp, timestamp_type, key, value, headers) - (self._version, self._type) = self.KEY_STRUCT.unpack(self._key) - - # see https://kafka.apache.org/documentation/#controlbatch - @property - def version(self): - return self._version - - @property - def type(self): - return self._type - - @property - def abort(self): - return self._type == 0 - - @property - def commit(self): - return self._type == 1 - - def __repr__(self): - return ( - "ControlRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," - " version={!r}, type={!r} <{!s}>)".format( - self._offset, self._timestamp, self._timestamp_type, - self._version, self._type, "abort" if self.abort else "commit") - ) - - class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): # excluding key, value and headers: @@ -510,23 +393,6 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): self._buffer = bytearray(self.HEADER_STRUCT.size) - def set_producer_state(self, producer_id, producer_epoch, base_sequence, is_transactional): - assert not is_transactional or producer_id != -1, "Cannot write transactional messages without a valid producer ID" - assert producer_id == -1 or producer_epoch != -1, "Invalid negative producer epoch" - assert producer_id == -1 or base_sequence != -1, "Invalid negative sequence number" - self._producer_id = producer_id - self._producer_epoch = producer_epoch - self._base_sequence = base_sequence - self._is_transactional = is_transactional - - @property - def producer_id(self): - return self._producer_id - - @property - def producer_epoch(self): - return self._producer_epoch - def _get_attributes(self, include_compression_type=True): attrs = 0 if include_compression_type: @@ -635,8 +501,8 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): 0, # CRC will be set below, as we need a filled buffer for it self._get_attributes(use_compression_type), self._last_offset, - self._first_timestamp or 0, - self._max_timestamp or 0, + self._first_timestamp, + self._max_timestamp, self._producer_id, self._producer_epoch, self._base_sequence, @@ -681,15 +547,14 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): """ return len(self._buffer) - @classmethod - def header_size_in_bytes(self): - return self.HEADER_STRUCT.size - - @classmethod - def size_in_bytes(self, offset_delta, timestamp_delta, key, value, headers): + def size_in_bytes(self, offset, timestamp, key, value, headers): + if self._first_timestamp is not None: + timestamp_delta = timestamp - self._first_timestamp + else: + timestamp_delta = 0 size_of_body = ( 1 + # Attrs - size_of_varint(offset_delta) + + size_of_varint(offset) + size_of_varint(timestamp_delta) + self.size_of(key, value, headers) ) @@ -732,17 +597,6 @@ class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): cls.size_of(key, value, headers) ) - def __str__(self): - return ( - "DefaultRecordBatchBuilder(magic={}, base_offset={}, last_offset_delta={}," - " first_timestamp={}, max_timestamp={}," - " is_transactional={}, producer_id={}, producer_epoch={}, base_sequence={}," - " records_count={})".format( - self._magic, 0, self._last_offset, - self._first_timestamp or 0, self._max_timestamp or 0, - self._is_transactional, self._producer_id, self._producer_epoch, self._base_sequence, - self._num_records)) - class DefaultRecordMetadata(object): diff --git a/venv/lib/python3.12/site-packages/kafka/record/legacy_records.py b/venv/lib/python3.12/site-packages/kafka/record/legacy_records.py index f085978..e2ee549 100644 --- a/venv/lib/python3.12/site-packages/kafka/record/legacy_records.py +++ b/venv/lib/python3.12/site-packages/kafka/record/legacy_records.py @@ -52,7 +52,7 @@ from kafka.codec import ( gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka, ) import kafka.codec as codecs -from kafka.errors import CorruptRecordError, UnsupportedCodecError +from kafka.errors import CorruptRecordException, UnsupportedCodecError class LegacyRecordBase(object): @@ -129,7 +129,7 @@ class LegacyRecordBase(object): class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): - __slots__ = ("_buffer", "_magic", "_offset", "_length", "_crc", "_timestamp", + __slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp", "_attributes", "_decompressed") def __init__(self, buffer, magic): @@ -141,20 +141,11 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): assert magic == magic_ self._offset = offset - self._length = length self._crc = crc self._timestamp = timestamp self._attributes = attrs self._decompressed = False - @property - def base_offset(self): - return self._offset - - @property - def size_in_bytes(self): - return self._length + self.LOG_OVERHEAD - @property def timestamp_type(self): """0 for CreateTime; 1 for LogAppendTime; None if unsupported. @@ -173,10 +164,6 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): def compression_type(self): return self._attributes & self.CODEC_MASK - @property - def magic(self): - return self._magic - def validate_crc(self): crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:]) return self._crc == crc @@ -191,7 +178,7 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): value_size = struct.unpack_from(">i", self._buffer, pos)[0] pos += self.VALUE_LENGTH if value_size == -1: - raise CorruptRecordError("Value of compressed message is None") + raise CorruptRecordException("Value of compressed message is None") else: data = self._buffer[pos:pos + value_size] @@ -245,9 +232,6 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): value = self._buffer[pos:pos + value_size].tobytes() return key, value - def _crc_bytes(self, msg_pos, length): - return self._buffer[msg_pos + self.MAGIC_OFFSET:msg_pos + self.LOG_OVERHEAD + length] - def __iter__(self): if self._magic == 1: key_offset = self.KEY_OFFSET_V1 @@ -271,7 +255,7 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): absolute_base_offset = -1 for header, msg_pos in headers: - offset, length, crc, _, attrs, timestamp = header + offset, _, crc, _, attrs, timestamp = header # There should only ever be a single layer of compression assert not attrs & self.CODEC_MASK, ( 'MessageSet at offset %d appears double-compressed. This ' @@ -279,7 +263,7 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): # When magic value is greater than 0, the timestamp # of a compressed message depends on the - # timestamp type of the wrapper message: + # typestamp type of the wrapper message: if timestamp_type == self.LOG_APPEND_TIME: timestamp = self._timestamp @@ -287,36 +271,28 @@ class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): offset += absolute_base_offset key, value = self._read_key_value(msg_pos + key_offset) - crc_bytes = self._crc_bytes(msg_pos, length) yield LegacyRecord( - self._magic, offset, timestamp, timestamp_type, - key, value, crc, crc_bytes) + offset, timestamp, timestamp_type, + key, value, crc) else: key, value = self._read_key_value(key_offset) - crc_bytes = self._crc_bytes(0, len(self._buffer) - self.LOG_OVERHEAD) yield LegacyRecord( - self._magic, self._offset, self._timestamp, timestamp_type, - key, value, self._crc, crc_bytes) + self._offset, self._timestamp, timestamp_type, + key, value, self._crc) class LegacyRecord(ABCRecord): - __slots__ = ("_magic", "_offset", "_timestamp", "_timestamp_type", "_key", "_value", - "_crc", "_crc_bytes") + __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", + "_crc") - def __init__(self, magic, offset, timestamp, timestamp_type, key, value, crc, crc_bytes): - self._magic = magic + def __init__(self, offset, timestamp, timestamp_type, key, value, crc): self._offset = offset self._timestamp = timestamp self._timestamp_type = timestamp_type self._key = key self._value = value self._crc = crc - self._crc_bytes = crc_bytes - - @property - def magic(self): - return self._magic @property def offset(self): @@ -354,19 +330,11 @@ class LegacyRecord(ABCRecord): def checksum(self): return self._crc - def validate_crc(self): - crc = calc_crc32(self._crc_bytes) - return self._crc == crc - - @property - def size_in_bytes(self): - return LegacyRecordBatchBuilder.estimate_size_in_bytes(self._magic, None, self._key, self._value) - def __repr__(self): return ( - "LegacyRecord(magic={!r} offset={!r}, timestamp={!r}, timestamp_type={!r}," + "LegacyRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," " key={!r}, value={!r}, crc={!r})".format( - self._magic, self._offset, self._timestamp, self._timestamp_type, + self._offset, self._timestamp, self._timestamp_type, self._key, self._value, self._crc) ) diff --git a/venv/lib/python3.12/site-packages/kafka/record/memory_records.py b/venv/lib/python3.12/site-packages/kafka/record/memory_records.py index 9df7330..fc2ef2d 100644 --- a/venv/lib/python3.12/site-packages/kafka/record/memory_records.py +++ b/venv/lib/python3.12/site-packages/kafka/record/memory_records.py @@ -22,7 +22,7 @@ from __future__ import division import struct -from kafka.errors import CorruptRecordError, IllegalStateError, UnsupportedVersionError +from kafka.errors import CorruptRecordException from kafka.record.abc import ABCRecords from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder @@ -99,7 +99,7 @@ class MemoryRecords(ABCRecords): if next_slice is None: return None if len(next_slice) < _min_slice: - raise CorruptRecordError( + raise CorruptRecordException( "Record size is less than the minimum record overhead " "({})".format(_min_slice - self.LOG_OVERHEAD)) self._cache_next() @@ -109,56 +109,31 @@ class MemoryRecords(ABCRecords): else: return DefaultRecordBatch(next_slice) - def __iter__(self): - return self - - def __next__(self): - if not self.has_next(): - raise StopIteration - return self.next_batch() - - next = __next__ - class MemoryRecordsBuilder(object): __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", - "_magic", "_bytes_written", "_producer_id", "_producer_epoch") + "_bytes_written") - def __init__(self, magic, compression_type, batch_size, offset=0, - transactional=False, producer_id=-1, producer_epoch=-1, base_sequence=-1): + def __init__(self, magic, compression_type, batch_size): assert magic in [0, 1, 2], "Not supported magic" assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" if magic >= 2: - assert not transactional or producer_id != -1, "Cannot write transactional messages without a valid producer ID" - assert producer_id == -1 or producer_epoch != -1, "Invalid negative producer epoch" - assert producer_id == -1 or base_sequence != -1, "Invalid negative sequence number used" - self._builder = DefaultRecordBatchBuilder( magic=magic, compression_type=compression_type, - is_transactional=transactional, producer_id=producer_id, - producer_epoch=producer_epoch, base_sequence=base_sequence, - batch_size=batch_size) - self._producer_id = producer_id - self._producer_epoch = producer_epoch + is_transactional=False, producer_id=-1, producer_epoch=-1, + base_sequence=-1, batch_size=batch_size) else: - assert not transactional and producer_id == -1, "Idempotent messages are not supported for magic %s" % (magic,) self._builder = LegacyRecordBatchBuilder( magic=magic, compression_type=compression_type, batch_size=batch_size) - self._producer_id = None self._batch_size = batch_size self._buffer = None - self._next_offset = offset + self._next_offset = 0 self._closed = False - self._magic = magic self._bytes_written = 0 - def skip(self, offsets_to_skip): - # Exposed for testing compacted records - self._next_offset += offsets_to_skip - def append(self, timestamp, key, value, headers=[]): """ Append a message to the buffer. @@ -176,30 +151,6 @@ class MemoryRecordsBuilder(object): self._next_offset += 1 return metadata - def set_producer_state(self, producer_id, producer_epoch, base_sequence, is_transactional): - if self._magic < 2: - raise UnsupportedVersionError('Producer State requires Message format v2+') - elif self._closed: - # Sequence numbers are assigned when the batch is closed while the accumulator is being drained. - # If the resulting ProduceRequest to the partition leader failed for a retriable error, the batch will - # be re queued. In this case, we should not attempt to set the state again, since changing the pid and sequence - # once a batch has been sent to the broker risks introducing duplicates. - raise IllegalStateError("Trying to set producer state of an already closed batch. This indicates a bug on the client.") - self._builder.set_producer_state(producer_id, producer_epoch, base_sequence, is_transactional) - self._producer_id = producer_id - - @property - def producer_id(self): - return self._producer_id - - @property - def producer_epoch(self): - return self._producer_epoch - - def records(self): - assert self._closed - return MemoryRecords(self._buffer) - def close(self): # This method may be called multiple times on the same batch # i.e., on retries @@ -209,9 +160,6 @@ class MemoryRecordsBuilder(object): if not self._closed: self._bytes_written = self._builder.size() self._buffer = bytes(self._builder.build()) - if self._magic == 2: - self._producer_id = self._builder.producer_id - self._producer_epoch = self._builder.producer_epoch self._builder = None self._closed = True diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/__init__.py b/venv/lib/python3.12/site-packages/kafka/sasl/__init__.py deleted file mode 100644 index 90f05e7..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import absolute_import - -import platform - -from kafka.sasl.gssapi import SaslMechanismGSSAPI -from kafka.sasl.msk import SaslMechanismAwsMskIam -from kafka.sasl.oauth import SaslMechanismOAuth -from kafka.sasl.plain import SaslMechanismPlain -from kafka.sasl.scram import SaslMechanismScram -from kafka.sasl.sspi import SaslMechanismSSPI - - -SASL_MECHANISMS = {} - - -def register_sasl_mechanism(name, klass, overwrite=False): - if not overwrite and name in SASL_MECHANISMS: - raise ValueError('Sasl mechanism %s already defined!' % name) - SASL_MECHANISMS[name] = klass - - -def get_sasl_mechanism(name): - return SASL_MECHANISMS[name] - - -register_sasl_mechanism('AWS_MSK_IAM', SaslMechanismAwsMskIam) -if platform.system() == 'Windows': - register_sasl_mechanism('GSSAPI', SaslMechanismSSPI) -else: - register_sasl_mechanism('GSSAPI', SaslMechanismGSSAPI) -register_sasl_mechanism('OAUTHBEARER', SaslMechanismOAuth) -register_sasl_mechanism('PLAIN', SaslMechanismPlain) -register_sasl_mechanism('SCRAM-SHA-256', SaslMechanismScram) -register_sasl_mechanism('SCRAM-SHA-512', SaslMechanismScram) diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/abc.py b/venv/lib/python3.12/site-packages/kafka/sasl/abc.py deleted file mode 100644 index 0577888..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/abc.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import absolute_import - -import abc - -from kafka.vendor.six import add_metaclass - - -@add_metaclass(abc.ABCMeta) -class SaslMechanism(object): - @abc.abstractmethod - def __init__(self, **config): - pass - - @abc.abstractmethod - def auth_bytes(self): - pass - - @abc.abstractmethod - def receive(self, auth_bytes): - pass - - @abc.abstractmethod - def is_done(self): - pass - - @abc.abstractmethod - def is_authenticated(self): - pass - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated via SASL' diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/gssapi.py b/venv/lib/python3.12/site-packages/kafka/sasl/gssapi.py deleted file mode 100644 index 4785b1b..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/gssapi.py +++ /dev/null @@ -1,96 +0,0 @@ -from __future__ import absolute_import - -import struct - -# needed for SASL_GSSAPI authentication: -try: - import gssapi - from gssapi.raw.misc import GSSError -except (ImportError, OSError): - #no gssapi available, will disable gssapi mechanism - gssapi = None - GSSError = None - -from kafka.sasl.abc import SaslMechanism - - -class SaslMechanismGSSAPI(SaslMechanism): - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - - SASL_QOP_AUTH = 1 - SASL_QOP_AUTH_INT = 2 - SASL_QOP_AUTH_CONF = 4 - - def __init__(self, **config): - assert gssapi is not None, 'GSSAPI lib not available' - if 'sasl_kerberos_name' not in config and 'sasl_kerberos_service_name' not in config: - raise ValueError('sasl_kerberos_service_name or sasl_kerberos_name required for GSSAPI sasl configuration') - self._is_done = False - self._is_authenticated = False - self.gssapi_name = None - if config.get('sasl_kerberos_name', None) is not None: - self.auth_id = str(config['sasl_kerberos_name']) - if isinstance(config['sasl_kerberos_name'], gssapi.Name): - self.gssapi_name = config['sasl_kerberos_name'] - else: - kerberos_domain_name = config.get('sasl_kerberos_domain_name', '') or config.get('host', '') - self.auth_id = config['sasl_kerberos_service_name'] + '@' + kerberos_domain_name - if self.gssapi_name is None: - self.gssapi_name = gssapi.Name(self.auth_id, name_type=gssapi.NameType.hostbased_service).canonicalize(gssapi.MechType.kerberos) - self._client_ctx = gssapi.SecurityContext(name=self.gssapi_name, usage='initiate') - self._next_token = self._client_ctx.step(None) - - def auth_bytes(self): - # GSSAPI Auth does not have a final broker->client message - # so mark is_done after the final auth_bytes are provided - # in practice we'll still receive a response when using SaslAuthenticate - # but not when using the prior unframed approach. - if self._is_authenticated: - self._is_done = True - return self._next_token or b'' - - def receive(self, auth_bytes): - if not self._client_ctx.complete: - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - self._next_token = self._client_ctx.step(auth_bytes) - elif self._is_done: - # The final step of gssapi is send, so we do not expect any additional bytes - # however, allow an empty message to support SaslAuthenticate response - if auth_bytes != b'': - raise ValueError("Unexpected receive auth_bytes after sasl/gssapi completion") - else: - # unwraps message containing supported protection levels and msg size - msg = self._client_ctx.unwrap(auth_bytes).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - client_flags = self.SASL_QOP_AUTH - server_flags = struct.Struct('>b').unpack(msg[0:1])[0] - message_parts = [ - struct.Struct('>b').pack(client_flags & server_flags), - msg[1:], # always agree to max message size from server - self.auth_id.encode('utf-8'), - ] - # add authorization identity to the response, and GSS-wrap - self._next_token = self._client_ctx.wrap(b''.join(message_parts), False).message - # We need to identify the last token in auth_bytes(); - # we can't rely on client_ctx.complete because it becomes True after generating - # the second-to-last token (after calling .step(auth_bytes) for the final time) - # We could introduce an additional state variable (i.e., self._final_token), - # but instead we just set _is_authenticated. Since the plugin interface does - # not read is_authenticated() until after is_done() is True, this should be fine. - self._is_authenticated = True - - def is_done(self): - return self._is_done - - def is_authenticated(self): - return self._is_authenticated - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated as %s to %s via SASL / GSSAPI' % (self._client_ctx.initiator_name, self._client_ctx.target_name) diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/msk.py b/venv/lib/python3.12/site-packages/kafka/sasl/msk.py deleted file mode 100644 index 7ec0321..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/msk.py +++ /dev/null @@ -1,244 +0,0 @@ -from __future__ import absolute_import - -import datetime -import hashlib -import hmac -import json -import logging -import string - -# needed for AWS_MSK_IAM authentication: -try: - from botocore.session import Session as BotoSession -except ImportError: - # no botocore available, will disable AWS_MSK_IAM mechanism - BotoSession = None - -from kafka.errors import KafkaConfigurationError -from kafka.sasl.abc import SaslMechanism -from kafka.vendor.six.moves import urllib - - -log = logging.getLogger(__name__) - - -class SaslMechanismAwsMskIam(SaslMechanism): - def __init__(self, **config): - assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package' - assert config.get('security_protocol', '') == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL' - assert 'host' in config, 'AWS_MSK_IAM requires host configuration' - self.host = config['host'] - self._auth = None - self._is_done = False - self._is_authenticated = False - - def _build_client(self): - session = BotoSession() - credentials = session.get_credentials().get_frozen_credentials() - if not session.get_config_variable('region'): - raise KafkaConfigurationError('Unable to determine region for AWS MSK cluster. Is AWS_DEFAULT_REGION set?') - return AwsMskIamClient( - host=self.host, - access_key=credentials.access_key, - secret_key=credentials.secret_key, - region=session.get_config_variable('region'), - token=credentials.token, - ) - - def auth_bytes(self): - client = self._build_client() - log.debug("Generating auth token for MSK scope: %s", client._scope) - return client.first_message() - - def receive(self, auth_bytes): - self._is_done = True - self._is_authenticated = auth_bytes != b'' - self._auth = auth_bytes.decode('utf-8') - - def is_done(self): - return self._is_done - - def is_authenticated(self): - return self._is_authenticated - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated via SASL / AWS_MSK_IAM %s' % (self._auth,) - - -class AwsMskIamClient: - UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~' - - def __init__(self, host, access_key, secret_key, region, token=None): - """ - Arguments: - host (str): The hostname of the broker. - access_key (str): An AWS_ACCESS_KEY_ID. - secret_key (str): An AWS_SECRET_ACCESS_KEY. - region (str): An AWS_REGION. - token (Optional[str]): An AWS_SESSION_TOKEN if using temporary - credentials. - """ - self.algorithm = 'AWS4-HMAC-SHA256' - self.expires = '900' - self.hashfunc = hashlib.sha256 - self.headers = [ - ('host', host) - ] - self.version = '2020_10_22' - - self.service = 'kafka-cluster' - self.action = '{}:Connect'.format(self.service) - - now = datetime.datetime.utcnow() - self.datestamp = now.strftime('%Y%m%d') - self.timestamp = now.strftime('%Y%m%dT%H%M%SZ') - - self.host = host - self.access_key = access_key - self.secret_key = secret_key - self.region = region - self.token = token - - @property - def _credential(self): - return '{0.access_key}/{0._scope}'.format(self) - - @property - def _scope(self): - return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self) - - @property - def _signed_headers(self): - """ - Returns (str): - An alphabetically sorted, semicolon-delimited list of lowercase - request header names. - """ - return ';'.join(sorted(k.lower() for k, _ in self.headers)) - - @property - def _canonical_headers(self): - """ - Returns (str): - A newline-delited list of header names and values. - Header names are lowercased. - """ - return '\n'.join(map(':'.join, self.headers)) + '\n' - - @property - def _canonical_request(self): - """ - Returns (str): - An AWS Signature Version 4 canonical request in the format: - \n - \n - \n - \n - \n - - """ - # The hashed_payload is always an empty string for MSK. - hashed_payload = self.hashfunc(b'').hexdigest() - return '\n'.join(( - 'GET', - '/', - self._canonical_querystring, - self._canonical_headers, - self._signed_headers, - hashed_payload, - )) - - @property - def _canonical_querystring(self): - """ - Returns (str): - A '&'-separated list of URI-encoded key/value pairs. - """ - params = [] - params.append(('Action', self.action)) - params.append(('X-Amz-Algorithm', self.algorithm)) - params.append(('X-Amz-Credential', self._credential)) - params.append(('X-Amz-Date', self.timestamp)) - params.append(('X-Amz-Expires', self.expires)) - if self.token: - params.append(('X-Amz-Security-Token', self.token)) - params.append(('X-Amz-SignedHeaders', self._signed_headers)) - - return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params) - - @property - def _signing_key(self): - """ - Returns (bytes): - An AWS Signature V4 signing key generated from the secret_key, date, - region, service, and request type. - """ - key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp) - key = self._hmac(key, self.region) - key = self._hmac(key, self.service) - key = self._hmac(key, 'aws4_request') - return key - - @property - def _signing_str(self): - """ - Returns (str): - A string used to sign the AWS Signature V4 payload in the format: - \n - \n - \n - - """ - canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest() - return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash)) - - def _uriencode(self, msg): - """ - Arguments: - msg (str): A string to URI-encode. - - Returns (str): - The URI-encoded version of the provided msg, following the encoding - rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode - """ - return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS) - - def _hmac(self, key, msg): - """ - Arguments: - key (bytes): A key to use for the HMAC digest. - msg (str): A value to include in the HMAC digest. - Returns (bytes): - An HMAC digest of the given key and msg. - """ - return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest() - - def first_message(self): - """ - Returns (bytes): - An encoded JSON authentication payload that can be sent to the - broker. - """ - signature = hmac.new( - self._signing_key, - self._signing_str.encode('utf-8'), - digestmod=self.hashfunc, - ).hexdigest() - msg = { - 'version': self.version, - 'host': self.host, - 'user-agent': 'kafka-python', - 'action': self.action, - 'x-amz-algorithm': self.algorithm, - 'x-amz-credential': self._credential, - 'x-amz-date': self.timestamp, - 'x-amz-signedheaders': self._signed_headers, - 'x-amz-expires': self.expires, - 'x-amz-signature': signature, - } - if self.token: - msg['x-amz-security-token'] = self.token - - return json.dumps(msg, separators=(',', ':')).encode('utf-8') diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/oauth.py b/venv/lib/python3.12/site-packages/kafka/sasl/oauth.py deleted file mode 100644 index f1e959c..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/oauth.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import absolute_import - -import abc -import logging - -from kafka.sasl.abc import SaslMechanism - - -log = logging.getLogger(__name__) - - -class SaslMechanismOAuth(SaslMechanism): - - def __init__(self, **config): - assert 'sasl_oauth_token_provider' in config, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' - assert isinstance(config['sasl_oauth_token_provider'], AbstractTokenProvider), \ - 'sasl_oauth_token_provider must implement kafka.sasl.oauth.AbstractTokenProvider' - self.token_provider = config['sasl_oauth_token_provider'] - self._error = None - self._is_done = False - self._is_authenticated = False - - def auth_bytes(self): - if self._error: - # Server should respond to this with SaslAuthenticate failure, which ends the auth process - return self._error - token = self.token_provider.token() - extensions = self._token_extensions() - return "n,,\x01auth=Bearer {}{}\x01\x01".format(token, extensions).encode('utf-8') - - def receive(self, auth_bytes): - if auth_bytes != b'': - error = auth_bytes.decode('utf-8') - log.debug("Sending x01 response to server after receiving SASL OAuth error: %s", error) - self._error = b'\x01' - else: - self._is_done = True - self._is_authenticated = True - - def is_done(self): - return self._is_done - - def is_authenticated(self): - return self._is_authenticated - - def _token_extensions(self): - """ - Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER - initial request. - """ - # Builds up a string separated by \x01 via a dict of key value pairs - extensions = self.token_provider.extensions() - msg = '\x01'.join(['{}={}'.format(k, v) for k, v in extensions.items()]) - return '\x01' + msg if msg else '' - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated via SASL / OAuth' - -# This statement is compatible with both Python 2.7 & 3+ -ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) - -class AbstractTokenProvider(ABC): - """ - A Token Provider must be used for the SASL OAuthBearer protocol. - - The implementation should ensure token reuse so that multiple - calls at connect time do not create multiple tokens. The implementation - should also periodically refresh the token in order to guarantee - that each call returns an unexpired token. A timeout error should - be returned after a short period of inactivity so that the - broker can log debugging info and retry. - - Token Providers MUST implement the token() method - """ - - def __init__(self, **config): - pass - - @abc.abstractmethod - def token(self): - """ - Returns a (str) ID/Access Token to be sent to the Kafka - client. - """ - pass - - def extensions(self): - """ - This is an OPTIONAL method that may be implemented. - - Returns a map of key-value pairs that can - be sent with the SASL/OAUTHBEARER initial client request. If - not implemented, the values are ignored. This feature is only available - in Kafka >= 2.1.0. - - All returned keys and values should be type str - """ - return {} diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/plain.py b/venv/lib/python3.12/site-packages/kafka/sasl/plain.py deleted file mode 100644 index 81443f5..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/plain.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import absolute_import - -import logging - -from kafka.sasl.abc import SaslMechanism - - -log = logging.getLogger(__name__) - - -class SaslMechanismPlain(SaslMechanism): - - def __init__(self, **config): - if config.get('security_protocol', '') == 'SASL_PLAINTEXT': - log.warning('Sending username and password in the clear') - assert 'sasl_plain_username' in config, 'sasl_plain_username required for PLAIN sasl' - assert 'sasl_plain_password' in config, 'sasl_plain_password required for PLAIN sasl' - - self.username = config['sasl_plain_username'] - self.password = config['sasl_plain_password'] - self._is_done = False - self._is_authenticated = False - - def auth_bytes(self): - # Send PLAIN credentials per RFC-4616 - return bytes('\0'.join([self.username, self.username, self.password]).encode('utf-8')) - - def receive(self, auth_bytes): - self._is_done = True - self._is_authenticated = auth_bytes == b'' - - def is_done(self): - return self._is_done - - def is_authenticated(self): - return self._is_authenticated - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated as %s via SASL / Plain' % self.username diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/scram.py b/venv/lib/python3.12/site-packages/kafka/sasl/scram.py deleted file mode 100644 index d8cd071..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/scram.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import absolute_import - -import base64 -import hashlib -import hmac -import logging -import uuid - - -from kafka.sasl.abc import SaslMechanism -from kafka.vendor import six - - -log = logging.getLogger(__name__) - - -if six.PY2: - def xor_bytes(left, right): - return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) -else: - def xor_bytes(left, right): - return bytes(lb ^ rb for lb, rb in zip(left, right)) - - -class SaslMechanismScram(SaslMechanism): - def __init__(self, **config): - assert 'sasl_plain_username' in config, 'sasl_plain_username required for SCRAM sasl' - assert 'sasl_plain_password' in config, 'sasl_plain_password required for SCRAM sasl' - assert config.get('sasl_mechanism', '') in ScramClient.MECHANISMS, 'Unrecognized SCRAM mechanism' - if config.get('security_protocol', '') == 'SASL_PLAINTEXT': - log.warning('Exchanging credentials in the clear during Sasl Authentication') - - self.username = config['sasl_plain_username'] - self.mechanism = config['sasl_mechanism'] - self._scram_client = ScramClient( - config['sasl_plain_username'], - config['sasl_plain_password'], - config['sasl_mechanism'] - ) - self._state = 0 - - def auth_bytes(self): - if self._state == 0: - return self._scram_client.first_message() - elif self._state == 1: - return self._scram_client.final_message() - else: - raise ValueError('No auth_bytes for state: %s' % self._state) - - def receive(self, auth_bytes): - if self._state == 0: - self._scram_client.process_server_first_message(auth_bytes) - elif self._state == 1: - self._scram_client.process_server_final_message(auth_bytes) - else: - raise ValueError('Cannot receive bytes in state: %s' % self._state) - self._state += 1 - return self.is_done() - - def is_done(self): - return self._state == 2 - - def is_authenticated(self): - # receive raises if authentication fails...? - return self._state == 2 - - def auth_details(self): - if not self.is_authenticated: - raise RuntimeError('Not authenticated yet!') - return 'Authenticated as %s via SASL / %s' % (self.username, self.mechanism) - - -class ScramClient: - MECHANISMS = { - 'SCRAM-SHA-256': hashlib.sha256, - 'SCRAM-SHA-512': hashlib.sha512 - } - - def __init__(self, user, password, mechanism): - self.nonce = str(uuid.uuid4()).replace('-', '').encode('utf-8') - self.auth_message = b'' - self.salted_password = None - self.user = user.encode('utf-8') - self.password = password.encode('utf-8') - self.hashfunc = self.MECHANISMS[mechanism] - self.hashname = ''.join(mechanism.lower().split('-')[1:3]) - self.stored_key = None - self.client_key = None - self.client_signature = None - self.client_proof = None - self.server_key = None - self.server_signature = None - - def first_message(self): - client_first_bare = b'n=' + self.user + b',r=' + self.nonce - self.auth_message += client_first_bare - return b'n,,' + client_first_bare - - def process_server_first_message(self, server_first_message): - self.auth_message += b',' + server_first_message - params = dict(pair.split('=', 1) for pair in server_first_message.decode('utf-8').split(',')) - server_nonce = params['r'].encode('utf-8') - if not server_nonce.startswith(self.nonce): - raise ValueError("Server nonce, did not start with client nonce!") - self.nonce = server_nonce - self.auth_message += b',c=biws,r=' + self.nonce - - salt = base64.b64decode(params['s'].encode('utf-8')) - iterations = int(params['i']) - self.create_salted_password(salt, iterations) - - self.client_key = self.hmac(self.salted_password, b'Client Key') - self.stored_key = self.hashfunc(self.client_key).digest() - self.client_signature = self.hmac(self.stored_key, self.auth_message) - self.client_proof = xor_bytes(self.client_key, self.client_signature) - self.server_key = self.hmac(self.salted_password, b'Server Key') - self.server_signature = self.hmac(self.server_key, self.auth_message) - - def hmac(self, key, msg): - return hmac.new(key, msg, digestmod=self.hashfunc).digest() - - def create_salted_password(self, salt, iterations): - self.salted_password = hashlib.pbkdf2_hmac( - self.hashname, self.password, salt, iterations - ) - - def final_message(self): - return b'c=biws,r=' + self.nonce + b',p=' + base64.b64encode(self.client_proof) - - def process_server_final_message(self, server_final_message): - params = dict(pair.split('=', 1) for pair in server_final_message.decode('utf-8').split(',')) - if self.server_signature != base64.b64decode(params['v'].encode('utf-8')): - raise ValueError("Server sent wrong signature!") diff --git a/venv/lib/python3.12/site-packages/kafka/sasl/sspi.py b/venv/lib/python3.12/site-packages/kafka/sasl/sspi.py deleted file mode 100644 index f4c95d0..0000000 --- a/venv/lib/python3.12/site-packages/kafka/sasl/sspi.py +++ /dev/null @@ -1,111 +0,0 @@ -from __future__ import absolute_import - -import logging - -# Windows-only -try: - import sspi - import pywintypes - import sspicon - import win32security -except ImportError: - sspi = None - -from kafka.sasl.abc import SaslMechanism - - -log = logging.getLogger(__name__) - - -class SaslMechanismSSPI(SaslMechanism): - # Establish security context and negotiate protection level - # For reference see RFC 4752, section 3 - - SASL_QOP_AUTH = 1 - SASL_QOP_AUTH_INT = 2 - SASL_QOP_AUTH_CONF = 4 - - def __init__(self, **config): - assert sspi is not None, 'No GSSAPI lib available (gssapi or sspi)' - if 'sasl_kerberos_name' not in config and 'sasl_kerberos_service_name' not in config: - raise ValueError('sasl_kerberos_service_name or sasl_kerberos_name required for GSSAPI sasl configuration') - self._is_done = False - self._is_authenticated = False - if config.get('sasl_kerberos_name', None) is not None: - self.auth_id = str(config['sasl_kerberos_name']) - else: - kerberos_domain_name = config.get('sasl_kerberos_domain_name', '') or config.get('host', '') - self.auth_id = config['sasl_kerberos_service_name'] + '/' + kerberos_domain_name - scheme = "Kerberos" # Do not try with Negotiate for SASL authentication. Tokens are different. - # https://docs.microsoft.com/en-us/windows/win32/secauthn/context-requirements - flags = ( - sspicon.ISC_REQ_MUTUAL_AUTH | # mutual authentication - sspicon.ISC_REQ_INTEGRITY | # check for integrity - sspicon.ISC_REQ_SEQUENCE_DETECT | # enable out-of-order messages - sspicon.ISC_REQ_CONFIDENTIALITY # request confidentiality - ) - self._client_ctx = sspi.ClientAuth(scheme, targetspn=self.auth_id, scflags=flags) - self._next_token = self._client_ctx.step(None) - - def auth_bytes(self): - # GSSAPI Auth does not have a final broker->client message - # so mark is_done after the final auth_bytes are provided - # in practice we'll still receive a response when using SaslAuthenticate - # but not when using the prior unframed approach. - if self._client_ctx.authenticated: - self._is_done = True - self._is_authenticated = True - return self._next_token or b'' - - def receive(self, auth_bytes): - log.debug("Received token from server (size %s)", len(auth_bytes)) - if not self._client_ctx.authenticated: - # calculate an output token from kafka token (or None on first iteration) - # https://docs.microsoft.com/en-us/windows/win32/api/sspi/nf-sspi-initializesecuritycontexta - # https://docs.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--kerberos - # authorize method will wrap for us our token in sspi structures - error, auth = self._client_ctx.authorize(auth_bytes) - if len(auth) > 0 and len(auth[0].Buffer): - log.debug("Got token from context") - # this buffer must be sent to the server whatever the result is - self._next_token = auth[0].Buffer - else: - log.debug("Got no token, exchange finished") - # seems to be the end of the loop - self._next_token = b'' - elif self._is_done: - # The final step of gssapi is send, so we do not expect any additional bytes - # however, allow an empty message to support SaslAuthenticate response - if auth_bytes != b'': - raise ValueError("Unexpected receive auth_bytes after sasl/gssapi completion") - else: - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # The following part is required by SASL, but not by classic Kerberos. - # See RFC 4752 - - # unwraps message containing supported protection levels and msg size - msg, _was_encrypted = self._client_ctx.unwrap(auth_bytes) - - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - client_flags = self.SASL_QOP_AUTH - server_flags = msg[0] - message_parts = [ - bytes(client_flags & server_flags), - msg[:1], - self.auth_id.encode('utf-8'), - ] - # add authorization identity to the response, and GSS-wrap - self._next_token = self._client_ctx.wrap(b''.join(message_parts), False) - - def is_done(self): - return self._is_done - - def is_authenticated(self): - return self._is_authenticated - - def auth_details(self): - return 'Authenticated as %s to %s via SASL / SSPI/GSSAPI \\o/' % (self._client_ctx.initiator_name, self._client_ctx.service_name) diff --git a/venv/lib/python3.12/site-packages/kafka/scram.py b/venv/lib/python3.12/site-packages/kafka/scram.py new file mode 100644 index 0000000..7f00375 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka/scram.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import hmac +import uuid + +from kafka.vendor import six + + +if six.PY2: + def xor_bytes(left, right): + return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) +else: + def xor_bytes(left, right): + return bytes(lb ^ rb for lb, rb in zip(left, right)) + + +class ScramClient: + MECHANISMS = { + 'SCRAM-SHA-256': hashlib.sha256, + 'SCRAM-SHA-512': hashlib.sha512 + } + + def __init__(self, user, password, mechanism): + self.nonce = str(uuid.uuid4()).replace('-', '') + self.auth_message = '' + self.salted_password = None + self.user = user + self.password = password.encode('utf-8') + self.hashfunc = self.MECHANISMS[mechanism] + self.hashname = ''.join(mechanism.lower().split('-')[1:3]) + self.stored_key = None + self.client_key = None + self.client_signature = None + self.client_proof = None + self.server_key = None + self.server_signature = None + + def first_message(self): + client_first_bare = 'n={},r={}'.format(self.user, self.nonce) + self.auth_message += client_first_bare + return 'n,,' + client_first_bare + + def process_server_first_message(self, server_first_message): + self.auth_message += ',' + server_first_message + params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) + server_nonce = params['r'] + if not server_nonce.startswith(self.nonce): + raise ValueError("Server nonce, did not start with client nonce!") + self.nonce = server_nonce + self.auth_message += ',c=biws,r=' + self.nonce + + salt = base64.b64decode(params['s'].encode('utf-8')) + iterations = int(params['i']) + self.create_salted_password(salt, iterations) + + self.client_key = self.hmac(self.salted_password, b'Client Key') + self.stored_key = self.hashfunc(self.client_key).digest() + self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8')) + self.client_proof = xor_bytes(self.client_key, self.client_signature) + self.server_key = self.hmac(self.salted_password, b'Server Key') + self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8')) + + def hmac(self, key, msg): + return hmac.new(key, msg, digestmod=self.hashfunc).digest() + + def create_salted_password(self, salt, iterations): + self.salted_password = hashlib.pbkdf2_hmac( + self.hashname, self.password, salt, iterations + ) + + def final_message(self): + return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8')) + + def process_server_final_message(self, server_final_message): + params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) + if self.server_signature != base64.b64decode(params['v'].encode('utf-8')): + raise ValueError("Server sent wrong signature!") + + diff --git a/venv/lib/python3.12/site-packages/kafka/socks5_wrapper.py b/venv/lib/python3.12/site-packages/kafka/socks5_wrapper.py deleted file mode 100644 index 18bea7c..0000000 --- a/venv/lib/python3.12/site-packages/kafka/socks5_wrapper.py +++ /dev/null @@ -1,248 +0,0 @@ -try: - from urllib.parse import urlparse -except ImportError: - from urlparse import urlparse - -import errno -import logging -import random -import socket -import struct - -log = logging.getLogger(__name__) - - -class ProxyConnectionStates: - DISCONNECTED = '' - CONNECTING = '' - NEGOTIATE_PROPOSE = '' - NEGOTIATING = '' - AUTHENTICATING = '' - REQUEST_SUBMIT = '' - REQUESTING = '' - READ_ADDRESS = '' - COMPLETE = '' - - -class Socks5Wrapper: - """Socks5 proxy wrapper - - Manages connection through socks5 proxy with support for username/password - authentication. - """ - - def __init__(self, proxy_url, afi): - self._buffer_in = b'' - self._buffer_out = b'' - self._proxy_url = urlparse(proxy_url) - self._sock = None - self._state = ProxyConnectionStates.DISCONNECTED - self._target_afi = socket.AF_UNSPEC - - proxy_addrs = self.dns_lookup(self._proxy_url.hostname, self._proxy_url.port, afi) - # TODO raise error on lookup failure - self._proxy_addr = random.choice(proxy_addrs) - - @classmethod - def is_inet_4_or_6(cls, gai): - """Given a getaddrinfo struct, return True iff ipv4 or ipv6""" - return gai[0] in (socket.AF_INET, socket.AF_INET6) - - @classmethod - def dns_lookup(cls, host, port, afi=socket.AF_UNSPEC): - """Returns a list of getaddrinfo structs, optionally filtered to an afi (ipv4 / ipv6)""" - # XXX: all DNS functions in Python are blocking. If we really - # want to be non-blocking here, we need to use a 3rd-party - # library like python-adns, or move resolution onto its - # own thread. This will be subject to the default libc - # name resolution timeout (5s on most Linux boxes) - try: - return list(filter(cls.is_inet_4_or_6, - socket.getaddrinfo(host, port, afi, - socket.SOCK_STREAM))) - except socket.gaierror as ex: - log.warning("DNS lookup failed for proxy %s:%d, %r", host, port, ex) - return [] - - def socket(self, family, sock_type): - """Open and record a socket. - - Returns the actual underlying socket - object to ensure e.g. selects and ssl wrapping works as expected. - """ - self._target_afi = family # Store the address family of the target - afi, _, _, _, _ = self._proxy_addr - self._sock = socket.socket(afi, sock_type) - return self._sock - - def _flush_buf(self): - """Send out all data that is stored in the outgoing buffer. - - It is expected that the caller handles error handling, including non-blocking - as well as connection failure exceptions. - """ - while self._buffer_out: - sent_bytes = self._sock.send(self._buffer_out) - self._buffer_out = self._buffer_out[sent_bytes:] - - def _peek_buf(self, datalen): - """Ensure local inbound buffer has enough data, and return that data without - consuming the local buffer - - It's expected that the caller handles e.g. blocking exceptions""" - while True: - bytes_remaining = datalen - len(self._buffer_in) - if bytes_remaining <= 0: - break - data = self._sock.recv(bytes_remaining) - if not data: - break - self._buffer_in = self._buffer_in + data - - return self._buffer_in[:datalen] - - def _read_buf(self, datalen): - """Read and consume bytes from socket connection - - It's expected that the caller handles e.g. blocking exceptions""" - buf = self._peek_buf(datalen) - if buf: - self._buffer_in = self._buffer_in[len(buf):] - return buf - - def connect_ex(self, addr): - """Runs a state machine through connection to authentication to - proxy connection request. - - The somewhat strange setup is to facilitate non-intrusive use from - BrokerConnection state machine. - - This function is called with a socket in non-blocking mode. Both - send and receive calls can return in EWOULDBLOCK/EAGAIN which we - specifically avoid handling here. These are handled in main - BrokerConnection connection loop, which then would retry calls - to this function.""" - - if self._state == ProxyConnectionStates.DISCONNECTED: - self._state = ProxyConnectionStates.CONNECTING - - if self._state == ProxyConnectionStates.CONNECTING: - _, _, _, _, sockaddr = self._proxy_addr - ret = self._sock.connect_ex(sockaddr) - if not ret or ret == errno.EISCONN: - self._state = ProxyConnectionStates.NEGOTIATE_PROPOSE - else: - return ret - - if self._state == ProxyConnectionStates.NEGOTIATE_PROPOSE: - if self._proxy_url.username and self._proxy_url.password: - # Propose username/password - self._buffer_out = b"\x05\x01\x02" - else: - # Propose no auth - self._buffer_out = b"\x05\x01\x00" - self._state = ProxyConnectionStates.NEGOTIATING - - if self._state == ProxyConnectionStates.NEGOTIATING: - self._flush_buf() - buf = self._read_buf(2) - if buf[0:1] != b"\x05": - log.error("Unrecognized SOCKS version") - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - - if buf[1:2] == b"\x00": - # No authentication required - self._state = ProxyConnectionStates.REQUEST_SUBMIT - elif buf[1:2] == b"\x02": - # Username/password authentication selected - userlen = len(self._proxy_url.username) - passlen = len(self._proxy_url.password) - self._buffer_out = struct.pack( - "!bb{}sb{}s".format(userlen, passlen), - 1, # version - userlen, - self._proxy_url.username.encode(), - passlen, - self._proxy_url.password.encode(), - ) - self._state = ProxyConnectionStates.AUTHENTICATING - else: - log.error("Unrecognized SOCKS authentication method") - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - - if self._state == ProxyConnectionStates.AUTHENTICATING: - self._flush_buf() - buf = self._read_buf(2) - if buf == b"\x01\x00": - # Authentication succesful - self._state = ProxyConnectionStates.REQUEST_SUBMIT - else: - log.error("Socks5 proxy authentication failure") - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - - if self._state == ProxyConnectionStates.REQUEST_SUBMIT: - if self._target_afi == socket.AF_INET: - addr_type = 1 - addr_len = 4 - elif self._target_afi == socket.AF_INET6: - addr_type = 4 - addr_len = 16 - else: - log.error("Unknown address family, %r", self._target_afi) - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - - self._buffer_out = struct.pack( - "!bbbb{}sh".format(addr_len), - 5, # version - 1, # command: connect - 0, # reserved - addr_type, # 1 for ipv4, 4 for ipv6 address - socket.inet_pton(self._target_afi, addr[0]), # either 4 or 16 bytes of actual address - addr[1], # port - ) - self._state = ProxyConnectionStates.REQUESTING - - if self._state == ProxyConnectionStates.REQUESTING: - self._flush_buf() - buf = self._read_buf(2) - if buf[0:2] == b"\x05\x00": - self._state = ProxyConnectionStates.READ_ADDRESS - else: - log.error("Proxy request failed: %r", buf[1:2]) - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - - if self._state == ProxyConnectionStates.READ_ADDRESS: - # we don't really care about the remote endpoint address, but need to clear the stream - buf = self._peek_buf(2) - if buf[0:2] == b"\x00\x01": - _ = self._read_buf(2 + 4 + 2) # ipv4 address + port - elif buf[0:2] == b"\x00\x05": - _ = self._read_buf(2 + 16 + 2) # ipv6 address + port - else: - log.error("Unrecognized remote address type %r", buf[1:2]) - self._state = ProxyConnectionStates.DISCONNECTED - self._sock.close() - return errno.ECONNREFUSED - self._state = ProxyConnectionStates.COMPLETE - - if self._state == ProxyConnectionStates.COMPLETE: - return 0 - - # not reached; - # Send and recv will raise socket error on EWOULDBLOCK/EAGAIN that is assumed to be handled by - # the caller. The caller re-enters this state machine from retry logic with timer or via select & family - log.error("Internal error, state %r not handled correctly", self._state) - self._state = ProxyConnectionStates.DISCONNECTED - if self._sock: - self._sock.close() - return errno.ECONNREFUSED diff --git a/venv/lib/python3.12/site-packages/kafka/structs.py b/venv/lib/python3.12/site-packages/kafka/structs.py index 16ba0da..bcb0236 100644 --- a/venv/lib/python3.12/site-packages/kafka/structs.py +++ b/venv/lib/python3.12/site-packages/kafka/structs.py @@ -42,7 +42,7 @@ Keyword Arguments: this partition metadata. """ PartitionMetadata = namedtuple("PartitionMetadata", - ["topic", "partition", "leader", "leader_epoch", "replicas", "isr", "offline_replicas", "error"]) + ["topic", "partition", "leader", "replicas", "isr", "error"]) """The Kafka offset commit API @@ -55,10 +55,10 @@ what time the commit was made, etc. Keyword Arguments: offset (int): The offset to be committed metadata (str): Non-null metadata - leader_epoch (int): The last known epoch from the leader / broker """ OffsetAndMetadata = namedtuple("OffsetAndMetadata", - ["offset", "metadata", "leader_epoch"]) + # TODO add leaderEpoch: OffsetAndMetadata(offset, leaderEpoch, metadata) + ["offset", "metadata"]) """An offset and timestamp tuple @@ -66,10 +66,9 @@ OffsetAndMetadata = namedtuple("OffsetAndMetadata", Keyword Arguments: offset (int): An offset timestamp (int): The timestamp associated to the offset - leader_epoch (int): The last known epoch from the leader / broker """ OffsetAndTimestamp = namedtuple("OffsetAndTimestamp", - ["offset", "timestamp", "leader_epoch"]) + ["offset", "timestamp"]) MemberInformation = namedtuple("MemberInformation", ["member_id", "client_id", "client_host", "member_metadata", "member_assignment"]) diff --git a/venv/lib/python3.12/site-packages/kafka/util.py b/venv/lib/python3.12/site-packages/kafka/util.py index 658c17d..e31d993 100644 --- a/venv/lib/python3.12/site-packages/kafka/util.py +++ b/venv/lib/python3.12/site-packages/kafka/util.py @@ -1,12 +1,8 @@ -from __future__ import absolute_import, division +from __future__ import absolute_import import binascii -import functools -import re -import time import weakref -from kafka.errors import KafkaTimeoutError from kafka.vendor import six @@ -23,69 +19,7 @@ if six.PY3: crc -= TO_SIGNED return crc else: - from binascii import crc32 # noqa: F401 - - -class Timer: - __slots__ = ('_start_at', '_expire_at', '_timeout_ms', '_error_message') - - def __init__(self, timeout_ms, error_message=None, start_at=None): - self._timeout_ms = timeout_ms - self._start_at = start_at or time.time() - if timeout_ms is not None: - self._expire_at = self._start_at + timeout_ms / 1000 - else: - self._expire_at = float('inf') - self._error_message = error_message - - @property - def expired(self): - return time.time() >= self._expire_at - - @property - def timeout_ms(self): - if self._timeout_ms is None: - return None - elif self._expire_at == float('inf'): - return float('inf') - remaining = self._expire_at - time.time() - if remaining < 0: - return 0 - else: - return int(remaining * 1000) - - @property - def elapsed_ms(self): - return int(1000 * (time.time() - self._start_at)) - - def maybe_raise(self): - if self.expired: - raise KafkaTimeoutError(self._error_message) - - def __str__(self): - return "Timer(%s ms remaining)" % (self.timeout_ms) - -# Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29 -TOPIC_MAX_LENGTH = 249 -TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$') - -def ensure_valid_topic_name(topic): - """ Ensures that the topic name is valid according to the kafka source. """ - - # See Kafka Source: - # https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java - if topic is None: - raise TypeError('All topics must not be None') - if not isinstance(topic, six.string_types): - raise TypeError('All topics must be strings') - if len(topic) == 0: - raise ValueError('All topics must be non-empty strings') - if topic == '.' or topic == '..': - raise ValueError('Topic name cannot be "." or ".."') - if len(topic) > TOPIC_MAX_LENGTH: - raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(TOPIC_MAX_LENGTH, topic)) - if not TOPIC_LEGAL_CHARS.match(topic): - raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic)) + from binascii import crc32 class WeakMethod(object): @@ -130,11 +64,3 @@ class Dict(dict): See: https://docs.python.org/2/library/weakref.html """ pass - - -def synchronized(func): - def wrapper(self, *args, **kwargs): - with self._lock: - return func(self, *args, **kwargs) - functools.update_wrapper(wrapper, func) - return wrapper diff --git a/venv/lib/python3.12/site-packages/kafka/vendor/selectors34.py b/venv/lib/python3.12/site-packages/kafka/vendor/selectors34.py index 7874903..ebf5d51 100644 --- a/venv/lib/python3.12/site-packages/kafka/vendor/selectors34.py +++ b/venv/lib/python3.12/site-packages/kafka/vendor/selectors34.py @@ -15,11 +15,7 @@ The following code adapted from trollius.selectors. from __future__ import absolute_import from abc import ABCMeta, abstractmethod -from collections import namedtuple -try: - from collections.abc import Mapping -except ImportError: - from collections import Mapping +from collections import namedtuple, Mapping from errno import EINTR import math import select diff --git a/venv/lib/python3.12/site-packages/kafka/vendor/six.py b/venv/lib/python3.12/site-packages/kafka/vendor/six.py index 3198213..3621a0a 100644 --- a/venv/lib/python3.12/site-packages/kafka/vendor/six.py +++ b/venv/lib/python3.12/site-packages/kafka/vendor/six.py @@ -1,6 +1,6 @@ # pylint: skip-file -# Copyright (c) 2010-2020 Benjamin Peterson +# Copyright (c) 2010-2017 Benjamin Peterson # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,7 @@ import sys import types __author__ = "Benjamin Peterson " -__version__ = "1.16.0" +__version__ = "1.11.0" # Useful for very coarse version differentiation. @@ -77,11 +77,6 @@ else: # https://github.com/dpkp/kafka-python/pull/979#discussion_r100403389 # del X -if PY34: - from importlib.util import spec_from_loader -else: - spec_from_loader = None - def _add_doc(func, doc): """Add documentation to a function.""" @@ -197,11 +192,6 @@ class _SixMetaPathImporter(object): return self return None - def find_spec(self, fullname, path, target=None): - if fullname in self.known_modules: - return spec_from_loader(fullname, self) - return None - def __get_module(self, fullname): try: return self.known_modules[fullname] @@ -239,12 +229,6 @@ class _SixMetaPathImporter(object): return None get_source = get_code # same as get_code - def create_module(self, spec): - return self.load_module(spec.name) - - def exec_module(self, module): - pass - _importer = _SixMetaPathImporter(__name__) @@ -269,7 +253,7 @@ _moved_attributes = [ MovedAttribute("reduce", "__builtin__", "functools"), MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections", "IterableUserDict", "UserDict"), + MovedAttribute("UserDict", "UserDict", "collections"), MovedAttribute("UserList", "UserList", "collections"), MovedAttribute("UserString", "UserString", "collections"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), @@ -277,11 +261,9 @@ _moved_attributes = [ MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), - MovedModule("collections_abc", "collections", "collections.abc" if sys.version_info >= (3, 3) else "collections"), MovedModule("copyreg", "copy_reg"), MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("dbm_ndbm", "dbm", "dbm.ndbm"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread" if sys.version_info < (3, 9) else "_thread"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), MovedModule("http_cookies", "Cookie", "http.cookies"), MovedModule("html_entities", "htmlentitydefs", "html.entities"), @@ -661,16 +643,13 @@ if PY3: import io StringIO = io.StringIO BytesIO = io.BytesIO - del io _assertCountEqual = "assertCountEqual" if sys.version_info[1] <= 1: _assertRaisesRegex = "assertRaisesRegexp" _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" else: _assertRaisesRegex = "assertRaisesRegex" _assertRegex = "assertRegex" - _assertNotRegex = "assertNotRegex" else: def b(s): return s @@ -692,7 +671,6 @@ else: _assertCountEqual = "assertItemsEqual" _assertRaisesRegex = "assertRaisesRegexp" _assertRegex = "assertRegexpMatches" - _assertNotRegex = "assertNotRegexpMatches" _add_doc(b, """Byte literal""") _add_doc(u, """Text literal""") @@ -709,10 +687,6 @@ def assertRegex(self, *args, **kwargs): return getattr(self, _assertRegex)(*args, **kwargs) -def assertNotRegex(self, *args, **kwargs): - return getattr(self, _assertNotRegex)(*args, **kwargs) - - if PY3: exec_ = getattr(moves.builtins, "exec") @@ -748,7 +722,16 @@ else: """) -if sys.version_info[:2] > (3,): +if sys.version_info[:2] == (3, 2): + exec_("""def raise_from(value, from_value): + try: + if from_value is None: + raise value + raise value from from_value + finally: + value = None +""") +elif sys.version_info[:2] > (3, 2): exec_("""def raise_from(value, from_value): try: raise value from from_value @@ -828,33 +811,13 @@ if sys.version_info[:2] < (3, 3): _add_doc(reraise, """Reraise an exception.""") if sys.version_info[0:2] < (3, 4): - # This does exactly the same what the :func:`py3:functools.update_wrapper` - # function does on Python versions after 3.2. It sets the ``__wrapped__`` - # attribute on ``wrapper`` object and it doesn't raise an error if any of - # the attributes mentioned in ``assigned`` and ``updated`` are missing on - # ``wrapped`` object. - def _update_wrapper(wrapper, wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - for attr in assigned: - try: - value = getattr(wrapped, attr) - except AttributeError: - continue - else: - setattr(wrapper, attr, value) - for attr in updated: - getattr(wrapper, attr).update(getattr(wrapped, attr, {})) - wrapper.__wrapped__ = wrapped - return wrapper - _update_wrapper.__doc__ = functools.update_wrapper.__doc__ - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, updated=functools.WRAPPER_UPDATES): - return functools.partial(_update_wrapper, wrapped=wrapped, - assigned=assigned, updated=updated) - wraps.__doc__ = functools.wraps.__doc__ - + def wrapper(f): + f = functools.wraps(wrapped, assigned, updated)(f) + f.__wrapped__ = wrapped + return f + return wrapper else: wraps = functools.wraps @@ -867,15 +830,7 @@ def with_metaclass(meta, *bases): class metaclass(type): def __new__(cls, name, this_bases, d): - if sys.version_info[:2] >= (3, 7): - # This version introduced PEP 560 that requires a bit - # of extra care (we mimic what is done by __build_class__). - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d['__orig_bases__'] = bases - else: - resolved_bases = bases - return meta(name, resolved_bases, d) + return meta(name, bases, d) @classmethod def __prepare__(cls, name, this_bases): @@ -895,75 +850,13 @@ def add_metaclass(metaclass): orig_vars.pop(slots_var) orig_vars.pop('__dict__', None) orig_vars.pop('__weakref__', None) - if hasattr(cls, '__qualname__'): - orig_vars['__qualname__'] = cls.__qualname__ return metaclass(cls.__name__, cls.__bases__, orig_vars) return wrapper -def ensure_binary(s, encoding='utf-8', errors='strict'): - """Coerce **s** to six.binary_type. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` - """ - if isinstance(s, binary_type): - return s - if isinstance(s, text_type): - return s.encode(encoding, errors) - raise TypeError("not expecting type '%s'" % type(s)) - - -def ensure_str(s, encoding='utf-8', errors='strict'): - """Coerce *s* to `str`. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - # Optimization: Fast return for the common case. - if type(s) is str: - return s - if PY2 and isinstance(s, text_type): - return s.encode(encoding, errors) - elif PY3 and isinstance(s, binary_type): - return s.decode(encoding, errors) - elif not isinstance(s, (text_type, binary_type)): - raise TypeError("not expecting type '%s'" % type(s)) - return s - - -def ensure_text(s, encoding='utf-8', errors='strict'): - """Coerce *s* to six.text_type. - - For Python 2: - - `unicode` -> `unicode` - - `str` -> `unicode` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if isinstance(s, binary_type): - return s.decode(encoding, errors) - elif isinstance(s, text_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - def python_2_unicode_compatible(klass): """ - A class decorator that defines __unicode__ and __str__ methods under Python 2. + A decorator that defines __unicode__ and __str__ methods under Python 2. Under Python 3 it does nothing. To support Python 2 and 3 with a single code base, define a __str__ method diff --git a/venv/lib/python3.12/site-packages/kafka/vendor/socketpair.py b/venv/lib/python3.12/site-packages/kafka/vendor/socketpair.py index 54d9087..b55e629 100644 --- a/venv/lib/python3.12/site-packages/kafka/vendor/socketpair.py +++ b/venv/lib/python3.12/site-packages/kafka/vendor/socketpair.py @@ -53,23 +53,6 @@ if not hasattr(socket, "socketpair"): raise finally: lsock.close() - - # Authenticating avoids using a connection from something else - # able to connect to {host}:{port} instead of us. - # We expect only AF_INET and AF_INET6 families. - try: - if ( - ssock.getsockname() != csock.getpeername() - or csock.getsockname() != ssock.getpeername() - ): - raise ConnectionError("Unexpected peer connection") - except: - # getsockname() and getpeername() can fail - # if either socket isn't connected. - ssock.close() - csock.close() - raise - return (ssock, csock) socket.socketpair = socketpair diff --git a/venv/lib/python3.12/site-packages/kafka/version.py b/venv/lib/python3.12/site-packages/kafka/version.py index 9f4696f..668c344 100644 --- a/venv/lib/python3.12/site-packages/kafka/version.py +++ b/venv/lib/python3.12/site-packages/kafka/version.py @@ -1 +1 @@ -__version__ = '2.2.15' +__version__ = '2.0.2' diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/LICENSE b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/LICENSE new file mode 100644 index 0000000..412a2b6 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/LICENSE @@ -0,0 +1,202 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2015 David Arthur + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/METADATA b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/METADATA new file mode 100644 index 0000000..c141a98 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/METADATA @@ -0,0 +1,190 @@ +Metadata-Version: 2.1 +Name: kafka-python +Version: 2.0.2 +Summary: Pure Python client for Apache Kafka +Home-page: https://github.com/dpkp/kafka-python +Author: Dana Powers +Author-email: dana.powers@gmail.com +License: Apache License 2.0 +Keywords: apache kafka +Platform: UNKNOWN +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 2.7 +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.4 +Classifier: Programming Language :: Python :: 3.5 +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Provides-Extra: crc32c +Requires-Dist: crc32c ; extra == 'crc32c' + +Kafka Python client +------------------------ + +.. image:: https://img.shields.io/badge/kafka-2.4%2C%202.3%2C%202.2%2C%202.1%2C%202.0%2C%201.1%2C%201.0%2C%200.11%2C%200.10%2C%200.9%2C%200.8-brightgreen.svg + :target: https://kafka-python.readthedocs.io/en/master/compatibility.html +.. image:: https://img.shields.io/pypi/pyversions/kafka-python.svg + :target: https://pypi.python.org/pypi/kafka-python +.. image:: https://coveralls.io/repos/dpkp/kafka-python/badge.svg?branch=master&service=github + :target: https://coveralls.io/github/dpkp/kafka-python?branch=master +.. image:: https://travis-ci.org/dpkp/kafka-python.svg?branch=master + :target: https://travis-ci.org/dpkp/kafka-python +.. image:: https://img.shields.io/badge/license-Apache%202-blue.svg + :target: https://github.com/dpkp/kafka-python/blob/master/LICENSE + +Python client for the Apache Kafka distributed stream processing system. +kafka-python is designed to function much like the official java client, with a +sprinkling of pythonic interfaces (e.g., consumer iterators). + +kafka-python is best used with newer brokers (0.9+), but is backwards-compatible with +older versions (to 0.8.0). Some features will only be enabled on newer brokers. +For example, fully coordinated consumer groups -- i.e., dynamic partition +assignment to multiple consumers in the same group -- requires use of 0.9+ kafka +brokers. Supporting this feature for earlier broker releases would require +writing and maintaining custom leadership election and membership / health +check code (perhaps using zookeeper or consul). For older brokers, you can +achieve something similar by manually assigning different partitions to each +consumer instance with config management tools like chef, ansible, etc. This +approach will work fine, though it does not support rebalancing on failures. +See +for more details. + +Please note that the master branch may contain unreleased features. For release +documentation, please see readthedocs and/or python's inline help. + +>>> pip install kafka-python + +KafkaConsumer +************* + +KafkaConsumer is a high-level message consumer, intended to operate as similarly +as possible to the official java client. Full support for coordinated +consumer groups requires use of kafka brokers that support the Group APIs: kafka v0.9+. + +See +for API and configuration details. + +The consumer iterator returns ConsumerRecords, which are simple namedtuples +that expose basic message attributes: topic, partition, offset, key, and value: + +>>> from kafka import KafkaConsumer +>>> consumer = KafkaConsumer('my_favorite_topic') +>>> for msg in consumer: +... print (msg) + +>>> # join a consumer group for dynamic partition assignment and offset commits +>>> from kafka import KafkaConsumer +>>> consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group') +>>> for msg in consumer: +... print (msg) + +>>> # manually assign the partition list for the consumer +>>> from kafka import TopicPartition +>>> consumer = KafkaConsumer(bootstrap_servers='localhost:1234') +>>> consumer.assign([TopicPartition('foobar', 2)]) +>>> msg = next(consumer) + +>>> # Deserialize msgpack-encoded values +>>> consumer = KafkaConsumer(value_deserializer=msgpack.loads) +>>> consumer.subscribe(['msgpackfoo']) +>>> for msg in consumer: +... assert isinstance(msg.value, dict) + +>>> # Access record headers. The returned value is a list of tuples +>>> # with str, bytes for key and value +>>> for msg in consumer: +... print (msg.headers) + +>>> # Get consumer metrics +>>> metrics = consumer.metrics() + +KafkaProducer +************* + +KafkaProducer is a high-level, asynchronous message producer. The class is +intended to operate as similarly as possible to the official java client. +See +for more details. + +>>> from kafka import KafkaProducer +>>> producer = KafkaProducer(bootstrap_servers='localhost:1234') +>>> for _ in range(100): +... producer.send('foobar', b'some_message_bytes') + +>>> # Block until a single message is sent (or timeout) +>>> future = producer.send('foobar', b'another_message') +>>> result = future.get(timeout=60) + +>>> # Block until all pending messages are at least put on the network +>>> # NOTE: This does not guarantee delivery or success! It is really +>>> # only useful if you configure internal batching using linger_ms +>>> producer.flush() + +>>> # Use a key for hashed-partitioning +>>> producer.send('foobar', key=b'foo', value=b'bar') + +>>> # Serialize json messages +>>> import json +>>> producer = KafkaProducer(value_serializer=lambda v: json.dumps(v).encode('utf-8')) +>>> producer.send('fizzbuzz', {'foo': 'bar'}) + +>>> # Serialize string keys +>>> producer = KafkaProducer(key_serializer=str.encode) +>>> producer.send('flipflap', key='ping', value=b'1234') + +>>> # Compress messages +>>> producer = KafkaProducer(compression_type='gzip') +>>> for i in range(1000): +... producer.send('foobar', b'msg %d' % i) + +>>> # Include record headers. The format is list of tuples with string key +>>> # and bytes value. +>>> producer.send('foobar', value=b'c29tZSB2YWx1ZQ==', headers=[('content-encoding', b'base64')]) + +>>> # Get producer performance metrics +>>> metrics = producer.metrics() + +Thread safety +************* + +The KafkaProducer can be used across threads without issue, unlike the +KafkaConsumer which cannot. + +While it is possible to use the KafkaConsumer in a thread-local manner, +multiprocessing is recommended. + +Compression +*********** + +kafka-python supports gzip compression/decompression natively. To produce or consume lz4 +compressed messages, you should install python-lz4 (pip install lz4). +To enable snappy compression/decompression install python-snappy (also requires snappy library). +See +for more information. + +Optimized CRC32 Validation +************************** + +Kafka uses CRC32 checksums to validate messages. kafka-python includes a pure +python implementation for compatibility. To improve performance for high-throughput +applications, kafka-python will use `crc32c` for optimized native code if installed. +See https://pypi.org/project/crc32c/ + +Protocol +******** + +A secondary goal of kafka-python is to provide an easy-to-use protocol layer +for interacting with kafka brokers via the python repl. This is useful for +testing, probing, and general experimentation. The protocol support is +leveraged to enable a KafkaClient.check_version() method that +probes a kafka broker and attempts to identify which version it is running +(0.8.0 to 2.4+). + + diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/RECORD b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/RECORD new file mode 100644 index 0000000..7df7344 --- /dev/null +++ b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/RECORD @@ -0,0 +1,203 @@ +kafka/__init__.py,sha256=5Phe46DuaS980Pnma3UBEoYxwVag2IBNsQLw3lGygVk,1077 +kafka/__pycache__/__init__.cpython-312.pyc,, +kafka/__pycache__/client_async.cpython-312.pyc,, +kafka/__pycache__/cluster.cpython-312.pyc,, +kafka/__pycache__/codec.cpython-312.pyc,, +kafka/__pycache__/conn.cpython-312.pyc,, +kafka/__pycache__/errors.cpython-312.pyc,, +kafka/__pycache__/future.cpython-312.pyc,, +kafka/__pycache__/scram.cpython-312.pyc,, +kafka/__pycache__/structs.cpython-312.pyc,, +kafka/__pycache__/util.cpython-312.pyc,, +kafka/__pycache__/version.cpython-312.pyc,, +kafka/admin/__init__.py,sha256=S_XxqyyV480_yXhttK79XZqNAmZyXRjspd3SoqYykE8,720 +kafka/admin/__pycache__/__init__.cpython-312.pyc,, +kafka/admin/__pycache__/acl_resource.cpython-312.pyc,, +kafka/admin/__pycache__/client.cpython-312.pyc,, +kafka/admin/__pycache__/config_resource.cpython-312.pyc,, +kafka/admin/__pycache__/new_partitions.cpython-312.pyc,, +kafka/admin/__pycache__/new_topic.cpython-312.pyc,, +kafka/admin/acl_resource.py,sha256=ak_dUsSni4SyP0ORbSKenZpwTy0Ykxq3FSt_9XgLR8k,8265 +kafka/admin/client.py,sha256=4qr9DuDoDjvkPN8jn7dCw1vZtE-O1JbaLWlhp-j5fP4,63518 +kafka/admin/config_resource.py,sha256=_JZWN_Q7jbuTtq2kdfHxWyTt_jI1LI-xnVGsf6oYGyY,1039 +kafka/admin/new_partitions.py,sha256=rYSb7S6VL706ZauSmiN5J9GDsep0HYRmkkAZUgT2JIg,757 +kafka/admin/new_topic.py,sha256=fvezLP9JXumqX-nU27Fgo0tj4d85ybcJgKluQImm3-0,1306 +kafka/client_async.py,sha256=Tu0-OMb5IWJYupibS480mPUBYhiwVK9wCra75RoYZfw,45265 +kafka/cluster.py,sha256=mQTwoOgLtDj57DrQp1pQikmZVFDFuy9KW4ZKJd7l0_o,14822 +kafka/codec.py,sha256=IrYqQMWJ39V3kB6IXlqTGUQ4bFWENxxbCxMoCdZ0wkg,9548 +kafka/conn.py,sha256=zsYS6Fh2CCM3JWqixfi3sCQ19nFOwtg-nuen84Ak0Wc,68402 +kafka/consumer/__init__.py,sha256=NDdvtyuJgFyQZahqL9i5sYXGP6rOMIXWwHQEaZ1fCcs,122 +kafka/consumer/__pycache__/__init__.cpython-312.pyc,, +kafka/consumer/__pycache__/fetcher.cpython-312.pyc,, +kafka/consumer/__pycache__/group.cpython-312.pyc,, +kafka/consumer/__pycache__/subscription_state.cpython-312.pyc,, +kafka/consumer/fetcher.py,sha256=2ETL3j5Fsq3dRZ4iJ4INGNfAurhTGYLKzo0lM_LrBjQ,47679 +kafka/consumer/group.py,sha256=V4qpj6LmAJJTF4TvXzszKnt0Rd9p4T5Vrn4elftbpZg,58768 +kafka/consumer/subscription_state.py,sha256=2SgH37QISlIZh-v0KnNJW0n1d_sMLOxxW7UxkhsC5R0,21665 +kafka/coordinator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +kafka/coordinator/__pycache__/__init__.cpython-312.pyc,, +kafka/coordinator/__pycache__/base.cpython-312.pyc,, +kafka/coordinator/__pycache__/consumer.cpython-312.pyc,, +kafka/coordinator/__pycache__/heartbeat.cpython-312.pyc,, +kafka/coordinator/__pycache__/protocol.cpython-312.pyc,, +kafka/coordinator/assignors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +kafka/coordinator/assignors/__pycache__/__init__.cpython-312.pyc,, +kafka/coordinator/assignors/__pycache__/abstract.cpython-312.pyc,, +kafka/coordinator/assignors/__pycache__/range.cpython-312.pyc,, +kafka/coordinator/assignors/__pycache__/roundrobin.cpython-312.pyc,, +kafka/coordinator/assignors/abstract.py,sha256=belUnCkuw70HJ8HTWYIgVrT6pJmIBBrTl1vkO-bN1C0,1507 +kafka/coordinator/assignors/range.py,sha256=PXFkkb505pL1uJEQMTvXCOp0Rckm-qkoKqTGyn082qM,2912 +kafka/coordinator/assignors/roundrobin.py,sha256=Xt_TOvCtcdozjZSg1cxixLAPyWz1aTpDL8v1vDhX960,3776 +kafka/coordinator/assignors/sticky/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +kafka/coordinator/assignors/sticky/__pycache__/__init__.cpython-312.pyc,, +kafka/coordinator/assignors/sticky/__pycache__/partition_movements.cpython-312.pyc,, +kafka/coordinator/assignors/sticky/__pycache__/sorted_set.cpython-312.pyc,, +kafka/coordinator/assignors/sticky/__pycache__/sticky_assignor.cpython-312.pyc,, +kafka/coordinator/assignors/sticky/partition_movements.py,sha256=npydNO-YCG_cv--U--9CPTLGTbTWahiw_Ek295ayBjQ,6476 +kafka/coordinator/assignors/sticky/sorted_set.py,sha256=lOckfQ7vcOMNnIx5WjfHhKC_MgToeOxbp9vc_4tPIzs,1904 +kafka/coordinator/assignors/sticky/sticky_assignor.py,sha256=xNnyNy28vYWmOE8wSlxqdxHAaUX_PX6hQz8EoA_p62k,34114 +kafka/coordinator/base.py,sha256=zWIN0FI8KMzPJwoqfcIG_r6ppv-531HqwUyQGEbn0JE,46140 +kafka/coordinator/consumer.py,sha256=BCZlUHCf343K8w3CTd9Osqj9u1eE7b73wMa22_BHOVk,38920 +kafka/coordinator/heartbeat.py,sha256=WJqZGnXHG7TTq1Is3D0mKDis-bBwWVZlSgQiUoZv1jU,2304 +kafka/coordinator/protocol.py,sha256=wTaIOnUVbj0CKXZ82FktZo-zMRvOCk3hdQAoHJ62e3I,1041 +kafka/errors.py,sha256=MeK0fOtHbstT-HqseH9bBnlW9ZKv9envcrUDpWn4BdA,16324 +kafka/future.py,sha256=uJJLfKMFsdEHgHSyvCzQe_AXNrToiZE-MynZVNhk9qc,2474 +kafka/metrics/__init__.py,sha256=b82LCjV5BgisjmIc3pn11CqFpme5grtIFHWiH8C_R0U,574 +kafka/metrics/__pycache__/__init__.cpython-312.pyc,, +kafka/metrics/__pycache__/compound_stat.cpython-312.pyc,, +kafka/metrics/__pycache__/dict_reporter.cpython-312.pyc,, +kafka/metrics/__pycache__/kafka_metric.cpython-312.pyc,, +kafka/metrics/__pycache__/measurable.cpython-312.pyc,, +kafka/metrics/__pycache__/measurable_stat.cpython-312.pyc,, +kafka/metrics/__pycache__/metric_config.cpython-312.pyc,, +kafka/metrics/__pycache__/metric_name.cpython-312.pyc,, +kafka/metrics/__pycache__/metrics.cpython-312.pyc,, +kafka/metrics/__pycache__/metrics_reporter.cpython-312.pyc,, +kafka/metrics/__pycache__/quota.cpython-312.pyc,, +kafka/metrics/__pycache__/stat.cpython-312.pyc,, +kafka/metrics/compound_stat.py,sha256=CNnP71sNnViUhCDFHimdlXBb8G-PXrbqg6FfSS-SkVc,776 +kafka/metrics/dict_reporter.py,sha256=OvZ6SUFp-Yk3tNaWbC0ul0WXncp42ymg8bHw3O6MITA,2567 +kafka/metrics/kafka_metric.py,sha256=fnkHEmooLjCHRoAtti6rOymQLLMN1D276ma1bYAFJDY,933 +kafka/metrics/measurable.py,sha256=g5mp1c9816SRgJdgHXklTNqDoDnbeYp-opjoV3DOr7Q,770 +kafka/metrics/measurable_stat.py,sha256=NcOQfOieQV8m6mMClDFJDY1ibE-RmIPrth15W5XPDdU,503 +kafka/metrics/metric_config.py,sha256=SsibZd09icYgqLrMhXXW-pQVICPn0yYADrD8txdIMw0,1154 +kafka/metrics/metric_name.py,sha256=l25XFsjpOK6nv4Al_bRKFtt-UHeeqmkhBhaEfGMp9Qo,3419 +kafka/metrics/metrics.py,sha256=hEBGp8afj39FllLV021g6c8cZ2_KqwfWiUnRzc7SdlE,10314 +kafka/metrics/metrics_reporter.py,sha256=2qZRLiyOUzB-2ULBtOhXOtjU9phElIlundjPluYYXgE,1398 +kafka/metrics/quota.py,sha256=34psI-neVNQ-VeaD2KMvpx5bBIJp4rJcsJ6rvC91Wgk,1128 +kafka/metrics/stat.py,sha256=T_YGImowGnUnGgeNZ-r4buk1PdM_7NHG15PzTHieyZo,628 +kafka/metrics/stats/__init__.py,sha256=sHcT6pvQCt-s_aow5_QRy9Z5bRV4ShBCZlin51f--Ro,629 +kafka/metrics/stats/__pycache__/__init__.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/avg.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/count.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/histogram.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/max_stat.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/min_stat.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/percentile.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/percentiles.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/rate.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/sampled_stat.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/sensor.cpython-312.pyc,, +kafka/metrics/stats/__pycache__/total.cpython-312.pyc,, +kafka/metrics/stats/avg.py,sha256=WdyAFz37aQhvzIqkvbP4SGUDz9gZ-eua_Urhygjc2xU,678 +kafka/metrics/stats/count.py,sha256=dy5sdPVLOwsiVcfOawEx7EOyjSTXxUKqsJl84sjVZbQ,487 +kafka/metrics/stats/histogram.py,sha256=5-1V_juSsWUV-0e0F4egSYjJvXcBOJF-fzLyBXhywts,2874 +kafka/metrics/stats/max_stat.py,sha256=jBkG-ozWpsH1qmPNzhiICpJVKxwbep-VHwOEGRnCHJo,546 +kafka/metrics/stats/min_stat.py,sha256=gI0d7RUJB5En7PS_TT3WZ_gJl8tOZi81o-F0JD3S_oc,568 +kafka/metrics/stats/percentile.py,sha256=ZQoda6vpS9v5LopQJL64SyCWO9160SVELxQ9S1KKit8,342 +kafka/metrics/stats/percentiles.py,sha256=n4Uqt7qyRUkrkOWZvymfKx-7ANvopDXgLeXH1QRC_rk,2901 +kafka/metrics/stats/rate.py,sha256=-zkYp8kZrhy01hDaPCYcKqvRycY1esyNPRQrqK_JH5s,4533 +kafka/metrics/stats/sampled_stat.py,sha256=rb42q6MIkAm2LJh4H4QoC6OmP_zJ-3jbBrva0kf_0J0,3458 +kafka/metrics/stats/sensor.py,sha256=sxX2SxkTOuLA-VPIRu4LyJnJYsqpvEybd6oiVS2Lf5Y,5129 +kafka/metrics/stats/total.py,sha256=tUq8rPW96OVzjVz0aOsBkEe2Ljkv6JaBa5TMCocYydg,418 +kafka/oauth/__init__.py,sha256=nNNQI8KQjCXCEdSrgnvyQFLTQqGk0x7J1MNu82PA0i0,95 +kafka/oauth/__pycache__/__init__.cpython-312.pyc,, +kafka/oauth/__pycache__/abstract.cpython-312.pyc,, +kafka/oauth/abstract.py,sha256=g-6pEw5amtXfTfBcV6quD-BVuoNH4SrFPw_kvgWPmkY,1296 +kafka/partitioner/__init__.py,sha256=Fks3C5_kokVWYw1Ad5wv0sVVzaaBtOejL-2bIL1yRII,158 +kafka/partitioner/__pycache__/__init__.cpython-312.pyc,, +kafka/partitioner/__pycache__/default.cpython-312.pyc,, +kafka/partitioner/default.py,sha256=tW-RC1PWIPRDEbeEAaPTLn-00oiZnXoVouEk9AnYE4w,2879 +kafka/producer/__init__.py,sha256=i3Wxih0NHjmqCkRNE54ial8fBp9siqabUE6ZGyL6oX8,122 +kafka/producer/__pycache__/__init__.cpython-312.pyc,, +kafka/producer/__pycache__/buffer.cpython-312.pyc,, +kafka/producer/__pycache__/future.cpython-312.pyc,, +kafka/producer/__pycache__/kafka.cpython-312.pyc,, +kafka/producer/__pycache__/record_accumulator.cpython-312.pyc,, +kafka/producer/__pycache__/sender.cpython-312.pyc,, +kafka/producer/buffer.py,sha256=1ucTlZOQKBa37c_cKUNgFmHpO0P1WEQ9XDqTxmsOrG0,4370 +kafka/producer/future.py,sha256=CEUWEmYKeTMMPjP-SjSJY1RZ2QFn7ebcK0G0sSWx4xo,3039 +kafka/producer/kafka.py,sha256=Am_Tm2FtDcgqBDdqAq31bJnsRlqqxYS5_LX7MVObGok,37649 +kafka/producer/record_accumulator.py,sha256=SyvYJVD7J1s4G2omjkaO8-Q6Yn2MukGTWfSEnFQlIfY,24994 +kafka/producer/sender.py,sha256=vqmozAfH6WIxE5LctZj16Cux--ASx87wIYCnU0nJEXI,22968 +kafka/protocol/__init__.py,sha256=6LgsMXp87XMcvBCRNIwaOauyBdUqsyWOdNAlsAA4zxY,1075 +kafka/protocol/__pycache__/__init__.cpython-312.pyc,, +kafka/protocol/__pycache__/abstract.cpython-312.pyc,, +kafka/protocol/__pycache__/admin.cpython-312.pyc,, +kafka/protocol/__pycache__/api.cpython-312.pyc,, +kafka/protocol/__pycache__/commit.cpython-312.pyc,, +kafka/protocol/__pycache__/fetch.cpython-312.pyc,, +kafka/protocol/__pycache__/frame.cpython-312.pyc,, +kafka/protocol/__pycache__/group.cpython-312.pyc,, +kafka/protocol/__pycache__/message.cpython-312.pyc,, +kafka/protocol/__pycache__/metadata.cpython-312.pyc,, +kafka/protocol/__pycache__/offset.cpython-312.pyc,, +kafka/protocol/__pycache__/parser.cpython-312.pyc,, +kafka/protocol/__pycache__/pickle.cpython-312.pyc,, +kafka/protocol/__pycache__/produce.cpython-312.pyc,, +kafka/protocol/__pycache__/struct.cpython-312.pyc,, +kafka/protocol/__pycache__/types.cpython-312.pyc,, +kafka/protocol/abstract.py,sha256=LUYVZkjlEnZzvklkgrsfz8iZKNSFhS8cP-Q-N0jqdQo,385 +kafka/protocol/admin.py,sha256=6ncxMhsX6pSJI5eiCih0RRYE-3bkeKbVuiMA8LptYzA,25122 +kafka/protocol/api.py,sha256=xCwwkFasFBnzG7ER4a-dN40NmgkxLlwpu__ADXfkOz4,2493 +kafka/protocol/commit.py,sha256=_aztH5jgEdkIwmp7HF4F96N90_0s3Cbb-O-r634m2HI,6888 +kafka/protocol/fetch.py,sha256=WEYHr2MINaKIqtN1tRH9Iui7xQiWot333TkgtjOpR9E,11014 +kafka/protocol/frame.py,sha256=SejRBK5urTD-2UzcVM2OxTgC73qDxfF3nlBPl9sjsfY,734 +kafka/protocol/group.py,sha256=sLQYQjPukVHK63UM1wt-YD2CniI7CO8rrG3tLi4zdIs,5599 +kafka/protocol/message.py,sha256=9wNwJvfl9bsrdk_YcxbmAFjgvwZ5R1EBLSif2KILg9s,7657 +kafka/protocol/metadata.py,sha256=MgCDeXcMipy2kLxOuwslk-7qivPzvC9EpyacTfaXRvE,6116 +kafka/protocol/offset.py,sha256=o3MXGbiezLNcEmnQRhlCPJsmUxYoiIgWXgrVoU6ilB4,4707 +kafka/protocol/parser.py,sha256=T6C_UWOSIbbyfRihvaqLtyCiI0QnUTKha-OMiXFMj1w,6963 +kafka/protocol/pickle.py,sha256=FGEv-1l1aXY3TogqzCwOS1gCNpEg6-xNLbrysqNdHcs,920 +kafka/protocol/produce.py,sha256=Bd8tgRly7mknYcfrqlRk07vl--daLwf-nTuvGcGB1k0,6460 +kafka/protocol/struct.py,sha256=DxktwrPp1pj4b7Vne2H5n-xWjgx9jpCmf0ydZkeIjoY,2380 +kafka/protocol/types.py,sha256=KghOyWIU5Qkj7ZkP4G3AW7ILMJ9s5QZMrRO_8rDkPFg,5427 +kafka/record/__init__.py,sha256=Q20hP_R5XX3AEnAlPkpoWzTLShESvxUT2OLXmI-JYEQ,129 +kafka/record/__pycache__/__init__.cpython-312.pyc,, +kafka/record/__pycache__/_crc32c.cpython-312.pyc,, +kafka/record/__pycache__/abc.cpython-312.pyc,, +kafka/record/__pycache__/default_records.cpython-312.pyc,, +kafka/record/__pycache__/legacy_records.cpython-312.pyc,, +kafka/record/__pycache__/memory_records.cpython-312.pyc,, +kafka/record/__pycache__/util.cpython-312.pyc,, +kafka/record/_crc32c.py,sha256=Nr0O4kpyPg379hDY-svjnM4CSKtlEBhanxgb5Y3PlEQ,5753 +kafka/record/abc.py,sha256=YoOlVaBtWn8gLcJusLbo3zZhd9BI8aZGc06NHaDIGxI,3465 +kafka/record/default_records.py,sha256=zl9dWeUap6pGZt2ixGxtjtnB5y9HKWi05qN2zAC0Cz0,21023 +kafka/record/legacy_records.py,sha256=4XaxdFoWTSHoFiBA5i-Dvscy2pj83CdteLWD8BcHYJQ,17820 +kafka/record/memory_records.py,sha256=VjUsbLtIU0y5HM9eYlUjjpCFMwTej67uhrIL5CgtYq8,6344 +kafka/record/util.py,sha256=LDajBWdYVetmXts_t9Q76CxEx7njgC9LnjMgz9yPEMM,3556 +kafka/scram.py,sha256=Ei9FPJ3ajfTQRRGaDs1RjDKTpplkoPQKvk-u6Dkbh_U,3034 +kafka/serializer/__init__.py,sha256=_I4utl_8nNhcRzLLezFtwYX5akk6QKYmxa1HanRlYPU,103 +kafka/serializer/__pycache__/__init__.cpython-312.pyc,, +kafka/serializer/__pycache__/abstract.cpython-312.pyc,, +kafka/serializer/abstract.py,sha256=doiXDkMYt2SEHRarBdd8xVZKvr5S1qPdNEtl4syWA6Q,486 +kafka/structs.py,sha256=m2o20GOJBDJIiP7YUj1Lhk2bAXKLt9H48NloBJ39Om8,2927 +kafka/util.py,sha256=nu0h9bXBv6Hl8v7MW07o8NFe4zoZNw6C6ehFBPazOpU,1856 +kafka/vendor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +kafka/vendor/__pycache__/__init__.cpython-312.pyc,, +kafka/vendor/__pycache__/enum34.cpython-312.pyc,, +kafka/vendor/__pycache__/selectors34.cpython-312.pyc,, +kafka/vendor/__pycache__/six.cpython-312.pyc,, +kafka/vendor/__pycache__/socketpair.cpython-312.pyc,, +kafka/vendor/enum34.py,sha256=-u-lxAiJMt6ru4Do7NUDY9OpeWkYJMksb2xengJawFE,31204 +kafka/vendor/selectors34.py,sha256=40NdCvzBONYxE_IEQlvLma7Zftl_pCnkUEQPGTT_JOk,20502 +kafka/vendor/six.py,sha256=rz93m7VnaruMHlKvgxBfTW8VjgTmLoeY42OunxMCxoY,31133 +kafka/vendor/socketpair.py,sha256=xz_yjMNpIY5cO4eh7oqyU9caK9kJRnbeJtV1lGb0Sv8,2127 +kafka/version.py,sha256=kumiGImhzOTlTrRM-6jDo2mNnVHGO_2vxtrhB0nzAiw,22 +kafka_python-2.0.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +kafka_python-2.0.2.dist-info/LICENSE,sha256=KrcY2_AbbyVIpWwikBS96hmNAMYJqbtevJ9ghAvdT-w,11343 +kafka_python-2.0.2.dist-info/METADATA,sha256=hHQUfDUVxFLTZjLQL64onhVht0WQuGh_gcBY5cK0XCI,7807 +kafka_python-2.0.2.dist-info/RECORD,, +kafka_python-2.0.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +kafka_python-2.0.2.dist-info/WHEEL,sha256=HX-v9-noUkyUoxyZ1PMSuS7auUxDAR4VBdoYLqD0xws,110 +kafka_python-2.0.2.dist-info/top_level.txt,sha256=IivJz7l5WHdLNDT6RIiVAlhjQzYRwGqBBmKHZ7WjPeM,6 diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/WHEEL b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/WHEEL similarity index 70% rename from venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/WHEEL index 5f133db..c8240f0 100644 --- a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.9.0) +Generator: bdist_wheel (0.33.1) Root-Is-Purelib: true Tag: py2-none-any Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/kafka_python-2.0.2.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/METADATA b/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/METADATA deleted file mode 100644 index 7f86dd7..0000000 --- a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/METADATA +++ /dev/null @@ -1,277 +0,0 @@ -Metadata-Version: 2.4 -Name: kafka-python -Version: 2.2.15 -Summary: Pure Python client for Apache Kafka -Author-email: Dana Powers -Project-URL: Homepage, https://github.com/dpkp/kafka-python -Keywords: apache kafka,kafka -Classifier: Development Status :: 5 - Production/Stable -Classifier: Intended Audience :: Developers -Classifier: License :: OSI Approved :: Apache Software License -Classifier: Programming Language :: Python -Classifier: Programming Language :: Python :: 2 -Classifier: Programming Language :: Python :: 2.7 -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.4 -Classifier: Programming Language :: Python :: 3.5 -Classifier: Programming Language :: Python :: 3.6 -Classifier: Programming Language :: Python :: 3.7 -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 -Classifier: Programming Language :: Python :: Implementation :: CPython -Classifier: Programming Language :: Python :: Implementation :: PyPy -Classifier: Topic :: Software Development :: Libraries :: Python Modules -Description-Content-Type: text/x-rst -Provides-Extra: crc32c -Requires-Dist: crc32c; extra == "crc32c" -Provides-Extra: lz4 -Requires-Dist: lz4; extra == "lz4" -Provides-Extra: snappy -Requires-Dist: python-snappy; extra == "snappy" -Provides-Extra: zstd -Requires-Dist: zstandard; extra == "zstd" -Provides-Extra: testing -Requires-Dist: pytest; extra == "testing" -Requires-Dist: mock; python_version < "3.3" and extra == "testing" -Requires-Dist: pytest-mock; extra == "testing" -Requires-Dist: pytest-timeout; extra == "testing" -Provides-Extra: benchmarks -Requires-Dist: pyperf; extra == "benchmarks" - -Kafka Python client ------------------------- - -.. image:: https://img.shields.io/badge/kafka-4.0--0.8-brightgreen.svg - :target: https://kafka-python.readthedocs.io/en/master/compatibility.html -.. image:: https://img.shields.io/pypi/pyversions/kafka-python.svg - :target: https://pypi.python.org/pypi/kafka-python -.. image:: https://coveralls.io/repos/dpkp/kafka-python/badge.svg?branch=master&service=github - :target: https://coveralls.io/github/dpkp/kafka-python?branch=master -.. image:: https://img.shields.io/badge/license-Apache%202-blue.svg - :target: https://github.com/dpkp/kafka-python/blob/master/LICENSE -.. image:: https://img.shields.io/pypi/dw/kafka-python.svg - :target: https://pypistats.org/packages/kafka-python -.. image:: https://img.shields.io/pypi/v/kafka-python.svg - :target: https://pypi.org/project/kafka-python -.. image:: https://img.shields.io/pypi/implementation/kafka-python - :target: https://github.com/dpkp/kafka-python/blob/master/setup.py - - - -Python client for the Apache Kafka distributed stream processing system. -kafka-python is designed to function much like the official java client, with a -sprinkling of pythonic interfaces (e.g., consumer iterators). - -kafka-python is best used with newer brokers (0.9+), but is backwards-compatible with -older versions (to 0.8.0). Some features will only be enabled on newer brokers. -For example, fully coordinated consumer groups -- i.e., dynamic partition -assignment to multiple consumers in the same group -- requires use of 0.9+ kafka -brokers. Supporting this feature for earlier broker releases would require -writing and maintaining custom leadership election and membership / health -check code (perhaps using zookeeper or consul). For older brokers, you can -achieve something similar by manually assigning different partitions to each -consumer instance with config management tools like chef, ansible, etc. This -approach will work fine, though it does not support rebalancing on failures. -See https://kafka-python.readthedocs.io/en/master/compatibility.html -for more details. - -Please note that the master branch may contain unreleased features. For release -documentation, please see readthedocs and/or python's inline help. - -.. code-block:: bash - - $ pip install kafka-python - - -KafkaConsumer -************* - -KafkaConsumer is a high-level message consumer, intended to operate as similarly -as possible to the official java client. Full support for coordinated -consumer groups requires use of kafka brokers that support the Group APIs: kafka v0.9+. - -See https://kafka-python.readthedocs.io/en/master/apidoc/KafkaConsumer.html -for API and configuration details. - -The consumer iterator returns ConsumerRecords, which are simple namedtuples -that expose basic message attributes: topic, partition, offset, key, and value: - -.. code-block:: python - - from kafka import KafkaConsumer - consumer = KafkaConsumer('my_favorite_topic') - for msg in consumer: - print (msg) - -.. code-block:: python - - # join a consumer group for dynamic partition assignment and offset commits - from kafka import KafkaConsumer - consumer = KafkaConsumer('my_favorite_topic', group_id='my_favorite_group') - for msg in consumer: - print (msg) - -.. code-block:: python - - # manually assign the partition list for the consumer - from kafka import TopicPartition - consumer = KafkaConsumer(bootstrap_servers='localhost:1234') - consumer.assign([TopicPartition('foobar', 2)]) - msg = next(consumer) - -.. code-block:: python - - # Deserialize msgpack-encoded values - consumer = KafkaConsumer(value_deserializer=msgpack.loads) - consumer.subscribe(['msgpackfoo']) - for msg in consumer: - assert isinstance(msg.value, dict) - -.. code-block:: python - - # Access record headers. The returned value is a list of tuples - # with str, bytes for key and value - for msg in consumer: - print (msg.headers) - -.. code-block:: python - - # Read only committed messages from transactional topic - consumer = KafkaConsumer(isolation_level='read_committed') - consumer.subscribe(['txn_topic']) - for msg in consumer: - print(msg) - -.. code-block:: python - - # Get consumer metrics - metrics = consumer.metrics() - - -KafkaProducer -************* - -KafkaProducer is a high-level, asynchronous message producer. The class is -intended to operate as similarly as possible to the official java client. -See https://kafka-python.readthedocs.io/en/master/apidoc/KafkaProducer.html -for more details. - -.. code-block:: python - - from kafka import KafkaProducer - producer = KafkaProducer(bootstrap_servers='localhost:1234') - for _ in range(100): - producer.send('foobar', b'some_message_bytes') - -.. code-block:: python - - # Block until a single message is sent (or timeout) - future = producer.send('foobar', b'another_message') - result = future.get(timeout=60) - -.. code-block:: python - - # Block until all pending messages are at least put on the network - # NOTE: This does not guarantee delivery or success! It is really - # only useful if you configure internal batching using linger_ms - producer.flush() - -.. code-block:: python - - # Use a key for hashed-partitioning - producer.send('foobar', key=b'foo', value=b'bar') - -.. code-block:: python - - # Serialize json messages - import json - producer = KafkaProducer(value_serializer=lambda v: json.dumps(v).encode('utf-8')) - producer.send('fizzbuzz', {'foo': 'bar'}) - -.. code-block:: python - - # Serialize string keys - producer = KafkaProducer(key_serializer=str.encode) - producer.send('flipflap', key='ping', value=b'1234') - -.. code-block:: python - - # Compress messages - producer = KafkaProducer(compression_type='gzip') - for i in range(1000): - producer.send('foobar', b'msg %d' % i) - -.. code-block:: python - - # Use transactions - producer = KafkaProducer(transactional_id='fizzbuzz') - producer.init_transactions() - producer.begin_transaction() - future = producer.send('txn_topic', value=b'yes') - future.get() # wait for successful produce - producer.commit_transaction() # commit the transaction - - producer.begin_transaction() - future = producer.send('txn_topic', value=b'no') - future.get() # wait for successful produce - producer.abort_transaction() # abort the transaction - -.. code-block:: python - - # Include record headers. The format is list of tuples with string key - # and bytes value. - producer.send('foobar', value=b'c29tZSB2YWx1ZQ==', headers=[('content-encoding', b'base64')]) - -.. code-block:: python - - # Get producer performance metrics - metrics = producer.metrics() - - -Thread safety -************* - -The KafkaProducer can be used across threads without issue, unlike the -KafkaConsumer which cannot. - -While it is possible to use the KafkaConsumer in a thread-local manner, -multiprocessing is recommended. - - -Compression -*********** - -kafka-python supports the following compression formats: - -- gzip -- LZ4 -- Snappy -- Zstandard (zstd) - -gzip is supported natively, the others require installing additional libraries. -See https://kafka-python.readthedocs.io/en/master/install.html for more information. - - -Optimized CRC32 Validation -************************** - -Kafka uses CRC32 checksums to validate messages. kafka-python includes a pure -python implementation for compatibility. To improve performance for high-throughput -applications, kafka-python will use `crc32c` for optimized native code if installed. -See https://kafka-python.readthedocs.io/en/master/install.html for installation instructions. -See https://pypi.org/project/crc32c/ for details on the underlying crc32c lib. - - -Protocol -******** - -A secondary goal of kafka-python is to provide an easy-to-use protocol layer -for interacting with kafka brokers via the python repl. This is useful for -testing, probing, and general experimentation. The protocol support is -leveraged to enable a KafkaClient.check_version() method that -probes a kafka broker and attempts to identify which version it is running -(0.8.0 to 2.6+). diff --git a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/RECORD b/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/RECORD deleted file mode 100644 index 14b6ad4..0000000 --- a/venv/lib/python3.12/site-packages/kafka_python-2.2.15.dist-info/RECORD +++ /dev/null @@ -1,250 +0,0 @@ -kafka/__init__.py,sha256=4dvHKZAxmD_4tfJ5wGcRV2X78vPcm8vsUoqceULevjA,1077 -kafka/__pycache__/__init__.cpython-312.pyc,, -kafka/__pycache__/client_async.cpython-312.pyc,, -kafka/__pycache__/cluster.cpython-312.pyc,, -kafka/__pycache__/codec.cpython-312.pyc,, -kafka/__pycache__/conn.cpython-312.pyc,, -kafka/__pycache__/errors.cpython-312.pyc,, -kafka/__pycache__/future.cpython-312.pyc,, -kafka/__pycache__/socks5_wrapper.cpython-312.pyc,, -kafka/__pycache__/structs.cpython-312.pyc,, -kafka/__pycache__/util.cpython-312.pyc,, -kafka/__pycache__/version.cpython-312.pyc,, -kafka/admin/__init__.py,sha256=S_XxqyyV480_yXhttK79XZqNAmZyXRjspd3SoqYykE8,720 -kafka/admin/__pycache__/__init__.cpython-312.pyc,, -kafka/admin/__pycache__/acl_resource.cpython-312.pyc,, -kafka/admin/__pycache__/client.cpython-312.pyc,, -kafka/admin/__pycache__/config_resource.cpython-312.pyc,, -kafka/admin/__pycache__/new_partitions.cpython-312.pyc,, -kafka/admin/__pycache__/new_topic.cpython-312.pyc,, -kafka/admin/acl_resource.py,sha256=ak_dUsSni4SyP0ORbSKenZpwTy0Ykxq3FSt_9XgLR8k,8265 -kafka/admin/client.py,sha256=94UpHTsgzvhOoB6_1QLeKxvZKlStKfI96xuWyaY9_Sc,78814 -kafka/admin/config_resource.py,sha256=_JZWN_Q7jbuTtq2kdfHxWyTt_jI1LI-xnVGsf6oYGyY,1039 -kafka/admin/new_partitions.py,sha256=rYSb7S6VL706ZauSmiN5J9GDsep0HYRmkkAZUgT2JIg,757 -kafka/admin/new_topic.py,sha256=fvezLP9JXumqX-nU27Fgo0tj4d85ybcJgKluQImm3-0,1306 -kafka/benchmarks/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka/benchmarks/__pycache__/__init__.cpython-312.pyc,, -kafka/benchmarks/__pycache__/consumer_performance.cpython-312.pyc,, -kafka/benchmarks/__pycache__/load_example.cpython-312.pyc,, -kafka/benchmarks/__pycache__/producer_performance.cpython-312.pyc,, -kafka/benchmarks/__pycache__/record_batch_compose.cpython-312.pyc,, -kafka/benchmarks/__pycache__/record_batch_read.cpython-312.pyc,, -kafka/benchmarks/__pycache__/varint_speed.cpython-312.pyc,, -kafka/benchmarks/consumer_performance.py,sha256=UFW2rVHX4rdwLRRQqsoUoMR7FbA9hwYsCNkQA1qNvuQ,4932 -kafka/benchmarks/load_example.py,sha256=feaU2Qic11hZfi3rKTI4Fezxmu-kvNz17m2wJmZMjmw,3491 -kafka/benchmarks/producer_performance.py,sha256=jy1Q4zyamPrluh3SUKxiH3z6wY-8sSFG3yJvJbnUFO0,5210 -kafka/benchmarks/record_batch_compose.py,sha256=CnUreNg1lUT0Qx9enmSr-THmBl9PjVMfaB0tsIFjFr8,2057 -kafka/benchmarks/record_batch_read.py,sha256=vlFaWU2YWI379n_2M8qieb_S2uHUWKV0NquEYy5b-Ho,2184 -kafka/benchmarks/varint_speed.py,sha256=s4CuvKgDZL-_zna5E3vM8RgHjhXuW6pcaO1z1WYZ_0Y,12585 -kafka/client_async.py,sha256=R8q_rRpG3RrYrRmcZo7XgO2oSdpLJATNcq8w-1vIJ_8,56878 -kafka/cluster.py,sha256=B4tOZYhZaYrcGsyAtdA8yejFm9ue7ElJxn_pd6Xhdfk,16775 -kafka/codec.py,sha256=8NZpnehzNrhSBIjzbPVSvyFbSeLAqEntE7BfVHu-_9I,10036 -kafka/conn.py,sha256=_yP-pGwEbkDmeutMOZjVilQXAnF4PWF_CDc60qC3DuE,69488 -kafka/consumer/__init__.py,sha256=NDdvtyuJgFyQZahqL9i5sYXGP6rOMIXWwHQEaZ1fCcs,122 -kafka/consumer/__pycache__/__init__.cpython-312.pyc,, -kafka/consumer/__pycache__/fetcher.cpython-312.pyc,, -kafka/consumer/__pycache__/group.cpython-312.pyc,, -kafka/consumer/__pycache__/subscription_state.cpython-312.pyc,, -kafka/consumer/fetcher.py,sha256=RlQLut54c5nOMl21neTJA2tmdsxIIPIX2Idu5Q-dYKY,69184 -kafka/consumer/group.py,sha256=1_4qES7x3XyAHjVbFZ_E0ilAoueyaeHiGpNgggYLGiQ,58945 -kafka/consumer/subscription_state.py,sha256=bK-YTVbOzhy8OB206QAfXVuo7zPA9YqYXnrRRST369c,24289 -kafka/coordinator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka/coordinator/__pycache__/__init__.cpython-312.pyc,, -kafka/coordinator/__pycache__/base.cpython-312.pyc,, -kafka/coordinator/__pycache__/consumer.cpython-312.pyc,, -kafka/coordinator/__pycache__/heartbeat.cpython-312.pyc,, -kafka/coordinator/__pycache__/protocol.cpython-312.pyc,, -kafka/coordinator/assignors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka/coordinator/assignors/__pycache__/__init__.cpython-312.pyc,, -kafka/coordinator/assignors/__pycache__/abstract.cpython-312.pyc,, -kafka/coordinator/assignors/__pycache__/range.cpython-312.pyc,, -kafka/coordinator/assignors/__pycache__/roundrobin.cpython-312.pyc,, -kafka/coordinator/assignors/abstract.py,sha256=belUnCkuw70HJ8HTWYIgVrT6pJmIBBrTl1vkO-bN1C0,1507 -kafka/coordinator/assignors/range.py,sha256=PXFkkb505pL1uJEQMTvXCOp0Rckm-qkoKqTGyn082qM,2912 -kafka/coordinator/assignors/roundrobin.py,sha256=Xt_TOvCtcdozjZSg1cxixLAPyWz1aTpDL8v1vDhX960,3776 -kafka/coordinator/assignors/sticky/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka/coordinator/assignors/sticky/__pycache__/__init__.cpython-312.pyc,, -kafka/coordinator/assignors/sticky/__pycache__/partition_movements.cpython-312.pyc,, -kafka/coordinator/assignors/sticky/__pycache__/sorted_set.cpython-312.pyc,, -kafka/coordinator/assignors/sticky/__pycache__/sticky_assignor.cpython-312.pyc,, -kafka/coordinator/assignors/sticky/partition_movements.py,sha256=npydNO-YCG_cv--U--9CPTLGTbTWahiw_Ek295ayBjQ,6476 -kafka/coordinator/assignors/sticky/sorted_set.py,sha256=lOckfQ7vcOMNnIx5WjfHhKC_MgToeOxbp9vc_4tPIzs,1904 -kafka/coordinator/assignors/sticky/sticky_assignor.py,sha256=p5gDou3Gom7bUSLF5zpilihNPiT-fqJl1J8QxykqqMw,34216 -kafka/coordinator/base.py,sha256=hXfwtDkrHXHiNqjshCOa19js_2Y6ibLsdzDvJKGmcKc,54419 -kafka/coordinator/consumer.py,sha256=le4bGbHfrDK4pperYXekPKzuZW576uXL324IOwS4Kmw,46348 -kafka/coordinator/heartbeat.py,sha256=LeJJlwz1oUEOfEMIFT-R7ZOHBQ-b-luVKwmKyWxLfDo,3242 -kafka/coordinator/protocol.py,sha256=wTaIOnUVbj0CKXZ82FktZo-zMRvOCk3hdQAoHJ62e3I,1041 -kafka/errors.py,sha256=qX2Fp0qawU_HBNcZCwB7EDCmx3C2PehrETi6qSEJHmk,33290 -kafka/future.py,sha256=ZQStbfUYIPJRrgMfAWxxjrIRVxsw4WCtSR0J0bkyGno,2847 -kafka/metrics/__init__.py,sha256=b82LCjV5BgisjmIc3pn11CqFpme5grtIFHWiH8C_R0U,574 -kafka/metrics/__pycache__/__init__.cpython-312.pyc,, -kafka/metrics/__pycache__/compound_stat.cpython-312.pyc,, -kafka/metrics/__pycache__/dict_reporter.cpython-312.pyc,, -kafka/metrics/__pycache__/kafka_metric.cpython-312.pyc,, -kafka/metrics/__pycache__/measurable.cpython-312.pyc,, -kafka/metrics/__pycache__/measurable_stat.cpython-312.pyc,, -kafka/metrics/__pycache__/metric_config.cpython-312.pyc,, -kafka/metrics/__pycache__/metric_name.cpython-312.pyc,, -kafka/metrics/__pycache__/metrics.cpython-312.pyc,, -kafka/metrics/__pycache__/metrics_reporter.cpython-312.pyc,, -kafka/metrics/__pycache__/quota.cpython-312.pyc,, -kafka/metrics/__pycache__/stat.cpython-312.pyc,, -kafka/metrics/compound_stat.py,sha256=vHypFwcp4wWd-fC3jeMiMX8TwiVnnrn1vNfpZlBTZmg,850 -kafka/metrics/dict_reporter.py,sha256=OvZ6SUFp-Yk3tNaWbC0ul0WXncp42ymg8bHw3O6MITA,2567 -kafka/metrics/kafka_metric.py,sha256=vsLHShdhAjltL1vc51__B3M8lCUldMERub8cIdK3gFk,995 -kafka/metrics/measurable.py,sha256=g5mp1c9816SRgJdgHXklTNqDoDnbeYp-opjoV3DOr7Q,770 -kafka/metrics/measurable_stat.py,sha256=Y4D7yrg07E9HqZlqh_EgeVnEEk4DRoNyKEoteEicssU,542 -kafka/metrics/metric_config.py,sha256=LcHTPumiRscwKvF2Da14oMbHAEZolk-gUKk1sxpDUoI,1235 -kafka/metrics/metric_name.py,sha256=eO9rBbd8sp1tWWu6O9YasbDxsS4QQzq8eD0fz1JRqJ8,3493 -kafka/metrics/metrics.py,sha256=EAuMd-OLeSX3IS16NvC3w2tpIEwvCPedPwQ1gyM0C7E,10383 -kafka/metrics/metrics_reporter.py,sha256=hxAs01C5Gj_orStdgHUOYSs4-kOI4xfu0MOkYyuX28s,1437 -kafka/metrics/quota.py,sha256=xzZH-nVdi4nWNo__LAkRWUyb84DKsYGvBBt_ZzRhpKc,1170 -kafka/metrics/stat.py,sha256=eos8xrmz7vgBnIk-8LyqpbEsBbyqEEdJ_CrDzEVGEaU,667 -kafka/metrics/stats/__init__.py,sha256=sHcT6pvQCt-s_aow5_QRy9Z5bRV4ShBCZlin51f--Ro,629 -kafka/metrics/stats/__pycache__/__init__.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/avg.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/count.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/histogram.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/max_stat.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/min_stat.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/percentile.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/percentiles.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/rate.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/sampled_stat.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/sensor.cpython-312.pyc,, -kafka/metrics/stats/__pycache__/total.cpython-312.pyc,, -kafka/metrics/stats/avg.py,sha256=6YDKXBfr7z0w4_yXBDhdycUcTiPvT8Rw3B_iD-c_Qi0,738 -kafka/metrics/stats/count.py,sha256=2of_mXwfzp9ZCLzEA2VOXr0PWkdLy4TSZL0uH5nF1Dw,547 -kafka/metrics/stats/histogram.py,sha256=5jNlZHOnHvGOpho-Zm0Rna6GcHy-CYjxPe612B5DHIk,3039 -kafka/metrics/stats/max_stat.py,sha256=n_90jTiHCgF193OCu2wtjUlJJxSkldW336OyEAexbv0,606 -kafka/metrics/stats/min_stat.py,sha256=xKzBc3tQjk4ieiGdvs9HqKn885mPV6yaDxCb2ANye8c,628 -kafka/metrics/stats/percentile.py,sha256=RkBL4L1AIBL5Mp74xIOt5lYJol4PSLNYmROcpD9bMb0,391 -kafka/metrics/stats/percentiles.py,sha256=9aYsUwZO6h-uqsYnx8ob9biWwWJ-ztRDwTZ8AXVRI3w,3027 -kafka/metrics/stats/rate.py,sha256=5vvGCUyqZF7QDeUtVu0g37UVRavkwqdRc7DldKlMGn0,4628 -kafka/metrics/stats/sampled_stat.py,sha256=zO9HwoJFZvuuDWj_OdckPeVpxUxhR5dhRXcLTL0-hUQ,3556 -kafka/metrics/stats/sensor.py,sha256=xQsbt3cqcBkJr9ccAkFabWgh9pdeMzggYSjhiStvAdo,5317 -kafka/metrics/stats/total.py,sha256=gS4F4bsSv4gp4R1Et_SQnx2KaoE8wZ5SY9X10xf6bic,446 -kafka/partitioner/__init__.py,sha256=Fks3C5_kokVWYw1Ad5wv0sVVzaaBtOejL-2bIL1yRII,158 -kafka/partitioner/__pycache__/__init__.cpython-312.pyc,, -kafka/partitioner/__pycache__/default.cpython-312.pyc,, -kafka/partitioner/default.py,sha256=tW-RC1PWIPRDEbeEAaPTLn-00oiZnXoVouEk9AnYE4w,2879 -kafka/producer/__init__.py,sha256=i3Wxih0NHjmqCkRNE54ial8fBp9siqabUE6ZGyL6oX8,122 -kafka/producer/__pycache__/__init__.cpython-312.pyc,, -kafka/producer/__pycache__/future.cpython-312.pyc,, -kafka/producer/__pycache__/kafka.cpython-312.pyc,, -kafka/producer/__pycache__/record_accumulator.cpython-312.pyc,, -kafka/producer/__pycache__/sender.cpython-312.pyc,, -kafka/producer/__pycache__/transaction_manager.cpython-312.pyc,, -kafka/producer/future.py,sha256=UC3-g9QlgVFmbitrtMXVpeP0Pbvr7xl2kcw6bAehKG8,2983 -kafka/producer/kafka.py,sha256=oGO-UxoVZEFdBLOQ7zEqeDJWXMxKyUdNV-pCRU3jZmg,53302 -kafka/producer/record_accumulator.py,sha256=xNkHOCmganxDfa3W_Y3iBLT4RaAOZi0Lix-mUzsp2aQ,28170 -kafka/producer/sender.py,sha256=8-TLTw6vQO7AheWSDPI33cQdWMyTDxi1k-pkXuUb9k0,37789 -kafka/producer/transaction_manager.py,sha256=q3e9Lc9o-ofWvjT9FbHdTQH08XQBeRtoQEcQHGcnp7g,41535 -kafka/protocol/__init__.py,sha256=T1RBBlTH3zze0Cr1RqemPD4Z1b3IUDRmLOBfZTsPgLs,1184 -kafka/protocol/__pycache__/__init__.cpython-312.pyc,, -kafka/protocol/__pycache__/abstract.cpython-312.pyc,, -kafka/protocol/__pycache__/add_offsets_to_txn.cpython-312.pyc,, -kafka/protocol/__pycache__/add_partitions_to_txn.cpython-312.pyc,, -kafka/protocol/__pycache__/admin.cpython-312.pyc,, -kafka/protocol/__pycache__/api.cpython-312.pyc,, -kafka/protocol/__pycache__/api_versions.cpython-312.pyc,, -kafka/protocol/__pycache__/broker_api_versions.cpython-312.pyc,, -kafka/protocol/__pycache__/commit.cpython-312.pyc,, -kafka/protocol/__pycache__/end_txn.cpython-312.pyc,, -kafka/protocol/__pycache__/fetch.cpython-312.pyc,, -kafka/protocol/__pycache__/find_coordinator.cpython-312.pyc,, -kafka/protocol/__pycache__/frame.cpython-312.pyc,, -kafka/protocol/__pycache__/group.cpython-312.pyc,, -kafka/protocol/__pycache__/init_producer_id.cpython-312.pyc,, -kafka/protocol/__pycache__/list_offsets.cpython-312.pyc,, -kafka/protocol/__pycache__/message.cpython-312.pyc,, -kafka/protocol/__pycache__/metadata.cpython-312.pyc,, -kafka/protocol/__pycache__/offset_for_leader_epoch.cpython-312.pyc,, -kafka/protocol/__pycache__/parser.cpython-312.pyc,, -kafka/protocol/__pycache__/pickle.cpython-312.pyc,, -kafka/protocol/__pycache__/produce.cpython-312.pyc,, -kafka/protocol/__pycache__/sasl_authenticate.cpython-312.pyc,, -kafka/protocol/__pycache__/sasl_handshake.cpython-312.pyc,, -kafka/protocol/__pycache__/struct.cpython-312.pyc,, -kafka/protocol/__pycache__/txn_offset_commit.cpython-312.pyc,, -kafka/protocol/__pycache__/types.cpython-312.pyc,, -kafka/protocol/abstract.py,sha256=uOnuf6D8OTkL31Tp2QXG3VlzDPHVELGzM_bpSVa-_iw,424 -kafka/protocol/add_offsets_to_txn.py,sha256=Hya7vg6yqsV9XGLKWi8rES_VuN47-H4fdycg6mx8GLY,1486 -kafka/protocol/add_partitions_to_txn.py,sha256=mEz0DTrhY1ZN_GoITCQKRo-DO_HPc7A9r9eo_z1aF10,1766 -kafka/protocol/admin.py,sha256=11zE9sVrb34QY6AwYVvvWiwg4iycnq9aDSONCiuE9bo,30720 -kafka/protocol/api.py,sha256=ZI7DYb85UTL4BuhpwKGAyAKEv4Dl_y69AEW78M233lg,3813 -kafka/protocol/api_versions.py,sha256=VC9pvorLM--BE2uw0SvpeeMQPfWmcOvTgDFigLuGuVM,3546 -kafka/protocol/broker_api_versions.py,sha256=LA_pdbfsJClBxQPi01u5yVRLUIpZRUz6LiqhSsj8cgU,16523 -kafka/protocol/commit.py,sha256=-COlx8lTVCI6Zg4ZebDnsX4Wy_V69Kjw8V85FRd3Ics,8627 -kafka/protocol/end_txn.py,sha256=I0C1cxjkgLR0ri3QbEcmTkNoVT-lh7Bv_KaZO2wZUD0,1293 -kafka/protocol/fetch.py,sha256=G3Hh0AWGbEiWmiC83-b0t2jGlRLBovYz_ecfSp-vMEE,11214 -kafka/protocol/find_coordinator.py,sha256=sROaXxqAje2BSaNunh6QMTdVcR7uil5kz-woZqdg2BY,1697 -kafka/protocol/frame.py,sha256=SejRBK5urTD-2UzcVM2OxTgC73qDxfF3nlBPl9sjsfY,734 -kafka/protocol/group.py,sha256=SClv-Ntrj4IdEEL23L-S8XtCbELYojiue7BYwV8WjPc,7172 -kafka/protocol/init_producer_id.py,sha256=bFiPJTLTFXHNth2lg53mg9_N8znUBvpqR1PO31-RUlw,1117 -kafka/protocol/list_offsets.py,sha256=3kvif8X-B2LBSpR3qwbkGYyJ0GLKbQdENDGpxWV0scQ,4887 -kafka/protocol/message.py,sha256=9wNwJvfl9bsrdk_YcxbmAFjgvwZ5R1EBLSif2KILg9s,7657 -kafka/protocol/metadata.py,sha256=X99gdDTQJZWDrEa0sGWbwVED9cpKZ2zax6s6cMnN4xw,7403 -kafka/protocol/offset_for_leader_epoch.py,sha256=aunp-LMIuwcCsKwvgBZ8OcUhcgb0blaq5d3PAh22JOo,4304 -kafka/protocol/parser.py,sha256=OB3yebOp6JSQpl-5fEpV1_0SdAtYkiqIk6ffDIkHzu0,6859 -kafka/protocol/pickle.py,sha256=FGEv-1l1aXY3TogqzCwOS1gCNpEg6-xNLbrysqNdHcs,920 -kafka/protocol/produce.py,sha256=JDWCRY5B7eSL3vp0N977MIgYMrR2qxgrbUZrqQMlGWk,6540 -kafka/protocol/sasl_authenticate.py,sha256=HaFAHPRhCKgyGEoJ5LwGffcpMUBNCphgBgXCsITLho8,1150 -kafka/protocol/sasl_handshake.py,sha256=WzQh9HBRemXvShrczkN4rd4SM-hNdes1khMzPRvcRQQ,982 -kafka/protocol/struct.py,sha256=DxktwrPp1pj4b7Vne2H5n-xWjgx9jpCmf0ydZkeIjoY,2380 -kafka/protocol/txn_offset_commit.py,sha256=_6Wr-SabUd9q09Tj9oG43AVZcqlW3LYbqXNW1Pvk9vs,2250 -kafka/protocol/types.py,sha256=f-lwfCqsJulYnBT1loek_KbMnZZqItN4YRIONjg3kbE,10244 -kafka/record/__init__.py,sha256=Q20hP_R5XX3AEnAlPkpoWzTLShESvxUT2OLXmI-JYEQ,129 -kafka/record/__pycache__/__init__.cpython-312.pyc,, -kafka/record/__pycache__/_crc32c.cpython-312.pyc,, -kafka/record/__pycache__/abc.cpython-312.pyc,, -kafka/record/__pycache__/default_records.cpython-312.pyc,, -kafka/record/__pycache__/legacy_records.cpython-312.pyc,, -kafka/record/__pycache__/memory_records.cpython-312.pyc,, -kafka/record/__pycache__/util.cpython-312.pyc,, -kafka/record/_crc32c.py,sha256=Ok-P62Yvg6D6rMGM9Z56OMjZWQlnps4xBbakg-sdxvI,5761 -kafka/record/abc.py,sha256=z1UYURHbD2RyyGRpVXKP598jck5eXU9p4M6iUo6ZSFo,4110 -kafka/record/default_records.py,sha256=IuICFp0soETihkp8bUyjjksqTlzU45o-UYmo8joLBmo,25992 -kafka/record/legacy_records.py,sha256=bm1Y24PLVgLKtWqamESKvMk9P01uw3aQ8Z8q2QHxJy8,18858 -kafka/record/memory_records.py,sha256=b7RFxvaQ93drXSk3o3_YB3FQlVoESoBlGj3Z5PD25n8,8874 -kafka/record/util.py,sha256=LDajBWdYVetmXts_t9Q76CxEx7njgC9LnjMgz9yPEMM,3556 -kafka/sasl/__init__.py,sha256=wUUGIKRe52J6Qekj7hSypg44vWTrkYsEdVafQC7cX5s,1106 -kafka/sasl/__pycache__/__init__.cpython-312.pyc,, -kafka/sasl/__pycache__/abc.cpython-312.pyc,, -kafka/sasl/__pycache__/gssapi.cpython-312.pyc,, -kafka/sasl/__pycache__/msk.cpython-312.pyc,, -kafka/sasl/__pycache__/oauth.cpython-312.pyc,, -kafka/sasl/__pycache__/plain.cpython-312.pyc,, -kafka/sasl/__pycache__/scram.cpython-312.pyc,, -kafka/sasl/__pycache__/sspi.cpython-312.pyc,, -kafka/sasl/abc.py,sha256=R0BZOk3AYEGyehiGbbg-LMRvFAlWZsh0fBiESgUpBYw,657 -kafka/sasl/gssapi.py,sha256=pwLxXqcmJJxkuFQUoEfX5PWgZxr-8TziuRCg9K7fO3E,4705 -kafka/sasl/msk.py,sha256=FCv0uUTQKjvR2gIGyiv-dlwIvkpvEtaHvhqhXtC2q8w,8101 -kafka/sasl/oauth.py,sha256=dh87tVi-dlS5lIzgYsC4m7IXUhlLdejaMb9Ua6oYaB0,3425 -kafka/sasl/plain.py,sha256=PMfoWT856wx6nF_LhpfPKEnD7BRNx5l6rDhAqxBnMWU,1317 -kafka/sasl/scram.py,sha256=77If2o9x-QZDBs2fqml17S-wGyR5YkOMr2nZxXrCW9c,5045 -kafka/sasl/sspi.py,sha256=RUIVyWCEdlJPV1oj7bdzG8gORvFyR_9Bt79TzIohwMM,5001 -kafka/serializer/__init__.py,sha256=_I4utl_8nNhcRzLLezFtwYX5akk6QKYmxa1HanRlYPU,103 -kafka/serializer/__pycache__/__init__.cpython-312.pyc,, -kafka/serializer/__pycache__/abstract.cpython-312.pyc,, -kafka/serializer/abstract.py,sha256=doiXDkMYt2SEHRarBdd8xVZKvr5S1qPdNEtl4syWA6Q,486 -kafka/socks5_wrapper.py,sha256=6woOaCTJXJ5e89_zdyW5BjOpyE4rCbYFH-kd-FeuPuk,9827 -kafka/structs.py,sha256=SJGzmLdV21jZyQ7247k0WFy16UiusgTHK3I-e4qzI-E,3058 -kafka/util.py,sha256=WGqI5yT1yWGgHqSuRF9Fi8ejpiB53SurMy7ABkYxJ2g,4584 -kafka/vendor/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka/vendor/__pycache__/__init__.cpython-312.pyc,, -kafka/vendor/__pycache__/enum34.cpython-312.pyc,, -kafka/vendor/__pycache__/selectors34.cpython-312.pyc,, -kafka/vendor/__pycache__/six.cpython-312.pyc,, -kafka/vendor/__pycache__/socketpair.cpython-312.pyc,, -kafka/vendor/enum34.py,sha256=-u-lxAiJMt6ru4Do7NUDY9OpeWkYJMksb2xengJawFE,31204 -kafka/vendor/selectors34.py,sha256=gxejLO4eXf8mRSGXaQiknPig3GdX1rtsZiYOQJVuAy8,20594 -kafka/vendor/six.py,sha256=lLBa9_HrANP5BMZ7twEzg1M3wofwPmXyptuWmHX0brY,34826 -kafka/vendor/socketpair.py,sha256=Fi3PoY1Okkppab720wFk1BhHXyjcw7hi5DwhqrYZH2Y,2737 -kafka/version.py,sha256=Vh0q00JWD6pn7UpRKd065A7-8g7Bv7yYCxnqmZMfsFY,23 -kafka_python-2.2.15.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -kafka_python-2.2.15.dist-info/METADATA,sha256=K9jQXj1ujRv2RCbdfjE07NblzS8mIlVycU1q_bMOtUc,9952 -kafka_python-2.2.15.dist-info/RECORD,, -kafka_python-2.2.15.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -kafka_python-2.2.15.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109 -kafka_python-2.2.15.dist-info/top_level.txt,sha256=IivJz7l5WHdLNDT6RIiVAlhjQzYRwGqBBmKHZ7WjPeM,6 diff --git a/venv/lib/python3.12/site-packages/multipart/__init__.py b/venv/lib/python3.12/site-packages/multipart/__init__.py index 67f0e5b..309d698 100644 --- a/venv/lib/python3.12/site-packages/multipart/__init__.py +++ b/venv/lib/python3.12/site-packages/multipart/__init__.py @@ -1,24 +1,15 @@ -# This only works if using a file system, other loaders not implemented. +# This is the canonical package information. +__author__ = 'Andrew Dunham' +__license__ = 'Apache' +__copyright__ = "Copyright (c) 2012-2013, Andrew Dunham" +__version__ = "0.0.6" -import importlib.util -import sys -import warnings -from pathlib import Path -for p in sys.path: - file_path = Path(p, "multipart.py") - try: - if file_path.is_file(): - spec = importlib.util.spec_from_file_location("multipart", file_path) - assert spec is not None, f"{file_path} found but not loadable!" - module = importlib.util.module_from_spec(spec) - sys.modules["multipart"] = module - assert spec.loader is not None, f"{file_path} must be loadable!" - spec.loader.exec_module(module) - break - except PermissionError: - pass -else: - warnings.warn("Please use `import python_multipart` instead.", PendingDeprecationWarning, stacklevel=2) - from python_multipart import * - from python_multipart import __all__, __author__, __copyright__, __license__, __version__ +from .multipart import ( + FormParser, + MultipartParser, + QuerystringParser, + OctetStreamParser, + create_form_parser, + parse_form, +) diff --git a/venv/lib/python3.12/site-packages/multipart/decoders.py b/venv/lib/python3.12/site-packages/multipart/decoders.py index 31acdfb..0d7ab32 100644 --- a/venv/lib/python3.12/site-packages/multipart/decoders.py +++ b/venv/lib/python3.12/site-packages/multipart/decoders.py @@ -1 +1,171 @@ -from python_multipart.decoders import * +import base64 +import binascii + +from .exceptions import DecodeError + + +class Base64Decoder: + """This object provides an interface to decode a stream of Base64 data. It + is instantiated with an "underlying object", and whenever a write() + operation is performed, it will decode the incoming data as Base64, and + call write() on the underlying object. This is primarily used for decoding + form data encoded as Base64, but can be used for other purposes:: + + from multipart.decoders import Base64Decoder + fd = open("notb64.txt", "wb") + decoder = Base64Decoder(fd) + try: + decoder.write("Zm9vYmFy") # "foobar" in Base64 + decoder.finalize() + finally: + decoder.close() + + # The contents of "notb64.txt" should be "foobar". + + This object will also pass all finalize() and close() calls to the + underlying object, if the underlying object supports them. + + Note that this class maintains a cache of base64 chunks, so that a write of + arbitrary size can be performed. You must call :meth:`finalize` on this + object after all writes are completed to ensure that all data is flushed + to the underlying object. + + :param underlying: the underlying object to pass writes to + """ + + def __init__(self, underlying): + self.cache = bytearray() + self.underlying = underlying + + def write(self, data): + """Takes any input data provided, decodes it as base64, and passes it + on to the underlying object. If the data provided is invalid base64 + data, then this method will raise + a :class:`multipart.exceptions.DecodeError` + + :param data: base64 data to decode + """ + + # Prepend any cache info to our data. + if len(self.cache) > 0: + data = self.cache + data + + # Slice off a string that's a multiple of 4. + decode_len = (len(data) // 4) * 4 + val = data[:decode_len] + + # Decode and write, if we have any. + if len(val) > 0: + try: + decoded = base64.b64decode(val) + except binascii.Error: + raise DecodeError('There was an error raised while decoding ' + 'base64-encoded data.') + + self.underlying.write(decoded) + + # Get the remaining bytes and save in our cache. + remaining_len = len(data) % 4 + if remaining_len > 0: + self.cache = data[-remaining_len:] + else: + self.cache = b'' + + # Return the length of the data to indicate no error. + return len(data) + + def close(self): + """Close this decoder. If the underlying object has a `close()` + method, this function will call it. + """ + if hasattr(self.underlying, 'close'): + self.underlying.close() + + def finalize(self): + """Finalize this object. This should be called when no more data + should be written to the stream. This function can raise a + :class:`multipart.exceptions.DecodeError` if there is some remaining + data in the cache. + + If the underlying object has a `finalize()` method, this function will + call it. + """ + if len(self.cache) > 0: + raise DecodeError('There are %d bytes remaining in the ' + 'Base64Decoder cache when finalize() is called' + % len(self.cache)) + + if hasattr(self.underlying, 'finalize'): + self.underlying.finalize() + + def __repr__(self): + return f"{self.__class__.__name__}(underlying={self.underlying!r})" + + +class QuotedPrintableDecoder: + """This object provides an interface to decode a stream of quoted-printable + data. It is instantiated with an "underlying object", in the same manner + as the :class:`multipart.decoders.Base64Decoder` class. This class behaves + in exactly the same way, including maintaining a cache of quoted-printable + chunks. + + :param underlying: the underlying object to pass writes to + """ + def __init__(self, underlying): + self.cache = b'' + self.underlying = underlying + + def write(self, data): + """Takes any input data provided, decodes it as quoted-printable, and + passes it on to the underlying object. + + :param data: quoted-printable data to decode + """ + # Prepend any cache info to our data. + if len(self.cache) > 0: + data = self.cache + data + + # If the last 2 characters have an '=' sign in it, then we won't be + # able to decode the encoded value and we'll need to save it for the + # next decoding step. + if data[-2:].find(b'=') != -1: + enc, rest = data[:-2], data[-2:] + else: + enc = data + rest = b'' + + # Encode and write, if we have data. + if len(enc) > 0: + self.underlying.write(binascii.a2b_qp(enc)) + + # Save remaining in cache. + self.cache = rest + return len(data) + + def close(self): + """Close this decoder. If the underlying object has a `close()` + method, this function will call it. + """ + if hasattr(self.underlying, 'close'): + self.underlying.close() + + def finalize(self): + """Finalize this object. This should be called when no more data + should be written to the stream. This function will not raise any + exceptions, but it may write more data to the underlying object if + there is data remaining in the cache. + + If the underlying object has a `finalize()` method, this function will + call it. + """ + # If we have a cache, write and then remove it. + if len(self.cache) > 0: + self.underlying.write(binascii.a2b_qp(self.cache)) + self.cache = b'' + + # Finalize our underlying stream. + if hasattr(self.underlying, 'finalize'): + self.underlying.finalize() + + def __repr__(self): + return f"{self.__class__.__name__}(underlying={self.underlying!r})" diff --git a/venv/lib/python3.12/site-packages/multipart/exceptions.py b/venv/lib/python3.12/site-packages/multipart/exceptions.py index 36815d1..016e7f7 100644 --- a/venv/lib/python3.12/site-packages/multipart/exceptions.py +++ b/venv/lib/python3.12/site-packages/multipart/exceptions.py @@ -1 +1,46 @@ -from python_multipart.exceptions import * +class FormParserError(ValueError): + """Base error class for our form parser.""" + pass + + +class ParseError(FormParserError): + """This exception (or a subclass) is raised when there is an error while + parsing something. + """ + + #: This is the offset in the input data chunk (*NOT* the overall stream) in + #: which the parse error occurred. It will be -1 if not specified. + offset = -1 + + +class MultipartParseError(ParseError): + """This is a specific error that is raised when the MultipartParser detects + an error while parsing. + """ + pass + + +class QuerystringParseError(ParseError): + """This is a specific error that is raised when the QuerystringParser + detects an error while parsing. + """ + pass + + +class DecodeError(ParseError): + """This exception is raised when there is a decoding error - for example + with the Base64Decoder or QuotedPrintableDecoder. + """ + pass + + +# On Python 3.3, IOError is the same as OSError, so we don't want to inherit +# from both of them. We handle this case below. +if IOError is not OSError: # pragma: no cover + class FileError(FormParserError, IOError, OSError): + """Exception class for problems with the File class.""" + pass +else: # pragma: no cover + class FileError(FormParserError, OSError): + """Exception class for problems with the File class.""" + pass diff --git a/venv/lib/python3.12/site-packages/multipart/multipart.py b/venv/lib/python3.12/site-packages/multipart/multipart.py index 7bf567d..a9f1f9f 100644 --- a/venv/lib/python3.12/site-packages/multipart/multipart.py +++ b/venv/lib/python3.12/site-packages/multipart/multipart.py @@ -1 +1,1893 @@ -from python_multipart.multipart import * +from .decoders import * +from .exceptions import * + +import os +import re +import sys +import shutil +import logging +import tempfile +from io import BytesIO +from numbers import Number + +# Unique missing object. +_missing = object() + +# States for the querystring parser. +STATE_BEFORE_FIELD = 0 +STATE_FIELD_NAME = 1 +STATE_FIELD_DATA = 2 + +# States for the multipart parser +STATE_START = 0 +STATE_START_BOUNDARY = 1 +STATE_HEADER_FIELD_START = 2 +STATE_HEADER_FIELD = 3 +STATE_HEADER_VALUE_START = 4 +STATE_HEADER_VALUE = 5 +STATE_HEADER_VALUE_ALMOST_DONE = 6 +STATE_HEADERS_ALMOST_DONE = 7 +STATE_PART_DATA_START = 8 +STATE_PART_DATA = 9 +STATE_PART_DATA_END = 10 +STATE_END = 11 + +STATES = [ + "START", + "START_BOUNDARY", "HEADER_FIELD_START", "HEADER_FIELD", "HEADER_VALUE_START", "HEADER_VALUE", + "HEADER_VALUE_ALMOST_DONE", "HEADRES_ALMOST_DONE", "PART_DATA_START", "PART_DATA", "PART_DATA_END", "END" +] + + +# Flags for the multipart parser. +FLAG_PART_BOUNDARY = 1 +FLAG_LAST_BOUNDARY = 2 + +# Get constants. Since iterating over a str on Python 2 gives you a 1-length +# string, but iterating over a bytes object on Python 3 gives you an integer, +# we need to save these constants. +CR = b'\r'[0] +LF = b'\n'[0] +COLON = b':'[0] +SPACE = b' '[0] +HYPHEN = b'-'[0] +AMPERSAND = b'&'[0] +SEMICOLON = b';'[0] +LOWER_A = b'a'[0] +LOWER_Z = b'z'[0] +NULL = b'\x00'[0] + +# Lower-casing a character is different, because of the difference between +# str on Py2, and bytes on Py3. Same with getting the ordinal value of a byte, +# and joining a list of bytes together. +# These functions abstract that. +lower_char = lambda c: c | 0x20 +ord_char = lambda c: c +join_bytes = lambda b: bytes(list(b)) + +# These are regexes for parsing header values. +SPECIAL_CHARS = re.escape(b'()<>@,;:\\"/[]?={} \t') +QUOTED_STR = br'"(?:\\.|[^"])*"' +VALUE_STR = br'(?:[^' + SPECIAL_CHARS + br']+|' + QUOTED_STR + br')' +OPTION_RE_STR = ( + br'(?:;|^)\s*([^' + SPECIAL_CHARS + br']+)\s*=\s*(' + VALUE_STR + br')' +) +OPTION_RE = re.compile(OPTION_RE_STR) +QUOTE = b'"'[0] + + +def parse_options_header(value): + """ + Parses a Content-Type header into a value in the following format: + (content_type, {parameters}) + """ + if not value: + return (b'', {}) + + # If we are passed a string, we assume that it conforms to WSGI and does + # not contain any code point that's not in latin-1. + if isinstance(value, str): # pragma: no cover + value = value.encode('latin-1') + + # If we have no options, return the string as-is. + if b';' not in value: + return (value.lower().strip(), {}) + + # Split at the first semicolon, to get our value and then options. + ctype, rest = value.split(b';', 1) + options = {} + + # Parse the options. + for match in OPTION_RE.finditer(rest): + key = match.group(1).lower() + value = match.group(2) + if value[0] == QUOTE and value[-1] == QUOTE: + # Unquote the value. + value = value[1:-1] + value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') + + # If the value is a filename, we need to fix a bug on IE6 that sends + # the full file path instead of the filename. + if key == b'filename': + if value[1:3] == b':\\' or value[:2] == b'\\\\': + value = value.split(b'\\')[-1] + + options[key] = value + + return ctype, options + + +class Field: + """A Field object represents a (parsed) form field. It represents a single + field with a corresponding name and value. + + The name that a :class:`Field` will be instantiated with is the same name + that would be found in the following HTML:: + + + + This class defines two methods, :meth:`on_data` and :meth:`on_end`, that + will be called when data is written to the Field, and when the Field is + finalized, respectively. + + :param name: the name of the form field + """ + def __init__(self, name): + self._name = name + self._value = [] + + # We cache the joined version of _value for speed. + self._cache = _missing + + @classmethod + def from_value(klass, name, value): + """Create an instance of a :class:`Field`, and set the corresponding + value - either None or an actual value. This method will also + finalize the Field itself. + + :param name: the name of the form field + :param value: the value of the form field - either a bytestring or + None + """ + + f = klass(name) + if value is None: + f.set_none() + else: + f.write(value) + f.finalize() + return f + + def write(self, data): + """Write some data into the form field. + + :param data: a bytestring + """ + return self.on_data(data) + + def on_data(self, data): + """This method is a callback that will be called whenever data is + written to the Field. + + :param data: a bytestring + """ + self._value.append(data) + self._cache = _missing + return len(data) + + def on_end(self): + """This method is called whenever the Field is finalized. + """ + if self._cache is _missing: + self._cache = b''.join(self._value) + + def finalize(self): + """Finalize the form field. + """ + self.on_end() + + def close(self): + """Close the Field object. This will free any underlying cache. + """ + # Free our value array. + if self._cache is _missing: + self._cache = b''.join(self._value) + + del self._value + + def set_none(self): + """Some fields in a querystring can possibly have a value of None - for + example, the string "foo&bar=&baz=asdf" will have a field with the + name "foo" and value None, one with name "bar" and value "", and one + with name "baz" and value "asdf". Since the write() interface doesn't + support writing None, this function will set the field value to None. + """ + self._cache = None + + @property + def field_name(self): + """This property returns the name of the field.""" + return self._name + + @property + def value(self): + """This property returns the value of the form field.""" + if self._cache is _missing: + self._cache = b''.join(self._value) + + return self._cache + + def __eq__(self, other): + if isinstance(other, Field): + return ( + self.field_name == other.field_name and + self.value == other.value + ) + else: + return NotImplemented + + def __repr__(self): + if len(self.value) > 97: + # We get the repr, and then insert three dots before the final + # quote. + v = repr(self.value[:97])[:-1] + "...'" + else: + v = repr(self.value) + + return "{}(field_name={!r}, value={})".format( + self.__class__.__name__, + self.field_name, + v + ) + + +class File: + """This class represents an uploaded file. It handles writing file data to + either an in-memory file or a temporary file on-disk, if the optional + threshold is passed. + + There are some options that can be passed to the File to change behavior + of the class. Valid options are as follows: + + .. list-table:: + :widths: 15 5 5 30 + :header-rows: 1 + + * - Name + - Type + - Default + - Description + * - UPLOAD_DIR + - `str` + - None + - The directory to store uploaded files in. If this is None, a + temporary file will be created in the system's standard location. + * - UPLOAD_DELETE_TMP + - `bool` + - True + - Delete automatically created TMP file + * - UPLOAD_KEEP_FILENAME + - `bool` + - False + - Whether or not to keep the filename of the uploaded file. If True, + then the filename will be converted to a safe representation (e.g. + by removing any invalid path segments), and then saved with the + same name). Otherwise, a temporary name will be used. + * - UPLOAD_KEEP_EXTENSIONS + - `bool` + - False + - Whether or not to keep the uploaded file's extension. If False, the + file will be saved with the default temporary extension (usually + ".tmp"). Otherwise, the file's extension will be maintained. Note + that this will properly combine with the UPLOAD_KEEP_FILENAME + setting. + * - MAX_MEMORY_FILE_SIZE + - `int` + - 1 MiB + - The maximum number of bytes of a File to keep in memory. By + default, the contents of a File are kept into memory until a certain + limit is reached, after which the contents of the File are written + to a temporary file. This behavior can be disabled by setting this + value to an appropriately large value (or, for example, infinity, + such as `float('inf')`. + + :param file_name: The name of the file that this :class:`File` represents + + :param field_name: The field name that uploaded this file. Note that this + can be None, if, for example, the file was uploaded + with Content-Type application/octet-stream + + :param config: The configuration for this File. See above for valid + configuration keys and their corresponding values. + """ + def __init__(self, file_name, field_name=None, config={}): + # Save configuration, set other variables default. + self.logger = logging.getLogger(__name__) + self._config = config + self._in_memory = True + self._bytes_written = 0 + self._fileobj = BytesIO() + + # Save the provided field/file name. + self._field_name = field_name + self._file_name = file_name + + # Our actual file name is None by default, since, depending on our + # config, we may not actually use the provided name. + self._actual_file_name = None + + # Split the extension from the filename. + if file_name is not None: + base, ext = os.path.splitext(file_name) + self._file_base = base + self._ext = ext + + @property + def field_name(self): + """The form field associated with this file. May be None if there isn't + one, for example when we have an application/octet-stream upload. + """ + return self._field_name + + @property + def file_name(self): + """The file name given in the upload request. + """ + return self._file_name + + @property + def actual_file_name(self): + """The file name that this file is saved as. Will be None if it's not + currently saved on disk. + """ + return self._actual_file_name + + @property + def file_object(self): + """The file object that we're currently writing to. Note that this + will either be an instance of a :class:`io.BytesIO`, or a regular file + object. + """ + return self._fileobj + + @property + def size(self): + """The total size of this file, counted as the number of bytes that + currently have been written to the file. + """ + return self._bytes_written + + @property + def in_memory(self): + """A boolean representing whether or not this file object is currently + stored in-memory or on-disk. + """ + return self._in_memory + + def flush_to_disk(self): + """If the file is already on-disk, do nothing. Otherwise, copy from + the in-memory buffer to a disk file, and then reassign our internal + file object to this new disk file. + + Note that if you attempt to flush a file that is already on-disk, a + warning will be logged to this module's logger. + """ + if not self._in_memory: + self.logger.warning( + "Trying to flush to disk when we're not in memory" + ) + return + + # Go back to the start of our file. + self._fileobj.seek(0) + + # Open a new file. + new_file = self._get_disk_file() + + # Copy the file objects. + shutil.copyfileobj(self._fileobj, new_file) + + # Seek to the new position in our new file. + new_file.seek(self._bytes_written) + + # Reassign the fileobject. + old_fileobj = self._fileobj + self._fileobj = new_file + + # We're no longer in memory. + self._in_memory = False + + # Close the old file object. + old_fileobj.close() + + def _get_disk_file(self): + """This function is responsible for getting a file object on-disk for us. + """ + self.logger.info("Opening a file on disk") + + file_dir = self._config.get('UPLOAD_DIR') + keep_filename = self._config.get('UPLOAD_KEEP_FILENAME', False) + keep_extensions = self._config.get('UPLOAD_KEEP_EXTENSIONS', False) + delete_tmp = self._config.get('UPLOAD_DELETE_TMP', True) + + # If we have a directory and are to keep the filename... + if file_dir is not None and keep_filename: + self.logger.info("Saving with filename in: %r", file_dir) + + # Build our filename. + # TODO: what happens if we don't have a filename? + fname = self._file_base + if keep_extensions: + fname = fname + self._ext + + path = os.path.join(file_dir, fname) + try: + self.logger.info("Opening file: %r", path) + tmp_file = open(path, 'w+b') + except OSError as e: + tmp_file = None + + self.logger.exception("Error opening temporary file") + raise FileError("Error opening temporary file: %r" % path) + else: + # Build options array. + # Note that on Python 3, tempfile doesn't support byte names. We + # encode our paths using the default filesystem encoding. + options = {} + if keep_extensions: + ext = self._ext + if isinstance(ext, bytes): + ext = ext.decode(sys.getfilesystemencoding()) + + options['suffix'] = ext + if file_dir is not None: + d = file_dir + if isinstance(d, bytes): + d = d.decode(sys.getfilesystemencoding()) + + options['dir'] = d + options['delete'] = delete_tmp + + # Create a temporary (named) file with the appropriate settings. + self.logger.info("Creating a temporary file with options: %r", + options) + try: + tmp_file = tempfile.NamedTemporaryFile(**options) + except OSError: + self.logger.exception("Error creating named temporary file") + raise FileError("Error creating named temporary file") + + fname = tmp_file.name + + # Encode filename as bytes. + if isinstance(fname, str): + fname = fname.encode(sys.getfilesystemencoding()) + + self._actual_file_name = fname + return tmp_file + + def write(self, data): + """Write some data to the File. + + :param data: a bytestring + """ + return self.on_data(data) + + def on_data(self, data): + """This method is a callback that will be called whenever data is + written to the File. + + :param data: a bytestring + """ + pos = self._fileobj.tell() + bwritten = self._fileobj.write(data) + # true file objects write returns None + if bwritten is None: + bwritten = self._fileobj.tell() - pos + + # If the bytes written isn't the same as the length, just return. + if bwritten != len(data): + self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, + len(data)) + return bwritten + + # Keep track of how many bytes we've written. + self._bytes_written += bwritten + + # If we're in-memory and are over our limit, we create a file. + if (self._in_memory and + self._config.get('MAX_MEMORY_FILE_SIZE') is not None and + (self._bytes_written > + self._config.get('MAX_MEMORY_FILE_SIZE'))): + self.logger.info("Flushing to disk") + self.flush_to_disk() + + # Return the number of bytes written. + return bwritten + + def on_end(self): + """This method is called whenever the Field is finalized. + """ + # Flush the underlying file object + self._fileobj.flush() + + def finalize(self): + """Finalize the form file. This will not close the underlying file, + but simply signal that we are finished writing to the File. + """ + self.on_end() + + def close(self): + """Close the File object. This will actually close the underlying + file object (whether it's a :class:`io.BytesIO` or an actual file + object). + """ + self._fileobj.close() + + def __repr__(self): + return "{}(file_name={!r}, field_name={!r})".format( + self.__class__.__name__, + self.file_name, + self.field_name + ) + + +class BaseParser: + """This class is the base class for all parsers. It contains the logic for + calling and adding callbacks. + + A callback can be one of two different forms. "Notification callbacks" are + callbacks that are called when something happens - for example, when a new + part of a multipart message is encountered by the parser. "Data callbacks" + are called when we get some sort of data - for example, part of the body of + a multipart chunk. Notification callbacks are called with no parameters, + whereas data callbacks are called with three, as follows:: + + data_callback(data, start, end) + + The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on + Python 3). "start" and "end" are integer indexes into the "data" string + that represent the data of interest. Thus, in a data callback, the slice + `data[start:end]` represents the data that the callback is "interested in". + The callback is not passed a copy of the data, since copying severely hurts + performance. + """ + def __init__(self): + self.logger = logging.getLogger(__name__) + + def callback(self, name, data=None, start=None, end=None): + """This function calls a provided callback with some data. If the + callback is not set, will do nothing. + + :param name: The name of the callback to call (as a string). + + :param data: Data to pass to the callback. If None, then it is + assumed that the callback is a notification callback, + and no parameters are given. + + :param end: An integer that is passed to the data callback. + + :param start: An integer that is passed to the data callback. + """ + name = "on_" + name + func = self.callbacks.get(name) + if func is None: + return + + # Depending on whether we're given a buffer... + if data is not None: + # Don't do anything if we have start == end. + if start is not None and start == end: + return + + self.logger.debug("Calling %s with data[%d:%d]", name, start, end) + func(data, start, end) + else: + self.logger.debug("Calling %s with no data", name) + func() + + def set_callback(self, name, new_func): + """Update the function for a callback. Removes from the callbacks dict + if new_func is None. + + :param name: The name of the callback to call (as a string). + + :param new_func: The new function for the callback. If None, then the + callback will be removed (with no error if it does not + exist). + """ + if new_func is None: + self.callbacks.pop('on_' + name, None) + else: + self.callbacks['on_' + name] = new_func + + def close(self): + pass # pragma: no cover + + def finalize(self): + pass # pragma: no cover + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + +class OctetStreamParser(BaseParser): + """This parser parses an octet-stream request body and calls callbacks when + incoming data is received. Callbacks are as follows: + + .. list-table:: + :widths: 15 10 30 + :header-rows: 1 + + * - Callback Name + - Parameters + - Description + * - on_start + - None + - Called when the first data is parsed. + * - on_data + - data, start, end + - Called for each data chunk that is parsed. + * - on_end + - None + - Called when the parser is finished parsing all data. + + :param callbacks: A dictionary of callbacks. See the documentation for + :class:`BaseParser`. + + :param max_size: The maximum size of body to parse. Defaults to infinity - + i.e. unbounded. + """ + def __init__(self, callbacks={}, max_size=float('inf')): + super().__init__() + self.callbacks = callbacks + self._started = False + + if not isinstance(max_size, Number) or max_size < 1: + raise ValueError("max_size must be a positive number, not %r" % + max_size) + self.max_size = max_size + self._current_size = 0 + + def write(self, data): + """Write some data to the parser, which will perform size verification, + and then pass the data to the underlying callback. + + :param data: a bytestring + """ + if not self._started: + self.callback('start') + self._started = True + + # Truncate data length. + data_len = len(data) + if (self._current_size + data_len) > self.max_size: + # We truncate the length of data that we are to process. + new_size = int(self.max_size - self._current_size) + self.logger.warning("Current size is %d (max %d), so truncating " + "data length from %d to %d", + self._current_size, self.max_size, data_len, + new_size) + data_len = new_size + + # Increment size, then callback, in case there's an exception. + self._current_size += data_len + self.callback('data', data, 0, data_len) + return data_len + + def finalize(self): + """Finalize this parser, which signals to that we are finished parsing, + and sends the on_end callback. + """ + self.callback('end') + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + +class QuerystringParser(BaseParser): + """This is a streaming querystring parser. It will consume data, and call + the callbacks given when it has data. + + .. list-table:: + :widths: 15 10 30 + :header-rows: 1 + + * - Callback Name + - Parameters + - Description + * - on_field_start + - None + - Called when a new field is encountered. + * - on_field_name + - data, start, end + - Called when a portion of a field's name is encountered. + * - on_field_data + - data, start, end + - Called when a portion of a field's data is encountered. + * - on_field_end + - None + - Called when the end of a field is encountered. + * - on_end + - None + - Called when the parser is finished parsing all data. + + :param callbacks: A dictionary of callbacks. See the documentation for + :class:`BaseParser`. + + :param strict_parsing: Whether or not to parse the body strictly. Defaults + to False. If this is set to True, then the behavior + of the parser changes as the following: if a field + has a value with an equal sign (e.g. "foo=bar", or + "foo="), it is always included. If a field has no + equals sign (e.g. "...&name&..."), it will be + treated as an error if 'strict_parsing' is True, + otherwise included. If an error is encountered, + then a + :class:`multipart.exceptions.QuerystringParseError` + will be raised. + + :param max_size: The maximum size of body to parse. Defaults to infinity - + i.e. unbounded. + """ + def __init__(self, callbacks={}, strict_parsing=False, + max_size=float('inf')): + super().__init__() + self.state = STATE_BEFORE_FIELD + self._found_sep = False + + self.callbacks = callbacks + + # Max-size stuff + if not isinstance(max_size, Number) or max_size < 1: + raise ValueError("max_size must be a positive number, not %r" % + max_size) + self.max_size = max_size + self._current_size = 0 + + # Should parsing be strict? + self.strict_parsing = strict_parsing + + def write(self, data): + """Write some data to the parser, which will perform size verification, + parse into either a field name or value, and then pass the + corresponding data to the underlying callback. If an error is + encountered while parsing, a QuerystringParseError will be raised. The + "offset" attribute of the raised exception will be set to the offset in + the input data chunk (NOT the overall stream) that caused the error. + + :param data: a bytestring + """ + # Handle sizing. + data_len = len(data) + if (self._current_size + data_len) > self.max_size: + # We truncate the length of data that we are to process. + new_size = int(self.max_size - self._current_size) + self.logger.warning("Current size is %d (max %d), so truncating " + "data length from %d to %d", + self._current_size, self.max_size, data_len, + new_size) + data_len = new_size + + l = 0 + try: + l = self._internal_write(data, data_len) + finally: + self._current_size += l + + return l + + def _internal_write(self, data, length): + state = self.state + strict_parsing = self.strict_parsing + found_sep = self._found_sep + + i = 0 + while i < length: + ch = data[i] + + # Depending on our state... + if state == STATE_BEFORE_FIELD: + # If the 'found_sep' flag is set, we've already encountered + # and skipped a single separator. If so, we check our strict + # parsing flag and decide what to do. Otherwise, we haven't + # yet reached a separator, and thus, if we do, we need to skip + # it as it will be the boundary between fields that's supposed + # to be there. + if ch == AMPERSAND or ch == SEMICOLON: + if found_sep: + # If we're parsing strictly, we disallow blank chunks. + if strict_parsing: + e = QuerystringParseError( + "Skipping duplicate ampersand/semicolon at " + "%d" % i + ) + e.offset = i + raise e + else: + self.logger.debug("Skipping duplicate ampersand/" + "semicolon at %d", i) + else: + # This case is when we're skipping the (first) + # separator between fields, so we just set our flag + # and continue on. + found_sep = True + else: + # Emit a field-start event, and go to that state. Also, + # reset the "found_sep" flag, for the next time we get to + # this state. + self.callback('field_start') + i -= 1 + state = STATE_FIELD_NAME + found_sep = False + + elif state == STATE_FIELD_NAME: + # Try and find a separator - we ensure that, if we do, we only + # look for the equal sign before it. + sep_pos = data.find(b'&', i) + if sep_pos == -1: + sep_pos = data.find(b';', i) + + # See if we can find an equals sign in the remaining data. If + # so, we can immediately emit the field name and jump to the + # data state. + if sep_pos != -1: + equals_pos = data.find(b'=', i, sep_pos) + else: + equals_pos = data.find(b'=', i) + + if equals_pos != -1: + # Emit this name. + self.callback('field_name', data, i, equals_pos) + + # Jump i to this position. Note that it will then have 1 + # added to it below, which means the next iteration of this + # loop will inspect the character after the equals sign. + i = equals_pos + state = STATE_FIELD_DATA + else: + # No equals sign found. + if not strict_parsing: + # See also comments in the STATE_FIELD_DATA case below. + # If we found the separator, we emit the name and just + # end - there's no data callback at all (not even with + # a blank value). + if sep_pos != -1: + self.callback('field_name', data, i, sep_pos) + self.callback('field_end') + + i = sep_pos - 1 + state = STATE_BEFORE_FIELD + else: + # Otherwise, no separator in this block, so the + # rest of this chunk must be a name. + self.callback('field_name', data, i, length) + i = length + + else: + # We're parsing strictly. If we find a separator, + # this is an error - we require an equals sign. + if sep_pos != -1: + e = QuerystringParseError( + "When strict_parsing is True, we require an " + "equals sign in all field chunks. Did not " + "find one in the chunk that starts at %d" % + (i,) + ) + e.offset = i + raise e + + # No separator in the rest of this chunk, so it's just + # a field name. + self.callback('field_name', data, i, length) + i = length + + elif state == STATE_FIELD_DATA: + # Try finding either an ampersand or a semicolon after this + # position. + sep_pos = data.find(b'&', i) + if sep_pos == -1: + sep_pos = data.find(b';', i) + + # If we found it, callback this bit as data and then go back + # to expecting to find a field. + if sep_pos != -1: + self.callback('field_data', data, i, sep_pos) + self.callback('field_end') + + # Note that we go to the separator, which brings us to the + # "before field" state. This allows us to properly emit + # "field_start" events only when we actually have data for + # a field of some sort. + i = sep_pos - 1 + state = STATE_BEFORE_FIELD + + # Otherwise, emit the rest as data and finish. + else: + self.callback('field_data', data, i, length) + i = length + + else: # pragma: no cover (error case) + msg = "Reached an unknown state %d at %d" % (state, i) + self.logger.warning(msg) + e = QuerystringParseError(msg) + e.offset = i + raise e + + i += 1 + + self.state = state + self._found_sep = found_sep + return len(data) + + def finalize(self): + """Finalize this parser, which signals to that we are finished parsing, + if we're still in the middle of a field, an on_field_end callback, and + then the on_end callback. + """ + # If we're currently in the middle of a field, we finish it. + if self.state == STATE_FIELD_DATA: + self.callback('field_end') + self.callback('end') + + def __repr__(self): + return "{}(strict_parsing={!r}, max_size={!r})".format( + self.__class__.__name__, + self.strict_parsing, self.max_size + ) + + +class MultipartParser(BaseParser): + """This class is a streaming multipart/form-data parser. + + .. list-table:: + :widths: 15 10 30 + :header-rows: 1 + + * - Callback Name + - Parameters + - Description + * - on_part_begin + - None + - Called when a new part of the multipart message is encountered. + * - on_part_data + - data, start, end + - Called when a portion of a part's data is encountered. + * - on_part_end + - None + - Called when the end of a part is reached. + * - on_header_begin + - None + - Called when we've found a new header in a part of a multipart + message + * - on_header_field + - data, start, end + - Called each time an additional portion of a header is read (i.e. the + part of the header that is before the colon; the "Foo" in + "Foo: Bar"). + * - on_header_value + - data, start, end + - Called when we get data for a header. + * - on_header_end + - None + - Called when the current header is finished - i.e. we've reached the + newline at the end of the header. + * - on_headers_finished + - None + - Called when all headers are finished, and before the part data + starts. + * - on_end + - None + - Called when the parser is finished parsing all data. + + + :param boundary: The multipart boundary. This is required, and must match + what is given in the HTTP request - usually in the + Content-Type header. + + :param callbacks: A dictionary of callbacks. See the documentation for + :class:`BaseParser`. + + :param max_size: The maximum size of body to parse. Defaults to infinity - + i.e. unbounded. + """ + + def __init__(self, boundary, callbacks={}, max_size=float('inf')): + # Initialize parser state. + super().__init__() + self.state = STATE_START + self.index = self.flags = 0 + + self.callbacks = callbacks + + if not isinstance(max_size, Number) or max_size < 1: + raise ValueError("max_size must be a positive number, not %r" % + max_size) + self.max_size = max_size + self._current_size = 0 + + # Setup marks. These are used to track the state of data received. + self.marks = {} + + # TODO: Actually use this rather than the dumb version we currently use + # # Precompute the skip table for the Boyer-Moore-Horspool algorithm. + # skip = [len(boundary) for x in range(256)] + # for i in range(len(boundary) - 1): + # skip[ord_char(boundary[i])] = len(boundary) - i - 1 + # + # # We use a tuple since it's a constant, and marginally faster. + # self.skip = tuple(skip) + + # Save our boundary. + if isinstance(boundary, str): # pragma: no cover + boundary = boundary.encode('latin-1') + self.boundary = b'\r\n--' + boundary + + # Get a set of characters that belong to our boundary. + self.boundary_chars = frozenset(self.boundary) + + # We also create a lookbehind list. + # Note: the +8 is since we can have, at maximum, "\r\n--" + boundary + + # "--\r\n" at the final boundary, and the length of '\r\n--' and + # '--\r\n' is 8 bytes. + self.lookbehind = [NULL for x in range(len(boundary) + 8)] + + def write(self, data): + """Write some data to the parser, which will perform size verification, + and then parse the data into the appropriate location (e.g. header, + data, etc.), and pass this on to the underlying callback. If an error + is encountered, a MultipartParseError will be raised. The "offset" + attribute on the raised exception will be set to the offset of the byte + in the input chunk that caused the error. + + :param data: a bytestring + """ + # Handle sizing. + data_len = len(data) + if (self._current_size + data_len) > self.max_size: + # We truncate the length of data that we are to process. + new_size = int(self.max_size - self._current_size) + self.logger.warning("Current size is %d (max %d), so truncating " + "data length from %d to %d", + self._current_size, self.max_size, data_len, + new_size) + data_len = new_size + + l = 0 + try: + l = self._internal_write(data, data_len) + finally: + self._current_size += l + + return l + + def _internal_write(self, data, length): + # Get values from locals. + boundary = self.boundary + + # Get our state, flags and index. These are persisted between calls to + # this function. + state = self.state + index = self.index + flags = self.flags + + # Our index defaults to 0. + i = 0 + + # Set a mark. + def set_mark(name): + self.marks[name] = i + + # Remove a mark. + def delete_mark(name, reset=False): + self.marks.pop(name, None) + + # Helper function that makes calling a callback with data easier. The + # 'remaining' parameter will callback from the marked value until the + # end of the buffer, and reset the mark, instead of deleting it. This + # is used at the end of the function to call our callbacks with any + # remaining data in this chunk. + def data_callback(name, remaining=False): + marked_index = self.marks.get(name) + if marked_index is None: + return + + # If we're getting remaining data, we ignore the current i value + # and just call with the remaining data. + if remaining: + self.callback(name, data, marked_index, length) + self.marks[name] = 0 + + # Otherwise, we call it from the mark to the current byte we're + # processing. + else: + self.callback(name, data, marked_index, i) + self.marks.pop(name, None) + + # For each byte... + while i < length: + c = data[i] + + if state == STATE_START: + # Skip leading newlines + if c == CR or c == LF: + i += 1 + self.logger.debug("Skipping leading CR/LF at %d", i) + continue + + # index is used as in index into our boundary. Set to 0. + index = 0 + + # Move to the next state, but decrement i so that we re-process + # this character. + state = STATE_START_BOUNDARY + i -= 1 + + elif state == STATE_START_BOUNDARY: + # Check to ensure that the last 2 characters in our boundary + # are CRLF. + if index == len(boundary) - 2: + if c != CR: + # Error! + msg = "Did not find CR at end of boundary (%d)" % (i,) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + index += 1 + + elif index == len(boundary) - 2 + 1: + if c != LF: + msg = "Did not find LF at end of boundary (%d)" % (i,) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + # The index is now used for indexing into our boundary. + index = 0 + + # Callback for the start of a part. + self.callback('part_begin') + + # Move to the next character and state. + state = STATE_HEADER_FIELD_START + + else: + # Check to ensure our boundary matches + if c != boundary[index + 2]: + msg = "Did not find boundary character %r at index " \ + "%d" % (c, index + 2) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + # Increment index into boundary and continue. + index += 1 + + elif state == STATE_HEADER_FIELD_START: + # Mark the start of a header field here, reset the index, and + # continue parsing our header field. + index = 0 + + # Set a mark of our header field. + set_mark('header_field') + + # Move to parsing header fields. + state = STATE_HEADER_FIELD + i -= 1 + + elif state == STATE_HEADER_FIELD: + # If we've reached a CR at the beginning of a header, it means + # that we've reached the second of 2 newlines, and so there are + # no more headers to parse. + if c == CR: + delete_mark('header_field') + state = STATE_HEADERS_ALMOST_DONE + i += 1 + continue + + # Increment our index in the header. + index += 1 + + # Do nothing if we encounter a hyphen. + if c == HYPHEN: + pass + + # If we've reached a colon, we're done with this header. + elif c == COLON: + # A 0-length header is an error. + if index == 1: + msg = "Found 0-length header at %d" % (i,) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + # Call our callback with the header field. + data_callback('header_field') + + # Move to parsing the header value. + state = STATE_HEADER_VALUE_START + + else: + # Lower-case this character, and ensure that it is in fact + # a valid letter. If not, it's an error. + cl = lower_char(c) + if cl < LOWER_A or cl > LOWER_Z: + msg = "Found non-alphanumeric character %r in " \ + "header at %d" % (c, i) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + elif state == STATE_HEADER_VALUE_START: + # Skip leading spaces. + if c == SPACE: + i += 1 + continue + + # Mark the start of the header value. + set_mark('header_value') + + # Move to the header-value state, reprocessing this character. + state = STATE_HEADER_VALUE + i -= 1 + + elif state == STATE_HEADER_VALUE: + # If we've got a CR, we're nearly done our headers. Otherwise, + # we do nothing and just move past this character. + if c == CR: + data_callback('header_value') + self.callback('header_end') + state = STATE_HEADER_VALUE_ALMOST_DONE + + elif state == STATE_HEADER_VALUE_ALMOST_DONE: + # The last character should be a LF. If not, it's an error. + if c != LF: + msg = "Did not find LF character at end of header " \ + "(found %r)" % (c,) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + # Move back to the start of another header. Note that if that + # state detects ANOTHER newline, it'll trigger the end of our + # headers. + state = STATE_HEADER_FIELD_START + + elif state == STATE_HEADERS_ALMOST_DONE: + # We're almost done our headers. This is reached when we parse + # a CR at the beginning of a header, so our next character + # should be a LF, or it's an error. + if c != LF: + msg = f"Did not find LF at end of headers (found {c!r})" + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + self.callback('headers_finished') + state = STATE_PART_DATA_START + + elif state == STATE_PART_DATA_START: + # Mark the start of our part data. + set_mark('part_data') + + # Start processing part data, including this character. + state = STATE_PART_DATA + i -= 1 + + elif state == STATE_PART_DATA: + # We're processing our part data right now. During this, we + # need to efficiently search for our boundary, since any data + # on any number of lines can be a part of the current data. + # We use the Boyer-Moore-Horspool algorithm to efficiently + # search through the remainder of the buffer looking for our + # boundary. + + # Save the current value of our index. We use this in case we + # find part of a boundary, but it doesn't match fully. + prev_index = index + + # Set up variables. + boundary_length = len(boundary) + boundary_end = boundary_length - 1 + data_length = length + boundary_chars = self.boundary_chars + + # If our index is 0, we're starting a new part, so start our + # search. + if index == 0: + # Search forward until we either hit the end of our buffer, + # or reach a character that's in our boundary. + i += boundary_end + while i < data_length - 1 and data[i] not in boundary_chars: + i += boundary_length + + # Reset i back the length of our boundary, which is the + # earliest possible location that could be our match (i.e. + # if we've just broken out of our loop since we saw the + # last character in our boundary) + i -= boundary_end + c = data[i] + + # Now, we have a couple of cases here. If our index is before + # the end of the boundary... + if index < boundary_length: + # If the character matches... + if boundary[index] == c: + # If we found a match for our boundary, we send the + # existing data. + if index == 0: + data_callback('part_data') + + # The current character matches, so continue! + index += 1 + else: + index = 0 + + # Our index is equal to the length of our boundary! + elif index == boundary_length: + # First we increment it. + index += 1 + + # Now, if we've reached a newline, we need to set this as + # the potential end of our boundary. + if c == CR: + flags |= FLAG_PART_BOUNDARY + + # Otherwise, if this is a hyphen, we might be at the last + # of all boundaries. + elif c == HYPHEN: + flags |= FLAG_LAST_BOUNDARY + + # Otherwise, we reset our index, since this isn't either a + # newline or a hyphen. + else: + index = 0 + + # Our index is right after the part boundary, which should be + # a LF. + elif index == boundary_length + 1: + # If we're at a part boundary (i.e. we've seen a CR + # character already)... + if flags & FLAG_PART_BOUNDARY: + # We need a LF character next. + if c == LF: + # Unset the part boundary flag. + flags &= (~FLAG_PART_BOUNDARY) + + # Callback indicating that we've reached the end of + # a part, and are starting a new one. + self.callback('part_end') + self.callback('part_begin') + + # Move to parsing new headers. + index = 0 + state = STATE_HEADER_FIELD_START + i += 1 + continue + + # We didn't find an LF character, so no match. Reset + # our index and clear our flag. + index = 0 + flags &= (~FLAG_PART_BOUNDARY) + + # Otherwise, if we're at the last boundary (i.e. we've + # seen a hyphen already)... + elif flags & FLAG_LAST_BOUNDARY: + # We need a second hyphen here. + if c == HYPHEN: + # Callback to end the current part, and then the + # message. + self.callback('part_end') + self.callback('end') + state = STATE_END + else: + # No match, so reset index. + index = 0 + + # If we have an index, we need to keep this byte for later, in + # case we can't match the full boundary. + if index > 0: + self.lookbehind[index - 1] = c + + # Otherwise, our index is 0. If the previous index is not, it + # means we reset something, and we need to take the data we + # thought was part of our boundary and send it along as actual + # data. + elif prev_index > 0: + # Callback to write the saved data. + lb_data = join_bytes(self.lookbehind) + self.callback('part_data', lb_data, 0, prev_index) + + # Overwrite our previous index. + prev_index = 0 + + # Re-set our mark for part data. + set_mark('part_data') + + # Re-consider the current character, since this could be + # the start of the boundary itself. + i -= 1 + + elif state == STATE_END: + # Do nothing and just consume a byte in the end state. + if c not in (CR, LF): + self.logger.warning("Consuming a byte '0x%x' in the end state", c) + + else: # pragma: no cover (error case) + # We got into a strange state somehow! Just stop processing. + msg = "Reached an unknown state %d at %d" % (state, i) + self.logger.warning(msg) + e = MultipartParseError(msg) + e.offset = i + raise e + + # Move to the next byte. + i += 1 + + # We call our callbacks with any remaining data. Note that we pass + # the 'remaining' flag, which sets the mark back to 0 instead of + # deleting it, if it's found. This is because, if the mark is found + # at this point, we assume that there's data for one of these things + # that has been parsed, but not yet emitted. And, as such, it implies + # that we haven't yet reached the end of this 'thing'. So, by setting + # the mark to 0, we cause any data callbacks that take place in future + # calls to this function to start from the beginning of that buffer. + data_callback('header_field', True) + data_callback('header_value', True) + data_callback('part_data', True) + + # Save values to locals. + self.state = state + self.index = index + self.flags = flags + + # Return our data length to indicate no errors, and that we processed + # all of it. + return length + + def finalize(self): + """Finalize this parser, which signals to that we are finished parsing. + + Note: It does not currently, but in the future, it will verify that we + are in the final state of the parser (i.e. the end of the multipart + message is well-formed), and, if not, throw an error. + """ + # TODO: verify that we're in the state STATE_END, otherwise throw an + # error or otherwise state that we're not finished parsing. + pass + + def __repr__(self): + return f"{self.__class__.__name__}(boundary={self.boundary!r})" + + +class FormParser: + """This class is the all-in-one form parser. Given all the information + necessary to parse a form, it will instantiate the correct parser, create + the proper :class:`Field` and :class:`File` classes to store the data that + is parsed, and call the two given callbacks with each field and file as + they become available. + + :param content_type: The Content-Type of the incoming request. This is + used to select the appropriate parser. + + :param on_field: The callback to call when a field has been parsed and is + ready for usage. See above for parameters. + + :param on_file: The callback to call when a file has been parsed and is + ready for usage. See above for parameters. + + :param on_end: An optional callback to call when all fields and files in a + request has been parsed. Can be None. + + :param boundary: If the request is a multipart/form-data request, this + should be the boundary of the request, as given in the + Content-Type header, as a bytestring. + + :param file_name: If the request is of type application/octet-stream, then + the body of the request will not contain any information + about the uploaded file. In such cases, you can provide + the file name of the uploaded file manually. + + :param FileClass: The class to use for uploaded files. Defaults to + :class:`File`, but you can provide your own class if you + wish to customize behaviour. The class will be + instantiated as FileClass(file_name, field_name), and it + must provide the following functions:: + file_instance.write(data) + file_instance.finalize() + file_instance.close() + + :param FieldClass: The class to use for uploaded fields. Defaults to + :class:`Field`, but you can provide your own class if + you wish to customize behaviour. The class will be + instantiated as FieldClass(field_name), and it must + provide the following functions:: + field_instance.write(data) + field_instance.finalize() + field_instance.close() + + :param config: Configuration to use for this FormParser. The default + values are taken from the DEFAULT_CONFIG value, and then + any keys present in this dictionary will overwrite the + default values. + + """ + #: This is the default configuration for our form parser. + #: Note: all file sizes should be in bytes. + DEFAULT_CONFIG = { + 'MAX_BODY_SIZE': float('inf'), + 'MAX_MEMORY_FILE_SIZE': 1 * 1024 * 1024, + 'UPLOAD_DIR': None, + 'UPLOAD_KEEP_FILENAME': False, + 'UPLOAD_KEEP_EXTENSIONS': False, + + # Error on invalid Content-Transfer-Encoding? + 'UPLOAD_ERROR_ON_BAD_CTE': False, + } + + def __init__(self, content_type, on_field, on_file, on_end=None, + boundary=None, file_name=None, FileClass=File, + FieldClass=Field, config={}): + + self.logger = logging.getLogger(__name__) + + # Save variables. + self.content_type = content_type + self.boundary = boundary + self.bytes_received = 0 + self.parser = None + + # Save callbacks. + self.on_field = on_field + self.on_file = on_file + self.on_end = on_end + + # Save classes. + self.FileClass = File + self.FieldClass = Field + + # Set configuration options. + self.config = self.DEFAULT_CONFIG.copy() + self.config.update(config) + + # Depending on the Content-Type, we instantiate the correct parser. + if content_type == 'application/octet-stream': + # Work around the lack of 'nonlocal' in Py2 + class vars: + f = None + + def on_start(): + vars.f = FileClass(file_name, None, config=self.config) + + def on_data(data, start, end): + vars.f.write(data[start:end]) + + def on_end(): + # Finalize the file itself. + vars.f.finalize() + + # Call our callback. + on_file(vars.f) + + # Call the on-end callback. + if self.on_end is not None: + self.on_end() + + callbacks = { + 'on_start': on_start, + 'on_data': on_data, + 'on_end': on_end, + } + + # Instantiate an octet-stream parser + parser = OctetStreamParser(callbacks, + max_size=self.config['MAX_BODY_SIZE']) + + elif (content_type == 'application/x-www-form-urlencoded' or + content_type == 'application/x-url-encoded'): + + name_buffer = [] + + class vars: + f = None + + def on_field_start(): + pass + + def on_field_name(data, start, end): + name_buffer.append(data[start:end]) + + def on_field_data(data, start, end): + if vars.f is None: + vars.f = FieldClass(b''.join(name_buffer)) + del name_buffer[:] + vars.f.write(data[start:end]) + + def on_field_end(): + # Finalize and call callback. + if vars.f is None: + # If we get here, it's because there was no field data. + # We create a field, set it to None, and then continue. + vars.f = FieldClass(b''.join(name_buffer)) + del name_buffer[:] + vars.f.set_none() + + vars.f.finalize() + on_field(vars.f) + vars.f = None + + def on_end(): + if self.on_end is not None: + self.on_end() + + # Setup callbacks. + callbacks = { + 'on_field_start': on_field_start, + 'on_field_name': on_field_name, + 'on_field_data': on_field_data, + 'on_field_end': on_field_end, + 'on_end': on_end, + } + + # Instantiate parser. + parser = QuerystringParser( + callbacks=callbacks, + max_size=self.config['MAX_BODY_SIZE'] + ) + + elif content_type == 'multipart/form-data': + if boundary is None: + self.logger.error("No boundary given") + raise FormParserError("No boundary given") + + header_name = [] + header_value = [] + headers = {} + + # No 'nonlocal' on Python 2 :-( + class vars: + f = None + writer = None + is_file = False + + def on_part_begin(): + pass + + def on_part_data(data, start, end): + bytes_processed = vars.writer.write(data[start:end]) + # TODO: check for error here. + return bytes_processed + + def on_part_end(): + vars.f.finalize() + if vars.is_file: + on_file(vars.f) + else: + on_field(vars.f) + + def on_header_field(data, start, end): + header_name.append(data[start:end]) + + def on_header_value(data, start, end): + header_value.append(data[start:end]) + + def on_header_end(): + headers[b''.join(header_name)] = b''.join(header_value) + del header_name[:] + del header_value[:] + + def on_headers_finished(): + # Reset the 'is file' flag. + vars.is_file = False + + # Parse the content-disposition header. + # TODO: handle mixed case + content_disp = headers.get(b'Content-Disposition') + disp, options = parse_options_header(content_disp) + + # Get the field and filename. + field_name = options.get(b'name') + file_name = options.get(b'filename') + # TODO: check for errors + + # Create the proper class. + if file_name is None: + vars.f = FieldClass(field_name) + else: + vars.f = FileClass(file_name, field_name, config=self.config) + vars.is_file = True + + # Parse the given Content-Transfer-Encoding to determine what + # we need to do with the incoming data. + # TODO: check that we properly handle 8bit / 7bit encoding. + transfer_encoding = headers.get(b'Content-Transfer-Encoding', + b'7bit') + + if (transfer_encoding == b'binary' or + transfer_encoding == b'8bit' or + transfer_encoding == b'7bit'): + vars.writer = vars.f + + elif transfer_encoding == b'base64': + vars.writer = Base64Decoder(vars.f) + + elif transfer_encoding == b'quoted-printable': + vars.writer = QuotedPrintableDecoder(vars.f) + + else: + self.logger.warning("Unknown Content-Transfer-Encoding: " + "%r", transfer_encoding) + if self.config['UPLOAD_ERROR_ON_BAD_CTE']: + raise FormParserError( + 'Unknown Content-Transfer-Encoding "{}"'.format( + transfer_encoding + ) + ) + else: + # If we aren't erroring, then we just treat this as an + # unencoded Content-Transfer-Encoding. + vars.writer = vars.f + + def on_end(): + vars.writer.finalize() + if self.on_end is not None: + self.on_end() + + # These are our callbacks for the parser. + callbacks = { + 'on_part_begin': on_part_begin, + 'on_part_data': on_part_data, + 'on_part_end': on_part_end, + 'on_header_field': on_header_field, + 'on_header_value': on_header_value, + 'on_header_end': on_header_end, + 'on_headers_finished': on_headers_finished, + 'on_end': on_end, + } + + # Instantiate a multipart parser. + parser = MultipartParser(boundary, callbacks, + max_size=self.config['MAX_BODY_SIZE']) + + else: + self.logger.warning("Unknown Content-Type: %r", content_type) + raise FormParserError("Unknown Content-Type: {}".format( + content_type + )) + + self.parser = parser + + def write(self, data): + """Write some data. The parser will forward this to the appropriate + underlying parser. + + :param data: a bytestring + """ + self.bytes_received += len(data) + # TODO: check the parser's return value for errors? + return self.parser.write(data) + + def finalize(self): + """Finalize the parser.""" + if self.parser is not None and hasattr(self.parser, 'finalize'): + self.parser.finalize() + + def close(self): + """Close the parser.""" + if self.parser is not None and hasattr(self.parser, 'close'): + self.parser.close() + + def __repr__(self): + return "{}(content_type={!r}, parser={!r})".format( + self.__class__.__name__, + self.content_type, + self.parser, + ) + + +def create_form_parser(headers, on_field, on_file, trust_x_headers=False, + config={}): + """This function is a helper function to aid in creating a FormParser + instances. Given a dictionary-like headers object, it will determine + the correct information needed, instantiate a FormParser with the + appropriate values and given callbacks, and then return the corresponding + parser. + + :param headers: A dictionary-like object of HTTP headers. The only + required header is Content-Type. + + :param on_field: Callback to call with each parsed field. + + :param on_file: Callback to call with each parsed file. + + :param trust_x_headers: Whether or not to trust information received from + certain X-Headers - for example, the file name from + X-File-Name. + + :param config: Configuration variables to pass to the FormParser. + """ + content_type = headers.get('Content-Type') + if content_type is None: + logging.getLogger(__name__).warning("No Content-Type header given") + raise ValueError("No Content-Type header given!") + + # Boundaries are optional (the FormParser will raise if one is needed + # but not given). + content_type, params = parse_options_header(content_type) + boundary = params.get(b'boundary') + + # We need content_type to be a string, not a bytes object. + content_type = content_type.decode('latin-1') + + # File names are optional. + file_name = headers.get('X-File-Name') + + # Instantiate a form parser. + form_parser = FormParser(content_type, + on_field, + on_file, + boundary=boundary, + file_name=file_name, + config=config) + + # Return our parser. + return form_parser + + +def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, + **kwargs): + """This function is useful if you just want to parse a request body, + without too much work. Pass it a dictionary-like object of the request's + headers, and a file-like object for the input stream, along with two + callbacks that will get called whenever a field or file is parsed. + + :param headers: A dictionary-like object of HTTP headers. The only + required header is Content-Type. + + :param input_stream: A file-like object that represents the request body. + The read() method must return bytestrings. + + :param on_field: Callback to call with each parsed field. + + :param on_file: Callback to call with each parsed file. + + :param chunk_size: The maximum size to read from the input stream and write + to the parser at one time. Defaults to 1 MiB. + """ + + # Create our form parser. + parser = create_form_parser(headers, on_field, on_file) + + # Read chunks of 100KiB and write to the parser, but never read more than + # the given Content-Length, if any. + content_length = headers.get('Content-Length') + if content_length is not None: + content_length = int(content_length) + else: + content_length = float('inf') + bytes_read = 0 + + while True: + # Read only up to the Content-Length given. + max_readable = min(content_length - bytes_read, 1048576) + buff = input_stream.read(max_readable) + + # Write to the parser and update our length. + parser.write(buff) + bytes_read += len(buff) + + # If we get a buffer that's smaller than the size requested, or if we + # have read up to our content length, we're done. + if len(buff) != max_readable or bytes_read == content_length: + break + + # Tell our parser that we're done writing data. + parser.finalize() diff --git a/venv/lib/python3.12/site-packages/kafka/benchmarks/__init__.py b/venv/lib/python3.12/site-packages/multipart/tests/__init__.py similarity index 100% rename from venv/lib/python3.12/site-packages/kafka/benchmarks/__init__.py rename to venv/lib/python3.12/site-packages/multipart/tests/__init__.py diff --git a/venv/lib/python3.12/site-packages/multipart/tests/compat.py b/venv/lib/python3.12/site-packages/multipart/tests/compat.py new file mode 100644 index 0000000..897188d --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/compat.py @@ -0,0 +1,133 @@ +import os +import re +import sys +import types +import functools + + +def ensure_in_path(path): + """ + Ensure that a given path is in the sys.path array + """ + if not os.path.isdir(path): + raise RuntimeError('Tried to add nonexisting path') + + def _samefile(x, y): + try: + return os.path.samefile(x, y) + except OSError: + return False + except AttributeError: + # Probably on Windows. + path1 = os.path.abspath(x).lower() + path2 = os.path.abspath(y).lower() + return path1 == path2 + + # Remove existing copies of it. + for pth in sys.path: + if _samefile(pth, path): + sys.path.remove(pth) + + # Add it at the beginning. + sys.path.insert(0, path) + + +# Check if pytest is imported. If so, we use it to create marking decorators. +# If not, we just create a function that does nothing. +try: + import pytest +except ImportError: + pytest = None + +if pytest is not None: + slow_test = pytest.mark.slow_test + xfail = pytest.mark.xfail + +else: + slow_test = lambda x: x + + def xfail(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], types.FunctionType): + return args[0] + + return lambda x: x + + +# We don't use the pytest parametrizing function, since it seems to break +# with unittest.TestCase subclasses. +def parametrize(field_names, field_values): + # If we're not given a list of field names, we make it. + if not isinstance(field_names, (tuple, list)): + field_names = (field_names,) + field_values = [(val,) for val in field_values] + + # Create a decorator that saves this list of field names and values on the + # function for later parametrizing. + def decorator(func): + func.__dict__['param_names'] = field_names + func.__dict__['param_values'] = field_values + return func + + return decorator + + +# This is a metaclass that actually performs the parametrization. +class ParametrizingMetaclass(type): + IDENTIFIER_RE = re.compile('[^A-Za-z0-9]') + + def __new__(klass, name, bases, attrs): + new_attrs = attrs.copy() + for attr_name, attr in attrs.items(): + # We only care about functions + if not isinstance(attr, types.FunctionType): + continue + + param_names = attr.__dict__.pop('param_names', None) + param_values = attr.__dict__.pop('param_values', None) + if param_names is None or param_values is None: + continue + + # Create multiple copies of the function. + for i, values in enumerate(param_values): + assert len(param_names) == len(values) + + # Get a repr of the values, and fix it to be a valid identifier + human = '_'.join( + [klass.IDENTIFIER_RE.sub('', repr(x)) for x in values] + ) + + # Create a new name. + # new_name = attr.__name__ + "_%d" % i + new_name = attr.__name__ + "__" + human + + # Create a replacement function. + def create_new_func(func, names, values): + # Create a kwargs dictionary. + kwargs = dict(zip(names, values)) + + @functools.wraps(func) + def new_func(self): + return func(self, **kwargs) + + # Manually set the name and return the new function. + new_func.__name__ = new_name + return new_func + + # Actually create the new function. + new_func = create_new_func(attr, param_names, values) + + # Save this new function in our attrs dict. + new_attrs[new_name] = new_func + + # Remove the old attribute from our new dictionary. + del new_attrs[attr_name] + + # We create the class as normal, except we use our new attributes. + return type.__new__(klass, name, bases, new_attrs) + + +# This is a class decorator that actually applies the above metaclass. +def parametrize_class(klass): + return ParametrizingMetaclass(klass.__name__, + klass.__bases__, + klass.__dict__) diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.http new file mode 100644 index 0000000..0c81dae --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content- isposition: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.yaml new file mode 100644 index 0000000..c9b55f2 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 51 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.http new file mode 100644 index 0000000..f3dc834 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; n me="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.yaml new file mode 100644 index 0000000..a6efa7d --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/CR_in_header_value.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 76 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.http new file mode 100644 index 0000000..7d97e51 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.http @@ -0,0 +1,13 @@ +----boundary +Content-Disposition: form-data; name="file"; filename="test.txt" +Content-Type: text/plain + +--boundari +--boundaryq--boundary q--boundarq +--bounaryd-- +--notbound-- +--mismatch +--mismatch-- +--boundary-Q +--boundary Q--boundaryQ +----boundary-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.yaml new file mode 100644 index 0000000..235493e --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary.yaml @@ -0,0 +1,8 @@ +boundary: --boundary +expected: + - name: file + type: file + file_name: test.txt + data: !!binary | + LS1ib3VuZGFyaQ0KLS1ib3VuZGFyeXEtLWJvdW5kYXJ5DXEtLWJvdW5kYXJxDQotLWJvdW5hcnlkLS0NCi0tbm90Ym91bmQtLQ0KLS1taXNtYXRjaA0KLS1taXNtYXRjaC0tDQotLWJvdW5kYXJ5LVENCi0tYm91bmRhcnkNUS0tYm91bmRhcnlR + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.http new file mode 100644 index 0000000..327cc9b --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.http @@ -0,0 +1,6 @@ +----boundary +Content-Disposition: form-data; name="field" + +QQQQQQQQQQQQQQQQQQQQ +----boundaryQQQQQQQQQQQQQQQQQQQQ +----boundary-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml new file mode 100644 index 0000000..921637f --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml @@ -0,0 +1,8 @@ +boundary: --boundary +expected: + - name: field + type: field + data: !!binary | + UVFRUVFRUVFRUVFRUVFRUVFRUVENCi0tLS1ib3VuZGFyeVFRUVFRUVFRUVFRUVFRUVFRUVFR + + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.http new file mode 100644 index 0000000..e9a5a6c --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.http @@ -0,0 +1,6 @@ +----boundary +Content-Disposition: form-data; name="field" + +QQQQQQQQQQQQQQQQQQQQ +----boundary QQQQQQQQQQQQQQQQQQQQ +----boundary-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml new file mode 100644 index 0000000..7346e03 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml @@ -0,0 +1,8 @@ +boundary: --boundary +expected: + - name: field + type: field + data: !!binary | + UVFRUVFRUVFRUVFRUVFRUVFRUVENCi0tLS1ib3VuZGFyeQ1RUVFRUVFRUVFRUVFRUVFRUVFRUQ== + + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http new file mode 100644 index 0000000..9261f1b --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http @@ -0,0 +1,6 @@ +----boundary +Content-Disposition: form-data; name="field" + +QQQQQQQQQQQQQQQQQQQQ +----boundary-QQQQQQQQQQQQQQQQQQQQ +----boundary-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml new file mode 100644 index 0000000..17133c9 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml @@ -0,0 +1,8 @@ +boundary: --boundary +expected: + - name: field + type: field + data: !!binary | + UVFRUVFRUVFRUVFRUVFRUVFRUVENCi0tLS1ib3VuZGFyeS1RUVFRUVFRUVFRUVFRUVFRUVFRUQ== + + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.http new file mode 100644 index 0000000..de14ae1 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.http @@ -0,0 +1,4 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + QThis is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.yaml new file mode 100644 index 0000000..5fc1ec0 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_end_of_headers.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 89 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.http new file mode 100644 index 0000000..b90d00d --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-999position: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.yaml new file mode 100644 index 0000000..9d5f62a --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_header_char.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 50 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.http new file mode 100644 index 0000000..6aab9da --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.http @@ -0,0 +1,5 @@ +------WebQitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.yaml new file mode 100644 index 0000000..ffa4eb7 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/bad_initial_boundary.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 9 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.http new file mode 100644 index 0000000..3d2980f --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.http @@ -0,0 +1,7 @@ +----boundary +Content-Disposition: form-data; name="file"; filename="test.txt" +Content-Type: text/plain +Content-Transfer-Encoding: base64 + +VGVzdCAxMjM= +----boundary-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.yaml new file mode 100644 index 0000000..1033150 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/base64_encoding.yaml @@ -0,0 +1,7 @@ +boundary: --boundary +expected: + - name: file + type: file + file_name: test.txt + data: !!binary | + VGVzdCAxMjM= diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.http new file mode 100644 index 0000000..bd593f4 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.yaml new file mode 100644 index 0000000..574ed4c --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/empty_header.yaml @@ -0,0 +1,3 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + error: 42 diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.http new file mode 100644 index 0000000..4f13037 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.http @@ -0,0 +1,9 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field1" + +field1 +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field2" + +field2 +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.yaml new file mode 100644 index 0000000..cb2c2d6 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_fields.yaml @@ -0,0 +1,10 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + - name: field1 + type: field + data: !!binary | + ZmllbGQx + - name: field2 + type: field + data: !!binary | + ZmllbGQy diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.http new file mode 100644 index 0000000..fd2e468 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.http @@ -0,0 +1,11 @@ +------WebKitFormBoundarygbACTUR58IyeurVf +Content-Disposition: form-data; name="file1"; filename="test1.txt" +Content-Type: text/plain + +Test file #1 +------WebKitFormBoundarygbACTUR58IyeurVf +Content-Disposition: form-data; name="file2"; filename="test2.txt" +Content-Type: text/plain + +Test file #2 +------WebKitFormBoundarygbACTUR58IyeurVf-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.yaml new file mode 100644 index 0000000..3bf70e2 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/multiple_files.yaml @@ -0,0 +1,13 @@ +boundary: ----WebKitFormBoundarygbACTUR58IyeurVf +expected: + - name: file1 + type: file + file_name: test1.txt + data: !!binary | + VGVzdCBmaWxlICMx + - name: file2 + type: file + file_name: test2.txt + data: !!binary | + VGVzdCBmaWxlICMy + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.http new file mode 100644 index 0000000..09e555a --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.http @@ -0,0 +1,7 @@ +----boundary +Content-Disposition: form-data; name="file"; filename="test.txt" +Content-Type: text/plain +Content-Transfer-Encoding: quoted-printable + +foo=3Dbar +----boundary-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.yaml new file mode 100644 index 0000000..2c6bbfb --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/quoted_printable_encoding.yaml @@ -0,0 +1,7 @@ +boundary: --boundary +expected: + - name: file + type: file + file_name: test.txt + data: !!binary | + Zm9vPWJhcg== diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.http new file mode 100644 index 0000000..8b90b73 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.yaml new file mode 100644 index 0000000..7690f08 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field.yaml @@ -0,0 +1,6 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + - name: field + type: field + data: !!binary | + VGhpcyBpcyBhIHRlc3Qu diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.http new file mode 100644 index 0000000..5a61d83 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.http @@ -0,0 +1,5 @@ +--boundary +Content-Disposition: form-data; name="field" + +0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ +--boundary-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.yaml new file mode 100644 index 0000000..efb1b32 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_blocks.yaml @@ -0,0 +1,6 @@ +boundary: --boundary +expected: + - name: field + type: field + data: !!binary | + MDEyMzQ1Njc4OUFCQ0RFRkdISUowMTIzNDU2Nzg5QUJDREVGR0hJSg== diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.http new file mode 100644 index 0000000..46bd7e1 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.http @@ -0,0 +1,5 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + +qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.yaml new file mode 100644 index 0000000..5a11840 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_longer.yaml @@ -0,0 +1,6 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + - name: field + type: field + data: !!binary | + cXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXFxcXE= diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.http new file mode 100644 index 0000000..34a822b --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.http @@ -0,0 +1,10 @@ +--boundary +Content-Disposition: form-data; name="field" + +test1 +--boundary +Content-Disposition: form-data; name="file"; filename="file.txt" +Content-Type: text/plain + +test2 +--boundary-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.yaml new file mode 100644 index 0000000..47c8d6e --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_single_file.yaml @@ -0,0 +1,13 @@ +boundary: boundary +expected: + - name: field + type: field + data: !!binary | + dGVzdDE= + - name: file + type: file + file_name: file.txt + data: !!binary | + dGVzdDI= + + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.http new file mode 100644 index 0000000..10ebc2e --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.http @@ -0,0 +1,7 @@ + + +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- \ No newline at end of file diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.yaml new file mode 100644 index 0000000..7690f08 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_field_with_leading_newlines.yaml @@ -0,0 +1,6 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + - name: field + type: field + data: !!binary | + VGhpcyBpcyBhIHRlc3Qu diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.http new file mode 100644 index 0000000..104bfd0 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.http @@ -0,0 +1,6 @@ +------WebKitFormBoundary5BZGOJCWtXGYC9HW +Content-Disposition: form-data; name="file"; filename="test.txt" +Content-Type: text/plain + +This is a test file. +------WebKitFormBoundary5BZGOJCWtXGYC9HW-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.yaml new file mode 100644 index 0000000..2a8e005 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/single_file.yaml @@ -0,0 +1,8 @@ +boundary: ----WebKitFormBoundary5BZGOJCWtXGYC9HW +expected: + - name: file + type: file + file_name: test.txt + data: !!binary | + VGhpcyBpcyBhIHRlc3QgZmlsZS4= + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.http b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.http new file mode 100644 index 0000000..c26df08 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.http @@ -0,0 +1,6 @@ +------WebKitFormBoundaryI9SCEFp2lpx5DR2K +Content-Disposition: form-data; name="file"; filename="???.txt" +Content-Type: text/plain + +これはテストです。 +------WebKitFormBoundaryI9SCEFp2lpx5DR2K-- diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.yaml b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.yaml new file mode 100644 index 0000000..507ba2c --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_data/http/utf8_filename.yaml @@ -0,0 +1,8 @@ +boundary: ----WebKitFormBoundaryI9SCEFp2lpx5DR2K +expected: + - name: file + type: file + file_name: ???.txt + data: !!binary | + 44GT44KM44Gv44OG44K544OI44Gn44GZ44CC + diff --git a/venv/lib/python3.12/site-packages/multipart/tests/test_multipart.py b/venv/lib/python3.12/site-packages/multipart/tests/test_multipart.py new file mode 100644 index 0000000..089f451 --- /dev/null +++ b/venv/lib/python3.12/site-packages/multipart/tests/test_multipart.py @@ -0,0 +1,1305 @@ +import os +import sys +import glob +import yaml +import base64 +import random +import tempfile +import unittest +from .compat import ( + parametrize, + parametrize_class, + slow_test, +) +from io import BytesIO +from unittest.mock import MagicMock, Mock, patch + +from ..multipart import * + + +# Get the current directory for our later test cases. +curr_dir = os.path.abspath(os.path.dirname(__file__)) + + +def force_bytes(val): + if isinstance(val, str): + val = val.encode(sys.getfilesystemencoding()) + + return val + + +class TestField(unittest.TestCase): + def setUp(self): + self.f = Field('foo') + + def test_name(self): + self.assertEqual(self.f.field_name, 'foo') + + def test_data(self): + self.f.write(b'test123') + self.assertEqual(self.f.value, b'test123') + + def test_cache_expiration(self): + self.f.write(b'test') + self.assertEqual(self.f.value, b'test') + self.f.write(b'123') + self.assertEqual(self.f.value, b'test123') + + def test_finalize(self): + self.f.write(b'test123') + self.f.finalize() + self.assertEqual(self.f.value, b'test123') + + def test_close(self): + self.f.write(b'test123') + self.f.close() + self.assertEqual(self.f.value, b'test123') + + def test_from_value(self): + f = Field.from_value(b'name', b'value') + self.assertEqual(f.field_name, b'name') + self.assertEqual(f.value, b'value') + + f2 = Field.from_value(b'name', None) + self.assertEqual(f2.value, None) + + def test_equality(self): + f1 = Field.from_value(b'name', b'value') + f2 = Field.from_value(b'name', b'value') + + self.assertEqual(f1, f2) + + def test_equality_with_other(self): + f = Field.from_value(b'foo', b'bar') + self.assertFalse(f == b'foo') + self.assertFalse(b'foo' == f) + + def test_set_none(self): + f = Field(b'foo') + self.assertEqual(f.value, b'') + + f.set_none() + self.assertEqual(f.value, None) + + +class TestFile(unittest.TestCase): + def setUp(self): + self.c = {} + self.d = force_bytes(tempfile.mkdtemp()) + self.f = File(b'foo.txt', config=self.c) + + def assert_data(self, data): + f = self.f.file_object + f.seek(0) + self.assertEqual(f.read(), data) + f.seek(0) + f.truncate() + + def assert_exists(self): + full_path = os.path.join(self.d, self.f.actual_file_name) + self.assertTrue(os.path.exists(full_path)) + + def test_simple(self): + self.f.write(b'foobar') + self.assert_data(b'foobar') + + def test_invalid_write(self): + m = Mock() + m.write.return_value = 5 + self.f._fileobj = m + v = self.f.write(b'foobar') + self.assertEqual(v, 5) + + def test_file_fallback(self): + self.c['MAX_MEMORY_FILE_SIZE'] = 1 + + self.f.write(b'1') + self.assertTrue(self.f.in_memory) + self.assert_data(b'1') + + self.f.write(b'123') + self.assertFalse(self.f.in_memory) + self.assert_data(b'123') + + # Test flushing too. + old_obj = self.f.file_object + self.f.flush_to_disk() + self.assertFalse(self.f.in_memory) + self.assertIs(self.f.file_object, old_obj) + + def test_file_fallback_with_data(self): + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + self.f.write(b'1' * 10) + self.assertTrue(self.f.in_memory) + + self.f.write(b'2' * 10) + self.assertFalse(self.f.in_memory) + + self.assert_data(b'11111111112222222222') + + def test_file_name(self): + # Write to this dir. + self.c['UPLOAD_DIR'] = self.d + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + # Write. + self.f.write(b'12345678901') + self.assertFalse(self.f.in_memory) + + # Assert that the file exists + self.assertIsNotNone(self.f.actual_file_name) + self.assert_exists() + + def test_file_full_name(self): + # Write to this dir. + self.c['UPLOAD_DIR'] = self.d + self.c['UPLOAD_KEEP_FILENAME'] = True + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + # Write. + self.f.write(b'12345678901') + self.assertFalse(self.f.in_memory) + + # Assert that the file exists + self.assertEqual(self.f.actual_file_name, b'foo') + self.assert_exists() + + def test_file_full_name_with_ext(self): + self.c['UPLOAD_DIR'] = self.d + self.c['UPLOAD_KEEP_FILENAME'] = True + self.c['UPLOAD_KEEP_EXTENSIONS'] = True + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + # Write. + self.f.write(b'12345678901') + self.assertFalse(self.f.in_memory) + + # Assert that the file exists + self.assertEqual(self.f.actual_file_name, b'foo.txt') + self.assert_exists() + + def test_file_full_name_with_ext(self): + self.c['UPLOAD_DIR'] = self.d + self.c['UPLOAD_KEEP_FILENAME'] = True + self.c['UPLOAD_KEEP_EXTENSIONS'] = True + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + # Write. + self.f.write(b'12345678901') + self.assertFalse(self.f.in_memory) + + # Assert that the file exists + self.assertEqual(self.f.actual_file_name, b'foo.txt') + self.assert_exists() + + def test_no_dir_with_extension(self): + self.c['UPLOAD_KEEP_EXTENSIONS'] = True + self.c['MAX_MEMORY_FILE_SIZE'] = 10 + + # Write. + self.f.write(b'12345678901') + self.assertFalse(self.f.in_memory) + + # Assert that the file exists + ext = os.path.splitext(self.f.actual_file_name)[1] + self.assertEqual(ext, b'.txt') + self.assert_exists() + + def test_invalid_dir_with_name(self): + # Write to this dir. + self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) + self.c['UPLOAD_KEEP_FILENAME'] = True + self.c['MAX_MEMORY_FILE_SIZE'] = 5 + + # Write. + with self.assertRaises(FileError): + self.f.write(b'1234567890') + + def test_invalid_dir_no_name(self): + # Write to this dir. + self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) + self.c['UPLOAD_KEEP_FILENAME'] = False + self.c['MAX_MEMORY_FILE_SIZE'] = 5 + + # Write. + with self.assertRaises(FileError): + self.f.write(b'1234567890') + + # TODO: test uploading two files with the same name. + + +class TestParseOptionsHeader(unittest.TestCase): + def test_simple(self): + t, p = parse_options_header('application/json') + self.assertEqual(t, b'application/json') + self.assertEqual(p, {}) + + def test_blank(self): + t, p = parse_options_header('') + self.assertEqual(t, b'') + self.assertEqual(p, {}) + + def test_single_param(self): + t, p = parse_options_header('application/json;par=val') + self.assertEqual(t, b'application/json') + self.assertEqual(p, {b'par': b'val'}) + + def test_single_param_with_spaces(self): + t, p = parse_options_header(b'application/json; par=val') + self.assertEqual(t, b'application/json') + self.assertEqual(p, {b'par': b'val'}) + + def test_multiple_params(self): + t, p = parse_options_header(b'application/json;par=val;asdf=foo') + self.assertEqual(t, b'application/json') + self.assertEqual(p, {b'par': b'val', b'asdf': b'foo'}) + + def test_quoted_param(self): + t, p = parse_options_header(b'application/json;param="quoted"') + self.assertEqual(t, b'application/json') + self.assertEqual(p, {b'param': b'quoted'}) + + def test_quoted_param_with_semicolon(self): + t, p = parse_options_header(b'application/json;param="quoted;with;semicolons"') + self.assertEqual(p[b'param'], b'quoted;with;semicolons') + + def test_quoted_param_with_escapes(self): + t, p = parse_options_header(b'application/json;param="This \\" is \\" a \\" quote"') + self.assertEqual(p[b'param'], b'This " is " a " quote') + + def test_handles_ie6_bug(self): + t, p = parse_options_header(b'text/plain; filename="C:\\this\\is\\a\\path\\file.txt"') + + self.assertEqual(p[b'filename'], b'file.txt') + + +class TestBaseParser(unittest.TestCase): + def setUp(self): + self.b = BaseParser() + self.b.callbacks = {} + + def test_callbacks(self): + # The stupid list-ness is to get around lack of nonlocal on py2 + l = [0] + def on_foo(): + l[0] += 1 + + self.b.set_callback('foo', on_foo) + self.b.callback('foo') + self.assertEqual(l[0], 1) + + self.b.set_callback('foo', None) + self.b.callback('foo') + self.assertEqual(l[0], 1) + + +class TestQuerystringParser(unittest.TestCase): + def assert_fields(self, *args, **kwargs): + if kwargs.pop('finalize', True): + self.p.finalize() + + self.assertEqual(self.f, list(args)) + if kwargs.get('reset', True): + self.f = [] + + def setUp(self): + self.reset() + + def reset(self): + self.f = [] + + name_buffer = [] + data_buffer = [] + + def on_field_name(data, start, end): + name_buffer.append(data[start:end]) + + def on_field_data(data, start, end): + data_buffer.append(data[start:end]) + + def on_field_end(): + self.f.append(( + b''.join(name_buffer), + b''.join(data_buffer) + )) + + del name_buffer[:] + del data_buffer[:] + + callbacks = { + 'on_field_name': on_field_name, + 'on_field_data': on_field_data, + 'on_field_end': on_field_end + } + + self.p = QuerystringParser(callbacks) + + def test_simple_querystring(self): + self.p.write(b'foo=bar') + + self.assert_fields((b'foo', b'bar')) + + def test_querystring_blank_beginning(self): + self.p.write(b'&foo=bar') + + self.assert_fields((b'foo', b'bar')) + + def test_querystring_blank_end(self): + self.p.write(b'foo=bar&') + + self.assert_fields((b'foo', b'bar')) + + def test_multiple_querystring(self): + self.p.write(b'foo=bar&asdf=baz') + + self.assert_fields( + (b'foo', b'bar'), + (b'asdf', b'baz') + ) + + def test_streaming_simple(self): + self.p.write(b'foo=bar&') + self.assert_fields( + (b'foo', b'bar'), + finalize=False + ) + + self.p.write(b'asdf=baz') + self.assert_fields( + (b'asdf', b'baz') + ) + + def test_streaming_break(self): + self.p.write(b'foo=one') + self.assert_fields(finalize=False) + + self.p.write(b'two') + self.assert_fields(finalize=False) + + self.p.write(b'three') + self.assert_fields(finalize=False) + + self.p.write(b'&asd') + self.assert_fields( + (b'foo', b'onetwothree'), + finalize=False + ) + + self.p.write(b'f=baz') + self.assert_fields( + (b'asdf', b'baz') + ) + + def test_semicolon_separator(self): + self.p.write(b'foo=bar;asdf=baz') + + self.assert_fields( + (b'foo', b'bar'), + (b'asdf', b'baz') + ) + + def test_too_large_field(self): + self.p.max_size = 15 + + # Note: len = 8 + self.p.write(b"foo=bar&") + self.assert_fields((b'foo', b'bar'), finalize=False) + + # Note: len = 8, only 7 bytes processed + self.p.write(b'a=123456') + self.assert_fields((b'a', b'12345')) + + def test_invalid_max_size(self): + with self.assertRaises(ValueError): + p = QuerystringParser(max_size=-100) + + def test_strict_parsing_pass(self): + data = b'foo=bar&another=asdf' + for first, last in split_all(data): + self.reset() + self.p.strict_parsing = True + + print(f"{first!r} / {last!r}") + + self.p.write(first) + self.p.write(last) + self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + + def test_strict_parsing_fail_double_sep(self): + data = b'foo=bar&&another=asdf' + for first, last in split_all(data): + self.reset() + self.p.strict_parsing = True + + cnt = 0 + with self.assertRaises(QuerystringParseError) as cm: + cnt += self.p.write(first) + cnt += self.p.write(last) + self.p.finalize() + + # The offset should occur at 8 bytes into the data (as a whole), + # so we calculate the offset into the chunk. + if cm is not None: + self.assertEqual(cm.exception.offset, 8 - cnt) + + def test_double_sep(self): + data = b'foo=bar&&another=asdf' + for first, last in split_all(data): + print(f" {first!r} / {last!r} ") + self.reset() + + cnt = 0 + cnt += self.p.write(first) + cnt += self.p.write(last) + + self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + + def test_strict_parsing_fail_no_value(self): + self.p.strict_parsing = True + with self.assertRaises(QuerystringParseError) as cm: + self.p.write(b'foo=bar&blank&another=asdf') + + if cm is not None: + self.assertEqual(cm.exception.offset, 8) + + def test_success_no_value(self): + self.p.write(b'foo=bar&blank&another=asdf') + self.assert_fields( + (b'foo', b'bar'), + (b'blank', b''), + (b'another', b'asdf') + ) + + def test_repr(self): + # Issue #29; verify we don't assert on repr() + _ignored = repr(self.p) + + +class TestOctetStreamParser(unittest.TestCase): + def setUp(self): + self.d = [] + self.started = 0 + self.finished = 0 + + def on_start(): + self.started += 1 + + def on_data(data, start, end): + self.d.append(data[start:end]) + + def on_end(): + self.finished += 1 + + callbacks = { + 'on_start': on_start, + 'on_data': on_data, + 'on_end': on_end + } + + self.p = OctetStreamParser(callbacks) + + def assert_data(self, data, finalize=True): + self.assertEqual(b''.join(self.d), data) + self.d = [] + + def assert_started(self, val=True): + if val: + self.assertEqual(self.started, 1) + else: + self.assertEqual(self.started, 0) + + def assert_finished(self, val=True): + if val: + self.assertEqual(self.finished, 1) + else: + self.assertEqual(self.finished, 0) + + def test_simple(self): + # Assert is not started + self.assert_started(False) + + # Write something, it should then be started + have data + self.p.write(b'foobar') + self.assert_started() + self.assert_data(b'foobar') + + # Finalize, and check + self.assert_finished(False) + self.p.finalize() + self.assert_finished() + + def test_multiple_chunks(self): + self.p.write(b'foo') + self.p.write(b'bar') + self.p.write(b'baz') + self.p.finalize() + + self.assert_data(b'foobarbaz') + self.assert_finished() + + def test_max_size(self): + self.p.max_size = 5 + + self.p.write(b'0123456789') + self.p.finalize() + + self.assert_data(b'01234') + self.assert_finished() + + def test_invalid_max_size(self): + with self.assertRaises(ValueError): + q = OctetStreamParser(max_size='foo') + + +class TestBase64Decoder(unittest.TestCase): + # Note: base64('foobar') == 'Zm9vYmFy' + def setUp(self): + self.f = BytesIO() + self.d = Base64Decoder(self.f) + + def assert_data(self, data, finalize=True): + if finalize: + self.d.finalize() + + self.f.seek(0) + self.assertEqual(self.f.read(), data) + self.f.seek(0) + self.f.truncate() + + def test_simple(self): + self.d.write(b'Zm9vYmFy') + self.assert_data(b'foobar') + + def test_bad(self): + with self.assertRaises(DecodeError): + self.d.write(b'Zm9v!mFy') + + def test_split_properly(self): + self.d.write(b'Zm9v') + self.d.write(b'YmFy') + self.assert_data(b'foobar') + + def test_bad_split(self): + buff = b'Zm9v' + for i in range(1, 4): + first, second = buff[:i], buff[i:] + + self.setUp() + self.d.write(first) + self.d.write(second) + self.assert_data(b'foo') + + def test_long_bad_split(self): + buff = b'Zm9vYmFy' + for i in range(5, 8): + first, second = buff[:i], buff[i:] + + self.setUp() + self.d.write(first) + self.d.write(second) + self.assert_data(b'foobar') + + def test_close_and_finalize(self): + parser = Mock() + f = Base64Decoder(parser) + + f.finalize() + parser.finalize.assert_called_once_with() + + f.close() + parser.close.assert_called_once_with() + + def test_bad_length(self): + self.d.write(b'Zm9vYmF') # missing ending 'y' + + with self.assertRaises(DecodeError): + self.d.finalize() + + +class TestQuotedPrintableDecoder(unittest.TestCase): + def setUp(self): + self.f = BytesIO() + self.d = QuotedPrintableDecoder(self.f) + + def assert_data(self, data, finalize=True): + if finalize: + self.d.finalize() + + self.f.seek(0) + self.assertEqual(self.f.read(), data) + self.f.seek(0) + self.f.truncate() + + def test_simple(self): + self.d.write(b'foobar') + self.assert_data(b'foobar') + + def test_with_escape(self): + self.d.write(b'foo=3Dbar') + self.assert_data(b'foo=bar') + + def test_with_newline_escape(self): + self.d.write(b'foo=\r\nbar') + self.assert_data(b'foobar') + + def test_with_only_newline_escape(self): + self.d.write(b'foo=\nbar') + self.assert_data(b'foobar') + + def test_with_split_escape(self): + self.d.write(b'foo=3') + self.d.write(b'Dbar') + self.assert_data(b'foo=bar') + + def test_with_split_newline_escape_1(self): + self.d.write(b'foo=\r') + self.d.write(b'\nbar') + self.assert_data(b'foobar') + + def test_with_split_newline_escape_2(self): + self.d.write(b'foo=') + self.d.write(b'\r\nbar') + self.assert_data(b'foobar') + + def test_close_and_finalize(self): + parser = Mock() + f = QuotedPrintableDecoder(parser) + + f.finalize() + parser.finalize.assert_called_once_with() + + f.close() + parser.close.assert_called_once_with() + + def test_not_aligned(self): + """ + https://github.com/andrew-d/python-multipart/issues/6 + """ + self.d.write(b'=3AX') + self.assert_data(b':X') + + # Additional offset tests + self.d.write(b'=3') + self.d.write(b'AX') + self.assert_data(b':X') + + self.d.write(b'q=3AX') + self.assert_data(b'q:X') + + +# Load our list of HTTP test cases. +http_tests_dir = os.path.join(curr_dir, 'test_data', 'http') + +# Read in all test cases and load them. +NON_PARAMETRIZED_TESTS = {'single_field_blocks'} +http_tests = [] +for f in os.listdir(http_tests_dir): + # Only load the HTTP test cases. + fname, ext = os.path.splitext(f) + if fname in NON_PARAMETRIZED_TESTS: + continue + + if ext == '.http': + # Get the YAML file and load it too. + yaml_file = os.path.join(http_tests_dir, fname + '.yaml') + + # Load both. + with open(os.path.join(http_tests_dir, f), 'rb') as f: + test_data = f.read() + + with open(yaml_file, 'rb') as f: + yaml_data = yaml.safe_load(f) + + http_tests.append({ + 'name': fname, + 'test': test_data, + 'result': yaml_data + }) + + +def split_all(val): + """ + This function will split an array all possible ways. For example: + split_all([1,2,3,4]) + will give: + ([1], [2,3,4]), ([1,2], [3,4]), ([1,2,3], [4]) + """ + for i in range(1, len(val) - 1): + yield (val[:i], val[i:]) + + +@parametrize_class +class TestFormParser(unittest.TestCase): + def make(self, boundary, config={}): + self.ended = False + self.files = [] + self.fields = [] + + def on_field(f): + self.fields.append(f) + + def on_file(f): + self.files.append(f) + + def on_end(): + self.ended = True + + # Get a form-parser instance. + self.f = FormParser('multipart/form-data', on_field, on_file, on_end, + boundary=boundary, config=config) + + def assert_file_data(self, f, data): + o = f.file_object + o.seek(0) + file_data = o.read() + self.assertEqual(file_data, data) + + def assert_file(self, field_name, file_name, data): + # Find this file. + found = None + for f in self.files: + if f.field_name == field_name: + found = f + break + + # Assert that we found it. + self.assertIsNotNone(found) + + try: + # Assert about this file. + self.assert_file_data(found, data) + self.assertEqual(found.file_name, file_name) + + # Remove it from our list. + self.files.remove(found) + finally: + # Close our file + found.close() + + def assert_field(self, name, value): + # Find this field in our fields list. + found = None + for f in self.fields: + if f.field_name == name: + found = f + break + + # Assert that it exists and matches. + self.assertIsNotNone(found) + self.assertEqual(value, found.value) + + # Remove it for future iterations. + self.fields.remove(found) + + @parametrize('param', http_tests) + def test_http(self, param): + # Firstly, create our parser with the given boundary. + boundary = param['result']['boundary'] + if isinstance(boundary, str): + boundary = boundary.encode('latin-1') + self.make(boundary) + + # Now, we feed the parser with data. + exc = None + try: + processed = self.f.write(param['test']) + self.f.finalize() + except MultipartParseError as e: + processed = 0 + exc = e + + # print(repr(param)) + # print("") + # print(repr(self.fields)) + # print(repr(self.files)) + + # Do we expect an error? + if 'error' in param['result']['expected']: + self.assertIsNotNone(exc) + self.assertEqual(param['result']['expected']['error'], exc.offset) + return + + # No error! + self.assertEqual(processed, len(param['test'])) + + # Assert that the parser gave us the appropriate fields/files. + for e in param['result']['expected']: + # Get our type and name. + type = e['type'] + name = e['name'].encode('latin-1') + + if type == 'field': + self.assert_field(name, e['data']) + + elif type == 'file': + self.assert_file( + name, + e['file_name'].encode('latin-1'), + e['data'] + ) + + else: + assert False + + def test_random_splitting(self): + """ + This test runs a simple multipart body with one field and one file + through every possible split. + """ + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + # We split the file through all cases. + for first, last in split_all(test_data): + # Create form parser. + self.make('boundary') + + # Feed with data in 2 chunks. + i = 0 + i += self.f.write(first) + i += self.f.write(last) + self.f.finalize() + + # Assert we processed everything. + self.assertEqual(i, len(test_data)) + + # Assert that our file and field are here. + self.assert_field(b'field', b'test1') + self.assert_file(b'file', b'file.txt', b'test2') + + def test_feed_single_bytes(self): + """ + This test parses a simple multipart body 1 byte at a time. + """ + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + # Create form parser. + self.make('boundary') + + # Write all bytes. + # NOTE: Can't simply do `for b in test_data`, since that gives + # an integer when iterating over a bytes object on Python 3. + i = 0 + for x in range(len(test_data)): + b = test_data[x:x + 1] + i += self.f.write(b) + + self.f.finalize() + + # Assert we processed everything. + self.assertEqual(i, len(test_data)) + + # Assert that our file and field are here. + self.assert_field(b'field', b'test1') + self.assert_file(b'file', b'file.txt', b'test2') + + def test_feed_blocks(self): + """ + This test parses a simple multipart body 1 byte at a time. + """ + # Load test data. + test_file = 'single_field_blocks.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + for c in range(1, len(test_data) + 1): + # Skip first `d` bytes - not interesting + for d in range(c): + + # Create form parser. + self.make('boundary') + # Skip + i = 0 + self.f.write(test_data[:d]) + i += d + for x in range(d, len(test_data), c): + # Write a chunk to achieve condition + # `i == data_length - 1` + # in boundary search loop (multipatr.py:1302) + b = test_data[x:x + c] + i += self.f.write(b) + + self.f.finalize() + + # Assert we processed everything. + self.assertEqual(i, len(test_data)) + + # Assert that our field is here. + self.assert_field(b'field', + b'0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ') + + @slow_test + def test_request_body_fuzz(self): + """ + This test randomly fuzzes the request body to ensure that no strange + exceptions are raised and we don't end up in a strange state. The + fuzzing consists of randomly doing one of the following: + - Adding a random byte at a random offset + - Randomly deleting a single byte + - Randomly swapping two bytes + """ + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + iterations = 1000 + successes = 0 + failures = 0 + exceptions = 0 + + print("Running %d iterations of fuzz testing:" % (iterations,)) + for i in range(iterations): + # Create a bytearray to mutate. + fuzz_data = bytearray(test_data) + + # Pick what we're supposed to do. + choice = random.choice([1, 2, 3]) + if choice == 1: + # Add a random byte. + i = random.randrange(len(test_data)) + b = random.randrange(256) + + fuzz_data.insert(i, b) + msg = "Inserting byte %r at offset %d" % (b, i) + + elif choice == 2: + # Remove a random byte. + i = random.randrange(len(test_data)) + del fuzz_data[i] + + msg = "Deleting byte at offset %d" % (i,) + + elif choice == 3: + # Swap two bytes. + i = random.randrange(len(test_data) - 1) + fuzz_data[i], fuzz_data[i + 1] = fuzz_data[i + 1], fuzz_data[i] + + msg = "Swapping bytes %d and %d" % (i, i + 1) + + # Print message, so if this crashes, we can inspect the output. + print(" " + msg) + + # Create form parser. + self.make('boundary') + + # Feed with data, and ignore form parser exceptions. + i = 0 + try: + i = self.f.write(bytes(fuzz_data)) + self.f.finalize() + except FormParserError: + exceptions += 1 + else: + if i == len(fuzz_data): + successes += 1 + else: + failures += 1 + + print("--------------------------------------------------") + print("Successes: %d" % (successes,)) + print("Failures: %d" % (failures,)) + print("Exceptions: %d" % (exceptions,)) + + @slow_test + def test_request_body_fuzz_random_data(self): + """ + This test will fuzz the multipart parser with some number of iterations + of randomly-generated data. + """ + iterations = 1000 + successes = 0 + failures = 0 + exceptions = 0 + + print("Running %d iterations of fuzz testing:" % (iterations,)) + for i in range(iterations): + data_size = random.randrange(100, 4096) + data = os.urandom(data_size) + print(" Testing with %d random bytes..." % (data_size,)) + + # Create form parser. + self.make('boundary') + + # Feed with data, and ignore form parser exceptions. + i = 0 + try: + i = self.f.write(bytes(data)) + self.f.finalize() + except FormParserError: + exceptions += 1 + else: + if i == len(data): + successes += 1 + else: + failures += 1 + + print("--------------------------------------------------") + print("Successes: %d" % (successes,)) + print("Failures: %d" % (failures,)) + print("Exceptions: %d" % (exceptions,)) + + def test_bad_start_boundary(self): + self.make('boundary') + data = b'--boundary\rfoobar' + with self.assertRaises(MultipartParseError): + self.f.write(data) + + self.make('boundary') + data = b'--boundaryfoobar' + with self.assertRaises(MultipartParseError): + i = self.f.write(data) + + def test_octet_stream(self): + files = [] + def on_file(f): + files.append(f) + on_field = Mock() + on_end = Mock() + + f = FormParser('application/octet-stream', on_field, on_file, on_end=on_end, file_name=b'foo.txt') + self.assertTrue(isinstance(f.parser, OctetStreamParser)) + + f.write(b'test') + f.write(b'1234') + f.finalize() + + # Assert that we only received a single file, with the right data, and that we're done. + self.assertFalse(on_field.called) + self.assertEqual(len(files), 1) + self.assert_file_data(files[0], b'test1234') + self.assertTrue(on_end.called) + + def test_querystring(self): + fields = [] + def on_field(f): + fields.append(f) + on_file = Mock() + on_end = Mock() + + def simple_test(f): + # Reset tracking. + del fields[:] + on_file.reset_mock() + on_end.reset_mock() + + # Write test data. + f.write(b'foo=bar') + f.write(b'&test=asdf') + f.finalize() + + # Assert we only received 2 fields... + self.assertFalse(on_file.called) + self.assertEqual(len(fields), 2) + + # ...assert that we have the correct data... + self.assertEqual(fields[0].field_name, b'foo') + self.assertEqual(fields[0].value, b'bar') + + self.assertEqual(fields[1].field_name, b'test') + self.assertEqual(fields[1].value, b'asdf') + + # ... and assert that we've finished. + self.assertTrue(on_end.called) + + f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) + self.assertTrue(isinstance(f.parser, QuerystringParser)) + simple_test(f) + + f = FormParser('application/x-url-encoded', on_field, on_file, on_end=on_end) + self.assertTrue(isinstance(f.parser, QuerystringParser)) + simple_test(f) + + def test_close_methods(self): + parser = Mock() + f = FormParser('application/x-url-encoded', None, None) + f.parser = parser + + f.finalize() + parser.finalize.assert_called_once_with() + + f.close() + parser.close.assert_called_once_with() + + def test_bad_content_type(self): + # We should raise a ValueError for a bad Content-Type + with self.assertRaises(ValueError): + f = FormParser('application/bad', None, None) + + def test_no_boundary_given(self): + # We should raise a FormParserError when parsing a multipart message + # without a boundary. + with self.assertRaises(FormParserError): + f = FormParser('multipart/form-data', None, None) + + def test_bad_content_transfer_encoding(self): + data = b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\nContent-Type: text/plain\r\nContent-Transfer-Encoding: badstuff\r\n\r\nTest\r\n----boundary--\r\n' + + files = [] + def on_file(f): + files.append(f) + on_field = Mock() + on_end = Mock() + + # Test with erroring. + config = {'UPLOAD_ERROR_ON_BAD_CTE': True} + f = FormParser('multipart/form-data', on_field, on_file, + on_end=on_end, boundary='--boundary', config=config) + + with self.assertRaises(FormParserError): + f.write(data) + f.finalize() + + # Test without erroring. + config = {'UPLOAD_ERROR_ON_BAD_CTE': False} + f = FormParser('multipart/form-data', on_field, on_file, + on_end=on_end, boundary='--boundary', config=config) + + f.write(data) + f.finalize() + self.assert_file_data(files[0], b'Test') + + def test_handles_None_fields(self): + fields = [] + def on_field(f): + fields.append(f) + on_file = Mock() + on_end = Mock() + + f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) + f.write(b'foo=bar&another&baz=asdf') + f.finalize() + + self.assertEqual(fields[0].field_name, b'foo') + self.assertEqual(fields[0].value, b'bar') + + self.assertEqual(fields[1].field_name, b'another') + self.assertEqual(fields[1].value, None) + + self.assertEqual(fields[2].field_name, b'baz') + self.assertEqual(fields[2].value, b'asdf') + + def test_max_size_multipart(self): + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + # Create form parser. + self.make('boundary') + + # Set the maximum length that we can process to be halfway through the + # given data. + self.f.parser.max_size = len(test_data) / 2 + + i = self.f.write(test_data) + self.f.finalize() + + # Assert we processed the correct amount. + self.assertEqual(i, len(test_data) / 2) + + def test_max_size_form_parser(self): + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + # Create form parser setting the maximum length that we can process to + # be halfway through the given data. + size = len(test_data) / 2 + self.make('boundary', config={'MAX_BODY_SIZE': size}) + + i = self.f.write(test_data) + self.f.finalize() + + # Assert we processed the correct amount. + self.assertEqual(i, len(test_data) / 2) + + def test_octet_stream_max_size(self): + files = [] + def on_file(f): + files.append(f) + on_field = Mock() + on_end = Mock() + + f = FormParser('application/octet-stream', on_field, on_file, + on_end=on_end, file_name=b'foo.txt', + config={'MAX_BODY_SIZE': 10}) + + f.write(b'0123456789012345689') + f.finalize() + + self.assert_file_data(files[0], b'0123456789') + + def test_invalid_max_size_multipart(self): + with self.assertRaises(ValueError): + q = MultipartParser(b'bound', max_size='foo') + + +class TestHelperFunctions(unittest.TestCase): + def test_create_form_parser(self): + r = create_form_parser({'Content-Type': 'application/octet-stream'}, + None, None) + self.assertTrue(isinstance(r, FormParser)) + + def test_create_form_parser_error(self): + headers = {} + with self.assertRaises(ValueError): + create_form_parser(headers, None, None) + + def test_parse_form(self): + on_field = Mock() + on_file = Mock() + + parse_form( + {'Content-Type': 'application/octet-stream', + }, + BytesIO(b'123456789012345'), + on_field, + on_file + ) + + assert on_file.call_count == 1 + + # Assert that the first argument of the call (a File object) has size + # 15 - i.e. all data is written. + self.assertEqual(on_file.call_args[0][0].size, 15) + + def test_parse_form_content_length(self): + files = [] + def on_file(file): + files.append(file) + + parse_form( + {'Content-Type': 'application/octet-stream', + 'Content-Length': '10' + }, + BytesIO(b'123456789012345'), + None, + on_file + ) + + self.assertEqual(len(files), 1) + self.assertEqual(files[0].size, 10) + + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestParseOptionsHeader)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestBaseParser)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestQuerystringParser)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestOctetStreamParser)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestBase64Decoder)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestQuotedPrintableDecoder)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFormParser)) + suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestHelperFunctions)) + + return suite diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/METADATA new file mode 100644 index 0000000..032a564 --- /dev/null +++ b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/METADATA @@ -0,0 +1,803 @@ +Metadata-Version: 2.1 +Name: prometheus-client +Version: 0.18.0 +Summary: Python client for the Prometheus monitoring system. +Home-page: https://github.com/prometheus/client_python +Author: Brian Brazil +Author-email: brian.brazil@robustperception.io +License: Apache Software License 2.0 +Keywords: prometheus monitoring instrumentation client +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Information Technology +Classifier: Intended Audience :: System Administrators +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: System :: Monitoring +Classifier: License :: OSI Approved :: Apache Software License +Requires-Python: >=3.8 +Description-Content-Type: text/markdown +License-File: LICENSE +License-File: NOTICE +Provides-Extra: twisted +Requires-Dist: twisted ; extra == 'twisted' + +# Prometheus Python Client + +The official Python client for [Prometheus](https://prometheus.io). + +## Three Step Demo + +**One**: Install the client: +``` +pip install prometheus-client +``` + +**Two**: Paste the following into a Python interpreter: +```python +from prometheus_client import start_http_server, Summary +import random +import time + +# Create a metric to track time spent and requests made. +REQUEST_TIME = Summary('request_processing_seconds', 'Time spent processing request') + +# Decorate function with metric. +@REQUEST_TIME.time() +def process_request(t): + """A dummy function that takes some time.""" + time.sleep(t) + +if __name__ == '__main__': + # Start up the server to expose the metrics. + start_http_server(8000) + # Generate some requests. + while True: + process_request(random.random()) +``` + +**Three**: Visit [http://localhost:8000/](http://localhost:8000/) to view the metrics. + +From one easy to use decorator you get: + * `request_processing_seconds_count`: Number of times this function was called. + * `request_processing_seconds_sum`: Total amount of time spent in this function. + +Prometheus's `rate` function allows calculation of both requests per second, +and latency over time from this data. + +In addition if you're on Linux the `process` metrics expose CPU, memory and +other information about the process for free! + +## Installation + +``` +pip install prometheus-client +``` + +This package can be found on +[PyPI](https://pypi.python.org/pypi/prometheus_client). + +## Instrumenting + +Four types of metric are offered: Counter, Gauge, Summary and Histogram. +See the documentation on [metric types](http://prometheus.io/docs/concepts/metric_types/) +and [instrumentation best practices](https://prometheus.io/docs/practices/instrumentation/#counter-vs-gauge-summary-vs-histogram) +on how to use them. + +### Counter + +Counters go up, and reset when the process restarts. + + +```python +from prometheus_client import Counter +c = Counter('my_failures', 'Description of counter') +c.inc() # Increment by 1 +c.inc(1.6) # Increment by given value +``` + +If there is a suffix of `_total` on the metric name, it will be removed. When +exposing the time series for counter, a `_total` suffix will be added. This is +for compatibility between OpenMetrics and the Prometheus text format, as OpenMetrics +requires the `_total` suffix. + +There are utilities to count exceptions raised: + +```python +@c.count_exceptions() +def f(): + pass + +with c.count_exceptions(): + pass + +# Count only one type of exception +with c.count_exceptions(ValueError): + pass +``` + +### Gauge + +Gauges can go up and down. + +```python +from prometheus_client import Gauge +g = Gauge('my_inprogress_requests', 'Description of gauge') +g.inc() # Increment by 1 +g.dec(10) # Decrement by given value +g.set(4.2) # Set to a given value +``` + +There are utilities for common use cases: + +```python +g.set_to_current_time() # Set to current unixtime + +# Increment when entered, decrement when exited. +@g.track_inprogress() +def f(): + pass + +with g.track_inprogress(): + pass +``` + +A Gauge can also take its value from a callback: + +```python +d = Gauge('data_objects', 'Number of objects') +my_dict = {} +d.set_function(lambda: len(my_dict)) +``` + +### Summary + +Summaries track the size and number of events. + +```python +from prometheus_client import Summary +s = Summary('request_latency_seconds', 'Description of summary') +s.observe(4.7) # Observe 4.7 (seconds in this case) +``` + +There are utilities for timing code: + +```python +@s.time() +def f(): + pass + +with s.time(): + pass +``` + +The Python client doesn't store or expose quantile information at this time. + +### Histogram + +Histograms track the size and number of events in buckets. +This allows for aggregatable calculation of quantiles. + +```python +from prometheus_client import Histogram +h = Histogram('request_latency_seconds', 'Description of histogram') +h.observe(4.7) # Observe 4.7 (seconds in this case) +``` + +The default buckets are intended to cover a typical web/rpc request from milliseconds to seconds. +They can be overridden by passing `buckets` keyword argument to `Histogram`. + +There are utilities for timing code: + +```python +@h.time() +def f(): + pass + +with h.time(): + pass +``` + +### Info + +Info tracks key-value information, usually about a whole target. + +```python +from prometheus_client import Info +i = Info('my_build_version', 'Description of info') +i.info({'version': '1.2.3', 'buildhost': 'foo@bar'}) +``` + +### Enum + +Enum tracks which of a set of states something is currently in. + +```python +from prometheus_client import Enum +e = Enum('my_task_state', 'Description of enum', + states=['starting', 'running', 'stopped']) +e.state('running') +``` + +### Labels + +All metrics can have labels, allowing grouping of related time series. + +See the best practices on [naming](http://prometheus.io/docs/practices/naming/) +and [labels](http://prometheus.io/docs/practices/instrumentation/#use-labels). + +Taking a counter as an example: + +```python +from prometheus_client import Counter +c = Counter('my_requests_total', 'HTTP Failures', ['method', 'endpoint']) +c.labels('get', '/').inc() +c.labels('post', '/submit').inc() +``` + +Labels can also be passed as keyword-arguments: + +```python +from prometheus_client import Counter +c = Counter('my_requests_total', 'HTTP Failures', ['method', 'endpoint']) +c.labels(method='get', endpoint='/').inc() +c.labels(method='post', endpoint='/submit').inc() +``` + +Metrics with labels are not initialized when declared, because the client can't +know what values the label can have. It is recommended to initialize the label +values by calling the `.labels()` method alone: + +```python +from prometheus_client import Counter +c = Counter('my_requests_total', 'HTTP Failures', ['method', 'endpoint']) +c.labels('get', '/') +c.labels('post', '/submit') +``` + +### Exemplars + +Exemplars can be added to counter and histogram metrics. Exemplars can be +specified by passing a dict of label value pairs to be exposed as the exemplar. +For example with a counter: + +```python +from prometheus_client import Counter +c = Counter('my_requests_total', 'HTTP Failures', ['method', 'endpoint']) +c.labels('get', '/').inc(exemplar={'trace_id': 'abc123'}) +c.labels('post', '/submit').inc(1.0, {'trace_id': 'def456'}) +``` + +And with a histogram: + +```python +from prometheus_client import Histogram +h = Histogram('request_latency_seconds', 'Description of histogram') +h.observe(4.7, {'trace_id': 'abc123'}) +``` + +Exemplars are only rendered in the OpenMetrics exposition format. If using the +HTTP server or apps in this library, content negotiation can be used to specify +OpenMetrics (which is done by default in Prometheus). Otherwise it will be +necessary to use `generate_latest` from +`prometheus_client.openmetrics.exposition` to view exemplars. + +To view exemplars in Prometheus it is also necessary to enable the the +exemplar-storage feature flag: +``` +--enable-feature=exemplar-storage +``` +Additional information is available in [the Prometheus +documentation](https://prometheus.io/docs/prometheus/latest/feature_flags/#exemplars-storage). + +### Disabling `_created` metrics + +By default counters, histograms, and summaries export an additional series +suffixed with `_created` and a value of the unix timestamp for when the metric +was created. If this information is not helpful, it can be disabled by setting +the environment variable `PROMETHEUS_DISABLE_CREATED_SERIES=True`. + +### Process Collector + +The Python client automatically exports metrics about process CPU usage, RAM, +file descriptors and start time. These all have the prefix `process`, and +are only currently available on Linux. + +The namespace and pid constructor arguments allows for exporting metrics about +other processes, for example: +``` +ProcessCollector(namespace='mydaemon', pid=lambda: open('/var/run/daemon.pid').read()) +``` + +### Platform Collector + +The client also automatically exports some metadata about Python. If using Jython, +metadata about the JVM in use is also included. This information is available as +labels on the `python_info` metric. The value of the metric is 1, since it is the +labels that carry information. + +### Disabling Default Collector metrics + +By default the collected `process`, `gc`, and `platform` collector metrics are exported. +If this information is not helpful, it can be disabled using the following: +```python +import prometheus_client + +prometheus_client.REGISTRY.unregister(prometheus_client.GC_COLLECTOR) +prometheus_client.REGISTRY.unregister(prometheus_client.PLATFORM_COLLECTOR) +prometheus_client.REGISTRY.unregister(prometheus_client.PROCESS_COLLECTOR) +``` + +## Exporting + +There are several options for exporting metrics. + +### HTTP + +Metrics are usually exposed over HTTP, to be read by the Prometheus server. + +The easiest way to do this is via `start_http_server`, which will start a HTTP +server in a daemon thread on the given port: + +```python +from prometheus_client import start_http_server + +start_http_server(8000) +``` + +Visit [http://localhost:8000/](http://localhost:8000/) to view the metrics. + +To add Prometheus exposition to an existing HTTP server, see the `MetricsHandler` class +which provides a `BaseHTTPRequestHandler`. It also serves as a simple example of how +to write a custom endpoint. + +#### Twisted + +To use prometheus with [twisted](https://twistedmatrix.com/), there is `MetricsResource` which exposes metrics as a twisted resource. + +```python +from prometheus_client.twisted import MetricsResource +from twisted.web.server import Site +from twisted.web.resource import Resource +from twisted.internet import reactor + +root = Resource() +root.putChild(b'metrics', MetricsResource()) + +factory = Site(root) +reactor.listenTCP(8000, factory) +reactor.run() +``` + +#### WSGI + +To use Prometheus with [WSGI](http://wsgi.readthedocs.org/en/latest/), there is +`make_wsgi_app` which creates a WSGI application. + +```python +from prometheus_client import make_wsgi_app +from wsgiref.simple_server import make_server + +app = make_wsgi_app() +httpd = make_server('', 8000, app) +httpd.serve_forever() +``` + +Such an application can be useful when integrating Prometheus metrics with WSGI +apps. + +The method `start_wsgi_server` can be used to serve the metrics through the +WSGI reference implementation in a new thread. + +```python +from prometheus_client import start_wsgi_server + +start_wsgi_server(8000) +``` + +By default, the WSGI application will respect `Accept-Encoding:gzip` headers used by Prometheus +and compress the response if such a header is present. This behaviour can be disabled by passing +`disable_compression=True` when creating the app, like this: + +```python +app = make_wsgi_app(disable_compression=True) +``` + +#### ASGI + +To use Prometheus with [ASGI](http://asgi.readthedocs.org/en/latest/), there is +`make_asgi_app` which creates an ASGI application. + +```python +from prometheus_client import make_asgi_app + +app = make_asgi_app() +``` +Such an application can be useful when integrating Prometheus metrics with ASGI +apps. + +By default, the WSGI application will respect `Accept-Encoding:gzip` headers used by Prometheus +and compress the response if such a header is present. This behaviour can be disabled by passing +`disable_compression=True` when creating the app, like this: + +```python +app = make_asgi_app(disable_compression=True) +``` + +#### Flask + +To use Prometheus with [Flask](http://flask.pocoo.org/) we need to serve metrics through a Prometheus WSGI application. This can be achieved using [Flask's application dispatching](http://flask.pocoo.org/docs/latest/patterns/appdispatch/). Below is a working example. + +Save the snippet below in a `myapp.py` file + +```python +from flask import Flask +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from prometheus_client import make_wsgi_app + +# Create my app +app = Flask(__name__) + +# Add prometheus wsgi middleware to route /metrics requests +app.wsgi_app = DispatcherMiddleware(app.wsgi_app, { + '/metrics': make_wsgi_app() +}) +``` + +Run the example web application like this + +```bash +# Install uwsgi if you do not have it +pip install uwsgi +uwsgi --http 127.0.0.1:8000 --wsgi-file myapp.py --callable app +``` + +Visit http://localhost:8000/metrics to see the metrics + +#### FastAPI + Gunicorn + +To use Prometheus with [FastAPI](https://fastapi.tiangolo.com/) and [Gunicorn](https://gunicorn.org/) we need to serve metrics through a Prometheus ASGI application. + +Save the snippet below in a `myapp.py` file + +```python +from fastapi import FastAPI +from prometheus_client import make_asgi_app + +# Create app +app = FastAPI(debug=False) + +# Add prometheus asgi middleware to route /metrics requests +metrics_app = make_asgi_app() +app.mount("/metrics", metrics_app) +``` + +For Multiprocessing support, use this modified code snippet. Full multiprocessing instructions are provided [here](https://github.com/prometheus/client_python#multiprocess-mode-eg-gunicorn). + +```python +from fastapi import FastAPI +from prometheus_client import make_asgi_app + +app = FastAPI(debug=False) + +# Using multiprocess collector for registry +def make_metrics_app(): + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + return make_asgi_app(registry=registry) + + +metrics_app = make_metrics_app() +app.mount("/metrics", metrics_app) +``` + +Run the example web application like this + +```bash +# Install gunicorn if you do not have it +pip install gunicorn +# If using multiple workers, add `--workers n` parameter to the line below +gunicorn -b 127.0.0.1:8000 myapp:app -k uvicorn.workers.UvicornWorker +``` + +Visit http://localhost:8000/metrics to see the metrics + + +### Node exporter textfile collector + +The [textfile collector](https://github.com/prometheus/node_exporter#textfile-collector) +allows machine-level statistics to be exported out via the Node exporter. + +This is useful for monitoring cronjobs, or for writing cronjobs to expose metrics +about a machine system that the Node exporter does not support or would not make sense +to perform at every scrape (for example, anything involving subprocesses). + +```python +from prometheus_client import CollectorRegistry, Gauge, write_to_textfile + +registry = CollectorRegistry() +g = Gauge('raid_status', '1 if raid array is okay', registry=registry) +g.set(1) +write_to_textfile('/configured/textfile/path/raid.prom', registry) +``` + +A separate registry is used, as the default registry may contain other metrics +such as those from the Process Collector. + +## Exporting to a Pushgateway + +The [Pushgateway](https://github.com/prometheus/pushgateway) +allows ephemeral and batch jobs to expose their metrics to Prometheus. + +```python +from prometheus_client import CollectorRegistry, Gauge, push_to_gateway + +registry = CollectorRegistry() +g = Gauge('job_last_success_unixtime', 'Last time a batch job successfully finished', registry=registry) +g.set_to_current_time() +push_to_gateway('localhost:9091', job='batchA', registry=registry) +``` + +A separate registry is used, as the default registry may contain other metrics +such as those from the Process Collector. + +Pushgateway functions take a grouping key. `push_to_gateway` replaces metrics +with the same grouping key, `pushadd_to_gateway` only replaces metrics with the +same name and grouping key and `delete_from_gateway` deletes metrics with the +given job and grouping key. See the +[Pushgateway documentation](https://github.com/prometheus/pushgateway/blob/master/README.md) +for more information. + +`instance_ip_grouping_key` returns a grouping key with the instance label set +to the host's IP address. + +### Handlers for authentication + +If the push gateway you are connecting to is protected with HTTP Basic Auth, +you can use a special handler to set the Authorization header. + +```python +from prometheus_client import CollectorRegistry, Gauge, push_to_gateway +from prometheus_client.exposition import basic_auth_handler + +def my_auth_handler(url, method, timeout, headers, data): + username = 'foobar' + password = 'secret123' + return basic_auth_handler(url, method, timeout, headers, data, username, password) +registry = CollectorRegistry() +g = Gauge('job_last_success_unixtime', 'Last time a batch job successfully finished', registry=registry) +g.set_to_current_time() +push_to_gateway('localhost:9091', job='batchA', registry=registry, handler=my_auth_handler) +``` + +TLS Auth is also supported when using the push gateway with a special handler. + +```python +from prometheus_client import CollectorRegistry, Gauge, push_to_gateway +from prometheus_client.exposition import tls_auth_handler + + +def my_auth_handler(url, method, timeout, headers, data): + certfile = 'client-crt.pem' + keyfile = 'client-key.pem' + return tls_auth_handler(url, method, timeout, headers, data, certfile, keyfile) + +registry = CollectorRegistry() +g = Gauge('job_last_success_unixtime', 'Last time a batch job successfully finished', registry=registry) +g.set_to_current_time() +push_to_gateway('localhost:9091', job='batchA', registry=registry, handler=my_auth_handler) +``` + +## Bridges + +It is also possible to expose metrics to systems other than Prometheus. +This allows you to take advantage of Prometheus instrumentation even +if you are not quite ready to fully transition to Prometheus yet. + +### Graphite + +Metrics are pushed over TCP in the Graphite plaintext format. + +```python +from prometheus_client.bridge.graphite import GraphiteBridge + +gb = GraphiteBridge(('graphite.your.org', 2003)) +# Push once. +gb.push() +# Push every 10 seconds in a daemon thread. +gb.start(10.0) +``` + +Graphite [tags](https://grafana.com/blog/2018/01/11/graphite-1.1-teaching-an-old-dog-new-tricks/) are also supported. + +```python +from prometheus_client.bridge.graphite import GraphiteBridge + +gb = GraphiteBridge(('graphite.your.org', 2003), tags=True) +c = Counter('my_requests_total', 'HTTP Failures', ['method', 'endpoint']) +c.labels('get', '/').inc() +gb.push() +``` + +## Custom Collectors + +Sometimes it is not possible to directly instrument code, as it is not +in your control. This requires you to proxy metrics from other systems. + +To do so you need to create a custom collector, for example: + +```python +from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily, REGISTRY +from prometheus_client.registry import Collector + +class CustomCollector(Collector): + def collect(self): + yield GaugeMetricFamily('my_gauge', 'Help text', value=7) + c = CounterMetricFamily('my_counter_total', 'Help text', labels=['foo']) + c.add_metric(['bar'], 1.7) + c.add_metric(['baz'], 3.8) + yield c + +REGISTRY.register(CustomCollector()) +``` + +`SummaryMetricFamily`, `HistogramMetricFamily` and `InfoMetricFamily` work similarly. + +A collector may implement a `describe` method which returns metrics in the same +format as `collect` (though you don't have to include the samples). This is +used to predetermine the names of time series a `CollectorRegistry` exposes and +thus to detect collisions and duplicate registrations. + +Usually custom collectors do not have to implement `describe`. If `describe` is +not implemented and the CollectorRegistry was created with `auto_describe=True` +(which is the case for the default registry) then `collect` will be called at +registration time instead of `describe`. If this could cause problems, either +implement a proper `describe`, or if that's not practical have `describe` +return an empty list. + + +## Multiprocess Mode (E.g. Gunicorn) + +Prometheus client libraries presume a threaded model, where metrics are shared +across workers. This doesn't work so well for languages such as Python where +it's common to have processes rather than threads to handle large workloads. + +To handle this the client library can be put in multiprocess mode. +This comes with a number of limitations: + +- Registries can not be used as normal, all instantiated metrics are exported + - Registering metrics to a registry later used by a `MultiProcessCollector` + may cause duplicate metrics to be exported +- Custom collectors do not work (e.g. cpu and memory metrics) +- Info and Enum metrics do not work +- The pushgateway cannot be used +- Gauges cannot use the `pid` label +- Exemplars are not supported + +There's several steps to getting this working: + +**1. Deployment**: + +The `PROMETHEUS_MULTIPROC_DIR` environment variable must be set to a directory +that the client library can use for metrics. This directory must be wiped +between process/Gunicorn runs (before startup is recommended). + +This environment variable should be set from a start-up shell script, +and not directly from Python (otherwise it may not propagate to child processes). + +**2. Metrics collector**: + +The application must initialize a new `CollectorRegistry`, and store the +multi-process collector inside. It is a best practice to create this registry +inside the context of a request to avoid metrics registering themselves to a +collector used by a `MultiProcessCollector`. If a registry with metrics +registered is used by a `MultiProcessCollector` duplicate metrics may be +exported, one for multiprocess, and one for the process serving the request. + +```python +from prometheus_client import multiprocess +from prometheus_client import generate_latest, CollectorRegistry, CONTENT_TYPE_LATEST, Counter + +MY_COUNTER = Counter('my_counter', 'Description of my counter') + +# Expose metrics. +def app(environ, start_response): + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + data = generate_latest(registry) + status = '200 OK' + response_headers = [ + ('Content-type', CONTENT_TYPE_LATEST), + ('Content-Length', str(len(data))) + ] + start_response(status, response_headers) + return iter([data]) +``` + +**3. Gunicorn configuration**: + +The `gunicorn` configuration file needs to include the following function: + +```python +from prometheus_client import multiprocess + +def child_exit(server, worker): + multiprocess.mark_process_dead(worker.pid) +``` + +**4. Metrics tuning (Gauge)**: + +When `Gauge`s are used in multiprocess applications, +you must decide how to handle the metrics reported by each process. +Gauges have several modes they can run in, which can be selected with the `multiprocess_mode` parameter. + +- 'all': Default. Return a timeseries per process (alive or dead), labelled by the process's `pid` (the label is added internally). +- 'min': Return a single timeseries that is the minimum of the values of all processes (alive or dead). +- 'max': Return a single timeseries that is the maximum of the values of all processes (alive or dead). +- 'sum': Return a single timeseries that is the sum of the values of all processes (alive or dead). +- 'mostrecent': Return a single timeseries that is the most recent value among all processes (alive or dead). + +Prepend 'live' to the beginning of the mode to return the same result but only considering living processes +(e.g., 'liveall, 'livesum', 'livemax', 'livemin', 'livemostrecent'). + +```python +from prometheus_client import Gauge + +# Example gauge +IN_PROGRESS = Gauge("inprogress_requests", "help", multiprocess_mode='livesum') +``` + + +## Parser + +The Python client supports parsing the Prometheus text format. +This is intended for advanced use cases where you have servers +exposing Prometheus metrics and need to get them into some other +system. + +```python +from prometheus_client.parser import text_string_to_metric_families +for family in text_string_to_metric_families(u"my_gauge 1.0\n"): + for sample in family.samples: + print("Name: {0} Labels: {1} Value: {2}".format(*sample)) +``` + +## Restricted registry + +Registries support restriction to only return specific metrics. +If you’re using the built-in HTTP server, you can use the GET parameter "name[]", since it’s an array it can be used multiple times. +If you’re directly using `generate_latest`, you can use the function `restricted_registry()`. + +```python +curl --get --data-urlencode "name[]=python_gc_objects_collected_total" --data-urlencode "name[]=python_info" http://127.0.0.1:9200/metrics +``` + +```python +from prometheus_client import generate_latest + +generate_latest(REGISTRY.restricted_registry(['python_gc_objects_collected_total', 'python_info'])) +``` + +```python +# HELP python_info Python platform information +# TYPE python_info gauge +python_info{implementation="CPython",major="3",minor="9",patchlevel="3",version="3.9.3"} 1.0 +# HELP python_gc_objects_collected_total Objects collected during gc +# TYPE python_gc_objects_collected_total counter +python_gc_objects_collected_total{generation="0"} 73129.0 +python_gc_objects_collected_total{generation="1"} 8594.0 +python_gc_objects_collected_total{generation="2"} 296.0 +``` + + +## Links + +* [Releases](https://github.com/prometheus/client_python/releases): The releases page shows the history of the project and acts as a changelog. +* [PyPI](https://pypi.python.org/pypi/prometheus_client) diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/licenses/NOTICE b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/NOTICE similarity index 100% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/licenses/NOTICE rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/NOTICE diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/RECORD similarity index 59% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/RECORD rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/RECORD index 6aba3fe..67aefbb 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/RECORD @@ -1,12 +1,12 @@ -prometheus_client-0.23.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -prometheus_client-0.23.1.dist-info/METADATA,sha256=IsOo26jaObfdWtq0csEgr3c8J8LOD2l-ZJFfHk7TNo4,1907 -prometheus_client-0.23.1.dist-info/RECORD,, -prometheus_client-0.23.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -prometheus_client-0.23.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 -prometheus_client-0.23.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357 -prometheus_client-0.23.1.dist-info/licenses/NOTICE,sha256=TvoYdK6qYPNl9Xl-YX8f-TPhXlCOr3UemEjtRBPXp64,236 -prometheus_client-0.23.1.dist-info/top_level.txt,sha256=AxLEvHEMhTW-Kvb9Ly1DPI3aapigQ2aeg8TXMt9WMRo,18 -prometheus_client/__init__.py,sha256=3KznwpxJxkWiKqn6lw62fOKRALWYx8NO743ln0f3drI,1935 +prometheus_client-0.18.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +prometheus_client-0.18.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357 +prometheus_client-0.18.0.dist-info/METADATA,sha256=eRWr9-kFrteOREkDNhp_LadPGX8uQtTfM_eSULwhADs,26065 +prometheus_client-0.18.0.dist-info/NOTICE,sha256=TvoYdK6qYPNl9Xl-YX8f-TPhXlCOr3UemEjtRBPXp64,236 +prometheus_client-0.18.0.dist-info/RECORD,, +prometheus_client-0.18.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +prometheus_client-0.18.0.dist-info/WHEEL,sha256=Xo9-1PvkuimrydujYJAjF7pCkriuXBpUPEjma1nZyJ0,92 +prometheus_client-0.18.0.dist-info/top_level.txt,sha256=AxLEvHEMhTW-Kvb9Ly1DPI3aapigQ2aeg8TXMt9WMRo,18 +prometheus_client/__init__.py,sha256=D-ptlQkWPXqZIJPi5TR0QNMdWr_Ejv-gMq6WAFik_9o,1815 prometheus_client/__pycache__/__init__.cpython-312.pyc,, prometheus_client/__pycache__/asgi.cpython-312.pyc,, prometheus_client/__pycache__/context_managers.cpython-312.pyc,, @@ -24,38 +24,36 @@ prometheus_client/__pycache__/process_collector.cpython-312.pyc,, prometheus_client/__pycache__/registry.cpython-312.pyc,, prometheus_client/__pycache__/samples.cpython-312.pyc,, prometheus_client/__pycache__/utils.cpython-312.pyc,, -prometheus_client/__pycache__/validation.cpython-312.pyc,, prometheus_client/__pycache__/values.cpython-312.pyc,, -prometheus_client/asgi.py,sha256=rfeeBIusQudy9hjsmRiMmRYSW7aSgEc4gmVPHZ-j5bM,1621 +prometheus_client/asgi.py,sha256=ivn-eV7ZU0BEa4E9oWBFbBRUklHPw9f5lcdGsyFuCLo,1606 prometheus_client/bridge/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 prometheus_client/bridge/__pycache__/__init__.cpython-312.pyc,, prometheus_client/bridge/__pycache__/graphite.cpython-312.pyc,, prometheus_client/bridge/graphite.py,sha256=m5-7IyVyGL8C6S9yLxeupS1pfj8KFNPNlazddamQT8s,2897 prometheus_client/context_managers.py,sha256=E7uksn4D7yBoZWDgjI1VRpR3l2tKivs9DHZ5UAcmPwE,2343 -prometheus_client/core.py,sha256=BkKsCowQEZmrpZR0mEtieA5zhUnV6RkN2stRabpcDMA,930 +prometheus_client/core.py,sha256=yyVvSxa8WQnBvAr4JhO3HqdTqClwhbzmVGvwRvWQMIo,860 prometheus_client/decorator.py,sha256=7MdUokWmzQ17foet2R5QcMubdZ1WDPGYo0_HqLxAw2k,15802 -prometheus_client/exposition.py,sha256=22I8xRcaoGnXYKiRiAofNLPclgpaLT6nCEk6ITGKlLU,30470 +prometheus_client/exposition.py,sha256=83jr9uKj-Xmo830KEbyJrk01CHm89lQvfPWOiE5KgyY,23680 prometheus_client/gc_collector.py,sha256=tBhXXktF9g9h7gvO-DmI2gxPol2_gXI1M6e9ZMazNfY,1514 -prometheus_client/metrics.py,sha256=KaDps8Ku6HmdXArSZbqcIeW9YIA0S8SMf0ADr-EWZcA,27488 -prometheus_client/metrics_core.py,sha256=lbyXIhnDYGcbtDd3k5NDw7SsO3u6V6aHc3RIitzZugw,15565 +prometheus_client/metrics.py,sha256=Mr5XqGO0q-13b_0qmk-8iob4WiupfV02ASjtTf-Aw7A,27116 +prometheus_client/metrics_core.py,sha256=Yz-yqS3pxNdpIRMShQv_IHaKlVS_Q53TaYcP9U8LDlE,15548 prometheus_client/mmap_dict.py,sha256=-t49kywZHFHk2D9IWtunqKFtr5eEgiN-RjFWg16JE-Q,5393 -prometheus_client/multiprocess.py,sha256=b_sgKYaId9ctLzKSsUxY6oSjvqQWmMaVYoKU9dPShLg,7563 +prometheus_client/multiprocess.py,sha256=VIvAR0vmjL0lknnTijKt9HS1DNz9rZrS09HqIIcaZLs,7539 prometheus_client/openmetrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 prometheus_client/openmetrics/__pycache__/__init__.cpython-312.pyc,, prometheus_client/openmetrics/__pycache__/exposition.cpython-312.pyc,, prometheus_client/openmetrics/__pycache__/parser.cpython-312.pyc,, -prometheus_client/openmetrics/exposition.py,sha256=kCTOS4XFilCg7ZfHE2RDm4KLkQOPWJEtUIhqVzusCkE,11226 -prometheus_client/openmetrics/parser.py,sha256=gVM33y__66qA5CgJZD_hcuqWL8I7fDQ5fo7Skz8jjNY,25002 -prometheus_client/parser.py,sha256=-6c-xuVKm8hJjrcjfw1nMxCAL5c2X4g1CW8n4t6wgk0,12592 +prometheus_client/openmetrics/exposition.py,sha256=VzG8zBijM5y6sGXOssdLpHwV6aa9wqJ5YY8iJcR955U,2993 +prometheus_client/openmetrics/parser.py,sha256=c6vQccyW93MXzc22QGdceETg0m_KMeMyEbKrfObG0R8,22125 +prometheus_client/parser.py,sha256=zuVhB8clFPvQ9wOEj1XikN7NoJe8J3pZcQkNgEUkuXg,7434 prometheus_client/platform_collector.py,sha256=t_GD2oCLN3Pql4TltbNqTap8a4HOtbvBm0OU5_gPn38,1879 prometheus_client/process_collector.py,sha256=B8y36L1iq0c3KFlvdNj1F5JEQLTec116h6y3m9Jhk90,3864 prometheus_client/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -prometheus_client/registry.py,sha256=9nL15Ygxa5CCjCR2YS1tnrfF48ADK2EHHISLw48hjhg,6206 -prometheus_client/samples.py,sha256=7gomVT0FHNScdi2IXu-eY2XNQmFADBZ-mx5l5SpF7SE,2348 +prometheus_client/registry.py,sha256=3R-yxiPitVs36cnIRnotqSJmOPwAQsLz-tl6kw3rcd4,6196 +prometheus_client/samples.py,sha256=Fco7izqcgRn6xYBsPlegIB2gol9fXidrhuCeo3g0V9Y,1520 prometheus_client/twisted/__init__.py,sha256=0RxJjYSOC5p6o2cu6JbfUzc8ReHYQGNv9pKP-U4u7OE,72 prometheus_client/twisted/__pycache__/__init__.cpython-312.pyc,, prometheus_client/twisted/__pycache__/_exposition.cpython-312.pyc,, prometheus_client/twisted/_exposition.py,sha256=2TL2BH5sW0i6H7dHkot9aBH9Ld-I60ax55DuaIWnElo,250 -prometheus_client/utils.py,sha256=19xp6HxU__ZU7h3QrDfKLlSvqO626bL9IkkcoGtwctM,907 -prometheus_client/validation.py,sha256=OzyhTmsP5FzAIc51cJ0JS2hAP_FQ6nggowZRyueHoh8,4170 +prometheus_client/utils.py,sha256=zKJZaW_hyZgQSmkaD-rgT5l-YsT3--le0BRQ7v_x8eE,594 prometheus_client/values.py,sha256=hzThQQd0x4mIPR3ddezQpjUoDVdSBnwem4Z48woxpa8,5002 diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/WHEEL similarity index 65% rename from venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/WHEEL index ae527e7..ba48cbc 100644 --- a/venv/lib/python3.12/site-packages/PyJWT-2.10.1.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (75.6.0) +Generator: bdist_wheel (0.41.3) Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/prometheus_client-0.18.0.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/METADATA b/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/METADATA deleted file mode 100644 index 9d7e718..0000000 --- a/venv/lib/python3.12/site-packages/prometheus_client-0.23.1.dist-info/METADATA +++ /dev/null @@ -1,51 +0,0 @@ -Metadata-Version: 2.4 -Name: prometheus_client -Version: 0.23.1 -Summary: Python client for the Prometheus monitoring system. -Author-email: The Prometheus Authors -License-Expression: Apache-2.0 AND BSD-2-Clause -Project-URL: Homepage, https://github.com/prometheus/client_python -Project-URL: Documentation, https://prometheus.github.io/client_python/ -Keywords: prometheus,monitoring,instrumentation,client -Classifier: Development Status :: 4 - Beta -Classifier: Intended Audience :: Developers -Classifier: Intended Audience :: Information Technology -Classifier: Intended Audience :: System Administrators -Classifier: Programming Language :: Python -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 -Classifier: Programming Language :: Python :: Implementation :: CPython -Classifier: Programming Language :: Python :: Implementation :: PyPy -Classifier: Topic :: System :: Monitoring -Requires-Python: >=3.9 -Description-Content-Type: text/markdown -License-File: LICENSE -License-File: NOTICE -Provides-Extra: twisted -Requires-Dist: twisted; extra == "twisted" -Dynamic: license-file - -# Prometheus Python Client - -The official Python client for [Prometheus](https://prometheus.io). - -## Installation - -``` -pip install prometheus-client -``` - -This package can be found on [PyPI](https://pypi.python.org/pypi/prometheus_client). - -## Documentation - -Documentation is available on https://prometheus.github.io/client_python - -## Links - -* [Releases](https://github.com/prometheus/client_python/releases): The releases page shows the history of the project and acts as a changelog. -* [PyPI](https://pypi.python.org/pypi/prometheus_client) diff --git a/venv/lib/python3.12/site-packages/prometheus_client/__init__.py b/venv/lib/python3.12/site-packages/prometheus_client/__init__.py index 221ad27..84a7ba8 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/__init__.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/__init__.py @@ -5,10 +5,9 @@ from . import ( process_collector, registry, ) from .exposition import ( - CONTENT_TYPE_LATEST, CONTENT_TYPE_PLAIN_0_0_4, CONTENT_TYPE_PLAIN_1_0_0, - delete_from_gateway, generate_latest, instance_ip_grouping_key, - make_asgi_app, make_wsgi_app, MetricsHandler, push_to_gateway, - pushadd_to_gateway, start_http_server, start_wsgi_server, + CONTENT_TYPE_LATEST, delete_from_gateway, generate_latest, + instance_ip_grouping_key, make_asgi_app, make_wsgi_app, MetricsHandler, + push_to_gateway, pushadd_to_gateway, start_http_server, start_wsgi_server, write_to_textfile, ) from .gc_collector import GC_COLLECTOR, GCCollector @@ -34,8 +33,6 @@ __all__ = ( 'enable_created_metrics', 'disable_created_metrics', 'CONTENT_TYPE_LATEST', - 'CONTENT_TYPE_PLAIN_0_0_4', - 'CONTENT_TYPE_PLAIN_1_0_0', 'generate_latest', 'MetricsHandler', 'make_wsgi_app', diff --git a/venv/lib/python3.12/site-packages/prometheus_client/asgi.py b/venv/lib/python3.12/site-packages/prometheus_client/asgi.py index affd984..e1864b8 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/asgi.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/asgi.py @@ -11,7 +11,7 @@ def make_asgi_app(registry: CollectorRegistry = REGISTRY, disable_compression: b async def prometheus_app(scope, receive, send): assert scope.get("type") == "http" # Prepare parameters - params = parse_qs(scope.get('query_string', b'').decode("utf8")) + params = parse_qs(scope.get('query_string', b'')) accept_header = ",".join([ value.decode("utf8") for (name, value) in scope.get('headers') if name.decode("utf8").lower() == 'accept' diff --git a/venv/lib/python3.12/site-packages/prometheus_client/core.py b/venv/lib/python3.12/site-packages/prometheus_client/core.py index 60f93ce..ad3a454 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/core.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/core.py @@ -5,10 +5,9 @@ from .metrics_core import ( SummaryMetricFamily, UnknownMetricFamily, UntypedMetricFamily, ) from .registry import CollectorRegistry, REGISTRY -from .samples import BucketSpan, Exemplar, NativeHistogram, Sample, Timestamp +from .samples import Exemplar, Sample, Timestamp __all__ = ( - 'BucketSpan', 'CollectorRegistry', 'Counter', 'CounterMetricFamily', @@ -22,7 +21,6 @@ __all__ = ( 'Info', 'InfoMetricFamily', 'Metric', - 'NativeHistogram', 'REGISTRY', 'Sample', 'StateSetMetricFamily', diff --git a/venv/lib/python3.12/site-packages/prometheus_client/exposition.py b/venv/lib/python3.12/site-packages/prometheus_client/exposition.py index 0d47170..13af927 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/exposition.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/exposition.py @@ -1,6 +1,5 @@ import base64 from contextlib import closing -from functools import partial import gzip from http.server import BaseHTTPRequestHandler import os @@ -20,12 +19,10 @@ from wsgiref.simple_server import make_server, WSGIRequestHandler, WSGIServer from .openmetrics import exposition as openmetrics from .registry import CollectorRegistry, REGISTRY -from .utils import floatToGoString, parse_version +from .utils import floatToGoString __all__ = ( 'CONTENT_TYPE_LATEST', - 'CONTENT_TYPE_PLAIN_0_0_4', - 'CONTENT_TYPE_PLAIN_1_0_0', 'delete_from_gateway', 'generate_latest', 'instance_ip_grouping_key', @@ -39,13 +36,8 @@ __all__ = ( 'write_to_textfile', ) -CONTENT_TYPE_PLAIN_0_0_4 = 'text/plain; version=0.0.4; charset=utf-8' -"""Content type of the compatibility format""" - -CONTENT_TYPE_PLAIN_1_0_0 = 'text/plain; version=1.0.0; charset=utf-8' -"""Content type of the latest format""" - -CONTENT_TYPE_LATEST = CONTENT_TYPE_PLAIN_1_0_0 +CONTENT_TYPE_LATEST = 'text/plain; version=0.0.4; charset=utf-8' +"""Content type of the latest text format""" class _PrometheusRedirectHandler(HTTPRedirectHandler): @@ -126,24 +118,12 @@ def make_wsgi_app(registry: CollectorRegistry = REGISTRY, disable_compression: b accept_header = environ.get('HTTP_ACCEPT') accept_encoding_header = environ.get('HTTP_ACCEPT_ENCODING') params = parse_qs(environ.get('QUERY_STRING', '')) - method = environ['REQUEST_METHOD'] - - if method == 'OPTIONS': - status = '200 OK' - headers = [('Allow', 'OPTIONS,GET')] - output = b'' - elif method != 'GET': - status = '405 Method Not Allowed' - headers = [('Allow', 'OPTIONS,GET')] - output = '# HTTP {}: {}; use OPTIONS or GET\n'.format(status, method).encode() - elif environ['PATH_INFO'] == '/favicon.ico': + if environ['PATH_INFO'] == '/favicon.ico': # Serve empty response for browsers status = '200 OK' - headers = [] + headers = [('', '')] output = b'' else: - # Note: For backwards compatibility, the URI path for GET is not - # constrained to the documented /metrics, but any path is allowed. # Bake output status, headers, output = _bake_output(registry, accept_header, accept_encoding_header, params, disable_compression) # Return output @@ -174,63 +154,12 @@ def _get_best_family(address, port): # binding an ipv6 address is requested. # This function is based on what upstream python did for http.server # in https://github.com/python/cpython/pull/11767 - infos = socket.getaddrinfo(address, port, type=socket.SOCK_STREAM, flags=socket.AI_PASSIVE) + infos = socket.getaddrinfo(address, port) family, _, _, _, sockaddr = next(iter(infos)) return family, sockaddr[0] -def _get_ssl_ctx( - certfile: str, - keyfile: str, - protocol: int, - cafile: Optional[str] = None, - capath: Optional[str] = None, - client_auth_required: bool = False, -) -> ssl.SSLContext: - """Load context supports SSL.""" - ssl_cxt = ssl.SSLContext(protocol=protocol) - - if cafile is not None or capath is not None: - try: - ssl_cxt.load_verify_locations(cafile, capath) - except IOError as exc: - exc_type = type(exc) - msg = str(exc) - raise exc_type(f"Cannot load CA certificate chain from file " - f"{cafile!r} or directory {capath!r}: {msg}") - else: - try: - ssl_cxt.load_default_certs(purpose=ssl.Purpose.CLIENT_AUTH) - except IOError as exc: - exc_type = type(exc) - msg = str(exc) - raise exc_type(f"Cannot load default CA certificate chain: {msg}") - - if client_auth_required: - ssl_cxt.verify_mode = ssl.CERT_REQUIRED - - try: - ssl_cxt.load_cert_chain(certfile=certfile, keyfile=keyfile) - except IOError as exc: - exc_type = type(exc) - msg = str(exc) - raise exc_type(f"Cannot load server certificate file {certfile!r} or " - f"its private key file {keyfile!r}: {msg}") - - return ssl_cxt - - -def start_wsgi_server( - port: int, - addr: str = '0.0.0.0', - registry: CollectorRegistry = REGISTRY, - certfile: Optional[str] = None, - keyfile: Optional[str] = None, - client_cafile: Optional[str] = None, - client_capath: Optional[str] = None, - protocol: int = ssl.PROTOCOL_TLS_SERVER, - client_auth_required: bool = False, -) -> Tuple[WSGIServer, threading.Thread]: +def start_wsgi_server(port: int, addr: str = '0.0.0.0', registry: CollectorRegistry = REGISTRY) -> None: """Starts a WSGI server for prometheus metrics as a daemon thread.""" class TmpServer(ThreadingWSGIServer): @@ -239,51 +168,30 @@ def start_wsgi_server( TmpServer.address_family, addr = _get_best_family(addr, port) app = make_wsgi_app(registry) httpd = make_server(addr, port, app, TmpServer, handler_class=_SilentHandler) - if certfile and keyfile: - context = _get_ssl_ctx(certfile, keyfile, protocol, client_cafile, client_capath, client_auth_required) - httpd.socket = context.wrap_socket(httpd.socket, server_side=True) t = threading.Thread(target=httpd.serve_forever) t.daemon = True t.start() - return httpd, t - start_http_server = start_wsgi_server -def generate_latest(registry: CollectorRegistry = REGISTRY, escaping: str = openmetrics.UNDERSCORES) -> bytes: - """ - Generates the exposition format using the basic Prometheus text format. +def generate_latest(registry: CollectorRegistry = REGISTRY) -> bytes: + """Returns the metrics from the registry in latest text format as a string.""" - Params: - registry: CollectorRegistry to export data from. - escaping: Escaping scheme used for metric and label names. - - Returns: UTF-8 encoded string containing the metrics in text format. - """ - - def sample_line(samples): - if samples.labels: - labelstr = '{0}'.format(','.join( - # Label values always support UTF-8 + def sample_line(line): + if line.labels: + labelstr = '{{{0}}}'.format(','.join( ['{}="{}"'.format( - openmetrics.escape_label_name(k, escaping), openmetrics._escape(v, openmetrics.ALLOWUTF8, False)) - for k, v in sorted(samples.labels.items())])) + k, v.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"')) + for k, v in sorted(line.labels.items())])) else: labelstr = '' timestamp = '' - if samples.timestamp is not None: + if line.timestamp is not None: # Convert to milliseconds. - timestamp = f' {int(float(samples.timestamp) * 1000):d}' - if escaping != openmetrics.ALLOWUTF8 or openmetrics._is_valid_legacy_metric_name(samples.name): - if labelstr: - labelstr = '{{{0}}}'.format(labelstr) - return f'{openmetrics.escape_metric_name(samples.name, escaping)}{labelstr} {floatToGoString(samples.value)}{timestamp}\n' - maybe_comma = '' - if labelstr: - maybe_comma = ',' - return f'{{{openmetrics.escape_metric_name(samples.name, escaping)}{maybe_comma}{labelstr}}} {floatToGoString(samples.value)}{timestamp}\n' + timestamp = f' {int(float(line.timestamp) * 1000):d}' + return f'{line.name}{labelstr} {floatToGoString(line.value)}{timestamp}\n' output = [] for metric in registry.collect(): @@ -306,8 +214,8 @@ def generate_latest(registry: CollectorRegistry = REGISTRY, escaping: str = open mtype = 'untyped' output.append('# HELP {} {}\n'.format( - openmetrics.escape_metric_name(mname, escaping), metric.documentation.replace('\\', r'\\').replace('\n', r'\n'))) - output.append(f'# TYPE {openmetrics.escape_metric_name(mname, escaping)} {mtype}\n') + mname, metric.documentation.replace('\\', r'\\').replace('\n', r'\n'))) + output.append(f'# TYPE {mname} {mtype}\n') om_samples: Dict[str, List[str]] = {} for s in metric.samples: @@ -323,79 +231,20 @@ def generate_latest(registry: CollectorRegistry = REGISTRY, escaping: str = open raise for suffix, lines in sorted(om_samples.items()): - output.append('# HELP {} {}\n'.format(openmetrics.escape_metric_name(metric.name + suffix, escaping), - metric.documentation.replace('\\', r'\\').replace('\n', r'\n'))) - output.append(f'# TYPE {openmetrics.escape_metric_name(metric.name + suffix, escaping)} gauge\n') + output.append('# HELP {}{} {}\n'.format(metric.name, suffix, + metric.documentation.replace('\\', r'\\').replace('\n', r'\n'))) + output.append(f'# TYPE {metric.name}{suffix} gauge\n') output.extend(lines) return ''.join(output).encode('utf-8') def choose_encoder(accept_header: str) -> Tuple[Callable[[CollectorRegistry], bytes], str]: - # Python client library accepts a narrower range of content-types than - # Prometheus does. accept_header = accept_header or '' - escaping = openmetrics.UNDERSCORES for accepted in accept_header.split(','): if accepted.split(';')[0].strip() == 'application/openmetrics-text': - toks = accepted.split(';') - version = _get_version(toks) - escaping = _get_escaping(toks) - # Only return an escaping header if we have a good version and - # mimetype. - if not version: - return (partial(openmetrics.generate_latest, escaping=openmetrics.UNDERSCORES, version="1.0.0"), openmetrics.CONTENT_TYPE_LATEST) - if version and parse_version(version) >= (1, 0, 0): - return (partial(openmetrics.generate_latest, escaping=escaping, version=version), - f'application/openmetrics-text; version={version}; charset=utf-8; escaping=' + str(escaping)) - elif accepted.split(';')[0].strip() == 'text/plain': - toks = accepted.split(';') - version = _get_version(toks) - escaping = _get_escaping(toks) - # Only return an escaping header if we have a good version and - # mimetype. - if version and parse_version(version) >= (1, 0, 0): - return (partial(generate_latest, escaping=escaping), - CONTENT_TYPE_LATEST + '; escaping=' + str(escaping)) - return generate_latest, CONTENT_TYPE_PLAIN_0_0_4 - - -def _get_version(accept_header: List[str]) -> str: - """Return the version tag from the Accept header. - - If no version is specified, returns empty string.""" - - for tok in accept_header: - if '=' not in tok: - continue - key, value = tok.strip().split('=', 1) - if key == 'version': - return value - return "" - - -def _get_escaping(accept_header: List[str]) -> str: - """Return the escaping scheme from the Accept header. - - If no escaping scheme is specified or the scheme is not one of the allowed - strings, defaults to UNDERSCORES.""" - - for tok in accept_header: - if '=' not in tok: - continue - key, value = tok.strip().split('=', 1) - if key != 'escaping': - continue - if value == openmetrics.ALLOWUTF8: - return openmetrics.ALLOWUTF8 - elif value == openmetrics.UNDERSCORES: - return openmetrics.UNDERSCORES - elif value == openmetrics.DOTS: - return openmetrics.DOTS - elif value == openmetrics.VALUES: - return openmetrics.VALUES - else: - return openmetrics.UNDERSCORES - return openmetrics.UNDERSCORES + return (openmetrics.generate_latest, + openmetrics.CONTENT_TYPE_LATEST) + return generate_latest, CONTENT_TYPE_LATEST def gzip_accepted(accept_encoding_header: str) -> bool: @@ -444,34 +293,20 @@ class MetricsHandler(BaseHTTPRequestHandler): return MyMetricsHandler -def write_to_textfile(path: str, registry: CollectorRegistry, escaping: str = openmetrics.ALLOWUTF8, tmpdir: Optional[str] = None) -> None: +def write_to_textfile(path: str, registry: CollectorRegistry) -> None: """Write metrics to the given path. This is intended for use with the Node exporter textfile collector. - The path must end in .prom for the textfile collector to process it. + The path must end in .prom for the textfile collector to process it.""" + tmppath = f'{path}.{os.getpid()}.{threading.current_thread().ident}' + with open(tmppath, 'wb') as f: + f.write(generate_latest(registry)) - An optional tmpdir parameter can be set to determine where the - metrics will be temporarily written to. If not set, it will be in - the same directory as the .prom file. If provided, the path MUST be - on the same filesystem.""" - if tmpdir is not None: - filename = os.path.basename(path) - tmppath = f'{os.path.join(tmpdir, filename)}.{os.getpid()}.{threading.current_thread().ident}' + # rename(2) is atomic but fails on Windows if the destination file exists + if os.name == 'nt': + os.replace(tmppath, path) else: - tmppath = f'{path}.{os.getpid()}.{threading.current_thread().ident}' - try: - with open(tmppath, 'wb') as f: - f.write(generate_latest(registry, escaping)) - - # rename(2) is atomic but fails on Windows if the destination file exists - if os.name == 'nt': - os.replace(tmppath, path) - else: - os.rename(tmppath, path) - except Exception: - if os.path.exists(tmppath): - os.remove(tmppath) - raise + os.rename(tmppath, path) def _make_handler( @@ -572,7 +407,7 @@ def tls_auth_handler( The default protocol (ssl.PROTOCOL_TLS_CLIENT) will also enable ssl.CERT_REQUIRED and SSLContext.check_hostname by default. This can be disabled by setting insecure_skip_verify to True. - + Both this handler and the TLS feature on pushgateay are experimental.""" context = ssl.SSLContext(protocol=protocol) if cafile is not None: @@ -729,7 +564,7 @@ def _use_gateway( handler( url=url, method=method, timeout=timeout, - headers=[('Content-Type', CONTENT_TYPE_PLAIN_0_0_4)], data=data, + headers=[('Content-Type', CONTENT_TYPE_LATEST)], data=data, )() diff --git a/venv/lib/python3.12/site-packages/prometheus_client/metrics.py b/venv/lib/python3.12/site-packages/prometheus_client/metrics.py index b9f25ff..7e5b030 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/metrics.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/metrics.py @@ -6,25 +6,22 @@ from typing import ( Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Type, TypeVar, Union, ) -import warnings from . import values # retain this import style for testability from .context_managers import ExceptionCounter, InprogressTracker, Timer -from .metrics_core import Metric +from .metrics_core import ( + Metric, METRIC_LABEL_NAME_RE, METRIC_NAME_RE, + RESERVED_METRIC_LABEL_NAME_RE, +) from .registry import Collector, CollectorRegistry, REGISTRY from .samples import Exemplar, Sample from .utils import floatToGoString, INF -from .validation import ( - _validate_exemplar, _validate_labelnames, _validate_metric_name, -) T = TypeVar('T', bound='MetricWrapperBase') F = TypeVar("F", bound=Callable[..., Any]) def _build_full_name(metric_type, name, namespace, subsystem, unit): - if not name: - raise ValueError('Metric name should not be empty') full_name = '' if namespace: full_name += namespace + '_' @@ -40,6 +37,31 @@ def _build_full_name(metric_type, name, namespace, subsystem, unit): return full_name +def _validate_labelname(l): + if not METRIC_LABEL_NAME_RE.match(l): + raise ValueError('Invalid label metric name: ' + l) + if RESERVED_METRIC_LABEL_NAME_RE.match(l): + raise ValueError('Reserved label metric name: ' + l) + + +def _validate_labelnames(cls, labelnames): + labelnames = tuple(labelnames) + for l in labelnames: + _validate_labelname(l) + if l in cls._reserved_labelnames: + raise ValueError('Reserved label metric name: ' + l) + return labelnames + + +def _validate_exemplar(exemplar): + runes = 0 + for k, v in exemplar.items(): + _validate_labelname(k) + runes += len(k) + runes += len(v) + if runes > 128: + raise ValueError('Exemplar labels have %d UTF-8 characters, exceeding the limit of 128') + def _get_use_created() -> bool: return os.environ.get("PROMETHEUS_DISABLE_CREATED_SERIES", 'False').lower() not in ('true', '1', 't') @@ -88,8 +110,8 @@ class MetricWrapperBase(Collector): def collect(self) -> Iterable[Metric]: metric = self._get_metric() - for suffix, labels, value, timestamp, exemplar, native_histogram_value in self._samples(): - metric.add_sample(self._name + suffix, labels, value, timestamp, exemplar, native_histogram_value) + for suffix, labels, value, timestamp, exemplar in self._samples(): + metric.add_sample(self._name + suffix, labels, value, timestamp, exemplar) return [metric] def __str__(self) -> str: @@ -116,7 +138,8 @@ class MetricWrapperBase(Collector): self._documentation = documentation self._unit = unit - _validate_metric_name(self._name) + if not METRIC_NAME_RE.match(self._name): + raise ValueError('Invalid metric name: ' + self._name) if self._is_parent(): # Prepare the fields needed for child metrics. @@ -187,11 +210,6 @@ class MetricWrapperBase(Collector): return self._metrics[labelvalues] def remove(self, *labelvalues: Any) -> None: - if 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: - warnings.warn( - "Removal of labels has not been implemented in multi-process mode yet.", - UserWarning) - if not self._labelnames: raise ValueError('No label names were set when constructing %s' % self) @@ -200,15 +218,10 @@ class MetricWrapperBase(Collector): raise ValueError('Incorrect label count (expected %d, got %s)' % (len(self._labelnames), labelvalues)) labelvalues = tuple(str(l) for l in labelvalues) with self._lock: - if labelvalues in self._metrics: - del self._metrics[labelvalues] + del self._metrics[labelvalues] def clear(self) -> None: """Remove all labelsets from the metric""" - if 'prometheus_multiproc_dir' in os.environ or 'PROMETHEUS_MULTIPROC_DIR' in os.environ: - warnings.warn( - "Clearing labels has not been implemented in multi-process mode yet", - UserWarning) with self._lock: self._metrics = {} @@ -223,8 +236,8 @@ class MetricWrapperBase(Collector): metrics = self._metrics.copy() for labels, metric in metrics.items(): series_labels = list(zip(self._labelnames, labels)) - for suffix, sample_labels, value, timestamp, exemplar, native_histogram_value in metric._samples(): - yield Sample(suffix, dict(series_labels + list(sample_labels.items())), value, timestamp, exemplar, native_histogram_value) + for suffix, sample_labels, value, timestamp, exemplar in metric._samples(): + yield Sample(suffix, dict(series_labels + list(sample_labels.items())), value, timestamp, exemplar) def _child_samples(self) -> Iterable[Sample]: # pragma: no cover raise NotImplementedError('_child_samples() must be implemented by %r' % self) @@ -269,12 +282,6 @@ class Counter(MetricWrapperBase): # Count only one type of exception with c.count_exceptions(ValueError): pass - - You can also reset the counter to zero in case your logical "process" restarts - without restarting the actual python process. - - c.reset() - """ _type = 'counter' @@ -293,11 +300,6 @@ class Counter(MetricWrapperBase): _validate_exemplar(exemplar) self._value.set_exemplar(Exemplar(exemplar, amount, time.time())) - def reset(self) -> None: - """Reset the counter to zero. Use this when a logical process restarts without restarting the actual python process.""" - self._value.set(0) - self._created = time.time() - def count_exceptions(self, exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = Exception) -> ExceptionCounter: """Count exceptions in a block of code or function. @@ -682,8 +684,6 @@ class Info(MetricWrapperBase): if self._labelname_set.intersection(val.keys()): raise ValueError('Overlapping labels for Info metric, metric: {} child: {}'.format( self._labelnames, val)) - if any(i is None for i in val.values()): - raise ValueError('Label value cannot be None') with self._lock: self._value = dict(val) diff --git a/venv/lib/python3.12/site-packages/prometheus_client/metrics_core.py b/venv/lib/python3.12/site-packages/prometheus_client/metrics_core.py index 27d1712..7226d92 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/metrics_core.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/metrics_core.py @@ -1,12 +1,15 @@ +import re from typing import Dict, List, Optional, Sequence, Tuple, Union -from .samples import Exemplar, NativeHistogram, Sample, Timestamp -from .validation import _validate_metric_name +from .samples import Exemplar, Sample, Timestamp METRIC_TYPES = ( 'counter', 'gauge', 'summary', 'histogram', 'gaugehistogram', 'unknown', 'info', 'stateset', ) +METRIC_NAME_RE = re.compile(r'^[a-zA-Z_:][a-zA-Z0-9_:]*$') +METRIC_LABEL_NAME_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') +RESERVED_METRIC_LABEL_NAME_RE = re.compile(r'^__.*$') class Metric: @@ -21,7 +24,8 @@ class Metric: def __init__(self, name: str, documentation: str, typ: str, unit: str = ''): if unit and not name.endswith("_" + unit): name += "_" + unit - _validate_metric_name(name) + if not METRIC_NAME_RE.match(name): + raise ValueError('Invalid metric name: ' + name) self.name: str = name self.documentation: str = documentation self.unit: str = unit @@ -32,11 +36,11 @@ class Metric: self.type: str = typ self.samples: List[Sample] = [] - def add_sample(self, name: str, labels: Dict[str, str], value: float, timestamp: Optional[Union[Timestamp, float]] = None, exemplar: Optional[Exemplar] = None, native_histogram: Optional[NativeHistogram] = None) -> None: + def add_sample(self, name: str, labels: Dict[str, str], value: float, timestamp: Optional[Union[Timestamp, float]] = None, exemplar: Optional[Exemplar] = None) -> None: """Add a sample to the metric. Internal-only, do not use.""" - self.samples.append(Sample(name, labels, value, timestamp, exemplar, native_histogram)) + self.samples.append(Sample(name, labels, value, timestamp, exemplar)) def __eq__(self, other: object) -> bool: return (isinstance(other, Metric) @@ -112,7 +116,6 @@ class CounterMetricFamily(Metric): labels: Optional[Sequence[str]] = None, created: Optional[float] = None, unit: str = '', - exemplar: Optional[Exemplar] = None, ): # Glue code for pre-OpenMetrics metrics. if name.endswith('_total'): @@ -124,14 +127,13 @@ class CounterMetricFamily(Metric): labels = [] self._labelnames = tuple(labels) if value is not None: - self.add_metric([], value, created, exemplar=exemplar) + self.add_metric([], value, created) def add_metric(self, labels: Sequence[str], value: float, created: Optional[float] = None, timestamp: Optional[Union[Timestamp, float]] = None, - exemplar: Optional[Exemplar] = None, ) -> None: """Add a metric to the metric family. @@ -140,7 +142,7 @@ class CounterMetricFamily(Metric): value: The value of the metric created: Optional unix timestamp the child was created at. """ - self.samples.append(Sample(self.name + '_total', dict(zip(self._labelnames, labels)), value, timestamp, exemplar)) + self.samples.append(Sample(self.name + '_total', dict(zip(self._labelnames, labels)), value, timestamp)) if created is not None: self.samples.append(Sample(self.name + '_created', dict(zip(self._labelnames, labels)), created, timestamp)) @@ -282,6 +284,7 @@ class HistogramMetricFamily(Metric): Sample(self.name + '_sum', dict(zip(self._labelnames, labels)), sum_value, timestamp)) + class GaugeHistogramMetricFamily(Metric): """A single gauge histogram and its samples. diff --git a/venv/lib/python3.12/site-packages/prometheus_client/multiprocess.py b/venv/lib/python3.12/site-packages/prometheus_client/multiprocess.py index 2682190..7021b49 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/multiprocess.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/multiprocess.py @@ -93,7 +93,7 @@ class MultiProcessCollector: buckets = defaultdict(lambda: defaultdict(float)) samples_setdefault = samples.setdefault for s in metric.samples: - name, labels, value, timestamp, exemplar, native_histogram_value = s + name, labels, value, timestamp, exemplar = s if metric.type == 'gauge': without_pid_key = (name, tuple(l for l in labels if l[0] != 'pid')) if metric._multiprocess_mode in ('min', 'livemin'): diff --git a/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/exposition.py b/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/exposition.py index 5e69e46..b019030 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/exposition.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/exposition.py @@ -1,287 +1,72 @@ #!/usr/bin/env python -from io import StringIO -from sys import maxunicode -from typing import Callable -from ..utils import floatToGoString, parse_version -from ..validation import ( - _is_valid_legacy_labelname, _is_valid_legacy_metric_name, -) +from ..utils import floatToGoString -CONTENT_TYPE_LATEST = 'application/openmetrics-text; version=1.0.0; charset=utf-8' -"""Content type of the latest OpenMetrics 1.0 text format""" -CONTENT_TYPE_LATEST_2_0 = 'application/openmetrics-text; version=2.0.0; charset=utf-8' -"""Content type of the OpenMetrics 2.0 text format""" -ESCAPING_HEADER_TAG = 'escaping' - - -ALLOWUTF8 = 'allow-utf-8' -UNDERSCORES = 'underscores' -DOTS = 'dots' -VALUES = 'values' +CONTENT_TYPE_LATEST = 'application/openmetrics-text; version=0.0.1; charset=utf-8' +"""Content type of the latest OpenMetrics text format""" def _is_valid_exemplar_metric(metric, sample): if metric.type == 'counter' and sample.name.endswith('_total'): return True - if metric.type in ('gaugehistogram') and sample.name.endswith('_bucket'): - return True - if metric.type in ('histogram') and sample.name.endswith('_bucket') or sample.name == metric.name: + if metric.type in ('histogram', 'gaugehistogram') and sample.name.endswith('_bucket'): return True return False -def _compose_exemplar_string(metric, sample, exemplar): - """Constructs an exemplar string.""" - if not _is_valid_exemplar_metric(metric, sample): - raise ValueError(f"Metric {metric.name} has exemplars, but is not a histogram bucket or counter") - labels = '{{{0}}}'.format(','.join( - ['{}="{}"'.format( - k, v.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"')) - for k, v in sorted(exemplar.labels.items())])) - if exemplar.timestamp is not None: - exemplarstr = ' # {} {} {}'.format( - labels, - floatToGoString(exemplar.value), - exemplar.timestamp, - ) - else: - exemplarstr = ' # {} {}'.format( - labels, - floatToGoString(exemplar.value), - ) - - return exemplarstr - - -def generate_latest(registry, escaping=UNDERSCORES, version="1.0.0"): +def generate_latest(registry): '''Returns the metrics from the registry in latest text format as a string.''' output = [] for metric in registry.collect(): try: mname = metric.name output.append('# HELP {} {}\n'.format( - escape_metric_name(mname, escaping), _escape(metric.documentation, ALLOWUTF8, _is_legacy_labelname_rune))) - output.append(f'# TYPE {escape_metric_name(mname, escaping)} {metric.type}\n') + mname, metric.documentation.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"'))) + output.append(f'# TYPE {mname} {metric.type}\n') if metric.unit: - output.append(f'# UNIT {escape_metric_name(mname, escaping)} {metric.unit}\n') + output.append(f'# UNIT {mname} {metric.unit}\n') for s in metric.samples: - if escaping == ALLOWUTF8 and not _is_valid_legacy_metric_name(s.name): - labelstr = escape_metric_name(s.name, escaping) - if s.labels: - labelstr += ',' + if s.labels: + labelstr = '{{{0}}}'.format(','.join( + ['{}="{}"'.format( + k, v.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"')) + for k, v in sorted(s.labels.items())])) else: labelstr = '' - - if s.labels: - items = sorted(s.labels.items()) - # Label values always support UTF-8 - labelstr += ','.join( - ['{}="{}"'.format( - escape_label_name(k, escaping), _escape(v, ALLOWUTF8, _is_legacy_labelname_rune)) - for k, v in items]) - if labelstr: - labelstr = "{" + labelstr + "}" if s.exemplar: - exemplarstr = _compose_exemplar_string(metric, s, s.exemplar) + if not _is_valid_exemplar_metric(metric, s): + raise ValueError(f"Metric {metric.name} has exemplars, but is not a histogram bucket or counter") + labels = '{{{0}}}'.format(','.join( + ['{}="{}"'.format( + k, v.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"')) + for k, v in sorted(s.exemplar.labels.items())])) + if s.exemplar.timestamp is not None: + exemplarstr = ' # {} {} {}'.format( + labels, + floatToGoString(s.exemplar.value), + s.exemplar.timestamp, + ) + else: + exemplarstr = ' # {} {}'.format( + labels, + floatToGoString(s.exemplar.value), + ) else: exemplarstr = '' timestamp = '' if s.timestamp is not None: timestamp = f' {s.timestamp}' - - # Skip native histogram samples entirely if version < 2.0.0 - if s.native_histogram and parse_version(version) < (2, 0, 0): - continue - - native_histogram = '' - negative_spans = '' - negative_deltas = '' - positive_spans = '' - positive_deltas = '' - - if s.native_histogram: - # Initialize basic nh template - nh_sample_template = '{{count:{},sum:{},schema:{},zero_threshold:{},zero_count:{}' - - args = [ - s.native_histogram.count_value, - s.native_histogram.sum_value, - s.native_histogram.schema, - s.native_histogram.zero_threshold, - s.native_histogram.zero_count, - ] - - # If there are neg spans, append them and the neg deltas to the template and args - if s.native_histogram.neg_spans: - negative_spans = ','.join([f'{ns[0]}:{ns[1]}' for ns in s.native_histogram.neg_spans]) - negative_deltas = ','.join(str(nd) for nd in s.native_histogram.neg_deltas) - nh_sample_template += ',negative_spans:[{}]' - args.append(negative_spans) - nh_sample_template += ',negative_deltas:[{}]' - args.append(negative_deltas) - - # If there are pos spans, append them and the pos spans to the template and args - if s.native_histogram.pos_spans: - positive_spans = ','.join([f'{ps[0]}:{ps[1]}' for ps in s.native_histogram.pos_spans]) - positive_deltas = ','.join(f'{pd}' for pd in s.native_histogram.pos_deltas) - nh_sample_template += ',positive_spans:[{}]' - args.append(positive_spans) - nh_sample_template += ',positive_deltas:[{}]' - args.append(positive_deltas) - - # Add closing brace - nh_sample_template += '}}' - - # Format the template with the args - native_histogram = nh_sample_template.format(*args) - - if s.native_histogram.nh_exemplars: - for nh_ex in s.native_histogram.nh_exemplars: - nh_exemplarstr = _compose_exemplar_string(metric, s, nh_ex) - exemplarstr += nh_exemplarstr - - value = '' - if s.native_histogram: - value = native_histogram - elif s.value is not None: - value = floatToGoString(s.value) - if (escaping != ALLOWUTF8) or _is_valid_legacy_metric_name(s.name): - output.append('{}{} {}{}{}\n'.format( - _escape(s.name, escaping, _is_legacy_labelname_rune), - labelstr, - value, - timestamp, - exemplarstr - )) - else: - output.append('{} {}{}{}\n'.format( - labelstr, - value, - timestamp, - exemplarstr - )) + output.append('{}{} {}{}{}\n'.format( + s.name, + labelstr, + floatToGoString(s.value), + timestamp, + exemplarstr, + )) except Exception as exception: exception.args = (exception.args or ('',)) + (metric,) raise output.append('# EOF\n') return ''.join(output).encode('utf-8') - - -def escape_metric_name(s: str, escaping: str = UNDERSCORES) -> str: - """Escapes the metric name and puts it in quotes iff the name does not - conform to the legacy Prometheus character set. - """ - if len(s) == 0: - return s - if escaping == ALLOWUTF8: - if not _is_valid_legacy_metric_name(s): - return '"{}"'.format(_escape(s, escaping, _is_legacy_metric_rune)) - return _escape(s, escaping, _is_legacy_metric_rune) - elif escaping == UNDERSCORES: - if _is_valid_legacy_metric_name(s): - return s - return _escape(s, escaping, _is_legacy_metric_rune) - elif escaping == DOTS: - return _escape(s, escaping, _is_legacy_metric_rune) - elif escaping == VALUES: - if _is_valid_legacy_metric_name(s): - return s - return _escape(s, escaping, _is_legacy_metric_rune) - return s - - -def escape_label_name(s: str, escaping: str = UNDERSCORES) -> str: - """Escapes the label name and puts it in quotes iff the name does not - conform to the legacy Prometheus character set. - """ - if len(s) == 0: - return s - if escaping == ALLOWUTF8: - if not _is_valid_legacy_labelname(s): - return '"{}"'.format(_escape(s, escaping, _is_legacy_labelname_rune)) - return _escape(s, escaping, _is_legacy_labelname_rune) - elif escaping == UNDERSCORES: - if _is_valid_legacy_labelname(s): - return s - return _escape(s, escaping, _is_legacy_labelname_rune) - elif escaping == DOTS: - return _escape(s, escaping, _is_legacy_labelname_rune) - elif escaping == VALUES: - if _is_valid_legacy_labelname(s): - return s - return _escape(s, escaping, _is_legacy_labelname_rune) - return s - - -def _escape(s: str, escaping: str, valid_rune_fn: Callable[[str, int], bool]) -> str: - """Performs backslash escaping on backslash, newline, and double-quote characters. - - valid_rune_fn takes the input character and its index in the containing string.""" - if escaping == ALLOWUTF8: - return s.replace('\\', r'\\').replace('\n', r'\n').replace('"', r'\"') - elif escaping == UNDERSCORES: - escaped = StringIO() - for i, b in enumerate(s): - if valid_rune_fn(b, i): - escaped.write(b) - else: - escaped.write('_') - return escaped.getvalue() - elif escaping == DOTS: - escaped = StringIO() - for i, b in enumerate(s): - if b == '_': - escaped.write('__') - elif b == '.': - escaped.write('_dot_') - elif valid_rune_fn(b, i): - escaped.write(b) - else: - escaped.write('__') - return escaped.getvalue() - elif escaping == VALUES: - escaped = StringIO() - escaped.write("U__") - for i, b in enumerate(s): - if b == '_': - escaped.write("__") - elif valid_rune_fn(b, i): - escaped.write(b) - elif not _is_valid_utf8(b): - escaped.write("_FFFD_") - else: - escaped.write('_') - escaped.write(format(ord(b), 'x')) - escaped.write('_') - return escaped.getvalue() - return s - - -def _is_legacy_metric_rune(b: str, i: int) -> bool: - return _is_legacy_labelname_rune(b, i) or b == ':' - - -def _is_legacy_labelname_rune(b: str, i: int) -> bool: - if len(b) != 1: - raise ValueError("Input 'b' must be a single character.") - return ( - ('a' <= b <= 'z') - or ('A' <= b <= 'Z') - or (b == '_') - or ('0' <= b <= '9' and i > 0) - ) - - -_SURROGATE_MIN = 0xD800 -_SURROGATE_MAX = 0xDFFF - - -def _is_valid_utf8(s: str) -> bool: - if 0 <= ord(s) < _SURROGATE_MIN: - return True - if _SURROGATE_MAX < ord(s) <= maxunicode: - return True - return False diff --git a/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/parser.py b/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/parser.py index d967e83..6128a0d 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/parser.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/openmetrics/parser.py @@ -5,14 +5,9 @@ import io as StringIO import math import re -from ..metrics_core import Metric -from ..parser import ( - _last_unquoted_char, _next_unquoted_char, _parse_value, _split_quoted, - _unquote_unescape, parse_labels, -) -from ..samples import BucketSpan, Exemplar, NativeHistogram, Sample, Timestamp +from ..metrics_core import Metric, METRIC_LABEL_NAME_RE +from ..samples import Exemplar, Sample, Timestamp from ..utils import floatToGoString -from ..validation import _is_valid_legacy_metric_name, _validate_metric_name def text_string_to_metric_families(text): @@ -78,6 +73,16 @@ def _unescape_help(text): return ''.join(result) +def _parse_value(value): + value = ''.join(value) + if value != value.strip() or '_' in value: + raise ValueError(f"Invalid value: {value!r}") + try: + return int(value) + except ValueError: + return float(value) + + def _parse_timestamp(timestamp): timestamp = ''.join(timestamp) if not timestamp: @@ -108,31 +113,165 @@ def _is_character_escaped(s, charpos): return num_bslashes % 2 == 1 +def _parse_labels_with_state_machine(text): + # The { has already been parsed. + state = 'startoflabelname' + labelname = [] + labelvalue = [] + labels = {} + labels_len = 0 + + for char in text: + if state == 'startoflabelname': + if char == '}': + state = 'endoflabels' + else: + state = 'labelname' + labelname.append(char) + elif state == 'labelname': + if char == '=': + state = 'labelvaluequote' + else: + labelname.append(char) + elif state == 'labelvaluequote': + if char == '"': + state = 'labelvalue' + else: + raise ValueError("Invalid line: " + text) + elif state == 'labelvalue': + if char == '\\': + state = 'labelvalueslash' + elif char == '"': + ln = ''.join(labelname) + if not METRIC_LABEL_NAME_RE.match(ln): + raise ValueError("Invalid line, bad label name: " + text) + if ln in labels: + raise ValueError("Invalid line, duplicate label name: " + text) + labels[ln] = ''.join(labelvalue) + labelname = [] + labelvalue = [] + state = 'endoflabelvalue' + else: + labelvalue.append(char) + elif state == 'endoflabelvalue': + if char == ',': + state = 'labelname' + elif char == '}': + state = 'endoflabels' + else: + raise ValueError("Invalid line: " + text) + elif state == 'labelvalueslash': + state = 'labelvalue' + if char == '\\': + labelvalue.append('\\') + elif char == 'n': + labelvalue.append('\n') + elif char == '"': + labelvalue.append('"') + else: + labelvalue.append('\\' + char) + elif state == 'endoflabels': + if char == ' ': + break + else: + raise ValueError("Invalid line: " + text) + labels_len += 1 + return labels, labels_len + + +def _parse_labels(text): + labels = {} + + # Raise error if we don't have valid labels + if text and "=" not in text: + raise ValueError + + # Copy original labels + sub_labels = text + try: + # Process one label at a time + while sub_labels: + # The label name is before the equal + value_start = sub_labels.index("=") + label_name = sub_labels[:value_start] + sub_labels = sub_labels[value_start + 1:] + + # Check for missing quotes + if not sub_labels or sub_labels[0] != '"': + raise ValueError + + # The first quote is guaranteed to be after the equal + value_substr = sub_labels[1:] + + # Check for extra commas + if not label_name or label_name[0] == ',': + raise ValueError + if not value_substr or value_substr[-1] == ',': + raise ValueError + + # Find the last unescaped quote + i = 0 + while i < len(value_substr): + i = value_substr.index('"', i) + if not _is_character_escaped(value_substr[:i], i): + break + i += 1 + + # The label value is between the first and last quote + quote_end = i + 1 + label_value = sub_labels[1:quote_end] + # Replace escaping if needed + if "\\" in label_value: + label_value = _replace_escaping(label_value) + if not METRIC_LABEL_NAME_RE.match(label_name): + raise ValueError("invalid line, bad label name: " + text) + if label_name in labels: + raise ValueError("invalid line, duplicate label name: " + text) + labels[label_name] = label_value + + # Remove the processed label from the sub-slice for next iteration + sub_labels = sub_labels[quote_end + 1:] + if sub_labels.startswith(","): + next_comma = 1 + else: + next_comma = 0 + sub_labels = sub_labels[next_comma:] + + # Check for missing commas + if sub_labels and next_comma == 0: + raise ValueError + + return labels + + except ValueError: + raise ValueError("Invalid labels: " + text) + + def _parse_sample(text): separator = " # " # Detect the labels in the text - label_start = _next_unquoted_char(text, '{') + label_start = text.find("{") if label_start == -1 or separator in text[:label_start]: # We don't have labels, but there could be an exemplar. - name_end = _next_unquoted_char(text, ' ') + name_end = text.index(" ") name = text[:name_end] - if not _is_valid_legacy_metric_name(name): - raise ValueError("invalid metric name:" + text) # Parse the remaining text after the name remaining_text = text[name_end + 1:] value, timestamp, exemplar = _parse_remaining_text(remaining_text) return Sample(name, {}, value, timestamp, exemplar) + # The name is before the labels name = text[:label_start] - label_end = _next_unquoted_char(text, '}') - labels = parse_labels(text[label_start + 1:label_end], True) - if not name: - # Name might be in the labels - if '__name__' not in labels: - raise ValueError - name = labels['__name__'] - del labels['__name__'] - elif '__name__' in labels: - raise ValueError("metric name specified more than once") + if separator not in text: + # Line doesn't contain an exemplar + # We can use `rindex` to find `label_end` + label_end = text.rindex("}") + label = text[label_start + 1:label_end] + labels = _parse_labels(label) + else: + # Line potentially contains an exemplar + # Fallback to parsing labels with a state machine + labels, labels_len = _parse_labels_with_state_machine(text[label_start + 1:]) + label_end = labels_len + len(name) # Parsing labels succeeded, continue parsing the remaining text remaining_text = text[label_end + 2:] value, timestamp, exemplar = _parse_remaining_text(remaining_text) @@ -155,12 +294,7 @@ def _parse_remaining_text(text): text = split_text[1] it = iter(text) - in_quotes = False for char in it: - if char == '"': - in_quotes = not in_quotes - if in_quotes: - continue if state == 'timestamp': if char == '#' and not timestamp: state = 'exemplarspace' @@ -180,9 +314,8 @@ def _parse_remaining_text(text): raise ValueError("Invalid line: " + text) elif state == 'exemplarstartoflabels': if char == '{': - label_start = _next_unquoted_char(text, '{') - label_end = _last_unquoted_char(text, '}') - exemplar_labels = parse_labels(text[label_start + 1:label_end], True) + label_start, label_end = text.index("{"), text.rindex("}") + exemplar_labels = _parse_labels(text[label_start + 1:label_end]) state = 'exemplarparsedlabels' else: raise ValueError("Invalid line: " + text) @@ -231,154 +364,6 @@ def _parse_remaining_text(text): return val, ts, exemplar -def _parse_nh_sample(text, suffixes): - """Determines if the line has a native histogram sample, and parses it if so.""" - labels_start = _next_unquoted_char(text, '{') - labels_end = -1 - - # Finding a native histogram sample requires careful parsing of - # possibly-quoted text, which can appear in metric names, label names, and - # values. - # - # First, we need to determine if there are metric labels. Find the space - # between the metric definition and the rest of the line. Look for unquoted - # space or {. - i = 0 - has_metric_labels = False - i = _next_unquoted_char(text, ' {') - if i == -1: - return - - # If the first unquoted char was a {, then that is the metric labels (which - # could contain a UTF-8 metric name). - if text[i] == '{': - has_metric_labels = True - # Consume the labels -- jump ahead to the close bracket. - labels_end = i = _next_unquoted_char(text, '}', i) - if labels_end == -1: - raise ValueError - - # If there is no subsequent unquoted {, then it's definitely not a nh. - nh_value_start = _next_unquoted_char(text, '{', i + 1) - if nh_value_start == -1: - return - - # Edge case: if there is an unquoted # between the metric definition and the {, - # then this is actually an exemplar - exemplar = _next_unquoted_char(text, '#', i + 1) - if exemplar != -1 and exemplar < nh_value_start: - return - - nh_value_end = _next_unquoted_char(text, '}', nh_value_start) - if nh_value_end == -1: - raise ValueError - - if has_metric_labels: - labelstext = text[labels_start + 1:labels_end] - labels = parse_labels(labelstext, True) - name_end = labels_start - name = text[:name_end] - if name.endswith(suffixes): - raise ValueError("the sample name of a native histogram with labels should have no suffixes", name) - if not name: - # Name might be in the labels - if '__name__' not in labels: - raise ValueError - name = labels['__name__'] - del labels['__name__'] - # Edge case: the only "label" is the name definition. - if not labels: - labels = None - - nh_value = text[nh_value_start:] - nat_hist_value = _parse_nh_struct(nh_value) - return Sample(name, labels, None, None, None, nat_hist_value) - # check if it's a native histogram - else: - nh_value = text[nh_value_start:] - name_end = nh_value_start - 1 - name = text[:name_end] - if name.endswith(suffixes): - raise ValueError("the sample name of a native histogram should have no suffixes", name) - # Not possible for UTF-8 name here, that would have been caught as having a labelset. - nat_hist_value = _parse_nh_struct(nh_value) - return Sample(name, None, None, None, None, nat_hist_value) - - -def _parse_nh_struct(text): - pattern = r'(\w+):\s*([^,}]+)' - re_spans = re.compile(r'(positive_spans|negative_spans):\[(\d+:\d+(,\d+:\d+)*)\]') - re_deltas = re.compile(r'(positive_deltas|negative_deltas):\[(-?\d+(?:,-?\d+)*)\]') - - items = dict(re.findall(pattern, text)) - span_matches = re_spans.findall(text) - deltas = dict(re_deltas.findall(text)) - - count_value = int(items['count']) - sum_value = int(items['sum']) - schema = int(items['schema']) - zero_threshold = float(items['zero_threshold']) - zero_count = int(items['zero_count']) - - pos_spans = _compose_spans(span_matches, 'positive_spans') - neg_spans = _compose_spans(span_matches, 'negative_spans') - pos_deltas = _compose_deltas(deltas, 'positive_deltas') - neg_deltas = _compose_deltas(deltas, 'negative_deltas') - - return NativeHistogram( - count_value=count_value, - sum_value=sum_value, - schema=schema, - zero_threshold=zero_threshold, - zero_count=zero_count, - pos_spans=pos_spans, - neg_spans=neg_spans, - pos_deltas=pos_deltas, - neg_deltas=neg_deltas - ) - - -def _compose_spans(span_matches, spans_name): - """Takes a list of span matches (expected to be a list of tuples) and a string - (the expected span list name) and processes the list so that the values extracted - from the span matches can be used to compose a tuple of BucketSpan objects""" - spans = {} - for match in span_matches: - # Extract the key from the match (first element of the tuple). - key = match[0] - # Extract the value from the match (second element of the tuple). - # Split the value string by commas to get individual pairs, - # split each pair by ':' to get start and end, and convert them to integers. - value = [tuple(map(int, pair.split(':'))) for pair in match[1].split(',')] - # Store the processed value in the spans dictionary with the key. - spans[key] = value - if spans_name not in spans: - return None - out_spans = [] - # Iterate over each start and end tuple in the list of tuples for the specified spans_name. - for start, end in spans[spans_name]: - # Compose a BucketSpan object with the start and end values - # and append it to the out_spans list. - out_spans.append(BucketSpan(start, end)) - # Convert to tuple - out_spans_tuple = tuple(out_spans) - return out_spans_tuple - - -def _compose_deltas(deltas, deltas_name): - """Takes a list of deltas matches (a dictionary) and a string (the expected delta list name), - and processes its elements to compose a tuple of integers representing the deltas""" - if deltas_name not in deltas: - return None - out_deltas = deltas.get(deltas_name) - if out_deltas is not None and out_deltas.strip(): - elems = out_deltas.split(',') - # Convert each element in the list elems to an integer - # after stripping whitespace and create a tuple from these integers. - out_deltas_tuple = tuple(int(x.strip()) for x in elems) - return out_deltas_tuple - - def _group_for_sample(sample, name, typ): if typ == 'info': # We can't distinguish between groups for info metrics. @@ -421,8 +406,6 @@ def _check_histogram(samples, name): for s in samples: suffix = s.name[len(name):] g = _group_for_sample(s, name, 'histogram') - if len(suffix) == 0: - continue if g != group or s.timestamp != timestamp: if group is not None: do_checks() @@ -498,14 +481,11 @@ def text_fd_to_metric_families(fd): raise ValueError("Units not allowed for this metric type: " + name) if typ in ['histogram', 'gaugehistogram']: _check_histogram(samples, name) - _validate_metric_name(name) metric = Metric(name, documentation, typ, unit) # TODO: check labelvalues are valid utf8 metric.samples = samples return metric - is_nh = False - typ = None for line in fd: if line[-1] == '\n': line = line[:-1] @@ -519,19 +499,16 @@ def text_fd_to_metric_families(fd): if line == '# EOF': eof = True elif line.startswith('#'): - parts = _split_quoted(line, ' ', 3) + parts = line.split(' ', 3) if len(parts) < 4: raise ValueError("Invalid line: " + line) - candidate_name, quoted = _unquote_unescape(parts[2]) - if not quoted and not _is_valid_legacy_metric_name(candidate_name): - raise ValueError - if candidate_name == name and samples: + if parts[2] == name and samples: raise ValueError("Received metadata after samples: " + line) - if candidate_name != name: + if parts[2] != name: if name is not None: yield build_metric(name, documentation, typ, unit, samples) # New metric - name = candidate_name + name = parts[2] unit = None typ = None documentation = None @@ -540,8 +517,8 @@ def text_fd_to_metric_families(fd): group_timestamp = None group_timestamp_samples = set() samples = [] - allowed_names = [candidate_name] - + allowed_names = [parts[2]] + if parts[1] == 'HELP': if documentation is not None: raise ValueError("More than one HELP for metric: " + line) @@ -560,25 +537,12 @@ def text_fd_to_metric_families(fd): else: raise ValueError("Invalid line: " + line) else: - if typ == 'histogram': - # set to true to account for native histograms naming exceptions/sanitizing differences - is_nh = True - sample = _parse_nh_sample(line, tuple(type_suffixes['histogram'])) - # It's not a native histogram - if sample is None: - is_nh = False - sample = _parse_sample(line) - else: - is_nh = False - sample = _parse_sample(line) - if sample.name not in allowed_names and not is_nh: + sample = _parse_sample(line) + if sample.name not in allowed_names: if name is not None: yield build_metric(name, documentation, typ, unit, samples) # Start an unknown metric. - candidate_name, quoted = _unquote_unescape(sample.name) - if not quoted and not _is_valid_legacy_metric_name(candidate_name): - raise ValueError - name = candidate_name + name = sample.name documentation = None unit = None typ = 'unknown' @@ -606,29 +570,26 @@ def text_fd_to_metric_families(fd): or _isUncanonicalNumber(sample.labels['quantile']))): raise ValueError("Invalid quantile label: " + line) - if not is_nh: - g = tuple(sorted(_group_for_sample(sample, name, typ).items())) - if group is not None and g != group and g in seen_groups: - raise ValueError("Invalid metric grouping: " + line) - if group is not None and g == group: - if (sample.timestamp is None) != (group_timestamp is None): - raise ValueError("Mix of timestamp presence within a group: " + line) - if group_timestamp is not None and group_timestamp > sample.timestamp and typ != 'info': - raise ValueError("Timestamps went backwards within a group: " + line) - else: - group_timestamp_samples = set() - - series_id = (sample.name, tuple(sorted(sample.labels.items()))) - if sample.timestamp != group_timestamp or series_id not in group_timestamp_samples: - # Not a duplicate due to timestamp truncation. - samples.append(sample) - group_timestamp_samples.add(series_id) - - group = g - group_timestamp = sample.timestamp - seen_groups.add(g) + g = tuple(sorted(_group_for_sample(sample, name, typ).items())) + if group is not None and g != group and g in seen_groups: + raise ValueError("Invalid metric grouping: " + line) + if group is not None and g == group: + if (sample.timestamp is None) != (group_timestamp is None): + raise ValueError("Mix of timestamp presence within a group: " + line) + if group_timestamp is not None and group_timestamp > sample.timestamp and typ != 'info': + raise ValueError("Timestamps went backwards within a group: " + line) else: + group_timestamp_samples = set() + + series_id = (sample.name, tuple(sorted(sample.labels.items()))) + if sample.timestamp != group_timestamp or series_id not in group_timestamp_samples: + # Not a duplicate due to timestamp truncation. samples.append(sample) + group_timestamp_samples.add(series_id) + + group = g + group_timestamp = sample.timestamp + seen_groups.add(g) if typ == 'stateset' and sample.value not in [0, 1]: raise ValueError("Stateset samples can only have values zero and one: " + line) @@ -645,7 +606,7 @@ def text_fd_to_metric_families(fd): (typ in ['histogram', 'gaugehistogram'] and sample.name.endswith('_bucket')) or (typ in ['counter'] and sample.name.endswith('_total'))): raise ValueError("Invalid line only histogram/gaugehistogram buckets and counters can have exemplars: " + line) - + if name is not None: yield build_metric(name, documentation, typ, unit, samples) diff --git a/venv/lib/python3.12/site-packages/prometheus_client/parser.py b/venv/lib/python3.12/site-packages/prometheus_client/parser.py index ceca273..dc8e30d 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/parser.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/parser.py @@ -1,13 +1,9 @@ import io as StringIO import re -import string from typing import Dict, Iterable, List, Match, Optional, TextIO, Tuple from .metrics_core import Metric from .samples import Sample -from .validation import ( - _is_valid_legacy_metric_name, _validate_labelname, _validate_metric_name, -) def text_string_to_metric_families(text: str) -> Iterable[Metric]: @@ -49,172 +45,54 @@ def _is_character_escaped(s: str, charpos: int) -> bool: return num_bslashes % 2 == 1 -def parse_labels(labels_string: str, openmetrics: bool = False) -> Dict[str, str]: +def _parse_labels(labels_string: str) -> Dict[str, str]: labels: Dict[str, str] = {} + # Return if we don't have valid labels + if "=" not in labels_string: + return labels + + escaping = False + if "\\" in labels_string: + escaping = True # Copy original labels - sub_labels = labels_string.strip() - if openmetrics and sub_labels and sub_labels[0] == ',': - raise ValueError("leading comma: " + labels_string) + sub_labels = labels_string try: # Process one label at a time while sub_labels: - # The label name is before the equal, or if there's no equal, that's the - # metric name. - - name_term, value_term, sub_labels = _next_term(sub_labels, openmetrics) - if not value_term: - if openmetrics: - raise ValueError("empty term in line: " + labels_string) - continue - - label_name, quoted_name = _unquote_unescape(name_term) - - if not quoted_name and not _is_valid_legacy_metric_name(label_name): - raise ValueError("unquoted UTF-8 metric name") - - # Check for missing quotes - if not value_term or value_term[0] != '"': - raise ValueError + # The label name is before the equal + value_start = sub_labels.index("=") + label_name = sub_labels[:value_start] + sub_labels = sub_labels[value_start + 1:].lstrip() + # Find the first quote after the equal + quote_start = sub_labels.index('"') + 1 + value_substr = sub_labels[quote_start:] - # The first quote is guaranteed to be after the equal. - # Make sure that the next unescaped quote is the last character. - i = 1 - while i < len(value_term): - i = value_term.index('"', i) - if not _is_character_escaped(value_term[:i], i): + # Find the last unescaped quote + i = 0 + while i < len(value_substr): + i = value_substr.index('"', i) + if not _is_character_escaped(value_substr, i): break i += 1 + # The label value is between the first and last quote quote_end = i + 1 - if quote_end != len(value_term): - raise ValueError("unexpected text after quote: " + labels_string) + label_value = sub_labels[quote_start:quote_end] + # Replace escaping if needed + if escaping: + label_value = _replace_escaping(label_value) + labels[label_name.strip()] = label_value + + # Remove the processed label from the sub-slice for next iteration + sub_labels = sub_labels[quote_end + 1:] + next_comma = sub_labels.find(",") + 1 + sub_labels = sub_labels[next_comma:].lstrip() - label_value, _ = _unquote_unescape(value_term) - if label_name == '__name__': - _validate_metric_name(label_name) - else: - _validate_labelname(label_name) - if label_name in labels: - raise ValueError("invalid line, duplicate label name: " + labels_string) - labels[label_name] = label_value return labels + except ValueError: - raise ValueError("Invalid labels: " + labels_string) - - -def _next_term(text: str, openmetrics: bool) -> Tuple[str, str, str]: - """Extract the next comma-separated label term from the text. The results - are stripped terms for the label name, label value, and then the remainder - of the string including the final , or }. - - Raises ValueError if the term is empty and we're in openmetrics mode. - """ - - # There may be a leading comma, which is fine here. - if text[0] == ',': - text = text[1:] - if not text: - return "", "", "" - if text[0] == ',': - raise ValueError("multiple commas") - - splitpos = _next_unquoted_char(text, '=,}') - if splitpos >= 0 and text[splitpos] == "=": - labelname = text[:splitpos] - text = text[splitpos + 1:] - splitpos = _next_unquoted_char(text, ',}') - else: - labelname = "__name__" - - if splitpos == -1: - splitpos = len(text) - term = text[:splitpos] - if not term and openmetrics: - raise ValueError("empty term:", term) - - rest = text[splitpos:] - return labelname, term.strip(), rest.strip() - - -def _next_unquoted_char(text: str, chs: Optional[str], startidx: int = 0) -> int: - """Return position of next unquoted character in tuple, or -1 if not found. - - It is always assumed that the first character being checked is not already - inside quotes. - """ - in_quotes = False - if chs is None: - chs = string.whitespace - - for i, c in enumerate(text[startidx:]): - if c == '"' and not _is_character_escaped(text, startidx + i): - in_quotes = not in_quotes - if not in_quotes: - if c in chs: - return startidx + i - return -1 - - -def _last_unquoted_char(text: str, chs: Optional[str]) -> int: - """Return position of last unquoted character in list, or -1 if not found.""" - i = len(text) - 1 - in_quotes = False - if chs is None: - chs = string.whitespace - while i > 0: - if text[i] == '"' and not _is_character_escaped(text, i): - in_quotes = not in_quotes - - if not in_quotes: - if text[i] in chs: - return i - i -= 1 - return -1 - - -def _split_quoted(text, separator, maxsplit=0): - """Splits on split_ch similarly to strings.split, skipping separators if - they are inside quotes. - """ - - tokens = [''] - x = 0 - while x < len(text): - split_pos = _next_unquoted_char(text, separator, x) - if split_pos == -1: - tokens[-1] = text[x:] - x = len(text) - continue - # If the first character is the separator keep going. This happens when - # there are double whitespace characters separating symbols. - if split_pos == x: - x += 1 - continue - - if maxsplit > 0 and len(tokens) > maxsplit: - tokens[-1] = text[x:] - break - tokens[-1] = text[x:split_pos] - x = split_pos + 1 - tokens.append('') - return tokens - - -def _unquote_unescape(text): - """Returns the string, and true if it was quoted.""" - if not text: - return text, False - quoted = False - text = text.strip() - if text[0] == '"': - if len(text) == 1 or text[-1] != '"': - raise ValueError("missing close quote") - text = text[1:-1] - quoted = True - if "\\" in text: - text = _replace_escaping(text) - return text, quoted + raise ValueError("Invalid labels: %s" % labels_string) # If we have multiple values only consider the first @@ -226,50 +104,34 @@ def _parse_value_and_timestamp(s: str) -> Tuple[float, Optional[float]]: values = [value.strip() for value in s.split(separator) if value.strip()] if not values: return float(s), None - value = _parse_value(values[0]) - timestamp = (_parse_value(values[-1]) / 1000) if len(values) > 1 else None + value = float(values[0]) + timestamp = (float(values[-1]) / 1000) if len(values) > 1 else None return value, timestamp -def _parse_value(value): - value = ''.join(value) - if value != value.strip() or '_' in value: - raise ValueError(f"Invalid value: {value!r}") - try: - return int(value) - except ValueError: - return float(value) - - -def _parse_sample(text): - separator = " # " +def _parse_sample(text: str) -> Sample: # Detect the labels in the text - label_start = _next_unquoted_char(text, '{') - if label_start == -1 or separator in text[:label_start]: - # We don't have labels, but there could be an exemplar. - name_end = _next_unquoted_char(text, ' \t') - name = text[:name_end].strip() - if not _is_valid_legacy_metric_name(name): - raise ValueError("invalid metric name:" + text) - # Parse the remaining text after the name - remaining_text = text[name_end + 1:] - value, timestamp = _parse_value_and_timestamp(remaining_text) + try: + label_start, label_end = text.index("{"), text.rindex("}") + # The name is before the labels + name = text[:label_start].strip() + # We ignore the starting curly brace + label = text[label_start + 1:label_end] + # The value is after the label end (ignoring curly brace) + value, timestamp = _parse_value_and_timestamp(text[label_end + 1:]) + return Sample(name, _parse_labels(label), value, timestamp) + + # We don't have labels + except ValueError: + # Detect what separator is used + separator = " " + if separator not in text: + separator = "\t" + name_end = text.index(separator) + name = text[:name_end] + # The value is after the name + value, timestamp = _parse_value_and_timestamp(text[name_end:]) return Sample(name, {}, value, timestamp) - name = text[:label_start].strip() - label_end = _next_unquoted_char(text[label_start:], '}') + label_start - labels = parse_labels(text[label_start + 1:label_end], False) - if not name: - # Name might be in the labels - if '__name__' not in labels: - raise ValueError - name = labels['__name__'] - del labels['__name__'] - elif '__name__' in labels: - raise ValueError("metric name specified more than once") - # Parsing labels succeeded, continue parsing the remaining text - remaining_text = text[label_end + 1:] - value, timestamp = _parse_value_and_timestamp(remaining_text) - return Sample(name, labels, value, timestamp) def text_fd_to_metric_families(fd: TextIO) -> Iterable[Metric]: @@ -306,38 +168,28 @@ def text_fd_to_metric_families(fd: TextIO) -> Iterable[Metric]: line = line.strip() if line.startswith('#'): - parts = _split_quoted(line, None, 3) + parts = line.split(None, 3) if len(parts) < 2: continue - candidate_name, quoted = '', False - if len(parts) > 2: - # Ignore comment tokens - if parts[1] != 'TYPE' and parts[1] != 'HELP': - continue - candidate_name, quoted = _unquote_unescape(parts[2]) - if not quoted and not _is_valid_legacy_metric_name(candidate_name): - raise ValueError if parts[1] == 'HELP': - if candidate_name != name: + if parts[2] != name: if name != '': yield build_metric(name, documentation, typ, samples) # New metric - name = candidate_name + name = parts[2] typ = 'untyped' samples = [] - allowed_names = [candidate_name] + allowed_names = [parts[2]] if len(parts) == 4: documentation = _replace_help_escaping(parts[3]) else: documentation = '' elif parts[1] == 'TYPE': - if len(parts) < 4: - raise ValueError - if candidate_name != name: + if parts[2] != name: if name != '': yield build_metric(name, documentation, typ, samples) # New metric - name = candidate_name + name = parts[2] documentation = '' samples = [] typ = parts[3] @@ -348,6 +200,9 @@ def text_fd_to_metric_families(fd: TextIO) -> Iterable[Metric]: 'histogram': ['_count', '_sum', '_bucket'], }.get(typ, ['']) allowed_names = [name + n for n in allowed_names] + else: + # Ignore other comment tokens + pass elif line == '': # Ignore blank lines pass diff --git a/venv/lib/python3.12/site-packages/prometheus_client/registry.py b/venv/lib/python3.12/site-packages/prometheus_client/registry.py index 8de4ce9..694e4bd 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/registry.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/registry.py @@ -103,7 +103,7 @@ class CollectorRegistry(Collector): only samples with the given names. Intended usage is: - generate_latest(REGISTRY.restricted_registry(['a_timeseries']), escaping) + generate_latest(REGISTRY.restricted_registry(['a_timeseries'])) Experimental.""" names = set(names) diff --git a/venv/lib/python3.12/site-packages/prometheus_client/samples.py b/venv/lib/python3.12/site-packages/prometheus_client/samples.py index 994d128..d3e351c 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/samples.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/samples.py @@ -1,4 +1,4 @@ -from typing import Dict, NamedTuple, Optional, Sequence, Union +from typing import Dict, NamedTuple, Optional, Union class Timestamp: @@ -28,16 +28,7 @@ class Timestamp: return not self == other def __gt__(self, other: "Timestamp") -> bool: - return self.nsec > other.nsec if self.sec == other.sec else self.sec > other.sec - - def __lt__(self, other: "Timestamp") -> bool: - return self.nsec < other.nsec if self.sec == other.sec else self.sec < other.sec - - -# BucketSpan is experimental and subject to change at any time. -class BucketSpan(NamedTuple): - offset: int - length: int + return self.sec > other.sec or self.nsec > other.nsec # Timestamp and exemplar are optional. @@ -51,24 +42,9 @@ class Exemplar(NamedTuple): timestamp: Optional[Union[float, Timestamp]] = None -# NativeHistogram is experimental and subject to change at any time. -class NativeHistogram(NamedTuple): - count_value: float - sum_value: float - schema: int - zero_threshold: float - zero_count: float - pos_spans: Optional[Sequence[BucketSpan]] = None - neg_spans: Optional[Sequence[BucketSpan]] = None - pos_deltas: Optional[Sequence[int]] = None - neg_deltas: Optional[Sequence[int]] = None - nh_exemplars: Optional[Sequence[Exemplar]] = None - - class Sample(NamedTuple): name: str labels: Dict[str, str] value: float timestamp: Optional[Union[float, Timestamp]] = None exemplar: Optional[Exemplar] = None - native_histogram: Optional[NativeHistogram] = None diff --git a/venv/lib/python3.12/site-packages/prometheus_client/utils.py b/venv/lib/python3.12/site-packages/prometheus_client/utils.py index 87b75ca..0d2b094 100644 --- a/venv/lib/python3.12/site-packages/prometheus_client/utils.py +++ b/venv/lib/python3.12/site-packages/prometheus_client/utils.py @@ -1,5 +1,4 @@ import math -from typing import Union INF = float("inf") MINUS_INF = float("-inf") @@ -23,14 +22,3 @@ def floatToGoString(d): mantissa = f'{s[0]}.{s[1:dot]}{s[dot + 1:]}'.rstrip('0.') return f'{mantissa}e+0{dot - 1}' return s - - -def parse_version(version_str: str) -> tuple[Union[int, str], ...]: - version: list[Union[int, str]] = [] - for part in version_str.split('.'): - try: - version.append(int(part)) - except ValueError: - version.append(part) - - return tuple(version) diff --git a/venv/lib/python3.12/site-packages/prometheus_client/validation.py b/venv/lib/python3.12/site-packages/prometheus_client/validation.py deleted file mode 100644 index 6fcc801..0000000 --- a/venv/lib/python3.12/site-packages/prometheus_client/validation.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import re - -METRIC_NAME_RE = re.compile(r'^[a-zA-Z_:][a-zA-Z0-9_:]*$') -METRIC_LABEL_NAME_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') -RESERVED_METRIC_LABEL_NAME_RE = re.compile(r'^__.*$') - - -def _init_legacy_validation() -> bool: - """Retrieve name validation setting from environment.""" - return os.environ.get("PROMETHEUS_LEGACY_NAME_VALIDATION", 'False').lower() in ('true', '1', 't') - - -_legacy_validation = _init_legacy_validation() - - -def get_legacy_validation() -> bool: - """Return the current status of the legacy validation setting.""" - return _legacy_validation - - -def disable_legacy_validation(): - """Disable legacy name validation, instead allowing all UTF8 characters.""" - global _legacy_validation - _legacy_validation = False - - -def enable_legacy_validation(): - """Enable legacy name validation instead of allowing all UTF8 characters.""" - global _legacy_validation - _legacy_validation = True - - -def _validate_metric_name(name: str) -> None: - """Raises ValueError if the provided name is not a valid metric name. - - This check uses the global legacy validation setting to determine the validation scheme. - """ - if not name: - raise ValueError("metric name cannot be empty") - if _legacy_validation: - if not METRIC_NAME_RE.match(name): - raise ValueError("invalid metric name " + name) - try: - name.encode('utf-8') - except UnicodeDecodeError: - raise ValueError("invalid metric name " + name) - - -def _is_valid_legacy_metric_name(name: str) -> bool: - """Returns true if the provided metric name conforms to the legacy validation scheme.""" - if len(name) == 0: - return False - return METRIC_NAME_RE.match(name) is not None - - -def _validate_metric_label_name_token(tok: str) -> None: - """Raises ValueError if a parsed label name token is invalid. - - UTF-8 names must be quoted. - """ - if not tok: - raise ValueError("invalid label name token " + tok) - quoted = tok[0] == '"' and tok[-1] == '"' - if not quoted or _legacy_validation: - if not METRIC_LABEL_NAME_RE.match(tok): - raise ValueError("invalid label name token " + tok) - return - try: - tok.encode('utf-8') - except UnicodeDecodeError: - raise ValueError("invalid label name token " + tok) - - -def _validate_labelname(l): - """Raises ValueError if the provided name is not a valid label name. - - This check uses the global legacy validation setting to determine the validation scheme. - """ - if get_legacy_validation(): - if not METRIC_LABEL_NAME_RE.match(l): - raise ValueError('Invalid label metric name: ' + l) - if RESERVED_METRIC_LABEL_NAME_RE.match(l): - raise ValueError('Reserved label metric name: ' + l) - else: - try: - l.encode('utf-8') - except UnicodeDecodeError: - raise ValueError('Invalid label metric name: ' + l) - if RESERVED_METRIC_LABEL_NAME_RE.match(l): - raise ValueError('Reserved label metric name: ' + l) - - -def _is_valid_legacy_labelname(l: str) -> bool: - """Returns true if the provided label name conforms to the legacy validation scheme.""" - if len(l) == 0: - return False - if METRIC_LABEL_NAME_RE.match(l) is None: - return False - return RESERVED_METRIC_LABEL_NAME_RE.match(l) is None - - -def _validate_labelnames(cls, labelnames): - """Raises ValueError if any of the provided names is not a valid label name. - - This check uses the global legacy validation setting to determine the validation scheme. - """ - labelnames = tuple(labelnames) - for l in labelnames: - _validate_labelname(l) - if l in cls._reserved_labelnames: - raise ValueError('Reserved label methe fric name: ' + l) - return labelnames - - -def _validate_exemplar(exemplar): - """Raises ValueError if the exemplar is invalid.""" - runes = 0 - for k, v in exemplar.items(): - _validate_labelname(k) - runes += len(k) - runes += len(v) - if runes > 128: - raise ValueError('Exemplar labels have %d UTF-8 characters, exceeding the limit of 128') diff --git a/venv/lib/python3.12/site-packages/psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so index d1a7201..49b9ec0 100644 Binary files a/venv/lib/python3.12/site-packages/psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/psycopg2/errorcodes.py b/venv/lib/python3.12/site-packages/psycopg2/errorcodes.py index 0bc9625..aa646c4 100644 --- a/venv/lib/python3.12/site-packages/psycopg2/errorcodes.py +++ b/venv/lib/python3.12/site-packages/psycopg2/errorcodes.py @@ -256,7 +256,6 @@ HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL = '25008' NO_ACTIVE_SQL_TRANSACTION = '25P01' IN_FAILED_SQL_TRANSACTION = '25P02' IDLE_IN_TRANSACTION_SESSION_TIMEOUT = '25P03' -TRANSACTION_TIMEOUT = '25P04' # Class 26 - Invalid SQL Statement Name INVALID_SQL_STATEMENT_NAME = '26000' diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/LICENSE b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/LICENSE rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/METADATA b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/METADATA similarity index 88% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/METADATA rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/METADATA index 05674fb..724e6c1 100644 --- a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/METADATA @@ -1,6 +1,6 @@ Metadata-Version: 2.1 Name: psycopg2-binary -Version: 2.9.10 +Version: 2.9.9 Summary: psycopg2 - Python-PostgreSQL Database Adapter Home-page: https://psycopg.org/ Author: Federico Di Gregorio @@ -9,7 +9,6 @@ Maintainer: Daniele Varrazzo Maintainer-email: daniele.varrazzo@gmail.com License: LGPL with exceptions Project-URL: Homepage, https://psycopg.org/ -Project-URL: Changes, https://www.psycopg.org/docs/news.html Project-URL: Documentation, https://www.psycopg.org/docs/ Project-URL: Code, https://github.com/psycopg/psycopg2 Project-URL: Issue Tracker, https://github.com/psycopg/psycopg2/issues @@ -20,12 +19,12 @@ Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL) Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: 3 :: Only Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: C @@ -36,7 +35,7 @@ Classifier: Topic :: Software Development Classifier: Topic :: Software Development :: Libraries :: Python Modules Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: Unix -Requires-Python: >=3.8 +Requires-Python: >=3.7 License-File: LICENSE Psycopg is the most popular PostgreSQL database adapter for the Python @@ -55,18 +54,6 @@ flexible objects adaptation system. Psycopg 2 is both Unicode and Python 3 friendly. -.. Note:: - - The psycopg2 package is still widely used and actively maintained, but it - is not expected to receive new features. - - `Psycopg 3`__ is the evolution of psycopg2 and is where `new features are - being developed`__: if you are starting a new project you should probably - start from 3! - - .. __: https://pypi.org/project/psycopg/ - .. __: https://www.psycopg.org/psycopg3/docs/index.html - Documentation ------------- diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/RECORD b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/RECORD similarity index 62% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/RECORD rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/RECORD index 95be6d6..567ce11 100644 --- a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/RECORD @@ -12,33 +12,33 @@ psycopg2/__pycache__/sql.cpython-312.pyc,, psycopg2/__pycache__/tz.cpython-312.pyc,, psycopg2/_ipaddress.py,sha256=jkuyhLgqUGRBcLNWDM8QJysV6q1Npc_RYH4_kE7JZPU,2922 psycopg2/_json.py,sha256=XPn4PnzbTg1Dcqz7n1JMv5dKhB5VFV6834GEtxSawt0,7153 -psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so,sha256=F5RVch-J7Y0sT6p5xvWgNGtLtTS9MbQTDOtaPeeZaqE,339145 +psycopg2/_psycopg.cpython-312-x86_64-linux-gnu.so,sha256=Y_MtTA7BiSenx2ulSd3tYwfiMjdXdyK-brB_A7-kKD8,339145 psycopg2/_range.py,sha256=sXeenGraJEEw2I3mc8RlmNivy2jMg7zWoanDes2Ywp8,18494 -psycopg2/errorcodes.py,sha256=ko0m0I294B6tb60GAu_gqvoVykqf6cyrGM7MLj4p0Qg,14392 +psycopg2/errorcodes.py,sha256=jb1SkuGq5zJT7F99GFAUi3VQH8GbsB7zRHiLsAWAU0Q,14362 psycopg2/errors.py,sha256=aAS4dJyTg1bsDzJDCRQAMB_s7zv-Q4yB6Yvih26I-0M,1425 psycopg2/extensions.py,sha256=CG0kG5vL8Ot503UGlDXXJJFdFWLg4HE2_c1-lLOLc8M,6797 psycopg2/extras.py,sha256=oBfrdvtWn8ITxc3x-h2h6IwHUsWdVqCdf4Gphb0JqY8,44215 psycopg2/pool.py,sha256=UGEt8IdP3xNc2PGYNlG4sQvg8nhf4aeCnz39hTR0H8I,6316 psycopg2/sql.py,sha256=OcFEAmpe2aMfrx0MEk4Lx00XvXXJCmvllaOVbJY-yoE,14779 psycopg2/tz.py,sha256=r95kK7eGSpOYr_luCyYsznHMzjl52sLjsnSPXkXLzRI,4870 -psycopg2_binary-2.9.10.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -psycopg2_binary-2.9.10.dist-info/LICENSE,sha256=lhS4XfyacsWyyjMUTB1-HtOxwpdFnZ-yimpXYsLo1xs,2238 -psycopg2_binary-2.9.10.dist-info/METADATA,sha256=nKflg_fOjsZqIxaJEFDyyzl0I8YNcHuEN4WrM68RG5E,4924 -psycopg2_binary-2.9.10.dist-info/RECORD,, -psycopg2_binary-2.9.10.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -psycopg2_binary-2.9.10.dist-info/WHEEL,sha256=7B4nnId14TToQHuAKpxbDLCJbNciqBsV-mvXE2hVLJc,151 -psycopg2_binary-2.9.10.dist-info/top_level.txt,sha256=7dHGpLqQ3w-vGmGEVn-7uK90qU9fyrGdWWi7S-gTcnM,9 +psycopg2_binary-2.9.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +psycopg2_binary-2.9.9.dist-info/LICENSE,sha256=lhS4XfyacsWyyjMUTB1-HtOxwpdFnZ-yimpXYsLo1xs,2238 +psycopg2_binary-2.9.9.dist-info/METADATA,sha256=vkxMt-2J7iReUtyq2SN4AY4BrHDgiz8csUjacUUYWVk,4445 +psycopg2_binary-2.9.9.dist-info/RECORD,, +psycopg2_binary-2.9.9.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +psycopg2_binary-2.9.9.dist-info/WHEEL,sha256=4ZiCdXIWMxJyEClivrQv1QAHZpQh8kVYU92_ZAVwaok,152 +psycopg2_binary-2.9.9.dist-info/top_level.txt,sha256=7dHGpLqQ3w-vGmGEVn-7uK90qU9fyrGdWWi7S-gTcnM,9 psycopg2_binary.libs/libcom_err-2abe824b.so.2.1,sha256=VCbctU3QHJ7t2gXiF58ORxFOi0ilNP_p6UkW55Rxslc,17497 -psycopg2_binary.libs/libcrypto-ea28cefb.so.1.1,sha256=BaFxOxGMaqri29sHgDiXm9wGdC2M_5O7F734mZJBUJk,3133185 +psycopg2_binary.libs/libcrypto-0628e7d4.so.1.1,sha256=iNCZwhYYZg5Gc5zN14JOY0gUyelRkm3wD9A-0kbL6SA,3133185 psycopg2_binary.libs/libgssapi_krb5-497db0c6.so.2.2,sha256=KnSwMw7pcygbJvjr5KzvDr-e6ZxraEl8-RUf_2xMNOE,345209 psycopg2_binary.libs/libk5crypto-b1f99d5c.so.3.1,sha256=mETlAJ5wpq0vsitYcwaBD-Knsbn2uZItqhx4ujRm3ic,219953 psycopg2_binary.libs/libkeyutils-dfe70bd6.so.1.5,sha256=wp5BsDz0st_7-0lglG4rQvgsDKXVPSMdPw_Fl7onRIg,17913 psycopg2_binary.libs/libkrb5-fcafa220.so.3.3,sha256=sqq1KP9MqyFE5c4BskasCfV0oHKlP_Y-qB1rspsmuPE,1018953 psycopg2_binary.libs/libkrb5support-d0bcff84.so.0.1,sha256=anH1fXSP73m05zbVNIh1VF0KIk-okotdYqPPJkf8EJ8,76873 -psycopg2_binary.libs/liblber-e0f57070.so.2.0.200,sha256=LU9hjsIesgayrlVnZhk66rAcud1YdFLYsbpOvmNKhlI,60977 -psycopg2_binary.libs/libldap-c37ed727.so.2.0.200,sha256=Ss_bwD7xjuADElwjXmfqGWQxPxTC2DY0AsRCeJ9IWIk,447313 +psycopg2_binary.libs/liblber-5a1d5ae1.so.2.0.200,sha256=hfC4ohbSIRZ9kJRuaT4PlfOEogZXpgLlY_FgaMNaoYc,60977 +psycopg2_binary.libs/libldap-5d2ff197.so.2.0.200,sha256=ho65rEV6AhnLA0mo-TKB9TcUROR8-uymbfEAGkAcpwQ,447329 psycopg2_binary.libs/libpcre-9513aab5.so.1.2.0,sha256=Au2oUOBJMWVtivgfUXG_902L7BVT09hcPTLX_F7-iGQ,406817 -psycopg2_binary.libs/libpq-e8a033dd.so.5.16,sha256=EZJfnquq7cmvub9JhNH0A6W3g9PprFCQKYXVIQRcBOA,370761 +psycopg2_binary.libs/libpq-e8a033dd.so.5.16,sha256=io69ZDoOBgCMoVj2aGl1-aovIrAOzg2YxumgJeq1iQ8,370777 psycopg2_binary.libs/libsasl2-883649fd.so.3.0.0,sha256=GC8C1eR02yJ82oOrrHQT1DHUh8bAGv0M10HhQM7cDzo,119217 psycopg2_binary.libs/libselinux-0922c95c.so.1,sha256=1PqOf7Ot2WCmgyWlnJaUJErqMhP9c5pQgVywZ8SWVlQ,178337 -psycopg2_binary.libs/libssl-3e69114b.so.1.1,sha256=e6o34uLMd7XUKjB3w0Qu2cvHbfo6fjPRf_NLJKhv5qs,646065 +psycopg2_binary.libs/libssl-3e69114b.so.1.1,sha256=FJ2ccBmBNGXrf07x0GVrPwIORu0BPRHyt_tLogu5jjA,646065 diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/WHEEL b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/WHEEL similarity index 78% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/WHEEL index e21e9f2..d1b3f1d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.9.0) +Generator: bdist_wheel (0.41.2) Root-Is-Purelib: false Tag: cp312-cp312-manylinux_2_17_x86_64 Tag: cp312-cp312-manylinux2014_x86_64 diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/psycopg2_binary-2.9.10.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/psycopg2_binary-2.9.9.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-ea28cefb.so.1.1 b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-0628e7d4.so.1.1 similarity index 99% rename from venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-ea28cefb.so.1.1 rename to venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-0628e7d4.so.1.1 index 67db91d..34fea43 100755 Binary files a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-ea28cefb.so.1.1 and b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libcrypto-0628e7d4.so.1.1 differ diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-e0f57070.so.2.0.200 b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-5a1d5ae1.so.2.0.200 similarity index 99% rename from venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-e0f57070.so.2.0.200 rename to venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-5a1d5ae1.so.2.0.200 index b4e6057..7884bd3 100755 Binary files a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-e0f57070.so.2.0.200 and b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/liblber-5a1d5ae1.so.2.0.200 differ diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-c37ed727.so.2.0.200 b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-5d2ff197.so.2.0.200 similarity index 99% rename from venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-c37ed727.so.2.0.200 rename to venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-5d2ff197.so.2.0.200 index dbc45da..3780b75 100755 Binary files a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-c37ed727.so.2.0.200 and b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libldap-5d2ff197.so.2.0.200 differ diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libpq-e8a033dd.so.5.16 b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libpq-e8a033dd.so.5.16 index dbfbbec..0c52cfa 100755 Binary files a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libpq-e8a033dd.so.5.16 and b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libpq-e8a033dd.so.5.16 differ diff --git a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libssl-3e69114b.so.1.1 b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libssl-3e69114b.so.1.1 index 83e4917..b1fd77e 100755 Binary files a/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libssl-3e69114b.so.1.1 and b/venv/lib/python3.12/site-packages/psycopg2_binary.libs/libssl-3e69114b.so.1.1 differ diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/METADATA b/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/METADATA deleted file mode 100644 index 16e36c1..0000000 --- a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/METADATA +++ /dev/null @@ -1,787 +0,0 @@ -Metadata-Version: 2.4 -Name: pydantic -Version: 2.11.9 -Summary: Data validation using Python type hints -Project-URL: Homepage, https://github.com/pydantic/pydantic -Project-URL: Documentation, https://docs.pydantic.dev -Project-URL: Funding, https://github.com/sponsors/samuelcolvin -Project-URL: Source, https://github.com/pydantic/pydantic -Project-URL: Changelog, https://docs.pydantic.dev/latest/changelog/ -Author-email: Samuel Colvin , Eric Jolibois , Hasan Ramezani , Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>, Terrence Dorsey , David Montague , Serge Matveenko , Marcelo Trylesinski , Sydney Runkle , David Hewitt , Alex Hall , Victorien Plot -License-Expression: MIT -License-File: LICENSE -Classifier: Development Status :: 5 - Production/Stable -Classifier: Framework :: Hypothesis -Classifier: Framework :: Pydantic -Classifier: Intended Audience :: Developers -Classifier: Intended Audience :: Information Technology -Classifier: License :: OSI Approved :: MIT License -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3 :: Only -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 -Classifier: Programming Language :: Python :: Implementation :: CPython -Classifier: Programming Language :: Python :: Implementation :: PyPy -Classifier: Topic :: Internet -Classifier: Topic :: Software Development :: Libraries :: Python Modules -Requires-Python: >=3.9 -Requires-Dist: annotated-types>=0.6.0 -Requires-Dist: pydantic-core==2.33.2 -Requires-Dist: typing-extensions>=4.12.2 -Requires-Dist: typing-inspection>=0.4.0 -Provides-Extra: email -Requires-Dist: email-validator>=2.0.0; extra == 'email' -Provides-Extra: timezone -Requires-Dist: tzdata; (python_version >= '3.9' and platform_system == 'Windows') and extra == 'timezone' -Description-Content-Type: text/markdown - -# Pydantic -[![CI](https://img.shields.io/github/actions/workflow/status/pydantic/pydantic/ci.yml?branch=main&logo=github&label=CI)](https://github.com/pydantic/pydantic/actions?query=event%3Apush+branch%3Amain+workflow%3ACI) -[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic) -[![pypi](https://img.shields.io/pypi/v/pydantic.svg)](https://pypi.python.org/pypi/pydantic) -[![CondaForge](https://img.shields.io/conda/v/conda-forge/pydantic.svg)](https://anaconda.org/conda-forge/pydantic) -[![downloads](https://static.pepy.tech/badge/pydantic/month)](https://pepy.tech/project/pydantic) -[![versions](https://img.shields.io/pypi/pyversions/pydantic.svg)](https://github.com/pydantic/pydantic) -[![license](https://img.shields.io/github/license/pydantic/pydantic.svg)](https://github.com/pydantic/pydantic/blob/main/LICENSE) -[![Pydantic v2](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/pydantic/pydantic/main/docs/badge/v2.json)](https://docs.pydantic.dev/latest/contributing/#badges) -[![llms.txt](https://img.shields.io/badge/llms.txt-green)](https://docs.pydantic.dev/latest/llms.txt) - - -Data validation using Python type hints. - -Fast and extensible, Pydantic plays nicely with your linters/IDE/brain. -Define how data should be in pure, canonical Python 3.9+; validate it with Pydantic. - -## Pydantic Logfire :fire: - -We've recently launched Pydantic Logfire to help you monitor your applications. -[Learn more](https://pydantic.dev/articles/logfire-announcement) - -## Pydantic V1.10 vs. V2 - -Pydantic V2 is a ground-up rewrite that offers many new features, performance improvements, and some breaking changes compared to Pydantic V1. - -If you're using Pydantic V1 you may want to look at the -[pydantic V1.10 Documentation](https://docs.pydantic.dev/) or, -[`1.10.X-fixes` git branch](https://github.com/pydantic/pydantic/tree/1.10.X-fixes). Pydantic V2 also ships with the latest version of Pydantic V1 built in so that you can incrementally upgrade your code base and projects: `from pydantic import v1 as pydantic_v1`. - -## Help - -See [documentation](https://docs.pydantic.dev/) for more details. - -## Installation - -Install using `pip install -U pydantic` or `conda install pydantic -c conda-forge`. -For more installation options to make Pydantic even faster, -see the [Install](https://docs.pydantic.dev/install/) section in the documentation. - -## A Simple Example - -```python -from datetime import datetime -from typing import Optional -from pydantic import BaseModel - -class User(BaseModel): - id: int - name: str = 'John Doe' - signup_ts: Optional[datetime] = None - friends: list[int] = [] - -external_data = {'id': '123', 'signup_ts': '2017-06-01 12:22', 'friends': [1, '2', b'3']} -user = User(**external_data) -print(user) -#> User id=123 name='John Doe' signup_ts=datetime.datetime(2017, 6, 1, 12, 22) friends=[1, 2, 3] -print(user.id) -#> 123 -``` - -## Contributing - -For guidance on setting up a development environment and how to make a -contribution to Pydantic, see -[Contributing to Pydantic](https://docs.pydantic.dev/contributing/). - -## Reporting a Security Vulnerability - -See our [security policy](https://github.com/pydantic/pydantic/security/policy). - -## Changelog - -## v2.11.9 (2025-09-13) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.9) - -### What's Changed - -#### Fixes - -* Backport v1.10.23 changes by [@Viicos](https://github.com/Viicos) - -## v2.11.8 (2025-09-13) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.8) - -### What's Changed - -#### Fixes - -* Fix mypy plugin for mypy 1.18 by [@cdce8p](https://github.com/cdce8p) in [#12209](https://github.com/pydantic/pydantic/pull/12209) - -## v2.11.7 (2025-06-14) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.7) - -### What's Changed - -#### Fixes - -* Copy `FieldInfo` instance if necessary during `FieldInfo` build by [@Viicos](https://github.com/Viicos) in [#11898](https://github.com/pydantic/pydantic/pull/11898) - -## v2.11.6 (2025-06-13) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.6) - -### What's Changed - -#### Fixes - -* Rebuild dataclass fields before schema generation by [@Viicos](https://github.com/Viicos) in [#11949](https://github.com/pydantic/pydantic/pull/11949) -* Always store the original field assignment on `FieldInfo` by [@Viicos](https://github.com/Viicos) in [#11946](https://github.com/pydantic/pydantic/pull/11946) - -## v2.11.5 (2025-05-22) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.5) - -### What's Changed - -#### Fixes - -* Check if `FieldInfo` is complete after applying type variable map by [@Viicos](https://github.com/Viicos) in [#11855](https://github.com/pydantic/pydantic/pull/11855) -* Do not delete mock validator/serializer in `model_rebuild()` by [@Viicos](https://github.com/Viicos) in [#11890](https://github.com/pydantic/pydantic/pull/11890) -* Do not duplicate metadata on model rebuild by [@Viicos](https://github.com/Viicos) in [#11902](https://github.com/pydantic/pydantic/pull/11902) - -## v2.11.4 (2025-04-29) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.4) - -### What's Changed - -#### Packaging - -* Bump `mkdocs-llmstxt` to v0.2.0 by [@Viicos](https://github.com/Viicos) in [#11725](https://github.com/pydantic/pydantic/pull/11725) - -#### Changes - -* Allow config and bases to be specified together in `create_model()` by [@Viicos](https://github.com/Viicos) in [#11714](https://github.com/pydantic/pydantic/pull/11714). - This change was backported as it was previously possible (although not meant to be supported) - to provide `model_config` as a field, which would make it possible to provide both configuration - and bases. - -#### Fixes - -* Remove generics cache workaround by [@Viicos](https://github.com/Viicos) in [#11755](https://github.com/pydantic/pydantic/pull/11755) -* Remove coercion of decimal constraints by [@Viicos](https://github.com/Viicos) in [#11772](https://github.com/pydantic/pydantic/pull/11772) -* Fix crash when expanding root type in the mypy plugin by [@Viicos](https://github.com/Viicos) in [#11735](https://github.com/pydantic/pydantic/pull/11735) -* Fix issue with recursive generic models by [@Viicos](https://github.com/Viicos) in [#11775](https://github.com/pydantic/pydantic/pull/11775) -* Traverse `function-before` schemas during schema gathering by [@Viicos](https://github.com/Viicos) in [#11801](https://github.com/pydantic/pydantic/pull/11801) - -## v2.11.3 (2025-04-08) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.3) - -### What's Changed - -#### Packaging - -* Update V1 copy to v1.10.21 by [@Viicos](https://github.com/Viicos) in [#11706](https://github.com/pydantic/pydantic/pull/11706) - -#### Fixes - -* Preserve field description when rebuilding model fields by [@Viicos](https://github.com/Viicos) in [#11698](https://github.com/pydantic/pydantic/pull/11698) - -## v2.11.2 (2025-04-03) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.2) - -### What's Changed - -#### Fixes - -* Bump `pydantic-core` to v2.33.1 by [@Viicos](https://github.com/Viicos) in [#11678](https://github.com/pydantic/pydantic/pull/11678) -* Make sure `__pydantic_private__` exists before setting private attributes by [@Viicos](https://github.com/Viicos) in [#11666](https://github.com/pydantic/pydantic/pull/11666) -* Do not override `FieldInfo._complete` when using field from parent class by [@Viicos](https://github.com/Viicos) in [#11668](https://github.com/pydantic/pydantic/pull/11668) -* Provide the available definitions when applying discriminated unions by [@Viicos](https://github.com/Viicos) in [#11670](https://github.com/pydantic/pydantic/pull/11670) -* Do not expand root type in the mypy plugin for variables by [@Viicos](https://github.com/Viicos) in [#11676](https://github.com/pydantic/pydantic/pull/11676) -* Mention the attribute name in model fields deprecation message by [@Viicos](https://github.com/Viicos) in [#11674](https://github.com/pydantic/pydantic/pull/11674) -* Properly validate parameterized mappings by [@Viicos](https://github.com/Viicos) in [#11658](https://github.com/pydantic/pydantic/pull/11658) - -## v2.11.1 (2025-03-28) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.1) - -### What's Changed - -#### Fixes - -* Do not override `'definitions-ref'` schemas containing serialization schemas or metadata by [@Viicos](https://github.com/Viicos) in [#11644](https://github.com/pydantic/pydantic/pull/11644) - -## v2.11.0 (2025-03-27) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.0) - -### What's Changed - -Pydantic v2.11 is a version strongly focused on build time performance of Pydantic models (and core schema generation in general). -See the [blog post](https://pydantic.dev/articles/pydantic-v2-11-release) for more details. - -#### Packaging - -* Bump `pydantic-core` to v2.33.0 by [@Viicos](https://github.com/Viicos) in [#11631](https://github.com/pydantic/pydantic/pull/11631) - -#### New Features - -* Add `encoded_string()` method to the URL types by [@YassinNouh21](https://github.com/YassinNouh21) in [#11580](https://github.com/pydantic/pydantic/pull/11580) -* Add support for `defer_build` with `@validate_call` decorator by [@Viicos](https://github.com/Viicos) in [#11584](https://github.com/pydantic/pydantic/pull/11584) -* Allow `@with_config` decorator to be used with keyword arguments by [@Viicos](https://github.com/Viicos) in [#11608](https://github.com/pydantic/pydantic/pull/11608) -* Simplify customization of default value inclusion in JSON Schema generation by [@Viicos](https://github.com/Viicos) in [#11634](https://github.com/pydantic/pydantic/pull/11634) -* Add `generate_arguments_schema()` function by [@Viicos](https://github.com/Viicos) in [#11572](https://github.com/pydantic/pydantic/pull/11572) - -#### Fixes - -* Allow generic typed dictionaries to be used for unpacked variadic keyword parameters by [@Viicos](https://github.com/Viicos) in [#11571](https://github.com/pydantic/pydantic/pull/11571) -* Fix runtime error when computing model string representation involving cached properties and self-referenced models by [@Viicos](https://github.com/Viicos) in [#11579](https://github.com/pydantic/pydantic/pull/11579) -* Preserve other steps when using the ellipsis in the pipeline API by [@Viicos](https://github.com/Viicos) in [#11626](https://github.com/pydantic/pydantic/pull/11626) -* Fix deferred discriminator application logic by [@Viicos](https://github.com/Viicos) in [#11591](https://github.com/pydantic/pydantic/pull/11591) - -### New Contributors - -* [@cmenon12](https://github.com/cmenon12) made their first contribution in [#11562](https://github.com/pydantic/pydantic/pull/11562) -* [@Jeukoh](https://github.com/Jeukoh) made their first contribution in [#11611](https://github.com/pydantic/pydantic/pull/11611) - -## v2.11.0b2 (2025-03-17) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.0b2) - -### What's Changed - -#### Packaging - -* Bump `pydantic-core` to v2.32.0 by [@Viicos](https://github.com/Viicos) in [#11567](https://github.com/pydantic/pydantic/pull/11567) - -#### New Features - -* Add experimental support for free threading by [@Viicos](https://github.com/Viicos) in [#11516](https://github.com/pydantic/pydantic/pull/11516) - -#### Fixes - -* Fix `NotRequired` qualifier not taken into account in stringified annotation by [@Viicos](https://github.com/Viicos) in [#11559](https://github.com/pydantic/pydantic/pull/11559) - -### New Contributors - -* [@joren485](https://github.com/joren485) made their first contribution in [#11547](https://github.com/pydantic/pydantic/pull/11547) - -## v2.11.0b1 (2025-03-06) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.0b1) - -### What's Changed - -#### Packaging - -* Add a `check_pydantic_core_version()` function by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11324 -* Remove `greenlet` development dependency by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11351 -* Use the `typing-inspection` library by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11479 -* Bump `pydantic-core` to `v2.31.1` by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11526 - -#### New Features - -* Support unsubstituted type variables with both a default and a bound or constraints by [@FyZzyss](https://github.com/FyZzyss) in https://github.com/pydantic/pydantic/pull/10789 -* Add a `default_factory_takes_validated_data` property to `FieldInfo` by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11034 -* Raise a better error when a generic alias is used inside `type[]` by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11088 -* Properly support PEP 695 generics syntax by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11189 -* Properly support type variable defaults by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11332 -* Add support for validating v6, v7, v8 UUIDs by [@astei](https://github.com/astei) in https://github.com/pydantic/pydantic/pull/11436 -* Improve alias configuration APIs by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11468 - -#### Changes - -* Rework `create_model` field definitions format by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11032 -* Raise a deprecation warning when a field is annotated as final with a default value by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11168 -* Deprecate accessing `model_fields` and `model_computed_fields` on instances by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11169 -* **Breaking Change:** Move core schema generation logic for path types inside the `GenerateSchema` class by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/10846 -* Remove Python 3.8 Support by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11258 -* Optimize calls to `get_type_ref` by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/10863 -* Disable `pydantic-core` core schema validation by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11271 - -#### Performance - -* Only evaluate `FieldInfo` annotations if required during schema building by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/10769 -* Improve `__setattr__` performance of Pydantic models by caching setter functions by [@MarkusSintonen](https://github.com/MarkusSintonen) in https://github.com/pydantic/pydantic/pull/10868 -* Improve annotation application performance by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11186 -* Improve performance of `_typing_extra` module by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11255 -* Refactor and optimize schema cleaning logic by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11244 -* Create a single dictionary when creating a `CoreConfig` instance by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11384 -* Bump `pydantic-core` and thus use `SchemaValidator` and `SchemaSerializer` caching by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11402 -* Reuse cached core schemas for parametrized generic Pydantic models by [@MarkusSintonen](https://github.com/MarkusSintonen) in https://github.com/pydantic/pydantic/pull/11434 - -#### Fixes - -* Improve `TypeAdapter` instance repr by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/10872 -* Use the correct frame when instantiating a parametrized `TypeAdapter` by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/10893 -* Infer final fields with a default value as class variables in the mypy plugin by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11121 -* Recursively unpack `Literal` values if using PEP 695 type aliases by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11114 -* Override `__subclasscheck__` on `ModelMetaclass` to avoid memory leak and performance issues by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11116 -* Remove unused `_extract_get_pydantic_json_schema()` parameter by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11155 -* Improve discriminated union error message for invalid union variants by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11161 -* Unpack PEP 695 type aliases if using the `Annotated` form by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11109 -* Add missing stacklevel in `deprecated_instance_property` warning by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11200 -* Copy `WithJsonSchema` schema to avoid sharing mutated data by [@thejcannon](https://github.com/thejcannon) in https://github.com/pydantic/pydantic/pull/11014 -* Do not cache parametrized models when in the process of parametrizing another model by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/10704 -* Add discriminated union related metadata entries to the `CoreMetadata` definition by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11216 -* Consolidate schema definitions logic in the `_Definitions` class by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11208 -* Support initializing root model fields with values of the `root` type in the mypy plugin by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11212 -* Fix various issues with dataclasses and `use_attribute_docstrings` by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11246 -* Only compute normalized decimal places if necessary in `decimal_places_validator` by [@misrasaurabh1](https://github.com/misrasaurabh1) in https://github.com/pydantic/pydantic/pull/11281 -* Add support for `validation_alias` in the mypy plugin by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11295 -* Fix JSON Schema reference collection with `"examples"` keys by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11305 -* Do not transform model serializer functions as class methods in the mypy plugin by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11298 -* Simplify `GenerateJsonSchema.literal_schema()` implementation by [@misrasaurabh1](https://github.com/misrasaurabh1) in https://github.com/pydantic/pydantic/pull/11321 -* Add additional allowed schemes for `ClickHouseDsn` by [@Maze21127](https://github.com/Maze21127) in https://github.com/pydantic/pydantic/pull/11319 -* Coerce decimal constraints to `Decimal` instances by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11350 -* Use the correct JSON Schema mode when handling function schemas by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11367 -* Improve exception message when encountering recursion errors during type evaluation by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11356 -* Always include `additionalProperties: True` for arbitrary dictionary schemas by [@austinyu](https://github.com/austinyu) in https://github.com/pydantic/pydantic/pull/11392 -* Expose `fallback` parameter in serialization methods by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11398 -* Fix path serialization behavior by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11416 -* Do not reuse validators and serializers during model rebuild by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11429 -* Collect model fields when rebuilding a model by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11388 -* Allow cached properties to be altered on frozen models by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11432 -* Fix tuple serialization for `Sequence` types by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11435 -* Fix: do not check for `__get_validators__` on classes where `__get_pydantic_core_schema__` is also defined by [@tlambert03](https://github.com/tlambert03) in https://github.com/pydantic/pydantic/pull/11444 -* Allow callable instances to be used as serializers by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11451 -* Improve error thrown when overriding field with a property by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11459 -* Fix JSON Schema generation with referenceable core schemas holding JSON metadata by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11475 -* Support strict specification on union member types by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11481 -* Implicitly set `validate_by_name` to `True` when `validate_by_alias` is `False` by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic/pull/11503 -* Change type of `Any` when synthesizing `BaseSettings.__init__` signature in the mypy plugin by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11497 -* Support type variable defaults referencing other type variables by [@Viicos](https://github.com/Viicos) in https://github.com/pydantic/pydantic/pull/11520 -* Fix `ValueError` on year zero by [@davidhewitt](https://github.com/davidhewitt) in https://github.com/pydantic/pydantic-core/pull/1583 -* `dataclass` `InitVar` shouldn't be required on serialization by [@sydney-runkle](https://github.com/sydney-runkle) in https://github.com/pydantic/pydantic-core/pull/1602 - -## New Contributors -* [@FyZzyss](https://github.com/FyZzyss) made their first contribution in https://github.com/pydantic/pydantic/pull/10789 -* [@tamird](https://github.com/tamird) made their first contribution in https://github.com/pydantic/pydantic/pull/10948 -* [@felixxm](https://github.com/felixxm) made their first contribution in https://github.com/pydantic/pydantic/pull/11077 -* [@alexprabhat99](https://github.com/alexprabhat99) made their first contribution in https://github.com/pydantic/pydantic/pull/11082 -* [@Kharianne](https://github.com/Kharianne) made their first contribution in https://github.com/pydantic/pydantic/pull/11111 -* [@mdaffad](https://github.com/mdaffad) made their first contribution in https://github.com/pydantic/pydantic/pull/11177 -* [@thejcannon](https://github.com/thejcannon) made their first contribution in https://github.com/pydantic/pydantic/pull/11014 -* [@thomasfrimannkoren](https://github.com/thomasfrimannkoren) made their first contribution in https://github.com/pydantic/pydantic/pull/11251 -* [@usernameMAI](https://github.com/usernameMAI) made their first contribution in https://github.com/pydantic/pydantic/pull/11275 -* [@ananiavito](https://github.com/ananiavito) made their first contribution in https://github.com/pydantic/pydantic/pull/11302 -* [@pawamoy](https://github.com/pawamoy) made their first contribution in https://github.com/pydantic/pydantic/pull/11311 -* [@Maze21127](https://github.com/Maze21127) made their first contribution in https://github.com/pydantic/pydantic/pull/11319 -* [@kauabh](https://github.com/kauabh) made their first contribution in https://github.com/pydantic/pydantic/pull/11369 -* [@jaceklaskowski](https://github.com/jaceklaskowski) made their first contribution in https://github.com/pydantic/pydantic/pull/11353 -* [@tmpbeing](https://github.com/tmpbeing) made their first contribution in https://github.com/pydantic/pydantic/pull/11375 -* [@petyosi](https://github.com/petyosi) made their first contribution in https://github.com/pydantic/pydantic/pull/11405 -* [@austinyu](https://github.com/austinyu) made their first contribution in https://github.com/pydantic/pydantic/pull/11392 -* [@mikeedjones](https://github.com/mikeedjones) made their first contribution in https://github.com/pydantic/pydantic/pull/11402 -* [@astei](https://github.com/astei) made their first contribution in https://github.com/pydantic/pydantic/pull/11436 -* [@dsayling](https://github.com/dsayling) made their first contribution in https://github.com/pydantic/pydantic/pull/11522 -* [@sobolevn](https://github.com/sobolevn) made their first contribution in https://github.com/pydantic/pydantic-core/pull/1645 - -## v2.11.0a2 (2025-02-10) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.0a2) - -### What's Changed - -Pydantic v2.11 is a version strongly focused on build time performance of Pydantic models (and core schema generation in general). -This is another early alpha release, meant to collect early feedback from users having issues with core schema builds. - -#### Packaging - -* Bump `ruff` from 0.9.2 to 0.9.5 by [@Viicos](https://github.com/Viicos) in [#11407](https://github.com/pydantic/pydantic/pull/11407) -* Bump `pydantic-core` to v2.29.0 by [@mikeedjones](https://github.com/mikeedjones) in [#11402](https://github.com/pydantic/pydantic/pull/11402) -* Use locally-built rust with symbols & pgo by [@davidhewitt](https://github.com/davidhewitt) in [#11403](https://github.com/pydantic/pydantic/pull/11403) - - -#### Performance - -* Create a single dictionary when creating a `CoreConfig` instance by [@sydney-runkle](https://github.com/sydney-runkle) in [#11384](https://github.com/pydantic/pydantic/pull/11384) - -#### Fixes - -* Use the correct JSON Schema mode when handling function schemas by [@Viicos](https://github.com/Viicos) in [#11367](https://github.com/pydantic/pydantic/pull/11367) -* Fix JSON Schema reference logic with `examples` keys by [@Viicos](https://github.com/Viicos) in [#11366](https://github.com/pydantic/pydantic/pull/11366) -* Improve exception message when encountering recursion errors during type evaluation by [@Viicos](https://github.com/Viicos) in [#11356](https://github.com/pydantic/pydantic/pull/11356) -* Always include `additionalProperties: True` for arbitrary dictionary schemas by [@austinyu](https://github.com/austinyu) in [#11392](https://github.com/pydantic/pydantic/pull/11392) -* Expose `fallback` parameter in serialization methods by [@Viicos](https://github.com/Viicos) in [#11398](https://github.com/pydantic/pydantic/pull/11398) -* Fix path serialization behavior by [@sydney-runkle](https://github.com/sydney-runkle) in [#11416](https://github.com/pydantic/pydantic/pull/11416) - -### New Contributors - -* [@kauabh](https://github.com/kauabh) made their first contribution in [#11369](https://github.com/pydantic/pydantic/pull/11369) -* [@jaceklaskowski](https://github.com/jaceklaskowski) made their first contribution in [#11353](https://github.com/pydantic/pydantic/pull/11353) -* [@tmpbeing](https://github.com/tmpbeing) made their first contribution in [#11375](https://github.com/pydantic/pydantic/pull/11375) -* [@petyosi](https://github.com/petyosi) made their first contribution in [#11405](https://github.com/pydantic/pydantic/pull/11405) -* [@austinyu](https://github.com/austinyu) made their first contribution in [#11392](https://github.com/pydantic/pydantic/pull/11392) -* [@mikeedjones](https://github.com/mikeedjones) made their first contribution in [#11402](https://github.com/pydantic/pydantic/pull/11402) - -## v2.11.0a1 (2025-01-30) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.11.0a1) - -### What's Changed - -Pydantic v2.11 is a version strongly focused on build time performance of Pydantic models (and core schema generation in general). -This is an early alpha release, meant to collect early feedback from users having issues with core schema builds. - -#### Packaging - -* Bump dawidd6/action-download-artifact from 6 to 7 by [@dependabot](https://github.com/dependabot) in [#11018](https://github.com/pydantic/pydantic/pull/11018) -* Re-enable memray related tests on Python 3.12+ by [@Viicos](https://github.com/Viicos) in [#11191](https://github.com/pydantic/pydantic/pull/11191) -* Bump astral-sh/setup-uv to 5 by [@dependabot](https://github.com/dependabot) in [#11205](https://github.com/pydantic/pydantic/pull/11205) -* Bump `ruff` to v0.9.0 by [@sydney-runkle](https://github.com/sydney-runkle) in [#11254](https://github.com/pydantic/pydantic/pull/11254) -* Regular `uv.lock` deps update by [@sydney-runkle](https://github.com/sydney-runkle) in [#11333](https://github.com/pydantic/pydantic/pull/11333) -* Add a `check_pydantic_core_version()` function by [@Viicos](https://github.com/Viicos) in [#11324](https://github.com/pydantic/pydantic/pull/11324) -* Remove `greenlet` development dependency by [@Viicos](https://github.com/Viicos) in [#11351](https://github.com/pydantic/pydantic/pull/11351) -* Bump `pydantic-core` to v2.28.0 by [@Viicos](https://github.com/Viicos) in [#11364](https://github.com/pydantic/pydantic/pull/11364) - -#### New Features - -* Support unsubstituted type variables with both a default and a bound or constraints by [@FyZzyss](https://github.com/FyZzyss) in [#10789](https://github.com/pydantic/pydantic/pull/10789) -* Add a `default_factory_takes_validated_data` property to `FieldInfo` by [@Viicos](https://github.com/Viicos) in [#11034](https://github.com/pydantic/pydantic/pull/11034) -* Raise a better error when a generic alias is used inside `type[]` by [@Viicos](https://github.com/Viicos) in [#11088](https://github.com/pydantic/pydantic/pull/11088) -* Properly support PEP 695 generics syntax by [@Viicos](https://github.com/Viicos) in [#11189](https://github.com/pydantic/pydantic/pull/11189) -* Properly support type variable defaults by [@Viicos](https://github.com/Viicos) in [#11332](https://github.com/pydantic/pydantic/pull/11332) - -#### Changes - -* Rework `create_model` field definitions format by [@Viicos](https://github.com/Viicos) in [#11032](https://github.com/pydantic/pydantic/pull/11032) -* Raise a deprecation warning when a field is annotated as final with a default value by [@Viicos](https://github.com/Viicos) in [#11168](https://github.com/pydantic/pydantic/pull/11168) -* Deprecate accessing `model_fields` and `model_computed_fields` on instances by [@Viicos](https://github.com/Viicos) in [#11169](https://github.com/pydantic/pydantic/pull/11169) -* Move core schema generation logic for path types inside the `GenerateSchema` class by [@sydney-runkle](https://github.com/sydney-runkle) in [#10846](https://github.com/pydantic/pydantic/pull/10846) -* Move `deque` schema gen to `GenerateSchema` class by [@sydney-runkle](https://github.com/sydney-runkle) in [#11239](https://github.com/pydantic/pydantic/pull/11239) -* Move `Mapping` schema gen to `GenerateSchema` to complete removal of `prepare_annotations_for_known_type` workaround by [@sydney-runkle](https://github.com/sydney-runkle) in [#11247](https://github.com/pydantic/pydantic/pull/11247) -* Remove Python 3.8 Support by [@sydney-runkle](https://github.com/sydney-runkle) in [#11258](https://github.com/pydantic/pydantic/pull/11258) -* Disable `pydantic-core` core schema validation by [@sydney-runkle](https://github.com/sydney-runkle) in [#11271](https://github.com/pydantic/pydantic/pull/11271) - -#### Performance - -* Only evaluate `FieldInfo` annotations if required during schema building by [@Viicos](https://github.com/Viicos) in [#10769](https://github.com/pydantic/pydantic/pull/10769) -* Optimize calls to `get_type_ref` by [@Viicos](https://github.com/Viicos) in [#10863](https://github.com/pydantic/pydantic/pull/10863) -* Improve `__setattr__` performance of Pydantic models by caching setter functions by [@MarkusSintonen](https://github.com/MarkusSintonen) in [#10868](https://github.com/pydantic/pydantic/pull/10868) -* Improve annotation application performance by [@Viicos](https://github.com/Viicos) in [#11186](https://github.com/pydantic/pydantic/pull/11186) -* Improve performance of `_typing_extra` module by [@Viicos](https://github.com/Viicos) in [#11255](https://github.com/pydantic/pydantic/pull/11255) -* Refactor and optimize schema cleaning logic by [@Viicos](https://github.com/Viicos) and [@MarkusSintonen](https://github.com/MarkusSintonen) in [#11244](https://github.com/pydantic/pydantic/pull/11244) - -#### Fixes - -* Add validation tests for `_internal/_validators.py` by [@tkasuz](https://github.com/tkasuz) in [#10763](https://github.com/pydantic/pydantic/pull/10763) -* Improve `TypeAdapter` instance repr by [@sydney-runkle](https://github.com/sydney-runkle) in [#10872](https://github.com/pydantic/pydantic/pull/10872) -* Revert "ci: use locally built pydantic-core with debug symbols by [@sydney-runkle](https://github.com/sydney-runkle) in [#10942](https://github.com/pydantic/pydantic/pull/10942) -* Re-enable all FastAPI tests by [@tamird](https://github.com/tamird) in [#10948](https://github.com/pydantic/pydantic/pull/10948) -* Fix typo in HISTORY.md. by [@felixxm](https://github.com/felixxm) in [#11077](https://github.com/pydantic/pydantic/pull/11077) -* Infer final fields with a default value as class variables in the mypy plugin by [@Viicos](https://github.com/Viicos) in [#11121](https://github.com/pydantic/pydantic/pull/11121) -* Recursively unpack `Literal` values if using PEP 695 type aliases by [@Viicos](https://github.com/Viicos) in [#11114](https://github.com/pydantic/pydantic/pull/11114) -* Override `__subclasscheck__` on `ModelMetaclass` to avoid memory leak and performance issues by [@Viicos](https://github.com/Viicos) in [#11116](https://github.com/pydantic/pydantic/pull/11116) -* Remove unused `_extract_get_pydantic_json_schema()` parameter by [@Viicos](https://github.com/Viicos) in [#11155](https://github.com/pydantic/pydantic/pull/11155) -* Add FastAPI and SQLModel to third-party tests by [@sydney-runkle](https://github.com/sydney-runkle) in [#11044](https://github.com/pydantic/pydantic/pull/11044) -* Fix conditional expressions syntax for third-party tests by [@Viicos](https://github.com/Viicos) in [#11162](https://github.com/pydantic/pydantic/pull/11162) -* Move FastAPI tests to third-party workflow by [@Viicos](https://github.com/Viicos) in [#11164](https://github.com/pydantic/pydantic/pull/11164) -* Improve discriminated union error message for invalid union variants by [@Viicos](https://github.com/Viicos) in [#11161](https://github.com/pydantic/pydantic/pull/11161) -* Unpack PEP 695 type aliases if using the `Annotated` form by [@Viicos](https://github.com/Viicos) in [#11109](https://github.com/pydantic/pydantic/pull/11109) -* Include `openapi-python-client` check in issue creation for third-party failures, use `main` branch by [@sydney-runkle](https://github.com/sydney-runkle) in [#11182](https://github.com/pydantic/pydantic/pull/11182) -* Add pandera third-party tests by [@Viicos](https://github.com/Viicos) in [#11193](https://github.com/pydantic/pydantic/pull/11193) -* Add ODMantic third-party tests by [@sydney-runkle](https://github.com/sydney-runkle) in [#11197](https://github.com/pydantic/pydantic/pull/11197) -* Add missing stacklevel in `deprecated_instance_property` warning by [@Viicos](https://github.com/Viicos) in [#11200](https://github.com/pydantic/pydantic/pull/11200) -* Copy `WithJsonSchema` schema to avoid sharing mutated data by [@thejcannon](https://github.com/thejcannon) in [#11014](https://github.com/pydantic/pydantic/pull/11014) -* Do not cache parametrized models when in the process of parametrizing another model by [@Viicos](https://github.com/Viicos) in [#10704](https://github.com/pydantic/pydantic/pull/10704) -* Re-enable Beanie third-party tests by [@Viicos](https://github.com/Viicos) in [#11214](https://github.com/pydantic/pydantic/pull/11214) -* Add discriminated union related metadata entries to the `CoreMetadata` definition by [@Viicos](https://github.com/Viicos) in [#11216](https://github.com/pydantic/pydantic/pull/11216) -* Consolidate schema definitions logic in the `_Definitions` class by [@Viicos](https://github.com/Viicos) in [#11208](https://github.com/pydantic/pydantic/pull/11208) -* Support initializing root model fields with values of the `root` type in the mypy plugin by [@Viicos](https://github.com/Viicos) in [#11212](https://github.com/pydantic/pydantic/pull/11212) -* Fix various issues with dataclasses and `use_attribute_docstrings` by [@Viicos](https://github.com/Viicos) in [#11246](https://github.com/pydantic/pydantic/pull/11246) -* Only compute normalized decimal places if necessary in `decimal_places_validator` by [@misrasaurabh1](https://github.com/misrasaurabh1) in [#11281](https://github.com/pydantic/pydantic/pull/11281) -* Fix two misplaced sentences in validation errors documentation by [@ananiavito](https://github.com/ananiavito) in [#11302](https://github.com/pydantic/pydantic/pull/11302) -* Fix mkdocstrings inventory example in documentation by [@pawamoy](https://github.com/pawamoy) in [#11311](https://github.com/pydantic/pydantic/pull/11311) -* Add support for `validation_alias` in the mypy plugin by [@Viicos](https://github.com/Viicos) in [#11295](https://github.com/pydantic/pydantic/pull/11295) -* Do not transform model serializer functions as class methods in the mypy plugin by [@Viicos](https://github.com/Viicos) in [#11298](https://github.com/pydantic/pydantic/pull/11298) -* Simplify `GenerateJsonSchema.literal_schema()` implementation by [@misrasaurabh1](https://github.com/misrasaurabh1) in [#11321](https://github.com/pydantic/pydantic/pull/11321) -* Add additional allowed schemes for `ClickHouseDsn` by [@Maze21127](https://github.com/Maze21127) in [#11319](https://github.com/pydantic/pydantic/pull/11319) -* Coerce decimal constraints to `Decimal` instances by [@Viicos](https://github.com/Viicos) in [#11350](https://github.com/pydantic/pydantic/pull/11350) -* Fix `ValueError` on year zero by [@davidhewitt](https://github.com/davidhewitt) in [pydantic-core#1583](https://github.com/pydantic/pydantic-core/pull/1583) - -### New Contributors - -* [@FyZzyss](https://github.com/FyZzyss) made their first contribution in [#10789](https://github.com/pydantic/pydantic/pull/10789) -* [@tamird](https://github.com/tamird) made their first contribution in [#10948](https://github.com/pydantic/pydantic/pull/10948) -* [@felixxm](https://github.com/felixxm) made their first contribution in [#11077](https://github.com/pydantic/pydantic/pull/11077) -* [@alexprabhat99](https://github.com/alexprabhat99) made their first contribution in [#11082](https://github.com/pydantic/pydantic/pull/11082) -* [@Kharianne](https://github.com/Kharianne) made their first contribution in [#11111](https://github.com/pydantic/pydantic/pull/11111) -* [@mdaffad](https://github.com/mdaffad) made their first contribution in [#11177](https://github.com/pydantic/pydantic/pull/11177) -* [@thejcannon](https://github.com/thejcannon) made their first contribution in [#11014](https://github.com/pydantic/pydantic/pull/11014) -* [@thomasfrimannkoren](https://github.com/thomasfrimannkoren) made their first contribution in [#11251](https://github.com/pydantic/pydantic/pull/11251) -* [@usernameMAI](https://github.com/usernameMAI) made their first contribution in [#11275](https://github.com/pydantic/pydantic/pull/11275) -* [@ananiavito](https://github.com/ananiavito) made their first contribution in [#11302](https://github.com/pydantic/pydantic/pull/11302) -* [@pawamoy](https://github.com/pawamoy) made their first contribution in [#11311](https://github.com/pydantic/pydantic/pull/11311) -* [@Maze21127](https://github.com/Maze21127) made their first contribution in [#11319](https://github.com/pydantic/pydantic/pull/11319) - -## v2.10.6 (2025-01-23) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.6) - -### What's Changed - -#### Fixes - -* Fix JSON Schema reference collection with `'examples'` keys by [@Viicos](https://github.com/Viicos) in [#11325](https://github.com/pydantic/pydantic/pull/11325) -* Fix url python serialization by [@sydney-runkle](https://github.com/sydney-runkle) in [#11331](https://github.com/pydantic/pydantic/pull/11331) - -## v2.10.5 (2025-01-08) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.5) - -### What's Changed - -#### Fixes - -* Remove custom MRO implementation of Pydantic models by [@Viicos](https://github.com/Viicos) in [#11184](https://github.com/pydantic/pydantic/pull/11184) -* Fix URL serialization for unions by [@sydney-runkle](https://github.com/sydney-runkle) in [#11233](https://github.com/pydantic/pydantic/pull/11233) - -## v2.10.4 (2024-12-18) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.4) - -### What's Changed - -#### Packaging - -* Bump `pydantic-core` to v2.27.2 by [@davidhewitt](https://github.com/davidhewitt) in [#11138](https://github.com/pydantic/pydantic/pull/11138) - -#### Fixes - -* Fix for comparison of `AnyUrl` objects by [@alexprabhat99](https://github.com/alexprabhat99) in [#11082](https://github.com/pydantic/pydantic/pull/11082) -* Properly fetch PEP 695 type params for functions, do not fetch annotations from signature by [@Viicos](https://github.com/Viicos) in [#11093](https://github.com/pydantic/pydantic/pull/11093) -* Include JSON Schema input core schema in function schemas by [@Viicos](https://github.com/Viicos) in [#11085](https://github.com/pydantic/pydantic/pull/11085) -* Add `len` to `_BaseUrl` to avoid TypeError by [@Kharianne](https://github.com/Kharianne) in [#11111](https://github.com/pydantic/pydantic/pull/11111) -* Make sure the type reference is removed from the seen references by [@Viicos](https://github.com/Viicos) in [#11143](https://github.com/pydantic/pydantic/pull/11143) - -### New Contributors - -* [@FyZzyss](https://github.com/FyZzyss) made their first contribution in [#10789](https://github.com/pydantic/pydantic/pull/10789) -* [@tamird](https://github.com/tamird) made their first contribution in [#10948](https://github.com/pydantic/pydantic/pull/10948) -* [@felixxm](https://github.com/felixxm) made their first contribution in [#11077](https://github.com/pydantic/pydantic/pull/11077) -* [@alexprabhat99](https://github.com/alexprabhat99) made their first contribution in [#11082](https://github.com/pydantic/pydantic/pull/11082) -* [@Kharianne](https://github.com/Kharianne) made their first contribution in [#11111](https://github.com/pydantic/pydantic/pull/11111) - -## v2.10.3 (2024-12-03) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.3) - -### What's Changed - -#### Fixes - -* Set fields when `defer_build` is set on Pydantic dataclasses by [@Viicos](https://github.com/Viicos) in [#10984](https://github.com/pydantic/pydantic/pull/10984) -* Do not resolve the JSON Schema reference for `dict` core schema keys by [@Viicos](https://github.com/Viicos) in [#10989](https://github.com/pydantic/pydantic/pull/10989) -* Use the globals of the function when evaluating the return type for `PlainSerializer` and `WrapSerializer` functions by [@Viicos](https://github.com/Viicos) in [#11008](https://github.com/pydantic/pydantic/pull/11008) -* Fix host required enforcement for urls to be compatible with v2.9 behavior by [@sydney-runkle](https://github.com/sydney-runkle) in [#11027](https://github.com/pydantic/pydantic/pull/11027) -* Add a `default_factory_takes_validated_data` property to `FieldInfo` by [@Viicos](https://github.com/Viicos) in [#11034](https://github.com/pydantic/pydantic/pull/11034) -* Fix url json schema in `serialization` mode by [@sydney-runkle](https://github.com/sydney-runkle) in [#11035](https://github.com/pydantic/pydantic/pull/11035) - -## v2.10.2 (2024-11-25) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.2) - -### What's Changed - -#### Fixes - -* Only evaluate FieldInfo annotations if required during schema building by [@Viicos](https://github.com/Viicos) in [#10769](https://github.com/pydantic/pydantic/pull/10769) -* Do not evaluate annotations for private fields by [@Viicos](https://github.com/Viicos) in [#10962](https://github.com/pydantic/pydantic/pull/10962) -* Support serialization as any for `Secret` types and `Url` types by [@sydney-runkle](https://github.com/sydney-runkle) in [#10947](https://github.com/pydantic/pydantic/pull/10947) -* Fix type hint of `Field.default` to be compatible with Python 3.8 and 3.9 by [@Viicos](https://github.com/Viicos) in [#10972](https://github.com/pydantic/pydantic/pull/10972) -* Add hashing support for URL types by [@sydney-runkle](https://github.com/sydney-runkle) in [#10975](https://github.com/pydantic/pydantic/pull/10975) -* Hide `BaseModel.__replace__` definition from type checkers by [@Viicos](https://github.com/Viicos) in [#10979](https://github.com/pydantic/pydantic/pull/10979) - -## v2.10.1 (2024-11-21) - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.1) - -### What's Changed - -#### Packaging - -* Bump `pydantic-core` version to `v2.27.1` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10938](https://github.com/pydantic/pydantic/pull/10938) - -#### Fixes - -* Use the correct frame when instantiating a parametrized `TypeAdapter` by [@Viicos](https://github.com/Viicos) in [#10893](https://github.com/pydantic/pydantic/pull/10893) -* Relax check for validated data in `default_factory` utils by [@sydney-runkle](https://github.com/sydney-runkle) in [#10909](https://github.com/pydantic/pydantic/pull/10909) -* Fix type checking issue with `model_fields` and `model_computed_fields` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10911](https://github.com/pydantic/pydantic/pull/10911) -* Use the parent configuration during schema generation for stdlib `dataclass`es by [@sydney-runkle](https://github.com/sydney-runkle) in [#10928](https://github.com/pydantic/pydantic/pull/10928) -* Use the `globals` of the function when evaluating the return type of serializers and `computed_field`s by [@Viicos](https://github.com/Viicos) in [#10929](https://github.com/pydantic/pydantic/pull/10929) -* Fix URL constraint application by [@sydney-runkle](https://github.com/sydney-runkle) in [#10922](https://github.com/pydantic/pydantic/pull/10922) -* Fix URL equality with different validation methods by [@sydney-runkle](https://github.com/sydney-runkle) in [#10934](https://github.com/pydantic/pydantic/pull/10934) -* Fix JSON schema title when specified as `''` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10936](https://github.com/pydantic/pydantic/pull/10936) -* Fix `python` mode serialization for `complex` inference by [@sydney-runkle](https://github.com/sydney-runkle) in [pydantic-core#1549](https://github.com/pydantic/pydantic-core/pull/1549) - -### New Contributors - -## v2.10.0 (2024-11-20) - -The code released in v2.10.0 is practically identical to that of v2.10.0b2. - -[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.0) - -See the [v2.10 release blog post](https://pydantic.dev/articles/pydantic-v2-10-release) for the highlights! - -### What's Changed - -#### Packaging - -* Bump `pydantic-core` to `v2.27.0` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10825](https://github.com/pydantic/pydantic/pull/10825) -* Replaced pdm with uv by [@frfahim](https://github.com/frfahim) in [#10727](https://github.com/pydantic/pydantic/pull/10727) - -#### New Features - -* Support `fractions.Fraction` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10318](https://github.com/pydantic/pydantic/pull/10318) -* Support `Hashable` for json validation by [@sydney-runkle](https://github.com/sydney-runkle) in [#10324](https://github.com/pydantic/pydantic/pull/10324) -* Add a `SocketPath` type for `linux` systems by [@theunkn0wn1](https://github.com/theunkn0wn1) in [#10378](https://github.com/pydantic/pydantic/pull/10378) -* Allow arbitrary refs in JSON schema `examples` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10417](https://github.com/pydantic/pydantic/pull/10417) -* Support `defer_build` for Pydantic dataclasses by [@Viicos](https://github.com/Viicos) in [#10313](https://github.com/pydantic/pydantic/pull/10313) -* Adding v1 / v2 incompatibility warning for nested v1 model by [@sydney-runkle](https://github.com/sydney-runkle) in [#10431](https://github.com/pydantic/pydantic/pull/10431) -* Add support for unpacked `TypedDict` to type hint variadic keyword arguments with `@validate_call` by [@Viicos](https://github.com/Viicos) in [#10416](https://github.com/pydantic/pydantic/pull/10416) -* Support compiled patterns in `protected_namespaces` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10522](https://github.com/pydantic/pydantic/pull/10522) -* Add support for `propertyNames` in JSON schema by [@FlorianSW](https://github.com/FlorianSW) in [#10478](https://github.com/pydantic/pydantic/pull/10478) -* Adding `__replace__` protocol for Python 3.13+ support by [@sydney-runkle](https://github.com/sydney-runkle) in [#10596](https://github.com/pydantic/pydantic/pull/10596) -* Expose public `sort` method for JSON schema generation by [@sydney-runkle](https://github.com/sydney-runkle) in [#10595](https://github.com/pydantic/pydantic/pull/10595) -* Add runtime validation of `@validate_call` callable argument by [@kc0506](https://github.com/kc0506) in [#10627](https://github.com/pydantic/pydantic/pull/10627) -* Add `experimental_allow_partial` support by [@samuelcolvin](https://github.com/samuelcolvin) in [#10748](https://github.com/pydantic/pydantic/pull/10748) -* Support default factories taking validated data as an argument by [@Viicos](https://github.com/Viicos) in [#10678](https://github.com/pydantic/pydantic/pull/10678) -* Allow subclassing `ValidationError` and `PydanticCustomError` by [@Youssefares](https://github.com/Youssefares) in [pydantic/pydantic-core#1413](https://github.com/pydantic/pydantic-core/pull/1413) -* Add `trailing-strings` support to `experimental_allow_partial` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10825](https://github.com/pydantic/pydantic/pull/10825) -* Add `rebuild()` method for `TypeAdapter` and simplify `defer_build` patterns by [@sydney-runkle](https://github.com/sydney-runkle) in [#10537](https://github.com/pydantic/pydantic/pull/10537) -* Improve `TypeAdapter` instance repr by [@sydney-runkle](https://github.com/sydney-runkle) in [#10872](https://github.com/pydantic/pydantic/pull/10872) - -#### Changes - -* Don't allow customization of `SchemaGenerator` until interface is more stable by [@sydney-runkle](https://github.com/sydney-runkle) in [#10303](https://github.com/pydantic/pydantic/pull/10303) -* Cleanly `defer_build` on `TypeAdapters`, removing experimental flag by [@sydney-runkle](https://github.com/sydney-runkle) in [#10329](https://github.com/pydantic/pydantic/pull/10329) -* Fix `mro` of generic subclass by [@kc0506](https://github.com/kc0506) in [#10100](https://github.com/pydantic/pydantic/pull/10100) -* Strip whitespaces on JSON Schema title generation by [@sydney-runkle](https://github.com/sydney-runkle) in [#10404](https://github.com/pydantic/pydantic/pull/10404) -* Use `b64decode` and `b64encode` for `Base64Bytes` type by [@sydney-runkle](https://github.com/sydney-runkle) in [#10486](https://github.com/pydantic/pydantic/pull/10486) -* Relax protected namespace config default by [@sydney-runkle](https://github.com/sydney-runkle) in [#10441](https://github.com/pydantic/pydantic/pull/10441) -* Revalidate parametrized generics if instance's origin is subclass of OG class by [@sydney-runkle](https://github.com/sydney-runkle) in [#10666](https://github.com/pydantic/pydantic/pull/10666) -* Warn if configuration is specified on the `@dataclass` decorator and with the `__pydantic_config__` attribute by [@sydney-runkle](https://github.com/sydney-runkle) in [#10406](https://github.com/pydantic/pydantic/pull/10406) -* Recommend against using `Ellipsis` (...) with `Field` by [@Viicos](https://github.com/Viicos) in [#10661](https://github.com/pydantic/pydantic/pull/10661) -* Migrate to subclassing instead of annotated approach for pydantic url types by [@sydney-runkle](https://github.com/sydney-runkle) in [#10662](https://github.com/pydantic/pydantic/pull/10662) -* Change JSON schema generation of `Literal`s and `Enums` by [@Viicos](https://github.com/Viicos) in [#10692](https://github.com/pydantic/pydantic/pull/10692) -* Simplify unions involving `Any` or `Never` when replacing type variables by [@Viicos](https://github.com/Viicos) in [#10338](https://github.com/pydantic/pydantic/pull/10338) -* Do not require padding when decoding `base64` bytes by [@bschoenmaeckers](https://github.com/bschoenmaeckers) in [pydantic/pydantic-core#1448](https://github.com/pydantic/pydantic-core/pull/1448) -* Support dates all the way to 1BC by [@changhc](https://github.com/changhc) in [pydantic/speedate#77](https://github.com/pydantic/speedate/pull/77) - -#### Performance - -* Schema cleaning: skip unnecessary copies during schema walking by [@Viicos](https://github.com/Viicos) in [#10286](https://github.com/pydantic/pydantic/pull/10286) -* Refactor namespace logic for annotations evaluation by [@Viicos](https://github.com/Viicos) in [#10530](https://github.com/pydantic/pydantic/pull/10530) -* Improve email regexp on edge cases by [@AlekseyLobanov](https://github.com/AlekseyLobanov) in [#10601](https://github.com/pydantic/pydantic/pull/10601) -* `CoreMetadata` refactor with an emphasis on documentation, schema build time performance, and reducing complexity by [@sydney-runkle](https://github.com/sydney-runkle) in [#10675](https://github.com/pydantic/pydantic/pull/10675) - -#### Fixes - -* Remove guarding check on `computed_field` with `field_serializer` by [@nix010](https://github.com/nix010) in [#10390](https://github.com/pydantic/pydantic/pull/10390) -* Fix `Predicate` issue in `v2.9.0` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10321](https://github.com/pydantic/pydantic/pull/10321) -* Fixing `annotated-types` bound by [@sydney-runkle](https://github.com/sydney-runkle) in [#10327](https://github.com/pydantic/pydantic/pull/10327) -* Turn `tzdata` install requirement into optional `timezone` dependency by [@jakob-keller](https://github.com/jakob-keller) in [#10331](https://github.com/pydantic/pydantic/pull/10331) -* Use correct types namespace when building `namedtuple` core schemas by [@Viicos](https://github.com/Viicos) in [#10337](https://github.com/pydantic/pydantic/pull/10337) -* Fix evaluation of stringified annotations during namespace inspection by [@Viicos](https://github.com/Viicos) in [#10347](https://github.com/pydantic/pydantic/pull/10347) -* Fix `IncEx` type alias definition by [@Viicos](https://github.com/Viicos) in [#10339](https://github.com/pydantic/pydantic/pull/10339) -* Do not error when trying to evaluate annotations of private attributes by [@Viicos](https://github.com/Viicos) in [#10358](https://github.com/pydantic/pydantic/pull/10358) -* Fix nested type statement by [@kc0506](https://github.com/kc0506) in [#10369](https://github.com/pydantic/pydantic/pull/10369) -* Improve typing of `ModelMetaclass.mro` by [@Viicos](https://github.com/Viicos) in [#10372](https://github.com/pydantic/pydantic/pull/10372) -* Fix class access of deprecated `computed_field`s by [@Viicos](https://github.com/Viicos) in [#10391](https://github.com/pydantic/pydantic/pull/10391) -* Make sure `inspect.iscoroutinefunction` works on coroutines decorated with `@validate_call` by [@MovisLi](https://github.com/MovisLi) in [#10374](https://github.com/pydantic/pydantic/pull/10374) -* Fix `NameError` when using `validate_call` with PEP 695 on a class by [@kc0506](https://github.com/kc0506) in [#10380](https://github.com/pydantic/pydantic/pull/10380) -* Fix `ZoneInfo` with various invalid types by [@sydney-runkle](https://github.com/sydney-runkle) in [#10408](https://github.com/pydantic/pydantic/pull/10408) -* Fix `PydanticUserError` on empty `model_config` with annotations by [@cdwilson](https://github.com/cdwilson) in [#10412](https://github.com/pydantic/pydantic/pull/10412) -* Fix variance issue in `_IncEx` type alias, only allow `True` by [@Viicos](https://github.com/Viicos) in [#10414](https://github.com/pydantic/pydantic/pull/10414) -* Fix serialization schema generation when using `PlainValidator` by [@Viicos](https://github.com/Viicos) in [#10427](https://github.com/pydantic/pydantic/pull/10427) -* Fix schema generation error when serialization schema holds references by [@Viicos](https://github.com/Viicos) in [#10444](https://github.com/pydantic/pydantic/pull/10444) -* Inline references if possible when generating schema for `json_schema_input_type` by [@Viicos](https://github.com/Viicos) in [#10439](https://github.com/pydantic/pydantic/pull/10439) -* Fix recursive arguments in `Representation` by [@Viicos](https://github.com/Viicos) in [#10480](https://github.com/pydantic/pydantic/pull/10480) -* Fix representation for builtin function types by [@kschwab](https://github.com/kschwab) in [#10479](https://github.com/pydantic/pydantic/pull/10479) -* Add python validators for decimal constraints (`max_digits` and `decimal_places`) by [@sydney-runkle](https://github.com/sydney-runkle) in [#10506](https://github.com/pydantic/pydantic/pull/10506) -* Only fetch `__pydantic_core_schema__` from the current class during schema generation by [@Viicos](https://github.com/Viicos) in [#10518](https://github.com/pydantic/pydantic/pull/10518) -* Fix `stacklevel` on deprecation warnings for `BaseModel` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10520](https://github.com/pydantic/pydantic/pull/10520) -* Fix warning `stacklevel` in `BaseModel.__init__` by [@Viicos](https://github.com/Viicos) in [#10526](https://github.com/pydantic/pydantic/pull/10526) -* Improve error handling for in-evaluable refs for discriminator application by [@sydney-runkle](https://github.com/sydney-runkle) in [#10440](https://github.com/pydantic/pydantic/pull/10440) -* Change the signature of `ConfigWrapper.core_config` to take the title directly by [@Viicos](https://github.com/Viicos) in [#10562](https://github.com/pydantic/pydantic/pull/10562) -* Do not use the previous config from the stack for dataclasses without config by [@Viicos](https://github.com/Viicos) in [#10576](https://github.com/pydantic/pydantic/pull/10576) -* Fix serialization for IP types with `mode='python'` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10594](https://github.com/pydantic/pydantic/pull/10594) -* Support constraint application for `Base64Etc` types by [@sydney-runkle](https://github.com/sydney-runkle) in [#10584](https://github.com/pydantic/pydantic/pull/10584) -* Fix `validate_call` ignoring `Field` in `Annotated` by [@kc0506](https://github.com/kc0506) in [#10610](https://github.com/pydantic/pydantic/pull/10610) -* Raise an error when `Self` is invalid by [@kc0506](https://github.com/kc0506) in [#10609](https://github.com/pydantic/pydantic/pull/10609) -* Using `core_schema.InvalidSchema` instead of metadata injection + checks by [@sydney-runkle](https://github.com/sydney-runkle) in [#10523](https://github.com/pydantic/pydantic/pull/10523) -* Tweak type alias logic by [@kc0506](https://github.com/kc0506) in [#10643](https://github.com/pydantic/pydantic/pull/10643) -* Support usage of `type` with `typing.Self` and type aliases by [@kc0506](https://github.com/kc0506) in [#10621](https://github.com/pydantic/pydantic/pull/10621) -* Use overloads for `Field` and `PrivateAttr` functions by [@Viicos](https://github.com/Viicos) in [#10651](https://github.com/pydantic/pydantic/pull/10651) -* Clean up the `mypy` plugin implementation by [@Viicos](https://github.com/Viicos) in [#10669](https://github.com/pydantic/pydantic/pull/10669) -* Properly check for `typing_extensions` variant of `TypeAliasType` by [@Daraan](https://github.com/Daraan) in [#10713](https://github.com/pydantic/pydantic/pull/10713) -* Allow any mapping in `BaseModel.model_copy()` by [@Viicos](https://github.com/Viicos) in [#10751](https://github.com/pydantic/pydantic/pull/10751) -* Fix `isinstance` behavior for urls by [@sydney-runkle](https://github.com/sydney-runkle) in [#10766](https://github.com/pydantic/pydantic/pull/10766) -* Ensure `cached_property` can be set on Pydantic models by [@Viicos](https://github.com/Viicos) in [#10774](https://github.com/pydantic/pydantic/pull/10774) -* Fix equality checks for primitives in literals by [@sydney-runkle](https://github.com/sydney-runkle) in [pydantic/pydantic-core#1459](https://github.com/pydantic/pydantic-core/pull/1459) -* Properly enforce `host_required` for URLs by [@Viicos](https://github.com/Viicos) in [pydantic/pydantic-core#1488](https://github.com/pydantic/pydantic-core/pull/1488) -* Fix when `coerce_numbers_to_str` enabled and string has invalid Unicode character by [@andrey-berenda](https://github.com/andrey-berenda) in [pydantic/pydantic-core#1515](https://github.com/pydantic/pydantic-core/pull/1515) -* Fix serializing `complex` values in `Enum`s by [@changhc](https://github.com/changhc) in [pydantic/pydantic-core#1524](https://github.com/pydantic/pydantic-core/pull/1524) -* Refactor `_typing_extra` module by [@Viicos](https://github.com/Viicos) in [#10725](https://github.com/pydantic/pydantic/pull/10725) -* Support intuitive equality for urls by [@sydney-runkle](https://github.com/sydney-runkle) in [#10798](https://github.com/pydantic/pydantic/pull/10798) -* Add `bytearray` to `TypeAdapter.validate_json` signature by [@samuelcolvin](https://github.com/samuelcolvin) in [#10802](https://github.com/pydantic/pydantic/pull/10802) -* Ensure class access of method descriptors is performed when used as a default with `Field` by [@Viicos](https://github.com/Viicos) in [#10816](https://github.com/pydantic/pydantic/pull/10816) -* Fix circular import with `validate_call` by [@sydney-runkle](https://github.com/sydney-runkle) in [#10807](https://github.com/pydantic/pydantic/pull/10807) -* Fix error when using type aliases referencing other type aliases by [@Viicos](https://github.com/Viicos) in [#10809](https://github.com/pydantic/pydantic/pull/10809) -* Fix `IncEx` type alias to be compatible with mypy by [@Viicos](https://github.com/Viicos) in [#10813](https://github.com/pydantic/pydantic/pull/10813) -* Make `__signature__` a lazy property, do not deepcopy defaults by [@Viicos](https://github.com/Viicos) in [#10818](https://github.com/pydantic/pydantic/pull/10818) -* Make `__signature__` lazy for dataclasses, too by [@sydney-runkle](https://github.com/sydney-runkle) in [#10832](https://github.com/pydantic/pydantic/pull/10832) -* Subclass all single host url classes from `AnyUrl` to preserve behavior from v2.9 by [@sydney-runkle](https://github.com/sydney-runkle) in [#10856](https://github.com/pydantic/pydantic/pull/10856) - -### New Contributors - -* [@jakob-keller](https://github.com/jakob-keller) made their first contribution in [#10331](https://github.com/pydantic/pydantic/pull/10331) -* [@MovisLi](https://github.com/MovisLi) made their first contribution in [#10374](https://github.com/pydantic/pydantic/pull/10374) -* [@joaopalmeiro](https://github.com/joaopalmeiro) made their first contribution in [#10405](https://github.com/pydantic/pydantic/pull/10405) -* [@theunkn0wn1](https://github.com/theunkn0wn1) made their first contribution in [#10378](https://github.com/pydantic/pydantic/pull/10378) -* [@cdwilson](https://github.com/cdwilson) made their first contribution in [#10412](https://github.com/pydantic/pydantic/pull/10412) -* [@dlax](https://github.com/dlax) made their first contribution in [#10421](https://github.com/pydantic/pydantic/pull/10421) -* [@kschwab](https://github.com/kschwab) made their first contribution in [#10479](https://github.com/pydantic/pydantic/pull/10479) -* [@santibreo](https://github.com/santibreo) made their first contribution in [#10453](https://github.com/pydantic/pydantic/pull/10453) -* [@FlorianSW](https://github.com/FlorianSW) made their first contribution in [#10478](https://github.com/pydantic/pydantic/pull/10478) -* [@tkasuz](https://github.com/tkasuz) made their first contribution in [#10555](https://github.com/pydantic/pydantic/pull/10555) -* [@AlekseyLobanov](https://github.com/AlekseyLobanov) made their first contribution in [#10601](https://github.com/pydantic/pydantic/pull/10601) -* [@NiclasvanEyk](https://github.com/NiclasvanEyk) made their first contribution in [#10667](https://github.com/pydantic/pydantic/pull/10667) -* [@mschoettle](https://github.com/mschoettle) made their first contribution in [#10677](https://github.com/pydantic/pydantic/pull/10677) -* [@Daraan](https://github.com/Daraan) made their first contribution in [#10713](https://github.com/pydantic/pydantic/pull/10713) -* [@k4nar](https://github.com/k4nar) made their first contribution in [#10736](https://github.com/pydantic/pydantic/pull/10736) -* [@UriyaHarpeness](https://github.com/UriyaHarpeness) made their first contribution in [#10740](https://github.com/pydantic/pydantic/pull/10740) -* [@frfahim](https://github.com/frfahim) made their first contribution in [#10727](https://github.com/pydantic/pydantic/pull/10727) - -## v2.10.0b2 (2024-11-13) - -Pre-release, see [the GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.0b2) for details. - -## v2.10.0b1 (2024-11-06) - -Pre-release, see [the GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.10.0b1) for details. - - -... see [here](https://docs.pydantic.dev/changelog/#v0322-2019-08-17) for earlier changes. diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/RECORD deleted file mode 100644 index 4d97aa7..0000000 --- a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/RECORD +++ /dev/null @@ -1,216 +0,0 @@ -pydantic-2.11.9.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -pydantic-2.11.9.dist-info/METADATA,sha256=lL7X6XeRbjebP34efO4-PGm-7pd2SSu2OoR9ZM6toso,68441 -pydantic-2.11.9.dist-info/RECORD,, -pydantic-2.11.9.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic-2.11.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 -pydantic-2.11.9.dist-info/licenses/LICENSE,sha256=qeGG88oWte74QxjnpwFyE1GgDLe4rjpDlLZ7SeNSnvM,1129 -pydantic/__init__.py,sha256=D3_-0aRPoAF5EH4T4JPVOYLNEc-DeaCcDt6UzIjP_D0,15395 -pydantic/__pycache__/__init__.cpython-312.pyc,, -pydantic/__pycache__/_migration.cpython-312.pyc,, -pydantic/__pycache__/alias_generators.cpython-312.pyc,, -pydantic/__pycache__/aliases.cpython-312.pyc,, -pydantic/__pycache__/annotated_handlers.cpython-312.pyc,, -pydantic/__pycache__/class_validators.cpython-312.pyc,, -pydantic/__pycache__/color.cpython-312.pyc,, -pydantic/__pycache__/config.cpython-312.pyc,, -pydantic/__pycache__/dataclasses.cpython-312.pyc,, -pydantic/__pycache__/datetime_parse.cpython-312.pyc,, -pydantic/__pycache__/decorator.cpython-312.pyc,, -pydantic/__pycache__/env_settings.cpython-312.pyc,, -pydantic/__pycache__/error_wrappers.cpython-312.pyc,, -pydantic/__pycache__/errors.cpython-312.pyc,, -pydantic/__pycache__/fields.cpython-312.pyc,, -pydantic/__pycache__/functional_serializers.cpython-312.pyc,, -pydantic/__pycache__/functional_validators.cpython-312.pyc,, -pydantic/__pycache__/generics.cpython-312.pyc,, -pydantic/__pycache__/json.cpython-312.pyc,, -pydantic/__pycache__/json_schema.cpython-312.pyc,, -pydantic/__pycache__/main.cpython-312.pyc,, -pydantic/__pycache__/mypy.cpython-312.pyc,, -pydantic/__pycache__/networks.cpython-312.pyc,, -pydantic/__pycache__/parse.cpython-312.pyc,, -pydantic/__pycache__/root_model.cpython-312.pyc,, -pydantic/__pycache__/schema.cpython-312.pyc,, -pydantic/__pycache__/tools.cpython-312.pyc,, -pydantic/__pycache__/type_adapter.cpython-312.pyc,, -pydantic/__pycache__/types.cpython-312.pyc,, -pydantic/__pycache__/typing.cpython-312.pyc,, -pydantic/__pycache__/utils.cpython-312.pyc,, -pydantic/__pycache__/validate_call_decorator.cpython-312.pyc,, -pydantic/__pycache__/validators.cpython-312.pyc,, -pydantic/__pycache__/version.cpython-312.pyc,, -pydantic/__pycache__/warnings.cpython-312.pyc,, -pydantic/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic/_internal/__pycache__/__init__.cpython-312.pyc,, -pydantic/_internal/__pycache__/_config.cpython-312.pyc,, -pydantic/_internal/__pycache__/_core_metadata.cpython-312.pyc,, -pydantic/_internal/__pycache__/_core_utils.cpython-312.pyc,, -pydantic/_internal/__pycache__/_dataclasses.cpython-312.pyc,, -pydantic/_internal/__pycache__/_decorators.cpython-312.pyc,, -pydantic/_internal/__pycache__/_decorators_v1.cpython-312.pyc,, -pydantic/_internal/__pycache__/_discriminated_union.cpython-312.pyc,, -pydantic/_internal/__pycache__/_docs_extraction.cpython-312.pyc,, -pydantic/_internal/__pycache__/_fields.cpython-312.pyc,, -pydantic/_internal/__pycache__/_forward_ref.cpython-312.pyc,, -pydantic/_internal/__pycache__/_generate_schema.cpython-312.pyc,, -pydantic/_internal/__pycache__/_generics.cpython-312.pyc,, -pydantic/_internal/__pycache__/_git.cpython-312.pyc,, -pydantic/_internal/__pycache__/_import_utils.cpython-312.pyc,, -pydantic/_internal/__pycache__/_internal_dataclass.cpython-312.pyc,, -pydantic/_internal/__pycache__/_known_annotated_metadata.cpython-312.pyc,, -pydantic/_internal/__pycache__/_mock_val_ser.cpython-312.pyc,, -pydantic/_internal/__pycache__/_model_construction.cpython-312.pyc,, -pydantic/_internal/__pycache__/_namespace_utils.cpython-312.pyc,, -pydantic/_internal/__pycache__/_repr.cpython-312.pyc,, -pydantic/_internal/__pycache__/_schema_gather.cpython-312.pyc,, -pydantic/_internal/__pycache__/_schema_generation_shared.cpython-312.pyc,, -pydantic/_internal/__pycache__/_serializers.cpython-312.pyc,, -pydantic/_internal/__pycache__/_signature.cpython-312.pyc,, -pydantic/_internal/__pycache__/_typing_extra.cpython-312.pyc,, -pydantic/_internal/__pycache__/_utils.cpython-312.pyc,, -pydantic/_internal/__pycache__/_validate_call.cpython-312.pyc,, -pydantic/_internal/__pycache__/_validators.cpython-312.pyc,, -pydantic/_internal/_config.py,sha256=WV07hp8xf0Q0yP9IwMvuGLQmu34AZl5sBs2JaOgCk9I,14253 -pydantic/_internal/_core_metadata.py,sha256=Y_g2t3i7uluK-wXCZvzJfRFMPUM23aBYLfae4FzBPy0,5162 -pydantic/_internal/_core_utils.py,sha256=_-ZuXhpi_0JDpZzz8jvGr82kgS3PEritWR22fjWpw48,6746 -pydantic/_internal/_dataclasses.py,sha256=GA-NO1cgYbce0UwZP-sfPe5AujHjhvgTKbPCyg9GGP8,8990 -pydantic/_internal/_decorators.py,sha256=NS7SKQvtDgnsAd37mjqtwPh19td57FJ69LsceO5SywI,32638 -pydantic/_internal/_decorators_v1.py,sha256=tfdfdpQKY4R2XCOwqHbZeoQMur6VNigRrfhudXBHx38,6185 -pydantic/_internal/_discriminated_union.py,sha256=aMl0SRSyQyHfW4-klnMTHNvwSRoqE3H3PRV_05vRsTg,25478 -pydantic/_internal/_docs_extraction.py,sha256=p-STFvLHUzxrj6bblpaAAYWmq4INxVCAdIupDgQYSIw,3831 -pydantic/_internal/_fields.py,sha256=tFmaX47Q2z8QCCPJ4K8MrPfgKDztx9clntzPxBv0OKo,23205 -pydantic/_internal/_forward_ref.py,sha256=5n3Y7-3AKLn8_FS3Yc7KutLiPUhyXmAtkEZOaFnonwM,611 -pydantic/_internal/_generate_schema.py,sha256=LWJsmvNdWDh1QxY4WelsFSw1_nScPwEfJdpwMZH5V4k,133821 -pydantic/_internal/_generics.py,sha256=D1_0xgqnL6TJQe_fFyaSk2Ug_F-kT_jRBfLjHFLCIqQ,23849 -pydantic/_internal/_git.py,sha256=IwPh3DPfa2Xq3rBuB9Nx8luR2A1i69QdeTfWWXIuCVg,809 -pydantic/_internal/_import_utils.py,sha256=TRhxD5OuY6CUosioBdBcJUs0om7IIONiZdYAV7zQ8jM,402 -pydantic/_internal/_internal_dataclass.py,sha256=_bedc1XbuuygRGiLZqkUkwwFpQaoR1hKLlR501nyySY,144 -pydantic/_internal/_known_annotated_metadata.py,sha256=lYAPiUhfSgfpY6qH9xJPJTEMoowv27QmcyOgQzys90U,16213 -pydantic/_internal/_mock_val_ser.py,sha256=wmRRFSBvqfcLbI41PsFliB4u2AZ3mJpZeiERbD3xKTo,8885 -pydantic/_internal/_model_construction.py,sha256=2Qa5Y4EgBojkhsVHu0OjpphUIlWYuVXMg1KC2opc00s,35228 -pydantic/_internal/_namespace_utils.py,sha256=CMG7nEAXVb-Idqyd3CgdulRrM-zEXOPe3kYEDBqnSKw,12878 -pydantic/_internal/_repr.py,sha256=t7GNyaUU8xvqwlDHxVE2IyDeaNZrK7p01ojQPP0UI_o,5081 -pydantic/_internal/_schema_gather.py,sha256=VLEv51TYEeeND2czsyrmJq1MVnJqTOmnLan7VG44c8A,9114 -pydantic/_internal/_schema_generation_shared.py,sha256=F_rbQbrkoomgxsskdHpP0jUJ7TCfe0BADAEkq6CJ4nM,4842 -pydantic/_internal/_serializers.py,sha256=qQ3Rak4J6bqbnjGCRjiAY4M8poLo0s5qH46sXZSQQuA,1474 -pydantic/_internal/_signature.py,sha256=8EljPJe4pSnapuirG5DkBAgD1hggHxEAyzFPH-9H0zE,6779 -pydantic/_internal/_typing_extra.py,sha256=PO3u2JmX3JKlTFy0Ew95iyjAgYHgJsqqskev4zooB2I,28216 -pydantic/_internal/_utils.py,sha256=iRmCSO0uoFhAL_ChHaYSCKrswpSrRHYoO_YQSFfCJxU,15344 -pydantic/_internal/_validate_call.py,sha256=PfdVnSzhXOrENtaDoDw3PFWPVYD5W_gNYPe8p3Ug6Lg,5321 -pydantic/_internal/_validators.py,sha256=TJcR9bxcPXjzntN6Qgib8cyPRkFZQxHW32SoKGEcp0k,20610 -pydantic/_migration.py,sha256=_6VCCVWNYB7fDpbP2MqW4bXXqo17C5_J907u9zNJQbM,11907 -pydantic/alias_generators.py,sha256=KM1n3u4JfLSBl1UuYg3hoYHzXJD-yvgrnq8u1ccwh_A,2124 -pydantic/aliases.py,sha256=vhCHyoSWnX-EJ-wWb5qj4xyRssgGWnTQfzQp4GSZ9ug,4937 -pydantic/annotated_handlers.py,sha256=WfyFSqwoEIFXBh7T73PycKloI1DiX45GWi0-JOsCR4Y,4407 -pydantic/class_validators.py,sha256=i_V3j-PYdGLSLmj_IJZekTRjunO8SIVz8LMlquPyP7E,148 -pydantic/color.py,sha256=AzqGfVQHF92_ZctDcue0DM4yTp2P6tekkwRINTWrLIo,21481 -pydantic/config.py,sha256=roz_FbfFPoVpJVpB1G7dJ8A3swghQjdN-ozrBxbLShM,42048 -pydantic/dataclasses.py,sha256=K2e76b_Cj1yvwcwfJVR7nQnLoPdetVig5yHVMGuzkpE,16644 -pydantic/datetime_parse.py,sha256=QC-WgMxMr_wQ_mNXUS7AVf-2hLEhvvsPY1PQyhSGOdk,150 -pydantic/decorator.py,sha256=YX-jUApu5AKaVWKPoaV-n-4l7UbS69GEt9Ra3hszmKI,145 -pydantic/deprecated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic/deprecated/__pycache__/__init__.cpython-312.pyc,, -pydantic/deprecated/__pycache__/class_validators.cpython-312.pyc,, -pydantic/deprecated/__pycache__/config.cpython-312.pyc,, -pydantic/deprecated/__pycache__/copy_internals.cpython-312.pyc,, -pydantic/deprecated/__pycache__/decorator.cpython-312.pyc,, -pydantic/deprecated/__pycache__/json.cpython-312.pyc,, -pydantic/deprecated/__pycache__/parse.cpython-312.pyc,, -pydantic/deprecated/__pycache__/tools.cpython-312.pyc,, -pydantic/deprecated/class_validators.py,sha256=rwfP165xity36foy1NNCg4Jf9Sul44sJLW-A5sseahI,10245 -pydantic/deprecated/config.py,sha256=k_lsVk57paxLJOcBueH07cu1OgEgWdVBxm6lfaC3CCU,2663 -pydantic/deprecated/copy_internals.py,sha256=Ku0LHLEU0WcoIInNHls7PjuBvpLFTQ4Uus77jQ3Yi08,7616 -pydantic/deprecated/decorator.py,sha256=TBm6bJ7wJsNih_8Wq5IzDcwP32m9_vfxs96desLuk00,10845 -pydantic/deprecated/json.py,sha256=HlWCG35RRrxyzuTS6LTQiZBwRhmDZWmeqQH8rLW6wA8,4657 -pydantic/deprecated/parse.py,sha256=Gzd6b_g8zJXcuE7QRq5adhx_EMJahXfcpXCF0RgrqqI,2511 -pydantic/deprecated/tools.py,sha256=Nrm9oFRZWp8-jlfvPgJILEsywp4YzZD52XIGPDLxHcI,3330 -pydantic/env_settings.py,sha256=6IHeeWEqlUPRUv3V-AXiF_W91fg2Jw_M3O0l34J_eyA,148 -pydantic/error_wrappers.py,sha256=RK6mqATc9yMD-KBD9IJS9HpKCprWHd8wo84Bnm-3fR8,150 -pydantic/errors.py,sha256=7ctBNCtt57kZFx71Ls2H86IufQARv4wPKf8DhdsVn5w,6002 -pydantic/experimental/__init__.py,sha256=j08eROfz-xW4k_X9W4m2AW26IVdyF3Eg1OzlIGA11vk,328 -pydantic/experimental/__pycache__/__init__.cpython-312.pyc,, -pydantic/experimental/__pycache__/arguments_schema.cpython-312.pyc,, -pydantic/experimental/__pycache__/pipeline.cpython-312.pyc,, -pydantic/experimental/arguments_schema.py,sha256=EFnjX_ulp-tPyUjQX5pmQtug1OFL_Acc8bcMbLd-fVY,1866 -pydantic/experimental/pipeline.py,sha256=znbMBvir3xvPA20Xj8Moco1oJMPf1VYVrIQ8KQNtDlM,23910 -pydantic/fields.py,sha256=9Ky1nTKaMhThaNkVEkJOFHQHGq2FCKSwA6-zwUB-KWo,64416 -pydantic/functional_serializers.py,sha256=3m81unH3lYovdMi00oZywlHhn1KDz9X2CO3iTtBya6A,17102 -pydantic/functional_validators.py,sha256=-yY6uj_9_GAI4aqqfZlzyGdzs06huzy6zNWD7TJp3_0,29560 -pydantic/generics.py,sha256=0ZqZ9O9annIj_3mGBRqps4htey3b5lV1-d2tUxPMMnA,144 -pydantic/json.py,sha256=ZH8RkI7h4Bz-zp8OdTAxbJUoVvcoU-jhMdRZ0B-k0xc,140 -pydantic/json_schema.py,sha256=KhsS_MWPox0PYqklnhJcb_3uiCVrEOgyhG53cUZv6QA,115430 -pydantic/main.py,sha256=v67a4-nFooC-GJ1oHgS__Vm399Ygp_NH-1WzHXwjFM0,81012 -pydantic/mypy.py,sha256=OG7AqM_6vuTxRnBPU27eUkfX5wShU6aD0dJGmMhLaN8,59265 -pydantic/networks.py,sha256=_YpSnBR2kMfoWX76sdq34cfCH-MWr5or0ve0tow7OWo,41446 -pydantic/parse.py,sha256=wkd82dgtvWtD895U_I6E1htqMlGhBSYEV39cuBSeo3A,141 -pydantic/plugin/__init__.py,sha256=5cXMmu5xL4LVZhWPE1XD8ozHZ-qEC2-s4seLe8tbN_Y,6965 -pydantic/plugin/__pycache__/__init__.cpython-312.pyc,, -pydantic/plugin/__pycache__/_loader.cpython-312.pyc,, -pydantic/plugin/__pycache__/_schema_validator.cpython-312.pyc,, -pydantic/plugin/_loader.py,sha256=nI3SEKr0mlCB556kvbyBXjYQw9b_s8UTKE9Q6iESX6s,2167 -pydantic/plugin/_schema_validator.py,sha256=QbmqsG33MBmftNQ2nNiuN22LhbrexUA7ipDVv3J02BU,5267 -pydantic/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic/root_model.py,sha256=SCXhpRCgZgfqE9AGVJTC7kMAojKffL7PV4i0qcwOMm0,6279 -pydantic/schema.py,sha256=Vqqjvq_LnapVknebUd3Bp_J1p2gXZZnZRgL48bVEG7o,142 -pydantic/tools.py,sha256=iHQpd8SJ5DCTtPV5atAV06T89bjSaMFeZZ2LX9lasZY,141 -pydantic/type_adapter.py,sha256=Y3NE0YhFwxwoqrYU9caWymLWp1Avq4sRUdb5s01RoJk,31171 -pydantic/types.py,sha256=mWTvQH_Wt_CccQcEHYjcUWpyoj1U04WOnrMsMYod_64,104781 -pydantic/typing.py,sha256=P7feA35MwTcLsR1uL7db0S-oydBxobmXa55YDoBgajQ,138 -pydantic/utils.py,sha256=15nR2QpqTBFlQV4TNtTItMyTJx_fbyV-gPmIEY1Gooc,141 -pydantic/v1/__init__.py,sha256=SxQPklgBs4XHJwE6BZ9qoewYoGiNyYUnmHzEFCZbfnI,2946 -pydantic/v1/__pycache__/__init__.cpython-312.pyc,, -pydantic/v1/__pycache__/_hypothesis_plugin.cpython-312.pyc,, -pydantic/v1/__pycache__/annotated_types.cpython-312.pyc,, -pydantic/v1/__pycache__/class_validators.cpython-312.pyc,, -pydantic/v1/__pycache__/color.cpython-312.pyc,, -pydantic/v1/__pycache__/config.cpython-312.pyc,, -pydantic/v1/__pycache__/dataclasses.cpython-312.pyc,, -pydantic/v1/__pycache__/datetime_parse.cpython-312.pyc,, -pydantic/v1/__pycache__/decorator.cpython-312.pyc,, -pydantic/v1/__pycache__/env_settings.cpython-312.pyc,, -pydantic/v1/__pycache__/error_wrappers.cpython-312.pyc,, -pydantic/v1/__pycache__/errors.cpython-312.pyc,, -pydantic/v1/__pycache__/fields.cpython-312.pyc,, -pydantic/v1/__pycache__/generics.cpython-312.pyc,, -pydantic/v1/__pycache__/json.cpython-312.pyc,, -pydantic/v1/__pycache__/main.cpython-312.pyc,, -pydantic/v1/__pycache__/mypy.cpython-312.pyc,, -pydantic/v1/__pycache__/networks.cpython-312.pyc,, -pydantic/v1/__pycache__/parse.cpython-312.pyc,, -pydantic/v1/__pycache__/schema.cpython-312.pyc,, -pydantic/v1/__pycache__/tools.cpython-312.pyc,, -pydantic/v1/__pycache__/types.cpython-312.pyc,, -pydantic/v1/__pycache__/typing.cpython-312.pyc,, -pydantic/v1/__pycache__/utils.cpython-312.pyc,, -pydantic/v1/__pycache__/validators.cpython-312.pyc,, -pydantic/v1/__pycache__/version.cpython-312.pyc,, -pydantic/v1/_hypothesis_plugin.py,sha256=5ES5xWuw1FQAsymLezy8QgnVz0ZpVfU3jkmT74H27VQ,14847 -pydantic/v1/annotated_types.py,sha256=uk2NAAxqiNELKjiHhyhxKaIOh8F1lYW_LzrW3X7oZBc,3157 -pydantic/v1/class_validators.py,sha256=ULOaIUgYUDBsHL7EEVEarcM-UubKUggoN8hSbDonsFE,14672 -pydantic/v1/color.py,sha256=iZABLYp6OVoo2AFkP9Ipri_wSc6-Kklu8YuhSartd5g,16844 -pydantic/v1/config.py,sha256=a6P0Wer9x4cbwKW7Xv8poSUqM4WP-RLWwX6YMpYq9AA,6532 -pydantic/v1/dataclasses.py,sha256=784cqvInbwIPWr9usfpX3ch7z4t3J2tTK6N067_wk1o,18172 -pydantic/v1/datetime_parse.py,sha256=4Qy1kQpq3rNVZJeIHeSPDpuS2Bvhp1KPtzJG1xu-H00,7724 -pydantic/v1/decorator.py,sha256=zaaxxxoWPCm818D1bs0yhapRjXm32V8G0ZHWCdM1uXA,10339 -pydantic/v1/env_settings.py,sha256=A9VXwtRl02AY-jH0C0ouy5VNw3fi6F_pkzuHDjgAAOM,14105 -pydantic/v1/error_wrappers.py,sha256=6625Mfw9qkC2NwitB_JFAWe8B-Xv6zBU7rL9k28tfyo,5196 -pydantic/v1/errors.py,sha256=mIwPED5vGM5Q5v4C4Z1JPldTRH-omvEylH6ksMhOmPw,17726 -pydantic/v1/fields.py,sha256=VqWJCriUNiEyptXroDVJ501JpVA0en2VANcksqXL2b8,50649 -pydantic/v1/generics.py,sha256=VzC9YUV-EbPpQ3aAfk1cNFej79_IzznkQ7WrmTTZS9E,17871 -pydantic/v1/json.py,sha256=WQ5Hy_hIpfdR3YS8k6N2E6KMJzsdbBi_ldWOPJaV81M,3390 -pydantic/v1/main.py,sha256=zuNpdN5Q0V0wG2UUTKt0HUy3XJ4OAvPSZDdiXY-FIzs,44824 -pydantic/v1/mypy.py,sha256=TsnGYsg0zR2CtzIGVHgAsBDz6VndFTDnGXG6cLoDNkY,38949 -pydantic/v1/networks.py,sha256=HYNtKAfOmOnKJpsDg1g6SIkj9WPhU_-i8l5e2JKBpG4,22124 -pydantic/v1/parse.py,sha256=BJtdqiZRtav9VRFCmOxoY-KImQmjPy-A_NoojiFUZxY,1821 -pydantic/v1/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic/v1/schema.py,sha256=aqBuA--cq8gAVkim5BJPFASHzOZ8dFtmFX_fNGr6ip4,47801 -pydantic/v1/tools.py,sha256=1lDdXHk0jL5uP3u5RCYAvUAlGClgAO-45lkq9j7fyBA,2881 -pydantic/v1/types.py,sha256=Fltx5GoP_qaUmAktlGz7nFeJa13yNy3FY1-RcMzEVt8,35455 -pydantic/v1/typing.py,sha256=HNtuKvgH4EHIeb2ytkd7VSyG6mxP9RKqEqEql-1ab14,19720 -pydantic/v1/utils.py,sha256=M5FRyfNUb1A2mk9laGgCVdfHHb3AtQgrjO5qfyBf4xA,25989 -pydantic/v1/validators.py,sha256=lyUkn1MWhHxlCX5ZfEgFj_CAHojoiPcaQeMdEM9XviU,22187 -pydantic/v1/version.py,sha256=HXnXW-1bMW5qKhlr5RgOEPohrZDCDSuyy8-gi8GCgZo,1039 -pydantic/validate_call_decorator.py,sha256=8jqLlgXTjWEj4dXDg0wI3EGQKkb0JnCsL_JSUjbU5Sg,4389 -pydantic/validators.py,sha256=pwbIJXVb1CV2mAE4w_EGfNj7DwzsKaWw_tTL6cviTus,146 -pydantic/version.py,sha256=BxN15sODEuCNfCzfyH-02nNIqpuYxgBeMVz85m2urTg,2710 -pydantic/warnings.py,sha256=gqDTQ2FX7wGLZJV3XboQSiRXKHknss3pfIOXL0BDXTk,3772 diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/METADATA b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/METADATA new file mode 100644 index 0000000..3281200 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/METADATA @@ -0,0 +1,1337 @@ +Metadata-Version: 2.1 +Name: pydantic +Version: 2.4.2 +Summary: Data validation using Python type hints +Project-URL: Homepage, https://github.com/pydantic/pydantic +Project-URL: Documentation, https://docs.pydantic.dev +Project-URL: Funding, https://github.com/sponsors/samuelcolvin +Project-URL: Source, https://github.com/pydantic/pydantic +Project-URL: Changelog, https://docs.pydantic.dev/latest/changelog/ +Author-email: Samuel Colvin , Eric Jolibois , Hasan Ramezani , Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>, Terrence Dorsey , David Montague +License-Expression: MIT +License-File: LICENSE +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Environment :: MacOS X +Classifier: Framework :: Hypothesis +Classifier: Framework :: Pydantic +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Information Technology +Classifier: Intended Audience :: System Administrators +Classifier: License :: OSI Approved :: MIT License +Classifier: Operating System :: POSIX :: Linux +Classifier: Operating System :: Unix +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: Implementation :: CPython +Classifier: Programming Language :: Python :: Implementation :: PyPy +Classifier: Topic :: Internet +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.7 +Requires-Dist: annotated-types>=0.4.0 +Requires-Dist: pydantic-core==2.10.1 +Requires-Dist: typing-extensions>=4.6.1 +Provides-Extra: email +Requires-Dist: email-validator>=2.0.0; extra == 'email' +Description-Content-Type: text/markdown + +# Pydantic + +[![CI](https://github.com/pydantic/pydantic/workflows/CI/badge.svg?event=push)](https://github.com/pydantic/pydantic/actions?query=event%3Apush+branch%3Amain+workflow%3ACI) +[![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic) +[![pypi](https://img.shields.io/pypi/v/pydantic.svg)](https://pypi.python.org/pypi/pydantic) +[![CondaForge](https://img.shields.io/conda/v/conda-forge/pydantic.svg)](https://anaconda.org/conda-forge/pydantic) +[![downloads](https://static.pepy.tech/badge/pydantic/month)](https://pepy.tech/project/pydantic) +[![versions](https://img.shields.io/pypi/pyversions/pydantic.svg)](https://github.com/pydantic/pydantic) +[![license](https://img.shields.io/github/license/pydantic/pydantic.svg)](https://github.com/pydantic/pydantic/blob/main/LICENSE) +[![Pydantic v2](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/pydantic/pydantic/main/docs/badge/v2.json)](https://docs.pydantic.dev/latest/contributing/#badges) + +Data validation using Python type hints. + +Fast and extensible, Pydantic plays nicely with your linters/IDE/brain. +Define how data should be in pure, canonical Python 3.7+; validate it with Pydantic. + +## Pydantic Company :rocket: + +We've started a company based on the principles that I believe have led to Pydantic's success. +Learning more from the [Company Announcement](https://pydantic.dev/announcement/). + +## Pydantic V1.10 vs. V2 + +Pydantic V2 is a ground-up rewrite that offers many new features, performance improvements, and some breaking changes compared to Pydantic V1. + +If you're using Pydantic V1 you may want to look at the +[pydantic V1.10 Documentation](https://docs.pydantic.dev/) or, +[`1.10.X-fixes` git branch](https://github.com/pydantic/pydantic/tree/1.10.X-fixes). Pydantic V2 also ships with the latest version of Pydantic V1 built in so that you can incrementally upgrade your code base and projects: `from pydantic import v1 as pydantic_v1`. + +## Help + +See [documentation](https://docs.pydantic.dev/) for more details. + +## Installation + +Install using `pip install -U pydantic` or `conda install pydantic -c conda-forge`. +For more installation options to make Pydantic even faster, +see the [Install](https://docs.pydantic.dev/install/) section in the documentation. + +## A Simple Example + +```py +from datetime import datetime +from typing import List, Optional +from pydantic import BaseModel + +class User(BaseModel): + id: int + name: str = 'John Doe' + signup_ts: Optional[datetime] = None + friends: List[int] = [] + +external_data = {'id': '123', 'signup_ts': '2017-06-01 12:22', 'friends': [1, '2', b'3']} +user = User(**external_data) +print(user) +#> User id=123 name='John Doe' signup_ts=datetime.datetime(2017, 6, 1, 12, 22) friends=[1, 2, 3] +print(user.id) +#> 123 +``` + +## Contributing + +For guidance on setting up a development environment and how to make a +contribution to Pydantic, see +[Contributing to Pydantic](https://docs.pydantic.dev/contributing/). + +## Reporting a Security Vulnerability + +See our [security policy](https://github.com/pydantic/pydantic/security/policy). + +## Changelog + +## v2.4.2 (2023-09-27) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.4.2) + +### What's Changed + +#### Fixes + +* Fix bug with JSON schema for sequence of discriminated union by [@dmontagu](https://github.com/dmontagu) in [#7647](https://github.com/pydantic/pydantic/pull/7647) +* Fix schema references in discriminated unions by [@adriangb](https://github.com/adriangb) in [#7646](https://github.com/pydantic/pydantic/pull/7646) +* Fix json schema generation for recursive models by [@adriangb](https://github.com/adriangb) in [#7653](https://github.com/pydantic/pydantic/pull/7653) +* Fix `models_json_schema` for generic models by [@adriangb](https://github.com/adriangb) in [#7654](https://github.com/pydantic/pydantic/pull/7654) +* Fix xfailed test for generic model signatures by [@adriangb](https://github.com/adriangb) in [#7658](https://github.com/pydantic/pydantic/pull/7658) + +### New Contributors + +* [@austinorr](https://github.com/austinorr) made their first contribution in [#7657](https://github.com/pydantic/pydantic/pull/7657) +* [@peterHoburg](https://github.com/peterHoburg) made their first contribution in [#7670](https://github.com/pydantic/pydantic/pull/7670) + +## v2.4.1 (2023-09-26) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.4.1) + +### What's Changed + +#### Packaging + +* Update pydantic-core to 2.10.1 by [@davidhewitt](https://github.com/davidhewitt) in [#7633](https://github.com/pydantic/pydantic/pull/7633) + +#### Fixes + +* Serialize unsubstituted type vars as `Any` by [@adriangb](https://github.com/adriangb) in [#7606](https://github.com/pydantic/pydantic/pull/7606) +* Remove schema building caches by [@adriangb](https://github.com/adriangb) in [#7624](https://github.com/pydantic/pydantic/pull/7624) +* Fix an issue where JSON schema extras weren't JSON encoded by [@dmontagu](https://github.com/dmontagu) in [#7625](https://github.com/pydantic/pydantic/pull/7625) + +## v2.4.0 (2023-09-22) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.4.0) + +### What's Changed + +#### Packaging + +* Update pydantic-core to 2.10.0 by [@samuelcolvin](https://github.com/samuelcolvin) in [#7542](https://github.com/pydantic/pydantic/pull/7542) + +#### New Features + +* Add `Base64Url` types by [@dmontagu](https://github.com/dmontagu) in [#7286](https://github.com/pydantic/pydantic/pull/7286) +* Implement optional `number` to `str` coercion by [@lig](https://github.com/lig) in [#7508](https://github.com/pydantic/pydantic/pull/7508) +* Allow access to `field_name` and `data` in all validators if there is data and a field name by [@samuelcolvin](https://github.com/samuelcolvin) in [#7542](https://github.com/pydantic/pydantic/pull/7542) +* Add `BaseModel.model_validate_strings` and `TypeAdapter.validate_strings` by [@hramezani](https://github.com/hramezani) in [#7552](https://github.com/pydantic/pydantic/pull/7552) +* Add Pydantic `plugins` experimental implementation by [@lig](https://github.com/lig) [@samuelcolvin](https://github.com/samuelcolvin) and [@Kludex](https://github.com/Kludex) in [#6820](https://github.com/pydantic/pydantic/pull/6820) + +#### Changes + +* Do not override `model_post_init` in subclass with private attrs by [@Viicos](https://github.com/Viicos) in [#7302](https://github.com/pydantic/pydantic/pull/7302) +* Make fields with defaults not required in the serialization schema by default by [@dmontagu](https://github.com/dmontagu) in [#7275](https://github.com/pydantic/pydantic/pull/7275) +* Mark `Extra` as deprecated by [@disrupted](https://github.com/disrupted) in [#7299](https://github.com/pydantic/pydantic/pull/7299) +* Make `EncodedStr` a dataclass by [@Kludex](https://github.com/Kludex) in [#7396](https://github.com/pydantic/pydantic/pull/7396) +* Move `annotated_handlers` to be public by [@samuelcolvin](https://github.com/samuelcolvin) in [#7569](https://github.com/pydantic/pydantic/pull/7569) + +#### Performance + +* Simplify flattening and inlining of `CoreSchema` by [@adriangb](https://github.com/adriangb) in [#7523](https://github.com/pydantic/pydantic/pull/7523) +* Remove unused copies in `CoreSchema` walking by [@adriangb](https://github.com/adriangb) in [#7528](https://github.com/pydantic/pydantic/pull/7528) +* Add caches for collecting definitions and invalid schemas from a CoreSchema by [@adriangb](https://github.com/adriangb) in [#7527](https://github.com/pydantic/pydantic/pull/7527) +* Eagerly resolve discriminated unions and cache cases where we can't by [@adriangb](https://github.com/adriangb) in [#7529](https://github.com/pydantic/pydantic/pull/7529) +* Replace `dict.get` and `dict.setdefault` with more verbose versions in `CoreSchema` building hot paths by [@adriangb](https://github.com/adriangb) in [#7536](https://github.com/pydantic/pydantic/pull/7536) +* Cache invalid `CoreSchema` discovery by [@adriangb](https://github.com/adriangb) in [#7535](https://github.com/pydantic/pydantic/pull/7535) +* Allow disabling `CoreSchema` validation for faster startup times by [@adriangb](https://github.com/adriangb) in [#7565](https://github.com/pydantic/pydantic/pull/7565) + +#### Fixes + +* Fix config detection for `TypedDict` from grandparent classes by [@dmontagu](https://github.com/dmontagu) in [#7272](https://github.com/pydantic/pydantic/pull/7272) +* Fix hash function generation for frozen models with unusual MRO by [@dmontagu](https://github.com/dmontagu) in [#7274](https://github.com/pydantic/pydantic/pull/7274) +* Make `strict` config overridable in field for Path by [@hramezani](https://github.com/hramezani) in [#7281](https://github.com/pydantic/pydantic/pull/7281) +* Use `ser_json_` on default in `GenerateJsonSchema` by [@Kludex](https://github.com/Kludex) in [#7269](https://github.com/pydantic/pydantic/pull/7269) +* Adding a check that alias is validated as an identifier for Python by [@andree0](https://github.com/andree0) in [#7319](https://github.com/pydantic/pydantic/pull/7319) +* Raise an error when computed field overrides field by [@sydney-runkle](https://github.com/sydney-runkle) in [#7346](https://github.com/pydantic/pydantic/pull/7346) +* Fix applying `SkipValidation` to referenced schemas by [@adriangb](https://github.com/adriangb) in [#7381](https://github.com/pydantic/pydantic/pull/7381) +* Enforce behavior of private attributes having double leading underscore by [@lig](https://github.com/lig) in [#7265](https://github.com/pydantic/pydantic/pull/7265) +* Standardize `__get_pydantic_core_schema__` signature by [@hramezani](https://github.com/hramezani) in [#7415](https://github.com/pydantic/pydantic/pull/7415) +* Fix generic dataclass fields mutation bug (when using `TypeAdapter`) by [@sydney-runkle](https://github.com/sydney-runkle) in [#7435](https://github.com/pydantic/pydantic/pull/7435) +* Fix `TypeError` on `model_validator` in `wrap` mode by [@pmmmwh](https://github.com/pmmmwh) in [#7496](https://github.com/pydantic/pydantic/pull/7496) +* Improve enum error message by [@hramezani](https://github.com/hramezani) in [#7506](https://github.com/pydantic/pydantic/pull/7506) +* Make `repr` work for instances that failed initialization when handling `ValidationError`s by [@dmontagu](https://github.com/dmontagu) in [#7439](https://github.com/pydantic/pydantic/pull/7439) +* Fixed a regular expression denial of service issue by limiting whitespaces by [@prodigysml](https://github.com/prodigysml) in [#7360](https://github.com/pydantic/pydantic/pull/7360) +* Fix handling of `UUID` values having `UUID.version=None` by [@lig](https://github.com/lig) in [#7566](https://github.com/pydantic/pydantic/pull/7566) +* Fix `__iter__` returning private `cached_property` info by [@sydney-runkle](https://github.com/sydney-runkle) in [#7570](https://github.com/pydantic/pydantic/pull/7570) +* Improvements to version info message by [@samuelcolvin](https://github.com/samuelcolvin) in [#7594](https://github.com/pydantic/pydantic/pull/7594) + +### New Contributors +* [@15498th](https://github.com/15498th) made their first contribution in [#7238](https://github.com/pydantic/pydantic/pull/7238) +* [@GabrielCappelli](https://github.com/GabrielCappelli) made their first contribution in [#7213](https://github.com/pydantic/pydantic/pull/7213) +* [@tobni](https://github.com/tobni) made their first contribution in [#7184](https://github.com/pydantic/pydantic/pull/7184) +* [@redruin1](https://github.com/redruin1) made their first contribution in [#7282](https://github.com/pydantic/pydantic/pull/7282) +* [@FacerAin](https://github.com/FacerAin) made their first contribution in [#7288](https://github.com/pydantic/pydantic/pull/7288) +* [@acdha](https://github.com/acdha) made their first contribution in [#7297](https://github.com/pydantic/pydantic/pull/7297) +* [@andree0](https://github.com/andree0) made their first contribution in [#7319](https://github.com/pydantic/pydantic/pull/7319) +* [@gordonhart](https://github.com/gordonhart) made their first contribution in [#7375](https://github.com/pydantic/pydantic/pull/7375) +* [@pmmmwh](https://github.com/pmmmwh) made their first contribution in [#7496](https://github.com/pydantic/pydantic/pull/7496) +* [@disrupted](https://github.com/disrupted) made their first contribution in [#7299](https://github.com/pydantic/pydantic/pull/7299) +* [@prodigysml](https://github.com/prodigysml) made their first contribution in [#7360](https://github.com/pydantic/pydantic/pull/7360) + +## v2.3.0 (2023-08-23) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.3.0) + +* 🔥 Remove orphaned changes file from repo by [@lig](https://github.com/lig) in [#7168](https://github.com/pydantic/pydantic/pull/7168) +* Add copy button on documentation by [@Kludex](https://github.com/Kludex) in [#7190](https://github.com/pydantic/pydantic/pull/7190) +* Fix docs on JSON type by [@Kludex](https://github.com/Kludex) in [#7189](https://github.com/pydantic/pydantic/pull/7189) +* Update mypy 1.5.0 to 1.5.1 in CI by [@hramezani](https://github.com/hramezani) in [#7191](https://github.com/pydantic/pydantic/pull/7191) +* fix download links badge by [@samuelcolvin](https://github.com/samuelcolvin) in [#7200](https://github.com/pydantic/pydantic/pull/7200) +* add 2.2.1 to changelog by [@samuelcolvin](https://github.com/samuelcolvin) in [#7212](https://github.com/pydantic/pydantic/pull/7212) +* Make ModelWrapValidator protocols generic by [@dmontagu](https://github.com/dmontagu) in [#7154](https://github.com/pydantic/pydantic/pull/7154) +* Correct `Field(..., exclude: bool)` docs by [@samuelcolvin](https://github.com/samuelcolvin) in [#7214](https://github.com/pydantic/pydantic/pull/7214) +* Make shadowing attributes a warning instead of an error by [@adriangb](https://github.com/adriangb) in [#7193](https://github.com/pydantic/pydantic/pull/7193) +* Document `Base64Str` and `Base64Bytes` by [@Kludex](https://github.com/Kludex) in [#7192](https://github.com/pydantic/pydantic/pull/7192) +* Fix `config.defer_build` for serialization first cases by [@samuelcolvin](https://github.com/samuelcolvin) in [#7024](https://github.com/pydantic/pydantic/pull/7024) +* clean Model docstrings in JSON Schema by [@samuelcolvin](https://github.com/samuelcolvin) in [#7210](https://github.com/pydantic/pydantic/pull/7210) +* fix [#7228](https://github.com/pydantic/pydantic/pull/7228) (typo): docs in `validators.md` to correct `validate_default` kwarg by [@lmmx](https://github.com/lmmx) in [#7229](https://github.com/pydantic/pydantic/pull/7229) +* ✅ Implement `tzinfo.fromutc` method for `TzInfo` in `pydantic-core` by [@lig](https://github.com/lig) in [#7019](https://github.com/pydantic/pydantic/pull/7019) +* Support `__get_validators__` by [@hramezani](https://github.com/hramezani) in [#7197](https://github.com/pydantic/pydantic/pull/7197) + +## v2.2.1 (2023-08-18) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.2.1) + +* Make `xfail`ing test for root model extra stop `xfail`ing by [@dmontagu](https://github.com/dmontagu) in [#6937](https://github.com/pydantic/pydantic/pull/6937) +* Optimize recursion detection by stopping on the second visit for the same object by [@mciucu](https://github.com/mciucu) in [#7160](https://github.com/pydantic/pydantic/pull/7160) +* fix link in docs by [@tlambert03](https://github.com/tlambert03) in [#7166](https://github.com/pydantic/pydantic/pull/7166) +* Replace MiMalloc w/ default allocator by [@adriangb](https://github.com/adriangb) in [pydantic/pydantic-core#900](https://github.com/pydantic/pydantic-core/pull/900) +* Bump pydantic-core to 2.6.1 and prepare 2.2.1 release by [@adriangb](https://github.com/adriangb) in [#7176](https://github.com/pydantic/pydantic/pull/7176) + +## v2.2.0 (2023-08-17) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.2.0) + +* Split "pipx install" setup command into two commands on the documentation site by [@nomadmtb](https://github.com/nomadmtb) in [#6869](https://github.com/pydantic/pydantic/pull/6869) +* Deprecate `Field.include` by [@hramezani](https://github.com/hramezani) in [#6852](https://github.com/pydantic/pydantic/pull/6852) +* Fix typo in default factory error msg by [@hramezani](https://github.com/hramezani) in [#6880](https://github.com/pydantic/pydantic/pull/6880) +* Simplify handling of typing.Annotated in GenerateSchema by [@dmontagu](https://github.com/dmontagu) in [#6887](https://github.com/pydantic/pydantic/pull/6887) +* Re-enable fastapi tests in CI by [@dmontagu](https://github.com/dmontagu) in [#6883](https://github.com/pydantic/pydantic/pull/6883) +* Make it harder to hit collisions with json schema defrefs by [@dmontagu](https://github.com/dmontagu) in [#6566](https://github.com/pydantic/pydantic/pull/6566) +* Cleaner error for invalid input to `Path` fields by [@samuelcolvin](https://github.com/samuelcolvin) in [#6903](https://github.com/pydantic/pydantic/pull/6903) +* :memo: support Coordinate Type by [@yezz123](https://github.com/yezz123) in [#6906](https://github.com/pydantic/pydantic/pull/6906) +* Fix `ForwardRef` wrapper for py 3.10.0 (shim until bpo-45166) by [@randomir](https://github.com/randomir) in [#6919](https://github.com/pydantic/pydantic/pull/6919) +* Fix misbehavior related to copying of RootModel by [@dmontagu](https://github.com/dmontagu) in [#6918](https://github.com/pydantic/pydantic/pull/6918) +* Fix issue with recursion error caused by ParamSpec by [@dmontagu](https://github.com/dmontagu) in [#6923](https://github.com/pydantic/pydantic/pull/6923) +* Add section about Constrained classes to the Migration Guide by [@Kludex](https://github.com/Kludex) in [#6924](https://github.com/pydantic/pydantic/pull/6924) +* Use `main` branch for badge links by [@Viicos](https://github.com/Viicos) in [#6925](https://github.com/pydantic/pydantic/pull/6925) +* Add test for v1/v2 Annotated discrepancy by [@carlbordum](https://github.com/carlbordum) in [#6926](https://github.com/pydantic/pydantic/pull/6926) +* Make the v1 mypy plugin work with both v1 and v2 by [@dmontagu](https://github.com/dmontagu) in [#6921](https://github.com/pydantic/pydantic/pull/6921) +* Fix issue where generic models couldn't be parametrized with BaseModel by [@dmontagu](https://github.com/dmontagu) in [#6933](https://github.com/pydantic/pydantic/pull/6933) +* Remove xfail for discriminated union with alias by [@dmontagu](https://github.com/dmontagu) in [#6938](https://github.com/pydantic/pydantic/pull/6938) +* add field_serializer to computed_field by [@andresliszt](https://github.com/andresliszt) in [#6965](https://github.com/pydantic/pydantic/pull/6965) +* Use union_schema with Type[Union[...]] by [@JeanArhancet](https://github.com/JeanArhancet) in [#6952](https://github.com/pydantic/pydantic/pull/6952) +* Fix inherited typeddict attributes / config by [@adriangb](https://github.com/adriangb) in [#6981](https://github.com/pydantic/pydantic/pull/6981) +* fix dataclass annotated before validator called twice by [@davidhewitt](https://github.com/davidhewitt) in [#6998](https://github.com/pydantic/pydantic/pull/6998) +* Update test-fastapi deselected tests by [@hramezani](https://github.com/hramezani) in [#7014](https://github.com/pydantic/pydantic/pull/7014) +* Fix validator doc format by [@hramezani](https://github.com/hramezani) in [#7015](https://github.com/pydantic/pydantic/pull/7015) +* Fix typo in docstring of model_json_schema by [@AdamVinch-Federated](https://github.com/AdamVinch-Federated) in [#7032](https://github.com/pydantic/pydantic/pull/7032) +* remove unused "type ignores" with pyright by [@samuelcolvin](https://github.com/samuelcolvin) in [#7026](https://github.com/pydantic/pydantic/pull/7026) +* Add benchmark representing FastAPI startup time by [@adriangb](https://github.com/adriangb) in [#7030](https://github.com/pydantic/pydantic/pull/7030) +* Fix json_encoders for Enum subclasses by [@adriangb](https://github.com/adriangb) in [#7029](https://github.com/pydantic/pydantic/pull/7029) +* Update docstring of `ser_json_bytes` regarding base64 encoding by [@Viicos](https://github.com/Viicos) in [#7052](https://github.com/pydantic/pydantic/pull/7052) +* Allow `@validate_call` to work on async methods by [@adriangb](https://github.com/adriangb) in [#7046](https://github.com/pydantic/pydantic/pull/7046) +* Fix: mypy error with `Settings` and `SettingsConfigDict` by [@JeanArhancet](https://github.com/JeanArhancet) in [#7002](https://github.com/pydantic/pydantic/pull/7002) +* Fix some typos (repeated words and it's/its) by [@eumiro](https://github.com/eumiro) in [#7063](https://github.com/pydantic/pydantic/pull/7063) +* Fix the typo in docstring by [@harunyasar](https://github.com/harunyasar) in [#7062](https://github.com/pydantic/pydantic/pull/7062) +* Docs: Fix broken URL in the pydantic-settings package recommendation by [@swetjen](https://github.com/swetjen) in [#6995](https://github.com/pydantic/pydantic/pull/6995) +* Handle constraints being applied to schemas that don't accept it by [@adriangb](https://github.com/adriangb) in [#6951](https://github.com/pydantic/pydantic/pull/6951) +* Replace almost_equal_floats with math.isclose by [@eumiro](https://github.com/eumiro) in [#7082](https://github.com/pydantic/pydantic/pull/7082) +* bump pydantic-core to 2.5.0 by [@davidhewitt](https://github.com/davidhewitt) in [#7077](https://github.com/pydantic/pydantic/pull/7077) +* Add `short_version` and use it in links by [@hramezani](https://github.com/hramezani) in [#7115](https://github.com/pydantic/pydantic/pull/7115) +* 📝 Add usage link to `RootModel` by [@Kludex](https://github.com/Kludex) in [#7113](https://github.com/pydantic/pydantic/pull/7113) +* Revert "Fix default port for mongosrv DSNs (#6827)" by [@Kludex](https://github.com/Kludex) in [#7116](https://github.com/pydantic/pydantic/pull/7116) +* Clarify validate_default and _Unset handling in usage docs and migration guide by [@benbenbang](https://github.com/benbenbang) in [#6950](https://github.com/pydantic/pydantic/pull/6950) +* Tweak documentation of `Field.exclude` by [@Viicos](https://github.com/Viicos) in [#7086](https://github.com/pydantic/pydantic/pull/7086) +* Do not require `validate_assignment` to use `Field.frozen` by [@Viicos](https://github.com/Viicos) in [#7103](https://github.com/pydantic/pydantic/pull/7103) +* tweaks to `_core_utils` by [@samuelcolvin](https://github.com/samuelcolvin) in [#7040](https://github.com/pydantic/pydantic/pull/7040) +* Make DefaultDict working with set by [@hramezani](https://github.com/hramezani) in [#7126](https://github.com/pydantic/pydantic/pull/7126) +* Don't always require typing.Generic as a base for partially parametrized models by [@dmontagu](https://github.com/dmontagu) in [#7119](https://github.com/pydantic/pydantic/pull/7119) +* Fix issue with JSON schema incorrectly using parent class core schema by [@dmontagu](https://github.com/dmontagu) in [#7020](https://github.com/pydantic/pydantic/pull/7020) +* Fix xfailed test related to TypedDict and alias_generator by [@dmontagu](https://github.com/dmontagu) in [#6940](https://github.com/pydantic/pydantic/pull/6940) +* Improve error message for NameEmail by [@dmontagu](https://github.com/dmontagu) in [#6939](https://github.com/pydantic/pydantic/pull/6939) +* Fix generic computed fields by [@dmontagu](https://github.com/dmontagu) in [#6988](https://github.com/pydantic/pydantic/pull/6988) +* Reflect namedtuple default values during validation by [@dmontagu](https://github.com/dmontagu) in [#7144](https://github.com/pydantic/pydantic/pull/7144) +* Update dependencies, fix pydantic-core usage, fix CI issues by [@dmontagu](https://github.com/dmontagu) in [#7150](https://github.com/pydantic/pydantic/pull/7150) +* Add mypy 1.5.0 by [@hramezani](https://github.com/hramezani) in [#7118](https://github.com/pydantic/pydantic/pull/7118) +* Handle non-json native enum values by [@adriangb](https://github.com/adriangb) in [#7056](https://github.com/pydantic/pydantic/pull/7056) +* document `round_trip` in Json type documentation by [@jc-louis](https://github.com/jc-louis) in [#7137](https://github.com/pydantic/pydantic/pull/7137) +* Relax signature checks to better support builtins and C extension functions as validators by [@adriangb](https://github.com/adriangb) in [#7101](https://github.com/pydantic/pydantic/pull/7101) +* add union_mode='left_to_right' by [@davidhewitt](https://github.com/davidhewitt) in [#7151](https://github.com/pydantic/pydantic/pull/7151) +* Include an error message hint for inherited ordering by [@yvalencia91](https://github.com/yvalencia91) in [#7124](https://github.com/pydantic/pydantic/pull/7124) +* Fix one docs link and resolve some warnings for two others by [@dmontagu](https://github.com/dmontagu) in [#7153](https://github.com/pydantic/pydantic/pull/7153) +* Include Field extra keys name in warning by [@hramezani](https://github.com/hramezani) in [#7136](https://github.com/pydantic/pydantic/pull/7136) + +## v2.1.1 (2023-07-25) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.1.1) + +* Skip FieldInfo merging when unnecessary by [@dmontagu](https://github.com/dmontagu) in [#6862](https://github.com/pydantic/pydantic/pull/6862) + +## v2.1.0 (2023-07-25) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.1.0) + +* Add `StringConstraints` for use as Annotated metadata by [@adriangb](https://github.com/adriangb) in [#6605](https://github.com/pydantic/pydantic/pull/6605) +* Try to fix intermittently failing CI by [@adriangb](https://github.com/adriangb) in [#6683](https://github.com/pydantic/pydantic/pull/6683) +* Remove redundant example of optional vs default. by [@ehiggs-deliverect](https://github.com/ehiggs-deliverect) in [#6676](https://github.com/pydantic/pydantic/pull/6676) +* Docs update by [@samuelcolvin](https://github.com/samuelcolvin) in [#6692](https://github.com/pydantic/pydantic/pull/6692) +* Remove the Validate always section in validator docs by [@adriangb](https://github.com/adriangb) in [#6679](https://github.com/pydantic/pydantic/pull/6679) +* Fix recursion error in json schema generation by [@adriangb](https://github.com/adriangb) in [#6720](https://github.com/pydantic/pydantic/pull/6720) +* Fix incorrect subclass check for secretstr by [@AlexVndnblcke](https://github.com/AlexVndnblcke) in [#6730](https://github.com/pydantic/pydantic/pull/6730) +* update pdm / pdm lockfile to 2.8.0 by [@davidhewitt](https://github.com/davidhewitt) in [#6714](https://github.com/pydantic/pydantic/pull/6714) +* unpin pdm on more CI jobs by [@davidhewitt](https://github.com/davidhewitt) in [#6755](https://github.com/pydantic/pydantic/pull/6755) +* improve source locations for auxiliary packages in docs by [@davidhewitt](https://github.com/davidhewitt) in [#6749](https://github.com/pydantic/pydantic/pull/6749) +* Assume builtins don't accept an info argument by [@adriangb](https://github.com/adriangb) in [#6754](https://github.com/pydantic/pydantic/pull/6754) +* Fix bug where calling `help(BaseModelSubclass)` raises errors by [@hramezani](https://github.com/hramezani) in [#6758](https://github.com/pydantic/pydantic/pull/6758) +* Fix mypy plugin handling of `@model_validator(mode="after")` by [@ljodal](https://github.com/ljodal) in [#6753](https://github.com/pydantic/pydantic/pull/6753) +* update pydantic-core to 2.3.1 by [@davidhewitt](https://github.com/davidhewitt) in [#6756](https://github.com/pydantic/pydantic/pull/6756) +* Mypy plugin for settings by [@hramezani](https://github.com/hramezani) in [#6760](https://github.com/pydantic/pydantic/pull/6760) +* Use `contentSchema` keyword for JSON schema by [@dmontagu](https://github.com/dmontagu) in [#6715](https://github.com/pydantic/pydantic/pull/6715) +* fast-path checking finite decimals by [@davidhewitt](https://github.com/davidhewitt) in [#6769](https://github.com/pydantic/pydantic/pull/6769) +* Docs update by [@samuelcolvin](https://github.com/samuelcolvin) in [#6771](https://github.com/pydantic/pydantic/pull/6771) +* Improve json schema doc by [@hramezani](https://github.com/hramezani) in [#6772](https://github.com/pydantic/pydantic/pull/6772) +* Update validator docs by [@adriangb](https://github.com/adriangb) in [#6695](https://github.com/pydantic/pydantic/pull/6695) +* Fix typehint for wrap validator by [@dmontagu](https://github.com/dmontagu) in [#6788](https://github.com/pydantic/pydantic/pull/6788) +* 🐛 Fix validation warning for unions of Literal and other type by [@lig](https://github.com/lig) in [#6628](https://github.com/pydantic/pydantic/pull/6628) +* Update documentation for generics support in V2 by [@tpdorsey](https://github.com/tpdorsey) in [#6685](https://github.com/pydantic/pydantic/pull/6685) +* add pydantic-core build info to `version_info()` by [@samuelcolvin](https://github.com/samuelcolvin) in [#6785](https://github.com/pydantic/pydantic/pull/6785) +* Fix pydantic dataclasses that use slots with default values by [@dmontagu](https://github.com/dmontagu) in [#6796](https://github.com/pydantic/pydantic/pull/6796) +* Fix inheritance of hash function for frozen models by [@dmontagu](https://github.com/dmontagu) in [#6789](https://github.com/pydantic/pydantic/pull/6789) +* ✨ Add `SkipJsonSchema` annotation by [@Kludex](https://github.com/Kludex) in [#6653](https://github.com/pydantic/pydantic/pull/6653) +* Error if an invalid field name is used with Field by [@dmontagu](https://github.com/dmontagu) in [#6797](https://github.com/pydantic/pydantic/pull/6797) +* Add `GenericModel` to `MOVED_IN_V2` by [@adriangb](https://github.com/adriangb) in [#6776](https://github.com/pydantic/pydantic/pull/6776) +* Remove unused code from `docs/usage/types/custom.md` by [@hramezani](https://github.com/hramezani) in [#6803](https://github.com/pydantic/pydantic/pull/6803) +* Fix `float` -> `Decimal` coercion precision loss by [@adriangb](https://github.com/adriangb) in [#6810](https://github.com/pydantic/pydantic/pull/6810) +* remove email validation from the north star benchmark by [@davidhewitt](https://github.com/davidhewitt) in [#6816](https://github.com/pydantic/pydantic/pull/6816) +* Fix link to mypy by [@progsmile](https://github.com/progsmile) in [#6824](https://github.com/pydantic/pydantic/pull/6824) +* Improve initialization hooks example by [@hramezani](https://github.com/hramezani) in [#6822](https://github.com/pydantic/pydantic/pull/6822) +* Fix default port for mongosrv DSNs by [@dmontagu](https://github.com/dmontagu) in [#6827](https://github.com/pydantic/pydantic/pull/6827) +* Improve API documentation, in particular more links between usage and API docs by [@samuelcolvin](https://github.com/samuelcolvin) in [#6780](https://github.com/pydantic/pydantic/pull/6780) +* update pydantic-core to 2.4.0 by [@davidhewitt](https://github.com/davidhewitt) in [#6831](https://github.com/pydantic/pydantic/pull/6831) +* Fix `annotated_types.MaxLen` validator for custom sequence types by [@ImogenBits](https://github.com/ImogenBits) in [#6809](https://github.com/pydantic/pydantic/pull/6809) +* Update V1 by [@hramezani](https://github.com/hramezani) in [#6833](https://github.com/pydantic/pydantic/pull/6833) +* Make it so callable JSON schema extra works by [@dmontagu](https://github.com/dmontagu) in [#6798](https://github.com/pydantic/pydantic/pull/6798) +* Fix serialization issue with `InstanceOf` by [@dmontagu](https://github.com/dmontagu) in [#6829](https://github.com/pydantic/pydantic/pull/6829) +* Add back support for `json_encoders` by [@adriangb](https://github.com/adriangb) in [#6811](https://github.com/pydantic/pydantic/pull/6811) +* Update field annotations when building the schema by [@dmontagu](https://github.com/dmontagu) in [#6838](https://github.com/pydantic/pydantic/pull/6838) +* Use `WeakValueDictionary` to fix generic memory leak by [@dmontagu](https://github.com/dmontagu) in [#6681](https://github.com/pydantic/pydantic/pull/6681) +* Add `config.defer_build` to optionally make model building lazy by [@samuelcolvin](https://github.com/samuelcolvin) in [#6823](https://github.com/pydantic/pydantic/pull/6823) +* delegate `UUID` serialization to pydantic-core by [@davidhewitt](https://github.com/davidhewitt) in [#6850](https://github.com/pydantic/pydantic/pull/6850) +* Update `json_encoders` docs by [@adriangb](https://github.com/adriangb) in [#6848](https://github.com/pydantic/pydantic/pull/6848) +* Fix error message for `staticmethod`/`classmethod` order with validate_call by [@dmontagu](https://github.com/dmontagu) in [#6686](https://github.com/pydantic/pydantic/pull/6686) +* Improve documentation for `Config` by [@samuelcolvin](https://github.com/samuelcolvin) in [#6847](https://github.com/pydantic/pydantic/pull/6847) +* Update serialization doc to mention `Field.exclude` takes priority over call-time `include/exclude` by [@hramezani](https://github.com/hramezani) in [#6851](https://github.com/pydantic/pydantic/pull/6851) +* Allow customizing core schema generation by making `GenerateSchema` public by [@adriangb](https://github.com/adriangb) in [#6737](https://github.com/pydantic/pydantic/pull/6737) + +## v2.0.3 (2023-07-05) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.0.3) + +* Mention PyObject (v1) moving to ImportString (v2) in migration doc by [@slafs](https://github.com/slafs) in [#6456](https://github.com/pydantic/pydantic/pull/6456) +* Fix release-tweet CI by [@Kludex](https://github.com/Kludex) in [#6461](https://github.com/pydantic/pydantic/pull/6461) +* Revise the section on required / optional / nullable fields. by [@ybressler](https://github.com/ybressler) in [#6468](https://github.com/pydantic/pydantic/pull/6468) +* Warn if a type hint is not in fact a type by [@adriangb](https://github.com/adriangb) in [#6479](https://github.com/pydantic/pydantic/pull/6479) +* Replace TransformSchema with GetPydanticSchema by [@dmontagu](https://github.com/dmontagu) in [#6484](https://github.com/pydantic/pydantic/pull/6484) +* Fix the un-hashability of various annotation types, for use in caching generic containers by [@dmontagu](https://github.com/dmontagu) in [#6480](https://github.com/pydantic/pydantic/pull/6480) +* PYD-164: Rework custom types docs by [@adriangb](https://github.com/adriangb) in [#6490](https://github.com/pydantic/pydantic/pull/6490) +* Fix ci by [@adriangb](https://github.com/adriangb) in [#6507](https://github.com/pydantic/pydantic/pull/6507) +* Fix forward ref in generic by [@adriangb](https://github.com/adriangb) in [#6511](https://github.com/pydantic/pydantic/pull/6511) +* Fix generation of serialization JSON schemas for core_schema.ChainSchema by [@dmontagu](https://github.com/dmontagu) in [#6515](https://github.com/pydantic/pydantic/pull/6515) +* Document the change in `Field.alias` behavior in Pydantic V2 by [@hramezani](https://github.com/hramezani) in [#6508](https://github.com/pydantic/pydantic/pull/6508) +* Give better error message attempting to compute the json schema of a model with undefined fields by [@dmontagu](https://github.com/dmontagu) in [#6519](https://github.com/pydantic/pydantic/pull/6519) +* Document `alias_priority` by [@tpdorsey](https://github.com/tpdorsey) in [#6520](https://github.com/pydantic/pydantic/pull/6520) +* Add redirect for types documentation by [@tpdorsey](https://github.com/tpdorsey) in [#6513](https://github.com/pydantic/pydantic/pull/6513) +* Allow updating docs without release by [@samuelcolvin](https://github.com/samuelcolvin) in [#6551](https://github.com/pydantic/pydantic/pull/6551) +* Ensure docs tests always run in the right folder by [@dmontagu](https://github.com/dmontagu) in [#6487](https://github.com/pydantic/pydantic/pull/6487) +* Defer evaluation of return type hints for serializer functions by [@dmontagu](https://github.com/dmontagu) in [#6516](https://github.com/pydantic/pydantic/pull/6516) +* Disable E501 from Ruff and rely on just Black by [@adriangb](https://github.com/adriangb) in [#6552](https://github.com/pydantic/pydantic/pull/6552) +* Update JSON Schema documentation for V2 by [@tpdorsey](https://github.com/tpdorsey) in [#6492](https://github.com/pydantic/pydantic/pull/6492) +* Add documentation of cyclic reference handling by [@dmontagu](https://github.com/dmontagu) in [#6493](https://github.com/pydantic/pydantic/pull/6493) +* Remove the need for change files by [@samuelcolvin](https://github.com/samuelcolvin) in [#6556](https://github.com/pydantic/pydantic/pull/6556) +* add "north star" benchmark by [@davidhewitt](https://github.com/davidhewitt) in [#6547](https://github.com/pydantic/pydantic/pull/6547) +* Update Dataclasses docs by [@tpdorsey](https://github.com/tpdorsey) in [#6470](https://github.com/pydantic/pydantic/pull/6470) +* ♻️ Use different error message on v1 redirects by [@Kludex](https://github.com/Kludex) in [#6595](https://github.com/pydantic/pydantic/pull/6595) +* ⬆ Upgrade `pydantic-core` to v2.2.0 by [@lig](https://github.com/lig) in [#6589](https://github.com/pydantic/pydantic/pull/6589) +* Fix serialization for IPvAny by [@dmontagu](https://github.com/dmontagu) in [#6572](https://github.com/pydantic/pydantic/pull/6572) +* Improve CI by using PDM instead of pip to install typing-extensions by [@adriangb](https://github.com/adriangb) in [#6602](https://github.com/pydantic/pydantic/pull/6602) +* Add `enum` error type docs by [@lig](https://github.com/lig) in [#6603](https://github.com/pydantic/pydantic/pull/6603) +* 🐛 Fix `max_length` for unicode strings by [@lig](https://github.com/lig) in [#6559](https://github.com/pydantic/pydantic/pull/6559) +* Add documentation for accessing features via `pydantic.v1` by [@tpdorsey](https://github.com/tpdorsey) in [#6604](https://github.com/pydantic/pydantic/pull/6604) +* Include extra when iterating over a model by [@adriangb](https://github.com/adriangb) in [#6562](https://github.com/pydantic/pydantic/pull/6562) +* Fix typing of model_validator by [@adriangb](https://github.com/adriangb) in [#6514](https://github.com/pydantic/pydantic/pull/6514) +* Touch up Decimal validator by [@adriangb](https://github.com/adriangb) in [#6327](https://github.com/pydantic/pydantic/pull/6327) +* Fix various docstrings using fixed pytest-examples by [@dmontagu](https://github.com/dmontagu) in [#6607](https://github.com/pydantic/pydantic/pull/6607) +* Handle function validators in a discriminated union by [@dmontagu](https://github.com/dmontagu) in [#6570](https://github.com/pydantic/pydantic/pull/6570) +* Review json_schema.md by [@tpdorsey](https://github.com/tpdorsey) in [#6608](https://github.com/pydantic/pydantic/pull/6608) +* Make validate_call work on basemodel methods by [@dmontagu](https://github.com/dmontagu) in [#6569](https://github.com/pydantic/pydantic/pull/6569) +* add test for big int json serde by [@davidhewitt](https://github.com/davidhewitt) in [#6614](https://github.com/pydantic/pydantic/pull/6614) +* Fix pydantic dataclass problem with dataclasses.field default_factory by [@hramezani](https://github.com/hramezani) in [#6616](https://github.com/pydantic/pydantic/pull/6616) +* Fixed mypy type inference for TypeAdapter by [@zakstucke](https://github.com/zakstucke) in [#6617](https://github.com/pydantic/pydantic/pull/6617) +* Make it work to use None as a generic parameter by [@dmontagu](https://github.com/dmontagu) in [#6609](https://github.com/pydantic/pydantic/pull/6609) +* Make it work to use `$ref` as an alias by [@dmontagu](https://github.com/dmontagu) in [#6568](https://github.com/pydantic/pydantic/pull/6568) +* add note to migration guide about changes to `AnyUrl` etc by [@davidhewitt](https://github.com/davidhewitt) in [#6618](https://github.com/pydantic/pydantic/pull/6618) +* 🐛 Support defining `json_schema_extra` on `RootModel` using `Field` by [@lig](https://github.com/lig) in [#6622](https://github.com/pydantic/pydantic/pull/6622) +* Update pre-commit to prevent commits to main branch on accident by [@dmontagu](https://github.com/dmontagu) in [#6636](https://github.com/pydantic/pydantic/pull/6636) +* Fix PDM CI for python 3.7 on MacOS/windows by [@dmontagu](https://github.com/dmontagu) in [#6627](https://github.com/pydantic/pydantic/pull/6627) +* Produce more accurate signatures for pydantic dataclasses by [@dmontagu](https://github.com/dmontagu) in [#6633](https://github.com/pydantic/pydantic/pull/6633) +* Updates to Url types for Pydantic V2 by [@tpdorsey](https://github.com/tpdorsey) in [#6638](https://github.com/pydantic/pydantic/pull/6638) +* Fix list markdown in `transform` docstring by [@StefanBRas](https://github.com/StefanBRas) in [#6649](https://github.com/pydantic/pydantic/pull/6649) +* simplify slots_dataclass construction to appease mypy by [@davidhewitt](https://github.com/davidhewitt) in [#6639](https://github.com/pydantic/pydantic/pull/6639) +* Update TypedDict schema generation docstring by [@adriangb](https://github.com/adriangb) in [#6651](https://github.com/pydantic/pydantic/pull/6651) +* Detect and lint-error for prints by [@dmontagu](https://github.com/dmontagu) in [#6655](https://github.com/pydantic/pydantic/pull/6655) +* Add xfailing test for pydantic-core PR 766 by [@dmontagu](https://github.com/dmontagu) in [#6641](https://github.com/pydantic/pydantic/pull/6641) +* Ignore unrecognized fields from dataclasses metadata by [@dmontagu](https://github.com/dmontagu) in [#6634](https://github.com/pydantic/pydantic/pull/6634) +* Make non-existent class getattr a mypy error by [@dmontagu](https://github.com/dmontagu) in [#6658](https://github.com/pydantic/pydantic/pull/6658) +* Update pydantic-core to 2.3.0 by [@hramezani](https://github.com/hramezani) in [#6648](https://github.com/pydantic/pydantic/pull/6648) +* Use OrderedDict from typing_extensions by [@dmontagu](https://github.com/dmontagu) in [#6664](https://github.com/pydantic/pydantic/pull/6664) +* Fix typehint for JSON schema extra callable by [@dmontagu](https://github.com/dmontagu) in [#6659](https://github.com/pydantic/pydantic/pull/6659) + +## v2.0.2 (2023-07-05) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.0.2) + +* Fix bug where round-trip pickling/unpickling a `RootModel` would change the value of `__dict__`, [#6457](https://github.com/pydantic/pydantic/pull/6457) by [@dmontagu](https://github.com/dmontagu) +* Allow single-item discriminated unions, [#6405](https://github.com/pydantic/pydantic/pull/6405) by [@dmontagu](https://github.com/dmontagu) +* Fix issue with union parsing of enums, [#6440](https://github.com/pydantic/pydantic/pull/6440) by [@dmontagu](https://github.com/dmontagu) +* Docs: Fixed `constr` documentation, renamed old `regex` to new `pattern`, [#6452](https://github.com/pydantic/pydantic/pull/6452) by [@miili](https://github.com/miili) +* Change `GenerateJsonSchema.generate_definitions` signature, [#6436](https://github.com/pydantic/pydantic/pull/6436) by [@dmontagu](https://github.com/dmontagu) + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0.2) + +## v2.0.1 (2023-07-04) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.0.1) + +First patch release of Pydantic V2 + +* Extra fields added via `setattr` (i.e. `m.some_extra_field = 'extra_value'`) + are added to `.model_extra` if `model_config` `extra='allowed'`. Fixed [#6333](https://github.com/pydantic/pydantic/pull/6333), [#6365](https://github.com/pydantic/pydantic/pull/6365) by [@aaraney](https://github.com/aaraney) +* Automatically unpack JSON schema '$ref' for custom types, [#6343](https://github.com/pydantic/pydantic/pull/6343) by [@adriangb](https://github.com/adriangb) +* Fix tagged unions multiple processing in submodels, [#6340](https://github.com/pydantic/pydantic/pull/6340) by [@suharnikov](https://github.com/suharnikov) + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0.1) + +## v2.0 (2023-06-30) + +[GitHub release](https://github.com/pydantic/pydantic/releases/tag/v2.0) + +Pydantic V2 is here! :tada: + +See [this post](https://docs.pydantic.dev/2.0/blog/pydantic-v2-final/) for more details. + +## v2.0b3 (2023-06-16) + +Third beta pre-release of Pydantic V2 + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0b3) + +## v2.0b2 (2023-06-03) + +Add `from_attributes` runtime flag to `TypeAdapter.validate_python` and `BaseModel.model_validate`. + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0b2) + +## v2.0b1 (2023-06-01) + +First beta pre-release of Pydantic V2 + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0b1) + +## v2.0a4 (2023-05-05) + +Fourth pre-release of Pydantic V2 + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0a4) + +## v2.0a3 (2023-04-20) + +Third pre-release of Pydantic V2 + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0a3) + +## v2.0a2 (2023-04-12) + +Second pre-release of Pydantic V2 + +See the full changelog [here](https://github.com/pydantic/pydantic/releases/tag/v2.0a2) + +## v2.0a1 (2023-04-03) + +First pre-release of Pydantic V2! + +See [this post](https://docs.pydantic.dev/blog/pydantic-v2-alpha/) for more details. + +## v1.10.13 (2023-09-27) + +* Fix: Add max length check to `pydantic.validate_email`, [#7673](https://github.com/pydantic/pydantic/issues/7673) by [@hramezani](https://github.com/hramezani) +* Docs: Fix pip commands to install v1, [#6930](https://github.com/pydantic/pydantic/issues/6930) by [@chbndrhnns](https://github.com/chbndrhnns) + +## v1.10.12 (2023-07-24) + +* Fixes the `maxlen` property being dropped on `deque` validation. Happened only if the deque item has been typed. Changes the `_validate_sequence_like` func, [#6581](https://github.com/pydantic/pydantic/pull/6581) by [@maciekglowka](https://github.com/maciekglowka) + +## v1.10.11 (2023-07-04) + +* Importing create_model in tools.py through relative path instead of absolute path - so that it doesn't import V2 code when copied over to V2 branch, [#6361](https://github.com/pydantic/pydantic/pull/6361) by [@SharathHuddar](https://github.com/SharathHuddar) + +## v1.10.10 (2023-06-30) + +* Add Pydantic `Json` field support to settings management, [#6250](https://github.com/pydantic/pydantic/pull/6250) by [@hramezani](https://github.com/hramezani) +* Fixed literal validator errors for unhashable values, [#6188](https://github.com/pydantic/pydantic/pull/6188) by [@markus1978](https://github.com/markus1978) +* Fixed bug with generics receiving forward refs, [#6130](https://github.com/pydantic/pydantic/pull/6130) by [@mark-todd](https://github.com/mark-todd) +* Update install method of FastAPI for internal tests in CI, [#6117](https://github.com/pydantic/pydantic/pull/6117) by [@Kludex](https://github.com/Kludex) + +## v1.10.9 (2023-06-07) + +* Fix trailing zeros not ignored in Decimal validation, [#5968](https://github.com/pydantic/pydantic/pull/5968) by [@hramezani](https://github.com/hramezani) +* Fix mypy plugin for v1.4.0, [#5928](https://github.com/pydantic/pydantic/pull/5928) by [@cdce8p](https://github.com/cdce8p) +* Add future and past date hypothesis strategies, [#5850](https://github.com/pydantic/pydantic/pull/5850) by [@bschoenmaeckers](https://github.com/bschoenmaeckers) +* Discourage usage of Cython 3 with Pydantic 1.x, [#5845](https://github.com/pydantic/pydantic/pull/5845) by [@lig](https://github.com/lig) + +## v1.10.8 (2023-05-23) + +* Fix a bug in `Literal` usage with `typing-extension==4.6.0`, [#5826](https://github.com/pydantic/pydantic/pull/5826) by [@hramezani](https://github.com/hramezani) +* This solves the (closed) issue [#3849](https://github.com/pydantic/pydantic/pull/3849) where aliased fields that use discriminated union fail to validate when the data contains the non-aliased field name, [#5736](https://github.com/pydantic/pydantic/pull/5736) by [@benwah](https://github.com/benwah) +* Update email-validator dependency to >=2.0.0post2, [#5627](https://github.com/pydantic/pydantic/pull/5627) by [@adriangb](https://github.com/adriangb) +* update `AnyClassMethod` for changes in [python/typeshed#9771](https://github.com/python/typeshed/issues/9771), [#5505](https://github.com/pydantic/pydantic/pull/5505) by [@ITProKyle](https://github.com/ITProKyle) + +## v1.10.7 (2023-03-22) + +* Fix creating schema from model using `ConstrainedStr` with `regex` as dict key, [#5223](https://github.com/pydantic/pydantic/pull/5223) by [@matejetz](https://github.com/matejetz) +* Address bug in mypy plugin caused by explicit_package_bases=True, [#5191](https://github.com/pydantic/pydantic/pull/5191) by [@dmontagu](https://github.com/dmontagu) +* Add implicit defaults in the mypy plugin for Field with no default argument, [#5190](https://github.com/pydantic/pydantic/pull/5190) by [@dmontagu](https://github.com/dmontagu) +* Fix schema generated for Enum values used as Literals in discriminated unions, [#5188](https://github.com/pydantic/pydantic/pull/5188) by [@javibookline](https://github.com/javibookline) +* Fix mypy failures caused by the pydantic mypy plugin when users define `from_orm` in their own classes, [#5187](https://github.com/pydantic/pydantic/pull/5187) by [@dmontagu](https://github.com/dmontagu) +* Fix `InitVar` usage with pydantic dataclasses, mypy version `1.1.1` and the custom mypy plugin, [#5162](https://github.com/pydantic/pydantic/pull/5162) by [@cdce8p](https://github.com/cdce8p) + +## v1.10.6 (2023-03-08) + +* Implement logic to support creating validators from non standard callables by using defaults to identify them and unwrapping `functools.partial` and `functools.partialmethod` when checking the signature, [#5126](https://github.com/pydantic/pydantic/pull/5126) by [@JensHeinrich](https://github.com/JensHeinrich) +* Fix mypy plugin for v1.1.1, and fix `dataclass_transform` decorator for pydantic dataclasses, [#5111](https://github.com/pydantic/pydantic/pull/5111) by [@cdce8p](https://github.com/cdce8p) +* Raise `ValidationError`, not `ConfigError`, when a discriminator value is unhashable, [#4773](https://github.com/pydantic/pydantic/pull/4773) by [@kurtmckee](https://github.com/kurtmckee) + +## v1.10.5 (2023-02-15) + +* Fix broken parametrized bases handling with `GenericModel`s with complex sets of models, [#5052](https://github.com/pydantic/pydantic/pull/5052) by [@MarkusSintonen](https://github.com/MarkusSintonen) +* Invalidate mypy cache if plugin config changes, [#5007](https://github.com/pydantic/pydantic/pull/5007) by [@cdce8p](https://github.com/cdce8p) +* Fix `RecursionError` when deep-copying dataclass types wrapped by pydantic, [#4949](https://github.com/pydantic/pydantic/pull/4949) by [@mbillingr](https://github.com/mbillingr) +* Fix `X | Y` union syntax breaking `GenericModel`, [#4146](https://github.com/pydantic/pydantic/pull/4146) by [@thenx](https://github.com/thenx) +* Switch coverage badge to show coverage for this branch/release, [#5060](https://github.com/pydantic/pydantic/pull/5060) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.10.4 (2022-12-30) + +* Change dependency to `typing-extensions>=4.2.0`, [#4885](https://github.com/pydantic/pydantic/pull/4885) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.10.3 (2022-12-29) + +**NOTE: v1.10.3 was ["yanked"](https://pypi.org/help/#yanked) from PyPI due to [#4885](https://github.com/pydantic/pydantic/pull/4885) which is fixed in v1.10.4** + +* fix parsing of custom root models, [#4883](https://github.com/pydantic/pydantic/pull/4883) by [@gou177](https://github.com/gou177) +* fix: use dataclass proxy for frozen or empty dataclasses, [#4878](https://github.com/pydantic/pydantic/pull/4878) by [@PrettyWood](https://github.com/PrettyWood) +* Fix `schema` and `schema_json` on models where a model instance is a one of default values, [#4781](https://github.com/pydantic/pydantic/pull/4781) by [@Bobronium](https://github.com/Bobronium) +* Add Jina AI to sponsors on docs index page, [#4767](https://github.com/pydantic/pydantic/pull/4767) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix: support assignment on `DataclassProxy`, [#4695](https://github.com/pydantic/pydantic/pull/4695) by [@PrettyWood](https://github.com/PrettyWood) +* Add `postgresql+psycopg` as allowed scheme for `PostgreDsn` to make it usable with SQLAlchemy 2, [#4689](https://github.com/pydantic/pydantic/pull/4689) by [@morian](https://github.com/morian) +* Allow dict schemas to have both `patternProperties` and `additionalProperties`, [#4641](https://github.com/pydantic/pydantic/pull/4641) by [@jparise](https://github.com/jparise) +* Fixes error passing None for optional lists with `unique_items`, [#4568](https://github.com/pydantic/pydantic/pull/4568) by [@mfulgo](https://github.com/mfulgo) +* Fix `GenericModel` with `Callable` param raising a `TypeError`, [#4551](https://github.com/pydantic/pydantic/pull/4551) by [@mfulgo](https://github.com/mfulgo) +* Fix field regex with `StrictStr` type annotation, [#4538](https://github.com/pydantic/pydantic/pull/4538) by [@sisp](https://github.com/sisp) +* Correct `dataclass_transform` keyword argument name from `field_descriptors` to `field_specifiers`, [#4500](https://github.com/pydantic/pydantic/pull/4500) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix: avoid multiple calls of `__post_init__` when dataclasses are inherited, [#4487](https://github.com/pydantic/pydantic/pull/4487) by [@PrettyWood](https://github.com/PrettyWood) +* Reduce the size of binary wheels, [#2276](https://github.com/pydantic/pydantic/pull/2276) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.10.2 (2022-09-05) + +* **Revert Change:** Revert percent encoding of URL parts which was originally added in [#4224](https://github.com/pydantic/pydantic/pull/4224), [#4470](https://github.com/pydantic/pydantic/pull/4470) by [@samuelcolvin](https://github.com/samuelcolvin) +* Prevent long (length > `4_300`) strings/bytes as input to int fields, see + [python/cpython#95778](https://github.com/python/cpython/issues/95778) and + [CVE-2020-10735](https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2020-10735), [#1477](https://github.com/pydantic/pydantic/pull/1477) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix: dataclass wrapper was not always called, [#4477](https://github.com/pydantic/pydantic/pull/4477) by [@PrettyWood](https://github.com/PrettyWood) +* Use `tomllib` on Python 3.11 when parsing `mypy` configuration, [#4476](https://github.com/pydantic/pydantic/pull/4476) by [@hauntsaninja](https://github.com/hauntsaninja) +* Basic fix of `GenericModel` cache to detect order of arguments in `Union` models, [#4474](https://github.com/pydantic/pydantic/pull/4474) by [@sveinugu](https://github.com/sveinugu) +* Fix mypy plugin when using bare types like `list` and `dict` as `default_factory`, [#4457](https://github.com/pydantic/pydantic/pull/4457) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.10.1 (2022-08-31) + +* Add `__hash__` method to `pydancic.color.Color` class, [#4454](https://github.com/pydantic/pydantic/pull/4454) by [@czaki](https://github.com/czaki) + +## v1.10.0 (2022-08-30) + +* Refactor the whole _pydantic_ `dataclass` decorator to really act like its standard lib equivalent. + It hence keeps `__eq__`, `__hash__`, ... and makes comparison with its non-validated version possible. + It also fixes usage of `frozen` dataclasses in fields and usage of `default_factory` in nested dataclasses. + The support of `Config.extra` has been added. + Finally, config customization directly via a `dict` is now possible, [#2557](https://github.com/pydantic/pydantic/pull/2557) by [@PrettyWood](https://github.com/PrettyWood) +

+ **BREAKING CHANGES:** + - The `compiled` boolean (whether _pydantic_ is compiled with cython) has been moved from `main.py` to `version.py` + - Now that `Config.extra` is supported, `dataclass` ignores by default extra arguments (like `BaseModel`) +* Fix PEP487 `__set_name__` protocol in `BaseModel` for PrivateAttrs, [#4407](https://github.com/pydantic/pydantic/pull/4407) by [@tlambert03](https://github.com/tlambert03) +* Allow for custom parsing of environment variables via `parse_env_var` in `Config`, [#4406](https://github.com/pydantic/pydantic/pull/4406) by [@acmiyaguchi](https://github.com/acmiyaguchi) +* Rename `master` to `main`, [#4405](https://github.com/pydantic/pydantic/pull/4405) by [@hramezani](https://github.com/hramezani) +* Fix `StrictStr` does not raise `ValidationError` when `max_length` is present in `Field`, [#4388](https://github.com/pydantic/pydantic/pull/4388) by [@hramezani](https://github.com/hramezani) +* Make `SecretStr` and `SecretBytes` hashable, [#4387](https://github.com/pydantic/pydantic/pull/4387) by [@chbndrhnns](https://github.com/chbndrhnns) +* Fix `StrictBytes` does not raise `ValidationError` when `max_length` is present in `Field`, [#4380](https://github.com/pydantic/pydantic/pull/4380) by [@JeanArhancet](https://github.com/JeanArhancet) +* Add support for bare `type`, [#4375](https://github.com/pydantic/pydantic/pull/4375) by [@hramezani](https://github.com/hramezani) +* Support Python 3.11, including binaries for 3.11 in PyPI, [#4374](https://github.com/pydantic/pydantic/pull/4374) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add support for `re.Pattern`, [#4366](https://github.com/pydantic/pydantic/pull/4366) by [@hramezani](https://github.com/hramezani) +* Fix `__post_init_post_parse__` is incorrectly passed keyword arguments when no `__post_init__` is defined, [#4361](https://github.com/pydantic/pydantic/pull/4361) by [@hramezani](https://github.com/hramezani) +* Fix implicitly importing `ForwardRef` and `Callable` from `pydantic.typing` instead of `typing` and also expose `MappingIntStrAny`, [#4358](https://github.com/pydantic/pydantic/pull/4358) by [@aminalaee](https://github.com/aminalaee) +* remove `Any` types from the `dataclass` decorator so it can be used with the `disallow_any_expr` mypy option, [#4356](https://github.com/pydantic/pydantic/pull/4356) by [@DetachHead](https://github.com/DetachHead) +* moved repo to `pydantic/pydantic`, [#4348](https://github.com/pydantic/pydantic/pull/4348) by [@yezz123](https://github.com/yezz123) +* fix "extra fields not permitted" error when dataclass with `Extra.forbid` is validated multiple times, [#4343](https://github.com/pydantic/pydantic/pull/4343) by [@detachhead](https://github.com/detachhead) +* Add Python 3.9 and 3.10 examples to docs, [#4339](https://github.com/pydantic/pydantic/pull/4339) by [@Bobronium](https://github.com/Bobronium) +* Discriminated union models now use `oneOf` instead of `anyOf` when generating OpenAPI schema definitions, [#4335](https://github.com/pydantic/pydantic/pull/4335) by [@MaxwellPayne](https://github.com/MaxwellPayne) +* Allow type checkers to infer inner type of `Json` type. `Json[list[str]]` will be now inferred as `list[str]`, + `Json[Any]` should be used instead of plain `Json`. + Runtime behaviour is not changed, [#4332](https://github.com/pydantic/pydantic/pull/4332) by [@Bobronium](https://github.com/Bobronium) +* Allow empty string aliases by using a `alias is not None` check, rather than `bool(alias)`, [#4253](https://github.com/pydantic/pydantic/pull/4253) by [@sergeytsaplin](https://github.com/sergeytsaplin) +* Update `ForwardRef`s in `Field.outer_type_`, [#4249](https://github.com/pydantic/pydantic/pull/4249) by [@JacobHayes](https://github.com/JacobHayes) +* The use of `__dataclass_transform__` has been replaced by `typing_extensions.dataclass_transform`, which is the preferred way to mark pydantic models as a dataclass under [PEP 681](https://peps.python.org/pep-0681/), [#4241](https://github.com/pydantic/pydantic/pull/4241) by [@multimeric](https://github.com/multimeric) +* Use parent model's `Config` when validating nested `NamedTuple` fields, [#4219](https://github.com/pydantic/pydantic/pull/4219) by [@synek](https://github.com/synek) +* Update `BaseModel.construct` to work with aliased Fields, [#4192](https://github.com/pydantic/pydantic/pull/4192) by [@kylebamos](https://github.com/kylebamos) +* Catch certain raised errors in `smart_deepcopy` and revert to `deepcopy` if so, [#4184](https://github.com/pydantic/pydantic/pull/4184) by [@coneybeare](https://github.com/coneybeare) +* Add `Config.anystr_upper` and `to_upper` kwarg to constr and conbytes, [#4165](https://github.com/pydantic/pydantic/pull/4165) by [@satheler](https://github.com/satheler) +* Fix JSON schema for `set` and `frozenset` when they include default values, [#4155](https://github.com/pydantic/pydantic/pull/4155) by [@aminalaee](https://github.com/aminalaee) +* Teach the mypy plugin that methods decorated by `@validator` are classmethods, [#4102](https://github.com/pydantic/pydantic/pull/4102) by [@DMRobertson](https://github.com/DMRobertson) +* Improve mypy plugin's ability to detect required fields, [#4086](https://github.com/pydantic/pydantic/pull/4086) by [@richardxia](https://github.com/richardxia) +* Support fields of type `Type[]` in schema, [#4051](https://github.com/pydantic/pydantic/pull/4051) by [@aminalaee](https://github.com/aminalaee) +* Add `default` value in JSON Schema when `const=True`, [#4031](https://github.com/pydantic/pydantic/pull/4031) by [@aminalaee](https://github.com/aminalaee) +* Adds reserved word check to signature generation logic, [#4011](https://github.com/pydantic/pydantic/pull/4011) by [@strue36](https://github.com/strue36) +* Fix Json strategy failure for the complex nested field, [#4005](https://github.com/pydantic/pydantic/pull/4005) by [@sergiosim](https://github.com/sergiosim) +* Add JSON-compatible float constraint `allow_inf_nan`, [#3994](https://github.com/pydantic/pydantic/pull/3994) by [@tiangolo](https://github.com/tiangolo) +* Remove undefined behaviour when `env_prefix` had characters in common with `env_nested_delimiter`, [#3975](https://github.com/pydantic/pydantic/pull/3975) by [@arsenron](https://github.com/arsenron) +* Support generics model with `create_model`, [#3945](https://github.com/pydantic/pydantic/pull/3945) by [@hot123s](https://github.com/hot123s) +* allow submodels to overwrite extra field info, [#3934](https://github.com/pydantic/pydantic/pull/3934) by [@PrettyWood](https://github.com/PrettyWood) +* Document and test structural pattern matching ([PEP 636](https://peps.python.org/pep-0636/)) on `BaseModel`, [#3920](https://github.com/pydantic/pydantic/pull/3920) by [@irgolic](https://github.com/irgolic) +* Fix incorrect deserialization of python timedelta object to ISO 8601 for negative time deltas. + Minus was serialized in incorrect place ("P-1DT23H59M59.888735S" instead of correct "-P1DT23H59M59.888735S"), [#3899](https://github.com/pydantic/pydantic/pull/3899) by [@07pepa](https://github.com/07pepa) +* Fix validation of discriminated union fields with an alias when passing a model instance, [#3846](https://github.com/pydantic/pydantic/pull/3846) by [@chornsby](https://github.com/chornsby) +* Add a CockroachDsn type to validate CockroachDB connection strings. The type + supports the following schemes: `cockroachdb`, `cockroachdb+psycopg2` and `cockroachdb+asyncpg`, [#3839](https://github.com/pydantic/pydantic/pull/3839) by [@blubber](https://github.com/blubber) +* Fix MyPy plugin to not override pre-existing `__init__` method in models, [#3824](https://github.com/pydantic/pydantic/pull/3824) by [@patrick91](https://github.com/patrick91) +* Fix mypy version checking, [#3783](https://github.com/pydantic/pydantic/pull/3783) by [@KotlinIsland](https://github.com/KotlinIsland) +* support overwriting dunder attributes of `BaseModel` instances, [#3777](https://github.com/pydantic/pydantic/pull/3777) by [@PrettyWood](https://github.com/PrettyWood) +* Added `ConstrainedDate` and `condate`, [#3740](https://github.com/pydantic/pydantic/pull/3740) by [@hottwaj](https://github.com/hottwaj) +* Support `kw_only` in dataclasses, [#3670](https://github.com/pydantic/pydantic/pull/3670) by [@detachhead](https://github.com/detachhead) +* Add comparison method for `Color` class, [#3646](https://github.com/pydantic/pydantic/pull/3646) by [@aminalaee](https://github.com/aminalaee) +* Drop support for python3.6, associated cleanup, [#3605](https://github.com/pydantic/pydantic/pull/3605) by [@samuelcolvin](https://github.com/samuelcolvin) +* created new function `to_lower_camel()` for "non pascal case" camel case, [#3463](https://github.com/pydantic/pydantic/pull/3463) by [@schlerp](https://github.com/schlerp) +* Add checks to `default` and `default_factory` arguments in Mypy plugin, [#3430](https://github.com/pydantic/pydantic/pull/3430) by [@klaa97](https://github.com/klaa97) +* fix mangling of `inspect.signature` for `BaseModel`, [#3413](https://github.com/pydantic/pydantic/pull/3413) by [@fix-inspect-signature](https://github.com/fix-inspect-signature) +* Adds the `SecretField` abstract class so that all the current and future secret fields like `SecretStr` and `SecretBytes` will derive from it, [#3409](https://github.com/pydantic/pydantic/pull/3409) by [@expobrain](https://github.com/expobrain) +* Support multi hosts validation in `PostgresDsn`, [#3337](https://github.com/pydantic/pydantic/pull/3337) by [@rglsk](https://github.com/rglsk) +* Fix parsing of very small numeric timedelta values, [#3315](https://github.com/pydantic/pydantic/pull/3315) by [@samuelcolvin](https://github.com/samuelcolvin) +* Update `SecretsSettingsSource` to respect `config.case_sensitive`, [#3273](https://github.com/pydantic/pydantic/pull/3273) by [@JeanArhancet](https://github.com/JeanArhancet) +* Add MongoDB network data source name (DSN) schema, [#3229](https://github.com/pydantic/pydantic/pull/3229) by [@snosratiershad](https://github.com/snosratiershad) +* Add support for multiple dotenv files, [#3222](https://github.com/pydantic/pydantic/pull/3222) by [@rekyungmin](https://github.com/rekyungmin) +* Raise an explicit `ConfigError` when multiple fields are incorrectly set for a single validator, [#3215](https://github.com/pydantic/pydantic/pull/3215) by [@SunsetOrange](https://github.com/SunsetOrange) +* Allow ellipsis on `Field`s inside `Annotated` for `TypedDicts` required, [#3133](https://github.com/pydantic/pydantic/pull/3133) by [@ezegomez](https://github.com/ezegomez) +* Catch overflow errors in `int_validator`, [#3112](https://github.com/pydantic/pydantic/pull/3112) by [@ojii](https://github.com/ojii) +* Adds a `__rich_repr__` method to `Representation` class which enables pretty printing with [Rich](https://github.com/willmcgugan/rich), [#3099](https://github.com/pydantic/pydantic/pull/3099) by [@willmcgugan](https://github.com/willmcgugan) +* Add percent encoding in `AnyUrl` and descendent types, [#3061](https://github.com/pydantic/pydantic/pull/3061) by [@FaresAhmedb](https://github.com/FaresAhmedb) +* `validate_arguments` decorator now supports `alias`, [#3019](https://github.com/pydantic/pydantic/pull/3019) by [@MAD-py](https://github.com/MAD-py) +* Avoid `__dict__` and `__weakref__` attributes in `AnyUrl` and IP address fields, [#2890](https://github.com/pydantic/pydantic/pull/2890) by [@nuno-andre](https://github.com/nuno-andre) +* Add ability to use `Final` in a field type annotation, [#2766](https://github.com/pydantic/pydantic/pull/2766) by [@uriyyo](https://github.com/uriyyo) +* Update requirement to `typing_extensions>=4.1.0` to guarantee `dataclass_transform` is available, [#4424](https://github.com/pydantic/pydantic/pull/4424) by [@commonism](https://github.com/commonism) +* Add Explosion and AWS to main sponsors, [#4413](https://github.com/pydantic/pydantic/pull/4413) by [@samuelcolvin](https://github.com/samuelcolvin) +* Update documentation for `copy_on_model_validation` to reflect recent changes, [#4369](https://github.com/pydantic/pydantic/pull/4369) by [@samuelcolvin](https://github.com/samuelcolvin) +* Runtime warning if `__slots__` is passed to `create_model`, `__slots__` is then ignored, [#4432](https://github.com/pydantic/pydantic/pull/4432) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add type hints to `BaseSettings.Config` to avoid mypy errors, also correct mypy version compatibility notice in docs, [#4450](https://github.com/pydantic/pydantic/pull/4450) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.10.0b1 (2022-08-24) + +Pre-release, see [the GitHub release](https://github.com/pydantic/pydantic/releases/tag/v1.10.0b1) for details. + +## v1.10.0a2 (2022-08-24) + +Pre-release, see [the GitHub release](https://github.com/pydantic/pydantic/releases/tag/v1.10.0a2) for details. + +## v1.10.0a1 (2022-08-22) + +Pre-release, see [the GitHub release](https://github.com/pydantic/pydantic/releases/tag/v1.10.0a1) for details. + +## v1.9.2 (2022-08-11) + +**Revert Breaking Change**: _v1.9.1_ introduced a breaking change where model fields were +deep copied by default, this release reverts the default behaviour to match _v1.9.0_ and before, +while also allow deep-copy behaviour via `copy_on_model_validation = 'deep'`. See [#4092](https://github.com/pydantic/pydantic/pull/4092) for more information. + +* Allow for shallow copies of model fields, `Config.copy_on_model_validation` is now a str which must be + `'none'`, `'deep'`, or `'shallow'` corresponding to not copying, deep copy & shallow copy; default `'shallow'`, + [#4093](https://github.com/pydantic/pydantic/pull/4093) by [@timkpaine](https://github.com/timkpaine) + +## v1.9.1 (2022-05-19) + +Thank you to pydantic's sponsors: +[@tiangolo](https://github.com/tiangolo), [@stellargraph](https://github.com/stellargraph), [@JonasKs](https://github.com/JonasKs), [@grillazz](https://github.com/grillazz), [@Mazyod](https://github.com/Mazyod), [@kevinalh](https://github.com/kevinalh), [@chdsbd](https://github.com/chdsbd), [@povilasb](https://github.com/povilasb), [@povilasb](https://github.com/povilasb), [@jina-ai](https://github.com/jina-ai), +[@mainframeindustries](https://github.com/mainframeindustries), [@robusta-dev](https://github.com/robusta-dev), [@SendCloud](https://github.com/SendCloud), [@rszamszur](https://github.com/rszamszur), [@jodal](https://github.com/jodal), [@hardbyte](https://github.com/hardbyte), [@corleyma](https://github.com/corleyma), [@daddycocoaman](https://github.com/daddycocoaman), +[@Rehket](https://github.com/Rehket), [@jokull](https://github.com/jokull), [@reillysiemens](https://github.com/reillysiemens), [@westonsteimel](https://github.com/westonsteimel), [@primer-io](https://github.com/primer-io), [@koxudaxi](https://github.com/koxudaxi), [@browniebroke](https://github.com/browniebroke), [@stradivari96](https://github.com/stradivari96), +[@adriangb](https://github.com/adriangb), [@kamalgill](https://github.com/kamalgill), [@jqueguiner](https://github.com/jqueguiner), [@dev-zero](https://github.com/dev-zero), [@datarootsio](https://github.com/datarootsio), [@RedCarpetUp](https://github.com/RedCarpetUp) +for their kind support. + +* Limit the size of `generics._generic_types_cache` and `generics._assigned_parameters` + to avoid unlimited increase in memory usage, [#4083](https://github.com/pydantic/pydantic/pull/4083) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add Jupyverse and FPS as Jupyter projects using pydantic, [#4082](https://github.com/pydantic/pydantic/pull/4082) by [@davidbrochart](https://github.com/davidbrochart) +* Speedup `__isinstancecheck__` on pydantic models when the type is not a model, may also avoid memory "leaks", [#4081](https://github.com/pydantic/pydantic/pull/4081) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix in-place modification of `FieldInfo` that caused problems with PEP 593 type aliases, [#4067](https://github.com/pydantic/pydantic/pull/4067) by [@adriangb](https://github.com/adriangb) +* Add support for autocomplete in VS Code via `__dataclass_transform__` when using `pydantic.dataclasses.dataclass`, [#4006](https://github.com/pydantic/pydantic/pull/4006) by [@giuliano-oliveira](https://github.com/giuliano-oliveira) +* Remove benchmarks from codebase and docs, [#3973](https://github.com/pydantic/pydantic/pull/3973) by [@samuelcolvin](https://github.com/samuelcolvin) +* Typing checking with pyright in CI, improve docs on vscode/pylance/pyright, [#3972](https://github.com/pydantic/pydantic/pull/3972) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix nested Python dataclass schema regression, [#3819](https://github.com/pydantic/pydantic/pull/3819) by [@himbeles](https://github.com/himbeles) +* Update documentation about lazy evaluation of sources for Settings, [#3806](https://github.com/pydantic/pydantic/pull/3806) by [@garyd203](https://github.com/garyd203) +* Prevent subclasses of bytes being converted to bytes, [#3706](https://github.com/pydantic/pydantic/pull/3706) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fixed "error checking inheritance of" when using PEP585 and PEP604 type hints, [#3681](https://github.com/pydantic/pydantic/pull/3681) by [@aleksul](https://github.com/aleksul) +* Allow self referencing `ClassVar`s in models, [#3679](https://github.com/pydantic/pydantic/pull/3679) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change, see [#4106](https://github.com/pydantic/pydantic/pull/4106)**: Fix issue with self-referencing dataclass, [#3675](https://github.com/pydantic/pydantic/pull/3675) by [@uriyyo](https://github.com/uriyyo) +* Include non-standard port numbers in rendered URLs, [#3652](https://github.com/pydantic/pydantic/pull/3652) by [@dolfinus](https://github.com/dolfinus) +* `Config.copy_on_model_validation` does a deep copy and not a shallow one, [#3641](https://github.com/pydantic/pydantic/pull/3641) by [@PrettyWood](https://github.com/PrettyWood) +* fix: clarify that discriminated unions do not support singletons, [#3636](https://github.com/pydantic/pydantic/pull/3636) by [@tommilligan](https://github.com/tommilligan) +* Add `read_text(encoding='utf-8')` for `setup.py`, [#3625](https://github.com/pydantic/pydantic/pull/3625) by [@hswong3i](https://github.com/hswong3i) +* Fix JSON Schema generation for Discriminated Unions within lists, [#3608](https://github.com/pydantic/pydantic/pull/3608) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.9.0 (2021-12-31) + +Thank you to pydantic's sponsors: +[@sthagen](https://github.com/sthagen), [@timdrijvers](https://github.com/timdrijvers), [@toinbis](https://github.com/toinbis), [@koxudaxi](https://github.com/koxudaxi), [@ginomempin](https://github.com/ginomempin), [@primer-io](https://github.com/primer-io), [@and-semakin](https://github.com/and-semakin), [@westonsteimel](https://github.com/westonsteimel), [@reillysiemens](https://github.com/reillysiemens), +[@es3n1n](https://github.com/es3n1n), [@jokull](https://github.com/jokull), [@JonasKs](https://github.com/JonasKs), [@Rehket](https://github.com/Rehket), [@corleyma](https://github.com/corleyma), [@daddycocoaman](https://github.com/daddycocoaman), [@hardbyte](https://github.com/hardbyte), [@datarootsio](https://github.com/datarootsio), [@jodal](https://github.com/jodal), [@aminalaee](https://github.com/aminalaee), [@rafsaf](https://github.com/rafsaf), +[@jqueguiner](https://github.com/jqueguiner), [@chdsbd](https://github.com/chdsbd), [@kevinalh](https://github.com/kevinalh), [@Mazyod](https://github.com/Mazyod), [@grillazz](https://github.com/grillazz), [@JonasKs](https://github.com/JonasKs), [@simw](https://github.com/simw), [@leynier](https://github.com/leynier), [@xfenix](https://github.com/xfenix) +for their kind support. + +### Highlights + +* add Python 3.10 support, [#2885](https://github.com/pydantic/pydantic/pull/2885) by [@PrettyWood](https://github.com/PrettyWood) +* [Discriminated unions](https://docs.pydantic.dev/usage/types/#discriminated-unions-aka-tagged-unions), [#619](https://github.com/pydantic/pydantic/pull/619) by [@PrettyWood](https://github.com/PrettyWood) +* [`Config.smart_union` for better union logic](https://docs.pydantic.dev/usage/model_config/#smart-union), [#2092](https://github.com/pydantic/pydantic/pull/2092) by [@PrettyWood](https://github.com/PrettyWood) +* Binaries for Macos M1 CPUs, [#3498](https://github.com/pydantic/pydantic/pull/3498) by [@samuelcolvin](https://github.com/samuelcolvin) +* Complex types can be set via [nested environment variables](https://docs.pydantic.dev/usage/settings/#parsing-environment-variable-values), e.g. `foo___bar`, [#3159](https://github.com/pydantic/pydantic/pull/3159) by [@Air-Mark](https://github.com/Air-Mark) +* add a dark mode to _pydantic_ documentation, [#2913](https://github.com/pydantic/pydantic/pull/2913) by [@gbdlin](https://github.com/gbdlin) +* Add support for autocomplete in VS Code via `__dataclass_transform__`, [#2721](https://github.com/pydantic/pydantic/pull/2721) by [@tiangolo](https://github.com/tiangolo) +* Add "exclude" as a field parameter so that it can be configured using model config, [#660](https://github.com/pydantic/pydantic/pull/660) by [@daviskirk](https://github.com/daviskirk) + +### v1.9.0 (2021-12-31) Changes + +* Apply `update_forward_refs` to `Config.json_encodes` prevent name clashes in types defined via strings, [#3583](https://github.com/pydantic/pydantic/pull/3583) by [@samuelcolvin](https://github.com/samuelcolvin) +* Extend pydantic's mypy plugin to support mypy versions `0.910`, `0.920`, `0.921` & `0.930`, [#3573](https://github.com/pydantic/pydantic/pull/3573) & [#3594](https://github.com/pydantic/pydantic/pull/3594) by [@PrettyWood](https://github.com/PrettyWood), [@christianbundy](https://github.com/christianbundy), [@samuelcolvin](https://github.com/samuelcolvin) + +### v1.9.0a2 (2021-12-24) Changes + +* support generic models with discriminated union, [#3551](https://github.com/pydantic/pydantic/pull/3551) by [@PrettyWood](https://github.com/PrettyWood) +* keep old behaviour of `json()` by default, [#3542](https://github.com/pydantic/pydantic/pull/3542) by [@PrettyWood](https://github.com/PrettyWood) +* Removed typing-only `__root__` attribute from `BaseModel`, [#3540](https://github.com/pydantic/pydantic/pull/3540) by [@layday](https://github.com/layday) +* Build Python 3.10 wheels, [#3539](https://github.com/pydantic/pydantic/pull/3539) by [@mbachry](https://github.com/mbachry) +* Fix display of `extra` fields with model `__repr__`, [#3234](https://github.com/pydantic/pydantic/pull/3234) by [@cocolman](https://github.com/cocolman) +* models copied via `Config.copy_on_model_validation` always have all fields, [#3201](https://github.com/pydantic/pydantic/pull/3201) by [@PrettyWood](https://github.com/PrettyWood) +* nested ORM from nested dictionaries, [#3182](https://github.com/pydantic/pydantic/pull/3182) by [@PrettyWood](https://github.com/PrettyWood) +* fix link to discriminated union section by [@PrettyWood](https://github.com/PrettyWood) + +### v1.9.0a1 (2021-12-18) Changes + +* Add support for `Decimal`-specific validation configurations in `Field()`, additionally to using `condecimal()`, + to allow better support from editors and tooling, [#3507](https://github.com/pydantic/pydantic/pull/3507) by [@tiangolo](https://github.com/tiangolo) +* Add `arm64` binaries suitable for MacOS with an M1 CPU to PyPI, [#3498](https://github.com/pydantic/pydantic/pull/3498) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix issue where `None` was considered invalid when using a `Union` type containing `Any` or `object`, [#3444](https://github.com/pydantic/pydantic/pull/3444) by [@tharradine](https://github.com/tharradine) +* When generating field schema, pass optional `field` argument (of type + `pydantic.fields.ModelField`) to `__modify_schema__()` if present, [#3434](https://github.com/pydantic/pydantic/pull/3434) by [@jasujm](https://github.com/jasujm) +* Fix issue when pydantic fail to parse `typing.ClassVar` string type annotation, [#3401](https://github.com/pydantic/pydantic/pull/3401) by [@uriyyo](https://github.com/uriyyo) +* Mention Python >= 3.9.2 as an alternative to `typing_extensions.TypedDict`, [#3374](https://github.com/pydantic/pydantic/pull/3374) by [@BvB93](https://github.com/BvB93) +* Changed the validator method name in the [Custom Errors example](https://docs.pydantic.dev/usage/models/#custom-errors) + to more accurately describe what the validator is doing; changed from `name_must_contain_space` to ` value_must_equal_bar`, [#3327](https://github.com/pydantic/pydantic/pull/3327) by [@michaelrios28](https://github.com/michaelrios28) +* Add `AmqpDsn` class, [#3254](https://github.com/pydantic/pydantic/pull/3254) by [@kludex](https://github.com/kludex) +* Always use `Enum` value as default in generated JSON schema, [#3190](https://github.com/pydantic/pydantic/pull/3190) by [@joaommartins](https://github.com/joaommartins) +* Add support for Mypy 0.920, [#3175](https://github.com/pydantic/pydantic/pull/3175) by [@christianbundy](https://github.com/christianbundy) +* `validate_arguments` now supports `extra` customization (used to always be `Extra.forbid`), [#3161](https://github.com/pydantic/pydantic/pull/3161) by [@PrettyWood](https://github.com/PrettyWood) +* Complex types can be set by nested environment variables, [#3159](https://github.com/pydantic/pydantic/pull/3159) by [@Air-Mark](https://github.com/Air-Mark) +* Fix mypy plugin to collect fields based on `pydantic.utils.is_valid_field` so that it ignores untyped private variables, [#3146](https://github.com/pydantic/pydantic/pull/3146) by [@hi-ogawa](https://github.com/hi-ogawa) +* fix `validate_arguments` issue with `Config.validate_all`, [#3135](https://github.com/pydantic/pydantic/pull/3135) by [@PrettyWood](https://github.com/PrettyWood) +* avoid dict coercion when using dict subclasses as field type, [#3122](https://github.com/pydantic/pydantic/pull/3122) by [@PrettyWood](https://github.com/PrettyWood) +* add support for `object` type, [#3062](https://github.com/pydantic/pydantic/pull/3062) by [@PrettyWood](https://github.com/PrettyWood) +* Updates pydantic dataclasses to keep `_special` properties on parent classes, [#3043](https://github.com/pydantic/pydantic/pull/3043) by [@zulrang](https://github.com/zulrang) +* Add a `TypedDict` class for error objects, [#3038](https://github.com/pydantic/pydantic/pull/3038) by [@matthewhughes934](https://github.com/matthewhughes934) +* Fix support for using a subclass of an annotation as a default, [#3018](https://github.com/pydantic/pydantic/pull/3018) by [@JacobHayes](https://github.com/JacobHayes) +* make `create_model_from_typeddict` mypy compliant, [#3008](https://github.com/pydantic/pydantic/pull/3008) by [@PrettyWood](https://github.com/PrettyWood) +* Make multiple inheritance work when using `PrivateAttr`, [#2989](https://github.com/pydantic/pydantic/pull/2989) by [@hmvp](https://github.com/hmvp) +* Parse environment variables as JSON, if they have a `Union` type with a complex subfield, [#2936](https://github.com/pydantic/pydantic/pull/2936) by [@cbartz](https://github.com/cbartz) +* Prevent `StrictStr` permitting `Enum` values where the enum inherits from `str`, [#2929](https://github.com/pydantic/pydantic/pull/2929) by [@samuelcolvin](https://github.com/samuelcolvin) +* Make `SecretsSettingsSource` parse values being assigned to fields of complex types when sourced from a secrets file, + just as when sourced from environment variables, [#2917](https://github.com/pydantic/pydantic/pull/2917) by [@davidmreed](https://github.com/davidmreed) +* add a dark mode to _pydantic_ documentation, [#2913](https://github.com/pydantic/pydantic/pull/2913) by [@gbdlin](https://github.com/gbdlin) +* Make `pydantic-mypy` plugin compatible with `pyproject.toml` configuration, consistent with `mypy` changes. + See the [doc](https://docs.pydantic.dev/mypy_plugin/#configuring-the-plugin) for more information, [#2908](https://github.com/pydantic/pydantic/pull/2908) by [@jrwalk](https://github.com/jrwalk) +* add Python 3.10 support, [#2885](https://github.com/pydantic/pydantic/pull/2885) by [@PrettyWood](https://github.com/PrettyWood) +* Correctly parse generic models with `Json[T]`, [#2860](https://github.com/pydantic/pydantic/pull/2860) by [@geekingfrog](https://github.com/geekingfrog) +* Update contrib docs re: Python version to use for building docs, [#2856](https://github.com/pydantic/pydantic/pull/2856) by [@paxcodes](https://github.com/paxcodes) +* Clarify documentation about _pydantic_'s support for custom validation and strict type checking, + despite _pydantic_ being primarily a parsing library, [#2855](https://github.com/pydantic/pydantic/pull/2855) by [@paxcodes](https://github.com/paxcodes) +* Fix schema generation for `Deque` fields, [#2810](https://github.com/pydantic/pydantic/pull/2810) by [@sergejkozin](https://github.com/sergejkozin) +* fix an edge case when mixing constraints and `Literal`, [#2794](https://github.com/pydantic/pydantic/pull/2794) by [@PrettyWood](https://github.com/PrettyWood) +* Fix postponed annotation resolution for `NamedTuple` and `TypedDict` when they're used directly as the type of fields + within Pydantic models, [#2760](https://github.com/pydantic/pydantic/pull/2760) by [@jameysharp](https://github.com/jameysharp) +* Fix bug when `mypy` plugin fails on `construct` method call for `BaseSettings` derived classes, [#2753](https://github.com/pydantic/pydantic/pull/2753) by [@uriyyo](https://github.com/uriyyo) +* Add function overloading for a `pydantic.create_model` function, [#2748](https://github.com/pydantic/pydantic/pull/2748) by [@uriyyo](https://github.com/uriyyo) +* Fix mypy plugin issue with self field declaration, [#2743](https://github.com/pydantic/pydantic/pull/2743) by [@uriyyo](https://github.com/uriyyo) +* The colon at the end of the line "The fields which were supplied when user was initialised:" suggests that the code following it is related. + Changed it to a period, [#2733](https://github.com/pydantic/pydantic/pull/2733) by [@krisaoe](https://github.com/krisaoe) +* Renamed variable `schema` to `schema_` to avoid shadowing of global variable name, [#2724](https://github.com/pydantic/pydantic/pull/2724) by [@shahriyarr](https://github.com/shahriyarr) +* Add support for autocomplete in VS Code via `__dataclass_transform__`, [#2721](https://github.com/pydantic/pydantic/pull/2721) by [@tiangolo](https://github.com/tiangolo) +* add missing type annotations in `BaseConfig` and handle `max_length = 0`, [#2719](https://github.com/pydantic/pydantic/pull/2719) by [@PrettyWood](https://github.com/PrettyWood) +* Change `orm_mode` checking to allow recursive ORM mode parsing with dicts, [#2718](https://github.com/pydantic/pydantic/pull/2718) by [@nuno-andre](https://github.com/nuno-andre) +* Add episode 313 of the *Talk Python To Me* podcast, where Michael Kennedy and Samuel Colvin discuss Pydantic, to the docs, [#2712](https://github.com/pydantic/pydantic/pull/2712) by [@RatulMaharaj](https://github.com/RatulMaharaj) +* fix JSON schema generation when a field is of type `NamedTuple` and has a default value, [#2707](https://github.com/pydantic/pydantic/pull/2707) by [@PrettyWood](https://github.com/PrettyWood) +* `Enum` fields now properly support extra kwargs in schema generation, [#2697](https://github.com/pydantic/pydantic/pull/2697) by [@sammchardy](https://github.com/sammchardy) +* **Breaking Change, see [#3780](https://github.com/pydantic/pydantic/pull/3780)**: Make serialization of referenced pydantic models possible, [#2650](https://github.com/pydantic/pydantic/pull/2650) by [@PrettyWood](https://github.com/PrettyWood) +* Add `uniqueItems` option to `ConstrainedList`, [#2618](https://github.com/pydantic/pydantic/pull/2618) by [@nuno-andre](https://github.com/nuno-andre) +* Try to evaluate forward refs automatically at model creation, [#2588](https://github.com/pydantic/pydantic/pull/2588) by [@uriyyo](https://github.com/uriyyo) +* Switch docs preview and coverage display to use [smokeshow](https://smokeshow.helpmanual.io/), [#2580](https://github.com/pydantic/pydantic/pull/2580) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add `__version__` attribute to pydantic module, [#2572](https://github.com/pydantic/pydantic/pull/2572) by [@paxcodes](https://github.com/paxcodes) +* Add `postgresql+asyncpg`, `postgresql+pg8000`, `postgresql+psycopg2`, `postgresql+psycopg2cffi`, `postgresql+py-postgresql` + and `postgresql+pygresql` schemes for `PostgresDsn`, [#2567](https://github.com/pydantic/pydantic/pull/2567) by [@postgres-asyncpg](https://github.com/postgres-asyncpg) +* Enable the Hypothesis plugin to generate a constrained decimal when the `decimal_places` argument is specified, [#2524](https://github.com/pydantic/pydantic/pull/2524) by [@cwe5590](https://github.com/cwe5590) +* Allow `collections.abc.Callable` to be used as type in Python 3.9, [#2519](https://github.com/pydantic/pydantic/pull/2519) by [@daviskirk](https://github.com/daviskirk) +* Documentation update how to custom compile pydantic when using pip install, small change in `setup.py` + to allow for custom CFLAGS when compiling, [#2517](https://github.com/pydantic/pydantic/pull/2517) by [@peterroelants](https://github.com/peterroelants) +* remove side effect of `default_factory` to run it only once even if `Config.validate_all` is set, [#2515](https://github.com/pydantic/pydantic/pull/2515) by [@PrettyWood](https://github.com/PrettyWood) +* Add lookahead to ip regexes for `AnyUrl` hosts. This allows urls with DNS labels + looking like IPs to validate as they are perfectly valid host names, [#2512](https://github.com/pydantic/pydantic/pull/2512) by [@sbv-csis](https://github.com/sbv-csis) +* Set `minItems` and `maxItems` in generated JSON schema for fixed-length tuples, [#2497](https://github.com/pydantic/pydantic/pull/2497) by [@PrettyWood](https://github.com/PrettyWood) +* Add `strict` argument to `conbytes`, [#2489](https://github.com/pydantic/pydantic/pull/2489) by [@koxudaxi](https://github.com/koxudaxi) +* Support user defined generic field types in generic models, [#2465](https://github.com/pydantic/pydantic/pull/2465) by [@daviskirk](https://github.com/daviskirk) +* Add an example and a short explanation of subclassing `GetterDict` to docs, [#2463](https://github.com/pydantic/pydantic/pull/2463) by [@nuno-andre](https://github.com/nuno-andre) +* add `KafkaDsn` type, `HttpUrl` now has default port 80 for http and 443 for https, [#2447](https://github.com/pydantic/pydantic/pull/2447) by [@MihanixA](https://github.com/MihanixA) +* Add `PastDate` and `FutureDate` types, [#2425](https://github.com/pydantic/pydantic/pull/2425) by [@Kludex](https://github.com/Kludex) +* Support generating schema for `Generic` fields with subtypes, [#2375](https://github.com/pydantic/pydantic/pull/2375) by [@maximberg](https://github.com/maximberg) +* fix(encoder): serialize `NameEmail` to str, [#2341](https://github.com/pydantic/pydantic/pull/2341) by [@alecgerona](https://github.com/alecgerona) +* add `Config.smart_union` to prevent coercion in `Union` if possible, see + [the doc](https://docs.pydantic.dev/usage/model_config/#smart-union) for more information, [#2092](https://github.com/pydantic/pydantic/pull/2092) by [@PrettyWood](https://github.com/PrettyWood) +* Add ability to use `typing.Counter` as a model field type, [#2060](https://github.com/pydantic/pydantic/pull/2060) by [@uriyyo](https://github.com/uriyyo) +* Add parameterised subclasses to `__bases__` when constructing new parameterised classes, so that `A <: B => A[int] <: B[int]`, [#2007](https://github.com/pydantic/pydantic/pull/2007) by [@diabolo-dan](https://github.com/diabolo-dan) +* Create `FileUrl` type that allows URLs that conform to [RFC 8089](https://tools.ietf.org/html/rfc8089#section-2). + Add `host_required` parameter, which is `True` by default (`AnyUrl` and subclasses), `False` in `RedisDsn`, `FileUrl`, [#1983](https://github.com/pydantic/pydantic/pull/1983) by [@vgerak](https://github.com/vgerak) +* add `confrozenset()`, analogous to `conset()` and `conlist()`, [#1897](https://github.com/pydantic/pydantic/pull/1897) by [@PrettyWood](https://github.com/PrettyWood) +* stop calling parent class `root_validator` if overridden, [#1895](https://github.com/pydantic/pydantic/pull/1895) by [@PrettyWood](https://github.com/PrettyWood) +* Add `repr` (defaults to `True`) parameter to `Field`, to hide it from the default representation of the `BaseModel`, [#1831](https://github.com/pydantic/pydantic/pull/1831) by [@fnep](https://github.com/fnep) +* Accept empty query/fragment URL parts, [#1807](https://github.com/pydantic/pydantic/pull/1807) by [@xavier](https://github.com/xavier) + +## v1.8.2 (2021-05-11) + +!!! warning + A security vulnerability, level "moderate" is fixed in v1.8.2. Please upgrade **ASAP**. + See security advisory [CVE-2021-29510](https://github.com/pydantic/pydantic/security/advisories/GHSA-5jqp-qgf6-3pvh) + +* **Security fix:** Fix `date` and `datetime` parsing so passing either `'infinity'` or `float('inf')` + (or their negative values) does not cause an infinite loop, + see security advisory [CVE-2021-29510](https://github.com/pydantic/pydantic/security/advisories/GHSA-5jqp-qgf6-3pvh) +* fix schema generation with Enum by generating a valid name, [#2575](https://github.com/pydantic/pydantic/pull/2575) by [@PrettyWood](https://github.com/PrettyWood) +* fix JSON schema generation with a `Literal` of an enum member, [#2536](https://github.com/pydantic/pydantic/pull/2536) by [@PrettyWood](https://github.com/PrettyWood) +* Fix bug with configurations declarations that are passed as + keyword arguments during class creation, [#2532](https://github.com/pydantic/pydantic/pull/2532) by [@uriyyo](https://github.com/uriyyo) +* Allow passing `json_encoders` in class kwargs, [#2521](https://github.com/pydantic/pydantic/pull/2521) by [@layday](https://github.com/layday) +* support arbitrary types with custom `__eq__`, [#2483](https://github.com/pydantic/pydantic/pull/2483) by [@PrettyWood](https://github.com/PrettyWood) +* support `Annotated` in `validate_arguments` and in generic models with Python 3.9, [#2483](https://github.com/pydantic/pydantic/pull/2483) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.8.1 (2021-03-03) + +Bug fixes for regressions and new features from `v1.8` + +* allow elements of `Config.field` to update elements of a `Field`, [#2461](https://github.com/pydantic/pydantic/pull/2461) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix validation with a `BaseModel` field and a custom root type, [#2449](https://github.com/pydantic/pydantic/pull/2449) by [@PrettyWood](https://github.com/PrettyWood) +* expose `Pattern` encoder to `fastapi`, [#2444](https://github.com/pydantic/pydantic/pull/2444) by [@PrettyWood](https://github.com/PrettyWood) +* enable the Hypothesis plugin to generate a constrained float when the `multiple_of` argument is specified, [#2442](https://github.com/pydantic/pydantic/pull/2442) by [@tobi-lipede-oodle](https://github.com/tobi-lipede-oodle) +* Avoid `RecursionError` when using some types like `Enum` or `Literal` with generic models, [#2436](https://github.com/pydantic/pydantic/pull/2436) by [@PrettyWood](https://github.com/PrettyWood) +* do not overwrite declared `__hash__` in subclasses of a model, [#2422](https://github.com/pydantic/pydantic/pull/2422) by [@PrettyWood](https://github.com/PrettyWood) +* fix `mypy` complaints on `Path` and `UUID` related custom types, [#2418](https://github.com/pydantic/pydantic/pull/2418) by [@PrettyWood](https://github.com/PrettyWood) +* Support properly variable length tuples of compound types, [#2416](https://github.com/pydantic/pydantic/pull/2416) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.8 (2021-02-26) + +Thank you to pydantic's sponsors: +[@jorgecarleitao](https://github.com/jorgecarleitao), [@BCarley](https://github.com/BCarley), [@chdsbd](https://github.com/chdsbd), [@tiangolo](https://github.com/tiangolo), [@matin](https://github.com/matin), [@linusg](https://github.com/linusg), [@kevinalh](https://github.com/kevinalh), [@koxudaxi](https://github.com/koxudaxi), [@timdrijvers](https://github.com/timdrijvers), [@mkeen](https://github.com/mkeen), [@meadsteve](https://github.com/meadsteve), +[@ginomempin](https://github.com/ginomempin), [@primer-io](https://github.com/primer-io), [@and-semakin](https://github.com/and-semakin), [@tomthorogood](https://github.com/tomthorogood), [@AjitZK](https://github.com/AjitZK), [@westonsteimel](https://github.com/westonsteimel), [@Mazyod](https://github.com/Mazyod), [@christippett](https://github.com/christippett), [@CarlosDomingues](https://github.com/CarlosDomingues), +[@Kludex](https://github.com/Kludex), [@r-m-n](https://github.com/r-m-n) +for their kind support. + +### Highlights + +* [Hypothesis plugin](https://docs.pydantic.dev/hypothesis_plugin/) for testing, [#2097](https://github.com/pydantic/pydantic/pull/2097) by [@Zac-HD](https://github.com/Zac-HD) +* support for [`NamedTuple` and `TypedDict`](https://docs.pydantic.dev/usage/types/#annotated-types), [#2216](https://github.com/pydantic/pydantic/pull/2216) by [@PrettyWood](https://github.com/PrettyWood) +* Support [`Annotated` hints on model fields](https://docs.pydantic.dev/usage/schema/#typingannotated-fields), [#2147](https://github.com/pydantic/pydantic/pull/2147) by [@JacobHayes](https://github.com/JacobHayes) +* [`frozen` parameter on `Config`](https://docs.pydantic.dev/usage/model_config/) to allow models to be hashed, [#1880](https://github.com/pydantic/pydantic/pull/1880) by [@rhuille](https://github.com/rhuille) + +### Changes + +* **Breaking Change**, remove old deprecation aliases from v1, [#2415](https://github.com/pydantic/pydantic/pull/2415) by [@samuelcolvin](https://github.com/samuelcolvin): + * remove notes on migrating to v1 in docs + * remove `Schema` which was replaced by `Field` + * remove `Config.case_insensitive` which was replaced by `Config.case_sensitive` (default `False`) + * remove `Config.allow_population_by_alias` which was replaced by `Config.allow_population_by_field_name` + * remove `model.fields` which was replaced by `model.__fields__` + * remove `model.to_string()` which was replaced by `str(model)` + * remove `model.__values__` which was replaced by `model.__dict__` +* **Breaking Change:** always validate only first sublevel items with `each_item`. + There were indeed some edge cases with some compound types where the validated items were the last sublevel ones, [#1933](https://github.com/pydantic/pydantic/pull/1933) by [@PrettyWood](https://github.com/PrettyWood) +* Update docs extensions to fix local syntax highlighting, [#2400](https://github.com/pydantic/pydantic/pull/2400) by [@daviskirk](https://github.com/daviskirk) +* fix: allow `utils.lenient_issubclass` to handle `typing.GenericAlias` objects like `list[str]` in Python >= 3.9, [#2399](https://github.com/pydantic/pydantic/pull/2399) by [@daviskirk](https://github.com/daviskirk) +* Improve field declaration for _pydantic_ `dataclass` by allowing the usage of _pydantic_ `Field` or `'metadata'` kwarg of `dataclasses.field`, [#2384](https://github.com/pydantic/pydantic/pull/2384) by [@PrettyWood](https://github.com/PrettyWood) +* Making `typing-extensions` a required dependency, [#2368](https://github.com/pydantic/pydantic/pull/2368) by [@samuelcolvin](https://github.com/samuelcolvin) +* Make `resolve_annotations` more lenient, allowing for missing modules, [#2363](https://github.com/pydantic/pydantic/pull/2363) by [@samuelcolvin](https://github.com/samuelcolvin) +* Allow configuring models through class kwargs, [#2356](https://github.com/pydantic/pydantic/pull/2356) by [@Bobronium](https://github.com/Bobronium) +* Prevent `Mapping` subclasses from always being coerced to `dict`, [#2325](https://github.com/pydantic/pydantic/pull/2325) by [@ofek](https://github.com/ofek) +* fix: allow `None` for type `Optional[conset / conlist]`, [#2320](https://github.com/pydantic/pydantic/pull/2320) by [@PrettyWood](https://github.com/PrettyWood) +* Support empty tuple type, [#2318](https://github.com/pydantic/pydantic/pull/2318) by [@PrettyWood](https://github.com/PrettyWood) +* fix: `python_requires` metadata to require >=3.6.1, [#2306](https://github.com/pydantic/pydantic/pull/2306) by [@hukkinj1](https://github.com/hukkinj1) +* Properly encode `Decimal` with, or without any decimal places, [#2293](https://github.com/pydantic/pydantic/pull/2293) by [@hultner](https://github.com/hultner) +* fix: update `__fields_set__` in `BaseModel.copy(update=…)`, [#2290](https://github.com/pydantic/pydantic/pull/2290) by [@PrettyWood](https://github.com/PrettyWood) +* fix: keep order of fields with `BaseModel.construct()`, [#2281](https://github.com/pydantic/pydantic/pull/2281) by [@PrettyWood](https://github.com/PrettyWood) +* Support generating schema for Generic fields, [#2262](https://github.com/pydantic/pydantic/pull/2262) by [@maximberg](https://github.com/maximberg) +* Fix `validate_decorator` so `**kwargs` doesn't exclude values when the keyword + has the same name as the `*args` or `**kwargs` names, [#2251](https://github.com/pydantic/pydantic/pull/2251) by [@cybojenix](https://github.com/cybojenix) +* Prevent overriding positional arguments with keyword arguments in + `validate_arguments`, as per behaviour with native functions, [#2249](https://github.com/pydantic/pydantic/pull/2249) by [@cybojenix](https://github.com/cybojenix) +* add documentation for `con*` type functions, [#2242](https://github.com/pydantic/pydantic/pull/2242) by [@tayoogunbiyi](https://github.com/tayoogunbiyi) +* Support custom root type (aka `__root__`) when using `parse_obj()` with nested models, [#2238](https://github.com/pydantic/pydantic/pull/2238) by [@PrettyWood](https://github.com/PrettyWood) +* Support custom root type (aka `__root__`) with `from_orm()`, [#2237](https://github.com/pydantic/pydantic/pull/2237) by [@PrettyWood](https://github.com/PrettyWood) +* ensure cythonized functions are left untouched when creating models, based on [#1944](https://github.com/pydantic/pydantic/pull/1944) by [@kollmats](https://github.com/kollmats), [#2228](https://github.com/pydantic/pydantic/pull/2228) by [@samuelcolvin](https://github.com/samuelcolvin) +* Resolve forward refs for stdlib dataclasses converted into _pydantic_ ones, [#2220](https://github.com/pydantic/pydantic/pull/2220) by [@PrettyWood](https://github.com/PrettyWood) +* Add support for `NamedTuple` and `TypedDict` types. + Those two types are now handled and validated when used inside `BaseModel` or _pydantic_ `dataclass`. + Two utils are also added `create_model_from_namedtuple` and `create_model_from_typeddict`, [#2216](https://github.com/pydantic/pydantic/pull/2216) by [@PrettyWood](https://github.com/PrettyWood) +* Do not ignore annotated fields when type is `Union[Type[...], ...]`, [#2213](https://github.com/pydantic/pydantic/pull/2213) by [@PrettyWood](https://github.com/PrettyWood) +* Raise a user-friendly `TypeError` when a `root_validator` does not return a `dict` (e.g. `None`), [#2209](https://github.com/pydantic/pydantic/pull/2209) by [@masalim2](https://github.com/masalim2) +* Add a `FrozenSet[str]` type annotation to the `allowed_schemes` argument on the `strict_url` field type, [#2198](https://github.com/pydantic/pydantic/pull/2198) by [@Midnighter](https://github.com/Midnighter) +* add `allow_mutation` constraint to `Field`, [#2195](https://github.com/pydantic/pydantic/pull/2195) by [@sblack-usu](https://github.com/sblack-usu) +* Allow `Field` with a `default_factory` to be used as an argument to a function + decorated with `validate_arguments`, [#2176](https://github.com/pydantic/pydantic/pull/2176) by [@thomascobb](https://github.com/thomascobb) +* Allow non-existent secrets directory by only issuing a warning, [#2175](https://github.com/pydantic/pydantic/pull/2175) by [@davidolrik](https://github.com/davidolrik) +* fix URL regex to parse fragment without query string, [#2168](https://github.com/pydantic/pydantic/pull/2168) by [@andrewmwhite](https://github.com/andrewmwhite) +* fix: ensure to always return one of the values in `Literal` field type, [#2166](https://github.com/pydantic/pydantic/pull/2166) by [@PrettyWood](https://github.com/PrettyWood) +* Support `typing.Annotated` hints on model fields. A `Field` may now be set in the type hint with `Annotated[..., Field(...)`; all other annotations are ignored but still visible with `get_type_hints(..., include_extras=True)`, [#2147](https://github.com/pydantic/pydantic/pull/2147) by [@JacobHayes](https://github.com/JacobHayes) +* Added `StrictBytes` type as well as `strict=False` option to `ConstrainedBytes`, [#2136](https://github.com/pydantic/pydantic/pull/2136) by [@rlizzo](https://github.com/rlizzo) +* added `Config.anystr_lower` and `to_lower` kwarg to `constr` and `conbytes`, [#2134](https://github.com/pydantic/pydantic/pull/2134) by [@tayoogunbiyi](https://github.com/tayoogunbiyi) +* Support plain `typing.Tuple` type, [#2132](https://github.com/pydantic/pydantic/pull/2132) by [@PrettyWood](https://github.com/PrettyWood) +* Add a bound method `validate` to functions decorated with `validate_arguments` + to validate parameters without actually calling the function, [#2127](https://github.com/pydantic/pydantic/pull/2127) by [@PrettyWood](https://github.com/PrettyWood) +* Add the ability to customize settings sources (add / disable / change priority order), [#2107](https://github.com/pydantic/pydantic/pull/2107) by [@kozlek](https://github.com/kozlek) +* Fix mypy complaints about most custom _pydantic_ types, [#2098](https://github.com/pydantic/pydantic/pull/2098) by [@PrettyWood](https://github.com/PrettyWood) +* Add a [Hypothesis](https://hypothesis.readthedocs.io/) plugin for easier [property-based testing](https://increment.com/testing/in-praise-of-property-based-testing/) with Pydantic's custom types - [usage details here](https://docs.pydantic.dev/hypothesis_plugin/), [#2097](https://github.com/pydantic/pydantic/pull/2097) by [@Zac-HD](https://github.com/Zac-HD) +* add validator for `None`, `NoneType` or `Literal[None]`, [#2095](https://github.com/pydantic/pydantic/pull/2095) by [@PrettyWood](https://github.com/PrettyWood) +* Handle properly fields of type `Callable` with a default value, [#2094](https://github.com/pydantic/pydantic/pull/2094) by [@PrettyWood](https://github.com/PrettyWood) +* Updated `create_model` return type annotation to return type which inherits from `__base__` argument, [#2071](https://github.com/pydantic/pydantic/pull/2071) by [@uriyyo](https://github.com/uriyyo) +* Add merged `json_encoders` inheritance, [#2064](https://github.com/pydantic/pydantic/pull/2064) by [@art049](https://github.com/art049) +* allow overwriting `ClassVar`s in sub-models without having to re-annotate them, [#2061](https://github.com/pydantic/pydantic/pull/2061) by [@layday](https://github.com/layday) +* add default encoder for `Pattern` type, [#2045](https://github.com/pydantic/pydantic/pull/2045) by [@PrettyWood](https://github.com/PrettyWood) +* Add `NonNegativeInt`, `NonPositiveInt`, `NonNegativeFloat`, `NonPositiveFloat`, [#1975](https://github.com/pydantic/pydantic/pull/1975) by [@mdavis-xyz](https://github.com/mdavis-xyz) +* Use % for percentage in string format of colors, [#1960](https://github.com/pydantic/pydantic/pull/1960) by [@EdwardBetts](https://github.com/EdwardBetts) +* Fixed issue causing `KeyError` to be raised when building schema from multiple `BaseModel` with the same names declared in separate classes, [#1912](https://github.com/pydantic/pydantic/pull/1912) by [@JSextonn](https://github.com/JSextonn) +* Add `rediss` (Redis over SSL) protocol to `RedisDsn` + Allow URLs without `user` part (e.g., `rediss://:pass@localhost`), [#1911](https://github.com/pydantic/pydantic/pull/1911) by [@TrDex](https://github.com/TrDex) +* Add a new `frozen` boolean parameter to `Config` (default: `False`). + Setting `frozen=True` does everything that `allow_mutation=False` does, and also generates a `__hash__()` method for the model. This makes instances of the model potentially hashable if all the attributes are hashable, [#1880](https://github.com/pydantic/pydantic/pull/1880) by [@rhuille](https://github.com/rhuille) +* fix schema generation with multiple Enums having the same name, [#1857](https://github.com/pydantic/pydantic/pull/1857) by [@PrettyWood](https://github.com/PrettyWood) +* Added support for 13/19 digits VISA credit cards in `PaymentCardNumber` type, [#1416](https://github.com/pydantic/pydantic/pull/1416) by [@AlexanderSov](https://github.com/AlexanderSov) +* fix: prevent `RecursionError` while using recursive `GenericModel`s, [#1370](https://github.com/pydantic/pydantic/pull/1370) by [@xppt](https://github.com/xppt) +* use `enum` for `typing.Literal` in JSON schema, [#1350](https://github.com/pydantic/pydantic/pull/1350) by [@PrettyWood](https://github.com/PrettyWood) +* Fix: some recursive models did not require `update_forward_refs` and silently behaved incorrectly, [#1201](https://github.com/pydantic/pydantic/pull/1201) by [@PrettyWood](https://github.com/PrettyWood) +* Fix bug where generic models with fields where the typevar is nested in another type `a: List[T]` are considered to be concrete. This allows these models to be subclassed and composed as expected, [#947](https://github.com/pydantic/pydantic/pull/947) by [@daviskirk](https://github.com/daviskirk) +* Add `Config.copy_on_model_validation` flag. When set to `False`, _pydantic_ will keep models used as fields + untouched on validation instead of reconstructing (copying) them, [#265](https://github.com/pydantic/pydantic/pull/265) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.7.4 (2021-05-11) + +* **Security fix:** Fix `date` and `datetime` parsing so passing either `'infinity'` or `float('inf')` + (or their negative values) does not cause an infinite loop, + See security advisory [CVE-2021-29510](https://github.com/pydantic/pydantic/security/advisories/GHSA-5jqp-qgf6-3pvh) + +## v1.7.3 (2020-11-30) + +Thank you to pydantic's sponsors: +[@timdrijvers](https://github.com/timdrijvers), [@BCarley](https://github.com/BCarley), [@chdsbd](https://github.com/chdsbd), [@tiangolo](https://github.com/tiangolo), [@matin](https://github.com/matin), [@linusg](https://github.com/linusg), [@kevinalh](https://github.com/kevinalh), [@jorgecarleitao](https://github.com/jorgecarleitao), [@koxudaxi](https://github.com/koxudaxi), [@primer-api](https://github.com/primer-api), +[@mkeen](https://github.com/mkeen), [@meadsteve](https://github.com/meadsteve) for their kind support. + +* fix: set right default value for required (optional) fields, [#2142](https://github.com/pydantic/pydantic/pull/2142) by [@PrettyWood](https://github.com/PrettyWood) +* fix: support `underscore_attrs_are_private` with generic models, [#2138](https://github.com/pydantic/pydantic/pull/2138) by [@PrettyWood](https://github.com/PrettyWood) +* fix: update all modified field values in `root_validator` when `validate_assignment` is on, [#2116](https://github.com/pydantic/pydantic/pull/2116) by [@PrettyWood](https://github.com/PrettyWood) +* Allow pickling of `pydantic.dataclasses.dataclass` dynamically created from a built-in `dataclasses.dataclass`, [#2111](https://github.com/pydantic/pydantic/pull/2111) by [@aimestereo](https://github.com/aimestereo) +* Fix a regression where Enum fields would not propagate keyword arguments to the schema, [#2109](https://github.com/pydantic/pydantic/pull/2109) by [@bm424](https://github.com/bm424) +* Ignore `__doc__` as private attribute when `Config.underscore_attrs_are_private` is set, [#2090](https://github.com/pydantic/pydantic/pull/2090) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.7.2 (2020-11-01) + +* fix slow `GenericModel` concrete model creation, allow `GenericModel` concrete name reusing in module, [#2078](https://github.com/pydantic/pydantic/pull/2078) by [@Bobronium](https://github.com/Bobronium) +* keep the order of the fields when `validate_assignment` is set, [#2073](https://github.com/pydantic/pydantic/pull/2073) by [@PrettyWood](https://github.com/PrettyWood) +* forward all the params of the stdlib `dataclass` when converted into _pydantic_ `dataclass`, [#2065](https://github.com/pydantic/pydantic/pull/2065) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.7.1 (2020-10-28) + +Thank you to pydantic's sponsors: +[@timdrijvers](https://github.com/timdrijvers), [@BCarley](https://github.com/BCarley), [@chdsbd](https://github.com/chdsbd), [@tiangolo](https://github.com/tiangolo), [@matin](https://github.com/matin), [@linusg](https://github.com/linusg), [@kevinalh](https://github.com/kevinalh), [@jorgecarleitao](https://github.com/jorgecarleitao), [@koxudaxi](https://github.com/koxudaxi), [@primer-api](https://github.com/primer-api), [@mkeen](https://github.com/mkeen) +for their kind support. + +* fix annotation of `validate_arguments` when passing configuration as argument, [#2055](https://github.com/pydantic/pydantic/pull/2055) by [@layday](https://github.com/layday) +* Fix mypy assignment error when using `PrivateAttr`, [#2048](https://github.com/pydantic/pydantic/pull/2048) by [@aphedges](https://github.com/aphedges) +* fix `underscore_attrs_are_private` causing `TypeError` when overriding `__init__`, [#2047](https://github.com/pydantic/pydantic/pull/2047) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fixed regression introduced in v1.7 involving exception handling in field validators when `validate_assignment=True`, [#2044](https://github.com/pydantic/pydantic/pull/2044) by [@johnsabath](https://github.com/johnsabath) +* fix: _pydantic_ `dataclass` can inherit from stdlib `dataclass` + and `Config.arbitrary_types_allowed` is supported, [#2042](https://github.com/pydantic/pydantic/pull/2042) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.7 (2020-10-26) + +Thank you to pydantic's sponsors: +[@timdrijvers](https://github.com/timdrijvers), [@BCarley](https://github.com/BCarley), [@chdsbd](https://github.com/chdsbd), [@tiangolo](https://github.com/tiangolo), [@matin](https://github.com/matin), [@linusg](https://github.com/linusg), [@kevinalh](https://github.com/kevinalh), [@jorgecarleitao](https://github.com/jorgecarleitao), [@koxudaxi](https://github.com/koxudaxi), [@primer-api](https://github.com/primer-api) +for their kind support. + +### Highlights + +* Python 3.9 support, thanks [@PrettyWood](https://github.com/PrettyWood) +* [Private model attributes](https://docs.pydantic.dev/usage/models/#private-model-attributes), thanks [@Bobronium](https://github.com/Bobronium) +* ["secrets files" support in `BaseSettings`](https://docs.pydantic.dev/usage/settings/#secret-support), thanks [@mdgilene](https://github.com/mdgilene) +* [convert stdlib dataclasses to pydantic dataclasses and use stdlib dataclasses in models](https://docs.pydantic.dev/usage/dataclasses/#stdlib-dataclasses-and-pydantic-dataclasses), thanks [@PrettyWood](https://github.com/PrettyWood) + +### Changes + +* **Breaking Change:** remove `__field_defaults__`, add `default_factory` support with `BaseModel.construct`. + Use `.get_default()` method on fields in `__fields__` attribute instead, [#1732](https://github.com/pydantic/pydantic/pull/1732) by [@PrettyWood](https://github.com/PrettyWood) +* Rearrange CI to run linting as a separate job, split install recipes for different tasks, [#2020](https://github.com/pydantic/pydantic/pull/2020) by [@samuelcolvin](https://github.com/samuelcolvin) +* Allows subclasses of generic models to make some, or all, of the superclass's type parameters concrete, while + also defining new type parameters in the subclass, [#2005](https://github.com/pydantic/pydantic/pull/2005) by [@choogeboom](https://github.com/choogeboom) +* Call validator with the correct `values` parameter type in `BaseModel.__setattr__`, + when `validate_assignment = True` in model config, [#1999](https://github.com/pydantic/pydantic/pull/1999) by [@me-ransh](https://github.com/me-ransh) +* Force `fields.Undefined` to be a singleton object, fixing inherited generic model schemas, [#1981](https://github.com/pydantic/pydantic/pull/1981) by [@daviskirk](https://github.com/daviskirk) +* Include tests in source distributions, [#1976](https://github.com/pydantic/pydantic/pull/1976) by [@sbraz](https://github.com/sbraz) +* Add ability to use `min_length/max_length` constraints with secret types, [#1974](https://github.com/pydantic/pydantic/pull/1974) by [@uriyyo](https://github.com/uriyyo) +* Also check `root_validators` when `validate_assignment` is on, [#1971](https://github.com/pydantic/pydantic/pull/1971) by [@PrettyWood](https://github.com/PrettyWood) +* Fix const validators not running when custom validators are present, [#1957](https://github.com/pydantic/pydantic/pull/1957) by [@hmvp](https://github.com/hmvp) +* add `deque` to field types, [#1935](https://github.com/pydantic/pydantic/pull/1935) by [@wozniakty](https://github.com/wozniakty) +* add basic support for Python 3.9, [#1832](https://github.com/pydantic/pydantic/pull/1832) by [@PrettyWood](https://github.com/PrettyWood) +* Fix typo in the anchor of exporting_models.md#modelcopy and incorrect description, [#1821](https://github.com/pydantic/pydantic/pull/1821) by [@KimMachineGun](https://github.com/KimMachineGun) +* Added ability for `BaseSettings` to read "secret files", [#1820](https://github.com/pydantic/pydantic/pull/1820) by [@mdgilene](https://github.com/mdgilene) +* add `parse_raw_as` utility function, [#1812](https://github.com/pydantic/pydantic/pull/1812) by [@PrettyWood](https://github.com/PrettyWood) +* Support home directory relative paths for `dotenv` files (e.g. `~/.env`), [#1803](https://github.com/pydantic/pydantic/pull/1803) by [@PrettyWood](https://github.com/PrettyWood) +* Clarify documentation for `parse_file` to show that the argument + should be a file *path* not a file-like object, [#1794](https://github.com/pydantic/pydantic/pull/1794) by [@mdavis-xyz](https://github.com/mdavis-xyz) +* Fix false positive from mypy plugin when a class nested within a `BaseModel` is named `Model`, [#1770](https://github.com/pydantic/pydantic/pull/1770) by [@selimb](https://github.com/selimb) +* add basic support of Pattern type in schema generation, [#1767](https://github.com/pydantic/pydantic/pull/1767) by [@PrettyWood](https://github.com/PrettyWood) +* Support custom title, description and default in schema of enums, [#1748](https://github.com/pydantic/pydantic/pull/1748) by [@PrettyWood](https://github.com/PrettyWood) +* Properly represent `Literal` Enums when `use_enum_values` is True, [#1747](https://github.com/pydantic/pydantic/pull/1747) by [@noelevans](https://github.com/noelevans) +* Allows timezone information to be added to strings to be formatted as time objects. Permitted formats are `Z` for UTC + or an offset for absolute positive or negative time shifts. Or the timezone data can be omitted, [#1744](https://github.com/pydantic/pydantic/pull/1744) by [@noelevans](https://github.com/noelevans) +* Add stub `__init__` with Python 3.6 signature for `ForwardRef`, [#1738](https://github.com/pydantic/pydantic/pull/1738) by [@sirtelemak](https://github.com/sirtelemak) +* Fix behaviour with forward refs and optional fields in nested models, [#1736](https://github.com/pydantic/pydantic/pull/1736) by [@PrettyWood](https://github.com/PrettyWood) +* add `Enum` and `IntEnum` as valid types for fields, [#1735](https://github.com/pydantic/pydantic/pull/1735) by [@PrettyWood](https://github.com/PrettyWood) +* Change default value of `__module__` argument of `create_model` from `None` to `'pydantic.main'`. + Set reference of created concrete model to it's module to allow pickling (not applied to models created in + functions), [#1686](https://github.com/pydantic/pydantic/pull/1686) by [@Bobronium](https://github.com/Bobronium) +* Add private attributes support, [#1679](https://github.com/pydantic/pydantic/pull/1679) by [@Bobronium](https://github.com/Bobronium) +* add `config` to `@validate_arguments`, [#1663](https://github.com/pydantic/pydantic/pull/1663) by [@samuelcolvin](https://github.com/samuelcolvin) +* Allow descendant Settings models to override env variable names for the fields defined in parent Settings models with + `env` in their `Config`. Previously only `env_prefix` configuration option was applicable, [#1561](https://github.com/pydantic/pydantic/pull/1561) by [@ojomio](https://github.com/ojomio) +* Support `ref_template` when creating schema `$ref`s, [#1479](https://github.com/pydantic/pydantic/pull/1479) by [@kilo59](https://github.com/kilo59) +* Add a `__call__` stub to `PyObject` so that mypy will know that it is callable, [#1352](https://github.com/pydantic/pydantic/pull/1352) by [@brianmaissy](https://github.com/brianmaissy) +* `pydantic.dataclasses.dataclass` decorator now supports built-in `dataclasses.dataclass`. + It is hence possible to convert an existing `dataclass` easily to add Pydantic validation. + Moreover nested dataclasses are also supported, [#744](https://github.com/pydantic/pydantic/pull/744) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.6.2 (2021-05-11) + +* **Security fix:** Fix `date` and `datetime` parsing so passing either `'infinity'` or `float('inf')` + (or their negative values) does not cause an infinite loop, + See security advisory [CVE-2021-29510](https://github.com/pydantic/pydantic/security/advisories/GHSA-5jqp-qgf6-3pvh) + +## v1.6.1 (2020-07-15) + +* fix validation and parsing of nested models with `default_factory`, [#1710](https://github.com/pydantic/pydantic/pull/1710) by [@PrettyWood](https://github.com/PrettyWood) + +## v1.6 (2020-07-11) + +Thank you to pydantic's sponsors: [@matin](https://github.com/matin), [@tiangolo](https://github.com/tiangolo), [@chdsbd](https://github.com/chdsbd), [@jorgecarleitao](https://github.com/jorgecarleitao), and 1 anonymous sponsor for their kind support. + +* Modify validators for `conlist` and `conset` to not have `always=True`, [#1682](https://github.com/pydantic/pydantic/pull/1682) by [@samuelcolvin](https://github.com/samuelcolvin) +* add port check to `AnyUrl` (can't exceed 65536) ports are 16 insigned bits: `0 <= port <= 2**16-1` src: [rfc793 header format](https://tools.ietf.org/html/rfc793#section-3.1), [#1654](https://github.com/pydantic/pydantic/pull/1654) by [@flapili](https://github.com/flapili) +* Document default `regex` anchoring semantics, [#1648](https://github.com/pydantic/pydantic/pull/1648) by [@yurikhan](https://github.com/yurikhan) +* Use `chain.from_iterable` in class_validators.py. This is a faster and more idiomatic way of using `itertools.chain`. + Instead of computing all the items in the iterable and storing them in memory, they are computed one-by-one and never + stored as a huge list. This can save on both runtime and memory space, [#1642](https://github.com/pydantic/pydantic/pull/1642) by [@cool-RR](https://github.com/cool-RR) +* Add `conset()`, analogous to `conlist()`, [#1623](https://github.com/pydantic/pydantic/pull/1623) by [@patrickkwang](https://github.com/patrickkwang) +* make Pydantic errors (un)pickable, [#1616](https://github.com/pydantic/pydantic/pull/1616) by [@PrettyWood](https://github.com/PrettyWood) +* Allow custom encoding for `dotenv` files, [#1615](https://github.com/pydantic/pydantic/pull/1615) by [@PrettyWood](https://github.com/PrettyWood) +* Ensure `SchemaExtraCallable` is always defined to get type hints on BaseConfig, [#1614](https://github.com/pydantic/pydantic/pull/1614) by [@PrettyWood](https://github.com/PrettyWood) +* Update datetime parser to support negative timestamps, [#1600](https://github.com/pydantic/pydantic/pull/1600) by [@mlbiche](https://github.com/mlbiche) +* Update mypy, remove `AnyType` alias for `Type[Any]`, [#1598](https://github.com/pydantic/pydantic/pull/1598) by [@samuelcolvin](https://github.com/samuelcolvin) +* Adjust handling of root validators so that errors are aggregated from _all_ failing root validators, instead of reporting on only the first root validator to fail, [#1586](https://github.com/pydantic/pydantic/pull/1586) by [@beezee](https://github.com/beezee) +* Make `__modify_schema__` on Enums apply to the enum schema rather than fields that use the enum, [#1581](https://github.com/pydantic/pydantic/pull/1581) by [@therefromhere](https://github.com/therefromhere) +* Fix behavior of `__all__` key when used in conjunction with index keys in advanced include/exclude of fields that are sequences, [#1579](https://github.com/pydantic/pydantic/pull/1579) by [@xspirus](https://github.com/xspirus) +* Subclass validators do not run when referencing a `List` field defined in a parent class when `each_item=True`. Added an example to the docs illustrating this, [#1566](https://github.com/pydantic/pydantic/pull/1566) by [@samueldeklund](https://github.com/samueldeklund) +* change `schema.field_class_to_schema` to support `frozenset` in schema, [#1557](https://github.com/pydantic/pydantic/pull/1557) by [@wangpeibao](https://github.com/wangpeibao) +* Call `__modify_schema__` only for the field schema, [#1552](https://github.com/pydantic/pydantic/pull/1552) by [@PrettyWood](https://github.com/PrettyWood) +* Move the assignment of `field.validate_always` in `fields.py` so the `always` parameter of validators work on inheritance, [#1545](https://github.com/pydantic/pydantic/pull/1545) by [@dcHHH](https://github.com/dcHHH) +* Added support for UUID instantiation through 16 byte strings such as `b'\x12\x34\x56\x78' * 4`. This was done to support `BINARY(16)` columns in sqlalchemy, [#1541](https://github.com/pydantic/pydantic/pull/1541) by [@shawnwall](https://github.com/shawnwall) +* Add a test assertion that `default_factory` can return a singleton, [#1523](https://github.com/pydantic/pydantic/pull/1523) by [@therefromhere](https://github.com/therefromhere) +* Add `NameEmail.__eq__` so duplicate `NameEmail` instances are evaluated as equal, [#1514](https://github.com/pydantic/pydantic/pull/1514) by [@stephen-bunn](https://github.com/stephen-bunn) +* Add datamodel-code-generator link in pydantic document site, [#1500](https://github.com/pydantic/pydantic/pull/1500) by [@koxudaxi](https://github.com/koxudaxi) +* Added a "Discussion of Pydantic" section to the documentation, with a link to "Pydantic Introduction" video by Alexander Hultnér, [#1499](https://github.com/pydantic/pydantic/pull/1499) by [@hultner](https://github.com/hultner) +* Avoid some side effects of `default_factory` by calling it only once + if possible and by not setting a default value in the schema, [#1491](https://github.com/pydantic/pydantic/pull/1491) by [@PrettyWood](https://github.com/PrettyWood) +* Added docs about dumping dataclasses to JSON, [#1487](https://github.com/pydantic/pydantic/pull/1487) by [@mikegrima](https://github.com/mikegrima) +* Make `BaseModel.__signature__` class-only, so getting `__signature__` from model instance will raise `AttributeError`, [#1466](https://github.com/pydantic/pydantic/pull/1466) by [@Bobronium](https://github.com/Bobronium) +* include `'format': 'password'` in the schema for secret types, [#1424](https://github.com/pydantic/pydantic/pull/1424) by [@atheuz](https://github.com/atheuz) +* Modify schema constraints on `ConstrainedFloat` so that `exclusiveMinimum` and + minimum are not included in the schema if they are equal to `-math.inf` and + `exclusiveMaximum` and `maximum` are not included if they are equal to `math.inf`, [#1417](https://github.com/pydantic/pydantic/pull/1417) by [@vdwees](https://github.com/vdwees) +* Squash internal `__root__` dicts in `.dict()` (and, by extension, in `.json()`), [#1414](https://github.com/pydantic/pydantic/pull/1414) by [@patrickkwang](https://github.com/patrickkwang) +* Move `const` validator to post-validators so it validates the parsed value, [#1410](https://github.com/pydantic/pydantic/pull/1410) by [@selimb](https://github.com/selimb) +* Fix model validation to handle nested literals, e.g. `Literal['foo', Literal['bar']]`, [#1364](https://github.com/pydantic/pydantic/pull/1364) by [@DBCerigo](https://github.com/DBCerigo) +* Remove `user_required = True` from `RedisDsn`, neither user nor password are required, [#1275](https://github.com/pydantic/pydantic/pull/1275) by [@samuelcolvin](https://github.com/samuelcolvin) +* Remove extra `allOf` from schema for fields with `Union` and custom `Field`, [#1209](https://github.com/pydantic/pydantic/pull/1209) by [@mostaphaRoudsari](https://github.com/mostaphaRoudsari) +* Updates OpenAPI schema generation to output all enums as separate models. + Instead of inlining the enum values in the model schema, models now use a `$ref` + property to point to the enum definition, [#1173](https://github.com/pydantic/pydantic/pull/1173) by [@calvinwyoung](https://github.com/calvinwyoung) + +## v1.5.1 (2020-04-23) + +* Signature generation with `extra: allow` never uses a field name, [#1418](https://github.com/pydantic/pydantic/pull/1418) by [@prettywood](https://github.com/prettywood) +* Avoid mutating `Field` default value, [#1412](https://github.com/pydantic/pydantic/pull/1412) by [@prettywood](https://github.com/prettywood) + +## v1.5 (2020-04-18) + +* Make includes/excludes arguments for `.dict()`, `._iter()`, ..., immutable, [#1404](https://github.com/pydantic/pydantic/pull/1404) by [@AlexECX](https://github.com/AlexECX) +* Always use a field's real name with includes/excludes in `model._iter()`, regardless of `by_alias`, [#1397](https://github.com/pydantic/pydantic/pull/1397) by [@AlexECX](https://github.com/AlexECX) +* Update constr regex example to include start and end lines, [#1396](https://github.com/pydantic/pydantic/pull/1396) by [@lmcnearney](https://github.com/lmcnearney) +* Confirm that shallow `model.copy()` does make a shallow copy of attributes, [#1383](https://github.com/pydantic/pydantic/pull/1383) by [@samuelcolvin](https://github.com/samuelcolvin) +* Renaming `model_name` argument of `main.create_model()` to `__model_name` to allow using `model_name` as a field name, [#1367](https://github.com/pydantic/pydantic/pull/1367) by [@kittipatv](https://github.com/kittipatv) +* Replace raising of exception to silent passing for non-Var attributes in mypy plugin, [#1345](https://github.com/pydantic/pydantic/pull/1345) by [@b0g3r](https://github.com/b0g3r) +* Remove `typing_extensions` dependency for Python 3.8, [#1342](https://github.com/pydantic/pydantic/pull/1342) by [@prettywood](https://github.com/prettywood) +* Make `SecretStr` and `SecretBytes` initialization idempotent, [#1330](https://github.com/pydantic/pydantic/pull/1330) by [@atheuz](https://github.com/atheuz) +* document making secret types dumpable using the json method, [#1328](https://github.com/pydantic/pydantic/pull/1328) by [@atheuz](https://github.com/atheuz) +* Move all testing and build to github actions, add windows and macos binaries, + thank you [@StephenBrown2](https://github.com/StephenBrown2) for much help, [#1326](https://github.com/pydantic/pydantic/pull/1326) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix card number length check in `PaymentCardNumber`, `PaymentCardBrand` now inherits from `str`, [#1317](https://github.com/pydantic/pydantic/pull/1317) by [@samuelcolvin](https://github.com/samuelcolvin) +* Have `BaseModel` inherit from `Representation` to make mypy happy when overriding `__str__`, [#1310](https://github.com/pydantic/pydantic/pull/1310) by [@FuegoFro](https://github.com/FuegoFro) +* Allow `None` as input to all optional list fields, [#1307](https://github.com/pydantic/pydantic/pull/1307) by [@prettywood](https://github.com/prettywood) +* Add `datetime` field to `default_factory` example, [#1301](https://github.com/pydantic/pydantic/pull/1301) by [@StephenBrown2](https://github.com/StephenBrown2) +* Allow subclasses of known types to be encoded with superclass encoder, [#1291](https://github.com/pydantic/pydantic/pull/1291) by [@StephenBrown2](https://github.com/StephenBrown2) +* Exclude exported fields from all elements of a list/tuple of submodels/dicts with `'__all__'`, [#1286](https://github.com/pydantic/pydantic/pull/1286) by [@masalim2](https://github.com/masalim2) +* Add pydantic.color.Color objects as available input for Color fields, [#1258](https://github.com/pydantic/pydantic/pull/1258) by [@leosussan](https://github.com/leosussan) +* In examples, type nullable fields as `Optional`, so that these are valid mypy annotations, [#1248](https://github.com/pydantic/pydantic/pull/1248) by [@kokes](https://github.com/kokes) +* Make `pattern_validator()` accept pre-compiled `Pattern` objects. Fix `str_validator()` return type to `str`, [#1237](https://github.com/pydantic/pydantic/pull/1237) by [@adamgreg](https://github.com/adamgreg) +* Document how to manage Generics and inheritance, [#1229](https://github.com/pydantic/pydantic/pull/1229) by [@esadruhn](https://github.com/esadruhn) +* `update_forward_refs()` method of BaseModel now copies `__dict__` of class module instead of modyfying it, [#1228](https://github.com/pydantic/pydantic/pull/1228) by [@paul-ilyin](https://github.com/paul-ilyin) +* Support instance methods and class methods with `@validate_arguments`, [#1222](https://github.com/pydantic/pydantic/pull/1222) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add `default_factory` argument to `Field` to create a dynamic default value by passing a zero-argument callable, [#1210](https://github.com/pydantic/pydantic/pull/1210) by [@prettywood](https://github.com/prettywood) +* add support for `NewType` of `List`, `Optional`, etc, [#1207](https://github.com/pydantic/pydantic/pull/1207) by [@Kazy](https://github.com/Kazy) +* fix mypy signature for `root_validator`, [#1192](https://github.com/pydantic/pydantic/pull/1192) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fixed parsing of nested 'custom root type' models, [#1190](https://github.com/pydantic/pydantic/pull/1190) by [@Shados](https://github.com/Shados) +* Add `validate_arguments` function decorator which checks the arguments to a function matches type annotations, [#1179](https://github.com/pydantic/pydantic/pull/1179) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add `__signature__` to models, [#1034](https://github.com/pydantic/pydantic/pull/1034) by [@Bobronium](https://github.com/Bobronium) +* Refactor `._iter()` method, 10x speed boost for `dict(model)`, [#1017](https://github.com/pydantic/pydantic/pull/1017) by [@Bobronium](https://github.com/Bobronium) + +## v1.4 (2020-01-24) + +* **Breaking Change:** alias precedence logic changed so aliases on a field always take priority over + an alias from `alias_generator` to avoid buggy/unexpected behaviour, + see [here](https://docs.pydantic.dev/usage/model_config/#alias-precedence) for details, [#1178](https://github.com/pydantic/pydantic/pull/1178) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add support for unicode and punycode in TLDs, [#1182](https://github.com/pydantic/pydantic/pull/1182) by [@jamescurtin](https://github.com/jamescurtin) +* Fix `cls` argument in validators during assignment, [#1172](https://github.com/pydantic/pydantic/pull/1172) by [@samuelcolvin](https://github.com/samuelcolvin) +* completing Luhn algorithm for `PaymentCardNumber`, [#1166](https://github.com/pydantic/pydantic/pull/1166) by [@cuencandres](https://github.com/cuencandres) +* add support for generics that implement `__get_validators__` like a custom data type, [#1159](https://github.com/pydantic/pydantic/pull/1159) by [@tiangolo](https://github.com/tiangolo) +* add support for infinite generators with `Iterable`, [#1152](https://github.com/pydantic/pydantic/pull/1152) by [@tiangolo](https://github.com/tiangolo) +* fix `url_regex` to accept schemas with `+`, `-` and `.` after the first character, [#1142](https://github.com/pydantic/pydantic/pull/1142) by [@samuelcolvin](https://github.com/samuelcolvin) +* move `version_info()` to `version.py`, suggest its use in issues, [#1138](https://github.com/pydantic/pydantic/pull/1138) by [@samuelcolvin](https://github.com/samuelcolvin) +* Improve pydantic import time by roughly 50% by deferring some module loading and regex compilation, [#1127](https://github.com/pydantic/pydantic/pull/1127) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix `EmailStr` and `NameEmail` to accept instances of themselves in cython, [#1126](https://github.com/pydantic/pydantic/pull/1126) by [@koxudaxi](https://github.com/koxudaxi) +* Pass model class to the `Config.schema_extra` callable, [#1125](https://github.com/pydantic/pydantic/pull/1125) by [@therefromhere](https://github.com/therefromhere) +* Fix regex for username and password in URLs, [#1115](https://github.com/pydantic/pydantic/pull/1115) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add support for nested generic models, [#1104](https://github.com/pydantic/pydantic/pull/1104) by [@dmontagu](https://github.com/dmontagu) +* add `__all__` to `__init__.py` to prevent "implicit reexport" errors from mypy, [#1072](https://github.com/pydantic/pydantic/pull/1072) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add support for using "dotenv" files with `BaseSettings`, [#1011](https://github.com/pydantic/pydantic/pull/1011) by [@acnebs](https://github.com/acnebs) + +## v1.3 (2019-12-21) + +* Change `schema` and `schema_model` to handle dataclasses by using their `__pydantic_model__` feature, [#792](https://github.com/pydantic/pydantic/pull/792) by [@aviramha](https://github.com/aviramha) +* Added option for `root_validator` to be skipped if values validation fails using keyword `skip_on_failure=True`, [#1049](https://github.com/pydantic/pydantic/pull/1049) by [@aviramha](https://github.com/aviramha) +* Allow `Config.schema_extra` to be a callable so that the generated schema can be post-processed, [#1054](https://github.com/pydantic/pydantic/pull/1054) by [@selimb](https://github.com/selimb) +* Update mypy to version 0.750, [#1057](https://github.com/pydantic/pydantic/pull/1057) by [@dmontagu](https://github.com/dmontagu) +* Trick Cython into allowing str subclassing, [#1061](https://github.com/pydantic/pydantic/pull/1061) by [@skewty](https://github.com/skewty) +* Prevent type attributes being added to schema unless the attribute `__schema_attributes__` is `True`, [#1064](https://github.com/pydantic/pydantic/pull/1064) by [@samuelcolvin](https://github.com/samuelcolvin) +* Change `BaseModel.parse_file` to use `Config.json_loads`, [#1067](https://github.com/pydantic/pydantic/pull/1067) by [@kierandarcy](https://github.com/kierandarcy) +* Fix for optional `Json` fields, [#1073](https://github.com/pydantic/pydantic/pull/1073) by [@volker48](https://github.com/volker48) +* Change the default number of threads used when compiling with cython to one, + allow override via the `CYTHON_NTHREADS` environment variable, [#1074](https://github.com/pydantic/pydantic/pull/1074) by [@samuelcolvin](https://github.com/samuelcolvin) +* Run FastAPI tests during Pydantic's CI tests, [#1075](https://github.com/pydantic/pydantic/pull/1075) by [@tiangolo](https://github.com/tiangolo) +* My mypy strictness constraints, and associated tweaks to type annotations, [#1077](https://github.com/pydantic/pydantic/pull/1077) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add `__eq__` to SecretStr and SecretBytes to allow "value equals", [#1079](https://github.com/pydantic/pydantic/pull/1079) by [@sbv-trueenergy](https://github.com/sbv-trueenergy) +* Fix schema generation for nested None case, [#1088](https://github.com/pydantic/pydantic/pull/1088) by [@lutostag](https://github.com/lutostag) +* Consistent checks for sequence like objects, [#1090](https://github.com/pydantic/pydantic/pull/1090) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix `Config` inheritance on `BaseSettings` when used with `env_prefix`, [#1091](https://github.com/pydantic/pydantic/pull/1091) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix for `__modify_schema__` when it conflicted with `field_class_to_schema*`, [#1102](https://github.com/pydantic/pydantic/pull/1102) by [@samuelcolvin](https://github.com/samuelcolvin) +* docs: Fix explanation of case sensitive environment variable names when populating `BaseSettings` subclass attributes, [#1105](https://github.com/pydantic/pydantic/pull/1105) by [@tribals](https://github.com/tribals) +* Rename django-rest-framework benchmark in documentation, [#1119](https://github.com/pydantic/pydantic/pull/1119) by [@frankie567](https://github.com/frankie567) + +## v1.2 (2019-11-28) + +* **Possible Breaking Change:** Add support for required `Optional` with `name: Optional[AnyType] = Field(...)` + and refactor `ModelField` creation to preserve `required` parameter value, [#1031](https://github.com/pydantic/pydantic/pull/1031) by [@tiangolo](https://github.com/tiangolo); + see [here](https://docs.pydantic.dev/usage/models/#required-optional-fields) for details +* Add benchmarks for `cattrs`, [#513](https://github.com/pydantic/pydantic/pull/513) by [@sebastianmika](https://github.com/sebastianmika) +* Add `exclude_none` option to `dict()` and friends, [#587](https://github.com/pydantic/pydantic/pull/587) by [@niknetniko](https://github.com/niknetniko) +* Add benchmarks for `valideer`, [#670](https://github.com/pydantic/pydantic/pull/670) by [@gsakkis](https://github.com/gsakkis) +* Add `parse_obj_as` and `parse_file_as` functions for ad-hoc parsing of data into arbitrary pydantic-compatible types, [#934](https://github.com/pydantic/pydantic/pull/934) by [@dmontagu](https://github.com/dmontagu) +* Add `allow_reuse` argument to validators, thus allowing validator reuse, [#940](https://github.com/pydantic/pydantic/pull/940) by [@dmontagu](https://github.com/dmontagu) +* Add support for mapping types for custom root models, [#958](https://github.com/pydantic/pydantic/pull/958) by [@dmontagu](https://github.com/dmontagu) +* Mypy plugin support for dataclasses, [#966](https://github.com/pydantic/pydantic/pull/966) by [@koxudaxi](https://github.com/koxudaxi) +* Add support for dataclasses default factory, [#968](https://github.com/pydantic/pydantic/pull/968) by [@ahirner](https://github.com/ahirner) +* Add a `ByteSize` type for converting byte string (`1GB`) to plain bytes, [#977](https://github.com/pydantic/pydantic/pull/977) by [@dgasmith](https://github.com/dgasmith) +* Fix mypy complaint about `@root_validator(pre=True)`, [#984](https://github.com/pydantic/pydantic/pull/984) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add manylinux binaries for Python 3.8 to pypi, also support manylinux2010, [#994](https://github.com/pydantic/pydantic/pull/994) by [@samuelcolvin](https://github.com/samuelcolvin) +* Adds ByteSize conversion to another unit, [#995](https://github.com/pydantic/pydantic/pull/995) by [@dgasmith](https://github.com/dgasmith) +* Fix `__str__` and `__repr__` inheritance for models, [#1022](https://github.com/pydantic/pydantic/pull/1022) by [@samuelcolvin](https://github.com/samuelcolvin) +* add testimonials section to docs, [#1025](https://github.com/pydantic/pydantic/pull/1025) by [@sullivancolin](https://github.com/sullivancolin) +* Add support for `typing.Literal` for Python 3.8, [#1026](https://github.com/pydantic/pydantic/pull/1026) by [@dmontagu](https://github.com/dmontagu) + +## v1.1.1 (2019-11-20) + +* Fix bug where use of complex fields on sub-models could cause fields to be incorrectly configured, [#1015](https://github.com/pydantic/pydantic/pull/1015) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.1 (2019-11-07) + +* Add a mypy plugin for type checking `BaseModel.__init__` and more, [#722](https://github.com/pydantic/pydantic/pull/722) by [@dmontagu](https://github.com/dmontagu) +* Change return type typehint for `GenericModel.__class_getitem__` to prevent PyCharm warnings, [#936](https://github.com/pydantic/pydantic/pull/936) by [@dmontagu](https://github.com/dmontagu) +* Fix usage of `Any` to allow `None`, also support `TypeVar` thus allowing use of un-parameterised collection types + e.g. `Dict` and `List`, [#962](https://github.com/pydantic/pydantic/pull/962) by [@samuelcolvin](https://github.com/samuelcolvin) +* Set `FieldInfo` on subfields to fix schema generation for complex nested types, [#965](https://github.com/pydantic/pydantic/pull/965) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.0 (2019-10-23) + +* **Breaking Change:** deprecate the `Model.fields` property, use `Model.__fields__` instead, [#883](https://github.com/pydantic/pydantic/pull/883) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change:** Change the precedence of aliases so child model aliases override parent aliases, + including using `alias_generator`, [#904](https://github.com/pydantic/pydantic/pull/904) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking change:** Rename `skip_defaults` to `exclude_unset`, and add ability to exclude actual defaults, [#915](https://github.com/pydantic/pydantic/pull/915) by [@dmontagu](https://github.com/dmontagu) +* Add `**kwargs` to `pydantic.main.ModelMetaclass.__new__` so `__init_subclass__` can take custom parameters on extended + `BaseModel` classes, [#867](https://github.com/pydantic/pydantic/pull/867) by [@retnikt](https://github.com/retnikt) +* Fix field of a type that has a default value, [#880](https://github.com/pydantic/pydantic/pull/880) by [@koxudaxi](https://github.com/koxudaxi) +* Use `FutureWarning` instead of `DeprecationWarning` when `alias` instead of `env` is used for settings models, [#881](https://github.com/pydantic/pydantic/pull/881) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix issue with `BaseSettings` inheritance and `alias` getting set to `None`, [#882](https://github.com/pydantic/pydantic/pull/882) by [@samuelcolvin](https://github.com/samuelcolvin) +* Modify `__repr__` and `__str__` methods to be consistent across all public classes, add `__pretty__` to support + python-devtools, [#884](https://github.com/pydantic/pydantic/pull/884) by [@samuelcolvin](https://github.com/samuelcolvin) +* deprecation warning for `case_insensitive` on `BaseSettings` config, [#885](https://github.com/pydantic/pydantic/pull/885) by [@samuelcolvin](https://github.com/samuelcolvin) +* For `BaseSettings` merge environment variables and in-code values recursively, as long as they create a valid object + when merged together, to allow splitting init arguments, [#888](https://github.com/pydantic/pydantic/pull/888) by [@idmitrievsky](https://github.com/idmitrievsky) +* change secret types example, [#890](https://github.com/pydantic/pydantic/pull/890) by [@ashears](https://github.com/ashears) +* Change the signature of `Model.construct()` to be more user-friendly, document `construct()` usage, [#898](https://github.com/pydantic/pydantic/pull/898) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add example for the `construct()` method, [#907](https://github.com/pydantic/pydantic/pull/907) by [@ashears](https://github.com/ashears) +* Improve use of `Field` constraints on complex types, raise an error if constraints are not enforceable, + also support tuples with an ellipsis `Tuple[X, ...]`, `Sequence` and `FrozenSet` in schema, [#909](https://github.com/pydantic/pydantic/pull/909) by [@samuelcolvin](https://github.com/samuelcolvin) +* update docs for bool missing valid value, [#911](https://github.com/pydantic/pydantic/pull/911) by [@trim21](https://github.com/trim21) +* Better `str`/`repr` logic for `ModelField`, [#912](https://github.com/pydantic/pydantic/pull/912) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix `ConstrainedList`, update schema generation to reflect `min_items` and `max_items` `Field()` arguments, [#917](https://github.com/pydantic/pydantic/pull/917) by [@samuelcolvin](https://github.com/samuelcolvin) +* Allow abstracts sets (eg. dict keys) in the `include` and `exclude` arguments of `dict()`, [#921](https://github.com/pydantic/pydantic/pull/921) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix JSON serialization errors on `ValidationError.json()` by using `pydantic_encoder`, [#922](https://github.com/pydantic/pydantic/pull/922) by [@samuelcolvin](https://github.com/samuelcolvin) +* Clarify usage of `remove_untouched`, improve error message for types with no validators, [#926](https://github.com/pydantic/pydantic/pull/926) by [@retnikt](https://github.com/retnikt) + +## v1.0b2 (2019-10-07) + +* Mark `StrictBool` typecheck as `bool` to allow for default values without mypy errors, [#690](https://github.com/pydantic/pydantic/pull/690) by [@dmontagu](https://github.com/dmontagu) +* Transfer the documentation build from sphinx to mkdocs, re-write much of the documentation, [#856](https://github.com/pydantic/pydantic/pull/856) by [@samuelcolvin](https://github.com/samuelcolvin) +* Add support for custom naming schemes for `GenericModel` subclasses, [#859](https://github.com/pydantic/pydantic/pull/859) by [@dmontagu](https://github.com/dmontagu) +* Add `if TYPE_CHECKING:` to the excluded lines for test coverage, [#874](https://github.com/pydantic/pydantic/pull/874) by [@dmontagu](https://github.com/dmontagu) +* Rename `allow_population_by_alias` to `allow_population_by_field_name`, remove unnecessary warning about it, [#875](https://github.com/pydantic/pydantic/pull/875) by [@samuelcolvin](https://github.com/samuelcolvin) + +## v1.0b1 (2019-10-01) + +* **Breaking Change:** rename `Schema` to `Field`, make it a function to placate mypy, [#577](https://github.com/pydantic/pydantic/pull/577) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change:** modify parsing behavior for `bool`, [#617](https://github.com/pydantic/pydantic/pull/617) by [@dmontagu](https://github.com/dmontagu) +* **Breaking Change:** `get_validators` is no longer recognised, use `__get_validators__`. + `Config.ignore_extra` and `Config.allow_extra` are no longer recognised, use `Config.extra`, [#720](https://github.com/pydantic/pydantic/pull/720) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change:** modify default config settings for `BaseSettings`; `case_insensitive` renamed to `case_sensitive`, + default changed to `case_sensitive = False`, `env_prefix` default changed to `''` - e.g. no prefix, [#721](https://github.com/pydantic/pydantic/pull/721) by [@dmontagu](https://github.com/dmontagu) +* **Breaking change:** Implement `root_validator` and rename root errors from `__obj__` to `__root__`, [#729](https://github.com/pydantic/pydantic/pull/729) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change:** alter the behaviour of `dict(model)` so that sub-models are nolonger + converted to dictionaries, [#733](https://github.com/pydantic/pydantic/pull/733) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking change:** Added `initvars` support to `post_init_post_parse`, [#748](https://github.com/pydantic/pydantic/pull/748) by [@Raphael-C-Almeida](https://github.com/Raphael-C-Almeida) +* **Breaking Change:** Make `BaseModel.json()` only serialize the `__root__` key for models with custom root, [#752](https://github.com/pydantic/pydantic/pull/752) by [@dmontagu](https://github.com/dmontagu) +* **Breaking Change:** complete rewrite of `URL` parsing logic, [#755](https://github.com/pydantic/pydantic/pull/755) by [@samuelcolvin](https://github.com/samuelcolvin) +* **Breaking Change:** preserve superclass annotations for field-determination when not provided in subclass, [#757](https://github.com/pydantic/pydantic/pull/757) by [@dmontagu](https://github.com/dmontagu) +* **Breaking Change:** `BaseSettings` now uses the special `env` settings to define which environment variables to + read, not aliases, [#847](https://github.com/pydantic/pydantic/pull/847) by [@samuelcolvin](https://github.com/samuelcolvin) +* add support for `assert` statements inside validators, [#653](https://github.com/pydantic/pydantic/pull/653) by [@abdusco](https://github.com/abdusco) +* Update documentation to specify the use of `pydantic.dataclasses.dataclass` and subclassing `pydantic.BaseModel`, [#710](https://github.com/pydantic/pydantic/pull/710) by [@maddosaurus](https://github.com/maddosaurus) +* Allow custom JSON decoding and encoding via `json_loads` and `json_dumps` `Config` properties, [#714](https://github.com/pydantic/pydantic/pull/714) by [@samuelcolvin](https://github.com/samuelcolvin) +* make all annotated fields occur in the order declared, [#715](https://github.com/pydantic/pydantic/pull/715) by [@dmontagu](https://github.com/dmontagu) +* use pytest to test `mypy` integration, [#735](https://github.com/pydantic/pydantic/pull/735) by [@dmontagu](https://github.com/dmontagu) +* add `__repr__` method to `ErrorWrapper`, [#738](https://github.com/pydantic/pydantic/pull/738) by [@samuelcolvin](https://github.com/samuelcolvin) +* Added support for `FrozenSet` members in dataclasses, and a better error when attempting to use types from the `typing` module that are not supported by Pydantic, [#745](https://github.com/pydantic/pydantic/pull/745) by [@djpetti](https://github.com/djpetti) +* add documentation for Pycharm Plugin, [#750](https://github.com/pydantic/pydantic/pull/750) by [@koxudaxi](https://github.com/koxudaxi) +* fix broken examples in the docs, [#753](https://github.com/pydantic/pydantic/pull/753) by [@dmontagu](https://github.com/dmontagu) +* moving typing related objects into `pydantic.typing`, [#761](https://github.com/pydantic/pydantic/pull/761) by [@samuelcolvin](https://github.com/samuelcolvin) +* Minor performance improvements to `ErrorWrapper`, `ValidationError` and datetime parsing, [#763](https://github.com/pydantic/pydantic/pull/763) by [@samuelcolvin](https://github.com/samuelcolvin) +* Improvements to `datetime`/`date`/`time`/`timedelta` types: more descriptive errors, + change errors to `value_error` not `type_error`, support bytes, [#766](https://github.com/pydantic/pydantic/pull/766) by [@samuelcolvin](https://github.com/samuelcolvin) +* fix error messages for `Literal` types with multiple allowed values, [#770](https://github.com/pydantic/pydantic/pull/770) by [@dmontagu](https://github.com/dmontagu) +* Improved auto-generated `title` field in JSON schema by converting underscore to space, [#772](https://github.com/pydantic/pydantic/pull/772) by [@skewty](https://github.com/skewty) +* support `mypy --no-implicit-reexport` for dataclasses, also respect `--no-implicit-reexport` in pydantic itself, [#783](https://github.com/pydantic/pydantic/pull/783) by [@samuelcolvin](https://github.com/samuelcolvin) +* add the `PaymentCardNumber` type, [#790](https://github.com/pydantic/pydantic/pull/790) by [@matin](https://github.com/matin) +* Fix const validations for lists, [#794](https://github.com/pydantic/pydantic/pull/794) by [@hmvp](https://github.com/hmvp) +* Set `additionalProperties` to false in schema for models with extra fields disallowed, [#796](https://github.com/pydantic/pydantic/pull/796) by [@Code0x58](https://github.com/Code0x58) +* `EmailStr` validation method now returns local part case-sensitive per RFC 5321, [#798](https://github.com/pydantic/pydantic/pull/798) by [@henriklindgren](https://github.com/henriklindgren) +* Added ability to validate strictness to `ConstrainedFloat`, `ConstrainedInt` and `ConstrainedStr` and added + `StrictFloat` and `StrictInt` classes, [#799](https://github.com/pydantic/pydantic/pull/799) by [@DerRidda](https://github.com/DerRidda) +* Improve handling of `None` and `Optional`, replace `whole` with `each_item` (inverse meaning, default `False`) + on validators, [#803](https://github.com/pydantic/pydantic/pull/803) by [@samuelcolvin](https://github.com/samuelcolvin) +* add support for `Type[T]` type hints, [#807](https://github.com/pydantic/pydantic/pull/807) by [@timonbimon](https://github.com/timonbimon) +* Performance improvements from removing `change_exceptions`, change how pydantic error are constructed, [#819](https://github.com/pydantic/pydantic/pull/819) by [@samuelcolvin](https://github.com/samuelcolvin) +* Fix the error message arising when a `BaseModel`-type model field causes a `ValidationError` during parsing, [#820](https://github.com/pydantic/pydantic/pull/820) by [@dmontagu](https://github.com/dmontagu) +* allow `getter_dict` on `Config`, modify `GetterDict` to be more like a `Mapping` object and thus easier to work with, [#821](https://github.com/pydantic/pydantic/pull/821) by [@samuelcolvin](https://github.com/samuelcolvin) +* Only check `TypeVar` param on base `GenericModel` class, [#842](https://github.com/pydantic/pydantic/pull/842) by [@zpencerq](https://github.com/zpencerq) +* rename `Model._schema_cache` -> `Model.__schema_cache__`, `Model._json_encoder` -> `Model.__json_encoder__`, + `Model._custom_root_type` -> `Model.__custom_root_type__`, [#851](https://github.com/pydantic/pydantic/pull/851) by [@samuelcolvin](https://github.com/samuelcolvin) + + +... see [here](https://docs.pydantic.dev/changelog/#v0322-2019-08-17) for earlier changes. diff --git a/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/RECORD new file mode 100644 index 0000000..57add57 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/RECORD @@ -0,0 +1,196 @@ +pydantic-2.4.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +pydantic-2.4.2.dist-info/METADATA,sha256=_VdGPMWObOl2UKc_NhU_M7fBJ_jKxtJZTRBShsLRDMI,158640 +pydantic-2.4.2.dist-info/RECORD,, +pydantic-2.4.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic-2.4.2.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87 +pydantic-2.4.2.dist-info/licenses/LICENSE,sha256=qeGG88oWte74QxjnpwFyE1GgDLe4rjpDlLZ7SeNSnvM,1129 +pydantic/__init__.py,sha256=GW5aPaYvdifnO_m4FONO3DN-T7ArOb4yXHiE0k8CZak,5814 +pydantic/__pycache__/__init__.cpython-312.pyc,, +pydantic/__pycache__/_migration.cpython-312.pyc,, +pydantic/__pycache__/alias_generators.cpython-312.pyc,, +pydantic/__pycache__/annotated_handlers.cpython-312.pyc,, +pydantic/__pycache__/class_validators.cpython-312.pyc,, +pydantic/__pycache__/color.cpython-312.pyc,, +pydantic/__pycache__/config.cpython-312.pyc,, +pydantic/__pycache__/dataclasses.cpython-312.pyc,, +pydantic/__pycache__/datetime_parse.cpython-312.pyc,, +pydantic/__pycache__/decorator.cpython-312.pyc,, +pydantic/__pycache__/env_settings.cpython-312.pyc,, +pydantic/__pycache__/error_wrappers.cpython-312.pyc,, +pydantic/__pycache__/errors.cpython-312.pyc,, +pydantic/__pycache__/fields.cpython-312.pyc,, +pydantic/__pycache__/functional_serializers.cpython-312.pyc,, +pydantic/__pycache__/functional_validators.cpython-312.pyc,, +pydantic/__pycache__/generics.cpython-312.pyc,, +pydantic/__pycache__/json.cpython-312.pyc,, +pydantic/__pycache__/json_schema.cpython-312.pyc,, +pydantic/__pycache__/main.cpython-312.pyc,, +pydantic/__pycache__/mypy.cpython-312.pyc,, +pydantic/__pycache__/networks.cpython-312.pyc,, +pydantic/__pycache__/parse.cpython-312.pyc,, +pydantic/__pycache__/root_model.cpython-312.pyc,, +pydantic/__pycache__/schema.cpython-312.pyc,, +pydantic/__pycache__/tools.cpython-312.pyc,, +pydantic/__pycache__/type_adapter.cpython-312.pyc,, +pydantic/__pycache__/types.cpython-312.pyc,, +pydantic/__pycache__/typing.cpython-312.pyc,, +pydantic/__pycache__/utils.cpython-312.pyc,, +pydantic/__pycache__/validate_call.cpython-312.pyc,, +pydantic/__pycache__/validators.cpython-312.pyc,, +pydantic/__pycache__/version.cpython-312.pyc,, +pydantic/__pycache__/warnings.cpython-312.pyc,, +pydantic/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic/_internal/__pycache__/__init__.cpython-312.pyc,, +pydantic/_internal/__pycache__/_config.cpython-312.pyc,, +pydantic/_internal/__pycache__/_core_metadata.cpython-312.pyc,, +pydantic/_internal/__pycache__/_core_utils.cpython-312.pyc,, +pydantic/_internal/__pycache__/_dataclasses.cpython-312.pyc,, +pydantic/_internal/__pycache__/_decorators.cpython-312.pyc,, +pydantic/_internal/__pycache__/_decorators_v1.cpython-312.pyc,, +pydantic/_internal/__pycache__/_discriminated_union.cpython-312.pyc,, +pydantic/_internal/__pycache__/_fields.cpython-312.pyc,, +pydantic/_internal/__pycache__/_forward_ref.cpython-312.pyc,, +pydantic/_internal/__pycache__/_generate_schema.cpython-312.pyc,, +pydantic/_internal/__pycache__/_generics.cpython-312.pyc,, +pydantic/_internal/__pycache__/_internal_dataclass.cpython-312.pyc,, +pydantic/_internal/__pycache__/_known_annotated_metadata.cpython-312.pyc,, +pydantic/_internal/__pycache__/_mock_val_ser.cpython-312.pyc,, +pydantic/_internal/__pycache__/_model_construction.cpython-312.pyc,, +pydantic/_internal/__pycache__/_repr.cpython-312.pyc,, +pydantic/_internal/__pycache__/_schema_generation_shared.cpython-312.pyc,, +pydantic/_internal/__pycache__/_std_types_schema.cpython-312.pyc,, +pydantic/_internal/__pycache__/_typing_extra.cpython-312.pyc,, +pydantic/_internal/__pycache__/_utils.cpython-312.pyc,, +pydantic/_internal/__pycache__/_validate_call.cpython-312.pyc,, +pydantic/_internal/__pycache__/_validators.cpython-312.pyc,, +pydantic/_internal/_config.py,sha256=2FscGY4-mMRdt2yKgI103kwEtOufbrPPlrF18NAXI3w,11307 +pydantic/_internal/_core_metadata.py,sha256=Da-e0-DXK__dJvog0e8CZLQ4r_k9RpldG6KQTGrYlHg,3521 +pydantic/_internal/_core_utils.py,sha256=1eJmY3fjg434Hf0K3Anv42XpUUhr0ct1yk_46LqV3-8,24820 +pydantic/_internal/_dataclasses.py,sha256=EvoJILb1yaee3cVEn6XN-aCJGWOBeiBSrRKxJvBsj8w,10707 +pydantic/_internal/_decorators.py,sha256=7zUASoVitYtcIwKinl2jgBvCEvJPetQeTzRFnGej08A,30775 +pydantic/_internal/_decorators_v1.py,sha256=_m9TskhZh9yPUn7Jmy3KbKa3UDREQWyMm5NXyOJM3R8,6266 +pydantic/_internal/_discriminated_union.py,sha256=clzts7UmTAaD6etCd_qjbz1hE9q-WNSS7vqN5pF82vQ,26228 +pydantic/_internal/_fields.py,sha256=JMR0r6aB2TRXTMFMNUUrTx0ZpaJl95vIv4KJWyXRNGo,11903 +pydantic/_internal/_forward_ref.py,sha256=JBimF5v9vkOthrwLQcl0hsLC_HJ11ICAS1d9gImXLr0,425 +pydantic/_internal/_generate_schema.py,sha256=-YYRfpUbQuV9oXb6g05oqFO1dNcJL1XagP-rDB277YM,89996 +pydantic/_internal/_generics.py,sha256=jPhM2BvcLElMO-lhkGk04O7KRZDvKXmf_-S5khrkPms,22173 +pydantic/_internal/_internal_dataclass.py,sha256=NswLpapJY_61NFHBAXYpgFdxMmIX_yE9ttx_pQt_Vp8,207 +pydantic/_internal/_known_annotated_metadata.py,sha256=MK8PFoqhgd9_NtuWkNS1WhjCIYncoZdT3c1u8_UN9nU,16275 +pydantic/_internal/_mock_val_ser.py,sha256=-TYaUZyEDZVL4qwvP4FDsHlOog8YskMBEdv6JhHjEis,4305 +pydantic/_internal/_model_construction.py,sha256=CMn15644KnTXb4WqsLsEdEqbBAvtPHnNzfimY0kHoCo,27043 +pydantic/_internal/_repr.py,sha256=fbIu0pJzS8LTO7twA5eR5wVarZdd38ioceW8lQw1PdQ,4376 +pydantic/_internal/_schema_generation_shared.py,sha256=eRwZ85Gj0FfabYlvM97I5997vhY4Mk3AYQJljK5B3to,4855 +pydantic/_internal/_std_types_schema.py,sha256=jhgYS7W1f5tb-vVmndF4HHYDU9LCE8dSmNU78m4H104,28949 +pydantic/_internal/_typing_extra.py,sha256=_HmXu6PaPDicxyxEr-UGnamRNEdW42Ru81aVUKPZ5Ok,16538 +pydantic/_internal/_utils.py,sha256=xfTCcIQ2yToh-_q3Gn5RrNktHNSoMH8jdEGplQf4WYE,11698 +pydantic/_internal/_validate_call.py,sha256=bgWQ8BYvpa9YcXJkvceUaRD4MUPxImNhQs-7H-T_Y-8,5491 +pydantic/_internal/_validators.py,sha256=6vHPe403edmmxSM8qaHilC4kzjRxYaCyWARq3vexkZ4,10047 +pydantic/_migration.py,sha256=to7sVaLhl003Mv9r5nRxxD4ws1hajPofmkfiewBrFmw,11899 +pydantic/alias_generators.py,sha256=95F9x9P1bzzL7Z3y5F2BvEF9SMUEiT-r69SWlJao_3E,1141 +pydantic/annotated_handlers.py,sha256=iyOdMvz2-G-pe6HJ1a1EpRYn3EnktNyppmlI0YeM-Ss,4346 +pydantic/class_validators.py,sha256=iQz1Tw8FBliqEapmzB7iLkbwkJAeAx5314Vksb_Kj0g,147 +pydantic/color.py,sha256=iaNP0rz9iDuiDtL9KDa382r8Z50T3I1swYdDrh5sBZ4,21493 +pydantic/config.py,sha256=4UBX7VgZu9YjEuH-86tiO66cofOFEmtFBdneX_VTO3Y,24737 +pydantic/dataclasses.py,sha256=6PLxjWUMpYPBZLWEEN3NY43F-Y1Z5SPZWUW4FIDPQ2Y,11490 +pydantic/datetime_parse.py,sha256=5lJpo3-iBTAA9YmMuDLglP-5f2k8etayAXjEi6rfEN0,149 +pydantic/decorator.py,sha256=Qqx1UU19tpRVp05a2NIK5OdpLXN_a84HZPMjt_5BxdE,144 +pydantic/deprecated/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic/deprecated/__pycache__/__init__.cpython-312.pyc,, +pydantic/deprecated/__pycache__/class_validators.cpython-312.pyc,, +pydantic/deprecated/__pycache__/config.cpython-312.pyc,, +pydantic/deprecated/__pycache__/copy_internals.cpython-312.pyc,, +pydantic/deprecated/__pycache__/decorator.cpython-312.pyc,, +pydantic/deprecated/__pycache__/json.cpython-312.pyc,, +pydantic/deprecated/__pycache__/parse.cpython-312.pyc,, +pydantic/deprecated/__pycache__/tools.cpython-312.pyc,, +pydantic/deprecated/class_validators.py,sha256=bm2JLyjA3O6BxTvDQFScvki39HkDysFyswVEA55vGTk,9848 +pydantic/deprecated/config.py,sha256=zgaFWxmg5k6cWUs7ir_OGYS26MQJxRiblp6HPmCy0u4,2612 +pydantic/deprecated/copy_internals.py,sha256=SoUj1MevXt3fnloqNg5wivSUHSDPnuSj_YydzkEMzu0,7595 +pydantic/deprecated/decorator.py,sha256=rYviEY5ZM77OrpdBPaaitrnoFjh4ENCT_oBzvQASWjs,10903 +pydantic/deprecated/json.py,sha256=1hcwvq33cxrwIvUA6vm_rpb0qMdzxMQGiroo0jJHYtU,4465 +pydantic/deprecated/parse.py,sha256=ZJpE4ukxCw-hUUd_PZRYGwkviZopQj6vX6WGUkbBGyY,2481 +pydantic/deprecated/tools.py,sha256=2VRvcQIaJbFywkRvhFITjdkeujfunmMHgjjlioUNJp0,3278 +pydantic/env_settings.py,sha256=quxt8c9TioRg-u74gTW-GrK6r5mFXmn-J5H8FAC9Prc,147 +pydantic/error_wrappers.py,sha256=u9Dz8RgawIw8-rx7G7WGZoRtGptHXyXhHxiN9PbQ58g,149 +pydantic/errors.py,sha256=rUi9iOo26RYJOwbBoGAvv7nlK0li3fnPIfn0OnRn5aA,4595 +pydantic/fields.py,sha256=a-5oxArunmeOWQTx97Vyd7p-ag1Zeu6wHdxv7X6UR4c,45513 +pydantic/functional_serializers.py,sha256=ubcOeapLyEmvq4ZyZe0pWfHNji39Wm1BRXWXJTr177c,10780 +pydantic/functional_validators.py,sha256=2rfnFlsDaEbGjzBz7ATHifN8kaisy9n80Bzi7vpqFxo,20471 +pydantic/generics.py,sha256=T1UIBvpgur_28EIcR9Dc_Wo2r9yntzqdcR-NbnOLXB8,143 +pydantic/json.py,sha256=qk9fHVGWKNrvE-v2WxWLEm66t81JKttbySd9zjy0dnc,139 +pydantic/json_schema.py,sha256=qYckzVoCE1xDLAligNKCTeu9pnYt1hV0mK_t-kyd7KA,100686 +pydantic/main.py,sha256=LcAcw3r5NxkF8JAqkw_Bdy0KNW5XCp3JEttZEWNgHYA,62260 +pydantic/mypy.py,sha256=Jv13Kk4LsamISbbIBfGmdz6t3p7BV_qf_GVmulI18e0,50721 +pydantic/networks.py,sha256=-xGfwCIzWcmeQgN5o_Wu6EeHwa2q7eCmZZ72lOp21PI,20543 +pydantic/parse.py,sha256=BNo_W_gp1xR7kohYdHjF2m_5UNYFQxUt487-NR0RiK8,140 +pydantic/plugin/__init__.py,sha256=6fSBBTPAvLh4V3Y4Su19BlDhXDQF9K3l8LnLMFHXvLI,5184 +pydantic/plugin/__pycache__/__init__.cpython-312.pyc,, +pydantic/plugin/__pycache__/_loader.cpython-312.pyc,, +pydantic/plugin/__pycache__/_schema_validator.cpython-312.pyc,, +pydantic/plugin/_loader.py,sha256=wW8GWTi1m14yNKg4XG9lf_BktsoBTyjO3w-andi7Hig,1972 +pydantic/plugin/_schema_validator.py,sha256=dm4a4ULQ2lrLaXW1JBylv7Dy-qcdRsLHg9-bUYGUG2M,4216 +pydantic/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic/root_model.py,sha256=bl6vvc4gciJitia6jkYeOhKN17t0KI1QX8Ar1PA7MO8,4949 +pydantic/schema.py,sha256=EkbomWuaAdv7C3V8h6xxoT4uJKy3Mwvkg064tOUbvxg,141 +pydantic/tools.py,sha256=YB4vzOx4g7reKUM_s5oTXIGxC5LGBnGsXdVICSRuh7g,140 +pydantic/type_adapter.py,sha256=iPjao2xGID7l1UyQya3ShVHwSC7N5KzpPT0ant7hUXI,18818 +pydantic/types.py,sha256=f_3sNoudJTUoVoXUI4T3JZ80pZA9eYEiISqNWsiqQcU,72231 +pydantic/typing.py,sha256=sPkx0hi_RX7qSV3BB0zzHd8ZuAKbRRI37XJI4av_HzQ,137 +pydantic/utils.py,sha256=twRV5SqiguiCrOA9GvrKvOG-TThfWYb7mEXDVXFZp2s,140 +pydantic/v1/__init__.py,sha256=iTu8CwWWvn6zM_zYJtqhie24PImW25zokitz_06kDYw,2771 +pydantic/v1/__pycache__/__init__.cpython-312.pyc,, +pydantic/v1/__pycache__/_hypothesis_plugin.cpython-312.pyc,, +pydantic/v1/__pycache__/annotated_types.cpython-312.pyc,, +pydantic/v1/__pycache__/class_validators.cpython-312.pyc,, +pydantic/v1/__pycache__/color.cpython-312.pyc,, +pydantic/v1/__pycache__/config.cpython-312.pyc,, +pydantic/v1/__pycache__/dataclasses.cpython-312.pyc,, +pydantic/v1/__pycache__/datetime_parse.cpython-312.pyc,, +pydantic/v1/__pycache__/decorator.cpython-312.pyc,, +pydantic/v1/__pycache__/env_settings.cpython-312.pyc,, +pydantic/v1/__pycache__/error_wrappers.cpython-312.pyc,, +pydantic/v1/__pycache__/errors.cpython-312.pyc,, +pydantic/v1/__pycache__/fields.cpython-312.pyc,, +pydantic/v1/__pycache__/generics.cpython-312.pyc,, +pydantic/v1/__pycache__/json.cpython-312.pyc,, +pydantic/v1/__pycache__/main.cpython-312.pyc,, +pydantic/v1/__pycache__/mypy.cpython-312.pyc,, +pydantic/v1/__pycache__/networks.cpython-312.pyc,, +pydantic/v1/__pycache__/parse.cpython-312.pyc,, +pydantic/v1/__pycache__/schema.cpython-312.pyc,, +pydantic/v1/__pycache__/tools.cpython-312.pyc,, +pydantic/v1/__pycache__/types.cpython-312.pyc,, +pydantic/v1/__pycache__/typing.cpython-312.pyc,, +pydantic/v1/__pycache__/utils.cpython-312.pyc,, +pydantic/v1/__pycache__/validators.cpython-312.pyc,, +pydantic/v1/__pycache__/version.cpython-312.pyc,, +pydantic/v1/_hypothesis_plugin.py,sha256=gILcyAEfZ3u9YfKxtDxkReLpakjMou1VWC3FEcXmJgQ,14844 +pydantic/v1/annotated_types.py,sha256=dJTDUyPj4QJj4rDcNkt9xDUMGEkAnuWzDeGE2q7Wxrc,3124 +pydantic/v1/class_validators.py,sha256=0BZx0Ft19cREVHEOaA6wf_E3A0bTL4wQIGzeOinVatg,14595 +pydantic/v1/color.py,sha256=cGzck7kSD5beBkOMhda4bfTICput6dMx8GGpEU5SK5Y,16811 +pydantic/v1/config.py,sha256=h5ceeZ9HzDjUv0IZNYQoza0aNGFVo22iszY-6s0a3eM,6477 +pydantic/v1/dataclasses.py,sha256=roiVI64yCN68aMRxHEw615qgrcdEwpHAHfTEz_HlAtQ,17515 +pydantic/v1/datetime_parse.py,sha256=DhGfkbG4Vs5Oyxq3u8jM-7gFrbuUKsn-4aG2DJDJbHw,7714 +pydantic/v1/decorator.py,sha256=wzuIuKKHVjaiE97YBctCU0Vho0VRlUO-aVu1IUEczFE,10263 +pydantic/v1/env_settings.py,sha256=4PWxPYeK5jt59JJ4QGb90qU8pfC7qgGX44UESTmXdpE,14039 +pydantic/v1/error_wrappers.py,sha256=NvfemFFYx9EFLXBGeJ07MKT2MJQAJFFlx_bIoVpqgVI,5142 +pydantic/v1/errors.py,sha256=f93z30S4s5bJEl8JXh-zFCAtLDCko9ze2hKTkOimaa8,17693 +pydantic/v1/fields.py,sha256=fxTn7A17AXAHuDdz8HzFSjb8qfWhRoruwc2VOzRpUdM,50488 +pydantic/v1/generics.py,sha256=n5TTgh3EHkG1Xw3eY9A143bUN11_4m57Db5u49hkGJ8,17805 +pydantic/v1/json.py,sha256=B0gJ2WmPqw-6fsvPmgu-rwhhOy4E0JpbbYjC8HR01Ho,3346 +pydantic/v1/main.py,sha256=kC5_bcJc4zoLhRUVvNq67ACmGmRtQFvyRHDub6cw5ik,44378 +pydantic/v1/mypy.py,sha256=G8yQLLt6CodoTvGl84MP3ZpdInBtc0QoaLJ7iArHXNU,38745 +pydantic/v1/networks.py,sha256=TeV9FvCYg4ALk8j7dU1q6Ntze7yaUrCHQFEDJDnq1NI,22059 +pydantic/v1/parse.py,sha256=rrVhaWLK8t03rT3oxvC6uRLuTF5iZ2NKGvGqs4iQEM0,1810 +pydantic/v1/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic/v1/schema.py,sha256=ZqIQQpjxohG0hP7Zz5W401fpm4mYNu_Crmvr5HlgvMA,47615 +pydantic/v1/tools.py,sha256=ELC66w6UaU_HzAGfJBSIP47Aq9ZGkGiWPMLkkTs6VrI,2826 +pydantic/v1/types.py,sha256=S1doibLP6gg6TVZU9TwNfL2E10mFhZwCzd9WZK8Kilo,35380 +pydantic/v1/typing.py,sha256=5_C_fiUvWiAzW3MBJaHeuy2s3Hi52rFMxTfNPHv9_os,18996 +pydantic/v1/utils.py,sha256=5w7Q3N_Fqg5H9__JQDaumw9N3EFdlc7galEsCGxEDN0,25809 +pydantic/v1/validators.py,sha256=T-t9y9L_68El9p4PYkEVGEjpetNV6luav8Iwu9iTLkM,21887 +pydantic/v1/version.py,sha256=yUT25-EekWoBCsQwsA0kQTvIKOBUST7feqZT-TrbyX4,1039 +pydantic/validate_call.py,sha256=N3R7_GEjjvO6-M4ev6KHXzkhXu4gvF8bDEMBzmQlM3Q,1780 +pydantic/validators.py,sha256=3oPhHojp9UD3PdEZpMYMkxeLGUAabRm__zera8_T92w,145 +pydantic/version.py,sha256=0W9ccbJkMBsxIUoZyZxP00c4KGPagAjYMOnu96DVedM,2307 +pydantic/warnings.py,sha256=giN1ynj2Jh4yUrPFaweJgFoxtDY1vC9l3gpbdb5mFu0,1947 diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/WHEEL b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/WHEEL similarity index 67% rename from venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/WHEEL index 12228d4..ba1a8af 100644 --- a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/WHEEL @@ -1,4 +1,4 @@ Wheel-Version: 1.0 -Generator: hatchling 1.27.0 +Generator: hatchling 1.18.0 Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/licenses/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic-2.11.9.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/pydantic-2.4.2.dist-info/licenses/LICENSE diff --git a/venv/lib/python3.12/site-packages/pydantic/__init__.py b/venv/lib/python3.12/site-packages/pydantic/__init__.py index 716ca40..ec7b0b0 100644 --- a/venv/lib/python3.12/site-packages/pydantic/__init__.py +++ b/venv/lib/python3.12/site-packages/pydantic/__init__.py @@ -1,73 +1,59 @@ import typing -from importlib import import_module -from warnings import warn +import pydantic_core +from pydantic_core.core_schema import ( + FieldSerializationInfo, + SerializationInfo, + SerializerFunctionWrapHandler, + ValidationInfo, + ValidatorFunctionWrapHandler, +) + +from . import dataclasses +from ._internal._generate_schema import GenerateSchema as GenerateSchema from ._migration import getattr_migration +from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler +from .config import ConfigDict +from .errors import * +from .fields import AliasChoices, AliasPath, Field, PrivateAttr, computed_field +from .functional_serializers import PlainSerializer, SerializeAsAny, WrapSerializer, field_serializer, model_serializer +from .functional_validators import ( + AfterValidator, + BeforeValidator, + InstanceOf, + PlainValidator, + SkipValidation, + WrapValidator, + field_validator, + model_validator, +) +from .json_schema import WithJsonSchema +from .main import * +from .networks import * +from .type_adapter import TypeAdapter +from .types import * +from .validate_call import validate_call from .version import VERSION +from .warnings import * + +__version__ = VERSION + +# this encourages pycharm to import `ValidationError` from here, not pydantic_core +ValidationError = pydantic_core.ValidationError if typing.TYPE_CHECKING: - # import of virtually everything is supported via `__getattr__` below, - # but we need them here for type checking and IDE support - import pydantic_core - from pydantic_core.core_schema import ( - FieldSerializationInfo, - SerializationInfo, - SerializerFunctionWrapHandler, - ValidationInfo, - ValidatorFunctionWrapHandler, - ) - - from . import dataclasses - from .aliases import AliasChoices, AliasGenerator, AliasPath - from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler - from .config import ConfigDict, with_config - from .errors import * - from .fields import Field, PrivateAttr, computed_field - from .functional_serializers import ( - PlainSerializer, - SerializeAsAny, - WrapSerializer, - field_serializer, - model_serializer, - ) - from .functional_validators import ( - AfterValidator, - BeforeValidator, - InstanceOf, - ModelWrapValidatorHandler, - PlainValidator, - SkipValidation, - WrapValidator, - field_validator, - model_validator, - ) - from .json_schema import WithJsonSchema - from .main import * - from .networks import * - from .type_adapter import TypeAdapter - from .types import * - from .validate_call_decorator import validate_call - from .warnings import ( - PydanticDeprecatedSince20, - PydanticDeprecatedSince26, - PydanticDeprecatedSince29, - PydanticDeprecatedSince210, - PydanticDeprecatedSince211, - PydanticDeprecationWarning, - PydanticExperimentalWarning, - ) - - # this encourages pycharm to import `ValidationError` from here, not pydantic_core - ValidationError = pydantic_core.ValidationError + # these are imported via `__getattr__` below, but we need them here for type checking and IDE support from .deprecated.class_validators import root_validator, validator from .deprecated.config import BaseConfig, Extra from .deprecated.tools import * from .root_model import RootModel -__version__ = VERSION -__all__ = ( +__all__ = [ # dataclasses 'dataclasses', + # pydantic_core.core_schema + 'ValidationInfo', + 'ValidatorFunctionWrapHandler', # functional validators 'field_validator', 'model_validator', @@ -77,8 +63,6 @@ __all__ = ( 'WrapValidator', 'SkipValidation', 'InstanceOf', - 'ModelWrapValidatorHandler', - # JSON Schema 'WithJsonSchema', # deprecated V1 functional validators, these are imported via `__getattr__` below 'root_validator', @@ -89,14 +73,18 @@ __all__ = ( 'PlainSerializer', 'SerializeAsAny', 'WrapSerializer', + 'FieldSerializationInfo', + 'SerializationInfo', + 'SerializerFunctionWrapHandler', # config 'ConfigDict', - 'with_config', # deprecated V1 config, these are imported via `__getattr__` below 'BaseConfig', 'Extra', # validate_call 'validate_call', + # pydantic_core errors + 'ValidationError', # errors 'PydanticErrorCodes', 'PydanticUserError', @@ -104,15 +92,11 @@ __all__ = ( 'PydanticImportError', 'PydanticUndefinedAnnotation', 'PydanticInvalidForJsonSchema', - 'PydanticForbiddenQualifier', # fields + 'AliasPath', + 'AliasChoices', 'Field', 'computed_field', - 'PrivateAttr', - # alias - 'AliasChoices', - 'AliasGenerator', - 'AliasPath', # main 'BaseModel', 'create_model', @@ -121,9 +105,6 @@ __all__ = ( 'AnyHttpUrl', 'FileUrl', 'HttpUrl', - 'FtpUrl', - 'WebsocketUrl', - 'AnyWebsocketUrl', 'UrlConstraints', 'EmailStr', 'NameEmail', @@ -136,11 +117,8 @@ __all__ = ( 'RedisDsn', 'MongoDsn', 'KafkaDsn', - 'NatsDsn', 'MySQLDsn', 'MariaDBDsn', - 'ClickHouseDsn', - 'SnowflakeDsn', 'validate_email', # root_model 'RootModel', @@ -175,22 +153,18 @@ __all__ = ( 'UUID3', 'UUID4', 'UUID5', - 'UUID6', - 'UUID7', - 'UUID8', 'FilePath', 'DirectoryPath', 'NewPath', 'Json', - 'Secret', 'SecretStr', 'SecretBytes', - 'SocketPath', 'StrictBool', 'StrictBytes', 'StrictInt', 'StrictFloat', 'PaymentCardNumber', + 'PrivateAttr', 'ByteSize', 'PastDate', 'FutureDate', @@ -208,238 +182,44 @@ __all__ = ( 'Base64UrlBytes', 'Base64UrlStr', 'GetPydanticSchema', - 'Tag', - 'Discriminator', - 'JsonValue', - 'FailFast', # type_adapter 'TypeAdapter', # version - '__version__', 'VERSION', # warnings 'PydanticDeprecatedSince20', - 'PydanticDeprecatedSince26', - 'PydanticDeprecatedSince29', - 'PydanticDeprecatedSince210', - 'PydanticDeprecatedSince211', 'PydanticDeprecationWarning', - 'PydanticExperimentalWarning', # annotated handlers 'GetCoreSchemaHandler', 'GetJsonSchemaHandler', - # pydantic_core - 'ValidationError', - 'ValidationInfo', - 'SerializationInfo', - 'ValidatorFunctionWrapHandler', - 'FieldSerializationInfo', - 'SerializerFunctionWrapHandler', - 'OnErrorOmit', -) + 'GenerateSchema', +] # A mapping of {: (package, )} defining dynamic imports _dynamic_imports: 'dict[str, tuple[str, str]]' = { - 'dataclasses': (__spec__.parent, '__module__'), - # functional validators - 'field_validator': (__spec__.parent, '.functional_validators'), - 'model_validator': (__spec__.parent, '.functional_validators'), - 'AfterValidator': (__spec__.parent, '.functional_validators'), - 'BeforeValidator': (__spec__.parent, '.functional_validators'), - 'PlainValidator': (__spec__.parent, '.functional_validators'), - 'WrapValidator': (__spec__.parent, '.functional_validators'), - 'SkipValidation': (__spec__.parent, '.functional_validators'), - 'InstanceOf': (__spec__.parent, '.functional_validators'), - 'ModelWrapValidatorHandler': (__spec__.parent, '.functional_validators'), - # JSON Schema - 'WithJsonSchema': (__spec__.parent, '.json_schema'), - # functional serializers - 'field_serializer': (__spec__.parent, '.functional_serializers'), - 'model_serializer': (__spec__.parent, '.functional_serializers'), - 'PlainSerializer': (__spec__.parent, '.functional_serializers'), - 'SerializeAsAny': (__spec__.parent, '.functional_serializers'), - 'WrapSerializer': (__spec__.parent, '.functional_serializers'), - # config - 'ConfigDict': (__spec__.parent, '.config'), - 'with_config': (__spec__.parent, '.config'), - # validate call - 'validate_call': (__spec__.parent, '.validate_call_decorator'), - # errors - 'PydanticErrorCodes': (__spec__.parent, '.errors'), - 'PydanticUserError': (__spec__.parent, '.errors'), - 'PydanticSchemaGenerationError': (__spec__.parent, '.errors'), - 'PydanticImportError': (__spec__.parent, '.errors'), - 'PydanticUndefinedAnnotation': (__spec__.parent, '.errors'), - 'PydanticInvalidForJsonSchema': (__spec__.parent, '.errors'), - 'PydanticForbiddenQualifier': (__spec__.parent, '.errors'), - # fields - 'Field': (__spec__.parent, '.fields'), - 'computed_field': (__spec__.parent, '.fields'), - 'PrivateAttr': (__spec__.parent, '.fields'), - # alias - 'AliasChoices': (__spec__.parent, '.aliases'), - 'AliasGenerator': (__spec__.parent, '.aliases'), - 'AliasPath': (__spec__.parent, '.aliases'), - # main - 'BaseModel': (__spec__.parent, '.main'), - 'create_model': (__spec__.parent, '.main'), - # network - 'AnyUrl': (__spec__.parent, '.networks'), - 'AnyHttpUrl': (__spec__.parent, '.networks'), - 'FileUrl': (__spec__.parent, '.networks'), - 'HttpUrl': (__spec__.parent, '.networks'), - 'FtpUrl': (__spec__.parent, '.networks'), - 'WebsocketUrl': (__spec__.parent, '.networks'), - 'AnyWebsocketUrl': (__spec__.parent, '.networks'), - 'UrlConstraints': (__spec__.parent, '.networks'), - 'EmailStr': (__spec__.parent, '.networks'), - 'NameEmail': (__spec__.parent, '.networks'), - 'IPvAnyAddress': (__spec__.parent, '.networks'), - 'IPvAnyInterface': (__spec__.parent, '.networks'), - 'IPvAnyNetwork': (__spec__.parent, '.networks'), - 'PostgresDsn': (__spec__.parent, '.networks'), - 'CockroachDsn': (__spec__.parent, '.networks'), - 'AmqpDsn': (__spec__.parent, '.networks'), - 'RedisDsn': (__spec__.parent, '.networks'), - 'MongoDsn': (__spec__.parent, '.networks'), - 'KafkaDsn': (__spec__.parent, '.networks'), - 'NatsDsn': (__spec__.parent, '.networks'), - 'MySQLDsn': (__spec__.parent, '.networks'), - 'MariaDBDsn': (__spec__.parent, '.networks'), - 'ClickHouseDsn': (__spec__.parent, '.networks'), - 'SnowflakeDsn': (__spec__.parent, '.networks'), - 'validate_email': (__spec__.parent, '.networks'), - # root_model - 'RootModel': (__spec__.parent, '.root_model'), - # types - 'Strict': (__spec__.parent, '.types'), - 'StrictStr': (__spec__.parent, '.types'), - 'conbytes': (__spec__.parent, '.types'), - 'conlist': (__spec__.parent, '.types'), - 'conset': (__spec__.parent, '.types'), - 'confrozenset': (__spec__.parent, '.types'), - 'constr': (__spec__.parent, '.types'), - 'StringConstraints': (__spec__.parent, '.types'), - 'ImportString': (__spec__.parent, '.types'), - 'conint': (__spec__.parent, '.types'), - 'PositiveInt': (__spec__.parent, '.types'), - 'NegativeInt': (__spec__.parent, '.types'), - 'NonNegativeInt': (__spec__.parent, '.types'), - 'NonPositiveInt': (__spec__.parent, '.types'), - 'confloat': (__spec__.parent, '.types'), - 'PositiveFloat': (__spec__.parent, '.types'), - 'NegativeFloat': (__spec__.parent, '.types'), - 'NonNegativeFloat': (__spec__.parent, '.types'), - 'NonPositiveFloat': (__spec__.parent, '.types'), - 'FiniteFloat': (__spec__.parent, '.types'), - 'condecimal': (__spec__.parent, '.types'), - 'condate': (__spec__.parent, '.types'), - 'UUID1': (__spec__.parent, '.types'), - 'UUID3': (__spec__.parent, '.types'), - 'UUID4': (__spec__.parent, '.types'), - 'UUID5': (__spec__.parent, '.types'), - 'UUID6': (__spec__.parent, '.types'), - 'UUID7': (__spec__.parent, '.types'), - 'UUID8': (__spec__.parent, '.types'), - 'FilePath': (__spec__.parent, '.types'), - 'DirectoryPath': (__spec__.parent, '.types'), - 'NewPath': (__spec__.parent, '.types'), - 'Json': (__spec__.parent, '.types'), - 'Secret': (__spec__.parent, '.types'), - 'SecretStr': (__spec__.parent, '.types'), - 'SecretBytes': (__spec__.parent, '.types'), - 'StrictBool': (__spec__.parent, '.types'), - 'StrictBytes': (__spec__.parent, '.types'), - 'StrictInt': (__spec__.parent, '.types'), - 'StrictFloat': (__spec__.parent, '.types'), - 'PaymentCardNumber': (__spec__.parent, '.types'), - 'ByteSize': (__spec__.parent, '.types'), - 'PastDate': (__spec__.parent, '.types'), - 'SocketPath': (__spec__.parent, '.types'), - 'FutureDate': (__spec__.parent, '.types'), - 'PastDatetime': (__spec__.parent, '.types'), - 'FutureDatetime': (__spec__.parent, '.types'), - 'AwareDatetime': (__spec__.parent, '.types'), - 'NaiveDatetime': (__spec__.parent, '.types'), - 'AllowInfNan': (__spec__.parent, '.types'), - 'EncoderProtocol': (__spec__.parent, '.types'), - 'EncodedBytes': (__spec__.parent, '.types'), - 'EncodedStr': (__spec__.parent, '.types'), - 'Base64Encoder': (__spec__.parent, '.types'), - 'Base64Bytes': (__spec__.parent, '.types'), - 'Base64Str': (__spec__.parent, '.types'), - 'Base64UrlBytes': (__spec__.parent, '.types'), - 'Base64UrlStr': (__spec__.parent, '.types'), - 'GetPydanticSchema': (__spec__.parent, '.types'), - 'Tag': (__spec__.parent, '.types'), - 'Discriminator': (__spec__.parent, '.types'), - 'JsonValue': (__spec__.parent, '.types'), - 'OnErrorOmit': (__spec__.parent, '.types'), - 'FailFast': (__spec__.parent, '.types'), - # type_adapter - 'TypeAdapter': (__spec__.parent, '.type_adapter'), - # warnings - 'PydanticDeprecatedSince20': (__spec__.parent, '.warnings'), - 'PydanticDeprecatedSince26': (__spec__.parent, '.warnings'), - 'PydanticDeprecatedSince29': (__spec__.parent, '.warnings'), - 'PydanticDeprecatedSince210': (__spec__.parent, '.warnings'), - 'PydanticDeprecatedSince211': (__spec__.parent, '.warnings'), - 'PydanticDeprecationWarning': (__spec__.parent, '.warnings'), - 'PydanticExperimentalWarning': (__spec__.parent, '.warnings'), - # annotated handlers - 'GetCoreSchemaHandler': (__spec__.parent, '.annotated_handlers'), - 'GetJsonSchemaHandler': (__spec__.parent, '.annotated_handlers'), - # pydantic_core stuff - 'ValidationError': ('pydantic_core', '.'), - 'ValidationInfo': ('pydantic_core', '.core_schema'), - 'SerializationInfo': ('pydantic_core', '.core_schema'), - 'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'), - 'FieldSerializationInfo': ('pydantic_core', '.core_schema'), - 'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'), - # deprecated, mostly not included in __all__ - 'root_validator': (__spec__.parent, '.deprecated.class_validators'), - 'validator': (__spec__.parent, '.deprecated.class_validators'), - 'BaseConfig': (__spec__.parent, '.deprecated.config'), - 'Extra': (__spec__.parent, '.deprecated.config'), - 'parse_obj_as': (__spec__.parent, '.deprecated.tools'), - 'schema_of': (__spec__.parent, '.deprecated.tools'), - 'schema_json_of': (__spec__.parent, '.deprecated.tools'), - # deprecated dynamic imports + 'RootModel': (__package__, '.root_model'), + 'root_validator': (__package__, '.deprecated.class_validators'), + 'validator': (__package__, '.deprecated.class_validators'), + 'BaseConfig': (__package__, '.deprecated.config'), + 'Extra': (__package__, '.deprecated.config'), + 'parse_obj_as': (__package__, '.deprecated.tools'), + 'schema_of': (__package__, '.deprecated.tools'), + 'schema_json_of': (__package__, '.deprecated.tools'), + # FieldValidationInfo is deprecated, and hidden behind module a `__getattr__` 'FieldValidationInfo': ('pydantic_core', '.core_schema'), - 'GenerateSchema': (__spec__.parent, '._internal._generate_schema'), } -_deprecated_dynamic_imports = {'FieldValidationInfo', 'GenerateSchema'} _getattr_migration = getattr_migration(__name__) def __getattr__(attr_name: str) -> object: - if attr_name in _deprecated_dynamic_imports: - warn( - f'Importing {attr_name} from `pydantic` is deprecated. This feature is either no longer supported, or is not public.', - DeprecationWarning, - stacklevel=2, - ) - dynamic_attr = _dynamic_imports.get(attr_name) if dynamic_attr is None: return _getattr_migration(attr_name) package, module_name = dynamic_attr - if module_name == '__module__': - result = import_module(f'.{attr_name}', package=package) - globals()[attr_name] = result - return result - else: - module = import_module(module_name, package=package) - result = getattr(module, attr_name) - g = globals() - for k, (_, v_module_name) in _dynamic_imports.items(): - if v_module_name == module_name and k not in _deprecated_dynamic_imports: - g[k] = getattr(module, k) - return result + from importlib import import_module - -def __dir__() -> 'list[str]': - return list(__all__) + module = import_module(module_name, package=package) + return getattr(module, attr_name) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_config.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_config.py index fe71264..61d3c30 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_config.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_config.py @@ -1,23 +1,25 @@ from __future__ import annotations as _annotations import warnings -from contextlib import contextmanager -from re import Pattern +from contextlib import contextmanager, nullcontext from typing import ( TYPE_CHECKING, Any, Callable, - Literal, + ContextManager, + Iterator, cast, ) from pydantic_core import core_schema -from typing_extensions import Self +from typing_extensions import ( + Literal, + Self, +) -from ..aliases import AliasGenerator -from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable +from ..config import ConfigDict, ExtraValues, JsonEncoder, JsonSchemaExtraCallable from ..errors import PydanticUserError -from ..warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince210 +from ..warnings import PydanticDeprecatedSince20 if not TYPE_CHECKING: # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 @@ -26,7 +28,6 @@ if not TYPE_CHECKING: if TYPE_CHECKING: from .._internal._schema_generation_shared import GenerateSchema - from ..fields import ComputedFieldInfo, FieldInfo DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.' @@ -56,12 +57,10 @@ class ConfigWrapper: # whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names # to construct error `loc`s, default `True` loc_by_alias: bool - alias_generator: Callable[[str], str] | AliasGenerator | None - model_title_generator: Callable[[type], str] | None - field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None + alias_generator: Callable[[str], str] | None ignored_types: tuple[type, ...] allow_inf_nan: bool - json_schema_extra: JsonDict | JsonSchemaExtraCallable | None + json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None json_encoders: dict[type[object], JsonEncoder] | None # new in V2 @@ -69,13 +68,11 @@ class ConfigWrapper: # whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never' revalidate_instances: Literal['always', 'never', 'subclass-instances'] ser_json_timedelta: Literal['iso8601', 'float'] - ser_json_bytes: Literal['utf8', 'base64', 'hex'] - val_json_bytes: Literal['utf8', 'base64', 'hex'] - ser_json_inf_nan: Literal['null', 'constants', 'strings'] + ser_json_bytes: Literal['utf8', 'base64'] # whether to validate default values during validation, default False validate_default: bool validate_return: bool - protected_namespaces: tuple[str | Pattern[str], ...] + protected_namespaces: tuple[str, ...] hide_input_in_errors: bool defer_build: bool plugin_settings: dict[str, object] | None @@ -83,13 +80,6 @@ class ConfigWrapper: json_schema_serialization_defaults_required: bool json_schema_mode_override: Literal['validation', 'serialization', None] coerce_numbers_to_str: bool - regex_engine: Literal['rust-regex', 'python-re'] - validation_error_cause: bool - use_attribute_docstrings: bool - cache_strings: bool | Literal['all', 'keys', 'none'] - validate_by_alias: bool - validate_by_name: bool - serialize_by_alias: bool def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True): if check: @@ -123,19 +113,13 @@ class ConfigWrapper: config_class_from_namespace = namespace.get('Config') config_dict_from_namespace = namespace.get('model_config') - raw_annotations = namespace.get('__annotations__', {}) - if raw_annotations.get('model_config') and config_dict_from_namespace is None: - raise PydanticUserError( - '`model_config` cannot be used as a model field name. Use `model_config` for model configuration.', - code='model-config-invalid-field-name', - ) - if config_class_from_namespace and config_dict_from_namespace: raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both') config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace) - config_new.update(config_from_namespace) + if config_from_namespace is not None: + config_new.update(config_from_namespace) for k in list(kwargs.keys()): if k in config_keys: @@ -144,7 +128,7 @@ class ConfigWrapper: return cls(config_new) # we don't show `__getattr__` to type checkers so missing attributes cause errors - if not TYPE_CHECKING: # pragma: no branch + if not TYPE_CHECKING: def __getattr__(self, name: str) -> Any: try: @@ -155,77 +139,46 @@ class ConfigWrapper: except KeyError: raise AttributeError(f'Config has no attribute {name!r}') from None - def core_config(self, title: str | None) -> core_schema.CoreConfig: - """Create a pydantic-core config. + def core_config(self, obj: Any) -> core_schema.CoreConfig: + """Create a pydantic-core config, `obj` is just used to populate `title` if not set in config. + + Pass `obj=None` if you do not want to attempt to infer the `title`. We don't use getattr here since we don't want to populate with defaults. Args: - title: The title to use if not set in config. + obj: An object used to populate `title` if not set in config. Returns: A `CoreConfig` object created from config. """ - config = self.config_dict - if config.get('schema_generator') is not None: - warnings.warn( - 'The `schema_generator` setting has been deprecated since v2.10. This setting no longer has any effect.', - PydanticDeprecatedSince210, - stacklevel=2, + def dict_not_none(**kwargs: Any) -> Any: + return {k: v for k, v in kwargs.items() if v is not None} + + core_config = core_schema.CoreConfig( + **dict_not_none( + title=self.config_dict.get('title') or (obj and obj.__name__), + extra_fields_behavior=self.config_dict.get('extra'), + allow_inf_nan=self.config_dict.get('allow_inf_nan'), + populate_by_name=self.config_dict.get('populate_by_name'), + str_strip_whitespace=self.config_dict.get('str_strip_whitespace'), + str_to_lower=self.config_dict.get('str_to_lower'), + str_to_upper=self.config_dict.get('str_to_upper'), + strict=self.config_dict.get('strict'), + ser_json_timedelta=self.config_dict.get('ser_json_timedelta'), + ser_json_bytes=self.config_dict.get('ser_json_bytes'), + from_attributes=self.config_dict.get('from_attributes'), + loc_by_alias=self.config_dict.get('loc_by_alias'), + revalidate_instances=self.config_dict.get('revalidate_instances'), + validate_default=self.config_dict.get('validate_default'), + str_max_length=self.config_dict.get('str_max_length'), + str_min_length=self.config_dict.get('str_min_length'), + hide_input_in_errors=self.config_dict.get('hide_input_in_errors'), + coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'), ) - - if (populate_by_name := config.get('populate_by_name')) is not None: - # We include this patch for backwards compatibility purposes, but this config setting will be deprecated in v3.0, and likely removed in v4.0. - # Thus, the above warning and this patch can be removed then as well. - if config.get('validate_by_name') is None: - config['validate_by_alias'] = True - config['validate_by_name'] = populate_by_name - - # We dynamically patch validate_by_name to be True if validate_by_alias is set to False - # and validate_by_name is not explicitly set. - if config.get('validate_by_alias') is False and config.get('validate_by_name') is None: - config['validate_by_name'] = True - - if (not config.get('validate_by_alias', True)) and (not config.get('validate_by_name', False)): - raise PydanticUserError( - 'At least one of `validate_by_alias` or `validate_by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - - return core_schema.CoreConfig( - **{ # pyright: ignore[reportArgumentType] - k: v - for k, v in ( - ('title', config.get('title') or title or None), - ('extra_fields_behavior', config.get('extra')), - ('allow_inf_nan', config.get('allow_inf_nan')), - ('str_strip_whitespace', config.get('str_strip_whitespace')), - ('str_to_lower', config.get('str_to_lower')), - ('str_to_upper', config.get('str_to_upper')), - ('strict', config.get('strict')), - ('ser_json_timedelta', config.get('ser_json_timedelta')), - ('ser_json_bytes', config.get('ser_json_bytes')), - ('val_json_bytes', config.get('val_json_bytes')), - ('ser_json_inf_nan', config.get('ser_json_inf_nan')), - ('from_attributes', config.get('from_attributes')), - ('loc_by_alias', config.get('loc_by_alias')), - ('revalidate_instances', config.get('revalidate_instances')), - ('validate_default', config.get('validate_default')), - ('str_max_length', config.get('str_max_length')), - ('str_min_length', config.get('str_min_length')), - ('hide_input_in_errors', config.get('hide_input_in_errors')), - ('coerce_numbers_to_str', config.get('coerce_numbers_to_str')), - ('regex_engine', config.get('regex_engine')), - ('validation_error_cause', config.get('validation_error_cause')), - ('cache_strings', config.get('cache_strings')), - ('validate_by_alias', config.get('validate_by_alias')), - ('validate_by_name', config.get('validate_by_name')), - ('serialize_by_alias', config.get('serialize_by_alias')), - ) - if v is not None - } ) + return core_config def __repr__(self): c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items()) @@ -242,20 +195,22 @@ class ConfigWrapperStack: def tail(self) -> ConfigWrapper: return self._config_wrapper_stack[-1] - @contextmanager - def push(self, config_wrapper: ConfigWrapper | ConfigDict | None): + def push(self, config_wrapper: ConfigWrapper | ConfigDict | None) -> ContextManager[None]: if config_wrapper is None: - yield - return + return nullcontext() if not isinstance(config_wrapper, ConfigWrapper): config_wrapper = ConfigWrapper(config_wrapper, check=False) - self._config_wrapper_stack.append(config_wrapper) - try: - yield - finally: - self._config_wrapper_stack.pop() + @contextmanager + def _context_manager() -> Iterator[None]: + self._config_wrapper_stack.append(config_wrapper) + try: + yield + finally: + self._config_wrapper_stack.pop() + + return _context_manager() config_defaults = ConfigDict( @@ -275,8 +230,6 @@ config_defaults = ConfigDict( from_attributes=False, loc_by_alias=True, alias_generator=None, - model_title_generator=None, - field_title_generator=None, ignored_types=(), allow_inf_nan=True, json_schema_extra=None, @@ -284,26 +237,17 @@ config_defaults = ConfigDict( revalidate_instances='never', ser_json_timedelta='iso8601', ser_json_bytes='utf8', - val_json_bytes='utf8', - ser_json_inf_nan='null', validate_default=False, validate_return=False, - protected_namespaces=('model_validate', 'model_dump'), + protected_namespaces=('model_',), hide_input_in_errors=False, json_encoders=None, defer_build=False, - schema_generator=None, plugin_settings=None, + schema_generator=None, json_schema_serialization_defaults_required=False, json_schema_mode_override=None, coerce_numbers_to_str=False, - regex_engine='rust-regex', - validation_error_cause=False, - use_attribute_docstrings=False, - cache_strings=True, - validate_by_alias=True, - validate_by_name=False, - serialize_by_alias=False, ) @@ -344,7 +288,7 @@ V2_REMOVED_KEYS = { 'post_init_call', } V2_RENAMED_KEYS = { - 'allow_population_by_field_name': 'validate_by_name', + 'allow_population_by_field_name': 'populate_by_name', 'anystr_lower': 'str_to_lower', 'anystr_strip_whitespace': 'str_strip_whitespace', 'anystr_upper': 'str_to_upper', diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_core_metadata.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_core_metadata.py index 9f2510c..296d49f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_core_metadata.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_core_metadata.py @@ -1,97 +1,92 @@ from __future__ import annotations as _annotations -from typing import TYPE_CHECKING, Any, TypedDict, cast -from warnings import warn +import typing +from typing import Any -if TYPE_CHECKING: - from ..config import JsonDict, JsonSchemaExtraCallable +import typing_extensions + +if typing.TYPE_CHECKING: + from ._schema_generation_shared import ( + CoreSchemaOrField as CoreSchemaOrField, + ) from ._schema_generation_shared import ( GetJsonSchemaFunction, ) -class CoreMetadata(TypedDict, total=False): +class CoreMetadata(typing_extensions.TypedDict, total=False): """A `TypedDict` for holding the metadata dict of the schema. Attributes: - pydantic_js_functions: List of JSON schema functions that resolve refs during application. - pydantic_js_annotation_functions: List of JSON schema functions that don't resolve refs during application. + pydantic_js_functions: List of JSON schema functions. pydantic_js_prefer_positional_arguments: Whether JSON schema generator will prefer positional over keyword arguments for an 'arguments' schema. - custom validation function. Only applies to before, plain, and wrap validators. - pydantic_js_updates: key / value pair updates to apply to the JSON schema for a type. - pydantic_js_extra: WIP, either key/value pair updates to apply to the JSON schema, or a custom callable. - pydantic_internal_union_tag_key: Used internally by the `Tag` metadata to specify the tag used for a discriminated union. - pydantic_internal_union_discriminator: Used internally to specify the discriminator value for a discriminated union - when the discriminator was applied to a `'definition-ref'` schema, and that reference was missing at the time - of the annotation application. - - TODO: Perhaps we should move this structure to pydantic-core. At the moment, though, - it's easier to iterate on if we leave it in pydantic until we feel there is a semi-stable API. - - TODO: It's unfortunate how functionally oriented JSON schema generation is, especially that which occurs during - the core schema generation process. It's inevitable that we need to store some json schema related information - on core schemas, given that we generate JSON schemas directly from core schemas. That being said, debugging related - issues is quite difficult when JSON schema information is disguised via dynamically defined functions. """ pydantic_js_functions: list[GetJsonSchemaFunction] pydantic_js_annotation_functions: list[GetJsonSchemaFunction] - pydantic_js_prefer_positional_arguments: bool - pydantic_js_updates: JsonDict - pydantic_js_extra: JsonDict | JsonSchemaExtraCallable - pydantic_internal_union_tag_key: str - pydantic_internal_union_discriminator: str + + # If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will + # prefer positional over keyword arguments for an 'arguments' schema. + pydantic_js_prefer_positional_arguments: bool | None + + pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema -def update_core_metadata( - core_metadata: Any, - /, - *, - pydantic_js_functions: list[GetJsonSchemaFunction] | None = None, - pydantic_js_annotation_functions: list[GetJsonSchemaFunction] | None = None, - pydantic_js_updates: JsonDict | None = None, - pydantic_js_extra: JsonDict | JsonSchemaExtraCallable | None = None, -) -> None: - from ..json_schema import PydanticJsonSchemaWarning +class CoreMetadataHandler: + """Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents. - """Update CoreMetadata instance in place. When we make modifications in this function, they - take effect on the `core_metadata` reference passed in as the first (and only) positional argument. - - First, cast to `CoreMetadata`, then finish with a cast to `dict[str, Any]` for core schema compatibility. - We do this here, instead of before / after each call to this function so that this typing hack - can be easily removed if/when we move `CoreMetadata` to `pydantic-core`. - - For parameter descriptions, see `CoreMetadata` above. + This class is used to interact with the metadata field on a CoreSchema object in a consistent + way throughout pydantic. """ - core_metadata = cast(CoreMetadata, core_metadata) - if pydantic_js_functions: - core_metadata.setdefault('pydantic_js_functions', []).extend(pydantic_js_functions) + __slots__ = ('_schema',) - if pydantic_js_annotation_functions: - core_metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions) + def __init__(self, schema: CoreSchemaOrField): + self._schema = schema - if pydantic_js_updates: - if (existing_updates := core_metadata.get('pydantic_js_updates')) is not None: - core_metadata['pydantic_js_updates'] = {**existing_updates, **pydantic_js_updates} - else: - core_metadata['pydantic_js_updates'] = pydantic_js_updates + metadata = schema.get('metadata') + if metadata is None: + schema['metadata'] = CoreMetadata() + elif not isinstance(metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.') - if pydantic_js_extra is not None: - existing_pydantic_js_extra = core_metadata.get('pydantic_js_extra') - if existing_pydantic_js_extra is None: - core_metadata['pydantic_js_extra'] = pydantic_js_extra - if isinstance(existing_pydantic_js_extra, dict): - if isinstance(pydantic_js_extra, dict): - core_metadata['pydantic_js_extra'] = {**existing_pydantic_js_extra, **pydantic_js_extra} - if callable(pydantic_js_extra): - warn( - 'Composing `dict` and `callable` type `json_schema_extra` is not supported.' - 'The `callable` type is being ignored.' - "If you'd like support for this behavior, please open an issue on pydantic.", - PydanticJsonSchemaWarning, - ) - if callable(existing_pydantic_js_extra): - # if ever there's a case of a callable, we'll just keep the last json schema extra spec - core_metadata['pydantic_js_extra'] = pydantic_js_extra + @property + def metadata(self) -> CoreMetadata: + """Retrieves the metadata dict from the schema, initializing it to a dict if it is None + and raises an error if it is not a dict. + """ + metadata = self._schema.get('metadata') + if metadata is None: + self._schema['metadata'] = metadata = CoreMetadata() + if not isinstance(metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.') + return metadata + + +def build_metadata_dict( + *, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way + js_functions: list[GetJsonSchemaFunction] | None = None, + js_annotation_functions: list[GetJsonSchemaFunction] | None = None, + js_prefer_positional_arguments: bool | None = None, + typed_dict_cls: type[Any] | None = None, + initial_metadata: Any | None = None, +) -> Any: + """Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent + with the CoreMetadataHandler class. + """ + if initial_metadata is not None and not isinstance(initial_metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.') + + metadata = CoreMetadata( + pydantic_js_functions=js_functions or [], + pydantic_js_annotation_functions=js_annotation_functions or [], + pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments, + pydantic_typed_dict_cls=typed_dict_cls, + ) + metadata = {k: v for k, v in metadata.items() if v is not None} + + if initial_metadata is not None: + metadata = {**initial_metadata, **metadata} + + return metadata diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_core_utils.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_core_utils.py index cf8cf7c..ebf12ec 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_core_utils.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_core_utils.py @@ -1,20 +1,22 @@ from __future__ import annotations -import inspect import os -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, Union +from collections import defaultdict +from typing import ( + Any, + Callable, + Hashable, + TypeVar, + Union, + _GenericAlias, # type: ignore + cast, +) from pydantic_core import CoreSchema, core_schema from pydantic_core import validate_core_schema as _validate_core_schema -from typing_extensions import TypeGuard, get_args, get_origin -from typing_inspection import typing_objects +from typing_extensions import TypeAliasType, TypeGuard, get_args from . import _repr -from ._typing_extra import is_generic_alias - -if TYPE_CHECKING: - from rich.console import Console AnyFunctionSchema = Union[ core_schema.AfterValidatorFunctionSchema, @@ -37,7 +39,19 @@ CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField] _CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'} _FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'} -_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'} +_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'} + +_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache' + +NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union' +"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of +schema building because one of it's members refers to a definition that was not yet defined when the union +was first encountered. +""" +HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid' +"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the +schema was first encountered. +""" def is_core_schema( @@ -60,27 +74,28 @@ def is_function_with_inner_schema( def is_list_like_schema_with_items_schema( schema: CoreSchema, -) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]: +) -> TypeGuard[ + core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema +]: return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES -def get_type_ref(type_: Any, args_override: tuple[type[Any], ...] | None = None) -> str: +def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str: """Produces the ref to be used for this type by pydantic_core's core schemas. This `args_override` argument was added for the purpose of creating valid recursive references when creating generic models without needing to create a concrete class. """ - origin = get_origin(type_) or type_ - - args = get_args(type_) if is_generic_alias(type_) else (args_override or ()) + origin = type_ + args = get_args(type_) if isinstance(type_, _GenericAlias) else (args_override or ()) generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None) if generic_metadata: origin = generic_metadata['origin'] or origin args = generic_metadata['args'] or args module_name = getattr(origin, '__module__', '') - if typing_objects.is_typealiastype(origin): - type_ref = f'{module_name}.{origin.__name__}:{id(origin)}' + if isinstance(origin, TypeAliasType): + type_ref = f'{module_name}.{origin.__name__}' else: try: qualname = getattr(origin, '__qualname__', f'') @@ -109,74 +124,457 @@ def get_ref(s: core_schema.CoreSchema) -> None | str: return s.get('ref', None) -def validate_core_schema(schema: CoreSchema) -> CoreSchema: - if os.getenv('PYDANTIC_VALIDATE_CORE_SCHEMAS'): - return _validate_core_schema(schema) +def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]: + defs: dict[str, CoreSchema] = {} + + def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + ref = get_ref(s) + if ref: + defs[ref] = s + return recurse(s, _record_valid_refs) + + walk_core_schema(schema, _record_valid_refs) + + return defs + + +def define_expected_missing_refs( + schema: core_schema.CoreSchema, allowed_missing_refs: set[str] +) -> core_schema.CoreSchema | None: + if not allowed_missing_refs: + # in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema + # this is a common case (will be hit for all non-generic models), so it's worth optimizing for + return None + + refs = collect_definitions(schema).keys() + + expected_missing_refs = allowed_missing_refs.difference(refs) + if expected_missing_refs: + definitions: list[core_schema.CoreSchema] = [ + # TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail + # Issue: https://github.com/pydantic/pydantic-core/issues/619 + core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True}) + for ref in expected_missing_refs + ] + return core_schema.definitions_schema(schema, definitions) + return None + + +def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool: + invalid = False + + def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + nonlocal invalid + if 'metadata' in s: + metadata = s['metadata'] + if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata: + invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] + return s + return recurse(s, _is_schema_valid) + + walk_core_schema(schema, _is_schema_valid) + return invalid + + +T = TypeVar('T') + + +Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema] +Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema] + +# TODO: Should we move _WalkCoreSchema into pydantic_core proper? +# Issue: https://github.com/pydantic/pydantic-core/issues/615 + + +class _WalkCoreSchema: + def __init__(self): + self._schema_type_to_method = self._build_schema_type_to_method() + + def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]: + mapping: dict[core_schema.CoreSchemaType, Recurse] = {} + key: core_schema.CoreSchemaType + for key in get_args(core_schema.CoreSchemaType): + method_name = f"handle_{key.replace('-', '_')}_schema" + mapping[key] = getattr(self, method_name, self._handle_other_schemas) + return mapping + + def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + return f(schema, self._walk) + + def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + schema = self._schema_type_to_method[schema['type']](schema.copy(), f) + ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore + if ser_schema: + schema['serialization'] = self._handle_ser_schemas(ser_schema, f) + return schema + + def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + sub_schema = schema.get('schema', None) + if sub_schema is not None: + schema['schema'] = self.walk(sub_schema, f) # type: ignore + return schema + + def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema: + schema: core_schema.CoreSchema | None = ser_schema.get('schema', None) + if schema is not None: + ser_schema['schema'] = self.walk(schema, f) # type: ignore + return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None) + if return_schema is not None: + ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore + return ser_schema + + def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema: + new_definitions: list[core_schema.CoreSchema] = [] + for definition in schema['definitions']: + updated_definition = self.walk(definition, f) + if 'ref' in updated_definition: + # If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions + # This is most likely to happen due to replacing something with a definition reference, in + # which case it should certainly not go in the definitions list + new_definitions.append(updated_definition) + new_inner_schema = self.walk(schema['schema'], f) + + if not new_definitions and len(schema) == 3: + # This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema + return new_inner_schema + + new_schema = schema.copy() + new_schema['schema'] = new_inner_schema + new_schema['definitions'] = new_definitions + return new_schema + + def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_tuple_variable_schema( + self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk + ) -> core_schema.CoreSchema: + schema = cast(core_schema.TupleVariableSchema, schema) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_tuple_positional_schema( + self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk + ) -> core_schema.CoreSchema: + schema = cast(core_schema.TuplePositionalSchema, schema) + schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']] + extras_schema = schema.get('extras_schema') + if extras_schema is not None: + schema['extras_schema'] = self.walk(extras_schema, f) + return schema + + def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema: + keys_schema = schema.get('keys_schema') + if keys_schema is not None: + schema['keys_schema'] = self.walk(keys_schema, f) + values_schema = schema.get('values_schema') + if values_schema: + schema['values_schema'] = self.walk(values_schema, f) + return schema + + def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema: + if not is_function_with_inner_schema(schema): + return schema + schema['schema'] = self.walk(schema['schema'], f) + return schema + + def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema: + new_choices: list[CoreSchema | tuple[CoreSchema, str]] = [] + for v in schema['choices']: + if isinstance(v, tuple): + new_choices.append((self.walk(v[0], f), v[1])) + else: + new_choices.append(self.walk(v, f)) + schema['choices'] = new_choices + return schema + + def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema: + new_choices: dict[Hashable, core_schema.CoreSchema] = {} + for k, v in schema['choices'].items(): + new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f) + schema['choices'] = new_choices + return schema + + def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema: + schema['steps'] = [self.walk(v, f) for v in schema['steps']] + return schema + + def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema: + schema['lax_schema'] = self.walk(schema['lax_schema'], f) + schema['strict_schema'] = self.walk(schema['strict_schema'], f) + return schema + + def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema: + schema['json_schema'] = self.walk(schema['json_schema'], f) + schema['python_schema'] = self.walk(schema['python_schema'], f) + return schema + + def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema: + extras_schema = schema.get('extras_schema') + if extras_schema is not None: + schema['extras_schema'] = self.walk(extras_schema, f) + replaced_fields: dict[str, core_schema.ModelField] = {} + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + for k, v in schema['fields'].items(): + replaced_field = v.copy() + replaced_field['schema'] = self.walk(v['schema'], f) + replaced_fields[k] = replaced_field + schema['fields'] = replaced_fields + return schema + + def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema: + extras_schema = schema.get('extras_schema') + if extras_schema is not None: + schema['extras_schema'] = self.walk(extras_schema, f) + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + replaced_fields: dict[str, core_schema.TypedDictField] = {} + for k, v in schema['fields'].items(): + replaced_field = v.copy() + replaced_field['schema'] = self.walk(v['schema'], f) + replaced_fields[k] = replaced_field + schema['fields'] = replaced_fields + return schema + + def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema: + replaced_fields: list[core_schema.DataclassField] = [] + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + for field in schema['fields']: + replaced_field = field.copy() + replaced_field['schema'] = self.walk(field['schema'], f) + replaced_fields.append(replaced_field) + schema['fields'] = replaced_fields + return schema + + def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema: + replaced_arguments_schema: list[core_schema.ArgumentsParameter] = [] + for param in schema['arguments_schema']: + replaced_param = param.copy() + replaced_param['schema'] = self.walk(param['schema'], f) + replaced_arguments_schema.append(replaced_param) + schema['arguments_schema'] = replaced_arguments_schema + if 'var_args_schema' in schema: + schema['var_args_schema'] = self.walk(schema['var_args_schema'], f) + if 'var_kwargs_schema' in schema: + schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f) + return schema + + def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema: + schema['arguments_schema'] = self.walk(schema['arguments_schema'], f) + if 'return_schema' in schema: + schema['return_schema'] = self.walk(schema['return_schema'], f) + return schema + + +_dispatch = _WalkCoreSchema().walk + + +def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + """Recursively traverse a CoreSchema. + + Args: + schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified. + f (Walk): A function to apply. This function takes two arguments: + 1. The current CoreSchema that is being processed + (not the same one you passed into this function, one level down). + 2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)` + to pass data down the recursive calls without using globals or other mutable state. + + Returns: + core_schema.CoreSchema: A processed CoreSchema. + """ + return f(schema.copy(), _dispatch) + + +def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901 + definitions: dict[str, core_schema.CoreSchema] = {} + ref_counts: dict[str, int] = defaultdict(int) + involved_in_recursion: dict[str, bool] = {} + current_recursion_ref_count: dict[str, int] = defaultdict(int) + + def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] == 'definitions': + for definition in s['definitions']: + ref = get_ref(definition) + assert ref is not None + if ref not in definitions: + definitions[ref] = definition + recurse(definition, collect_refs) + return recurse(s['schema'], collect_refs) + else: + ref = get_ref(s) + if ref is not None: + new = recurse(s, collect_refs) + new_ref = get_ref(new) + if new_ref: + definitions[new_ref] = new + return core_schema.definition_reference_schema(schema_ref=ref) + else: + return recurse(s, collect_refs) + + schema = walk_core_schema(schema, collect_refs) + + def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] != 'definition-ref': + return recurse(s, count_refs) + ref = s['schema_ref'] + ref_counts[ref] += 1 + + if ref_counts[ref] >= 2: + # If this model is involved in a recursion this should be detected + # on its second encounter, we can safely stop the walk here. + if current_recursion_ref_count[ref] != 0: + involved_in_recursion[ref] = True + return s + + current_recursion_ref_count[ref] += 1 + recurse(definitions[ref], count_refs) + current_recursion_ref_count[ref] -= 1 + return s + + schema = walk_core_schema(schema, count_refs) + + assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it' + + def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool: + if ref_counts[ref] > 1: + return False + if involved_in_recursion.get(ref, False): + return False + if 'serialization' in s: + return False + if 'metadata' in s: + metadata = s['metadata'] + for k in ( + 'pydantic_js_functions', + 'pydantic_js_annotation_functions', + 'pydantic.internal.union_discriminator', + ): + if k in metadata: + # we need to keep this as a ref + return False + return True + + def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] == 'definition-ref': + ref = s['schema_ref'] + # Check if the reference is only used once, not involved in recursion and does not have + # any extra keys (like 'serialization') + if can_be_inlined(s, ref): + # Inline the reference by replacing the reference with the actual schema + new = definitions.pop(ref) + ref_counts[ref] -= 1 # because we just replaced it! + # put all other keys that were on the def-ref schema into the inlined version + # in particular this is needed for `serialization` + if 'serialization' in s: + new['serialization'] = s['serialization'] + s = recurse(new, inline_refs) + return s + else: + return recurse(s, inline_refs) + else: + return recurse(s, inline_refs) + + schema = walk_core_schema(schema, inline_refs) + + def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore + + if def_values: + schema = core_schema.definitions_schema(schema=schema, definitions=def_values) return schema -def _clean_schema_for_pretty_print(obj: Any, strip_metadata: bool = True) -> Any: # pragma: no cover - """A utility function to remove irrelevant information from a core schema.""" - if isinstance(obj, Mapping): - new_dct = {} - for k, v in obj.items(): - if k == 'metadata' and strip_metadata: - new_metadata = {} - - for meta_k, meta_v in v.items(): - if meta_k in ('pydantic_js_functions', 'pydantic_js_annotation_functions'): - new_metadata['js_metadata'] = '' - else: - new_metadata[meta_k] = _clean_schema_for_pretty_print(meta_v, strip_metadata=strip_metadata) - - if list(new_metadata.keys()) == ['js_metadata']: - new_metadata = {''} - - new_dct[k] = new_metadata - # Remove some defaults: - elif k in ('custom_init', 'root_model') and not v: - continue +def _strip_metadata(schema: CoreSchema) -> CoreSchema: + def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema: + s = s.copy() + s.pop('metadata', None) + if s['type'] == 'model-fields': + s = s.copy() + s['fields'] = {k: v.copy() for k, v in s['fields'].items()} + for field_name, field_schema in s['fields'].items(): + field_schema.pop('metadata', None) + s['fields'][field_name] = field_schema + computed_fields = s.get('computed_fields', None) + if computed_fields: + s['computed_fields'] = [cf.copy() for cf in computed_fields] + for cf in computed_fields: + cf.pop('metadata', None) else: - new_dct[k] = _clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) + s.pop('computed_fields', None) + elif s['type'] == 'model': + # remove some defaults + if s.get('custom_init', True) is False: + s.pop('custom_init') + if s.get('root_model', True) is False: + s.pop('root_model') + if {'title'}.issuperset(s.get('config', {}).keys()): + s.pop('config', None) - return new_dct - elif isinstance(obj, Sequence) and not isinstance(obj, str): - return [_clean_schema_for_pretty_print(v, strip_metadata=strip_metadata) for v in obj] - else: - return obj + return recurse(s, strip_metadata) + + return walk_core_schema(schema, strip_metadata) def pretty_print_core_schema( - val: Any, - *, - console: Console | None = None, - max_depth: int | None = None, - strip_metadata: bool = True, -) -> None: # pragma: no cover - """Pretty-print a core schema using the `rich` library. + schema: CoreSchema, + include_metadata: bool = False, +) -> None: + """Pretty print a CoreSchema using rich. + This is intended for debugging purposes. Args: - val: The core schema to print, or a Pydantic model/dataclass/type adapter - (in which case the cached core schema is fetched and printed). - console: A rich console to use when printing. Defaults to the global rich console instance. - max_depth: The number of nesting levels which may be printed. - strip_metadata: Whether to strip metadata in the output. If `True` any known core metadata - attributes will be stripped (but custom attributes are kept). Defaults to `True`. + schema: The CoreSchema to print. + include_metadata: Whether to include metadata in the output. Defaults to `False`. """ - # lazy import: - from rich.pretty import pprint + from rich import print # type: ignore # install it manually in your dev env - # circ. imports: - from pydantic import BaseModel, TypeAdapter - from pydantic.dataclasses import is_pydantic_dataclass + if not include_metadata: + schema = _strip_metadata(schema) - if (inspect.isclass(val) and issubclass(val, BaseModel)) or is_pydantic_dataclass(val): - val = val.__pydantic_core_schema__ - if isinstance(val, TypeAdapter): - val = val.core_schema - cleaned_schema = _clean_schema_for_pretty_print(val, strip_metadata=strip_metadata) - - pprint(cleaned_schema, console=console, max_depth=max_depth) + return print(schema) -pps = pretty_print_core_schema +def validate_core_schema(schema: CoreSchema) -> CoreSchema: + if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ: + return schema + return _validate_core_schema(schema) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_dataclasses.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_dataclasses.py index 03f156f..430e3a9 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_dataclasses.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_dataclasses.py @@ -1,15 +1,17 @@ """Private logic for creating pydantic dataclasses.""" - from __future__ import annotations as _annotations import dataclasses +import inspect import typing import warnings from functools import partial, wraps -from typing import Any, ClassVar +from inspect import Parameter, Signature, signature +from typing import Any, Callable, ClassVar from pydantic_core import ( ArgsKwargs, + PydanticUndefined, SchemaSerializer, SchemaValidator, core_schema, @@ -17,22 +19,28 @@ from pydantic_core import ( from typing_extensions import TypeGuard from ..errors import PydanticUndefinedAnnotation -from ..plugin._schema_validator import PluggableSchemaValidator, create_schema_validator +from ..fields import FieldInfo +from ..plugin._schema_validator import create_schema_validator from ..warnings import PydanticDeprecatedSince20 -from . import _config, _decorators +from . import _config, _decorators, _discriminated_union, _typing_extra +from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema from ._fields import collect_dataclass_fields -from ._generate_schema import GenerateSchema, InvalidSchemaError +from ._generate_schema import GenerateSchema from ._generics import get_standard_typevars_map -from ._mock_val_ser import set_dataclass_mocks -from ._namespace_utils import NsResolver -from ._signature import generate_pydantic_signature -from ._utils import LazyClassAttribute +from ._mock_val_ser import set_dataclass_mock_validator +from ._schema_generation_shared import CallbackGetCoreSchemaHandler +from ._utils import is_valid_identifier if typing.TYPE_CHECKING: - from _typeshed import DataclassInstance as StandardDataclass - from ..config import ConfigDict - from ..fields import FieldInfo + + class StandardDataclass(typing.Protocol): + __dataclass_fields__: ClassVar[dict[str, Any]] + __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` + __post_init__: ClassVar[Callable[..., None]] + + def __init__(self, *args: object, **kwargs: object) -> None: + pass class PydanticDataclass(StandardDataclass, typing.Protocol): """A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass. @@ -53,10 +61,7 @@ if typing.TYPE_CHECKING: __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] __pydantic_fields__: ClassVar[dict[str, FieldInfo]] __pydantic_serializer__: ClassVar[SchemaSerializer] - __pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator] - - @classmethod - def __pydantic_fields_complete__(cls) -> bool: ... + __pydantic_validator__: ClassVar[SchemaValidator] else: # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 @@ -64,22 +69,15 @@ else: DeprecationWarning = PydanticDeprecatedSince20 -def set_dataclass_fields( - cls: type[StandardDataclass], - ns_resolver: NsResolver | None = None, - config_wrapper: _config.ConfigWrapper | None = None, -) -> None: +def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None: """Collect and set `cls.__pydantic_fields__`. Args: cls: The class. - ns_resolver: Namespace resolver to use when getting dataclass annotations. - config_wrapper: The config wrapper instance, defaults to `None`. + types_namespace: The types namespace, defaults to `None`. """ typevars_map = get_standard_typevars_map(cls) - fields = collect_dataclass_fields( - cls, ns_resolver=ns_resolver, typevars_map=typevars_map, config_wrapper=config_wrapper - ) + fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map) cls.__pydantic_fields__ = fields # type: ignore @@ -89,8 +87,7 @@ def complete_dataclass( config_wrapper: _config.ConfigWrapper, *, raise_errors: bool = True, - ns_resolver: NsResolver | None = None, - _force_build: bool = False, + types_namespace: dict[str, Any] | None, ) -> bool: """Finish building a pydantic dataclass. @@ -102,10 +99,7 @@ def complete_dataclass( cls: The class. config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors, defaults to `True`. - ns_resolver: The namespace resolver instance to use when collecting dataclass fields - and during schema building. - _force_build: Whether to force building the dataclass, no matter if - [`defer_build`][pydantic.config.ConfigDict.defer_build] is set. + types_namespace: The types namespace. Returns: `True` if building a pydantic dataclass is successfully completed, `False` otherwise. @@ -113,94 +107,136 @@ def complete_dataclass( Raises: PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations. """ - original_init = cls.__init__ + if hasattr(cls, '__post_init_post_parse__'): + warnings.warn( + 'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning + ) + + if types_namespace is None: + types_namespace = _typing_extra.get_cls_types_namespace(cls) + + set_dataclass_fields(cls, types_namespace) + + typevars_map = get_standard_typevars_map(cls) + gen_schema = GenerateSchema( + config_wrapper, + types_namespace, + typevars_map, + ) + + # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied. - # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied, - # and so that the mock validator is used if building was deferred: def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None: __tracebackhide__ = True s = __dataclass_self__ s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s) __init__.__qualname__ = f'{cls.__qualname__}.__init__' - + sig = generate_dataclass_signature(cls) cls.__init__ = __init__ # type: ignore + cls.__signature__ = sig # type: ignore cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore - set_dataclass_fields(cls, ns_resolver, config_wrapper=config_wrapper) - - if not _force_build and config_wrapper.defer_build: - set_dataclass_mocks(cls) - return False - - if hasattr(cls, '__post_init_post_parse__'): - warnings.warn( - 'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning - ) - - typevars_map = get_standard_typevars_map(cls) - gen_schema = GenerateSchema( - config_wrapper, - ns_resolver=ns_resolver, - typevars_map=typevars_map, - ) - - # set __signature__ attr only for the class, but not for its instances - # (because instances can define `__call__`, and `inspect.signature` shouldn't - # use the `__signature__` attribute and instead generate from `__call__`). - cls.__signature__ = LazyClassAttribute( - '__signature__', - partial( - generate_pydantic_signature, - # It's important that we reference the `original_init` here, - # as it is the one synthesized by the stdlib `dataclass` module: - init=original_init, - fields=cls.__pydantic_fields__, # type: ignore - validate_by_name=config_wrapper.validate_by_name, - extra=config_wrapper.extra, - is_dataclass=True, - ), - ) - + get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None) try: - schema = gen_schema.generate_schema(cls) + if get_core_schema: + schema = get_core_schema( + cls, + CallbackGetCoreSchemaHandler( + partial(gen_schema.generate_schema, from_dunder_get_core_schema=False), + gen_schema, + ref_mode='unpack', + ), + ) + else: + schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False) except PydanticUndefinedAnnotation as e: if raise_errors: raise - set_dataclass_mocks(cls, f'`{e.name}`') + set_dataclass_mock_validator(cls, cls.__name__, f'`{e.name}`') return False - core_config = config_wrapper.core_config(title=cls.__name__) + core_config = config_wrapper.core_config(cls) - try: - schema = gen_schema.clean_schema(schema) - except InvalidSchemaError: - set_dataclass_mocks(cls) + schema = gen_schema.collect_definitions(schema) + if collect_invalid_schemas(schema): + set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types') return False + schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema)) + # We are about to set all the remaining required properties expected for this cast; # __pydantic_decorators__ and __pydantic_fields__ should already be set cls = typing.cast('type[PydanticDataclass]', cls) # debug(schema) - cls.__pydantic_core_schema__ = schema + cls.__pydantic_core_schema__ = schema = validate_core_schema(schema) cls.__pydantic_validator__ = validator = create_schema_validator( - schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings + schema, core_config, config_wrapper.plugin_settings ) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) if config_wrapper.validate_assignment: @wraps(cls.__setattr__) - def validated_setattr(instance: Any, field: str, value: str, /) -> None: - validator.validate_assignment(instance, field, value) + def validated_setattr(instance: Any, __field: str, __value: str) -> None: + validator.validate_assignment(instance, __field, __value) cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore - cls.__pydantic_complete__ = True return True +def generate_dataclass_signature(cls: type[StandardDataclass]) -> Signature: + """Generate signature for a pydantic dataclass. + + This implementation assumes we do not support custom `__init__`, which is currently true for pydantic dataclasses. + If we change this eventually, we should make this function's logic more closely mirror that from + `pydantic._internal._model_construction.generate_model_signature`. + + Args: + cls: The dataclass. + + Returns: + The signature. + """ + sig = signature(cls) + final_params: dict[str, Parameter] = {} + + for param in sig.parameters.values(): + param_default = param.default + if isinstance(param_default, FieldInfo): + annotation = param.annotation + # Replace the annotation if appropriate + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if annotation == 'Any': + annotation = Any + + # Replace the field name with the alias if present + name = param.name + alias = param_default.alias + validation_alias = param_default.validation_alias + if validation_alias is None and isinstance(alias, str) and is_valid_identifier(alias): + name = alias + elif isinstance(validation_alias, str) and is_valid_identifier(validation_alias): + name = validation_alias + + # Replace the field default + default = param_default.default + if default is PydanticUndefined: + if param_default.default_factory is PydanticUndefined: + default = inspect.Signature.empty + else: + # this is used by dataclasses to indicate a factory exists: + default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore + + param = param.replace(annotation=annotation, name=name, default=default) + final_params[param.name] = param + + return Signature(parameters=list(final_params.values()), return_annotation=None) + + def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: """Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass. @@ -209,7 +245,7 @@ def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: - `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`) - `_cls` does not have any annotations that are not dataclass fields e.g. - ```python + ```py import dataclasses import pydantic.dataclasses diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators.py index 92880a4..0ac3248 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators.py @@ -1,30 +1,31 @@ """Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators.""" - from __future__ import annotations as _annotations -import types from collections import deque -from collections.abc import Iterable from dataclasses import dataclass, field -from functools import cached_property, partial, partialmethod +from functools import partial, partialmethod from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature from itertools import islice -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union -from pydantic_core import PydanticUndefined, PydanticUndefinedType, core_schema -from typing_extensions import TypeAlias, is_typeddict +from pydantic_core import PydanticUndefined, core_schema +from typing_extensions import Literal, TypeAlias, is_typeddict from ..errors import PydanticUserError +from ..fields import ComputedFieldInfo from ._core_utils import get_type_ref from ._internal_dataclass import slots_true -from ._namespace_utils import GlobalsNamespace, MappingNamespace from ._typing_extra import get_function_type_hints -from ._utils import can_be_positional if TYPE_CHECKING: - from ..fields import ComputedFieldInfo from ..functional_validators import FieldValidatorModes +try: + from functools import cached_property # type: ignore +except ImportError: + # python 3.7 + cached_property = None + @dataclass(**slots_true) class ValidatorDecoratorInfo: @@ -60,9 +61,6 @@ class FieldValidatorDecoratorInfo: fields: A tuple of field names the validator should be called on. mode: The proposed validator mode. check_fields: Whether to check that the fields actually exist on the model. - json_schema_input_type: The input type of the function. This is only used to generate - the appropriate JSON Schema (in validation mode) and can only specified - when `mode` is either `'before'`, `'plain'` or `'wrap'`. """ decorator_repr: ClassVar[str] = '@field_validator' @@ -70,7 +68,6 @@ class FieldValidatorDecoratorInfo: fields: tuple[str, ...] mode: FieldValidatorModes check_fields: bool | None - json_schema_input_type: Any @dataclass(**slots_true) @@ -135,7 +132,7 @@ class ModelValidatorDecoratorInfo: while building the pydantic-core schema. Attributes: - decorator_repr: A class variable representing the decorator string, '@model_validator'. + decorator_repr: A class variable representing the decorator string, '@model_serializer'. mode: The proposed serializer mode. """ @@ -143,7 +140,7 @@ class ModelValidatorDecoratorInfo: mode: Literal['wrap', 'before', 'after'] -DecoratorInfo: TypeAlias = """Union[ +DecoratorInfo = Union[ ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, RootValidatorDecoratorInfo, @@ -151,7 +148,7 @@ DecoratorInfo: TypeAlias = """Union[ ModelSerializerDecoratorInfo, ModelValidatorDecoratorInfo, ComputedFieldInfo, -]""" +] ReturnType = TypeVar('ReturnType') DecoratedType: TypeAlias = ( @@ -186,12 +183,6 @@ class PydanticDescriptorProxy(Generic[ReturnType]): def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]: self.wrapped = getattr(self.wrapped, name)(func) - if isinstance(self.wrapped, property): - # update ComputedFieldInfo.wrapped_property - from ..fields import ComputedFieldInfo - - if isinstance(self.decorator_info, ComputedFieldInfo): - self.decorator_info.wrapped_property = self.wrapped return self def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]: @@ -203,11 +194,11 @@ class PydanticDescriptorProxy(Generic[ReturnType]): def __set_name__(self, instance: Any, name: str) -> None: if hasattr(self.wrapped, '__set_name__'): - self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess] + self.wrapped.__set_name__(instance, name) - def __getattr__(self, name: str, /) -> Any: + def __getattr__(self, __name: str) -> Any: """Forward checks for __isabstractmethod__ and such.""" - return getattr(self.wrapped, name) + return getattr(self.wrapped, __name) DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo) @@ -497,8 +488,6 @@ class DecoratorInfos: model_dc, cls_var_name=var_name, shim=var_value.shim, info=info ) else: - from ..fields import ComputedFieldInfo - isinstance(var_value, ComputedFieldInfo) res.computed_fields[var_name] = Decorator.build( model_dc, cls_var_name=var_name, shim=None, info=info @@ -509,7 +498,7 @@ class DecoratorInfos: # so then we don't need to re-process the type, which means we can discard our descriptor wrappers # and replace them with the thing they are wrapping (see the other setattr call below) # which allows validator class methods to also function as regular class methods - model_dc.__pydantic_decorators__ = res + setattr(model_dc, '__pydantic_decorators__', res) for name, value in to_replace: setattr(model_dc, name, value) return res @@ -529,11 +518,12 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) """ try: sig = signature(validator) - except (ValueError, TypeError): - # `inspect.signature` might not be able to infer a signature, e.g. with C objects. - # In this case, we assume no info argument is present: + except ValueError: + # builtins and some C extensions don't have signatures + # assume that they don't take an info argument and only take a single argument + # e.g. `str.strip` or `datetime.datetime` return False - n_positional = count_positional_required_params(sig) + n_positional = count_positional_params(sig) if mode == 'wrap': if n_positional == 3: return True @@ -552,7 +542,9 @@ def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) ) -def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> tuple[bool, bool]: +def inspect_field_serializer( + serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False +) -> tuple[bool, bool]: """Look at a field serializer function and determine if it is a field serializer, and whether it takes an info argument. @@ -561,21 +553,18 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai Args: serializer: The serializer function to inspect. mode: The serializer mode, either 'plain' or 'wrap'. + computed_field: When serializer is applied on computed_field. It doesn't require + info signature. Returns: Tuple of (is_field_serializer, info_arg). """ - try: - sig = signature(serializer) - except (ValueError, TypeError): - # `inspect.signature` might not be able to infer a signature, e.g. with C objects. - # In this case, we assume no info argument is present and this is not a method: - return (False, False) + sig = signature(serializer) first = next(iter(sig.parameters.values()), None) is_field_serializer = first is not None and first.name == 'self' - n_positional = count_positional_required_params(sig) + n_positional = count_positional_params(sig) if is_field_serializer: # -1 to correct for self parameter info_arg = _serializer_info_arg(mode, n_positional - 1) @@ -587,8 +576,13 @@ def inspect_field_serializer(serializer: Callable[..., Any], mode: Literal['plai f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', code='field-serializer-signature', ) + if info_arg and computed_field: + raise PydanticUserError( + 'field_serializer on computed_field does not use info signature', code='field-serializer-signature' + ) - return is_field_serializer, info_arg + else: + return is_field_serializer, info_arg def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: @@ -603,13 +597,8 @@ def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal[' Returns: info_arg """ - try: - sig = signature(serializer) - except (ValueError, TypeError): - # `inspect.signature` might not be able to infer a signature, e.g. with C objects. - # In this case, we assume no info argument is present: - return False - info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) + sig = signature(serializer) + info_arg = _serializer_info_arg(mode, count_positional_params(sig)) if info_arg is None: raise PydanticUserError( f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', @@ -637,7 +626,7 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai ) sig = signature(serializer) - info_arg = _serializer_info_arg(mode, count_positional_required_params(sig)) + info_arg = _serializer_info_arg(mode, count_positional_params(sig)) if info_arg is None: raise PydanticUserError( f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}', @@ -650,18 +639,18 @@ def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plai def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None: if mode == 'plain': if n_positional == 1: - # (input_value: Any, /) -> Any + # (__input_value: Any) -> Any return False elif n_positional == 2: - # (model: Any, input_value: Any, /) -> Any + # (__model: Any, __input_value: Any) -> Any return True else: assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'" if n_positional == 2: - # (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any + # (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any return False elif n_positional == 3: - # (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any + # (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any return True return None @@ -722,25 +711,34 @@ def unwrap_wrapped_function( unwrap_class_static_method: bool = True, ) -> Any: """Recursively unwraps a wrapped function until the underlying function is reached. - This handles property, functools.partial, functools.partialmethod, staticmethod, and classmethod. + This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod. Args: func: The function to unwrap. - unwrap_partial: If True (default), unwrap partial and partialmethod decorators. + unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't. + decorators. unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod decorators. If False, only unwrap partial and partialmethod decorators. Returns: The underlying function of the wrapped function. """ - # Define the types we want to check against as a single tuple. - unwrap_types = ( - (property, cached_property) - + ((partial, partialmethod) if unwrap_partial else ()) - + ((staticmethod, classmethod) if unwrap_class_static_method else ()) - ) + all: set[Any] = {property} - while isinstance(func, unwrap_types): + if unwrap_partial: + all.update({partial, partialmethod}) + + try: + from functools import cached_property # type: ignore + except ImportError: + cached_property = type('', (), {}) + else: + all.add(cached_property) + + if unwrap_class_static_method: + all.update({staticmethod, classmethod}) + + while isinstance(func, tuple(all)): if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)): func = func.__func__ elif isinstance(func, (partial, partialmethod)): @@ -755,72 +753,38 @@ def unwrap_wrapped_function( return func -_function_like = ( - partial, - partialmethod, - types.FunctionType, - types.BuiltinFunctionType, - types.MethodType, - types.WrapperDescriptorType, - types.MethodWrapperType, - types.MemberDescriptorType, -) +def get_function_return_type( + func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None +) -> Any: + """Get the function return type. - -def get_callable_return_type( - callable_obj: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, -) -> Any | PydanticUndefinedType: - """Get the callable return type. + It gets the return type from the type annotation if `explicit_return_type` is `None`. + Otherwise, it returns `explicit_return_type`. Args: - callable_obj: The callable to analyze. - globalns: The globals namespace to use during type annotation evaluation. - localns: The locals namespace to use during type annotation evaluation. + func: The function to get its return type. + explicit_return_type: The explicit return type. + types_namespace: The types namespace, defaults to `None`. Returns: The function return type. """ - if isinstance(callable_obj, type): - # types are callables, and we assume the return type - # is the type itself (e.g. `int()` results in an instance of `int`). - return callable_obj - - if not isinstance(callable_obj, _function_like): - call_func = getattr(type(callable_obj), '__call__', None) # noqa: B004 - if call_func is not None: - callable_obj = call_func - - hints = get_function_type_hints( - unwrap_wrapped_function(callable_obj), - include_keys={'return'}, - globalns=globalns, - localns=localns, - ) - return hints.get('return', PydanticUndefined) + if explicit_return_type is PydanticUndefined: + # try to get it from the type annotation + hints = get_function_type_hints( + unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace + ) + return hints.get('return', PydanticUndefined) + else: + return explicit_return_type -def count_positional_required_params(sig: Signature) -> int: - """Get the number of positional (required) arguments of a signature. +def count_positional_params(sig: Signature) -> int: + return sum(1 for param in sig.parameters.values() if can_be_positional(param)) - This function should only be used to inspect signatures of validation and serialization functions. - The first argument (the value being serialized or validated) is counted as a required argument - even if a default value exists. - Returns: - The number of positional arguments of a signature. - """ - parameters = list(sig.parameters.values()) - return sum( - 1 - for param in parameters - if can_be_positional(param) - # First argument is the value being validated/serialized, and can have a default value - # (e.g. `float`, which has signature `(x=0, /)`). We assume other parameters (the info arg - # for instance) should be required, and thus without any default value. - and (param.default is Parameter.empty or param is parameters[0]) - ) +def can_be_positional(param: Parameter) -> bool: + return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) def ensure_property(f: Any) -> Any: diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators_v1.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators_v1.py index 3427377..4f81e6d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators_v1.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_decorators_v1.py @@ -1,45 +1,49 @@ """Logic for V1 validators, e.g. `@validator` and `@root_validator`.""" - from __future__ import annotations as _annotations from inspect import Parameter, signature -from typing import Any, Union, cast +from typing import Any, Dict, Tuple, Union, cast from pydantic_core import core_schema from typing_extensions import Protocol from ..errors import PydanticUserError -from ._utils import can_be_positional +from ._decorators import can_be_positional class V1OnlyValueValidator(Protocol): """A simple validator, supported for V1 validators and V2 validators.""" - def __call__(self, __value: Any) -> Any: ... + def __call__(self, __value: Any) -> Any: + ... class V1ValidatorWithValues(Protocol): """A validator with `values` argument, supported for V1 validators and V2 validators.""" - def __call__(self, __value: Any, values: dict[str, Any]) -> Any: ... + def __call__(self, __value: Any, values: dict[str, Any]) -> Any: + ... class V1ValidatorWithValuesKwOnly(Protocol): """A validator with keyword only `values` argument, supported for V1 validators and V2 validators.""" - def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: ... + def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: + ... class V1ValidatorWithKwargs(Protocol): """A validator with `kwargs` argument, supported for V1 validators and V2 validators.""" - def __call__(self, __value: Any, **kwargs: Any) -> Any: ... + def __call__(self, __value: Any, **kwargs: Any) -> Any: + ... class V1ValidatorWithValuesAndKwargs(Protocol): """A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators.""" - def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: ... + def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: + ... V1Validator = Union[ @@ -105,21 +109,23 @@ def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithI return wrapper2 -RootValidatorValues = dict[str, Any] +RootValidatorValues = Dict[str, Any] # technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars] -RootValidatorFieldsTuple = tuple[Any, ...] +RootValidatorFieldsTuple = Tuple[Any, ...] class V1RootValidatorFunction(Protocol): """A simple root validator, supported for V1 validators and V2 validators.""" - def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: ... + def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: + ... class V2CoreBeforeRootValidator(Protocol): """V2 validator with mode='before'.""" - def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: ... + def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: + ... class V2CoreAfterRootValidator(Protocol): @@ -127,7 +133,8 @@ class V2CoreAfterRootValidator(Protocol): def __call__( self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo - ) -> RootValidatorFieldsTuple: ... + ) -> RootValidatorFieldsTuple: + ... def make_v1_generic_root_validator( diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py index 5dd6fda..4cb9c3d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_discriminated_union.py @@ -1,19 +1,19 @@ from __future__ import annotations as _annotations -from collections.abc import Hashable, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import Any, Hashable, Sequence from pydantic_core import CoreSchema, core_schema from ..errors import PydanticUserError from . import _core_utils from ._core_utils import ( + NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, CoreSchemaField, + collect_definitions, + simplify_schema_references, ) -if TYPE_CHECKING: - from ..types import Discriminator - from ._core_metadata import CoreMetadata +CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator' class MissingDefinitionForUnionRef(Exception): @@ -26,15 +26,39 @@ class MissingDefinitionForUnionRef(Exception): super().__init__(f'Missing definition for ref {self.ref!r}') -def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: - metadata = cast('CoreMetadata', schema.setdefault('metadata', {})) - metadata['pydantic_internal_union_discriminator'] = discriminator +def set_discriminator(schema: CoreSchema, discriminator: Any) -> None: + schema.setdefault('metadata', {}) + metadata = schema.get('metadata') + assert metadata is not None + metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator + + +def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + definitions: dict[str, CoreSchema] | None = None + + def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: + nonlocal definitions + if 'metadata' in s: + if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False: + return s + + s = recurse(s, inner) + if s['type'] == 'tagged-union': + return s + + metadata = s.get('metadata', {}) + discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) + if discriminator is not None: + if definitions is None: + definitions = collect_definitions(schema) + s = apply_discriminator(s, discriminator, definitions) + return s + + return simplify_schema_references(_core_utils.walk_core_schema(schema, inner)) def apply_discriminator( - schema: core_schema.CoreSchema, - discriminator: str | Discriminator, - definitions: dict[str, core_schema.CoreSchema] | None = None, + schema: core_schema.CoreSchema, discriminator: str, definitions: dict[str, core_schema.CoreSchema] | None = None ) -> core_schema.CoreSchema: """Applies the discriminator and returns a new core schema. @@ -59,14 +83,6 @@ def apply_discriminator( - If discriminator fields have different aliases. - If discriminator field not of type `Literal`. """ - from ..types import Discriminator - - if isinstance(discriminator, Discriminator): - if isinstance(discriminator.discriminator, str): - discriminator = discriminator.discriminator - else: - return discriminator._convert_schema(schema) - return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) @@ -134,7 +150,7 @@ class _ApplyInferredDiscriminator: # in the output TaggedUnionSchema that will replace the union from the input schema self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {} - # `_used` is changed to True after applying the discriminator to prevent accidental reuse + # `_used` is changed to True after applying the discriminator to prevent accidental re-use self._used = False def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: @@ -160,11 +176,16 @@ class _ApplyInferredDiscriminator: - If discriminator fields have different aliases. - If discriminator field not of type `Literal`. """ + self.definitions.update(collect_definitions(schema)) assert not self._used schema = self._apply_to_root(schema) if self._should_be_nullable and not self._is_nullable: schema = core_schema.nullable_schema(schema) self._used = True + new_defs = collect_definitions(schema) + missing_defs = self.definitions.keys() - new_defs.keys() + if missing_defs: + schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs]) return schema def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: @@ -234,10 +255,6 @@ class _ApplyInferredDiscriminator: * Validating that each allowed discriminator value maps to a unique choice * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. """ - if choice['type'] == 'definition-ref': - if choice['schema_ref'] not in self.definitions: - raise MissingDefinitionForUnionRef(choice['schema_ref']) - if choice['type'] == 'none': self._should_be_nullable = True elif choice['type'] == 'definitions': @@ -249,6 +266,10 @@ class _ApplyInferredDiscriminator: # Reverse the choices list before extending the stack so that they get handled in the order they occur choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] self._choices_to_handle.extend(choices_schemas) + elif choice['type'] == 'definition-ref': + if choice['schema_ref'] not in self.definitions: + raise MissingDefinitionForUnionRef(choice['schema_ref']) + self._handle_choice(self.definitions[choice['schema_ref']]) elif choice['type'] not in { 'model', 'typed-dict', @@ -256,16 +277,12 @@ class _ApplyInferredDiscriminator: 'lax-or-strict', 'dataclass', 'dataclass-args', - 'definition-ref', } and not _core_utils.is_function_with_inner_schema(choice): # We should eventually handle 'definition-ref' as well - err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.' - if choice['type'] == 'list': - err_str += ( - ' If you are making use of a list of union types, make sure the discriminator is applied to the ' - 'union type and not the list (e.g. `list[Annotated[ | , Field(discriminator=...)]]`).' - ) - raise TypeError(err_str) + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) else: if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): # In this case, this inner tagged-union is compatible with the outer tagged-union, @@ -299,10 +316,13 @@ class _ApplyInferredDiscriminator: """ if choice['type'] == 'definitions': return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) - + elif choice['type'] == 'function-plain': + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) elif _core_utils.is_function_with_inner_schema(choice): return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) - elif choice['type'] == 'lax-or-strict': return sorted( set( @@ -353,13 +373,10 @@ class _ApplyInferredDiscriminator: raise MissingDefinitionForUnionRef(schema_ref) return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name) else: - err_str = f'The core schema type {choice["type"]!r} is not a valid discriminated union variant.' - if choice['type'] == 'list': - err_str += ( - ' If you are making use of a list of union types, make sure the discriminator is applied to the ' - 'union type and not the list (e.g. `list[Annotated[ | , Field(discriminator=...)]]`).' - ) - raise TypeError(err_str) + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) def _infer_discriminator_values_for_typed_dict_choice( self, choice: core_schema.TypedDictSchema, source_name: str | None = None diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_docs_extraction.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_docs_extraction.py deleted file mode 100644 index 7b5f310..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_docs_extraction.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Utilities related to attribute docstring extraction.""" - -from __future__ import annotations - -import ast -import inspect -import textwrap -from typing import Any - - -class DocstringVisitor(ast.NodeVisitor): - def __init__(self) -> None: - super().__init__() - - self.target: str | None = None - self.attrs: dict[str, str] = {} - self.previous_node_type: type[ast.AST] | None = None - - def visit(self, node: ast.AST) -> Any: - node_result = super().visit(node) - self.previous_node_type = type(node) - return node_result - - def visit_AnnAssign(self, node: ast.AnnAssign) -> Any: - if isinstance(node.target, ast.Name): - self.target = node.target.id - - def visit_Expr(self, node: ast.Expr) -> Any: - if ( - isinstance(node.value, ast.Constant) - and isinstance(node.value.value, str) - and self.previous_node_type is ast.AnnAssign - ): - docstring = inspect.cleandoc(node.value.value) - if self.target: - self.attrs[self.target] = docstring - self.target = None - - -def _dedent_source_lines(source: list[str]) -> str: - # Required for nested class definitions, e.g. in a function block - dedent_source = textwrap.dedent(''.join(source)) - if dedent_source.startswith((' ', '\t')): - # We are in the case where there's a dedented (usually multiline) string - # at a lower indentation level than the class itself. We wrap our class - # in a function as a workaround. - dedent_source = f'def dedent_workaround():\n{dedent_source}' - return dedent_source - - -def _extract_source_from_frame(cls: type[Any]) -> list[str] | None: - frame = inspect.currentframe() - - while frame: - if inspect.getmodule(frame) is inspect.getmodule(cls): - lnum = frame.f_lineno - try: - lines, _ = inspect.findsource(frame) - except OSError: # pragma: no cover - # Source can't be retrieved (maybe because running in an interactive terminal), - # we don't want to error here. - pass - else: - block_lines = inspect.getblock(lines[lnum - 1 :]) - dedent_source = _dedent_source_lines(block_lines) - try: - block_tree = ast.parse(dedent_source) - except SyntaxError: - pass - else: - stmt = block_tree.body[0] - if isinstance(stmt, ast.FunctionDef) and stmt.name == 'dedent_workaround': - # `_dedent_source_lines` wrapped the class around the workaround function - stmt = stmt.body[0] - if isinstance(stmt, ast.ClassDef) and stmt.name == cls.__name__: - return block_lines - - frame = frame.f_back - - -def extract_docstrings_from_cls(cls: type[Any], use_inspect: bool = False) -> dict[str, str]: - """Map model attributes and their corresponding docstring. - - Args: - cls: The class of the Pydantic model to inspect. - use_inspect: Whether to skip usage of frames to find the object and use - the `inspect` module instead. - - Returns: - A mapping containing attribute names and their corresponding docstring. - """ - if use_inspect: - # Might not work as expected if two classes have the same name in the same source file. - try: - source, _ = inspect.getsourcelines(cls) - except OSError: # pragma: no cover - return {} - else: - source = _extract_source_from_frame(cls) - - if not source: - return {} - - dedent_source = _dedent_source_lines(source) - - visitor = DocstringVisitor() - visitor.visit(ast.parse(dedent_source)) - return visitor.attrs diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py index 658be3b..e42c7a4 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_fields.py @@ -1,104 +1,92 @@ """Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`.""" - from __future__ import annotations as _annotations import dataclasses +import sys import warnings -from collections.abc import Mapping from copy import copy -from functools import cache -from inspect import Parameter, ismethoddescriptor, signature -from re import Pattern -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any +from annotated_types import BaseMetadata from pydantic_core import PydanticUndefined -from typing_extensions import TypeIs, get_origin -from typing_inspection import typing_objects -from typing_inspection.introspection import AnnotationSource -from pydantic import PydanticDeprecatedSince211 -from pydantic.errors import PydanticUserError - -from . import _generics, _typing_extra +from . import _typing_extra from ._config import ConfigWrapper -from ._docs_extraction import extract_docstrings_from_cls -from ._import_utils import import_cached_base_model, import_cached_field_info -from ._namespace_utils import NsResolver from ._repr import Representation -from ._utils import can_be_positional +from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar if TYPE_CHECKING: - from annotated_types import BaseMetadata - from ..fields import FieldInfo from ..main import BaseModel - from ._dataclasses import PydanticDataclass, StandardDataclass + from ._dataclasses import StandardDataclass from ._decorators import DecoratorInfos +def get_type_hints_infer_globalns( + obj: Any, + localns: dict[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + """Gets type hints for an object by inferring the global namespace. + + It uses the `typing.get_type_hints`, The only thing that we do here is fetching + global namespace from `obj.__module__` if it is not `None`. + + Args: + obj: The object to get its type hints. + localns: The local namespaces. + include_extras: Whether to recursively include annotation metadata. + + Returns: + The object type hints. + """ + module_name = getattr(obj, '__module__', None) + globalns: dict[str, Any] | None = None + if module_name: + try: + globalns = sys.modules[module_name].__dict__ + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) + + class PydanticMetadata(Representation): """Base class for annotation markers like `Strict`.""" __slots__ = () -def pydantic_general_metadata(**metadata: Any) -> BaseMetadata: - """Create a new `_PydanticGeneralMetadata` class with the given metadata. +class PydanticGeneralMetadata(PydanticMetadata, BaseMetadata): + """Pydantic general metada like `max_digits`.""" - Args: - **metadata: The metadata to add. - - Returns: - The new `_PydanticGeneralMetadata` class. - """ - return _general_metadata_cls()(metadata) # type: ignore - - -@cache -def _general_metadata_cls() -> type[BaseMetadata]: - """Do it this way to avoid importing `annotated_types` at import time.""" - from annotated_types import BaseMetadata - - class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata): - """Pydantic general metadata like `max_digits`.""" - - def __init__(self, metadata: Any): - self.__dict__ = metadata - - return _PydanticGeneralMetadata # type: ignore - - -def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], use_inspect: bool = False) -> None: - fields_docs = extract_docstrings_from_cls(cls, use_inspect=use_inspect) - for ann_name, field_info in fields.items(): - if field_info.description is None and ann_name in fields_docs: - field_info.description = fields_docs[ann_name] + def __init__(self, **metadata: Any): + self.__dict__ = metadata def collect_model_fields( # noqa: C901 cls: type[BaseModel], + bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, - ns_resolver: NsResolver | None, + types_namespace: dict[str, Any] | None, *, - typevars_map: Mapping[TypeVar, Any] | None = None, + typevars_map: dict[Any, Any] | None = None, ) -> tuple[dict[str, FieldInfo], set[str]]: - """Collect the fields and class variables names of a nascent Pydantic model. + """Collect the fields of a nascent pydantic model. - The fields collection process is *lenient*, meaning it won't error if string annotations - fail to evaluate. If this happens, the original annotation (and assigned value, if any) - is stored on the created `FieldInfo` instance. + Also collect the names of any ClassVars present in the type hints. - The `rebuild_model_fields()` should be called at a later point (e.g. when rebuilding the model), - and will make use of these stored attributes. + The returned value is a tuple of two items: the fields dict, and the set of ClassVar names. Args: cls: BaseModel or dataclass. + bases: Parents of the class, generally `cls.__bases__`. config_wrapper: The config wrapper instance. - ns_resolver: Namespace resolver to use when getting model annotations. + types_namespace: Optional extra namespace to look for types in. typevars_map: A dictionary mapping type variables to their concrete types. Returns: - A two-tuple containing model fields and class variables names. + A tuple contains fields and class variables. Raises: NameError: @@ -106,16 +94,9 @@ def collect_model_fields( # noqa: C901 - If there is a field other than `root` in `RootModel`. - If a field shadows an attribute in the parent model. """ - BaseModel = import_cached_base_model() - FieldInfo_ = import_cached_field_info() + from ..fields import FieldInfo - bases = cls.__bases__ - parent_fields_lookup: dict[str, FieldInfo] = {} - for base in reversed(bases): - if model_fields := getattr(base, '__pydantic_fields__', None): - parent_fields_lookup.update(model_fields) - - type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver) + type_hints = get_cls_type_hints_lenient(cls, types_namespace) # https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older # annotations is only used for finding fields in parent classes @@ -123,50 +104,39 @@ def collect_model_fields( # noqa: C901 fields: dict[str, FieldInfo] = {} class_vars: set[str] = set() - for ann_name, (ann_type, evaluated) in type_hints.items(): + for ann_name, ann_type in type_hints.items(): if ann_name == 'model_config': # We never want to treat `model_config` as a field # Note: we may need to change this logic if/when we introduce a `BareModel` class with no # protected namespaces (where `model_config` might be allowed as a field name) continue - for protected_namespace in config_wrapper.protected_namespaces: - ns_violation: bool = False - if isinstance(protected_namespace, Pattern): - ns_violation = protected_namespace.match(ann_name) is not None - elif isinstance(protected_namespace, str): - ns_violation = ann_name.startswith(protected_namespace) - - if ns_violation: + if ann_name.startswith(protected_namespace): for b in bases: if hasattr(b, ann_name): - if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})): + from ..main import BaseModel + + if not (issubclass(b, BaseModel) and ann_name in b.model_fields): raise NameError( f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}' f' of protected namespace "{protected_namespace}".' ) else: - valid_namespaces = () - for pn in config_wrapper.protected_namespaces: - if isinstance(pn, Pattern): - if not pn.match(ann_name): - valid_namespaces += (f're.compile({pn.pattern})',) - else: - if not ann_name.startswith(pn): - valid_namespaces += (pn,) - + valid_namespaces = tuple( + x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x) + ) warnings.warn( - f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".' + f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".' '\n\nYou may be able to resolve this warning by setting' f" `model_config['protected_namespaces'] = {valid_namespaces}`.", UserWarning, ) - if _typing_extra.is_classvar_annotation(ann_type): + if is_classvar(ann_type): + class_vars.add(ann_name) + continue + if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)): class_vars.add(ann_name) continue - - assigned_value = getattr(cls, ann_name, PydanticUndefined) - if not is_valid_field_name(ann_name): continue if cls.__pydantic_root_model__ and ann_name != 'root': @@ -175,7 +145,7 @@ def collect_model_fields( # noqa: C901 ) # when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get - # "... shadows an attribute" warnings + # "... shadows an attribute" errors generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin') for base in bases: dataclass_fields = { @@ -183,77 +153,42 @@ def collect_model_fields( # noqa: C901 } if hasattr(base, ann_name): if base is generic_origin: - # Don't warn when "shadowing" of attributes in parametrized generics + # Don't error when "shadowing" of attributes in parametrized generics continue if ann_name in dataclass_fields: - # Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set + # Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set # on the class instance. continue - - if ann_name not in annotations: - # Don't warn when a field exists in a parent class but has not been defined in the current class - continue - warnings.warn( - f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent ' - f'"{base.__qualname__}"', + f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ', UserWarning, ) - if assigned_value is PydanticUndefined: # no assignment, just a plain annotation - if ann_name in annotations or ann_name not in parent_fields_lookup: - # field is either: - # - present in the current model's annotations (and *not* from parent classes) - # - not found on any base classes; this seems to be caused by fields bot getting - # generated due to models not being fully defined while initializing recursive models. - # Nothing stops us from just creating a `FieldInfo` for this type hint, so we do this. - field_info = FieldInfo_.from_annotation(ann_type, _source=AnnotationSource.CLASS) - if not evaluated: - field_info._complete = False - # Store the original annotation that should be used to rebuild - # the field info later: - field_info._original_annotation = ann_type + try: + default = getattr(cls, ann_name, PydanticUndefined) + if default is PydanticUndefined: + raise AttributeError + except AttributeError: + if ann_name in annotations: + field_info = FieldInfo.from_annotation(ann_type) else: - # The field was present on one of the (possibly multiple) base classes - # copy the field to make sure typevar substitutions don't cause issues with the base classes - field_info = copy(parent_fields_lookup[ann_name]) - - else: # An assigned value is present (either the default value, or a `Field()` function) - _warn_on_nested_alias_in_annotation(ann_type, ann_name) - if isinstance(assigned_value, FieldInfo_) and ismethoddescriptor(assigned_value.default): - # `assigned_value` was fetched using `getattr`, which triggers a call to `__get__` - # for descriptors, so we do the same if the `= field(default=...)` form is used. - # Note that we only do this for method descriptors for now, we might want to - # extend this to any descriptor in the future (by simply checking for - # `hasattr(assigned_value.default, '__get__')`). - assigned_value.default = assigned_value.default.__get__(None, cls) - - # The `from_annotated_attribute()` call below mutates the assigned `Field()`, so make a copy: - original_assignment = ( - assigned_value._copy() if not evaluated and isinstance(assigned_value, FieldInfo_) else assigned_value - ) - - field_info = FieldInfo_.from_annotated_attribute(ann_type, assigned_value, _source=AnnotationSource.CLASS) - # Store the original annotation and assignment value that should be used to rebuild the field info later. - # Note that the assignment is always stored as the annotation might contain a type var that is later - # parameterized with an unknown forward reference (and we'll need it to rebuild the field info): - field_info._original_assignment = original_assignment - if not evaluated: - field_info._complete = False - field_info._original_annotation = ann_type - elif 'final' in field_info._qualifiers and not field_info.is_required(): - warnings.warn( - f'Annotation {ann_name!r} is marked as final and has a default value. Pydantic treats {ann_name!r} as a ' - 'class variable, but it will be considered as a normal field in V3 to be aligned with dataclasses. If you ' - f'still want {ann_name!r} to be considered as a class variable, annotate it as: `ClassVar[] = .`', - category=PydanticDeprecatedSince211, - # Incorrect when `create_model` is used, but the chance that final with a default is used is low in that case: - stacklevel=4, - ) - class_vars.add(ann_name) - continue - + # if field has no default value and is not in __annotations__ this means that it is + # defined in a base class and we can take it from there + model_fields_lookup: dict[str, FieldInfo] = {} + for x in cls.__bases__[::-1]: + model_fields_lookup.update(getattr(x, 'model_fields', {})) + if ann_name in model_fields_lookup: + # The field was present on one of the (possibly multiple) base classes + # copy the field to make sure typevar substitutions don't cause issues with the base classes + field_info = copy(model_fields_lookup[ann_name]) + else: + # The field was not found on any base classes; this seems to be caused by fields not getting + # generated thanks to models not being fully defined while initializing recursive models. + # Nothing stops us from just creating a new FieldInfo for this type hint, so we do this. + field_info = FieldInfo.from_annotation(ann_type) + else: + field_info = FieldInfo.from_annotated_attribute(ann_type, default) # attributes which are fields are removed from the class namespace: # 1. To match the behaviour of annotation-only fields # 2. To avoid false positives in the NameError check above @@ -266,250 +201,81 @@ def collect_model_fields( # noqa: C901 # to make sure the decorators have already been built for this exact class decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__'] if ann_name in decorators.computed_fields: - raise TypeError( - f'Field {ann_name!r} of class {cls.__name__!r} overrides symbol of same name in a parent class. ' - 'This override with a computed_field is incompatible.' - ) + raise ValueError("you can't override a field with a computed field") fields[ann_name] = field_info if typevars_map: for field in fields.values(): - if field._complete: - field.apply_typevars_map(typevars_map) + field.apply_typevars_map(typevars_map, types_namespace) - if config_wrapper.use_attribute_docstrings: - _update_fields_from_docstrings(cls, fields) return fields, class_vars -def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None: - FieldInfo = import_cached_field_info() +def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool: + from ..fields import FieldInfo - args = getattr(ann_type, '__args__', None) - if args: - for anno_arg in args: - if typing_objects.is_annotated(get_origin(anno_arg)): - for anno_type_arg in _typing_extra.get_args(anno_arg): - if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None: - warnings.warn( - f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.', - UserWarning, - ) - return - - -def rebuild_model_fields( - cls: type[BaseModel], - *, - ns_resolver: NsResolver, - typevars_map: Mapping[TypeVar, Any], -) -> dict[str, FieldInfo]: - """Rebuild the (already present) model fields by trying to reevaluate annotations. - - This function should be called whenever a model with incomplete fields is encountered. - - Raises: - NameError: If one of the annotations failed to evaluate. - - Note: - This function *doesn't* mutate the model fields in place, as it can be called during - schema generation, where you don't want to mutate other model's fields. - """ - FieldInfo_ = import_cached_field_info() - - rebuilt_fields: dict[str, FieldInfo] = {} - with ns_resolver.push(cls): - for f_name, field_info in cls.__pydantic_fields__.items(): - if field_info._complete: - rebuilt_fields[f_name] = field_info - else: - existing_desc = field_info.description - ann = _typing_extra.eval_type( - field_info._original_annotation, - *ns_resolver.types_namespace, - ) - ann = _generics.replace_types(ann, typevars_map) - - if (assign := field_info._original_assignment) is PydanticUndefined: - new_field = FieldInfo_.from_annotation(ann, _source=AnnotationSource.CLASS) - else: - new_field = FieldInfo_.from_annotated_attribute(ann, assign, _source=AnnotationSource.CLASS) - # The description might come from the docstring if `use_attribute_docstrings` was `True`: - new_field.description = new_field.description if new_field.description is not None else existing_desc - rebuilt_fields[f_name] = new_field - - return rebuilt_fields + if not is_finalvar(type_): + return False + elif val is PydanticUndefined: + return False + elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None): + return False + else: + return True def collect_dataclass_fields( - cls: type[StandardDataclass], - *, - ns_resolver: NsResolver | None = None, - typevars_map: dict[Any, Any] | None = None, - config_wrapper: ConfigWrapper | None = None, + cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None ) -> dict[str, FieldInfo]: """Collect the fields of a dataclass. Args: cls: dataclass. - ns_resolver: Namespace resolver to use when getting dataclass annotations. - Defaults to an empty instance. + types_namespace: Optional extra namespace to look for types in. typevars_map: A dictionary mapping type variables to their concrete types. - config_wrapper: The config wrapper instance. Returns: The dataclass fields. """ - FieldInfo_ = import_cached_field_info() + from ..fields import FieldInfo fields: dict[str, FieldInfo] = {} - ns_resolver = ns_resolver or NsResolver() - dataclass_fields = cls.__dataclass_fields__ + dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__ + cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead - # The logic here is similar to `_typing_extra.get_cls_type_hints`, - # although we do it manually as stdlib dataclasses already have annotations - # collected in each class: - for base in reversed(cls.__mro__): - if not dataclasses.is_dataclass(base): + for ann_name, dataclass_field in dataclass_fields.items(): + ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns) + if is_classvar(ann_type): continue - with ns_resolver.push(base): - for ann_name, dataclass_field in dataclass_fields.items(): - if ann_name not in base.__dict__.get('__annotations__', {}): - # `__dataclass_fields__`contains every field, even the ones from base classes. - # Only collect the ones defined on `base`. - continue + if not dataclass_field.init and dataclass_field.default_factory == dataclasses.MISSING: + # TODO: We should probably do something with this so that validate_assignment behaves properly + # Issue: https://github.com/pydantic/pydantic/issues/5470 + continue - globalns, localns = ns_resolver.types_namespace - ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns) + if isinstance(dataclass_field.default, FieldInfo): + if dataclass_field.default.init_var: + # TODO: same note as above + continue + field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default) + else: + field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field) + fields[ann_name] = field_info - if _typing_extra.is_classvar_annotation(ann_type): - continue - - if ( - not dataclass_field.init - and dataclass_field.default is dataclasses.MISSING - and dataclass_field.default_factory is dataclasses.MISSING - ): - # TODO: We should probably do something with this so that validate_assignment behaves properly - # Issue: https://github.com/pydantic/pydantic/issues/5470 - continue - - if isinstance(dataclass_field.default, FieldInfo_): - if dataclass_field.default.init_var: - if dataclass_field.default.init is False: - raise PydanticUserError( - f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.', - code='clashing-init-and-init-var', - ) - - # TODO: same note as above re validate_assignment - continue - field_info = FieldInfo_.from_annotated_attribute( - ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS - ) - field_info._original_assignment = dataclass_field.default - else: - field_info = FieldInfo_.from_annotated_attribute( - ann_type, dataclass_field, _source=AnnotationSource.DATACLASS - ) - field_info._original_assignment = dataclass_field - - if not evaluated: - field_info._complete = False - field_info._original_annotation = ann_type - - fields[ann_name] = field_info - - if field_info.default is not PydanticUndefined and isinstance( - getattr(cls, ann_name, field_info), FieldInfo_ - ): - # We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo - setattr(cls, ann_name, field_info.default) + if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo): + # We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo + setattr(cls, ann_name, field_info.default) if typevars_map: for field in fields.values(): - # We don't pass any ns, as `field.annotation` - # was already evaluated. TODO: is this method relevant? - # Can't we juste use `_generics.replace_types`? - field.apply_typevars_map(typevars_map) - - if config_wrapper is not None and config_wrapper.use_attribute_docstrings: - _update_fields_from_docstrings( - cls, - fields, - # We can't rely on the (more reliable) frame inspection method - # for stdlib dataclasses: - use_inspect=not hasattr(cls, '__is_pydantic_dataclass__'), - ) + field.apply_typevars_map(typevars_map, types_namespace) return fields -def rebuild_dataclass_fields( - cls: type[PydanticDataclass], - *, - config_wrapper: ConfigWrapper, - ns_resolver: NsResolver, - typevars_map: Mapping[TypeVar, Any], -) -> dict[str, FieldInfo]: - """Rebuild the (already present) dataclass fields by trying to reevaluate annotations. - - This function should be called whenever a dataclass with incomplete fields is encountered. - - Raises: - NameError: If one of the annotations failed to evaluate. - - Note: - This function *doesn't* mutate the dataclass fields in place, as it can be called during - schema generation, where you don't want to mutate other dataclass's fields. - """ - FieldInfo_ = import_cached_field_info() - - rebuilt_fields: dict[str, FieldInfo] = {} - with ns_resolver.push(cls): - for f_name, field_info in cls.__pydantic_fields__.items(): - if field_info._complete: - rebuilt_fields[f_name] = field_info - else: - existing_desc = field_info.description - ann = _typing_extra.eval_type( - field_info._original_annotation, - *ns_resolver.types_namespace, - ) - ann = _generics.replace_types(ann, typevars_map) - new_field = FieldInfo_.from_annotated_attribute( - ann, - field_info._original_assignment, - _source=AnnotationSource.DATACLASS, - ) - - # The description might come from the docstring if `use_attribute_docstrings` was `True`: - new_field.description = new_field.description if new_field.description is not None else existing_desc - rebuilt_fields[f_name] = new_field - - return rebuilt_fields - - def is_valid_field_name(name: str) -> bool: return not name.startswith('_') def is_valid_privateattr_name(name: str) -> bool: return name.startswith('_') and not name.startswith('__') - - -def takes_validated_data_argument( - default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any], -) -> TypeIs[Callable[[dict[str, Any]], Any]]: - """Whether the provided default factory callable has a validated data parameter.""" - try: - sig = signature(default_factory) - except (ValueError, TypeError): - # `inspect.signature` might not be able to infer a signature, e.g. with C objects. - # In this case, we assume no data argument is present: - return False - - parameters = list(sig.parameters.values()) - - return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_forward_ref.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_forward_ref.py index 231f81d..edf4baa 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_forward_ref.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_forward_ref.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations from dataclasses import dataclass -from typing import Union @dataclass @@ -15,9 +14,3 @@ class PydanticRecursiveRef: """Defining __call__ is necessary for the `typing` module to let you use an instance of this class as the result of resolving a standard ForwardRef. """ - - def __or__(self, other): - return Union[self, other] # type: ignore - - def __ror__(self, other): - return Union[other, self] # type: ignore diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py index 0451228..e48d21d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_generate_schema.py @@ -1,78 +1,64 @@ """Convert python types to pydantic-core schema.""" - from __future__ import annotations as _annotations import collections.abc import dataclasses -import datetime import inspect -import os -import pathlib import re import sys import typing import warnings -from collections.abc import Generator, Iterable, Iterator, Mapping from contextlib import contextmanager -from copy import copy -from decimal import Decimal +from copy import copy, deepcopy from enum import Enum -from fractions import Fraction from functools import partial from inspect import Parameter, _ParameterKind, signature -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from itertools import chain from operator import attrgetter -from types import FunctionType, GenericAlias, LambdaType, MethodType +from types import FunctionType, LambdaType, MethodType from typing import ( TYPE_CHECKING, Any, Callable, - Final, + Dict, ForwardRef, - Literal, + Iterable, + Iterator, + Mapping, + Type, TypeVar, Union, cast, overload, ) -from uuid import UUID from warnings import warn -from zoneinfo import ZoneInfo -import typing_extensions -from pydantic_core import ( - CoreSchema, - MultiHostUrl, - PydanticCustomError, - PydanticSerializationUnexpectedValue, - PydanticUndefined, - Url, - core_schema, - to_jsonable_python, -) -from typing_extensions import TypeAlias, TypeAliasType, TypedDict, get_args, get_origin, is_typeddict -from typing_inspection import typing_objects -from typing_inspection.introspection import AnnotationSource, get_literal_values, is_union_origin +from pydantic_core import CoreSchema, PydanticUndefined, core_schema, to_jsonable_python +from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypedDict, get_args, get_origin, is_typeddict -from ..aliases import AliasChoices, AliasGenerator, AliasPath from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler -from ..config import ConfigDict, JsonDict, JsonEncoder, JsonSchemaExtraCallable +from ..config import ConfigDict, JsonEncoder from ..errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation, PydanticUserError -from ..functional_validators import AfterValidator, BeforeValidator, FieldValidatorModes, PlainValidator, WrapValidator +from ..fields import AliasChoices, AliasPath, FieldInfo from ..json_schema import JsonSchemaValue from ..version import version_short from ..warnings import PydanticDeprecatedSince20 -from . import _decorators, _discriminated_union, _known_annotated_metadata, _repr, _typing_extra +from . import _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra from ._config import ConfigWrapper, ConfigWrapperStack -from ._core_metadata import CoreMetadata, update_core_metadata +from ._core_metadata import ( + CoreMetadataHandler, + build_metadata_dict, +) from ._core_utils import ( + NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, + CoreSchemaOrField, + define_expected_missing_refs, get_ref, get_type_ref, is_list_like_schema_with_items_schema, - validate_core_schema, ) from ._decorators import ( + ComputedFieldInfo, Decorator, DecoratorInfos, FieldSerializerDecoratorInfo, @@ -86,30 +72,23 @@ from ._decorators import ( inspect_model_serializer, inspect_validator, ) -from ._docs_extraction import extract_docstrings_from_cls -from ._fields import ( - collect_dataclass_fields, - rebuild_dataclass_fields, - rebuild_model_fields, - takes_validated_data_argument, -) +from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns from ._forward_ref import PydanticRecursiveRef -from ._generics import get_standard_typevars_map, replace_types -from ._import_utils import import_cached_base_model, import_cached_field_info -from ._mock_val_ser import MockCoreSchema -from ._namespace_utils import NamespacesTuple, NsResolver -from ._schema_gather import MissingDefinitionError, gather_schemas_for_cleaning -from ._schema_generation_shared import CallbackGetCoreSchemaHandler -from ._utils import lenient_issubclass, smart_deepcopy +from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types +from ._schema_generation_shared import ( + CallbackGetCoreSchemaHandler, +) +from ._typing_extra import is_finalvar +from ._utils import lenient_issubclass if TYPE_CHECKING: - from ..fields import ComputedFieldInfo, FieldInfo from ..main import BaseModel - from ..types import Discriminator + from ..validators import FieldValidatorModes from ._dataclasses import StandardDataclass from ._schema_generation_shared import GetJsonSchemaFunction _SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12) +_AnnotatedType = type(Annotated[int, 123]) FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo] FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo) @@ -119,55 +98,15 @@ AnyFieldDecorator = Union[ Decorator[FieldSerializerDecoratorInfo], ] -ModifyCoreSchemaWrapHandler: TypeAlias = GetCoreSchemaHandler -GetCoreSchemaFunction: TypeAlias = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema] -ParametersCallback: TypeAlias = "Callable[[int, str, Any], Literal['skip'] | None]" +ModifyCoreSchemaWrapHandler = GetCoreSchemaHandler +GetCoreSchemaFunction = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema] -TUPLE_TYPES: list[type] = [typing.Tuple, tuple] # noqa: UP006 -LIST_TYPES: list[type] = [typing.List, list, collections.abc.MutableSequence] # noqa: UP006 -SET_TYPES: list[type] = [typing.Set, set, collections.abc.MutableSet] # noqa: UP006 -FROZEN_SET_TYPES: list[type] = [typing.FrozenSet, frozenset, collections.abc.Set] # noqa: UP006 -DICT_TYPES: list[type] = [typing.Dict, dict] # noqa: UP006 -IP_TYPES: list[type] = [IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network] -SEQUENCE_TYPES: list[type] = [typing.Sequence, collections.abc.Sequence] -ITERABLE_TYPES: list[type] = [typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator] -TYPE_TYPES: list[type] = [typing.Type, type] # noqa: UP006 -PATTERN_TYPES: list[type] = [typing.Pattern, re.Pattern] -PATH_TYPES: list[type] = [ - os.PathLike, - pathlib.Path, - pathlib.PurePath, - pathlib.PosixPath, - pathlib.PurePosixPath, - pathlib.PureWindowsPath, -] -MAPPING_TYPES = [ - typing.Mapping, - typing.MutableMapping, - collections.abc.Mapping, - collections.abc.MutableMapping, - collections.OrderedDict, - typing_extensions.OrderedDict, - typing.DefaultDict, # noqa: UP006 - collections.defaultdict, -] -COUNTER_TYPES = [collections.Counter, typing.Counter] -DEQUE_TYPES: list[type] = [collections.deque, typing.Deque] # noqa: UP006 -# Note: This does not play very well with type checkers. For example, -# `a: LambdaType = lambda x: x` will raise a type error by Pyright. -ValidateCallSupportedTypes = Union[ - LambdaType, - FunctionType, - MethodType, - partial, -] - -VALIDATE_CALL_SUPPORTED_TYPES = get_args(ValidateCallSupportedTypes) - -_mode_to_validator: dict[ - FieldValidatorModes, type[BeforeValidator | AfterValidator | PlainValidator | WrapValidator] -] = {'before': BeforeValidator, 'after': AfterValidator, 'plain': PlainValidator, 'wrap': WrapValidator} +TUPLE_TYPES: list[type] = [tuple, typing.Tuple] +LIST_TYPES: list[type] = [list, typing.List, collections.abc.MutableSequence] +SET_TYPES: list[type] = [set, typing.Set, collections.abc.MutableSet] +FROZEN_SET_TYPES: list[type] = [frozenset, typing.FrozenSet, collections.abc.Set] +DICT_TYPES: list[type] = [dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping] def check_validator_fields_against_field_name( @@ -183,8 +122,13 @@ def check_validator_fields_against_field_name( Returns: `True` if field name is in validator fields, `False` otherwise. """ - fields = info.fields - return '*' in fields or field in fields + if isinstance(info, (ValidatorDecoratorInfo, FieldValidatorDecoratorInfo)): + if '*' in info.fields: + return True + for v_field_name in info.fields: + if v_field_name == field: + return True + return False def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields: Iterable[str]) -> None: @@ -201,7 +145,7 @@ def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields """ fields = set(fields) for dec in decorators: - if '*' in dec.info.fields: + if isinstance(dec.info, (ValidatorDecoratorInfo, FieldValidatorDecoratorInfo)) and '*' in dec.info.fields: continue if dec.info.check_fields is False: continue @@ -227,50 +171,60 @@ def apply_each_item_validators( ) -> core_schema.CoreSchema: # This V1 compatibility shim should eventually be removed - # fail early if each_item_validators is empty - if not each_item_validators: - return schema - # push down any `each_item=True` validators # note that this won't work for any Annotated types that get wrapped by a function validator # but that's okay because that didn't exist in V1 if schema['type'] == 'nullable': schema['schema'] = apply_each_item_validators(schema['schema'], each_item_validators, field_name) return schema - elif schema['type'] == 'tuple': - if (variadic_item_index := schema.get('variadic_item_index')) is not None: - schema['items_schema'][variadic_item_index] = apply_validators( - schema['items_schema'][variadic_item_index], - each_item_validators, - field_name, - ) elif is_list_like_schema_with_items_schema(schema): - inner_schema = schema.get('items_schema', core_schema.any_schema()) + inner_schema = schema.get('items_schema', None) + if inner_schema is None: + inner_schema = core_schema.any_schema() schema['items_schema'] = apply_validators(inner_schema, each_item_validators, field_name) elif schema['type'] == 'dict': - inner_schema = schema.get('values_schema', core_schema.any_schema()) + # push down any `each_item=True` validators onto dict _values_ + # this is super arbitrary but it's the V1 behavior + inner_schema = schema.get('values_schema', None) + if inner_schema is None: + inner_schema = core_schema.any_schema() schema['values_schema'] = apply_validators(inner_schema, each_item_validators, field_name) - else: + elif each_item_validators: raise TypeError( - f'`@validator(..., each_item=True)` cannot be applied to fields with a schema of {schema["type"]}' + f"`@validator(..., each_item=True)` cannot be applied to fields with a schema of {schema['type']}" ) return schema -def _extract_json_schema_info_from_field_info( - info: FieldInfo | ComputedFieldInfo, -) -> tuple[JsonDict | None, JsonDict | JsonSchemaExtraCallable | None]: - json_schema_updates = { - 'title': info.title, - 'description': info.description, - 'deprecated': bool(info.deprecated) or info.deprecated == '' or None, - 'examples': to_jsonable_python(info.examples), - } - json_schema_updates = {k: v for k, v in json_schema_updates.items() if v is not None} - return (json_schema_updates or None, info.json_schema_extra) +def modify_model_json_schema( + schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler, *, cls: Any +) -> JsonSchemaValue: + """Add title and description for model-like classes' JSON schema. + + Args: + schema_or_field: The schema data to generate a JSON schema from. + handler: The `GetCoreSchemaHandler` instance. + cls: The model-like class. + + Returns: + JsonSchemaValue: The updated JSON schema. + """ + json_schema = handler(schema_or_field) + original_schema = handler.resolve_ref_schema(json_schema) + # Preserve the fact that definitions schemas should never have sibling keys: + if '$ref' in original_schema: + ref = original_schema['$ref'] + original_schema.clear() + original_schema['allOf'] = [{'$ref': ref}] + if 'title' not in original_schema: + original_schema['title'] = cls.__name__ + docstring = cls.__doc__ + if docstring and 'description' not in original_schema: + original_schema['description'] = inspect.cleandoc(docstring) + return json_schema -JsonEncoders = dict[type[Any], JsonEncoder] +JsonEncoders = Dict[Type[Any], JsonEncoder] def _add_custom_serialization_from_json_encoders( @@ -307,321 +261,94 @@ def _add_custom_serialization_from_json_encoders( return schema -def _get_first_non_null(a: Any, b: Any) -> Any: - """Return the first argument if it is not None, otherwise return the second argument. - - Use case: serialization_alias (argument a) and alias (argument b) are both defined, and serialization_alias is ''. - This function will return serialization_alias, which is the first argument, even though it is an empty string. - """ - return a if a is not None else b - - -class InvalidSchemaError(Exception): - """The core schema is invalid.""" - - class GenerateSchema: """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" __slots__ = ( '_config_wrapper_stack', - '_ns_resolver', + '_types_namespace', '_typevars_map', + '_needs_apply_discriminated_union', + '_has_invalid_schema', 'field_name_stack', - 'model_type_stack', 'defs', ) def __init__( self, config_wrapper: ConfigWrapper, - ns_resolver: NsResolver | None = None, - typevars_map: Mapping[TypeVar, Any] | None = None, + types_namespace: dict[str, Any] | None, + typevars_map: dict[Any, Any] | None = None, ) -> None: - # we need a stack for recursing into nested models + # we need a stack for recursing into child models self._config_wrapper_stack = ConfigWrapperStack(config_wrapper) - self._ns_resolver = ns_resolver or NsResolver() + self._types_namespace = types_namespace self._typevars_map = typevars_map + self._needs_apply_discriminated_union = False + self._has_invalid_schema = False self.field_name_stack = _FieldNameStack() - self.model_type_stack = _ModelTypeStack() self.defs = _Definitions() - def __init_subclass__(cls) -> None: - super().__init_subclass__() - warnings.warn( - 'Subclassing `GenerateSchema` is not supported. The API is highly subject to change in minor versions.', - UserWarning, - stacklevel=2, - ) + @classmethod + def __from_parent( + cls, + config_wrapper_stack: ConfigWrapperStack, + types_namespace: dict[str, Any] | None, + typevars_map: dict[Any, Any] | None, + defs: _Definitions, + ) -> GenerateSchema: + obj = cls.__new__(cls) + obj._config_wrapper_stack = config_wrapper_stack + obj._types_namespace = types_namespace + obj._typevars_map = typevars_map + obj._needs_apply_discriminated_union = False + obj._has_invalid_schema = False + obj.field_name_stack = _FieldNameStack() + obj.defs = defs + return obj @property def _config_wrapper(self) -> ConfigWrapper: return self._config_wrapper_stack.tail @property - def _types_namespace(self) -> NamespacesTuple: - return self._ns_resolver.types_namespace + def _current_generate_schema(self) -> GenerateSchema: + cls = self._config_wrapper.schema_generator or GenerateSchema + return cls.__from_parent( + self._config_wrapper_stack, + self._types_namespace, + self._typevars_map, + self.defs, + ) @property def _arbitrary_types(self) -> bool: return self._config_wrapper.arbitrary_types_allowed + def str_schema(self) -> CoreSchema: + """Generate a CoreSchema for `str`""" + return core_schema.str_schema() + # the following methods can be overridden but should be considered # unstable / private APIs - def _list_schema(self, items_type: Any) -> CoreSchema: + def _list_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.list_schema(self.generate_schema(items_type)) - def _dict_schema(self, keys_type: Any, values_type: Any) -> CoreSchema: + def _dict_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema: return core_schema.dict_schema(self.generate_schema(keys_type), self.generate_schema(values_type)) - def _set_schema(self, items_type: Any) -> CoreSchema: + def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.set_schema(self.generate_schema(items_type)) - def _frozenset_schema(self, items_type: Any) -> CoreSchema: + def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema: return core_schema.frozenset_schema(self.generate_schema(items_type)) - def _enum_schema(self, enum_type: type[Enum]) -> CoreSchema: - cases: list[Any] = list(enum_type.__members__.values()) + def _tuple_variable_schema(self, tp: Any, items_type: Any) -> CoreSchema: + return core_schema.tuple_variable_schema(self.generate_schema(items_type)) - enum_ref = get_type_ref(enum_type) - description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) - if ( - description == 'An enumeration.' - ): # This is the default value provided by enum.EnumMeta.__new__; don't use it - description = None - js_updates = {'title': enum_type.__name__, 'description': description} - js_updates = {k: v for k, v in js_updates.items() if v is not None} - - sub_type: Literal['str', 'int', 'float'] | None = None - if issubclass(enum_type, int): - sub_type = 'int' - value_ser_type: core_schema.SerSchema = core_schema.simple_ser_schema('int') - elif issubclass(enum_type, str): - # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)` - sub_type = 'str' - value_ser_type = core_schema.simple_ser_schema('str') - elif issubclass(enum_type, float): - sub_type = 'float' - value_ser_type = core_schema.simple_ser_schema('float') - else: - # TODO this is an ugly hack, how do we trigger an Any schema for serialization? - value_ser_type = core_schema.plain_serializer_function_ser_schema(lambda x: x) - - if cases: - - def get_json_schema(schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - json_schema = handler(schema) - original_schema = handler.resolve_ref_schema(json_schema) - original_schema.update(js_updates) - return json_schema - - # we don't want to add the missing to the schema if it's the default one - default_missing = getattr(enum_type._missing_, '__func__', None) is Enum._missing_.__func__ # pyright: ignore[reportFunctionMemberAccess] - enum_schema = core_schema.enum_schema( - enum_type, - cases, - sub_type=sub_type, - missing=None if default_missing else enum_type._missing_, - ref=enum_ref, - metadata={'pydantic_js_functions': [get_json_schema]}, - ) - - if self._config_wrapper.use_enum_values: - enum_schema = core_schema.no_info_after_validator_function( - attrgetter('value'), enum_schema, serialization=value_ser_type - ) - - return enum_schema - - else: - - def get_json_schema_no_cases(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - json_schema = handler(core_schema.enum_schema(enum_type, cases, sub_type=sub_type, ref=enum_ref)) - original_schema = handler.resolve_ref_schema(json_schema) - original_schema.update(js_updates) - return json_schema - - # Use an isinstance check for enums with no cases. - # The most important use case for this is creating TypeVar bounds for generics that should - # be restricted to enums. This is more consistent than it might seem at first, since you can only - # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases. - # We use the get_json_schema function when an Enum subclass has been declared with no cases - # so that we can still generate a valid json schema. - return core_schema.is_instance_schema( - enum_type, - metadata={'pydantic_js_functions': [get_json_schema_no_cases]}, - ) - - def _ip_schema(self, tp: Any) -> CoreSchema: - from ._validators import IP_VALIDATOR_LOOKUP, IpType - - ip_type_json_schema_format: dict[type[IpType], str] = { - IPv4Address: 'ipv4', - IPv4Network: 'ipv4network', - IPv4Interface: 'ipv4interface', - IPv6Address: 'ipv6', - IPv6Network: 'ipv6network', - IPv6Interface: 'ipv6interface', - } - - def ser_ip(ip: Any, info: core_schema.SerializationInfo) -> str | IpType: - if not isinstance(ip, (tp, str)): - raise PydanticSerializationUnexpectedValue( - f"Expected `{tp}` but got `{type(ip)}` with value `'{ip}'` - serialized value may not be as expected." - ) - if info.mode == 'python': - return ip - return str(ip) - - return core_schema.lax_or_strict_schema( - lax_schema=core_schema.no_info_plain_validator_function(IP_VALIDATOR_LOOKUP[tp]), - strict_schema=core_schema.json_or_python_schema( - json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()), - python_schema=core_schema.is_instance_schema(tp), - ), - serialization=core_schema.plain_serializer_function_ser_schema(ser_ip, info_arg=True, when_used='always'), - metadata={ - 'pydantic_js_functions': [lambda _1, _2: {'type': 'string', 'format': ip_type_json_schema_format[tp]}] - }, - ) - - def _path_schema(self, tp: Any, path_type: Any) -> CoreSchema: - if tp is os.PathLike and (path_type not in {str, bytes} and not typing_objects.is_any(path_type)): - raise PydanticUserError( - '`os.PathLike` can only be used with `str`, `bytes` or `Any`', code='schema-for-unknown-type' - ) - - path_constructor = pathlib.PurePath if tp is os.PathLike else tp - strict_inner_schema = ( - core_schema.bytes_schema(strict=True) if (path_type is bytes) else core_schema.str_schema(strict=True) - ) - lax_inner_schema = core_schema.bytes_schema() if (path_type is bytes) else core_schema.str_schema() - - def path_validator(input_value: str | bytes) -> os.PathLike[Any]: # type: ignore - try: - if path_type is bytes: - if isinstance(input_value, bytes): - try: - input_value = input_value.decode() - except UnicodeDecodeError as e: - raise PydanticCustomError('bytes_type', 'Input must be valid bytes') from e - else: - raise PydanticCustomError('bytes_type', 'Input must be bytes') - elif not isinstance(input_value, str): - raise PydanticCustomError('path_type', 'Input is not a valid path') - - return path_constructor(input_value) # type: ignore - except TypeError as e: - raise PydanticCustomError('path_type', 'Input is not a valid path') from e - - def ser_path(path: Any, info: core_schema.SerializationInfo) -> str | os.PathLike[Any]: - if not isinstance(path, (tp, str)): - raise PydanticSerializationUnexpectedValue( - f"Expected `{tp}` but got `{type(path)}` with value `'{path}'` - serialized value may not be as expected." - ) - if info.mode == 'python': - return path - return str(path) - - instance_schema = core_schema.json_or_python_schema( - json_schema=core_schema.no_info_after_validator_function(path_validator, lax_inner_schema), - python_schema=core_schema.is_instance_schema(tp), - ) - - schema = core_schema.lax_or_strict_schema( - lax_schema=core_schema.union_schema( - [ - instance_schema, - core_schema.no_info_after_validator_function(path_validator, strict_inner_schema), - ], - custom_error_type='path_type', - custom_error_message=f'Input is not a valid path for {tp}', - ), - strict_schema=instance_schema, - serialization=core_schema.plain_serializer_function_ser_schema(ser_path, info_arg=True, when_used='always'), - metadata={'pydantic_js_functions': [lambda source, handler: {**handler(source), 'format': 'path'}]}, - ) - return schema - - def _deque_schema(self, items_type: Any) -> CoreSchema: - from ._serializers import serialize_sequence_via_list - from ._validators import deque_validator - - item_type_schema = self.generate_schema(items_type) - - # we have to use a lax list schema here, because we need to validate the deque's - # items via a list schema, but it's ok if the deque itself is not a list - list_schema = core_schema.list_schema(item_type_schema, strict=False) - - check_instance = core_schema.json_or_python_schema( - json_schema=list_schema, - python_schema=core_schema.is_instance_schema(collections.deque, cls_repr='Deque'), - ) - - lax_schema = core_schema.no_info_wrap_validator_function(deque_validator, list_schema) - - return core_schema.lax_or_strict_schema( - lax_schema=lax_schema, - strict_schema=core_schema.chain_schema([check_instance, lax_schema]), - serialization=core_schema.wrap_serializer_function_ser_schema( - serialize_sequence_via_list, schema=item_type_schema, info_arg=True - ), - ) - - def _mapping_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema: - from ._validators import MAPPING_ORIGIN_MAP, defaultdict_validator, get_defaultdict_default_default_factory - - mapped_origin = MAPPING_ORIGIN_MAP[tp] - keys_schema = self.generate_schema(keys_type) - values_schema = self.generate_schema(values_type) - dict_schema = core_schema.dict_schema(keys_schema, values_schema, strict=False) - - if mapped_origin is dict: - schema = dict_schema - else: - check_instance = core_schema.json_or_python_schema( - json_schema=dict_schema, - python_schema=core_schema.is_instance_schema(mapped_origin), - ) - - if tp is collections.defaultdict: - default_default_factory = get_defaultdict_default_default_factory(values_type) - coerce_instance_wrap = partial( - core_schema.no_info_wrap_validator_function, - partial(defaultdict_validator, default_default_factory=default_default_factory), - ) - else: - coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, mapped_origin) - - lax_schema = coerce_instance_wrap(dict_schema) - strict_schema = core_schema.chain_schema([check_instance, lax_schema]) - - schema = core_schema.lax_or_strict_schema( - lax_schema=lax_schema, - strict_schema=strict_schema, - serialization=core_schema.wrap_serializer_function_ser_schema( - lambda v, h: h(v), schema=dict_schema, info_arg=False - ), - ) - - return schema - - def _fraction_schema(self) -> CoreSchema: - """Support for [`fractions.Fraction`][fractions.Fraction].""" - from ._validators import fraction_validator - - # TODO: note, this is a fairly common pattern, re lax / strict for attempted type coercion, - # can we use a helper function to reduce boilerplate? - return core_schema.lax_or_strict_schema( - lax_schema=core_schema.no_info_plain_validator_function(fraction_validator), - strict_schema=core_schema.json_or_python_schema( - json_schema=core_schema.no_info_plain_validator_function(fraction_validator), - python_schema=core_schema.is_instance_schema(Fraction), - ), - # use str serialization to guarantee round trip behavior - serialization=core_schema.to_string_ser_schema(when_used='always'), - metadata={'pydantic_js_functions': [lambda _1, _2: {'type': 'string', 'format': 'fraction'}]}, - ) + def _tuple_positional_schema(self, tp: Any, items_types: list[Any]) -> CoreSchema: + items_schemas = [self.generate_schema(items_type) for items_type in items_types] + return core_schema.tuple_positional_schema(items_schemas) def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: if not isinstance(tp, type): @@ -646,49 +373,58 @@ class GenerateSchema: ' `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion.' ) - def _apply_discriminator_to_union( - self, schema: CoreSchema, discriminator: str | Discriminator | None - ) -> CoreSchema: - if discriminator is None: - return schema + def _apply_discriminator_to_union(self, schema: CoreSchema, discriminator: Any) -> CoreSchema: try: return _discriminated_union.apply_discriminator( schema, discriminator, - self.defs._definitions, ) except _discriminated_union.MissingDefinitionForUnionRef: # defer until defs are resolved - _discriminated_union.set_discriminator_in_metadata( + _discriminated_union.set_discriminator( schema, discriminator, ) + if 'metadata' in schema: + schema['metadata'][NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = True + else: + schema['metadata'] = {NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: True} + self._needs_apply_discriminated_union = True return schema - def clean_schema(self, schema: CoreSchema) -> CoreSchema: - schema = self.defs.finalize_schema(schema) - schema = validate_core_schema(schema) - return schema + def collect_definitions(self, schema: CoreSchema) -> CoreSchema: + ref = cast('str | None', schema.get('ref', None)) + if ref: + self.defs.definitions[ref] = schema + if 'ref' in schema: + schema = core_schema.definition_reference_schema(schema['ref']) + return core_schema.definitions_schema( + schema, + list(self.defs.definitions.values()), + ) def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None: - metadata = metadata_schema.get('metadata', {}) + metadata = CoreMetadataHandler(metadata_schema).metadata pydantic_js_functions = metadata.setdefault('pydantic_js_functions', []) # because of how we generate core schemas for nested generic models # we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times # this check may fail to catch duplicates if the function is a `functools.partial` - # or something like that, but if it does it'll fail by inserting the duplicate + # or something like that + # but if it does it'll fail by inserting the duplicate if js_function not in pydantic_js_functions: pydantic_js_functions.append(js_function) - metadata_schema['metadata'] = metadata def generate_schema( self, obj: Any, + from_dunder_get_core_schema: bool = True, ) -> core_schema.CoreSchema: """Generate core schema. Args: obj: The object to generate core schema for. + from_dunder_get_core_schema: Whether to generate schema from either the + `__get_pydantic_core_schema__` function or `__pydantic_core_schema__` property. Returns: The generated core schema. @@ -699,125 +435,77 @@ class GenerateSchema: PydanticSchemaGenerationError: If it is not possible to generate pydantic-core schema. TypeError: - - If `alias_generator` returns a disallowed type (must be str, AliasPath or AliasChoices). + - If `alias_generator` returns a non-string value. - If V1 style validator with `each_item=True` applied on a wrong field. PydanticUserError: - If `typing.TypedDict` is used instead of `typing_extensions.TypedDict` on Python < 3.12. - If `__modify_schema__` method is used instead of `__get_pydantic_json_schema__`. """ - schema = self._generate_schema_from_get_schema_method(obj, obj) + schema: CoreSchema | None = None + + if from_dunder_get_core_schema: + from_property = self._generate_schema_from_property(obj, obj) + if from_property is not None: + schema = from_property if schema is None: - schema = self._generate_schema_inner(obj) + schema = self._generate_schema(obj) - metadata_js_function = _extract_get_pydantic_json_schema(obj) + metadata_js_function = _extract_get_pydantic_json_schema(obj, schema) if metadata_js_function is not None: - metadata_schema = resolve_original_schema(schema, self.defs) + metadata_schema = resolve_original_schema(schema, self.defs.definitions) if metadata_schema: self._add_js_function(metadata_schema, metadata_js_function) schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema) + schema = self._post_process_generated_schema(schema) + return schema def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: """Generate schema for a Pydantic model.""" - BaseModel_ = import_cached_base_model() - with self.defs.get_schema_or_ref(cls) as (model_ref, maybe_schema): if maybe_schema is not None: return maybe_schema - schema = cls.__dict__.get('__pydantic_core_schema__') - if schema is not None and not isinstance(schema, MockCoreSchema): - if schema['type'] == 'definitions': - schema = self.defs.unpack_definitions(schema) - ref = get_ref(schema) - if ref: - return self.defs.create_definition_reference_schema(schema) - else: - return schema - + fields = cls.model_fields + decorators = cls.__pydantic_decorators__ + computed_fields = decorators.computed_fields + check_decorator_fields_exist( + chain( + decorators.field_validators.values(), + decorators.field_serializers.values(), + decorators.validators.values(), + ), + {*fields.keys(), *computed_fields.keys()}, + ) config_wrapper = ConfigWrapper(cls.model_config, check=False) + core_config = config_wrapper.core_config(cls) + metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls)]) - with self._config_wrapper_stack.push(config_wrapper), self._ns_resolver.push(cls): - core_config = self._config_wrapper.core_config(title=cls.__name__) + model_validators = decorators.model_validators.values() - if cls.__pydantic_fields_complete__ or cls is BaseModel_: - fields = getattr(cls, '__pydantic_fields__', {}) - else: - if not hasattr(cls, '__pydantic_fields__'): - # This happens when we have a loop in the schema generation: - # class Base[T](BaseModel): - # t: T - # - # class Other(BaseModel): - # b: 'Base[Other]' - # When we build fields for `Other`, we evaluate the forward annotation. - # At this point, `Other` doesn't have the model fields set. We create - # `Base[Other]`; model fields are successfully built, and we try to generate - # a schema for `t: Other`. As `Other.__pydantic_fields__` aren't set, we abort. - raise PydanticUndefinedAnnotation( - name=cls.__name__, - message=f'Class {cls.__name__!r} is not defined', - ) - try: - fields = rebuild_model_fields( - cls, - ns_resolver=self._ns_resolver, - typevars_map=self._typevars_map or {}, - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e - - decorators = cls.__pydantic_decorators__ - computed_fields = decorators.computed_fields - check_decorator_fields_exist( - chain( - decorators.field_validators.values(), - decorators.field_serializers.values(), - decorators.validators.values(), - ), - {*fields.keys(), *computed_fields.keys()}, - ) - - model_validators = decorators.model_validators.values() - - extras_schema = None - extras_keys_schema = None - if core_config.get('extra_fields_behavior') == 'allow': - assert cls.__mro__[0] is cls - assert cls.__mro__[-1] is object - for candidate_cls in cls.__mro__[:-1]: - extras_annotation = getattr(candidate_cls, '__annotations__', {}).get( - '__pydantic_extra__', None - ) - if extras_annotation is not None: - if isinstance(extras_annotation, str): - extras_annotation = _typing_extra.eval_type_backport( - _typing_extra._make_forward_ref( - extras_annotation, is_argument=False, is_class=True - ), - *self._types_namespace, - ) - tp = get_origin(extras_annotation) - if tp not in DICT_TYPES: - raise PydanticSchemaGenerationError( - 'The type annotation for `__pydantic_extra__` must be `dict[str, ...]`' - ) - extra_keys_type, extra_items_type = self._get_args_resolving_forward_refs( - extras_annotation, - required=True, + extras_schema = None + if core_config.get('extra_fields_behavior') == 'allow': + for tp in (cls, *cls.__mro__): + extras_annotation = cls.__annotations__.get('__pydantic_extra__', None) + if extras_annotation is not None: + tp = get_origin(extras_annotation) + if tp not in (Dict, dict): + raise PydanticSchemaGenerationError( + 'The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`' ) - if extra_keys_type is not str: - extras_keys_schema = self.generate_schema(extra_keys_type) - if not typing_objects.is_any(extra_items_type): - extras_schema = self.generate_schema(extra_items_type) - if extras_keys_schema is not None or extras_schema is not None: - break - - generic_origin: type[BaseModel] | None = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin') + extra_items_type = self._get_args_resolving_forward_refs( + cls.__annotations__['__pydantic_extra__'], + required=True, + )[1] + if extra_items_type is not Any: + extras_schema = self.generate_schema(extra_items_type) + break + with self._config_wrapper_stack.push(config_wrapper): + self = self._current_generate_schema if cls.__pydantic_root_model__: root_field = self._common_field_schema('root', fields['root'], decorators) inner_schema = root_field['schema'] @@ -825,12 +513,12 @@ class GenerateSchema: model_schema = core_schema.model_schema( cls, inner_schema, - generic_origin=generic_origin, custom_init=getattr(cls, '__pydantic_custom_init__', None), root_model=True, post_init=getattr(cls, '__pydantic_post_init__', None), config=core_config, ref=model_ref, + metadata=metadata, ) else: fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema( @@ -840,91 +528,89 @@ class GenerateSchema: for d in computed_fields.values() ], extras_schema=extras_schema, - extras_keys_schema=extras_keys_schema, model_name=cls.__name__, ) inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None) + new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs()) + if new_inner_schema is not None: + inner_schema = new_inner_schema inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') model_schema = core_schema.model_schema( cls, inner_schema, - generic_origin=generic_origin, custom_init=getattr(cls, '__pydantic_custom_init__', None), root_model=False, post_init=getattr(cls, '__pydantic_post_init__', None), config=core_config, ref=model_ref, + metadata=metadata, ) schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values()) schema = apply_model_validators(schema, model_validators, 'outer') - return self.defs.create_definition_reference_schema(schema) + self.defs.definitions[model_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(model_ref) - def _resolve_self_type(self, obj: Any) -> Any: - obj = self.model_type_stack.get() - if obj is None: - raise PydanticUserError('`typing.Self` is invalid in this context', code='invalid-self-type') - return obj + def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema: + """Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions` + and return the inner schema. + """ - def _generate_schema_from_get_schema_method(self, obj: Any, source: Any) -> core_schema.CoreSchema | None: - BaseModel_ = import_cached_base_model() + def get_ref(s: CoreSchema) -> str: + return s['ref'] # type: ignore + if schema['type'] == 'definitions': + self.defs.definitions.update({get_ref(s): s for s in schema['definitions']}) + schema = schema['schema'] + return schema + + def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None: + """Try to generate schema from either the `__get_pydantic_core_schema__` function or + `__pydantic_core_schema__` property. + + Note: `__get_pydantic_core_schema__` takes priority so it can + decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema. + """ + # avoid calling `__get_pydantic_core_schema__` if we've already visited this object + with self.defs.get_schema_or_ref(obj) as (_, maybe_schema): + if maybe_schema is not None: + return maybe_schema + if obj is source: + ref_mode = 'unpack' + else: + ref_mode = 'to-def' + + schema: CoreSchema get_schema = getattr(obj, '__get_pydantic_core_schema__', None) - is_base_model_get_schema = ( - getattr(get_schema, '__func__', None) is BaseModel_.__get_pydantic_core_schema__.__func__ # pyright: ignore[reportFunctionMemberAccess] - ) - - if ( - get_schema is not None - # BaseModel.__get_pydantic_core_schema__ is defined for backwards compatibility, - # to allow existing code to call `super().__get_pydantic_core_schema__` in Pydantic - # model that overrides `__get_pydantic_core_schema__`. However, it raises a deprecation - # warning stating that the method will be removed, and during the core schema gen we actually - # don't call the method: - and not is_base_model_get_schema - ): - # Some referenceable types might have a `__get_pydantic_core_schema__` method - # defined on it by users (e.g. on a dataclass). This generally doesn't play well - # as these types are already recognized by the `GenerateSchema` class and isn't ideal - # as we might end up calling `get_schema_or_ref` (expensive) on types that are actually - # not referenceable: - with self.defs.get_schema_or_ref(obj) as (_, maybe_schema): - if maybe_schema is not None: - return maybe_schema - - if obj is source: - ref_mode = 'unpack' - else: - ref_mode = 'to-def' - schema = get_schema( - source, CallbackGetCoreSchemaHandler(self._generate_schema_inner, self, ref_mode=ref_mode) + if get_schema is None: + validators = getattr(obj, '__get_validators__', None) + if validators is None: + return None + warn( + '`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.', + PydanticDeprecatedSince20, ) - if schema['type'] == 'definitions': - schema = self.defs.unpack_definitions(schema) - - ref = get_ref(schema) - if ref: - return self.defs.create_definition_reference_schema(schema) - - # Note: if schema is of type `'definition-ref'`, we might want to copy it as a - # safety measure (because these are inlined in place -- i.e. mutated directly) - return schema - - if get_schema is None and (validators := getattr(obj, '__get_validators__', None)) is not None: - from pydantic.v1 import BaseModel as BaseModelV1 - - if issubclass(obj, BaseModelV1): - warn( - f'Mixing V1 models and V2 models (or constructs, like `TypeAdapter`) is not supported. Please upgrade `{obj.__name__}` to V2.', - UserWarning, - ) + schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()]) + else: + if len(inspect.signature(get_schema).parameters) == 1: + # (source) -> CoreSchema + schema = get_schema(source) else: - warn( - '`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.', - PydanticDeprecatedSince20, + schema = get_schema( + source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode) ) - return core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()]) + + schema = self._unpack_refs_defs(schema) + + ref = get_ref(schema) + if ref: + self.defs.definitions[ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(ref) + + schema = self._post_process_generated_schema(schema) + + return schema def _resolve_forward_ref(self, obj: Any) -> Any: # we assume that types_namespace has the target of forward references in its scope, @@ -935,7 +621,7 @@ class GenerateSchema: # class Model(BaseModel): # x: SomeImportedTypeAliasWithAForwardReference try: - obj = _typing_extra.eval_type_backport(obj, *self._types_namespace) + obj = _typing_extra.evaluate_fwd_ref(obj, globalns=self._types_namespace) except NameError as e: raise PydanticUndefinedAnnotation.from_name_error(e) from e @@ -949,18 +635,17 @@ class GenerateSchema: return obj @overload - def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: ... + def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: + ... @overload - def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: ... + def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: + ... def _get_args_resolving_forward_refs(self, obj: Any, required: bool = False) -> tuple[Any, ...] | None: args = get_args(obj) if args: - if isinstance(obj, GenericAlias): - # PEP 585 generic aliases don't convert args to ForwardRefs, unlike `typing.List/Dict` etc. - args = (_typing_extra._make_forward_ref(a) if isinstance(a, str) else a for a in args) - args = tuple(self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args) + args = tuple([self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args]) elif required: # pragma: no cover raise TypeError(f'Expected {obj} to have generic parameters but it had none') return args @@ -980,11 +665,29 @@ class GenerateSchema: raise TypeError(f'Expected two type arguments for {origin}, got 1') return args[0], args[1] - def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: - if typing_objects.is_self(obj): - obj = self._resolve_self_type(obj) + def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + if 'metadata' in schema: + metadata = schema['metadata'] + metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union + else: + schema['metadata'] = { + NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: self._needs_apply_discriminated_union, + } + return schema - if typing_objects.is_annotated(get_origin(obj)): + def _generate_schema(self, obj: Any) -> core_schema.CoreSchema: + """Recursively generate a pydantic-core schema for any supported python type.""" + has_invalid_schema = self._has_invalid_schema + self._has_invalid_schema = False + needs_apply_discriminated_union = self._needs_apply_discriminated_union + self._needs_apply_discriminated_union = False + schema = self._post_process_generated_schema(self._generate_schema_inner(obj)) + self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema + self._needs_apply_discriminated_union = self._needs_apply_discriminated_union or needs_apply_discriminated_union + return schema + + def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: + if isinstance(obj, _AnnotatedType): return self._annotated_schema(obj) if isinstance(obj, dict): @@ -997,11 +700,10 @@ class GenerateSchema: if isinstance(obj, ForwardRef): return self.generate_schema(self._resolve_forward_ref(obj)) - BaseModel = import_cached_base_model() + from ..main import BaseModel if lenient_issubclass(obj, BaseModel): - with self.model_type_stack.push(obj): - return self._model_schema(obj) + return self._model_schema(obj) if isinstance(obj, PydanticRecursiveRef): return core_schema.definition_reference_schema(schema_ref=obj.type_ref) @@ -1022,7 +724,7 @@ class GenerateSchema: as they get requested and we figure out what the right API for them is. """ if obj is str: - return core_schema.str_schema() + return self.str_schema() elif obj is bytes: return core_schema.bytes_schema() elif obj is int: @@ -1031,92 +733,61 @@ class GenerateSchema: return core_schema.float_schema() elif obj is bool: return core_schema.bool_schema() - elif obj is complex: - return core_schema.complex_schema() - elif typing_objects.is_any(obj) or obj is object: + elif obj is Any or obj is object: return core_schema.any_schema() - elif obj is datetime.date: - return core_schema.date_schema() - elif obj is datetime.datetime: - return core_schema.datetime_schema() - elif obj is datetime.time: - return core_schema.time_schema() - elif obj is datetime.timedelta: - return core_schema.timedelta_schema() - elif obj is Decimal: - return core_schema.decimal_schema() - elif obj is UUID: - return core_schema.uuid_schema() - elif obj is Url: - return core_schema.url_schema() - elif obj is Fraction: - return self._fraction_schema() - elif obj is MultiHostUrl: - return core_schema.multi_host_url_schema() elif obj is None or obj is _typing_extra.NoneType: return core_schema.none_schema() - elif obj in IP_TYPES: - return self._ip_schema(obj) elif obj in TUPLE_TYPES: return self._tuple_schema(obj) elif obj in LIST_TYPES: - return self._list_schema(Any) + return self._list_schema(obj, self._get_first_arg_or_any(obj)) elif obj in SET_TYPES: - return self._set_schema(Any) + return self._set_schema(obj, self._get_first_arg_or_any(obj)) elif obj in FROZEN_SET_TYPES: - return self._frozenset_schema(Any) - elif obj in SEQUENCE_TYPES: - return self._sequence_schema(Any) - elif obj in ITERABLE_TYPES: - return self._iterable_schema(obj) + return self._frozenset_schema(obj, self._get_first_arg_or_any(obj)) elif obj in DICT_TYPES: - return self._dict_schema(Any, Any) - elif obj in PATH_TYPES: - return self._path_schema(obj, Any) - elif obj in DEQUE_TYPES: - return self._deque_schema(Any) - elif obj in MAPPING_TYPES: - return self._mapping_schema(obj, Any, Any) - elif obj in COUNTER_TYPES: - return self._mapping_schema(obj, Any, int) - elif typing_objects.is_typealiastype(obj): + return self._dict_schema(obj, *self._get_first_two_args_or_any(obj)) + elif isinstance(obj, TypeAliasType): return self._type_alias_type_schema(obj) - elif obj is type: + elif obj == type: return self._type_schema() - elif _typing_extra.is_callable(obj): + elif _typing_extra.is_callable_type(obj): return core_schema.callable_schema() - elif typing_objects.is_literal(get_origin(obj)): + elif _typing_extra.is_literal_type(obj): return self._literal_schema(obj) elif is_typeddict(obj): return self._typed_dict_schema(obj, None) elif _typing_extra.is_namedtuple(obj): return self._namedtuple_schema(obj, None) - elif typing_objects.is_newtype(obj): - # NewType, can't use isinstance because it fails <3.10 + elif _typing_extra.is_new_type(obj): + # NewType, can't use isinstance because it fails <3.7 return self.generate_schema(obj.__supertype__) - elif obj in PATTERN_TYPES: + elif obj == re.Pattern: return self._pattern_schema(obj) - elif _typing_extra.is_hashable(obj): + elif obj is collections.abc.Hashable or obj is typing.Hashable: return self._hashable_schema() elif isinstance(obj, typing.TypeVar): return self._unsubstituted_typevar_schema(obj) - elif _typing_extra.is_finalvar(obj): + elif is_finalvar(obj): if obj is Final: return core_schema.any_schema() return self.generate_schema( self._get_first_arg_or_any(obj), ) - elif isinstance(obj, VALIDATE_CALL_SUPPORTED_TYPES): - return self._call_schema(obj) + elif isinstance(obj, (FunctionType, LambdaType, MethodType, partial)): + return self._callable_schema(obj) elif inspect.isclass(obj) and issubclass(obj, Enum): - return self._enum_schema(obj) - elif obj is ZoneInfo: - return self._zoneinfo_schema() + from ._std_types_schema import get_enum_core_schema - # dataclasses.is_dataclass coerces dc instances to types, but we only handle - # the case of a dc type here - if dataclasses.is_dataclass(obj): - return self._dataclass_schema(obj, None) # pyright: ignore[reportArgumentType] + return get_enum_core_schema(obj, self._config_wrapper.config_dict) + + if _typing_extra.is_dataclass(obj): + return self._dataclass_schema(obj, None) + + res = self._get_prepare_pydantic_annotations_for_known_type(obj, ()) + if res is not None: + source_type, annotations = res + return self._apply_annotations(source_type, annotations) origin = get_origin(obj) if origin is not None: @@ -1127,50 +798,43 @@ class GenerateSchema: return self._unknown_type_schema(obj) def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C901 + if isinstance(origin, TypeAliasType): + return self._type_alias_type_schema(obj) + # Need to handle generic dataclasses before looking for the schema properties because attribute accesses # on _GenericAlias delegate to the origin type, so lose the information about the concrete parametrization # As a result, currently, there is no way to cache the schema for generic dataclasses. This may be possible # to resolve by modifying the value returned by `Generic.__class_getitem__`, but that is a dangerous game. - if dataclasses.is_dataclass(origin): - return self._dataclass_schema(obj, origin) # pyright: ignore[reportArgumentType] + if _typing_extra.is_dataclass(origin): + return self._dataclass_schema(obj, origin) if _typing_extra.is_namedtuple(origin): return self._namedtuple_schema(obj, origin) - schema = self._generate_schema_from_get_schema_method(origin, obj) - if schema is not None: - return schema + from_property = self._generate_schema_from_property(origin, obj) + if from_property is not None: + return from_property - if typing_objects.is_typealiastype(origin): - return self._type_alias_type_schema(obj) - elif is_union_origin(origin): + if _typing_extra.origin_is_union(origin): return self._union_schema(obj) elif origin in TUPLE_TYPES: return self._tuple_schema(obj) elif origin in LIST_TYPES: - return self._list_schema(self._get_first_arg_or_any(obj)) + return self._list_schema(obj, self._get_first_arg_or_any(obj)) elif origin in SET_TYPES: - return self._set_schema(self._get_first_arg_or_any(obj)) + return self._set_schema(obj, self._get_first_arg_or_any(obj)) elif origin in FROZEN_SET_TYPES: - return self._frozenset_schema(self._get_first_arg_or_any(obj)) + return self._frozenset_schema(obj, self._get_first_arg_or_any(obj)) elif origin in DICT_TYPES: - return self._dict_schema(*self._get_first_two_args_or_any(obj)) - elif origin in PATH_TYPES: - return self._path_schema(origin, self._get_first_arg_or_any(obj)) - elif origin in DEQUE_TYPES: - return self._deque_schema(self._get_first_arg_or_any(obj)) - elif origin in MAPPING_TYPES: - return self._mapping_schema(origin, *self._get_first_two_args_or_any(obj)) - elif origin in COUNTER_TYPES: - return self._mapping_schema(origin, self._get_first_arg_or_any(obj), int) + return self._dict_schema(obj, *self._get_first_two_args_or_any(obj)) elif is_typeddict(origin): return self._typed_dict_schema(obj, origin) - elif origin in TYPE_TYPES: + elif origin in (typing.Type, type): return self._subclass_schema(obj) - elif origin in SEQUENCE_TYPES: - return self._sequence_schema(self._get_first_arg_or_any(obj)) - elif origin in ITERABLE_TYPES: + elif origin in {typing.Sequence, collections.abc.Sequence}: + return self._sequence_schema(obj) + elif origin in {typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator}: return self._iterable_schema(obj) - elif origin in PATTERN_TYPES: + elif origin in (re.Pattern, typing.Pattern): return self._pattern_schema(obj) if self._arbitrary_types: @@ -1224,7 +888,6 @@ class GenerateSchema: return core_schema.dataclass_field( name, common_field['schema'], - init=field_info.init, init_only=field_info.init_var or None, kw_only=None if field_info.kw_only else False, serialization_exclude=common_field['serialization_exclude'], @@ -1234,144 +897,32 @@ class GenerateSchema: metadata=common_field['metadata'], ) - @staticmethod - def _apply_alias_generator_to_field_info( - alias_generator: Callable[[str], str] | AliasGenerator, field_info: FieldInfo, field_name: str - ) -> None: - """Apply an alias_generator to aliases on a FieldInfo instance if appropriate. + def _common_field_schema(self, name: str, field_info: FieldInfo, decorators: DecoratorInfos) -> _CommonField: + # Update FieldInfo annotation if appropriate: + if has_instance_in_type(field_info.annotation, (ForwardRef, str)): + types_namespace = self._types_namespace + if self._typevars_map: + types_namespace = (types_namespace or {}).copy() + # Ensure that typevars get mapped to their concrete types: + types_namespace.update({k.__name__: v for k, v in self._typevars_map.items()}) - Args: - alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance. - field_info: The FieldInfo instance to which the alias_generator is (maybe) applied. - field_name: The name of the field from which to generate the alias. - """ - # Apply an alias_generator if - # 1. An alias is not specified - # 2. An alias is specified, but the priority is <= 1 - if ( - field_info.alias_priority is None - or field_info.alias_priority <= 1 - or field_info.alias is None - or field_info.validation_alias is None - or field_info.serialization_alias is None - ): - alias, validation_alias, serialization_alias = None, None, None + evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace, None) + if evaluated is not field_info.annotation and not has_instance_in_type(evaluated, PydanticRecursiveRef): + field_info.annotation = evaluated - if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = alias_generator.generate_aliases(field_name) - elif isinstance(alias_generator, Callable): - alias = alias_generator(field_name) - if not isinstance(alias, str): - raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') - - # if priority is not set, we set to 1 - # which supports the case where the alias_generator from a child class is used - # to generate an alias for a field in a parent class - if field_info.alias_priority is None or field_info.alias_priority <= 1: - field_info.alias_priority = 1 - - # if the priority is 1, then we set the aliases to the generated alias - if field_info.alias_priority == 1: - field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) - field_info.validation_alias = _get_first_non_null(validation_alias, alias) - field_info.alias = alias - - # if any of the aliases are not set, then we set them to the corresponding generated alias - if field_info.alias is None: - field_info.alias = alias - if field_info.serialization_alias is None: - field_info.serialization_alias = _get_first_non_null(serialization_alias, alias) - if field_info.validation_alias is None: - field_info.validation_alias = _get_first_non_null(validation_alias, alias) - - @staticmethod - def _apply_alias_generator_to_computed_field_info( - alias_generator: Callable[[str], str] | AliasGenerator, - computed_field_info: ComputedFieldInfo, - computed_field_name: str, - ): - """Apply an alias_generator to alias on a ComputedFieldInfo instance if appropriate. - - Args: - alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance. - computed_field_info: The ComputedFieldInfo instance to which the alias_generator is (maybe) applied. - computed_field_name: The name of the computed field from which to generate the alias. - """ - # Apply an alias_generator if - # 1. An alias is not specified - # 2. An alias is specified, but the priority is <= 1 - - if ( - computed_field_info.alias_priority is None - or computed_field_info.alias_priority <= 1 - or computed_field_info.alias is None - ): - alias, validation_alias, serialization_alias = None, None, None - - if isinstance(alias_generator, AliasGenerator): - alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name) - elif isinstance(alias_generator, Callable): - alias = alias_generator(computed_field_name) - if not isinstance(alias, str): - raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') - - # if priority is not set, we set to 1 - # which supports the case where the alias_generator from a child class is used - # to generate an alias for a field in a parent class - if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1: - computed_field_info.alias_priority = 1 - - # if the priority is 1, then we set the aliases to the generated alias - # note that we use the serialization_alias with priority over alias, as computed_field - # aliases are used for serialization only (not validation) - if computed_field_info.alias_priority == 1: - computed_field_info.alias = _get_first_non_null(serialization_alias, alias) - - @staticmethod - def _apply_field_title_generator_to_field_info( - config_wrapper: ConfigWrapper, field_info: FieldInfo | ComputedFieldInfo, field_name: str - ) -> None: - """Apply a field_title_generator on a FieldInfo or ComputedFieldInfo instance if appropriate - Args: - config_wrapper: The config of the model - field_info: The FieldInfo or ComputedField instance to which the title_generator is (maybe) applied. - field_name: The name of the field from which to generate the title. - """ - field_title_generator = field_info.field_title_generator or config_wrapper.field_title_generator - - if field_title_generator is None: - return - - if field_info.title is None: - title = field_title_generator(field_name, field_info) # type: ignore - if not isinstance(title, str): - raise TypeError(f'field_title_generator {field_title_generator} must return str, not {title.__class__}') - - field_info.title = title - - def _common_field_schema( # C901 - self, name: str, field_info: FieldInfo, decorators: DecoratorInfos - ) -> _CommonField: source_type, annotations = field_info.annotation, field_info.metadata def set_discriminator(schema: CoreSchema) -> CoreSchema: schema = self._apply_discriminator_to_union(schema, field_info.discriminator) return schema - # Convert `@field_validator` decorators to `Before/After/Plain/WrapValidator` instances: - validators_from_decorators = [] - for decorator in filter_field_decorator_info_by_field(decorators.field_validators.values(), name): - validators_from_decorators.append(_mode_to_validator[decorator.info.mode]._from_decorator(decorator)) - with self.field_name_stack.push(name): if field_info.discriminator is not None: - schema = self._apply_annotations( - source_type, annotations + validators_from_decorators, transform_inner_schema=set_discriminator - ) + schema = self._apply_annotations(source_type, annotations, transform_inner_schema=set_discriminator) else: schema = self._apply_annotations( source_type, - annotations + validators_from_decorators, + annotations, ) # This V1 compatibility shim should eventually be removed @@ -1385,7 +936,10 @@ class GenerateSchema: this_field_validators = [v for v in this_field_validators if v not in each_item_validators] schema = apply_each_item_validators(schema, each_item_validators, name) - schema = apply_validators(schema, this_field_validators, name) + schema = apply_validators(schema, filter_field_decorator_info_by_field(this_field_validators, name), name) + schema = apply_validators( + schema, filter_field_decorator_info_by_field(decorators.field_validators.values(), name), name + ) # the default validator needs to go outside of any other validators # so that it is the topmost validator for the field validator @@ -1396,17 +950,35 @@ class GenerateSchema: schema = self._apply_field_serializers( schema, filter_field_decorator_info_by_field(decorators.field_serializers.values(), name) ) - self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, name) + json_schema_updates = { + 'title': field_info.title, + 'description': field_info.description, + 'examples': to_jsonable_python(field_info.examples), + } + json_schema_updates = {k: v for k, v in json_schema_updates.items() if v is not None} - pydantic_js_updates, pydantic_js_extra = _extract_json_schema_info_from_field_info(field_info) - core_metadata: dict[str, Any] = {} - update_core_metadata( - core_metadata, pydantic_js_updates=pydantic_js_updates, pydantic_js_extra=pydantic_js_extra - ) + json_schema_extra = field_info.json_schema_extra + def json_schema_update_func(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = {**handler(schema), **json_schema_updates} + if isinstance(json_schema_extra, dict): + json_schema.update(to_jsonable_python(json_schema_extra)) + elif callable(json_schema_extra): + json_schema_extra(json_schema) + return json_schema + + metadata = build_metadata_dict(js_annotation_functions=[json_schema_update_func]) + + # apply alias generator alias_generator = self._config_wrapper.alias_generator - if alias_generator is not None: - self._apply_alias_generator_to_field_info(alias_generator, field_info, name) + if alias_generator and (field_info.alias_priority is None or field_info.alias_priority <= 1): + alias = alias_generator(name) + if not isinstance(alias, str): + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + field_info.alias = alias + field_info.validation_alias = alias + field_info.serialization_alias = alias + field_info.alias_priority = 1 if isinstance(field_info.validation_alias, (AliasChoices, AliasPath)): validation_alias = field_info.validation_alias.convert_to_aliases() @@ -1419,13 +991,13 @@ class GenerateSchema: validation_alias=validation_alias, serialization_alias=field_info.serialization_alias, frozen=field_info.frozen, - metadata=core_metadata, + metadata=metadata, ) def _union_schema(self, union_type: Any) -> core_schema.CoreSchema: """Generate schema for a Union.""" args = self._get_args_resolving_forward_refs(union_type, required=True) - choices: list[CoreSchema] = [] + choices: list[CoreSchema | tuple[CoreSchema, str]] = [] nullable = False for arg in args: if arg is None or arg is _typing_extra.NoneType: @@ -1434,72 +1006,62 @@ class GenerateSchema: choices.append(self.generate_schema(arg)) if len(choices) == 1: - s = choices[0] + first_choice = choices[0] + s = first_choice[0] if isinstance(first_choice, tuple) else first_choice else: - choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = [] - for choice in choices: - tag = cast(CoreMetadata, choice.get('metadata', {})).get('pydantic_internal_union_tag_key') - if tag is not None: - choices_with_tags.append((choice, tag)) - else: - choices_with_tags.append(choice) - s = core_schema.union_schema(choices_with_tags) + s = core_schema.union_schema(choices) if nullable: s = core_schema.nullable_schema(s) return s - def _type_alias_type_schema(self, obj: TypeAliasType) -> CoreSchema: - with self.defs.get_schema_or_ref(obj) as (ref, maybe_schema): + def _type_alias_type_schema( + self, + obj: Any, # TypeAliasType + ) -> CoreSchema: + origin = get_origin(obj) + origin = origin or obj + with self.defs.get_schema_or_ref(origin) as (ref, maybe_schema): if maybe_schema is not None: return maybe_schema - origin: TypeAliasType = get_origin(obj) or obj - typevars_map = get_standard_typevars_map(obj) + namespace = (self._types_namespace or {}).copy() + new_namespace = {**_typing_extra.get_cls_types_namespace(origin), **namespace} + annotation = origin.__value__ - with self._ns_resolver.push(origin): - try: - annotation = _typing_extra.eval_type(origin.__value__, *self._types_namespace) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e - annotation = replace_types(annotation, typevars_map) - schema = self.generate_schema(annotation) - assert schema['type'] != 'definitions' - schema['ref'] = ref # type: ignore - return self.defs.create_definition_reference_schema(schema) + self._types_namespace = new_namespace + typevars_map = get_standard_typevars_map(obj) + annotation = replace_types(annotation, typevars_map) + schema = self.generate_schema(annotation) + assert schema['type'] != 'definitions' + schema['ref'] = ref # type: ignore + self._types_namespace = namespace or None + self.defs.definitions[ref] = schema + return core_schema.definition_reference_schema(ref) def _literal_schema(self, literal_type: Any) -> CoreSchema: """Generate schema for a Literal.""" - expected = list(get_literal_values(literal_type, type_check=False, unpack_type_aliases='eager')) + expected = _typing_extra.all_literal_values(literal_type) assert expected, f'literal "expected" cannot be empty, obj={literal_type}' - schema = core_schema.literal_schema(expected) - - if self._config_wrapper.use_enum_values and any(isinstance(v, Enum) for v in expected): - schema = core_schema.no_info_after_validator_function( - lambda v: v.value if isinstance(v, Enum) else v, schema - ) - - return schema + return core_schema.literal_schema(expected) def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.CoreSchema: - """Generate a core schema for a `TypedDict` class. + """Generate schema for a TypedDict. - To be able to build a `DecoratorInfos` instance for the `TypedDict` class (which will include - validators, serializers, etc.), we need to have access to the original bases of the class - (see https://docs.python.org/3/library/types.html#types.get_original_bases). - However, the `__orig_bases__` attribute was only added in 3.12 (https://github.com/python/cpython/pull/103698). + It is not possible to track required/optional keys in TypedDict without __required_keys__ + since TypedDict.__new__ erases the base classes (it replaces them with just `dict`) + and thus we can track usage of total=True/False + __required_keys__ was added in Python 3.9 + (https://github.com/miss-islington/cpython/blob/1e9939657dd1f8eb9f596f77c1084d2d351172fc/Doc/library/typing.rst?plain=1#L1546-L1548) + however it is buggy + (https://github.com/python/typing_extensions/blob/ac52ac5f2cb0e00e7988bae1e2a1b8257ac88d6d/src/typing_extensions.py#L657-L666). - For this reason, we require Python 3.12 (or using the `typing_extensions` backport). + On 3.11 but < 3.12 TypedDict does not preserve inheritance information. + + Hence to avoid creating validators that do not do what users expect we only + support typing.TypedDict on Python >= 3.12 or typing_extension.TypedDict on all versions """ - FieldInfo = import_cached_field_info() - - with ( - self.model_type_stack.push(typed_dict_cls), - self.defs.get_schema_or_ref(typed_dict_cls) as ( - typed_dict_ref, - maybe_schema, - ), - ): + with self.defs.get_schema_or_ref(typed_dict_cls) as (typed_dict_ref, maybe_schema): if maybe_schema is not None: return maybe_schema @@ -1514,14 +1076,14 @@ class GenerateSchema: ) try: - # if a typed dictionary class doesn't have config, we use the parent's config, hence a default of `None` - # see https://github.com/pydantic/pydantic/issues/10917 config: ConfigDict | None = get_attribute_from_bases(typed_dict_cls, '__pydantic_config__') except AttributeError: config = None with self._config_wrapper_stack.push(config): - core_config = self._config_wrapper.core_config(title=typed_dict_cls.__name__) + core_config = self._config_wrapper.core_config(typed_dict_cls) + + self = self._current_generate_schema required_keys: frozenset[str] = typed_dict_cls.__required_keys__ @@ -1529,86 +1091,65 @@ class GenerateSchema: decorators = DecoratorInfos.build(typed_dict_cls) - if self._config_wrapper.use_attribute_docstrings: - field_docstrings = extract_docstrings_from_cls(typed_dict_cls, use_inspect=True) - else: - field_docstrings = None + for field_name, annotation in get_type_hints_infer_globalns( + typed_dict_cls, localns=self._types_namespace, include_extras=True + ).items(): + annotation = replace_types(annotation, typevars_map) + required = field_name in required_keys - try: - annotations = _typing_extra.get_cls_type_hints(typed_dict_cls, ns_resolver=self._ns_resolver) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e + if get_origin(annotation) == _typing_extra.Required: + required = True + annotation = self._get_args_resolving_forward_refs( + annotation, + required=True, + )[0] + elif get_origin(annotation) == _typing_extra.NotRequired: + required = False + annotation = self._get_args_resolving_forward_refs( + annotation, + required=True, + )[0] - readonly_fields: list[str] = [] - - for field_name, annotation in annotations.items(): - field_info = FieldInfo.from_annotation(annotation, _source=AnnotationSource.TYPED_DICT) - field_info.annotation = replace_types(field_info.annotation, typevars_map) - - required = ( - field_name in required_keys or 'required' in field_info._qualifiers - ) and 'not_required' not in field_info._qualifiers - if 'read_only' in field_info._qualifiers: - readonly_fields.append(field_name) - - if ( - field_docstrings is not None - and field_info.description is None - and field_name in field_docstrings - ): - field_info.description = field_docstrings[field_name] - self._apply_field_title_generator_to_field_info(self._config_wrapper, field_info, field_name) + field_info = FieldInfo.from_annotation(annotation) fields[field_name] = self._generate_td_field_schema( field_name, field_info, decorators, required=required ) - if readonly_fields: - fields_repr = ', '.join(repr(f) for f in readonly_fields) - plural = len(readonly_fields) >= 2 - warnings.warn( - f'Item{"s" if plural else ""} {fields_repr} on TypedDict class {typed_dict_cls.__name__!r} ' - f'{"are" if plural else "is"} using the `ReadOnly` qualifier. Pydantic will not protect items ' - 'from any mutation on dictionary instances.', - UserWarning, - ) + metadata = build_metadata_dict( + js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls + ) td_schema = core_schema.typed_dict_schema( fields, - cls=typed_dict_cls, computed_fields=[ self._computed_field_schema(d, decorators.field_serializers) for d in decorators.computed_fields.values() ], ref=typed_dict_ref, + metadata=metadata, config=core_config, ) schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values()) schema = apply_model_validators(schema, decorators.model_validators.values(), 'all') - return self.defs.create_definition_reference_schema(schema) + self.defs.definitions[typed_dict_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(typed_dict_ref) def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema: """Generate schema for a NamedTuple.""" - with ( - self.model_type_stack.push(namedtuple_cls), - self.defs.get_schema_or_ref(namedtuple_cls) as ( - namedtuple_ref, - maybe_schema, - ), - ): + with self.defs.get_schema_or_ref(namedtuple_cls) as (namedtuple_ref, maybe_schema): if maybe_schema is not None: return maybe_schema typevars_map = get_standard_typevars_map(namedtuple_cls) if origin is not None: namedtuple_cls = origin - try: - annotations = _typing_extra.get_cls_type_hints(namedtuple_cls, ns_resolver=self._ns_resolver) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e + annotations: dict[str, Any] = get_type_hints_infer_globalns( + namedtuple_cls, include_extras=True, localns=self._types_namespace + ) if not annotations: # annotations is empty, happens if namedtuple_cls defined via collections.namedtuple(...) - annotations: dict[str, Any] = {k: Any for k in namedtuple_cls._fields} + annotations = {k: Any for k in namedtuple_cls._fields} if typevars_map: annotations = { @@ -1619,40 +1160,30 @@ class GenerateSchema: arguments_schema = core_schema.arguments_schema( [ self._generate_parameter_schema( - field_name, - annotation, - source=AnnotationSource.NAMED_TUPLE, - default=namedtuple_cls._field_defaults.get(field_name, Parameter.empty), + field_name, annotation, default=namedtuple_cls._field_defaults.get(field_name, Parameter.empty) ) for field_name, annotation in annotations.items() ], - metadata={'pydantic_js_prefer_positional_arguments': True}, + metadata=build_metadata_dict(js_prefer_positional_arguments=True), ) - schema = core_schema.call_schema(arguments_schema, namedtuple_cls, ref=namedtuple_ref) - return self.defs.create_definition_reference_schema(schema) + return core_schema.call_schema(arguments_schema, namedtuple_cls, ref=namedtuple_ref) def _generate_parameter_schema( self, name: str, annotation: type[Any], - source: AnnotationSource, default: Any = Parameter.empty, mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] | None = None, ) -> core_schema.ArgumentsParameter: - """Generate the definition of a field in a namedtuple or a parameter in a function signature. - - This definition is meant to be used for the `'arguments'` core schema, which will be replaced - in V3 by the `'arguments-v3`'. - """ - FieldInfo = import_cached_field_info() - + """Prepare a ArgumentsParameter to represent a field in a namedtuple or function signature.""" if default is Parameter.empty: - field = FieldInfo.from_annotation(annotation, _source=source) + field = FieldInfo.from_annotation(annotation) else: - field = FieldInfo.from_annotated_attribute(annotation, default, _source=source) + field = FieldInfo.from_annotated_attribute(annotation, default) assert field.annotation is not None, 'field.annotation should not be None when generating a schema' + source_type, annotations = field.annotation, field.metadata with self.field_name_stack.push(name): - schema = self._apply_annotations(field.annotation, [field]) + schema = self._apply_annotations(source_type, annotations) if not field.is_required(): schema = wrap_default(field, schema) @@ -1664,61 +1195,10 @@ class GenerateSchema: parameter_schema['alias'] = field.alias else: alias_generator = self._config_wrapper.alias_generator - if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: - parameter_schema['alias'] = alias_generator.alias(name) - elif callable(alias_generator): + if alias_generator: parameter_schema['alias'] = alias_generator(name) return parameter_schema - def _generate_parameter_v3_schema( - self, - name: str, - annotation: Any, - source: AnnotationSource, - mode: Literal[ - 'positional_only', - 'positional_or_keyword', - 'keyword_only', - 'var_args', - 'var_kwargs_uniform', - 'var_kwargs_unpacked_typed_dict', - ], - default: Any = Parameter.empty, - ) -> core_schema.ArgumentsV3Parameter: - """Generate the definition of a parameter in a function signature. - - This definition is meant to be used for the `'arguments-v3'` core schema, which will replace - the `'arguments`' schema in V3. - """ - FieldInfo = import_cached_field_info() - - if default is Parameter.empty: - field = FieldInfo.from_annotation(annotation, _source=source) - else: - field = FieldInfo.from_annotated_attribute(annotation, default, _source=source) - - with self.field_name_stack.push(name): - schema = self._apply_annotations(field.annotation, [field]) - - if not field.is_required(): - schema = wrap_default(field, schema) - - parameter_schema = core_schema.arguments_v3_parameter( - name=name, - schema=schema, - mode=mode, - ) - if field.alias is not None: - parameter_schema['alias'] = field.alias - else: - alias_generator = self._config_wrapper.alias_generator - if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: - parameter_schema['alias'] = alias_generator.alias(name) - elif callable(alias_generator): - parameter_schema['alias'] = alias_generator(name) - - return parameter_schema - def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: """Generate schema for a Tuple, e.g. `tuple[int, str]` or `tuple[int, ...]`.""" # TODO: do we really need to resolve type vars here? @@ -1732,22 +1212,22 @@ class GenerateSchema: # This is only true for <3.11, on Python 3.11+ `typing.Tuple[()]` gives `params=()` if not params: if tuple_type in TUPLE_TYPES: - return core_schema.tuple_schema([core_schema.any_schema()], variadic_item_index=0) + return core_schema.tuple_variable_schema() else: # special case for `tuple[()]` which means `tuple[]` - an empty tuple - return core_schema.tuple_schema([]) + return core_schema.tuple_positional_schema([]) elif params[-1] is Ellipsis: if len(params) == 2: - return core_schema.tuple_schema([self.generate_schema(params[0])], variadic_item_index=0) + return self._tuple_variable_schema(tuple_type, params[0]) else: # TODO: something like https://github.com/pydantic/pydantic/issues/5952 raise ValueError('Variable tuples can only have one type') elif len(params) == 1 and params[0] == (): - # special case for `tuple[()]` which means `tuple[]` - an empty tuple + # special case for `Tuple[()]` which means `Tuple[]` - an empty tuple # NOTE: This conditional can be removed when we drop support for Python 3.10. - return core_schema.tuple_schema([]) + return self._tuple_positional_schema(tuple_type, []) else: - return core_schema.tuple_schema([self.generate_schema(param) for param in params]) + return self._tuple_positional_schema(tuple_type, list(params)) def _type_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( @@ -1756,83 +1236,39 @@ class GenerateSchema: custom_error_message='Input should be a type', ) - def _zoneinfo_schema(self) -> core_schema.CoreSchema: - """Generate schema for a zone_info.ZoneInfo object""" - from ._validators import validate_str_is_valid_iana_tz - - metadata = {'pydantic_js_functions': [lambda _1, _2: {'type': 'string', 'format': 'zoneinfo'}]} - return core_schema.no_info_plain_validator_function( - validate_str_is_valid_iana_tz, - serialization=core_schema.to_string_ser_schema(), - metadata=metadata, - ) - - def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema: - """Generate schema for `type[Union[X, ...]]`.""" - args = self._get_args_resolving_forward_refs(union_type, required=True) - return core_schema.union_schema([self.generate_schema(type[args]) for args in args]) - def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: - """Generate schema for a type, e.g. `type[int]`.""" + """Generate schema for a Type, e.g. `Type[int]`.""" type_param = self._get_first_arg_or_any(type_) - - # Assume `type[Annotated[, ...]]` is equivalent to `type[]`: - type_param = _typing_extra.annotated_type(type_param) or type_param - - if typing_objects.is_any(type_param): + if type_param == Any: return self._type_schema() - elif typing_objects.is_typealiastype(type_param): - return self.generate_schema(type[type_param.__value__]) - elif typing_objects.is_typevar(type_param): + elif isinstance(type_param, typing.TypeVar): if type_param.__bound__: - if is_union_origin(get_origin(type_param.__bound__)): - return self._union_is_subclass_schema(type_param.__bound__) return core_schema.is_subclass_schema(type_param.__bound__) elif type_param.__constraints__: - return core_schema.union_schema([self.generate_schema(type[c]) for c in type_param.__constraints__]) + return core_schema.union_schema( + [self.generate_schema(typing.Type[c]) for c in type_param.__constraints__] + ) else: return self._type_schema() - elif is_union_origin(get_origin(type_param)): - return self._union_is_subclass_schema(type_param) + elif _typing_extra.origin_is_union(get_origin(type_param)): + args = self._get_args_resolving_forward_refs(type_param, required=True) + return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args]) else: - if typing_objects.is_self(type_param): - type_param = self._resolve_self_type(type_param) - if _typing_extra.is_generic_alias(type_param): - raise PydanticUserError( - 'Subscripting `type[]` with an already parametrized type is not supported. ' - f'Instead of using type[{type_param!r}], use type[{_repr.display_as_type(get_origin(type_param))}].', - code=None, - ) - if not inspect.isclass(type_param): - # when using type[None], this doesn't type convert to type[NoneType], and None isn't a class - # so we handle it manually here - if type_param is None: - return core_schema.is_subclass_schema(_typing_extra.NoneType) - raise TypeError(f'Expected a class, got {type_param!r}') return core_schema.is_subclass_schema(type_param) - def _sequence_schema(self, items_type: Any) -> core_schema.CoreSchema: + def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: """Generate schema for a Sequence, e.g. `Sequence[int]`.""" - from ._serializers import serialize_sequence_via_list + item_type = self._get_first_arg_or_any(sequence_type) - item_type_schema = self.generate_schema(items_type) - list_schema = core_schema.list_schema(item_type_schema) - - json_schema = smart_deepcopy(list_schema) + list_schema = core_schema.list_schema(self.generate_schema(item_type)) python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence') - if not typing_objects.is_any(items_type): + if item_type != Any: from ._validators import sequence_validator python_schema = core_schema.chain_schema( [python_schema, core_schema.no_info_wrap_validator_function(sequence_validator, list_schema)], ) - - serialization = core_schema.wrap_serializer_function_ser_schema( - serialize_sequence_via_list, schema=item_type_schema, info_arg=True - ) - return core_schema.json_or_python_schema( - json_schema=json_schema, python_schema=python_schema, serialization=serialization - ) + return core_schema.json_or_python_schema(json_schema=list_schema, python_schema=python_schema) def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: """Generate a schema for an `Iterable`.""" @@ -1843,11 +1279,11 @@ class GenerateSchema: def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: from . import _validators - metadata = {'pydantic_js_functions': [lambda _1, _2: {'type': 'string', 'format': 'regex'}]} + metadata = build_metadata_dict(js_functions=[lambda _1, _2: {'type': 'string', 'format': 'regex'}]) ser = core_schema.plain_serializer_function_ser_schema( attrgetter('pattern'), when_used='json', return_schema=core_schema.str_schema() ) - if pattern_type is typing.Pattern or pattern_type is re.Pattern: + if pattern_type == typing.Pattern or pattern_type == re.Pattern: # bare type return core_schema.no_info_plain_validator_function( _validators.pattern_either_validator, serialization=ser, metadata=metadata @@ -1857,11 +1293,11 @@ class GenerateSchema: pattern_type, required=True, )[0] - if param is str: + if param == str: return core_schema.no_info_plain_validator_function( _validators.pattern_str_validator, serialization=ser, metadata=metadata ) - elif param is bytes: + elif param == bytes: return core_schema.no_info_plain_validator_function( _validators.pattern_bytes_validator, serialization=ser, metadata=metadata ) @@ -1870,12 +1306,7 @@ class GenerateSchema: def _hashable_schema(self) -> core_schema.CoreSchema: return core_schema.custom_error_schema( - schema=core_schema.json_or_python_schema( - json_schema=core_schema.chain_schema( - [core_schema.any_schema(), core_schema.is_instance_schema(collections.abc.Hashable)] - ), - python_schema=core_schema.is_instance_schema(collections.abc.Hashable), - ), + core_schema.is_instance_schema(collections.abc.Hashable), custom_error_type='is_hashable', custom_error_message='Input should be hashable', ) @@ -1884,77 +1315,33 @@ class GenerateSchema: self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None ) -> core_schema.CoreSchema: """Generate schema for a dataclass.""" - with ( - self.model_type_stack.push(dataclass), - self.defs.get_schema_or_ref(dataclass) as ( - dataclass_ref, - maybe_schema, - ), - ): + with self.defs.get_schema_or_ref(dataclass) as (dataclass_ref, maybe_schema): if maybe_schema is not None: return maybe_schema - schema = dataclass.__dict__.get('__pydantic_core_schema__') - if schema is not None and not isinstance(schema, MockCoreSchema): - if schema['type'] == 'definitions': - schema = self.defs.unpack_definitions(schema) - ref = get_ref(schema) - if ref: - return self.defs.create_definition_reference_schema(schema) - else: - return schema - typevars_map = get_standard_typevars_map(dataclass) if origin is not None: dataclass = origin - # if (plain) dataclass doesn't have config, we use the parent's config, hence a default of `None` - # (Pydantic dataclasses have an empty dict config by default). - # see https://github.com/pydantic/pydantic/issues/10917 config = getattr(dataclass, '__pydantic_config__', None) + with self._config_wrapper_stack.push(config): + core_config = self._config_wrapper.core_config(dataclass) - from ..dataclasses import is_pydantic_dataclass + self = self._current_generate_schema + + from ..dataclasses import is_pydantic_dataclass - with self._ns_resolver.push(dataclass), self._config_wrapper_stack.push(config): if is_pydantic_dataclass(dataclass): - if dataclass.__pydantic_fields_complete__(): - # Copy the field info instances to avoid mutating the `FieldInfo` instances - # of the generic dataclass generic origin (e.g. `apply_typevars_map` below). - # Note that we don't apply `deepcopy` on `__pydantic_fields__` because we - # don't want to copy the `FieldInfo` attributes: - fields = { - f_name: copy(field_info) for f_name, field_info in dataclass.__pydantic_fields__.items() - } - if typevars_map: - for field in fields.values(): - field.apply_typevars_map(typevars_map, *self._types_namespace) - else: - try: - fields = rebuild_dataclass_fields( - dataclass, - config_wrapper=self._config_wrapper, - ns_resolver=self._ns_resolver, - typevars_map=typevars_map or {}, - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e + fields = deepcopy(dataclass.__pydantic_fields__) + if typevars_map: + for field in fields.values(): + field.apply_typevars_map(typevars_map, self._types_namespace) else: fields = collect_dataclass_fields( dataclass, + self._types_namespace, typevars_map=typevars_map, - config_wrapper=self._config_wrapper, ) - - if self._config_wrapper.extra == 'allow': - # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass - for field_name, field in fields.items(): - if field.init is False: - raise PydanticUserError( - f'Field {field_name} has `init=False` and dataclass has config setting `extra="allow"`. ' - f'This combination is not allowed.', - code='dataclass-init-false-extra-allow', - ) - decorators = dataclass.__dict__.get('__pydantic_decorators__') or DecoratorInfos.build(dataclass) # Move kw_only=False args to the start of the list, as this is how vanilla dataclasses work. # Note that when kw_only is missing or None, it is treated as equivalent to kw_only=True @@ -1980,270 +1367,156 @@ class GenerateSchema: model_validators = decorators.model_validators.values() inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') - core_config = self._config_wrapper.core_config(title=dataclass.__name__) - dc_schema = core_schema.dataclass_schema( dataclass, inner_schema, - generic_origin=origin, post_init=has_post_init, ref=dataclass_ref, fields=[field.name for field in dataclasses.fields(dataclass)], slots=has_slots, config=core_config, - # we don't use a custom __setattr__ for dataclasses, so we must - # pass along the frozen config setting to the pydantic-core schema - frozen=self._config_wrapper_stack.tail.frozen, ) schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values()) schema = apply_model_validators(schema, model_validators, 'outer') - return self.defs.create_definition_reference_schema(schema) + self.defs.definitions[dataclass_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(dataclass_ref) - def _call_schema(self, function: ValidateCallSupportedTypes) -> core_schema.CallSchema: + def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema: """Generate schema for a Callable. TODO support functional validators once we support them in Config """ - arguments_schema = self._arguments_schema(function) + sig = signature(function) - return_schema: core_schema.CoreSchema | None = None - config_wrapper = self._config_wrapper - if config_wrapper.validate_return: - sig = signature(function) - return_hint = sig.return_annotation - if return_hint is not sig.empty: - globalns, localns = self._types_namespace - type_hints = _typing_extra.get_function_type_hints( - function, globalns=globalns, localns=localns, include_keys={'return'} - ) - return_schema = self.generate_schema(type_hints['return']) + type_hints = _typing_extra.get_function_type_hints(function) - return core_schema.call_schema( - arguments_schema, - function, - return_schema=return_schema, - ) - - def _arguments_schema( - self, function: ValidateCallSupportedTypes, parameters_callback: ParametersCallback | None = None - ) -> core_schema.ArgumentsSchema: - """Generate schema for a Signature.""" mode_lookup: dict[_ParameterKind, Literal['positional_only', 'positional_or_keyword', 'keyword_only']] = { Parameter.POSITIONAL_ONLY: 'positional_only', Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', Parameter.KEYWORD_ONLY: 'keyword_only', } - sig = signature(function) - globalns, localns = self._types_namespace - type_hints = _typing_extra.get_function_type_hints(function, globalns=globalns, localns=localns) - arguments_list: list[core_schema.ArgumentsParameter] = [] var_args_schema: core_schema.CoreSchema | None = None var_kwargs_schema: core_schema.CoreSchema | None = None - var_kwargs_mode: core_schema.VarKwargsMode | None = None - for i, (name, p) in enumerate(sig.parameters.items()): + for name, p in sig.parameters.items(): if p.annotation is sig.empty: - annotation = typing.cast(Any, Any) + annotation = Any else: annotation = type_hints[name] - if parameters_callback is not None: - result = parameters_callback(i, name, annotation) - if result == 'skip': - continue - parameter_mode = mode_lookup.get(p.kind) if parameter_mode is not None: - arg_schema = self._generate_parameter_schema( - name, annotation, AnnotationSource.FUNCTION, p.default, parameter_mode - ) + arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode) arguments_list.append(arg_schema) elif p.kind == Parameter.VAR_POSITIONAL: var_args_schema = self.generate_schema(annotation) else: assert p.kind == Parameter.VAR_KEYWORD, p.kind + var_kwargs_schema = self.generate_schema(annotation) - unpack_type = _typing_extra.unpack_type(annotation) - if unpack_type is not None: - origin = get_origin(unpack_type) or unpack_type - if not is_typeddict(origin): - raise PydanticUserError( - f'Expected a `TypedDict` class inside `Unpack[...]`, got {unpack_type!r}', - code='unpack-typed-dict', - ) - non_pos_only_param_names = { - name for name, p in sig.parameters.items() if p.kind != Parameter.POSITIONAL_ONLY - } - overlapping_params = non_pos_only_param_names.intersection(origin.__annotations__) - if overlapping_params: - raise PydanticUserError( - f'Typed dictionary {origin.__name__!r} overlaps with parameter' - f'{"s" if len(overlapping_params) >= 2 else ""} ' - f'{", ".join(repr(p) for p in sorted(overlapping_params))}', - code='overlapping-unpack-typed-dict', - ) + return_schema: core_schema.CoreSchema | None = None + config_wrapper = self._config_wrapper + if config_wrapper.validate_return: + return_hint = type_hints.get('return') + if return_hint is not None: + return_schema = self.generate_schema(return_hint) - var_kwargs_mode = 'unpacked-typed-dict' - var_kwargs_schema = self._typed_dict_schema(unpack_type, get_origin(unpack_type)) - else: - var_kwargs_mode = 'uniform' - var_kwargs_schema = self.generate_schema(annotation) - - return core_schema.arguments_schema( - arguments_list, - var_args_schema=var_args_schema, - var_kwargs_mode=var_kwargs_mode, - var_kwargs_schema=var_kwargs_schema, - validate_by_name=self._config_wrapper.validate_by_name, - ) - - def _arguments_v3_schema( - self, function: ValidateCallSupportedTypes, parameters_callback: ParametersCallback | None = None - ) -> core_schema.ArgumentsV3Schema: - mode_lookup: dict[ - _ParameterKind, Literal['positional_only', 'positional_or_keyword', 'var_args', 'keyword_only'] - ] = { - Parameter.POSITIONAL_ONLY: 'positional_only', - Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', - Parameter.VAR_POSITIONAL: 'var_args', - Parameter.KEYWORD_ONLY: 'keyword_only', - } - - sig = signature(function) - globalns, localns = self._types_namespace - type_hints = _typing_extra.get_function_type_hints(function, globalns=globalns, localns=localns) - - parameters_list: list[core_schema.ArgumentsV3Parameter] = [] - - for i, (name, p) in enumerate(sig.parameters.items()): - if parameters_callback is not None: - result = parameters_callback(i, name, p.annotation) - if result == 'skip': - continue - - if p.annotation is Parameter.empty: - annotation = typing.cast(Any, Any) - else: - annotation = type_hints[name] - - parameter_mode = mode_lookup.get(p.kind) - if parameter_mode is None: - assert p.kind == Parameter.VAR_KEYWORD, p.kind - - unpack_type = _typing_extra.unpack_type(annotation) - if unpack_type is not None: - origin = get_origin(unpack_type) or unpack_type - if not is_typeddict(origin): - raise PydanticUserError( - f'Expected a `TypedDict` class inside `Unpack[...]`, got {unpack_type!r}', - code='unpack-typed-dict', - ) - non_pos_only_param_names = { - name for name, p in sig.parameters.items() if p.kind != Parameter.POSITIONAL_ONLY - } - overlapping_params = non_pos_only_param_names.intersection(origin.__annotations__) - if overlapping_params: - raise PydanticUserError( - f'Typed dictionary {origin.__name__!r} overlaps with parameter' - f'{"s" if len(overlapping_params) >= 2 else ""} ' - f'{", ".join(repr(p) for p in sorted(overlapping_params))}', - code='overlapping-unpack-typed-dict', - ) - parameter_mode = 'var_kwargs_unpacked_typed_dict' - annotation = unpack_type - else: - parameter_mode = 'var_kwargs_uniform' - - parameters_list.append( - self._generate_parameter_v3_schema( - name, annotation, AnnotationSource.FUNCTION, parameter_mode, default=p.default - ) - ) - - return core_schema.arguments_v3_schema( - parameters_list, - validate_by_name=self._config_wrapper.validate_by_name, + return core_schema.call_schema( + core_schema.arguments_schema( + arguments_list, + var_args_schema=var_args_schema, + var_kwargs_schema=var_kwargs_schema, + populate_by_name=config_wrapper.populate_by_name, + ), + function, + return_schema=return_schema, ) def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema: - try: - has_default = typevar.has_default() - except AttributeError: - # Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13 - pass - else: - if has_default: - return self.generate_schema(typevar.__default__) + assert isinstance(typevar, typing.TypeVar) - if constraints := typevar.__constraints__: - return self._union_schema(typing.Union[constraints]) + bound = typevar.__bound__ + constraints = typevar.__constraints__ + not_set = object() + default = getattr(typevar, '__default__', not_set) - if bound := typevar.__bound__: + if (bound is not None) + (len(constraints) != 0) + (default is not not_set) > 1: + raise NotImplementedError( + 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults' + ) + + if default is not not_set: + return self.generate_schema(default) + elif constraints: + return self._union_schema(typing.Union[constraints]) # type: ignore + elif bound: schema = self.generate_schema(bound) schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( - lambda x, h: h(x), - schema=core_schema.any_schema(), + lambda x, h: h(x), schema=core_schema.any_schema() ) return schema - - return core_schema.any_schema() + else: + return core_schema.any_schema() def _computed_field_schema( self, d: Decorator[ComputedFieldInfo], field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]], ) -> core_schema.ComputedField: - if d.info.return_type is not PydanticUndefined: - return_type = d.info.return_type - else: - try: - # Do not pass in globals as the function could be defined in a different module. - # Instead, let `get_callable_return_type` infer the globals to use, but still pass - # in locals that may contain a parent/rebuild namespace: - return_type = _decorators.get_callable_return_type(d.func, localns=self._types_namespace.locals) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e + try: + return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e if return_type is PydanticUndefined: raise PydanticUserError( 'Computed field is missing return type annotation or specifying `return_type`' - ' to the `@computed_field` decorator (e.g. `@computed_field(return_type=int | str)`)', + ' to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)', code='model-field-missing-annotation', ) return_type = replace_types(return_type, self._typevars_map) - # Create a new ComputedFieldInfo so that different type parametrizations of the same - # generic model's computed field can have different return types. - d.info = dataclasses.replace(d.info, return_type=return_type) return_type_schema = self.generate_schema(return_type) # Apply serializers to computed field if there exist return_type_schema = self._apply_field_serializers( return_type_schema, filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name), + computed_field=True, ) - + # Handle alias_generator using similar logic to that from + # pydantic._internal._generate_schema.GenerateSchema._common_field_schema, + # with field_info -> d.info and name -> d.cls_var_name alias_generator = self._config_wrapper.alias_generator - if alias_generator is not None: - self._apply_alias_generator_to_computed_field_info( - alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name - ) - self._apply_field_title_generator_to_field_info(self._config_wrapper, d.info, d.cls_var_name) + if alias_generator and (d.info.alias_priority is None or d.info.alias_priority <= 1): + alias = alias_generator(d.cls_var_name) + if not isinstance(alias, str): + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + d.info.alias = alias + d.info.alias_priority = 1 - pydantic_js_updates, pydantic_js_extra = _extract_json_schema_info_from_field_info(d.info) - core_metadata: dict[str, Any] = {} - update_core_metadata( - core_metadata, - pydantic_js_updates={'readOnly': True, **(pydantic_js_updates if pydantic_js_updates else {})}, - pydantic_js_extra=pydantic_js_extra, - ) + def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(schema) + + json_schema['readOnly'] = True + + title = d.info.title + if title is not None: + json_schema['title'] = title + + description = d.info.description + if description is not None: + json_schema['description'] = description + + return json_schema + + metadata = build_metadata_dict(js_annotation_functions=[set_computed_field_metadata]) return core_schema.computed_field( - d.cls_var_name, return_schema=return_type_schema, alias=d.info.alias, metadata=core_metadata + d.cls_var_name, return_schema=return_type_schema, alias=d.info.alias, metadata=metadata ) def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema: """Generate schema for an Annotated type, e.g. `Annotated[int, Field(...)]` or `Annotated[int, Gt(0)]`.""" - FieldInfo = import_cached_field_info() source_type, *annotations = self._get_args_resolving_forward_refs( annotated_type, required=True, @@ -2256,6 +1529,25 @@ class GenerateSchema: schema = wrap_default(annotation, schema) return schema + def _get_prepare_pydantic_annotations_for_known_type( + self, obj: Any, annotations: tuple[Any, ...] + ) -> tuple[Any, list[Any]] | None: + from ._std_types_schema import PREPARE_METHODS + + # This check for hashability is only necessary for python 3.7 + try: + hash(obj) + except TypeError: + # obj is definitely not a known type if this fails + return None + + for gen in PREPARE_METHODS: + res = gen(obj, annotations, self._config_wrapper.config_dict) + if res is not None: + return res + + return None + def _apply_annotations( self, source_type: Any, @@ -2269,18 +1561,21 @@ class GenerateSchema: (in other words, `GenerateSchema._annotated_schema` just unpacks `Annotated`, this process it). """ annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations)) + res = self._get_prepare_pydantic_annotations_for_known_type(source_type, tuple(annotations)) + if res is not None: + source_type, annotations = res pydantic_js_annotation_functions: list[GetJsonSchemaFunction] = [] def inner_handler(obj: Any) -> CoreSchema: - schema = self._generate_schema_from_get_schema_method(obj, source_type) - - if schema is None: - schema = self._generate_schema_inner(obj) - - metadata_js_function = _extract_get_pydantic_json_schema(obj) + from_property = self._generate_schema_from_property(obj, obj) + if from_property is None: + schema = self._generate_schema(obj) + else: + schema = from_property + metadata_js_function = _extract_get_pydantic_json_schema(obj, schema) if metadata_js_function is not None: - metadata_schema = resolve_original_schema(schema, self.defs) + metadata_schema = resolve_original_schema(schema, self.defs.definitions) if metadata_schema is not None: self._add_js_function(metadata_schema, metadata_js_function) return transform_inner_schema(schema) @@ -2296,13 +1591,11 @@ class GenerateSchema: schema = get_inner_schema(source_type) if pydantic_js_annotation_functions: - core_metadata = schema.setdefault('metadata', {}) - update_core_metadata(core_metadata, pydantic_js_annotation_functions=pydantic_js_annotation_functions) + metadata = CoreMetadataHandler(schema).metadata + metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions) return _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, source_type, schema) def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any) -> core_schema.CoreSchema: - FieldInfo = import_cached_field_info() - if isinstance(metadata, FieldInfo): for field_metadata in metadata.metadata: schema = self._apply_single_annotation(schema, field_metadata) @@ -2320,23 +1613,23 @@ class GenerateSchema: return schema original_schema = schema - ref = schema.get('ref') + ref = schema.get('ref', None) if ref is not None: schema = schema.copy() new_ref = ref + f'_{repr(metadata)}' - if (existing := self.defs.get_schema_from_ref(new_ref)) is not None: - return existing - schema['ref'] = new_ref # pyright: ignore[reportGeneralTypeIssues] + if new_ref in self.defs.definitions: + return self.defs.definitions[new_ref] + schema['ref'] = new_ref # type: ignore elif schema['type'] == 'definition-ref': ref = schema['schema_ref'] - if (referenced_schema := self.defs.get_schema_from_ref(ref)) is not None: - schema = referenced_schema.copy() + if ref in self.defs.definitions: + schema = self.defs.definitions[ref].copy() new_ref = ref + f'_{repr(metadata)}' - if (existing := self.defs.get_schema_from_ref(new_ref)) is not None: - return existing - schema['ref'] = new_ref # pyright: ignore[reportGeneralTypeIssues] + if new_ref in self.defs.definitions: + return self.defs.definitions[new_ref] + schema['ref'] = new_ref # type: ignore - maybe_updated_schema = _known_annotated_metadata.apply_known_metadata(metadata, schema) + maybe_updated_schema = _known_annotated_metadata.apply_known_metadata(metadata, schema.copy()) if maybe_updated_schema is not None: return maybe_updated_schema @@ -2345,17 +1638,34 @@ class GenerateSchema: def _apply_single_annotation_json_schema( self, schema: core_schema.CoreSchema, metadata: Any ) -> core_schema.CoreSchema: - FieldInfo = import_cached_field_info() - if isinstance(metadata, FieldInfo): for field_metadata in metadata.metadata: schema = self._apply_single_annotation_json_schema(schema, field_metadata) + json_schema_update: JsonSchemaValue = {} + if metadata.title: + json_schema_update['title'] = metadata.title + if metadata.description: + json_schema_update['description'] = metadata.description + if metadata.examples: + json_schema_update['examples'] = to_jsonable_python(metadata.examples) - pydantic_js_updates, pydantic_js_extra = _extract_json_schema_info_from_field_info(metadata) - core_metadata = schema.setdefault('metadata', {}) - update_core_metadata( - core_metadata, pydantic_js_updates=pydantic_js_updates, pydantic_js_extra=pydantic_js_extra - ) + json_schema_extra = metadata.json_schema_extra + if json_schema_update or json_schema_extra: + + def json_schema_update_func( + core_schema: CoreSchemaOrField, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = handler(core_schema) + json_schema.update(json_schema_update) + if isinstance(json_schema_extra, dict): + json_schema.update(to_jsonable_python(json_schema_extra)) + elif callable(json_schema_extra): + json_schema_extra(json_schema) + return json_schema + + CoreMetadataHandler(schema).metadata.setdefault('pydantic_js_annotation_functions', []).append( + json_schema_update_func + ) return schema def _get_wrapped_inner_schema( @@ -2364,17 +1674,16 @@ class GenerateSchema: annotation: Any, pydantic_js_annotation_functions: list[GetJsonSchemaFunction], ) -> CallbackGetCoreSchemaHandler: - annotation_get_schema: GetCoreSchemaFunction | None = getattr(annotation, '__get_pydantic_core_schema__', None) + metadata_get_schema: GetCoreSchemaFunction = getattr(annotation, '__get_pydantic_core_schema__', None) or ( + lambda source, handler: handler(source) + ) def new_handler(source: Any) -> core_schema.CoreSchema: - if annotation_get_schema is not None: - schema = annotation_get_schema(source, get_inner_schema) - else: - schema = get_inner_schema(source) - schema = self._apply_single_annotation(schema, annotation) - schema = self._apply_single_annotation_json_schema(schema, annotation) + schema = metadata_get_schema(source, get_inner_schema) + schema = self._apply_single_annotation(schema, annotation) + schema = self._apply_single_annotation_json_schema(schema, annotation) - metadata_js_function = _extract_get_pydantic_json_schema(annotation) + metadata_js_function = _extract_get_pydantic_json_schema(annotation, schema) if metadata_js_function is not None: pydantic_js_annotation_functions.append(metadata_js_function) return schema @@ -2385,6 +1694,7 @@ class GenerateSchema: self, schema: core_schema.CoreSchema, serializers: list[Decorator[FieldSerializerDecoratorInfo]], + computed_field: bool = False, ) -> core_schema.CoreSchema: """Apply field serializers to a schema.""" if serializers: @@ -2393,25 +1703,23 @@ class GenerateSchema: inner_schema = schema['schema'] schema['schema'] = self._apply_field_serializers(inner_schema, serializers) return schema - elif 'ref' in schema: - schema = self.defs.create_definition_reference_schema(schema) + else: + ref = typing.cast('str|None', schema.get('ref', None)) + if ref is not None: + schema = core_schema.definition_reference_schema(ref) # use the last serializer to make it easy to override a serializer set on a parent model serializer = serializers[-1] - is_field_serializer, info_arg = inspect_field_serializer(serializer.func, serializer.info.mode) + is_field_serializer, info_arg = inspect_field_serializer( + serializer.func, serializer.info.mode, computed_field=computed_field + ) - if serializer.info.return_type is not PydanticUndefined: - return_type = serializer.info.return_type - else: - try: - # Do not pass in globals as the function could be defined in a different module. - # Instead, let `get_callable_return_type` infer the globals to use, but still pass - # in locals that may contain a parent/rebuild namespace: - return_type = _decorators.get_callable_return_type( - serializer.func, localns=self._types_namespace.locals - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e + try: + return_type = _decorators.get_function_return_type( + serializer.func, serializer.info.return_type, self._types_namespace + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e if return_type is PydanticUndefined: return_schema = None @@ -2446,19 +1754,12 @@ class GenerateSchema: serializer = list(serializers)[-1] info_arg = inspect_model_serializer(serializer.func, serializer.info.mode) - if serializer.info.return_type is not PydanticUndefined: - return_type = serializer.info.return_type - else: - try: - # Do not pass in globals as the function could be defined in a different module. - # Instead, let `get_callable_return_type` infer the globals to use, but still pass - # in locals that may contain a parent/rebuild namespace: - return_type = _decorators.get_callable_return_type( - serializer.func, localns=self._types_namespace.locals - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e - + try: + return_type = _decorators.get_function_return_type( + serializer.func, serializer.info.return_type, self._types_namespace + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e if return_type is PydanticUndefined: return_schema = None else: @@ -2508,8 +1809,6 @@ _VALIDATOR_F_MATCH: Mapping[ } -# TODO V3: this function is only used for deprecated decorators. It should -# be removed once we drop support for those. def apply_validators( schema: core_schema.CoreSchema, validators: Iterable[Decorator[RootValidatorDecoratorInfo]] @@ -2610,10 +1909,7 @@ def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_ """ if field_info.default_factory: return core_schema.with_default_schema( - schema, - default_factory=field_info.default_factory, - default_factory_takes_data=takes_validated_data_argument(field_info.default_factory), - validate_default=field_info.validate_default, + schema, default_factory=field_info.default_factory, validate_default=field_info.validate_default ) elif field_info.default is not PydanticUndefined: return core_schema.with_default_schema( @@ -2623,31 +1919,29 @@ def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_ return schema -def _extract_get_pydantic_json_schema(tp: Any) -> GetJsonSchemaFunction | None: +def _extract_get_pydantic_json_schema(tp: Any, schema: CoreSchema) -> GetJsonSchemaFunction | None: """Extract `__get_pydantic_json_schema__` from a type, handling the deprecated `__modify_schema__`.""" js_modify_function = getattr(tp, '__get_pydantic_json_schema__', None) if hasattr(tp, '__modify_schema__'): - BaseModel = import_cached_base_model() + from pydantic import BaseModel # circular reference has_custom_v2_modify_js_func = ( js_modify_function is not None - and BaseModel.__get_pydantic_json_schema__.__func__ # type: ignore + and BaseModel.__get_pydantic_json_schema__.__func__ not in (js_modify_function, getattr(js_modify_function, '__func__', None)) ) if not has_custom_v2_modify_js_func: - cls_name = getattr(tp, '__name__', None) raise PydanticUserError( - f'The `__modify_schema__` method is not supported in Pydantic v2. ' - f'Use `__get_pydantic_json_schema__` instead{f" in class `{cls_name}`" if cls_name else ""}.', + 'The `__modify_schema__` method is not supported in Pydantic v2. ' + 'Use `__get_pydantic_json_schema__` instead.', code='custom-json-schema', ) - if (origin := get_origin(tp)) is not None: - # Generic aliases proxy attribute access to the origin, *except* dunder attributes, - # such as `__get_pydantic_json_schema__`, hence the explicit check. - return _extract_get_pydantic_json_schema(origin) + # handle GenericAlias' but ignore Annotated which "lies" about its origin (in this case it would be `int`) + if hasattr(tp, '__origin__') and not isinstance(tp, type(Annotated[int, 'placeholder'])): + return _extract_get_pydantic_json_schema(tp.__origin__, schema) if js_modify_function is None: return None @@ -2683,62 +1977,15 @@ def _common_field( } -def resolve_original_schema(schema: CoreSchema, definitions: _Definitions) -> CoreSchema | None: - if schema['type'] == 'definition-ref': - return definitions.get_schema_from_ref(schema['schema_ref']) - elif schema['type'] == 'definitions': - return schema['schema'] - else: - return schema - - -def _inlining_behavior( - def_ref: core_schema.DefinitionReferenceSchema, -) -> Literal['inline', 'keep', 'preserve_metadata']: - """Determine the inlining behavior of the `'definition-ref'` schema. - - - If no `'serialization'` schema and no metadata is attached, the schema can safely be inlined. - - If it has metadata but only related to the deferred discriminator application, it can be inlined - provided that such metadata is kept. - - Otherwise, the schema should not be inlined. Doing so would remove the `'serialization'` schema or metadata. - """ - if 'serialization' in def_ref: - return 'keep' - metadata = def_ref.get('metadata') - if not metadata: - return 'inline' - if len(metadata) == 1 and 'pydantic_internal_union_discriminator' in metadata: - return 'preserve_metadata' - return 'keep' - - class _Definitions: """Keeps track of references and definitions.""" - _recursively_seen: set[str] - """A set of recursively seen references. - - When a referenceable type is encountered, the `get_schema_or_ref` context manager is - entered to compute the reference. If the type references itself by some way (e.g. for - a dataclass a Pydantic model, the class can be referenced as a field annotation), - entering the context manager again will yield a `'definition-ref'` schema that should - short-circuit the normal generation process, as the reference was already in this set. - """ - - _definitions: dict[str, core_schema.CoreSchema] - """A mapping of references to their corresponding schema. - - When a schema for a referenceable type is generated, it is stored in this mapping. If the - same type is encountered again, the reference is yielded by the `get_schema_or_ref` context - manager. - """ - def __init__(self) -> None: - self._recursively_seen = set() - self._definitions = {} + self.seen: set[str] = set() + self.definitions: dict[str, core_schema.CoreSchema] = {} @contextmanager - def get_schema_or_ref(self, tp: Any, /) -> Generator[tuple[str, core_schema.DefinitionReferenceSchema | None]]: + def get_schema_or_ref(self, tp: Any) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: """Get a definition for `tp` if one exists. If a definition exists, a tuple of `(ref_string, CoreSchema)` is returned. @@ -2752,119 +1999,31 @@ class _Definitions: At present the following types can be named/recursive: - - Pydantic model - - Pydantic and stdlib dataclasses - - Typed dictionaries - - Named tuples - - `TypeAliasType` instances - - Enums + - BaseModel + - Dataclasses + - TypedDict + - TypeAliasType """ ref = get_type_ref(tp) - # return the reference if we're either (1) in a cycle or (2) it the reference was already encountered: - if ref in self._recursively_seen or ref in self._definitions: + # return the reference if we're either (1) in a cycle or (2) it was already defined + if ref in self.seen or ref in self.definitions: yield (ref, core_schema.definition_reference_schema(ref)) else: - self._recursively_seen.add(ref) + self.seen.add(ref) try: yield (ref, None) finally: - self._recursively_seen.discard(ref) + self.seen.discard(ref) - def get_schema_from_ref(self, ref: str) -> CoreSchema | None: - """Resolve the schema from the given reference.""" - return self._definitions.get(ref) - def create_definition_reference_schema(self, schema: CoreSchema) -> core_schema.DefinitionReferenceSchema: - """Store the schema as a definition and return a `'definition-reference'` schema pointing to it. - - The schema must have a reference attached to it. - """ - ref = schema['ref'] # pyright: ignore - self._definitions[ref] = schema - return core_schema.definition_reference_schema(ref) - - def unpack_definitions(self, schema: core_schema.DefinitionsSchema) -> CoreSchema: - """Store the definitions of the `'definitions'` core schema and return the inner core schema.""" - for def_schema in schema['definitions']: - self._definitions[def_schema['ref']] = def_schema # pyright: ignore +def resolve_original_schema(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> CoreSchema | None: + if schema['type'] == 'definition-ref': + return definitions.get(schema['schema_ref'], None) + elif schema['type'] == 'definitions': return schema['schema'] - - def finalize_schema(self, schema: CoreSchema) -> CoreSchema: - """Finalize the core schema. - - This traverses the core schema and referenced definitions, replaces `'definition-ref'` schemas - by the referenced definition if possible, and applies deferred discriminators. - """ - definitions = self._definitions - try: - gather_result = gather_schemas_for_cleaning( - schema, - definitions=definitions, - ) - except MissingDefinitionError as e: - raise InvalidSchemaError from e - - remaining_defs: dict[str, CoreSchema] = {} - - # Note: this logic doesn't play well when core schemas with deferred discriminator metadata - # and references are encountered. See the `test_deferred_discriminated_union_and_references()` test. - for ref, inlinable_def_ref in gather_result['collected_references'].items(): - if inlinable_def_ref is not None and (inlining_behavior := _inlining_behavior(inlinable_def_ref)) != 'keep': - if inlining_behavior == 'inline': - # `ref` was encountered, and only once: - # - `inlinable_def_ref` is a `'definition-ref'` schema and is guaranteed to be - # the only one. Transform it into the definition it points to. - # - Do not store the definition in the `remaining_defs`. - inlinable_def_ref.clear() # pyright: ignore[reportAttributeAccessIssue] - inlinable_def_ref.update(self._resolve_definition(ref, definitions)) # pyright: ignore - elif inlining_behavior == 'preserve_metadata': - # `ref` was encountered, and only once, but contains discriminator metadata. - # We will do the same thing as if `inlining_behavior` was `'inline'`, but make - # sure to keep the metadata for the deferred discriminator application logic below. - meta = inlinable_def_ref.pop('metadata') - inlinable_def_ref.clear() # pyright: ignore[reportAttributeAccessIssue] - inlinable_def_ref.update(self._resolve_definition(ref, definitions)) # pyright: ignore - inlinable_def_ref['metadata'] = meta - else: - # `ref` was encountered, at least two times (or only once, but with metadata or a serialization schema): - # - Do not inline the `'definition-ref'` schemas (they are not provided in the gather result anyway). - # - Store the the definition in the `remaining_defs` - remaining_defs[ref] = self._resolve_definition(ref, definitions) - - for cs in gather_result['deferred_discriminator_schemas']: - discriminator: str | None = cs['metadata'].pop('pydantic_internal_union_discriminator', None) # pyright: ignore[reportTypedDictNotRequiredAccess] - if discriminator is None: - # This can happen in rare scenarios, when a deferred schema is present multiple times in the - # gather result (e.g. when using the `Sequence` type -- see `test_sequence_discriminated_union()`). - # In this case, a previous loop iteration applied the discriminator and so we can just skip it here. - continue - applied = _discriminated_union.apply_discriminator(cs.copy(), discriminator, remaining_defs) - # Mutate the schema directly to have the discriminator applied - cs.clear() # pyright: ignore[reportAttributeAccessIssue] - cs.update(applied) # pyright: ignore - - if remaining_defs: - schema = core_schema.definitions_schema(schema=schema, definitions=[*remaining_defs.values()]) + else: return schema - def _resolve_definition(self, ref: str, definitions: dict[str, CoreSchema]) -> CoreSchema: - definition = definitions[ref] - if definition['type'] != 'definition-ref': - return definition - - # Some `'definition-ref'` schemas might act as "intermediate" references (e.g. when using - # a PEP 695 type alias (which is referenceable) that references another PEP 695 type alias): - visited: set[str] = set() - while definition['type'] == 'definition-ref' and _inlining_behavior(definition) == 'inline': - schema_ref = definition['schema_ref'] - if schema_ref in visited: - raise PydanticUserError( - f'{ref} contains a circular reference to itself.', code='circular-reference-schema' - ) - visited.add(schema_ref) - definition = definitions[schema_ref] - return {**definition, 'ref': ref} # pyright: ignore[reportReturnType] - class _FieldNameStack: __slots__ = ('_stack',) @@ -2883,22 +2042,3 @@ class _FieldNameStack: return self._stack[-1] else: return None - - -class _ModelTypeStack: - __slots__ = ('_stack',) - - def __init__(self) -> None: - self._stack: list[type] = [] - - @contextmanager - def push(self, type_obj: type) -> Iterator[None]: - self._stack.append(type_obj) - yield - self._stack.pop() - - def get(self) -> type | None: - if self._stack: - return self._stack[-1] - else: - return None diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_generics.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_generics.py index 8013676..7c3d5f4 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_generics.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_generics.py @@ -4,21 +4,17 @@ import sys import types import typing from collections import ChainMap -from collections.abc import Iterator, Mapping from contextlib import contextmanager from contextvars import ContextVar -from itertools import zip_longest from types import prepare_class -from typing import TYPE_CHECKING, Annotated, Any, TypeVar +from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar from weakref import WeakValueDictionary import typing_extensions -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin -from . import _typing_extra from ._core_utils import get_type_ref from ._forward_ref import PydanticRecursiveRef +from ._typing_extra import TypeVarType, typing_base from ._utils import all_identical, is_model_class if sys.version_info >= (3, 10): @@ -27,7 +23,7 @@ if sys.version_info >= (3, 10): if TYPE_CHECKING: from ..main import BaseModel -GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]] +GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]] # Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching. # Right now, to handle recursive generics, we some types must remain cached for brief periods without references. @@ -38,25 +34,43 @@ GenericTypesCacheKey = tuple[Any, Any, tuple[Any, ...]] KT = TypeVar('KT') VT = TypeVar('VT') _LIMITED_DICT_SIZE = 100 +if TYPE_CHECKING: + class LimitedDict(dict, MutableMapping[KT, VT]): + def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): + ... -class LimitedDict(dict[KT, VT]): - def __init__(self, size_limit: int = _LIMITED_DICT_SIZE) -> None: - self.size_limit = size_limit - super().__init__() +else: - def __setitem__(self, key: KT, value: VT, /) -> None: - super().__setitem__(key, value) - if len(self) > self.size_limit: - excess = len(self) - self.size_limit + self.size_limit // 10 - to_remove = list(self.keys())[:excess] - for k in to_remove: - del self[k] + class LimitedDict(dict): + """Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage. + + Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache. + """ + + def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): + self.size_limit = size_limit + super().__init__() + + def __setitem__(self, __key: Any, __value: Any) -> None: + super().__setitem__(__key, __value) + if len(self) > self.size_limit: + excess = len(self) - self.size_limit + self.size_limit // 10 + to_remove = list(self.keys())[:excess] + for key in to_remove: + del self[key] + + def __class_getitem__(cls, *args: Any) -> Any: + # to avoid errors with 3.7 + return cls # weak dictionaries allow the dynamically created parametrized versions of generic models to get collected # once they are no longer referenced by the caller. -GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]'] +if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9 + GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]'] +else: + GenericTypesCache = WeakValueDictionary if TYPE_CHECKING: @@ -94,13 +108,13 @@ else: # and discover later on that we need to re-add all this infrastructure... # _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict()) -_GENERIC_TYPES_CACHE: ContextVar[GenericTypesCache | None] = ContextVar('_GENERIC_TYPES_CACHE', default=None) +_GENERIC_TYPES_CACHE = GenericTypesCache() class PydanticGenericMetadata(typing_extensions.TypedDict): origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__ args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__ - parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__ + parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__ def create_generic_submodel( @@ -157,7 +171,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]: depth: The depth to get the frame. Returns: - A tuple contains `module_name` and `called_globally`. + A tuple contains `module_nam` and `called_globally`. Raises: RuntimeError: If the function is not called inside a function. @@ -175,7 +189,7 @@ def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]: DictValues: type[Any] = {}.values().__class__ -def iter_contained_typevars(v: Any) -> Iterator[TypeVar]: +def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found. This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias, @@ -208,7 +222,7 @@ def get_origin(v: Any) -> Any: return typing_extensions.get_origin(v) -def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None: +def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None: """Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the `replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias. """ @@ -221,11 +235,11 @@ def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None: # In this case, we know that cls is a _GenericAlias, and origin is the generic type # So it is safe to access cls.__args__ and origin.__parameters__ args: tuple[Any, ...] = cls.__args__ # type: ignore - parameters: tuple[TypeVar, ...] = origin.__parameters__ + parameters: tuple[TypeVarType, ...] = origin.__parameters__ return dict(zip(parameters, args)) -def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]: +def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None: """Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible with the `replace_types` function. @@ -237,13 +251,10 @@ def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any]: generic_metadata = cls.__pydantic_generic_metadata__ origin = generic_metadata['origin'] args = generic_metadata['args'] - if not args: - # No need to go into `iter_contained_typevars`: - return {} return dict(zip(iter_contained_typevars(origin), args)) -def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: +def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any: """Return type with all occurrences of `type_map` keys recursively replaced with their values. Args: @@ -255,13 +266,13 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: `typevar_map` keys recursively replaced. Example: - ```python - from typing import List, Union + ```py + from typing import List, Tuple, Union from pydantic._internal._generics import replace_types - replace_types(tuple[str, Union[List[str], float]], {str: int}) - #> tuple[int, Union[List[int], float]] + replace_types(Tuple[str, Union[List[str], float]], {str: int}) + #> Tuple[int, Union[List[int], float]] ``` """ if not type_map: @@ -270,25 +281,25 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: type_args = get_args(type_) origin_type = get_origin(type_) - if typing_objects.is_annotated(origin_type): + if origin_type is typing_extensions.Annotated: annotated_type, *annotations = type_args - annotated_type = replace_types(annotated_type, type_map) - # TODO remove parentheses when we drop support for Python 3.10: - return Annotated[(annotated_type, *annotations)] + annotated = replace_types(annotated_type, type_map) + for annotation in annotations: + annotated = typing_extensions.Annotated[annotated, annotation] + return annotated - # Having type args is a good indicator that this is a typing special form - # instance or a generic alias of some sort. + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. if type_args: resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args) if all_identical(type_args, resolved_type_args): # If all arguments are the same, there is no need to modify the # type or create a new object at all return type_ - if ( origin_type is not None - and isinstance(type_, _typing_extra.typing_base) - and not isinstance(origin_type, _typing_extra.typing_base) + and isinstance(type_, typing_base) + and not isinstance(origin_type, typing_base) and getattr(type_, '_name', None) is not None ): # In python < 3.9 generic aliases don't exist so any of these like `list`, @@ -296,24 +307,11 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: # See: https://www.python.org/dev/peps/pep-0585 origin_type = getattr(typing, type_._name) assert origin_type is not None - - if is_union_origin(origin_type): - if any(typing_objects.is_any(arg) for arg in resolved_type_args): - # `Any | T` ~ `Any`: - resolved_type_args = (Any,) - # `Never | T` ~ `T`: - resolved_type_args = tuple( - arg - for arg in resolved_type_args - if not (typing_objects.is_noreturn(arg) or typing_objects.is_never(arg)) - ) - # PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__. # We also cannot use isinstance() since we have to compare types. if sys.version_info >= (3, 10) and origin_type is types.UnionType: return _UnionGenericAlias(origin_type, resolved_type_args) - # NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below - return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args] + return origin_type[resolved_type_args] # We handle pydantic generic models separately as they don't have the same # semantics as "typing" classes or generic aliases @@ -329,8 +327,8 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: # Handle special case for typehints that can have lists as arguments. # `typing.Callable[[int, str], int]` is an example for this. - if isinstance(type_, list): - resolved_list = [replace_types(element, type_map) for element in type_] + if isinstance(type_, (List, list)): + resolved_list = list(replace_types(element, type_map) for element in type_) if all_identical(type_, resolved_list): return type_ return resolved_list @@ -340,57 +338,49 @@ def replace_types(type_: Any, type_map: Mapping[TypeVar, Any] | None) -> Any: return type_map.get(type_, type_) -def map_generic_model_arguments(cls: type[BaseModel], args: tuple[Any, ...]) -> dict[TypeVar, Any]: - """Return a mapping between the parameters of a generic model and the provided arguments during parameterization. +def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool: + """Checks if the type, or any of its arbitrary nested args, satisfy + `isinstance(, isinstance_target)`. + """ + if isinstance(type_, isinstance_target): + return True + + type_args = get_args(type_) + origin_type = get_origin(type_) + + if origin_type is typing_extensions.Annotated: + annotated_type, *annotations = type_args + return has_instance_in_type(annotated_type, isinstance_target) + + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. + if any(has_instance_in_type(a, isinstance_target) for a in type_args): + return True + + # Handle special case for typehints that can have lists as arguments. + # `typing.Callable[[int, str], int]` is an example for this. + if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec): + if any(has_instance_in_type(element, isinstance_target) for element in type_): + return True + + return False + + +def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None: + """Check the generic model parameters count is equal. + + Args: + cls: The generic model. + parameters: A tuple of passed parameters to the generic model. Raises: - TypeError: If the number of arguments does not match the parameters (i.e. if providing too few or too many arguments). - - Example: - ```python {test="skip" lint="skip"} - class Model[T, U, V = int](BaseModel): ... - - map_generic_model_arguments(Model, (str, bytes)) - #> {T: str, U: bytes, V: int} - - map_generic_model_arguments(Model, (str,)) - #> TypeError: Too few arguments for ; actual 1, expected at least 2 - - map_generic_model_arguments(Model, (str, bytes, int, complex)) - #> TypeError: Too many arguments for ; actual 4, expected 3 - ``` - - Note: - This function is analogous to the private `typing._check_generic_specialization` function. + TypeError: If the passed parameters count is not equal to generic model parameters count. """ - parameters = cls.__pydantic_generic_metadata__['parameters'] - expected_len = len(parameters) - typevars_map: dict[TypeVar, Any] = {} - - _missing = object() - for parameter, argument in zip_longest(parameters, args, fillvalue=_missing): - if parameter is _missing: - raise TypeError(f'Too many arguments for {cls}; actual {len(args)}, expected {expected_len}') - - if argument is _missing: - param = typing.cast(TypeVar, parameter) - try: - has_default = param.has_default() - except AttributeError: - # Happens if using `typing.TypeVar` (and not `typing_extensions`) on Python < 3.13. - has_default = False - if has_default: - # The default might refer to other type parameters. For an example, see: - # https://typing.readthedocs.io/en/latest/spec/generics.html#type-parameters-as-parameters-to-generics - typevars_map[param] = replace_types(param.__default__, typevars_map) - else: - expected_len -= sum(hasattr(p, 'has_default') and p.has_default() for p in parameters) - raise TypeError(f'Too few arguments for {cls}; actual {len(args)}, expected at least {expected_len}') - else: - param = typing.cast(TypeVar, parameter) - typevars_map[param] = argument - - return typevars_map + actual = len(parameters) + expected = len(cls.__pydantic_generic_metadata__['parameters']) + if actual != expected: + description = 'many' if actual > expected else 'few' + raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}') _generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None) @@ -421,8 +411,7 @@ def generic_recursion_self_type( yield self_type else: previously_seen_type_refs.add(type_ref) - yield - previously_seen_type_refs.remove(type_ref) + yield None finally: if token: _generic_recursion_cache.reset(token) @@ -453,24 +442,14 @@ def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually equal. """ - generic_types_cache = _GENERIC_TYPES_CACHE.get() - if generic_types_cache is None: - generic_types_cache = GenericTypesCache() - _GENERIC_TYPES_CACHE.set(generic_types_cache) - return generic_types_cache.get(_early_cache_key(parent, typevar_values)) + return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values)) def get_cached_generic_type_late( parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...] ) -> type[BaseModel] | None: """See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup.""" - generic_types_cache = _GENERIC_TYPES_CACHE.get() - if ( - generic_types_cache is None - ): # pragma: no cover (early cache is guaranteed to run first and initialize the cache) - generic_types_cache = GenericTypesCache() - _GENERIC_TYPES_CACHE.set(generic_types_cache) - cached = generic_types_cache.get(_late_cache_key(origin, args, typevar_values)) + cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values)) if cached is not None: set_cached_generic_type(parent, typevar_values, cached, origin, args) return cached @@ -486,17 +465,11 @@ def set_cached_generic_type( """See the docstring of `get_cached_generic_type_early` for more information about why items are cached with two different keys. """ - generic_types_cache = _GENERIC_TYPES_CACHE.get() - if ( - generic_types_cache is None - ): # pragma: no cover (cache lookup is guaranteed to run first and initialize the cache) - generic_types_cache = GenericTypesCache() - _GENERIC_TYPES_CACHE.set(generic_types_cache) - generic_types_cache[_early_cache_key(parent, typevar_values)] = type_ + _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_ if len(typevar_values) == 1: - generic_types_cache[_early_cache_key(parent, typevar_values[0])] = type_ + _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_ if origin and args: - generic_types_cache[_late_cache_key(origin, args, typevar_values)] = type_ + _GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_ def _union_orderings_key(typevar_values: Any) -> Any: @@ -517,7 +490,7 @@ def _union_orderings_key(typevar_values: Any) -> Any: for value in typevar_values: args_data.append(_union_orderings_key(value)) return tuple(args_data) - elif typing_objects.is_union(typing_extensions.get_origin(typevar_values)): + elif typing_extensions.get_origin(typevar_values) is typing.Union: return get_args(typevar_values) else: return () diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_git.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_git.py deleted file mode 100644 index 96dcda2..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_git.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py).""" - -from __future__ import annotations - -import subprocess -from pathlib import Path - - -def is_git_repo(dir: Path) -> bool: - """Is the given directory version-controlled with git?""" - return dir.joinpath('.git').exists() - - -def have_git() -> bool: # pragma: no cover - """Can we run the git executable?""" - try: - subprocess.check_output(['git', '--help']) - return True - except subprocess.CalledProcessError: - return False - except OSError: - return False - - -def git_revision(dir: Path) -> str: - """Get the SHA-1 of the HEAD of a git repository.""" - return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip() diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_import_utils.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_import_utils.py deleted file mode 100644 index 638102f..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_import_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from functools import cache -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pydantic import BaseModel - from pydantic.fields import FieldInfo - - -@cache -def import_cached_base_model() -> type['BaseModel']: - from pydantic import BaseModel - - return BaseModel - - -@cache -def import_cached_field_info() -> type['FieldInfo']: - from pydantic.fields import FieldInfo - - return FieldInfo diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_internal_dataclass.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_internal_dataclass.py index 33e152c..317a3d9 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_internal_dataclass.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_internal_dataclass.py @@ -1,4 +1,7 @@ import sys +from typing import Any, Dict + +dataclass_kwargs: Dict[str, Any] # `slots` is available on Python >= 3.10 if sys.version_info >= (3, 10): diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_known_annotated_metadata.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_known_annotated_metadata.py index c127e27..307adfa 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_known_annotated_metadata.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_known_annotated_metadata.py @@ -1,57 +1,42 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Iterable from copy import copy -from functools import lru_cache, partial -from typing import TYPE_CHECKING, Any +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Iterable -from pydantic_core import CoreSchema, PydanticCustomError, ValidationError, to_jsonable_python +import annotated_types as at +from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python from pydantic_core import core_schema as cs -from ._fields import PydanticMetadata -from ._import_utils import import_cached_field_info +from . import _validators +from ._fields import PydanticGeneralMetadata, PydanticMetadata if TYPE_CHECKING: - pass + from ..annotated_handlers import GetJsonSchemaHandler + STRICT = {'strict'} -FAIL_FAST = {'fail_fast'} -LENGTH_CONSTRAINTS = {'min_length', 'max_length'} +SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'} INEQUALITY = {'le', 'ge', 'lt', 'gt'} -NUMERIC_CONSTRAINTS = {'multiple_of', *INEQUALITY} -ALLOW_INF_NAN = {'allow_inf_nan'} +NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY} -STR_CONSTRAINTS = { - *LENGTH_CONSTRAINTS, - *STRICT, - 'strip_whitespace', - 'to_lower', - 'to_upper', - 'pattern', - 'coerce_numbers_to_str', -} -BYTES_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT} +STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'} +BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} -LIST_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST} -TUPLE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST} -SET_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT, *FAIL_FAST} -DICT_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT} -GENERATOR_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *STRICT} -SEQUENCE_CONSTRAINTS = {*LENGTH_CONSTRAINTS, *FAIL_FAST} +LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} -FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT} -DECIMAL_CONSTRAINTS = {'max_digits', 'decimal_places', *FLOAT_CONSTRAINTS} -INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *ALLOW_INF_NAN, *STRICT} +FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} BOOL_CONSTRAINTS = STRICT -UUID_CONSTRAINTS = STRICT DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} -LAX_OR_STRICT_CONSTRAINTS = STRICT -ENUM_CONSTRAINTS = STRICT -COMPLEX_CONSTRAINTS = STRICT UNION_CONSTRAINTS = {'union_mode'} URL_CONSTRAINTS = { @@ -68,33 +53,54 @@ SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime') CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set) +for constraint in STR_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES) +for constraint in BYTES_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',)) +for constraint in LIST_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',)) +for constraint in TUPLE_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',)) +for constraint in SET_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset')) +for constraint in DICT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',)) +for constraint in GENERATOR_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',)) +for constraint in FLOAT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',)) +for constraint in INT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',)) +for constraint in DATE_TIME_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime')) +for constraint in TIMEDELTA_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',)) +for constraint in TIME_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',)) +for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'): + CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type) +for constraint in UNION_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',)) +for constraint in URL_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url')) +for constraint in BOOL_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',)) -constraint_schema_pairings: list[tuple[set[str], tuple[str, ...]]] = [ - (STR_CONSTRAINTS, TEXT_SCHEMA_TYPES), - (BYTES_CONSTRAINTS, ('bytes',)), - (LIST_CONSTRAINTS, ('list',)), - (TUPLE_CONSTRAINTS, ('tuple',)), - (SET_CONSTRAINTS, ('set', 'frozenset')), - (DICT_CONSTRAINTS, ('dict',)), - (GENERATOR_CONSTRAINTS, ('generator',)), - (FLOAT_CONSTRAINTS, ('float',)), - (INT_CONSTRAINTS, ('int',)), - (DATE_TIME_CONSTRAINTS, ('date', 'time', 'datetime', 'timedelta')), - # TODO: this is a bit redundant, we could probably avoid some of these - (STRICT, (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model')), - (UNION_CONSTRAINTS, ('union',)), - (URL_CONSTRAINTS, ('url', 'multi-host-url')), - (BOOL_CONSTRAINTS, ('bool',)), - (UUID_CONSTRAINTS, ('uuid',)), - (LAX_OR_STRICT_CONSTRAINTS, ('lax-or-strict',)), - (ENUM_CONSTRAINTS, ('enum',)), - (DECIMAL_CONSTRAINTS, ('decimal',)), - (COMPLEX_CONSTRAINTS, ('complex',)), -] -for constraints, schemas in constraint_schema_pairings: - for c in constraints: - CONSTRAINTS_TO_ALLOWED_SCHEMAS[c].update(schemas) +def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None: + def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]: + js_schema = handler(s) + js_schema.update(f()) + return js_schema + + if 'metadata' in s: + metadata = s['metadata'] + if 'pydantic_js_functions' in s: + metadata['pydantic_js_functions'].append(update_js_schema) + else: + metadata['pydantic_js_functions'] = [update_js_schema] + else: + s['metadata'] = {'pydantic_js_functions': [update_js_schema]} def as_jsonable_value(v: Any) -> Any: @@ -113,7 +119,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]: An iterable of expanded annotations. Example: - ```python + ```py from annotated_types import Ge, Len from pydantic._internal._known_annotated_metadata import expand_grouped_metadata @@ -122,9 +128,7 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]: #> [Ge(ge=4), MinLen(min_length=5)] ``` """ - import annotated_types as at - - FieldInfo = import_cached_field_info() + from pydantic.fields import FieldInfo # circular import for annotation in annotations: if isinstance(annotation, at.GroupedMetadata): @@ -143,28 +147,6 @@ def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]: yield annotation -@lru_cache -def _get_at_to_constraint_map() -> dict[type, str]: - """Return a mapping of annotated types to constraints. - - Normally, we would define a mapping like this in the module scope, but we can't do that - because we don't permit module level imports of `annotated_types`, in an attempt to speed up - the import time of `pydantic`. We still only want to have this dictionary defined in one place, - so we use this function to cache the result. - """ - import annotated_types as at - - return { - at.Gt: 'gt', - at.Ge: 'ge', - at.Lt: 'lt', - at.Le: 'le', - at.MultipleOf: 'multiple_of', - at.MinLen: 'min_length', - at.MaxLen: 'max_length', - } - - def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901 """Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.). Otherwise return `None`. @@ -184,37 +166,14 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No Raises: PydanticCustomError: If `Predicate` fails. """ - import annotated_types as at - - from ._validators import NUMERIC_VALIDATOR_LOOKUP, forbid_inf_nan_check - schema = schema.copy() schema_update, other_metadata = collect_known_metadata([annotation]) schema_type = schema['type'] - - chain_schema_constraints: set[str] = { - 'pattern', - 'strip_whitespace', - 'to_lower', - 'to_upper', - 'coerce_numbers_to_str', - } - chain_schema_steps: list[CoreSchema] = [] - for constraint, value in schema_update.items(): if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS: raise ValueError(f'Unknown constraint {constraint}') allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint] - # if it becomes necessary to handle more than one constraint - # in this recursive case with function-after or function-wrap, we should refactor - # this is a bit challenging because we sometimes want to apply constraints to the inner schema, - # whereas other times we want to wrap the existing schema with a new one that enforces a new constraint. - if schema_type in {'function-before', 'function-wrap', 'function-after'} and constraint == 'strict': - schema['schema'] = apply_known_metadata(annotation, schema['schema']) # type: ignore # schema is function schema - return schema - - # if we're allowed to apply constraint directly to the schema, like le to int, do that if schema_type in allowed_schemas: if constraint == 'union_mode' and schema_type == 'union': schema['mode'] = value # type: ignore # schema is UnionSchema @@ -222,109 +181,145 @@ def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | No schema[constraint] = value continue - # else, apply a function after validator to the schema to enforce the corresponding constraint - if constraint in chain_schema_constraints: - - def _apply_constraint_with_incompatibility_info( - value: Any, handler: cs.ValidatorFunctionWrapHandler - ) -> Any: - try: - x = handler(value) - except ValidationError as ve: - # if the error is about the type, it's likely that the constraint is incompatible the type of the field - # for example, the following invalid schema wouldn't be caught during schema build, but rather at this point - # with a cryptic 'string_type' error coming from the string validator, - # that we'd rather express as a constraint incompatibility error (TypeError) - # Annotated[list[int], Field(pattern='abc')] - if 'type' in ve.errors()[0]['type']: - raise TypeError( - f"Unable to apply constraint '{constraint}' to supplied value {value} for schema of type '{schema_type}'" # noqa: B023 - ) - raise ve - return x - - chain_schema_steps.append( - cs.no_info_wrap_validator_function( - _apply_constraint_with_incompatibility_info, cs.str_schema(**{constraint: value}) - ) + if constraint == 'allow_inf_nan' and value is False: + return cs.no_info_after_validator_function( + _validators.forbid_inf_nan_check, + schema, ) - elif constraint in NUMERIC_VALIDATOR_LOOKUP: - if constraint in LENGTH_CONSTRAINTS: - inner_schema = schema - while inner_schema['type'] in {'function-before', 'function-wrap', 'function-after'}: - inner_schema = inner_schema['schema'] # type: ignore - inner_schema_type = inner_schema['type'] - if inner_schema_type == 'list' or ( - inner_schema_type == 'json-or-python' and inner_schema['json_schema']['type'] == 'list' # type: ignore - ): - js_constraint_key = 'minItems' if constraint == 'min_length' else 'maxItems' - else: - js_constraint_key = 'minLength' if constraint == 'min_length' else 'maxLength' - else: - js_constraint_key = constraint - - schema = cs.no_info_after_validator_function( - partial(NUMERIC_VALIDATOR_LOOKUP[constraint], **{constraint: value}), schema + elif constraint == 'pattern': + # insert a str schema to make sure the regex engine matches + return cs.chain_schema( + [ + schema, + cs.str_schema(pattern=value), + ] ) - metadata = schema.get('metadata', {}) - if (existing_json_schema_updates := metadata.get('pydantic_js_updates')) is not None: - metadata['pydantic_js_updates'] = { - **existing_json_schema_updates, - **{js_constraint_key: as_jsonable_value(value)}, - } - else: - metadata['pydantic_js_updates'] = {js_constraint_key: as_jsonable_value(value)} - schema['metadata'] = metadata - elif constraint == 'allow_inf_nan' and value is False: - schema = cs.no_info_after_validator_function( - forbid_inf_nan_check, + elif constraint == 'gt': + s = cs.no_info_after_validator_function( + partial(_validators.greater_than_validator, gt=value), + schema, + ) + add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)}) + return s + elif constraint == 'ge': + return cs.no_info_after_validator_function( + partial(_validators.greater_than_or_equal_validator, ge=value), + schema, + ) + elif constraint == 'lt': + return cs.no_info_after_validator_function( + partial(_validators.less_than_validator, lt=value), + schema, + ) + elif constraint == 'le': + return cs.no_info_after_validator_function( + partial(_validators.less_than_or_equal_validator, le=value), + schema, + ) + elif constraint == 'multiple_of': + return cs.no_info_after_validator_function( + partial(_validators.multiple_of_validator, multiple_of=value), + schema, + ) + elif constraint == 'min_length': + s = cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=value), + schema, + ) + add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))}) + return s + elif constraint == 'max_length': + s = cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=value), + schema, + ) + add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))}) + return s + elif constraint == 'strip_whitespace': + return cs.chain_schema( + [ + schema, + cs.str_schema(strip_whitespace=True), + ] + ) + elif constraint == 'to_lower': + return cs.chain_schema( + [ + schema, + cs.str_schema(to_lower=True), + ] + ) + elif constraint == 'to_upper': + return cs.chain_schema( + [ + schema, + cs.str_schema(to_upper=True), + ] + ) + elif constraint == 'min_length': + return cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=annotation.min_length), + schema, + ) + elif constraint == 'max_length': + return cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=annotation.max_length), schema, ) else: - # It's rare that we'd get here, but it's possible if we add a new constraint and forget to handle it - # Most constraint errors are caught at runtime during attempted application - raise RuntimeError(f"Unable to apply constraint '{constraint}' to schema of type '{schema_type}'") + raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}') for annotation in other_metadata: - if (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()): - constraint = at_to_constraint_map[annotation_type] - validator = NUMERIC_VALIDATOR_LOOKUP.get(constraint) - if validator is None: - raise ValueError(f'Unknown constraint {constraint}') - schema = cs.no_info_after_validator_function( - partial(validator, {constraint: getattr(annotation, constraint)}), schema + if isinstance(annotation, at.Gt): + return cs.no_info_after_validator_function( + partial(_validators.greater_than_validator, gt=annotation.gt), + schema, ) - continue - elif isinstance(annotation, (at.Predicate, at.Not)): - predicate_name = f'{annotation.func.__qualname__}' if hasattr(annotation.func, '__qualname__') else '' + elif isinstance(annotation, at.Ge): + return cs.no_info_after_validator_function( + partial(_validators.greater_than_or_equal_validator, ge=annotation.ge), + schema, + ) + elif isinstance(annotation, at.Lt): + return cs.no_info_after_validator_function( + partial(_validators.less_than_validator, lt=annotation.lt), + schema, + ) + elif isinstance(annotation, at.Le): + return cs.no_info_after_validator_function( + partial(_validators.less_than_or_equal_validator, le=annotation.le), + schema, + ) + elif isinstance(annotation, at.MultipleOf): + return cs.no_info_after_validator_function( + partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of), + schema, + ) + elif isinstance(annotation, at.MinLen): + return cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=annotation.min_length), + schema, + ) + elif isinstance(annotation, at.MaxLen): + return cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=annotation.max_length), + schema, + ) + elif isinstance(annotation, at.Predicate): + predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else '' def val_func(v: Any) -> Any: - predicate_satisfied = annotation.func(v) # noqa: B023 - # annotation.func may also raise an exception, let it pass through - if isinstance(annotation, at.Predicate): # noqa: B023 - if not predicate_satisfied: - raise PydanticCustomError( - 'predicate_failed', - f'Predicate {predicate_name} failed', # type: ignore # noqa: B023 - ) - else: - if predicate_satisfied: - raise PydanticCustomError( - 'not_operation_failed', - f'Not of {predicate_name} failed', # type: ignore # noqa: B023 - ) - + if not annotation.func(v): + raise PydanticCustomError( + 'predicate_failed', + f'Predicate {predicate_name}failed', # type: ignore + ) return v - schema = cs.no_info_after_validator_function(val_func, schema) - else: - # ignore any other unknown metadata - return None - - if chain_schema_steps: - chain_schema_steps = [schema] + chain_schema_steps - return cs.chain_schema(chain_schema_steps) + return cs.no_info_after_validator_function(val_func, schema) + # ignore any other unknown metadata + return None return schema @@ -339,7 +334,7 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], A tuple contains a dict of known metadata and a list of unknown annotations. Example: - ```python + ```py from annotated_types import Gt, Len from pydantic._internal._known_annotated_metadata import collect_known_metadata @@ -352,15 +347,29 @@ def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], res: dict[str, Any] = {} remaining: list[Any] = [] - for annotation in annotations: - # isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata - if isinstance(annotation, PydanticMetadata): + # Do we really want to consume any `BaseMetadata`? + # It does let us give a better error when there is an annotation that doesn't apply + # But it seems dangerous! + if isinstance(annotation, PydanticGeneralMetadata): + res.update(annotation.__dict__) + elif isinstance(annotation, PydanticMetadata): res.update(annotation.__dict__) # we don't use dataclasses.asdict because that recursively calls asdict on the field values - elif (annotation_type := type(annotation)) in (at_to_constraint_map := _get_at_to_constraint_map()): - constraint = at_to_constraint_map[annotation_type] - res[constraint] = getattr(annotation, constraint) + elif isinstance(annotation, at.MinLen): + res.update({'min_length': annotation.min_length}) + elif isinstance(annotation, at.MaxLen): + res.update({'max_length': annotation.max_length}) + elif isinstance(annotation, at.Gt): + res.update({'gt': annotation.gt}) + elif isinstance(annotation, at.Ge): + res.update({'ge': annotation.ge}) + elif isinstance(annotation, at.Lt): + res.update({'lt': annotation.lt}) + elif isinstance(annotation, at.Le): + res.update({'le': annotation.le}) + elif isinstance(annotation, at.MultipleOf): + res.update({'multiple_of': annotation.multiple_of}) elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata): # also support PydanticMetadata classes being used without initialisation, # e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]` diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_mock_val_ser.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_mock_val_ser.py index 9125ab3..ea03a68 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_mock_val_ser.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_mock_val_ser.py @@ -1,71 +1,18 @@ from __future__ import annotations -from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, Union +from typing import TYPE_CHECKING, Callable, Generic, TypeVar -from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator +from pydantic_core import SchemaSerializer, SchemaValidator +from typing_extensions import Literal from ..errors import PydanticErrorCodes, PydanticUserError -from ..plugin._schema_validator import PluggableSchemaValidator if TYPE_CHECKING: from ..dataclasses import PydanticDataclass from ..main import BaseModel - from ..type_adapter import TypeAdapter -ValSer = TypeVar('ValSer', bound=Union[SchemaValidator, PluggableSchemaValidator, SchemaSerializer]) -T = TypeVar('T') - - -class MockCoreSchema(Mapping[str, Any]): - """Mocker for `pydantic_core.CoreSchema` which optionally attempts to - rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails. - """ - - __slots__ = '_error_message', '_code', '_attempt_rebuild', '_built_memo' - - def __init__( - self, - error_message: str, - *, - code: PydanticErrorCodes, - attempt_rebuild: Callable[[], CoreSchema | None] | None = None, - ) -> None: - self._error_message = error_message - self._code: PydanticErrorCodes = code - self._attempt_rebuild = attempt_rebuild - self._built_memo: CoreSchema | None = None - - def __getitem__(self, key: str) -> Any: - return self._get_built().__getitem__(key) - - def __len__(self) -> int: - return self._get_built().__len__() - - def __iter__(self) -> Iterator[str]: - return self._get_built().__iter__() - - def _get_built(self) -> CoreSchema: - if self._built_memo is not None: - return self._built_memo - - if self._attempt_rebuild: - schema = self._attempt_rebuild() - if schema is not None: - self._built_memo = schema - return schema - raise PydanticUserError(self._error_message, code=self._code) - - def rebuild(self) -> CoreSchema | None: - self._built_memo = None - if self._attempt_rebuild: - schema = self._attempt_rebuild() - if schema is not None: - return schema - else: - raise PydanticUserError(self._error_message, code=self._code) - return None +ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer) class MockValSer(Generic[ValSer]): @@ -109,120 +56,63 @@ class MockValSer(Generic[ValSer]): return None -def set_type_adapter_mocks(adapter: TypeAdapter) -> None: - """Set `core_schema`, `validator` and `serializer` to mock core types on a type adapter instance. - - Args: - adapter: The type adapter instance to set the mocks on - """ - type_repr = str(adapter._type) - undefined_type_error_message = ( - f'`TypeAdapter[{type_repr}]` is not fully defined; you should define `{type_repr}` and all referenced types,' - f' then call `.rebuild()` on the instance.' - ) - - def attempt_rebuild_fn(attr_fn: Callable[[TypeAdapter], T]) -> Callable[[], T | None]: - def handler() -> T | None: - if adapter.rebuild(raise_errors=False, _parent_namespace_depth=5) is not False: - return attr_fn(adapter) - return None - - return handler - - adapter.core_schema = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.core_schema), - ) - adapter.validator = MockValSer( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - val_or_ser='validator', - attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.validator), - ) - adapter.serializer = MockValSer( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - val_or_ser='serializer', - attempt_rebuild=attempt_rebuild_fn(lambda ta: ta.serializer), - ) - - -def set_model_mocks(cls: type[BaseModel], undefined_name: str = 'all referenced types') -> None: - """Set `__pydantic_core_schema__`, `__pydantic_validator__` and `__pydantic_serializer__` to mock core types on a model. +def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None: + """Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model. Args: cls: The model class to set the mocks on + cls_name: Name of the model class, used in error messages undefined_name: Name of the undefined thing, used in error messages """ undefined_type_error_message = ( - f'`{cls.__name__}` is not fully defined; you should define {undefined_name},' - f' then call `{cls.__name__}.model_rebuild()`.' + f'`{cls_name}` is not fully defined; you should define {undefined_name},' + f' then call `{cls_name}.model_rebuild()`.' ) - def attempt_rebuild_fn(attr_fn: Callable[[type[BaseModel]], T]) -> Callable[[], T | None]: - def handler() -> T | None: - if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False: - return attr_fn(cls) + def attempt_rebuild_validator() -> SchemaValidator | None: + if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5): + return cls.__pydantic_validator__ + else: return None - return handler - - cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__), - ) - cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue] + cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment] undefined_type_error_message, code='class-not-fully-defined', val_or_ser='validator', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__), + attempt_rebuild=attempt_rebuild_validator, ) - cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue] + + def attempt_rebuild_serializer() -> SchemaSerializer | None: + if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5): + return cls.__pydantic_serializer__ + else: + return None + + cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment] undefined_type_error_message, code='class-not-fully-defined', val_or_ser='serializer', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__), + attempt_rebuild=attempt_rebuild_serializer, ) -def set_dataclass_mocks(cls: type[PydanticDataclass], undefined_name: str = 'all referenced types') -> None: - """Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass. - - Args: - cls: The model class to set the mocks on - undefined_name: Name of the undefined thing, used in error messages - """ - from ..dataclasses import rebuild_dataclass - +def set_dataclass_mock_validator(cls: type[PydanticDataclass], cls_name: str, undefined_name: str) -> None: undefined_type_error_message = ( - f'`{cls.__name__}` is not fully defined; you should define {undefined_name},' - f' then call `pydantic.dataclasses.rebuild_dataclass({cls.__name__})`.' + f'`{cls_name}` is not fully defined; you should define {undefined_name},' + f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.' ) - def attempt_rebuild_fn(attr_fn: Callable[[type[PydanticDataclass]], T]) -> Callable[[], T | None]: - def handler() -> T | None: - if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False: - return attr_fn(cls) + def attempt_rebuild() -> SchemaValidator | None: + from ..dataclasses import rebuild_dataclass + + if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5): + return cls.__pydantic_validator__ + else: return None - return handler - - cls.__pydantic_core_schema__ = MockCoreSchema( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_core_schema__), - ) - cls.__pydantic_validator__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue] + cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment] undefined_type_error_message, code='class-not-fully-defined', val_or_ser='validator', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_validator__), - ) - cls.__pydantic_serializer__ = MockValSer( # pyright: ignore[reportAttributeAccessIssue] - undefined_type_error_message, - code='class-not-fully-defined', - val_or_ser='serializer', - attempt_rebuild=attempt_rebuild_fn(lambda c: c.__pydantic_serializer__), + attempt_rebuild=attempt_rebuild, ) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_model_construction.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_model_construction.py index fd5d68b..15dbb3e 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_model_construction.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_model_construction.py @@ -1,54 +1,58 @@ """Private logic for creating models.""" - from __future__ import annotations as _annotations -import builtins -import operator -import sys import typing import warnings import weakref from abc import ABCMeta -from functools import cache, partial, wraps +from functools import partial from types import FunctionType -from typing import Any, Callable, Generic, Literal, NoReturn, cast +from typing import Any, Callable, Generic, Mapping from pydantic_core import PydanticUndefined, SchemaSerializer -from typing_extensions import TypeAliasType, dataclass_transform, deprecated, get_args, get_origin -from typing_inspection import typing_objects +from typing_extensions import dataclass_transform, deprecated from ..errors import PydanticUndefinedAnnotation, PydanticUserError +from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr from ..plugin._schema_validator import create_schema_validator -from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20 +from ..warnings import PydanticDeprecatedSince20 from ._config import ConfigWrapper -from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases, unwrap_wrapped_function -from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name -from ._generate_schema import GenerateSchema, InvalidSchemaError -from ._generics import PydanticGenericMetadata, get_model_typevars_map -from ._import_utils import import_cached_base_model, import_cached_field_info -from ._mock_val_ser import set_model_mocks -from ._namespace_utils import NsResolver -from ._signature import generate_pydantic_signature -from ._typing_extra import ( - _make_forward_ref, - eval_type_backport, - is_classvar_annotation, - parent_frame_namespace, +from ._core_utils import collect_invalid_schemas, simplify_schema_references, validate_core_schema +from ._decorators import ( + ComputedFieldInfo, + DecoratorInfos, + PydanticDescriptorProxy, + get_attribute_from_bases, ) -from ._utils import LazyClassAttribute, SafeGetItemProxy +from ._discriminated_union import apply_discriminators +from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name +from ._generate_schema import GenerateSchema +from ._generics import PydanticGenericMetadata, get_model_typevars_map +from ._mock_val_ser import MockValSer, set_model_mocks +from ._schema_generation_shared import CallbackGetCoreSchemaHandler +from ._typing_extra import get_cls_types_namespace, is_classvar, parent_frame_namespace +from ._utils import ClassAttribute, is_valid_identifier +from ._validate_call import ValidateCallWrapper if typing.TYPE_CHECKING: - from ..fields import Field as PydanticModelField - from ..fields import FieldInfo, ModelPrivateAttr - from ..fields import PrivateAttr as PydanticModelPrivateAttr + from inspect import Signature + from ..main import BaseModel else: # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 # and https://youtrack.jetbrains.com/issue/PY-51428 DeprecationWarning = PydanticDeprecatedSince20 - PydanticModelField = object() - PydanticModelPrivateAttr = object() + +IGNORED_TYPES: tuple[Any, ...] = ( + FunctionType, + property, + classmethod, + staticmethod, + PydanticDescriptorProxy, + ComputedFieldInfo, + ValidateCallWrapper, +) object_setattr = object.__setattr__ @@ -65,17 +69,7 @@ class _ModelNamespaceDict(dict): return super().__setitem__(k, v) -def NoInitField( - *, - init: Literal[False] = False, -) -> Any: - """Only for typing purposes. Used as default value of `__pydantic_fields_set__`, - `__pydantic_extra__`, `__pydantic_private__`, so they could be ignored when - synthesizing the `__init__` signature. - """ - - -@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr, NoInitField)) +@dataclass_transform(kw_only_default=True, field_specifiers=(Field,)) class ModelMetaclass(ABCMeta): def __new__( mcs, @@ -84,7 +78,6 @@ class ModelMetaclass(ABCMeta): namespace: dict[str, Any], __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, __pydantic_reset_parent_namespace__: bool = True, - _create_model_module: str | None = None, **kwargs: Any, ) -> type: """Metaclass for creating Pydantic models. @@ -95,7 +88,6 @@ class ModelMetaclass(ABCMeta): namespace: The attribute dictionary of the class to be created. __pydantic_generic_metadata__: Metadata for generic models. __pydantic_reset_parent_namespace__: Reset parent namespace. - _create_model_module: The module of the class to be created, if created by `create_model`. **kwargs: Catch-all for any other keyword arguments. Returns: @@ -112,18 +104,17 @@ class ModelMetaclass(ABCMeta): private_attributes = inspect_namespace( namespace, config_wrapper.ignored_types, class_vars, base_field_names ) - if private_attributes or base_private_attributes: + if private_attributes: original_model_post_init = get_model_post_init(namespace, bases) if original_model_post_init is not None: # if there are private_attributes and a model_post_init function, we handle both - @wraps(original_model_post_init) - def wrapped_model_post_init(self: BaseModel, context: Any, /) -> None: + def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: """We need to both initialize private attributes and call the user-defined model_post_init method. """ - init_private_attributes(self, context) - original_model_post_init(self, context) + init_private_attributes(self, __context) + original_model_post_init(self, __context) namespace['model_post_init'] = wrapped_model_post_init else: @@ -132,25 +123,15 @@ class ModelMetaclass(ABCMeta): namespace['__class_vars__'] = class_vars namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes} - cls = cast('type[BaseModel]', super().__new__(mcs, cls_name, bases, namespace, **kwargs)) - BaseModel_ = import_cached_base_model() + if config_wrapper.frozen: + set_default_hash_func(namespace, bases) - mro = cls.__mro__ - if Generic in mro and mro.index(Generic) < mro.index(BaseModel_): - warnings.warn( - GenericBeforeBaseModelWarning( - 'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) ' - 'for pydantic generics to work properly.' - ), - stacklevel=2, - ) + cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore + + from ..main import BaseModel cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False) - cls.__pydantic_post_init__ = ( - None if cls.model_post_init is BaseModel_.model_post_init else 'model_post_init' - ) - - cls.__pydantic_setattr_handlers__ = {} + cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init' cls.__pydantic_decorators__ = DecoratorInfos.build(cls) @@ -161,40 +142,22 @@ class ModelMetaclass(ABCMeta): parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ()) parameters = getattr(cls, '__parameters__', None) or parent_parameters if parameters and parent_parameters and not all(x in parameters for x in parent_parameters): - from ..root_model import RootModelRootType - - missing_parameters = tuple(x for x in parameters if x not in parent_parameters) - if RootModelRootType in parent_parameters and RootModelRootType not in parameters: - # This is a special case where the user has subclassed `RootModel`, but has not parametrized - # RootModel with the generic type identifiers being used. Ex: - # class MyModel(RootModel, Generic[T]): - # root: T - # Should instead just be: - # class MyModel(RootModel[T]): - # root: T - parameters_str = ', '.join([x.__name__ for x in missing_parameters]) - error_message = ( - f'{cls.__name__} is a subclass of `RootModel`, but does not include the generic type identifier(s) ' - f'{parameters_str} in its parameters. ' - f'You should parametrize RootModel directly, e.g., `class {cls.__name__}(RootModel[{parameters_str}]): ...`.' + combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters) + parameters_str = ', '.join([str(x) for x in combined_parameters]) + generic_type_label = f'typing.Generic[{parameters_str}]' + error_message = ( + f'All parameters must be present on typing.Generic;' + f' you should inherit from {generic_type_label}.' + ) + if Generic not in bases: # pragma: no cover + # We raise an error here not because it is desirable, but because some cases are mishandled. + # It would be nice to remove this error and still have things behave as expected, it's just + # challenging because we are using a custom `__class_getitem__` to parametrize generic models, + # and not returning a typing._GenericAlias from it. + bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label]) + error_message += ( + f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)' ) - else: - combined_parameters = parent_parameters + missing_parameters - parameters_str = ', '.join([str(x) for x in combined_parameters]) - generic_type_label = f'typing.Generic[{parameters_str}]' - error_message = ( - f'All parameters must be present on typing.Generic;' - f' you should inherit from {generic_type_label}.' - ) - if Generic not in bases: # pragma: no cover - # We raise an error here not because it is desirable, but because some cases are mishandled. - # It would be nice to remove this error and still have things behave as expected, it's just - # challenging because we are using a custom `__class_getitem__` to parametrize generic models, - # and not returning a typing._GenericAlias from it. - bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label]) - error_message += ( - f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)' - ) raise TypeError(error_message) cls.__pydantic_generic_metadata__ = { @@ -212,55 +175,29 @@ class ModelMetaclass(ABCMeta): if __pydantic_reset_parent_namespace__: cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace()) - parent_namespace: dict[str, Any] | None = getattr(cls, '__pydantic_parent_namespace__', None) + parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None) if isinstance(parent_namespace, dict): parent_namespace = unpack_lenient_weakvaluedict(parent_namespace) - ns_resolver = NsResolver(parent_namespace=parent_namespace) - - set_model_fields(cls, config_wrapper=config_wrapper, ns_resolver=ns_resolver) - - # This is also set in `complete_model_class()`, after schema gen because they are recreated. - # We set them here as well for backwards compatibility: - cls.__pydantic_computed_fields__ = { - k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items() - } - - if config_wrapper.defer_build: - # TODO we can also stop there if `__pydantic_fields_complete__` is False. - # However, `set_model_fields()` is currently lenient and we don't have access to the `NameError`. - # (which is useful as we can provide the name in the error message: `set_model_mock(cls, e.name)`) - set_model_mocks(cls) - else: - # Any operation that requires accessing the field infos instances should be put inside - # `complete_model_class()`: - complete_model_class( - cls, - config_wrapper, - raise_errors=False, - ns_resolver=ns_resolver, - create_model_module=_create_model_module, - ) - - if config_wrapper.frozen and '__hash__' not in namespace: - set_default_hash_func(cls, bases) - + types_namespace = get_cls_types_namespace(cls, parent_namespace) + set_model_fields(cls, bases, config_wrapper, types_namespace) + complete_model_class( + cls, + cls_name, + config_wrapper, + raise_errors=False, + types_namespace=types_namespace, + ) # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is # only hit for _proper_ subclasses of BaseModel super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc] return cls else: - # These are instance variables, but have been assigned to `NoInitField` to trick the type checker. - for instance_slot in '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__': - namespace.pop( - instance_slot, - None, # In case the metaclass is used with a class other than `BaseModel`. - ) - namespace.get('__annotations__', {}).clear() + # this is the BaseModel class itself being created, no logic required return super().__new__(mcs, cls_name, bases, namespace, **kwargs) - if not typing.TYPE_CHECKING: # pragma: no branch + if not typing.TYPE_CHECKING: # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access def __getattr__(self, item: str) -> Any: @@ -268,29 +205,30 @@ class ModelMetaclass(ABCMeta): private_attributes = self.__dict__.get('__private_attributes__') if private_attributes and item in private_attributes: return private_attributes[item] + if item == '__pydantic_core_schema__': + # This means the class didn't get a schema generated for it, likely because there was an undefined reference + maybe_mock_validator = getattr(self, '__pydantic_validator__', None) + if isinstance(maybe_mock_validator, MockValSer): + rebuilt_validator = maybe_mock_validator.rebuild() + if rebuilt_validator is not None: + # In this case, a validator was built, and so `__pydantic_core_schema__` should now be set + return getattr(self, '__pydantic_core_schema__') raise AttributeError(item) @classmethod - def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]: + def __prepare__(cls, *args: Any, **kwargs: Any) -> Mapping[str, object]: return _ModelNamespaceDict() def __instancecheck__(self, instance: Any) -> bool: - """Avoid calling ABC _abc_instancecheck unless we're pretty sure. - - See #3829 and python/cpython#92810 - """ - return hasattr(instance, '__pydantic_decorators__') and super().__instancecheck__(instance) - - def __subclasscheck__(self, subclass: type[Any]) -> bool: """Avoid calling ABC _abc_subclasscheck unless we're pretty sure. See #3829 and python/cpython#92810 """ - return hasattr(subclass, '__pydantic_decorators__') and super().__subclasscheck__(subclass) + return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance) @staticmethod def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]: - BaseModel = import_cached_base_model() + from ..main import BaseModel field_names: set[str] = set() class_vars: set[str] = set() @@ -298,57 +236,35 @@ class ModelMetaclass(ABCMeta): for base in bases: if issubclass(base, BaseModel) and base is not BaseModel: # model_fields might not be defined yet in the case of generics, so we use getattr here: - field_names.update(getattr(base, '__pydantic_fields__', {}).keys()) + field_names.update(getattr(base, 'model_fields', {}).keys()) class_vars.update(base.__class_vars__) private_attributes.update(base.__private_attributes__) return field_names, class_vars, private_attributes @property - @deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None) + @deprecated( + 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20 + ) def __fields__(self) -> dict[str, FieldInfo]: - warnings.warn( - 'The `__fields__` attribute is deprecated, use `model_fields` instead.', - PydanticDeprecatedSince20, - stacklevel=2, - ) - return getattr(self, '__pydantic_fields__', {}) - - @property - def __pydantic_fields_complete__(self) -> bool: - """Whether the fields where successfully collected (i.e. type hints were successfully resolves). - - This is a private attribute, not meant to be used outside Pydantic. - """ - if not hasattr(self, '__pydantic_fields__'): - return False - - field_infos = cast('dict[str, FieldInfo]', self.__pydantic_fields__) # pyright: ignore[reportAttributeAccessIssue] - - return all(field_info._complete for field_info in field_infos.values()) - - def __dir__(self) -> list[str]: - attributes = list(super().__dir__()) - if '__fields__' in attributes: - attributes.remove('__fields__') - return attributes + warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning) + return self.model_fields # type: ignore -def init_private_attributes(self: BaseModel, context: Any, /) -> None: +def init_private_attributes(self: BaseModel, __context: Any) -> None: """This function is meant to behave like a BaseModel method to initialise private attributes. It takes context as an argument since that's what pydantic-core passes when calling it. Args: self: The BaseModel instance. - context: The context. + __context: The context. """ - if getattr(self, '__pydantic_private__', None) is None: - pydantic_private = {} - for name, private_attr in self.__private_attributes__.items(): - default = private_attr.get_default() - if default is not PydanticUndefined: - pydantic_private[name] = default - object_setattr(self, '__pydantic_private__', pydantic_private) + pydantic_private = {} + for name, private_attr in self.__private_attributes__.items(): + default = private_attr.get_default() + if default is not PydanticUndefined: + pydantic_private[name] = default + object_setattr(self, '__pydantic_private__', pydantic_private) def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None: @@ -356,7 +272,7 @@ def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) if 'model_post_init' in namespace: return namespace['model_post_init'] - BaseModel = import_cached_base_model() + from ..main import BaseModel model_post_init = get_attribute_from_bases(bases, 'model_post_init') if model_post_init is not BaseModel.model_post_init: @@ -389,11 +305,7 @@ def inspect_namespace( # noqa C901 - If a field does not have a type annotation. - If a field on base class was overridden by a non-annotated attribute. """ - from ..fields import ModelPrivateAttr, PrivateAttr - - FieldInfo = import_cached_field_info() - - all_ignored_types = ignored_types + default_ignored_types() + all_ignored_types = ignored_types + IGNORED_TYPES private_attributes: dict[str, ModelPrivateAttr] = {} raw_annotations = namespace.get('__annotations__', {}) @@ -403,12 +315,11 @@ def inspect_namespace( # noqa C901 ignored_names: set[str] = set() for var_name, value in list(namespace.items()): - if var_name == 'model_config' or var_name == '__pydantic_extra__': + if var_name == 'model_config': continue elif ( isinstance(value, type) and value.__module__ == namespace['__module__'] - and '__qualname__' in namespace and value.__qualname__.startswith(namespace['__qualname__']) ): # `value` is a nested type defined in this namespace; don't error @@ -439,8 +350,8 @@ def inspect_namespace( # noqa C901 elif var_name.startswith('__'): continue elif is_valid_privateattr_name(var_name): - if var_name not in raw_annotations or not is_classvar_annotation(raw_annotations[var_name]): - private_attributes[var_name] = cast(ModelPrivateAttr, PrivateAttr(default=value)) + if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]): + private_attributes[var_name] = PrivateAttr(default=value) del namespace[var_name] elif var_name in base_class_vars: continue @@ -457,8 +368,8 @@ def inspect_namespace( # noqa C901 ) else: raise PydanticUserError( - f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a ' - f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this ' + f"A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a " + f"type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this " f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.", code='model-field-missing-annotation', ) @@ -468,82 +379,45 @@ def inspect_namespace( # noqa C901 is_valid_privateattr_name(ann_name) and ann_name not in private_attributes and ann_name not in ignored_names - # This condition can be a false negative when `ann_type` is stringified, - # but it is handled in most cases in `set_model_fields`: - and not is_classvar_annotation(ann_type) + and not is_classvar(ann_type) and ann_type not in all_ignored_types and getattr(ann_type, '__module__', None) != 'functools' ): - if isinstance(ann_type, str): - # Walking up the frames to get the module namespace where the model is defined - # (as the model class wasn't created yet, we unfortunately can't use `cls.__module__`): - frame = sys._getframe(2) - if frame is not None: - try: - ann_type = eval_type_backport( - _make_forward_ref(ann_type, is_argument=False, is_class=True), - globalns=frame.f_globals, - localns=frame.f_locals, - ) - except (NameError, TypeError): - pass - - if typing_objects.is_annotated(get_origin(ann_type)): - _, *metadata = get_args(ann_type) - private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None) - if private_attr is not None: - private_attributes[ann_name] = private_attr - continue private_attributes[ann_name] = PrivateAttr() return private_attributes -def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None: +def set_default_hash_func(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> None: + if '__hash__' in namespace: + return + base_hash_func = get_attribute_from_bases(bases, '__hash__') - new_hash_func = make_hash_func(cls) - if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__: - # If `__hash__` is some default, we generate a hash function. - # It will be `None` if not overridden from BaseModel. - # It may be `object.__hash__` if there is another + if base_hash_func in {None, object.__hash__}: + # If `__hash__` is None _or_ `object.__hash__`, we generate a hash function. + # It will be `None` if not overridden from BaseModel, but may be `object.__hash__` if there is another # parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`). - # It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model. - # In the last case we still need a new hash function to account for new `model_fields`. - cls.__hash__ = new_hash_func + def hash_func(self: Any) -> int: + return hash(self.__class__) + hash(tuple(self.__dict__.values())) - -def make_hash_func(cls: type[BaseModel]) -> Any: - getter = operator.itemgetter(*cls.__pydantic_fields__.keys()) if cls.__pydantic_fields__ else lambda _: 0 - - def hash_func(self: Any) -> int: - try: - return hash(getter(self.__dict__)) - except KeyError: - # In rare cases (such as when using the deprecated copy method), the __dict__ may not contain - # all model fields, which is how we can get here. - # getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys, - # and wrapping it in a `try` doesn't slow things down much in the common case. - return hash(getter(SafeGetItemProxy(self.__dict__))) - - return hash_func + namespace['__hash__'] = hash_func def set_model_fields( - cls: type[BaseModel], - config_wrapper: ConfigWrapper, - ns_resolver: NsResolver | None, + cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any] ) -> None: - """Collect and set `cls.__pydantic_fields__` and `cls.__class_vars__`. + """Collect and set `cls.model_fields` and `cls.__class_vars__`. Args: cls: BaseModel or dataclass. + bases: Parents of the class, generally `cls.__bases__`. config_wrapper: The config wrapper instance. - ns_resolver: Namespace resolver to use when getting model annotations. + types_namespace: Optional extra namespace to look for types in. """ typevars_map = get_model_typevars_map(cls) - fields, class_vars = collect_model_fields(cls, config_wrapper, ns_resolver, typevars_map=typevars_map) + fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map) - cls.__pydantic_fields__ = fields + cls.model_fields = fields cls.__class_vars__.update(class_vars) for k in class_vars: @@ -561,11 +435,11 @@ def set_model_fields( def complete_model_class( cls: type[BaseModel], + cls_name: str, config_wrapper: ConfigWrapper, *, raise_errors: bool = True, - ns_resolver: NsResolver | None = None, - create_model_module: str | None = None, + types_namespace: dict[str, Any] | None, ) -> bool: """Finish building a model class. @@ -574,10 +448,10 @@ def complete_model_class( Args: cls: BaseModel or dataclass. + cls_name: The model or dataclass name. config_wrapper: The config wrapper instance. raise_errors: Whether to raise errors. - ns_resolver: The namespace resolver instance to use during schema building. - create_model_module: The module of the class to be created, if created by `create_model`. + types_namespace: Optional extra namespace to look for types in. Returns: `True` if the model is successfully completed, else `False`. @@ -589,151 +463,132 @@ def complete_model_class( typevars_map = get_model_typevars_map(cls) gen_schema = GenerateSchema( config_wrapper, - ns_resolver, + types_namespace, typevars_map, ) + handler = CallbackGetCoreSchemaHandler( + partial(gen_schema.generate_schema, from_dunder_get_core_schema=False), + gen_schema, + ref_mode='unpack', + ) + + if config_wrapper.defer_build: + set_model_mocks(cls, cls_name) + return False + try: - schema = gen_schema.generate_schema(cls) + schema = cls.__get_pydantic_core_schema__(cls, handler) except PydanticUndefinedAnnotation as e: if raise_errors: raise - set_model_mocks(cls, f'`{e.name}`') + set_model_mocks(cls, cls_name, f'`{e.name}`') return False - core_config = config_wrapper.core_config(title=cls.__name__) + core_config = config_wrapper.core_config(cls) - try: - schema = gen_schema.clean_schema(schema) - except InvalidSchemaError: - set_model_mocks(cls) + schema = gen_schema.collect_definitions(schema) + + schema = apply_discriminators(simplify_schema_references(schema)) + if collect_invalid_schemas(schema): + set_model_mocks(cls, cls_name) return False - # This needs to happen *after* model schema generation, as the return type - # of the properties are evaluated and the `ComputedFieldInfo` are recreated: - cls.__pydantic_computed_fields__ = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()} - - set_deprecated_descriptors(cls) - - cls.__pydantic_core_schema__ = schema - - cls.__pydantic_validator__ = create_schema_validator( - schema, - cls, - create_model_module or cls.__module__, - cls.__qualname__, - 'create_model' if create_model_module else 'BaseModel', - core_config, - config_wrapper.plugin_settings, - ) + # debug(schema) + cls.__pydantic_core_schema__ = schema = validate_core_schema(schema) + cls.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings) cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True # set __signature__ attr only for model class, but not for its instances - # (because instances can define `__call__`, and `inspect.signature` shouldn't - # use the `__signature__` attribute and instead generate from `__call__`). - cls.__signature__ = LazyClassAttribute( - '__signature__', - partial( - generate_pydantic_signature, - init=cls.__init__, - fields=cls.__pydantic_fields__, - validate_by_name=config_wrapper.validate_by_name, - extra=config_wrapper.extra, - ), + cls.__signature__ = ClassAttribute( + '__signature__', generate_model_signature(cls.__init__, cls.model_fields, config_wrapper) ) return True -def set_deprecated_descriptors(cls: type[BaseModel]) -> None: - """Set data descriptors on the class for deprecated fields.""" - for field, field_info in cls.__pydantic_fields__.items(): - if (msg := field_info.deprecation_message) is not None: - desc = _DeprecatedFieldDescriptor(msg) - desc.__set_name__(cls, field) - setattr(cls, field, desc) +def generate_model_signature( + init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper +) -> Signature: + """Generate signature for model based on its fields. - for field, computed_field_info in cls.__pydantic_computed_fields__.items(): - if ( - (msg := computed_field_info.deprecation_message) is not None - # Avoid having two warnings emitted: - and not hasattr(unwrap_wrapped_function(computed_field_info.wrapped_property), '__deprecated__') - ): - desc = _DeprecatedFieldDescriptor(msg, computed_field_info.wrapped_property) - desc.__set_name__(cls, field) - setattr(cls, field, desc) + Args: + init: The class init. + fields: The model fields. + config_wrapper: The config wrapper instance. - -class _DeprecatedFieldDescriptor: - """Read-only data descriptor used to emit a runtime deprecation warning before accessing a deprecated field. - - Attributes: - msg: The deprecation message to be emitted. - wrapped_property: The property instance if the deprecated field is a computed field, or `None`. - field_name: The name of the field being deprecated. + Returns: + The model signature. """ + from inspect import Parameter, Signature, signature + from itertools import islice - field_name: str + present_params = signature(init).parameters.values() + merged_params: dict[str, Parameter] = {} + var_kw = None + use_var_kw = False - def __init__(self, msg: str, wrapped_property: property | None = None) -> None: - self.msg = msg - self.wrapped_property = wrapped_property + for param in islice(present_params, 1, None): # skip self arg + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if param.annotation == 'Any': + param = param.replace(annotation=Any) + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = param - def __set_name__(self, cls: type[BaseModel], name: str) -> None: - self.field_name = name + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config_wrapper.populate_by_name + for field_name, field in fields.items(): + # when alias is a str it should be used for signature generation + if isinstance(field.alias, str): + param_name = field.alias + else: + param_name = field_name - def __get__(self, obj: BaseModel | None, obj_type: type[BaseModel] | None = None) -> Any: - if obj is None: - if self.wrapped_property is not None: - return self.wrapped_property.__get__(None, obj_type) - raise AttributeError(self.field_name) + if field_name in merged_params or param_name in merged_params: + continue - warnings.warn(self.msg, builtins.DeprecationWarning, stacklevel=2) + if not is_valid_identifier(param_name): + if allow_names and is_valid_identifier(field_name): + param_name = field_name + else: + use_var_kw = True + continue - if self.wrapped_property is not None: - return self.wrapped_property.__get__(obj, obj_type) - return obj.__dict__[self.field_name] + kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} + merged_params[param_name] = Parameter( + param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs + ) - # Defined to make it a data descriptor and take precedence over the instance's dictionary. - # Note that it will not be called when setting a value on a model instance - # as `BaseModel.__setattr__` is defined and takes priority. - def __set__(self, obj: Any, value: Any) -> NoReturn: - raise AttributeError(self.field_name) + if config_wrapper.extra == 'allow': + use_var_kw = True - -class _PydanticWeakRef: - """Wrapper for `weakref.ref` that enables `pickle` serialization. - - Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related - to abstract base classes (`abc.ABC`). This class works around the issue by wrapping - `weakref.ref` instead of subclassing it. - - See https://github.com/pydantic/pydantic/issues/6763 for context. - - Semantics: - - If not pickled, behaves the same as a `weakref.ref`. - - If pickled along with the referenced object, the same `weakref.ref` behavior - will be maintained between them after unpickling. - - If pickled without the referenced object, after unpickling the underlying - reference will be cleared (`__call__` will always return `None`). - """ - - def __init__(self, obj: Any): - if obj is None: - # The object will be `None` upon deserialization if the serialized weakref - # had lost its underlying object. - self._wr = None + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' else: - self._wr = weakref.ref(obj) + # else start from var_kw + var_kw_name = var_kw.name - def __call__(self) -> Any: - if self._wr is None: - return None - else: - return self._wr() + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) - def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]: - return _PydanticWeakRef, (self(),) + return Signature(parameters=list(merged_params.values()), return_annotation=None) + + +class _PydanticWeakRef(weakref.ReferenceType): + pass def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None: @@ -770,23 +625,3 @@ def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | N else: result[k] = v return result - - -@cache -def default_ignored_types() -> tuple[type[Any], ...]: - from ..fields import ComputedFieldInfo - - ignored_types = [ - FunctionType, - property, - classmethod, - staticmethod, - PydanticDescriptorProxy, - ComputedFieldInfo, - TypeAliasType, # from `typing_extensions` - ] - - if sys.version_info >= (3, 12): - ignored_types.append(typing.TypeAliasType) - - return tuple(ignored_types) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_namespace_utils.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_namespace_utils.py deleted file mode 100644 index 781dfa2..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_namespace_utils.py +++ /dev/null @@ -1,293 +0,0 @@ -from __future__ import annotations - -import sys -from collections.abc import Generator, Iterator, Mapping -from contextlib import contextmanager -from functools import cached_property -from typing import Any, Callable, NamedTuple, TypeVar - -from typing_extensions import ParamSpec, TypeAlias, TypeAliasType, TypeVarTuple - -GlobalsNamespace: TypeAlias = 'dict[str, Any]' -"""A global namespace. - -In most cases, this is a reference to the `__dict__` attribute of a module. -This namespace type is expected as the `globals` argument during annotations evaluation. -""" - -MappingNamespace: TypeAlias = Mapping[str, Any] -"""Any kind of namespace. - -In most cases, this is a local namespace (e.g. the `__dict__` attribute of a class, -the [`f_locals`][frame.f_locals] attribute of a frame object, when dealing with types -defined inside functions). -This namespace type is expected as the `locals` argument during annotations evaluation. -""" - -_TypeVarLike: TypeAlias = 'TypeVar | ParamSpec | TypeVarTuple' - - -class NamespacesTuple(NamedTuple): - """A tuple of globals and locals to be used during annotations evaluation. - - This datastructure is defined as a named tuple so that it can easily be unpacked: - - ```python {lint="skip" test="skip"} - def eval_type(typ: type[Any], ns: NamespacesTuple) -> None: - return eval(typ, *ns) - ``` - """ - - globals: GlobalsNamespace - """The namespace to be used as the `globals` argument during annotations evaluation.""" - - locals: MappingNamespace - """The namespace to be used as the `locals` argument during annotations evaluation.""" - - -def get_module_ns_of(obj: Any) -> dict[str, Any]: - """Get the namespace of the module where the object is defined. - - Caution: this function does not return a copy of the module namespace, so the result - should not be mutated. The burden of enforcing this is on the caller. - """ - module_name = getattr(obj, '__module__', None) - if module_name: - try: - return sys.modules[module_name].__dict__ - except KeyError: - # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 - return {} - return {} - - -# Note that this class is almost identical to `collections.ChainMap`, but need to enforce -# immutable mappings here: -class LazyLocalNamespace(Mapping[str, Any]): - """A lazily evaluated mapping, to be used as the `locals` argument during annotations evaluation. - - While the [`eval`][eval] function expects a mapping as the `locals` argument, it only - performs `__getitem__` calls. The [`Mapping`][collections.abc.Mapping] abstract base class - is fully implemented only for type checking purposes. - - Args: - *namespaces: The namespaces to consider, in ascending order of priority. - - Example: - ```python {lint="skip" test="skip"} - ns = LazyLocalNamespace({'a': 1, 'b': 2}, {'a': 3}) - ns['a'] - #> 3 - ns['b'] - #> 2 - ``` - """ - - def __init__(self, *namespaces: MappingNamespace) -> None: - self._namespaces = namespaces - - @cached_property - def data(self) -> dict[str, Any]: - return {k: v for ns in self._namespaces for k, v in ns.items()} - - def __len__(self) -> int: - return len(self.data) - - def __getitem__(self, key: str) -> Any: - return self.data[key] - - def __contains__(self, key: object) -> bool: - return key in self.data - - def __iter__(self) -> Iterator[str]: - return iter(self.data) - - -def ns_for_function(obj: Callable[..., Any], parent_namespace: MappingNamespace | None = None) -> NamespacesTuple: - """Return the global and local namespaces to be used when evaluating annotations for the provided function. - - The global namespace will be the `__dict__` attribute of the module the function was defined in. - The local namespace will contain the `__type_params__` introduced by PEP 695. - - Args: - obj: The object to use when building namespaces. - parent_namespace: Optional namespace to be added with the lowest priority in the local namespace. - If the passed function is a method, the `parent_namespace` will be the namespace of the class - the method is defined in. Thus, we also fetch type `__type_params__` from there (i.e. the - class-scoped type variables). - """ - locals_list: list[MappingNamespace] = [] - if parent_namespace is not None: - locals_list.append(parent_namespace) - - # Get the `__type_params__` attribute introduced by PEP 695. - # Note that the `typing._eval_type` function expects type params to be - # passed as a separate argument. However, internally, `_eval_type` calls - # `ForwardRef._evaluate` which will merge type params with the localns, - # essentially mimicking what we do here. - type_params: tuple[_TypeVarLike, ...] = getattr(obj, '__type_params__', ()) - if parent_namespace is not None: - # We also fetch type params from the parent namespace. If present, it probably - # means the function was defined in a class. This is to support the following: - # https://github.com/python/cpython/issues/124089. - type_params += parent_namespace.get('__type_params__', ()) - - locals_list.append({t.__name__: t for t in type_params}) - - # What about short-cirtuiting to `obj.__globals__`? - globalns = get_module_ns_of(obj) - - return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list)) - - -class NsResolver: - """A class responsible for the namespaces resolving logic for annotations evaluation. - - This class handles the namespace logic when evaluating annotations mainly for class objects. - - It holds a stack of classes that are being inspected during the core schema building, - and the `types_namespace` property exposes the globals and locals to be used for - type annotation evaluation. Additionally -- if no class is present in the stack -- a - fallback globals and locals can be provided using the `namespaces_tuple` argument - (this is useful when generating a schema for a simple annotation, e.g. when using - `TypeAdapter`). - - The namespace creation logic is unfortunately flawed in some cases, for backwards - compatibility reasons and to better support valid edge cases. See the description - for the `parent_namespace` argument and the example for more details. - - Args: - namespaces_tuple: The default globals and locals to use if no class is present - on the stack. This can be useful when using the `GenerateSchema` class - with `TypeAdapter`, where the "type" being analyzed is a simple annotation. - parent_namespace: An optional parent namespace that will be added to the locals - with the lowest priority. For a given class defined in a function, the locals - of this function are usually used as the parent namespace: - - ```python {lint="skip" test="skip"} - from pydantic import BaseModel - - def func() -> None: - SomeType = int - - class Model(BaseModel): - f: 'SomeType' - - # when collecting fields, an namespace resolver instance will be created - # this way: - # ns_resolver = NsResolver(parent_namespace={'SomeType': SomeType}) - ``` - - For backwards compatibility reasons and to support valid edge cases, this parent - namespace will be used for *every* type being pushed to the stack. In the future, - we might want to be smarter by only doing so when the type being pushed is defined - in the same module as the parent namespace. - - Example: - ```python {lint="skip" test="skip"} - ns_resolver = NsResolver( - parent_namespace={'fallback': 1}, - ) - - class Sub: - m: 'Model' - - class Model: - some_local = 1 - sub: Sub - - ns_resolver = NsResolver() - - # This is roughly what happens when we build a core schema for `Model`: - with ns_resolver.push(Model): - ns_resolver.types_namespace - #> NamespacesTuple({'Sub': Sub}, {'Model': Model, 'some_local': 1}) - # First thing to notice here, the model being pushed is added to the locals. - # Because `NsResolver` is being used during the model definition, it is not - # yet added to the globals. This is useful when resolving self-referencing annotations. - - with ns_resolver.push(Sub): - ns_resolver.types_namespace - #> NamespacesTuple({'Sub': Sub}, {'Sub': Sub, 'Model': Model}) - # Second thing to notice: `Sub` is present in both the globals and locals. - # This is not an issue, just that as described above, the model being pushed - # is added to the locals, but it happens to be present in the globals as well - # because it is already defined. - # Third thing to notice: `Model` is also added in locals. This is a backwards - # compatibility workaround that allows for `Sub` to be able to resolve `'Model'` - # correctly (as otherwise models would have to be rebuilt even though this - # doesn't look necessary). - ``` - """ - - def __init__( - self, - namespaces_tuple: NamespacesTuple | None = None, - parent_namespace: MappingNamespace | None = None, - ) -> None: - self._base_ns_tuple = namespaces_tuple or NamespacesTuple({}, {}) - self._parent_ns = parent_namespace - self._types_stack: list[type[Any] | TypeAliasType] = [] - - @cached_property - def types_namespace(self) -> NamespacesTuple: - """The current global and local namespaces to be used for annotations evaluation.""" - if not self._types_stack: - # TODO: should we merge the parent namespace here? - # This is relevant for TypeAdapter, where there are no types on the stack, and we might - # need access to the parent_ns. Right now, we sidestep this in `type_adapter.py` by passing - # locals to both parent_ns and the base_ns_tuple, but this is a bit hacky. - # we might consider something like: - # if self._parent_ns is not None: - # # Hacky workarounds, see class docstring: - # # An optional parent namespace that will be added to the locals with the lowest priority - # locals_list: list[MappingNamespace] = [self._parent_ns, self._base_ns_tuple.locals] - # return NamespacesTuple(self._base_ns_tuple.globals, LazyLocalNamespace(*locals_list)) - return self._base_ns_tuple - - typ = self._types_stack[-1] - - globalns = get_module_ns_of(typ) - - locals_list: list[MappingNamespace] = [] - # Hacky workarounds, see class docstring: - # An optional parent namespace that will be added to the locals with the lowest priority - if self._parent_ns is not None: - locals_list.append(self._parent_ns) - if len(self._types_stack) > 1: - first_type = self._types_stack[0] - locals_list.append({first_type.__name__: first_type}) - - # Adding `__type_params__` *before* `vars(typ)`, as the latter takes priority - # (see https://github.com/python/cpython/pull/120272). - # TODO `typ.__type_params__` when we drop support for Python 3.11: - type_params: tuple[_TypeVarLike, ...] = getattr(typ, '__type_params__', ()) - if type_params: - # Adding `__type_params__` is mostly useful for generic classes defined using - # PEP 695 syntax *and* using forward annotations (see the example in - # https://github.com/python/cpython/issues/114053). For TypeAliasType instances, - # it is way less common, but still required if using a string annotation in the alias - # value, e.g. `type A[T] = 'T'` (which is not necessary in most cases). - locals_list.append({t.__name__: t for t in type_params}) - - # TypeAliasType instances don't have a `__dict__` attribute, so the check - # is necessary: - if hasattr(typ, '__dict__'): - locals_list.append(vars(typ)) - - # The `len(self._types_stack) > 1` check above prevents this from being added twice: - locals_list.append({typ.__name__: typ}) - - return NamespacesTuple(globalns, LazyLocalNamespace(*locals_list)) - - @contextmanager - def push(self, typ: type[Any] | TypeAliasType, /) -> Generator[None]: - """Push a type to the stack.""" - self._types_stack.append(typ) - # Reset the cached property: - self.__dict__.pop('types_namespace', None) - try: - yield - finally: - self._types_stack.pop() - self.__dict__.pop('types_namespace', None) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_repr.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_repr.py index bf3cae5..6250722 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_repr.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_repr.py @@ -1,5 +1,4 @@ """Tools to provide pretty/human-readable display of objects.""" - from __future__ import annotations as _annotations import types @@ -7,8 +6,6 @@ import typing from typing import Any import typing_extensions -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin from . import _typing_extra @@ -35,7 +32,7 @@ class Representation: # (this is not a docstring to avoid adding a docstring to classes which inherit from Representation) # we don't want to use a type annotation here as it can break get_type_hints - __slots__ = () # type: typing.Collection[str] + __slots__ = tuple() # type: typing.Collection[str] def __repr_args__(self) -> ReprArgs: """Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. @@ -48,17 +45,12 @@ class Representation: if not attrs_names and hasattr(self, '__dict__'): attrs_names = self.__dict__.keys() attrs = ((s, getattr(self, s)) for s in attrs_names) - return [(a, v if v is not self else self.__repr_recursion__(v)) for a, v in attrs if v is not None] + return [(a, v) for a, v in attrs if v is not None] def __repr_name__(self) -> str: """Name of the instance's class, used in __repr__.""" return self.__class__.__name__ - def __repr_recursion__(self, object: Any) -> str: - """Returns the string representation of a recursive object.""" - # This is copied over from the stdlib `pprint` module: - return f'' - def __repr_str__(self, join_str: str) -> str: return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) @@ -95,30 +87,25 @@ def display_as_type(obj: Any) -> str: Takes some logic from `typing._type_repr`. """ - if isinstance(obj, (types.FunctionType, types.BuiltinFunctionType)): + if isinstance(obj, types.FunctionType): return obj.__name__ elif obj is ...: return '...' elif isinstance(obj, Representation): return repr(obj) - elif isinstance(obj, typing.ForwardRef) or typing_objects.is_typealiastype(obj): - return str(obj) if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)): obj = obj.__class__ - if is_union_origin(typing_extensions.get_origin(obj)): + if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)): args = ', '.join(map(display_as_type, typing_extensions.get_args(obj))) return f'Union[{args}]' elif isinstance(obj, _typing_extra.WithArgsTypes): - if typing_objects.is_literal(typing_extensions.get_origin(obj)): + if typing_extensions.get_origin(obj) == typing_extensions.Literal: args = ', '.join(map(repr, typing_extensions.get_args(obj))) else: args = ', '.join(map(display_as_type, typing_extensions.get_args(obj))) - try: - return f'{obj.__qualname__}[{args}]' - except AttributeError: - return str(obj).replace('typing.', '').replace('typing_extensions.', '') # handles TypeAliasType in 3.12 + return f'{obj.__qualname__}[{args}]' elif isinstance(obj, type): return obj.__qualname__ else: diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_gather.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_gather.py deleted file mode 100644 index fc2d806..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_gather.py +++ /dev/null @@ -1,209 +0,0 @@ -# pyright: reportTypedDictNotRequiredAccess=false, reportGeneralTypeIssues=false, reportArgumentType=false, reportAttributeAccessIssue=false -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import TypedDict - -from pydantic_core.core_schema import ComputedField, CoreSchema, DefinitionReferenceSchema, SerSchema -from typing_extensions import TypeAlias - -AllSchemas: TypeAlias = 'CoreSchema | SerSchema | ComputedField' - - -class GatherResult(TypedDict): - """Schema traversing result.""" - - collected_references: dict[str, DefinitionReferenceSchema | None] - """The collected definition references. - - If a definition reference schema can be inlined, it means that there is - only one in the whole core schema. As such, it is stored as the value. - Otherwise, the value is set to `None`. - """ - - deferred_discriminator_schemas: list[CoreSchema] - """The list of core schemas having the discriminator application deferred.""" - - -class MissingDefinitionError(LookupError): - """A reference was pointing to a non-existing core schema.""" - - def __init__(self, schema_reference: str, /) -> None: - self.schema_reference = schema_reference - - -@dataclass -class GatherContext: - """The current context used during core schema traversing. - - Context instances should only be used during schema traversing. - """ - - definitions: dict[str, CoreSchema] - """The available definitions.""" - - deferred_discriminator_schemas: list[CoreSchema] = field(init=False, default_factory=list) - """The list of core schemas having the discriminator application deferred. - - Internally, these core schemas have a specific key set in the core metadata dict. - """ - - collected_references: dict[str, DefinitionReferenceSchema | None] = field(init=False, default_factory=dict) - """The collected definition references. - - If a definition reference schema can be inlined, it means that there is - only one in the whole core schema. As such, it is stored as the value. - Otherwise, the value is set to `None`. - - During schema traversing, definition reference schemas can be added as candidates, or removed - (by setting the value to `None`). - """ - - -def traverse_metadata(schema: AllSchemas, ctx: GatherContext) -> None: - meta = schema.get('metadata') - if meta is not None and 'pydantic_internal_union_discriminator' in meta: - ctx.deferred_discriminator_schemas.append(schema) # pyright: ignore[reportArgumentType] - - -def traverse_definition_ref(def_ref_schema: DefinitionReferenceSchema, ctx: GatherContext) -> None: - schema_ref = def_ref_schema['schema_ref'] - - if schema_ref not in ctx.collected_references: - definition = ctx.definitions.get(schema_ref) - if definition is None: - raise MissingDefinitionError(schema_ref) - - # The `'definition-ref'` schema was only encountered once, make it - # a candidate to be inlined: - ctx.collected_references[schema_ref] = def_ref_schema - traverse_schema(definition, ctx) - if 'serialization' in def_ref_schema: - traverse_schema(def_ref_schema['serialization'], ctx) - traverse_metadata(def_ref_schema, ctx) - else: - # The `'definition-ref'` schema was already encountered, meaning - # the previously encountered schema (and this one) can't be inlined: - ctx.collected_references[schema_ref] = None - - -def traverse_schema(schema: AllSchemas, context: GatherContext) -> None: - # TODO When we drop 3.9, use a match statement to get better type checking and remove - # file-level type ignore. - # (the `'type'` could also be fetched in every `if/elif` statement, but this alters performance). - schema_type = schema['type'] - - if schema_type == 'definition-ref': - traverse_definition_ref(schema, context) - # `traverse_definition_ref` handles the possible serialization and metadata schemas: - return - elif schema_type == 'definitions': - traverse_schema(schema['schema'], context) - for definition in schema['definitions']: - traverse_schema(definition, context) - elif schema_type in {'list', 'set', 'frozenset', 'generator'}: - if 'items_schema' in schema: - traverse_schema(schema['items_schema'], context) - elif schema_type == 'tuple': - if 'items_schema' in schema: - for s in schema['items_schema']: - traverse_schema(s, context) - elif schema_type == 'dict': - if 'keys_schema' in schema: - traverse_schema(schema['keys_schema'], context) - if 'values_schema' in schema: - traverse_schema(schema['values_schema'], context) - elif schema_type == 'union': - for choice in schema['choices']: - if isinstance(choice, tuple): - traverse_schema(choice[0], context) - else: - traverse_schema(choice, context) - elif schema_type == 'tagged-union': - for v in schema['choices'].values(): - traverse_schema(v, context) - elif schema_type == 'chain': - for step in schema['steps']: - traverse_schema(step, context) - elif schema_type == 'lax-or-strict': - traverse_schema(schema['lax_schema'], context) - traverse_schema(schema['strict_schema'], context) - elif schema_type == 'json-or-python': - traverse_schema(schema['json_schema'], context) - traverse_schema(schema['python_schema'], context) - elif schema_type in {'model-fields', 'typed-dict'}: - if 'extras_schema' in schema: - traverse_schema(schema['extras_schema'], context) - if 'computed_fields' in schema: - for s in schema['computed_fields']: - traverse_schema(s, context) - for s in schema['fields'].values(): - traverse_schema(s, context) - elif schema_type == 'dataclass-args': - if 'computed_fields' in schema: - for s in schema['computed_fields']: - traverse_schema(s, context) - for s in schema['fields']: - traverse_schema(s, context) - elif schema_type == 'arguments': - for s in schema['arguments_schema']: - traverse_schema(s['schema'], context) - if 'var_args_schema' in schema: - traverse_schema(schema['var_args_schema'], context) - if 'var_kwargs_schema' in schema: - traverse_schema(schema['var_kwargs_schema'], context) - elif schema_type == 'arguments-v3': - for s in schema['arguments_schema']: - traverse_schema(s['schema'], context) - elif schema_type == 'call': - traverse_schema(schema['arguments_schema'], context) - if 'return_schema' in schema: - traverse_schema(schema['return_schema'], context) - elif schema_type == 'computed-field': - traverse_schema(schema['return_schema'], context) - elif schema_type == 'function-before': - if 'schema' in schema: - traverse_schema(schema['schema'], context) - if 'json_schema_input_schema' in schema: - traverse_schema(schema['json_schema_input_schema'], context) - elif schema_type == 'function-plain': - # TODO duplicate schema types for serializers and validators, needs to be deduplicated. - if 'return_schema' in schema: - traverse_schema(schema['return_schema'], context) - if 'json_schema_input_schema' in schema: - traverse_schema(schema['json_schema_input_schema'], context) - elif schema_type == 'function-wrap': - # TODO duplicate schema types for serializers and validators, needs to be deduplicated. - if 'return_schema' in schema: - traverse_schema(schema['return_schema'], context) - if 'schema' in schema: - traverse_schema(schema['schema'], context) - if 'json_schema_input_schema' in schema: - traverse_schema(schema['json_schema_input_schema'], context) - else: - if 'schema' in schema: - traverse_schema(schema['schema'], context) - - if 'serialization' in schema: - traverse_schema(schema['serialization'], context) - traverse_metadata(schema, context) - - -def gather_schemas_for_cleaning(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> GatherResult: - """Traverse the core schema and definitions and return the necessary information for schema cleaning. - - During the core schema traversing, any `'definition-ref'` schema is: - - - Validated: the reference must point to an existing definition. If this is not the case, a - `MissingDefinitionError` exception is raised. - - Stored in the context: the actual reference is stored in the context. Depending on whether - the `'definition-ref'` schema is encountered more that once, the schema itself is also - saved in the context to be inlined (i.e. replaced by the definition it points to). - """ - context = GatherContext(definitions) - traverse_schema(schema, context) - - return { - 'collected_references': context.collected_references, - 'deferred_discriminator_schemas': context.deferred_discriminator_schemas, - } diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_generation_shared.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_generation_shared.py index b231a82..1a9aa85 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_generation_shared.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_schema_generation_shared.py @@ -1,10 +1,10 @@ """Types and utility functions used by various other internal tools.""" - from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable from pydantic_core import core_schema +from typing_extensions import Literal from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler @@ -12,7 +12,6 @@ if TYPE_CHECKING: from ..json_schema import GenerateJsonSchema, JsonSchemaValue from ._core_utils import CoreSchemaOrField from ._generate_schema import GenerateSchema - from ._namespace_utils import NamespacesTuple GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue] HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue] @@ -33,8 +32,8 @@ class GenerateJsonSchemaHandler(GetJsonSchemaHandler): self.handler = handler_override or generate_json_schema.generate_inner self.mode = generate_json_schema.mode - def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue: - return self.handler(core_schema) + def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue: + return self.handler(__core_schema) def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue: """Resolves `$ref` in the json schema. @@ -79,21 +78,22 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler): self._generate_schema = generate_schema self._ref_mode = ref_mode - def __call__(self, source_type: Any, /) -> core_schema.CoreSchema: - schema = self._handler(source_type) + def __call__(self, __source_type: Any) -> core_schema.CoreSchema: + schema = self._handler(__source_type) + ref = schema.get('ref') if self._ref_mode == 'to-def': - ref = schema.get('ref') if ref is not None: - return self._generate_schema.defs.create_definition_reference_schema(schema) + self._generate_schema.defs.definitions[ref] = schema + return core_schema.definition_reference_schema(ref) return schema - else: # ref_mode = 'unpack' + else: # ref_mode = 'unpack return self.resolve_ref_schema(schema) - def _get_types_namespace(self) -> NamespacesTuple: + def _get_types_namespace(self) -> dict[str, Any] | None: return self._generate_schema._types_namespace - def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema: - return self._generate_schema.generate_schema(source_type) + def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema: + return self._generate_schema.generate_schema(__source_type) @property def field_name(self) -> str | None: @@ -113,13 +113,12 @@ class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler): """ if maybe_ref_schema['type'] == 'definition-ref': ref = maybe_ref_schema['schema_ref'] - definition = self._generate_schema.defs.get_schema_from_ref(ref) - if definition is None: + if ref not in self._generate_schema.defs.definitions: raise LookupError( f'Could not find a ref for {ref}.' ' Maybe you tried to call resolve_ref_schema from within a recursive model?' ) - return definition + return self._generate_schema.defs.definitions[ref] elif maybe_ref_schema['type'] == 'definitions': return self.resolve_ref_schema(maybe_ref_schema['schema']) return maybe_ref_schema diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_serializers.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_serializers.py deleted file mode 100644 index d059321..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_serializers.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import collections -import collections.abc -import typing -from typing import Any - -from pydantic_core import PydanticOmit, core_schema - -SEQUENCE_ORIGIN_MAP: dict[Any, Any] = { - typing.Deque: collections.deque, # noqa: UP006 - collections.deque: collections.deque, - list: list, - typing.List: list, # noqa: UP006 - tuple: tuple, - typing.Tuple: tuple, # noqa: UP006 - set: set, - typing.AbstractSet: set, - typing.Set: set, # noqa: UP006 - frozenset: frozenset, - typing.FrozenSet: frozenset, # noqa: UP006 - typing.Sequence: list, - typing.MutableSequence: list, - typing.MutableSet: set, - # this doesn't handle subclasses of these - # parametrized typing.Set creates one of these - collections.abc.MutableSet: set, - collections.abc.Set: frozenset, -} - - -def serialize_sequence_via_list( - v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo -) -> Any: - items: list[Any] = [] - - mapped_origin = SEQUENCE_ORIGIN_MAP.get(type(v), None) - if mapped_origin is None: - # we shouldn't hit this branch, should probably add a serialization error or something - return v - - for index, item in enumerate(v): - try: - v = handler(item, index) - except PydanticOmit: - pass - else: - items.append(v) - - if info.mode_is_json(): - return items - else: - return mapped_origin(items) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_signature.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_signature.py deleted file mode 100644 index 977e5d2..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_signature.py +++ /dev/null @@ -1,188 +0,0 @@ -from __future__ import annotations - -import dataclasses -from inspect import Parameter, Signature, signature -from typing import TYPE_CHECKING, Any, Callable - -from pydantic_core import PydanticUndefined - -from ._utils import is_valid_identifier - -if TYPE_CHECKING: - from ..config import ExtraValues - from ..fields import FieldInfo - - -# Copied over from stdlib dataclasses -class _HAS_DEFAULT_FACTORY_CLASS: - def __repr__(self): - return '' - - -_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() - - -def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str: - """Extract the correct name to use for the field when generating a signature. - - Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name. - First priority is given to the alias, then the validation_alias, then the field name. - - Args: - field_name: The name of the field - field_info: The corresponding FieldInfo object. - - Returns: - The correct name to use when generating a signature. - """ - if isinstance(field_info.alias, str) and is_valid_identifier(field_info.alias): - return field_info.alias - if isinstance(field_info.validation_alias, str) and is_valid_identifier(field_info.validation_alias): - return field_info.validation_alias - - return field_name - - -def _process_param_defaults(param: Parameter) -> Parameter: - """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance. - - Args: - param (Parameter): The parameter - - Returns: - Parameter: The custom processed parameter - """ - from ..fields import FieldInfo - - param_default = param.default - if isinstance(param_default, FieldInfo): - annotation = param.annotation - # Replace the annotation if appropriate - # inspect does "clever" things to show annotations as strings because we have - # `from __future__ import annotations` in main, we don't want that - if annotation == 'Any': - annotation = Any - - # Replace the field default - default = param_default.default - if default is PydanticUndefined: - if param_default.default_factory is PydanticUndefined: - default = Signature.empty - else: - # this is used by dataclasses to indicate a factory exists: - default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore - return param.replace( - annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default - ) - return param - - -def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor) - init: Callable[..., None], - fields: dict[str, FieldInfo], - validate_by_name: bool, - extra: ExtraValues | None, -) -> dict[str, Parameter]: - """Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass.""" - from itertools import islice - - present_params = signature(init).parameters.values() - merged_params: dict[str, Parameter] = {} - var_kw = None - use_var_kw = False - - for param in islice(present_params, 1, None): # skip self arg - # inspect does "clever" things to show annotations as strings because we have - # `from __future__ import annotations` in main, we don't want that - if fields.get(param.name): - # exclude params with init=False - if getattr(fields[param.name], 'init', True) is False: - continue - param = param.replace(name=_field_name_for_signature(param.name, fields[param.name])) - if param.annotation == 'Any': - param = param.replace(annotation=Any) - if param.kind is param.VAR_KEYWORD: - var_kw = param - continue - merged_params[param.name] = param - - if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through - allow_names = validate_by_name - for field_name, field in fields.items(): - # when alias is a str it should be used for signature generation - param_name = _field_name_for_signature(field_name, field) - - if field_name in merged_params or param_name in merged_params: - continue - - if not is_valid_identifier(param_name): - if allow_names: - param_name = field_name - else: - use_var_kw = True - continue - - if field.is_required(): - default = Parameter.empty - elif field.default_factory is not None: - # Mimics stdlib dataclasses: - default = _HAS_DEFAULT_FACTORY - else: - default = field.default - merged_params[param_name] = Parameter( - param_name, - Parameter.KEYWORD_ONLY, - annotation=field.rebuild_annotation(), - default=default, - ) - - if extra == 'allow': - use_var_kw = True - - if var_kw and use_var_kw: - # Make sure the parameter for extra kwargs - # does not have the same name as a field - default_model_signature = [ - ('self', Parameter.POSITIONAL_ONLY), - ('data', Parameter.VAR_KEYWORD), - ] - if [(p.name, p.kind) for p in present_params] == default_model_signature: - # if this is the standard model signature, use extra_data as the extra args name - var_kw_name = 'extra_data' - else: - # else start from var_kw - var_kw_name = var_kw.name - - # generate a name that's definitely unique - while var_kw_name in fields: - var_kw_name += '_' - merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) - - return merged_params - - -def generate_pydantic_signature( - init: Callable[..., None], - fields: dict[str, FieldInfo], - validate_by_name: bool, - extra: ExtraValues | None, - is_dataclass: bool = False, -) -> Signature: - """Generate signature for a pydantic BaseModel or dataclass. - - Args: - init: The class init. - fields: The model fields. - validate_by_name: The `validate_by_name` value of the config. - extra: The `extra` value of the config. - is_dataclass: Whether the model is a dataclass. - - Returns: - The dataclass/BaseModel subclass signature. - """ - merged_params = _generate_signature_parameters(init, fields, validate_by_name, extra) - - if is_dataclass: - merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()} - - return Signature(parameters=list(merged_params.values()), return_annotation=None) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_std_types_schema.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_std_types_schema.py new file mode 100644 index 0000000..2c1cef2 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_std_types_schema.py @@ -0,0 +1,713 @@ +"""Logic for generating pydantic-core schemas for standard library types. + +Import of this module is deferred since it contains imports of many standard library modules. +""" +from __future__ import annotations as _annotations + +import collections +import collections.abc +import dataclasses +import decimal +import inspect +import os +import typing +from enum import Enum +from functools import partial +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from typing import Any, Callable, Iterable, TypeVar + +import typing_extensions +from pydantic_core import ( + CoreSchema, + MultiHostUrl, + PydanticCustomError, + PydanticOmit, + Url, + core_schema, +) +from typing_extensions import get_args, get_origin + +from pydantic.errors import PydanticSchemaGenerationError +from pydantic.fields import FieldInfo +from pydantic.types import Strict + +from ..config import ConfigDict +from ..json_schema import JsonSchemaValue, update_json_schema +from . import _known_annotated_metadata, _typing_extra, _validators +from ._core_utils import get_type_ref +from ._internal_dataclass import slots_true +from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler + +if typing.TYPE_CHECKING: + from ._generate_schema import GenerateSchema + + StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema] + + +@dataclasses.dataclass(**slots_true) +class SchemaTransformer: + get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema] + get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue] + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + return self.get_core_schema(source_type, handler) + + def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + return self.get_json_schema(schema, handler) + + +def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema: + cases: list[Any] = list(enum_type.__members__.values()) + + if not cases: + # Use an isinstance check for enums with no cases. + # This won't work with serialization or JSON schema, but that's okay -- the most important + # use case for this is creating typevar bounds for generics that should be restricted to enums. + # This is more consistent than it might seem at first, since you can only subclass enum.Enum + # (or subclasses of enum.Enum) if all parent classes have no cases. + return core_schema.is_instance_schema(enum_type) + + use_enum_values = config.get('use_enum_values', False) + + if len(cases) == 1: + expected = repr(cases[0].value) + else: + expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}' + + def to_enum(__input_value: Any) -> Enum: + try: + enum_field = enum_type(__input_value) + if use_enum_values: + return enum_field.value + return enum_field + except ValueError: + # The type: ignore on the next line is to ignore the requirement of LiteralString + raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore + + enum_ref = get_type_ref(enum_type) + description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) + if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it + description = None + updates = {'title': enum_type.__name__, 'description': description} + updates = {k: v for k, v in updates.items() if v is not None} + + def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref)) + original_schema = handler.resolve_ref_schema(json_schema) + update_json_schema(original_schema, updates) + return json_schema + + strict_python_schema = core_schema.is_instance_schema(enum_type) + if use_enum_values: + strict_python_schema = core_schema.chain_schema( + [strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)] + ) + + to_enum_validator = core_schema.no_info_plain_validator_function(to_enum) + if issubclass(enum_type, int): + # this handles `IntEnum`, and also `Foobar(int, Enum)` + updates['type'] = 'integer' + lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator]) + # Disallow float from JSON due to strict mode + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()), + python_schema=strict_python_schema, + ) + elif issubclass(enum_type, str): + # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)` + updates['type'] = 'string' + lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator]) + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()), + python_schema=strict_python_schema, + ) + elif issubclass(enum_type, float): + updates['type'] = 'numeric' + lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator]) + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()), + python_schema=strict_python_schema, + ) + else: + lax = to_enum_validator + strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema) + return core_schema.lax_or_strict_schema( + lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]} + ) + + +@dataclasses.dataclass(**slots_true) +class InnerSchemaValidator: + """Use a fixed CoreSchema, avoiding interference from outward annotations.""" + + core_schema: CoreSchema + js_schema: JsonSchemaValue | None = None + js_core_schema: CoreSchema | None = None + js_schema_update: JsonSchemaValue | None = None + + def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + if self.js_schema is not None: + return self.js_schema + js_schema = handler(self.js_core_schema or self.core_schema) + if self.js_schema_update is not None: + js_schema.update(self.js_schema_update) + return js_schema + + def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema: + return self.core_schema + + +def decimal_prepare_pydantic_annotations( + source: Any, annotations: Iterable[Any], config: ConfigDict +) -> tuple[Any, list[Any]] | None: + if source is not decimal.Decimal: + return None + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + + config_allow_inf_nan = config.get('allow_inf_nan') + if config_allow_inf_nan is not None: + metadata.setdefault('allow_inf_nan', config_allow_inf_nan) + + _known_annotated_metadata.check_metadata( + metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal + ) + return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations] + + +def datetime_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + import datetime + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + if source_type is datetime.date: + sv = InnerSchemaValidator(core_schema.date_schema(**metadata)) + elif source_type is datetime.datetime: + sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata)) + elif source_type is datetime.time: + sv = InnerSchemaValidator(core_schema.time_schema(**metadata)) + elif source_type is datetime.timedelta: + sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata)) + else: + return None + # check now that we know the source type is correct + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type) + return (source_type, [sv, *remaining_annotations]) + + +def uuid_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + # UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length + + from uuid import UUID + + if source_type is not UUID: + return None + + return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations]) + + +def path_schema_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + import pathlib + + if source_type not in { + os.PathLike, + pathlib.Path, + pathlib.PurePath, + pathlib.PosixPath, + pathlib.PurePosixPath, + pathlib.PureWindowsPath, + }: + return None + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type) + + construct_path = pathlib.PurePath if source_type is os.PathLike else source_type + + def path_validator(input_value: str) -> os.PathLike[Any]: + try: + return construct_path(input_value) + except TypeError as e: + raise PydanticCustomError('path_type', 'Input is not a valid path') from e + + constrained_str_schema = core_schema.str_schema(**metadata) + + instance_schema = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema), + python_schema=core_schema.is_instance_schema(source_type), + ) + + strict: bool | None = None + for annotation in annotations: + if isinstance(annotation, Strict): + strict = annotation.strict + + schema = core_schema.lax_or_strict_schema( + lax_schema=core_schema.union_schema( + [ + instance_schema, + core_schema.no_info_after_validator_function(path_validator, constrained_str_schema), + ], + custom_error_type='path_type', + custom_error_message='Input is not a valid path', + strict=True, + ), + strict_schema=instance_schema, + serialization=core_schema.to_string_ser_schema(), + strict=strict, + ) + + return ( + source_type, + [ + InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}), + *remaining_annotations, + ], + ) + + +def dequeue_validator( + input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int +) -> collections.deque[Any]: + if isinstance(input_value, collections.deque): + maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None] + if maxlens: + maxlen = min(maxlens) + return collections.deque(handler(input_value), maxlen=maxlen) + else: + return collections.deque(handler(input_value), maxlen=maxlen) + + +@dataclasses.dataclass(**slots_true) +class SequenceValidator: + mapped_origin: type[Any] + item_source_type: type[Any] + min_length: int | None = None + max_length: int | None = None + strict: bool = False + + def serialize_sequence_via_list( + self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo + ) -> Any: + items: list[Any] = [] + for index, item in enumerate(v): + try: + v = handler(item, index) + except PydanticOmit: + pass + else: + items.append(v) + + if info.mode_is_json(): + return items + else: + return self.mapped_origin(items) + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + if self.item_source_type is Any: + items_schema = None + else: + items_schema = handler.generate_schema(self.item_source_type) + + metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict} + + if self.mapped_origin in (list, set, frozenset): + if self.mapped_origin is list: + constrained_schema = core_schema.list_schema(items_schema, **metadata) + elif self.mapped_origin is set: + constrained_schema = core_schema.set_schema(items_schema, **metadata) + else: + assert self.mapped_origin is frozenset # safety check in case we forget to add a case + constrained_schema = core_schema.frozenset_schema(items_schema, **metadata) + + schema = constrained_schema + else: + # safety check in case we forget to add a case + assert self.mapped_origin in (collections.deque, collections.Counter) + + if self.mapped_origin is collections.deque: + # if we have a MaxLen annotation might as well set that as the default maxlen on the deque + # this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue + # that e.g. comes from JSON + coerce_instance_wrap = partial( + core_schema.no_info_wrap_validator_function, + partial(dequeue_validator, maxlen=metadata.get('max_length', None)), + ) + else: + coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin) + + constrained_schema = core_schema.list_schema(items_schema, **metadata) + + check_instance = core_schema.json_or_python_schema( + json_schema=core_schema.list_schema(), + python_schema=core_schema.is_instance_schema(self.mapped_origin), + ) + + serialization = core_schema.wrap_serializer_function_ser_schema( + self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True + ) + + strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)]) + + if metadata.get('strict', False): + schema = strict + else: + lax = coerce_instance_wrap(constrained_schema) + schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict) + schema['serialization'] = serialization + + return schema + + +SEQUENCE_ORIGIN_MAP: dict[Any, Any] = { + typing.Deque: collections.deque, + collections.deque: collections.deque, + list: list, + typing.List: list, + set: set, + typing.AbstractSet: set, + typing.Set: set, + frozenset: frozenset, + typing.FrozenSet: frozenset, + typing.Sequence: list, + typing.MutableSequence: list, + typing.MutableSet: set, + # this doesn't handle subclasses of these + # parametrized typing.Set creates one of these + collections.abc.MutableSet: set, + collections.abc.Set: frozenset, +} + + +def identity(s: CoreSchema) -> CoreSchema: + return s + + +def sequence_like_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + origin: Any = get_origin(source_type) + + mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None) + if mapped_origin is None: + return None + + args = get_args(source_type) + + if not args: + args = (Any,) + elif len(args) != 1: + raise ValueError('Expected sequence to have exactly 1 generic parameter') + + item_source_type = args[0] + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type) + + return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations]) + + +MAPPING_ORIGIN_MAP: dict[Any, Any] = { + typing.DefaultDict: collections.defaultdict, + collections.defaultdict: collections.defaultdict, + collections.OrderedDict: collections.OrderedDict, + typing_extensions.OrderedDict: collections.OrderedDict, + dict: dict, + typing.Dict: dict, + collections.Counter: collections.Counter, + typing.Counter: collections.Counter, + # this doesn't handle subclasses of these + typing.Mapping: dict, + typing.MutableMapping: dict, + # parametrized typing.{Mutable}Mapping creates one of these + collections.abc.MutableMapping: dict, + collections.abc.Mapping: dict, +} + + +def defaultdict_validator( + input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any] +) -> collections.defaultdict[Any, Any]: + if isinstance(input_value, collections.defaultdict): + default_factory = input_value.default_factory + return collections.defaultdict(default_factory, handler(input_value)) + else: + return collections.defaultdict(default_default_factory, handler(input_value)) + + +def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]: + def infer_default() -> Callable[[], Any]: + allowed_default_types: dict[Any, Any] = { + typing.Tuple: tuple, + tuple: tuple, + collections.abc.Sequence: tuple, + collections.abc.MutableSequence: list, + typing.List: list, + list: list, + typing.Sequence: list, + typing.Set: set, + set: set, + typing.MutableSet: set, + collections.abc.MutableSet: set, + collections.abc.Set: frozenset, + typing.MutableMapping: dict, + typing.Mapping: dict, + collections.abc.Mapping: dict, + collections.abc.MutableMapping: dict, + float: float, + int: int, + str: str, + bool: bool, + } + values_type_origin = get_origin(values_source_type) or values_source_type + instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`' + if isinstance(values_type_origin, TypeVar): + + def type_var_default_factory() -> None: + raise RuntimeError( + 'Generic defaultdict cannot be used without a concrete value type or an' + ' explicit default factory, ' + instructions + ) + + return type_var_default_factory + elif values_type_origin not in allowed_default_types: + # a somewhat subjective set of types that have reasonable default values + allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())]) + raise PydanticSchemaGenerationError( + f'Unable to infer a default factory for keys of type {values_source_type}.' + f' Only {allowed_msg} are supported, other types require an explicit default factory' + ' ' + instructions + ) + return allowed_default_types[values_type_origin] + + # Assume Annotated[..., Field(...)] + if _typing_extra.is_annotated(values_source_type): + field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None) + else: + field_info = None + if field_info and field_info.default_factory: + default_default_factory = field_info.default_factory + else: + default_default_factory = infer_default() + return default_default_factory + + +@dataclasses.dataclass(**slots_true) +class MappingValidator: + mapped_origin: type[Any] + keys_source_type: type[Any] + values_source_type: type[Any] + min_length: int | None = None + max_length: int | None = None + strict: bool = False + + def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any: + return handler(v) + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + if self.keys_source_type is Any: + keys_schema = None + else: + keys_schema = handler.generate_schema(self.keys_source_type) + if self.values_source_type is Any: + values_schema = None + else: + values_schema = handler.generate_schema(self.values_source_type) + + metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict} + + if self.mapped_origin is dict: + schema = core_schema.dict_schema(keys_schema, values_schema, **metadata) + else: + constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata) + check_instance = core_schema.json_or_python_schema( + json_schema=core_schema.dict_schema(), + python_schema=core_schema.is_instance_schema(self.mapped_origin), + ) + + if self.mapped_origin is collections.defaultdict: + default_default_factory = get_defaultdict_default_default_factory(self.values_source_type) + coerce_instance_wrap = partial( + core_schema.no_info_wrap_validator_function, + partial(defaultdict_validator, default_default_factory=default_default_factory), + ) + else: + coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin) + + serialization = core_schema.wrap_serializer_function_ser_schema( + self.serialize_mapping_via_dict, + schema=core_schema.dict_schema( + keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema() + ), + info_arg=False, + ) + + strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)]) + + if metadata.get('strict', False): + schema = strict + else: + lax = coerce_instance_wrap(constrained_schema) + schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict) + schema['serialization'] = serialization + + return schema + + +def mapping_like_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + origin: Any = get_origin(source_type) + + mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None) + if mapped_origin is None: + return None + + args = get_args(source_type) + + if not args: + args = (Any, Any) + elif mapped_origin is collections.Counter: + # a single generic + if len(args) != 1: + raise ValueError('Expected Counter to have exactly 1 generic parameter') + args = (args[0], int) # keys are always an int + elif len(args) != 2: + raise ValueError('Expected mapping to have exactly 2 generic parameters') + + keys_source_type, values_source_type = args + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type) + + return ( + source_type, + [ + MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata), + *remaining_annotations, + ], + ) + + +def ip_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + def make_strict_ip_schema(tp: type[Any]) -> CoreSchema: + return core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()), + python_schema=core_schema.is_instance_schema(tp), + ) + + if source_type is IPv4Address: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator), + strict_schema=make_strict_ip_schema(IPv4Address), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4'}, + ), + *annotations, + ] + if source_type is IPv4Network: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator), + strict_schema=make_strict_ip_schema(IPv4Network), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4network'}, + ), + *annotations, + ] + if source_type is IPv4Interface: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator), + strict_schema=make_strict_ip_schema(IPv4Interface), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'}, + ), + *annotations, + ] + + if source_type is IPv6Address: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator), + strict_schema=make_strict_ip_schema(IPv6Address), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6'}, + ), + *annotations, + ] + if source_type is IPv6Network: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator), + strict_schema=make_strict_ip_schema(IPv6Network), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6network'}, + ), + *annotations, + ] + if source_type is IPv6Interface: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator), + strict_schema=make_strict_ip_schema(IPv6Interface), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'}, + ), + *annotations, + ] + + return None + + +def url_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + if source_type is Url: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.url_schema(), + lambda cs, handler: handler(cs), + ), + *annotations, + ] + if source_type is MultiHostUrl: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.multi_host_url_schema(), + lambda cs, handler: handler(cs), + ), + *annotations, + ] + + +PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = ( + decimal_prepare_pydantic_annotations, + sequence_like_prepare_pydantic_annotations, + datetime_prepare_pydantic_annotations, + uuid_prepare_pydantic_annotations, + path_schema_prepare_pydantic_annotations, + mapping_like_prepare_pydantic_annotations, + ip_prepare_pydantic_annotations, + url_prepare_pydantic_annotations, +) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_typing_extra.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_typing_extra.py index 4be1a09..e83e03d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_typing_extra.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_typing_extra.py @@ -1,544 +1,244 @@ -"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap Python's typing module.""" +"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module.""" +from __future__ import annotations as _annotations -from __future__ import annotations - -import collections.abc -import re +import dataclasses import sys import types import typing +from collections.abc import Callable from functools import partial -from typing import TYPE_CHECKING, Any, Callable, cast +from types import GetSetDescriptorType +from typing import TYPE_CHECKING, Any, ForwardRef -import typing_extensions -from typing_extensions import deprecated, get_args, get_origin -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin +from typing_extensions import Annotated, Final, Literal, TypeAliasType, TypeGuard, get_args, get_origin -from pydantic.version import version_short +if TYPE_CHECKING: + from ._dataclasses import StandardDataclass + +try: + from typing import _TypingBase # type: ignore[attr-defined] +except ImportError: + from typing import _Final as _TypingBase # type: ignore[attr-defined] + +typing_base = _TypingBase + + +if sys.version_info < (3, 9): + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + TypingGenericAlias = () +else: + from typing import GenericAlias as TypingGenericAlias # type: ignore + + +if sys.version_info < (3, 11): + from typing_extensions import NotRequired, Required +else: + from typing import NotRequired, Required # noqa: F401 + + +if sys.version_info < (3, 10): + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is typing.Union + + WithArgsTypes = (TypingGenericAlias,) + +else: + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is typing.Union or tp is types.UnionType + + WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined] -from ._namespace_utils import GlobalsNamespace, MappingNamespace, NsResolver, get_module_ns_of if sys.version_info < (3, 10): NoneType = type(None) EllipsisType = type(Ellipsis) else: - from types import EllipsisType as EllipsisType from types import NoneType as NoneType -if TYPE_CHECKING: - from pydantic import BaseModel -# As per https://typing-extensions.readthedocs.io/en/latest/#runtime-use-of-types, -# always check for both `typing` and `typing_extensions` variants of a typing construct. -# (this is implemented differently than the suggested approach in the `typing_extensions` -# docs for performance). +LITERAL_TYPES: set[Any] = {Literal} +if hasattr(typing, 'Literal'): + LITERAL_TYPES.add(typing.Literal) # type: ignore + +NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES)) -_t_annotated = typing.Annotated -_te_annotated = typing_extensions.Annotated +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type -def is_annotated(tp: Any, /) -> bool: - """Return whether the provided argument is a `Annotated` special form. +def is_none_type(type_: Any) -> bool: + return type_ in NONE_TYPES - ```python {test="skip" lint="skip"} - is_annotated(Annotated[int, ...]) - #> True - ``` + +def is_callable_type(type_: type[Any]) -> bool: + return type_ is Callable or get_origin(type_) is Callable + + +def is_literal_type(type_: type[Any]) -> bool: + return Literal is not None and get_origin(type_) in LITERAL_TYPES + + +def literal_values(type_: type[Any]) -> tuple[Any, ...]: + return get_args(type_) + + +def all_literal_values(type_: type[Any]) -> list[Any]: + """This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`. """ - origin = get_origin(tp) - return origin is _t_annotated or origin is _te_annotated + if not is_literal_type(type_): + return [type_] + + values = literal_values(type_) + return list(x for value in values for x in all_literal_values(value)) -def annotated_type(tp: Any, /) -> Any | None: - """Return the type of the `Annotated` special form, or `None`.""" - return tp.__origin__ if typing_objects.is_annotated(get_origin(tp)) else None +def is_annotated(ann_type: Any) -> bool: + from ._utils import lenient_issubclass + + origin = get_origin(ann_type) + return origin is not None and lenient_issubclass(origin, Annotated) -def unpack_type(tp: Any, /) -> Any | None: - """Return the type wrapped by the `Unpack` special form, or `None`.""" - return get_args(tp)[0] if typing_objects.is_unpack(get_origin(tp)) else None - - -def is_hashable(tp: Any, /) -> bool: - """Return whether the provided argument is the `Hashable` class. - - ```python {test="skip" lint="skip"} - is_hashable(Hashable) - #> True - ``` +def is_namedtuple(type_: type[Any]) -> bool: + """Check if a given class is a named tuple. + It can be either a `typing.NamedTuple` or `collections.namedtuple`. """ - # `get_origin` is documented as normalizing any typing-module aliases to `collections` classes, - # hence the second check: - return tp is collections.abc.Hashable or get_origin(tp) is collections.abc.Hashable + from ._utils import lenient_issubclass + + return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') -def is_callable(tp: Any, /) -> bool: - """Return whether the provided argument is a `Callable`, parametrized or not. +test_new_type = typing.NewType('test_new_type', str) - ```python {test="skip" lint="skip"} - is_callable(Callable[[int], str]) - #> True - is_callable(typing.Callable) - #> True - is_callable(collections.abc.Callable) - #> True - ``` + +def is_new_type(type_: type[Any]) -> bool: + """Check whether type_ was created using typing.NewType. + + Can't use isinstance because it fails <3.10. """ - # `get_origin` is documented as normalizing any typing-module aliases to `collections` classes, - # hence the second check: - return tp is collections.abc.Callable or get_origin(tp) is collections.abc.Callable + return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type] -_classvar_re = re.compile(r'((\w+\.)?Annotated\[)?(\w+\.)?ClassVar\[') +def _check_classvar(v: type[Any] | None) -> bool: + if v is None: + return False + + return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar' -def is_classvar_annotation(tp: Any, /) -> bool: - """Return whether the provided argument represents a class variable annotation. - - Although not explicitly stated by the typing specification, `ClassVar` can be used - inside `Annotated` and as such, this function checks for this specific scenario. - - Because this function is used to detect class variables before evaluating forward references - (or because evaluation failed), we also implement a naive regex match implementation. This is - required because class variables are inspected before fields are collected, so we try to be - as accurate as possible. - """ - if typing_objects.is_classvar(tp): +def is_classvar(ann_type: type[Any]) -> bool: + if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)): return True - origin = get_origin(tp) - - if typing_objects.is_classvar(origin): - return True - - if typing_objects.is_annotated(origin): - annotated_type = tp.__origin__ - if typing_objects.is_classvar(annotated_type) or typing_objects.is_classvar(get_origin(annotated_type)): - return True - - str_ann: str | None = None - if isinstance(tp, typing.ForwardRef): - str_ann = tp.__forward_arg__ - if isinstance(tp, str): - str_ann = tp - - if str_ann is not None and _classvar_re.match(str_ann): - # stdlib dataclasses do something similar, although a bit more advanced - # (see `dataclass._is_type`). + # this is an ugly workaround for class vars that contain forward references and are therefore themselves + # forward references, see #3679 + if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore return True return False -_t_final = typing.Final -_te_final = typing_extensions.Final +def _check_finalvar(v: type[Any] | None) -> bool: + """Check if a given type is a `typing.Final` type.""" + if v is None: + return False + + return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final') -# TODO implement `is_finalvar_annotation` as Final can be wrapped with other special forms: -def is_finalvar(tp: Any, /) -> bool: - """Return whether the provided argument is a `Final` special form, parametrized or not. - - ```python {test="skip" lint="skip"} - is_finalvar(Final[int]) - #> True - is_finalvar(Final) - #> True - """ - # Final is not necessarily parametrized: - if tp is _t_final or tp is _te_final: - return True - origin = get_origin(tp) - return origin is _t_final or origin is _te_final +def is_finalvar(ann_type: Any) -> bool: + return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type)) -_NONE_TYPES: tuple[Any, ...] = (None, NoneType, typing.Literal[None], typing_extensions.Literal[None]) +def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None: + """We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the + global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope + and suggestion at the end of the next comment by @gvanrossum. + WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the + parent of where it is called. -def is_none_type(tp: Any, /) -> bool: - """Return whether the argument represents the `None` type as part of an annotation. - - ```python {test="skip" lint="skip"} - is_none_type(None) - #> True - is_none_type(NoneType) - #> True - is_none_type(Literal[None]) - #> True - is_none_type(type[None]) - #> False - """ - return tp in _NONE_TYPES - - -def is_namedtuple(tp: Any, /) -> bool: - """Return whether the provided argument is a named tuple class. - - The class can be created using `typing.NamedTuple` or `collections.namedtuple`. - Parametrized generic classes are *not* assumed to be named tuples. - """ - from ._utils import lenient_issubclass # circ. import - - return lenient_issubclass(tp, tuple) and hasattr(tp, '_fields') - - -# TODO In 2.12, delete this export. It is currently defined only to not break -# pydantic-settings which relies on it: -origin_is_union = is_union_origin - - -def is_generic_alias(tp: Any, /) -> bool: - return isinstance(tp, (types.GenericAlias, typing._GenericAlias)) # pyright: ignore[reportAttributeAccessIssue] - - -# TODO: Ideally, we should avoid relying on the private `typing` constructs: - -if sys.version_info < (3, 10): - WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # pyright: ignore[reportAttributeAccessIssue] -else: - WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias, types.UnionType) # pyright: ignore[reportAttributeAccessIssue] - - -# Similarly, we shouldn't rely on this `_Final` class, which is even more private than `_GenericAlias`: -typing_base: Any = typing._Final # pyright: ignore[reportAttributeAccessIssue] - - -### Annotation evaluations functions: - - -def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None: - """Fetch the local namespace of the parent frame where this function is called. - - Using this function is mostly useful to resolve forward annotations pointing to members defined in a local namespace, - such as assignments inside a function. Using the standard library tools, it is currently not possible to resolve - such annotations: - - ```python {lint="skip" test="skip"} - from typing import get_type_hints - - def func() -> None: - Alias = int - - class C: - a: 'Alias' - - # Raises a `NameError: 'Alias' is not defined` - get_type_hints(C) - ``` - - Pydantic uses this function when a Pydantic model is being defined to fetch the parent frame locals. However, - this only allows us to fetch the parent frame namespace and not other parents (e.g. a model defined in a function, - itself defined in another function). Inspecting the next outer frames (using `f_back`) is not reliable enough - (see https://discuss.python.org/t/20659). - - Because this function is mostly used to better resolve forward annotations, nothing is returned if the parent frame's - code object is defined at the module level. In this case, the locals of the frame will be the same as the module - globals where the class is defined (see `_namespace_utils.get_module_ns_of`). However, if you still want to fetch - the module globals (e.g. when rebuilding a model, where the frame where the rebuild call is performed might contain - members that you want to use for forward annotations evaluation), you can use the `force` parameter. - - Args: - parent_depth: The depth at which to get the frame. Defaults to 2, meaning the parent frame where this function - is called will be used. - force: Whether to always return the frame locals, even if the frame's code object is defined at the module level. - - Returns: - The locals of the namespace, or `None` if it was skipped as per the described logic. + WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a + dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many + other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659. """ frame = sys._getframe(parent_depth) - - if frame.f_code.co_name.startswith('`, - # and we need to skip this frame as it is irrelevant. - frame = cast(types.FrameType, frame.f_back) # guaranteed to not be `None` - - # note, we don't copy frame.f_locals here (or during the last return call), because we don't expect the namespace to be - # modified down the line if this becomes a problem, we could implement some sort of frozen mapping structure to enforce this. - if force: + # if f_back is None, it's the global module namespace and we don't need to include it here + if frame.f_back is None: + return None + else: return frame.f_locals - # If either of the following conditions are true, the class is defined at the top module level. - # To better understand why we need both of these checks, see - # https://github.com/pydantic/pydantic/pull/10113#discussion_r1714981531. - if frame.f_back is None or frame.f_code.co_name == '': - return None - return frame.f_locals +def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]: + module_name = getattr(obj, '__module__', None) + if module_name: + try: + module_globalns = sys.modules[module_name].__dict__ + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + else: + if globalns: + return {**module_globalns, **globalns} + else: + # copy module globals to make sure it can't be updated later + return module_globalns.copy() + + return globalns or {} -def _type_convert(arg: Any) -> Any: - """Convert `None` to `NoneType` and strings to `ForwardRef` instances. - - This is a backport of the private `typing._type_convert` function. When - evaluating a type, `ForwardRef._evaluate` ends up being called, and is - responsible for making this conversion. However, we still have to apply - it for the first argument passed to our type evaluation functions, similarly - to the `typing.get_type_hints` function. - """ - if arg is None: - return NoneType - if isinstance(arg, str): - # Like `typing.get_type_hints`, assume the arg can be in any context, - # hence the proper `is_argument` and `is_class` args: - return _make_forward_ref(arg, is_argument=False, is_class=True) - return arg +def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]: + ns = add_module_globals(cls, parent_namespace) + ns[cls.__name__] = cls + return ns -def get_model_type_hints( - obj: type[BaseModel], - *, - ns_resolver: NsResolver | None = None, -) -> dict[str, tuple[Any, bool]]: - """Collect annotations from a Pydantic model class, including those from parent classes. - - Args: - obj: The Pydantic model to inspect. - ns_resolver: A namespace resolver instance to use. Defaults to an empty instance. - - Returns: - A dictionary mapping annotation names to a two-tuple: the first element is the evaluated - type or the original annotation if a `NameError` occurred, the second element is a boolean - indicating if whether the evaluation succeeded. - """ - hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {} - ns_resolver = ns_resolver or NsResolver() - - for base in reversed(obj.__mro__): - ann: dict[str, Any] | None = base.__dict__.get('__annotations__') - if not ann or isinstance(ann, types.GetSetDescriptorType): - continue - with ns_resolver.push(base): - globalns, localns = ns_resolver.types_namespace - for name, value in ann.items(): - if name.startswith('_'): - # For private attributes, we only need the annotation to detect the `ClassVar` special form. - # For this reason, we still try to evaluate it, but we also catch any possible exception (on - # top of the `NameError`s caught in `try_eval_type`) that could happen so that users are free - # to use any kind of forward annotation for private fields (e.g. circular imports, new typing - # syntax, etc). - try: - hints[name] = try_eval_type(value, globalns, localns) - except Exception: - hints[name] = (value, False) - else: - hints[name] = try_eval_type(value, globalns, localns) - return hints - - -def get_cls_type_hints( - obj: type[Any], - *, - ns_resolver: NsResolver | None = None, -) -> dict[str, Any]: +def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]: """Collect annotations from a class, including those from parent classes. - Args: - obj: The class to inspect. - ns_resolver: A namespace resolver instance to use. Defaults to an empty instance. + Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable. """ - hints: dict[str, Any] | dict[str, tuple[Any, bool]] = {} - ns_resolver = ns_resolver or NsResolver() - + hints = {} for base in reversed(obj.__mro__): - ann: dict[str, Any] | None = base.__dict__.get('__annotations__') - if not ann or isinstance(ann, types.GetSetDescriptorType): - continue - with ns_resolver.push(base): - globalns, localns = ns_resolver.types_namespace + ann = base.__dict__.get('__annotations__') + localns = dict(vars(base)) + if ann is not None and ann is not GetSetDescriptorType: for name, value in ann.items(): - hints[name] = eval_type(value, globalns, localns) + hints[name] = eval_type_lenient(value, globalns, localns) return hints -def try_eval_type( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, -) -> tuple[Any, bool]: - """Try evaluating the annotation using the provided namespaces. - - Args: - value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance - of `str`, it will be converted to a `ForwardRef`. - localns: The global namespace to use during annotation evaluation. - globalns: The local namespace to use during annotation evaluation. - - Returns: - A two-tuple containing the possibly evaluated type and a boolean indicating - whether the evaluation succeeded or not. - """ - value = _type_convert(value) +def eval_type_lenient(value: Any, globalns: dict[str, Any] | None, localns: dict[str, Any] | None) -> Any: + """Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved.""" + if value is None: + value = NoneType + elif isinstance(value, str): + value = _make_forward_ref(value, is_argument=False, is_class=True) try: - return eval_type_backport(value, globalns, localns), True + return typing._eval_type(value, globalns, localns) # type: ignore except NameError: - return value, False - - -def eval_type( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, -) -> Any: - """Evaluate the annotation using the provided namespaces. - - Args: - value: The value to evaluate. If `None`, it will be replaced by `type[None]`. If an instance - of `str`, it will be converted to a `ForwardRef`. - localns: The global namespace to use during annotation evaluation. - globalns: The local namespace to use during annotation evaluation. - """ - value = _type_convert(value) - return eval_type_backport(value, globalns, localns) - - -@deprecated( - '`eval_type_lenient` is deprecated, use `try_eval_type` instead.', - category=None, -) -def eval_type_lenient( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, -) -> Any: - ev, _ = try_eval_type(value, globalns, localns) - return ev - - -def eval_type_backport( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, - type_params: tuple[Any, ...] | None = None, -) -> Any: - """An enhanced version of `typing._eval_type` which will fall back to using the `eval_type_backport` - package if it's installed to let older Python versions use newer typing constructs. - - Specifically, this transforms `X | Y` into `typing.Union[X, Y]` and `list[X]` into `typing.List[X]` - (as well as all the types made generic in PEP 585) if the original syntax is not supported in the - current Python version. - - This function will also display a helpful error if the value passed fails to evaluate. - """ - try: - return _eval_type_backport(value, globalns, localns, type_params) - except TypeError as e: - if 'Unable to evaluate type annotation' in str(e): - raise - - # If it is a `TypeError` and value isn't a `ForwardRef`, it would have failed during annotation definition. - # Thus we assert here for type checking purposes: - assert isinstance(value, typing.ForwardRef) - - message = f'Unable to evaluate type annotation {value.__forward_arg__!r}.' - if sys.version_info >= (3, 11): - e.add_note(message) - raise - else: - raise TypeError(message) from e - except RecursionError as e: - # TODO ideally recursion errors should be checked in `eval_type` above, but `eval_type_backport` - # is used directly in some places. - message = ( - "If you made use of an implicit recursive type alias (e.g. `MyType = list['MyType']), " - 'consider using PEP 695 type aliases instead. For more details, refer to the documentation: ' - f'https://docs.pydantic.dev/{version_short()}/concepts/types/#named-recursive-types' - ) - if sys.version_info >= (3, 11): - e.add_note(message) - raise - else: - raise RecursionError(f'{e.args[0]}\n{message}') - - -def _eval_type_backport( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, - type_params: tuple[Any, ...] | None = None, -) -> Any: - try: - return _eval_type(value, globalns, localns, type_params) - except TypeError as e: - if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)): - raise - - try: - from eval_type_backport import eval_type_backport - except ImportError: - raise TypeError( - f'Unable to evaluate type annotation {value.__forward_arg__!r}. If you are making use ' - 'of the new typing syntax (unions using `|` since Python 3.10 or builtins subscripting ' - 'since Python 3.9), you should either replace the use of new syntax with the existing ' - '`typing` constructs or install the `eval_type_backport` package.' - ) from e - - return eval_type_backport( - value, - globalns, - localns, # pyright: ignore[reportArgumentType], waiting on a new `eval_type_backport` release. - try_default=False, - ) - - -def _eval_type( - value: Any, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, - type_params: tuple[Any, ...] | None = None, -) -> Any: - if sys.version_info >= (3, 13): - return typing._eval_type( # type: ignore - value, globalns, localns, type_params=type_params - ) - else: - return typing._eval_type( # type: ignore - value, globalns, localns - ) - - -def is_backport_fixable_error(e: TypeError) -> bool: - msg = str(e) - - return sys.version_info < (3, 10) and msg.startswith('unsupported operand type(s) for |: ') + # the point of this function is to be tolerant to this case + return value def get_function_type_hints( - function: Callable[..., Any], - *, - include_keys: set[str] | None = None, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, + function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None ) -> dict[str, Any]: - """Return type hints for a function. - - This is similar to the `typing.get_type_hints` function, with a few differences: - - Support `functools.partial` by using the underlying `func` attribute. - - Do not wrap type annotation of a parameter with `Optional` if it has a default value of `None` - (related bug: https://github.com/python/cpython/issues/90353, only fixed in 3.11+). + """Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also + copes with `partial`. """ - try: - if isinstance(function, partial): - annotations = function.func.__annotations__ - else: - annotations = function.__annotations__ - except AttributeError: - # Some functions (e.g. builtins) don't have annotations: - return {} - - if globalns is None: - globalns = get_module_ns_of(function) - type_params: tuple[Any, ...] | None = None - if localns is None: - # If localns was specified, it is assumed to already contain type params. This is because - # Pydantic has more advanced logic to do so (see `_namespace_utils.ns_for_function`). - type_params = getattr(function, '__type_params__', ()) + if isinstance(function, partial): + annotations = function.func.__annotations__ + else: + annotations = function.__annotations__ + globalns = add_module_globals(function) type_hints = {} for name, value in annotations.items(): if include_keys is not None and name not in include_keys: @@ -548,7 +248,7 @@ def get_function_type_hints( elif isinstance(value, str): value = _make_forward_ref(value) - type_hints[name] = eval_type_backport(value, globalns, localns, type_params) + type_hints[name] = typing._eval_type(value, globalns, types_namespace) # type: ignore return type_hints @@ -663,15 +363,11 @@ else: if isinstance(value, str): value = _make_forward_ref(value, is_argument=False, is_class=True) - value = eval_type_backport(value, base_globals, base_locals) + value = typing._eval_type(value, base_globals, base_locals) # type: ignore hints[name] = value - if not include_extras and hasattr(typing, '_strip_annotations'): - return { - k: typing._strip_annotations(t) # type: ignore - for k, t in hints.items() - } - else: - return hints + return ( + hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore + ) if globalns is None: if isinstance(obj, types.ModuleType): @@ -692,7 +388,7 @@ else: if isinstance(obj, typing._allowed_types): # type: ignore return {} else: - raise TypeError(f'{obj!r} is not a module, class, method, or function.') + raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.') defaults = typing._get_defaults(obj) # type: ignore hints = dict(hints) for name, value in hints.items(): @@ -707,8 +403,33 @@ else: is_argument=not isinstance(obj, types.ModuleType), is_class=False, ) - value = eval_type_backport(value, globalns, localns) + value = typing._eval_type(value, globalns, localns) # type: ignore if name in defaults and defaults[name] is None: value = typing.Optional[value] hints[name] = value return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore + + +if sys.version_info < (3, 9): + + def evaluate_fwd_ref( + ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None + ) -> Any: + return ref._evaluate(globalns=globalns, localns=localns) + +else: + + def evaluate_fwd_ref( + ref: ForwardRef, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None + ) -> Any: + return ref._evaluate(globalns=globalns, localns=localns, recursive_guard=frozenset()) + + +def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: + # The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality, + # so I created this convenience function + return dataclasses.is_dataclass(_cls) + + +def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]: + return isinstance(origin, TypeAliasType) diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_utils.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_utils.py index f334649..69be19f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_utils.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_utils.py @@ -2,30 +2,20 @@ This should be reduced as much as possible with functions only used in one place, moved to that place. """ - from __future__ import annotations as _annotations -import dataclasses import keyword -import sys import typing -import warnings import weakref from collections import OrderedDict, defaultdict, deque -from collections.abc import Mapping from copy import deepcopy -from functools import cached_property -from inspect import Parameter from itertools import zip_longest from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType -from typing import Any, Callable, Generic, TypeVar, overload +from typing import Any, TypeVar -from typing_extensions import TypeAlias, TypeGuard, deprecated - -from pydantic import PydanticDeprecatedSince211 +from typing_extensions import TypeAlias, TypeGuard from . import _repr, _typing_extra -from ._import_utils import import_cached_base_model if typing.TYPE_CHECKING: MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]' @@ -69,25 +59,6 @@ BUILTIN_COLLECTIONS: set[type[Any]] = { } -def can_be_positional(param: Parameter) -> bool: - """Return whether the parameter accepts a positional argument. - - ```python {test="skip" lint="skip"} - def func(a, /, b, *, c): - pass - - params = inspect.signature(func).parameters - can_be_positional(params['a']) - #> True - can_be_positional(params['b']) - #> True - can_be_positional(params['c']) - #> False - ``` - """ - return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) - - def sequence_like(v: Any) -> bool: return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) @@ -112,7 +83,7 @@ def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]: """Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking, unlike raw calls to lenient_issubclass. """ - BaseModel = import_cached_base_model() + from ..main import BaseModel return lenient_issubclass(cls, BaseModel) and cls is not BaseModel @@ -304,23 +275,19 @@ class ValueItems(_repr.Representation): if typing.TYPE_CHECKING: - def LazyClassAttribute(name: str, get_value: Callable[[], T]) -> T: ... + def ClassAttribute(name: str, value: T) -> T: + ... else: - class LazyClassAttribute: - """A descriptor exposing an attribute only accessible on a class (hidden from instances). + class ClassAttribute: + """Hide class attribute from its instances.""" - The attribute is lazily computed and cached during the first access. - """ + __slots__ = 'name', 'value' - def __init__(self, name: str, get_value: Callable[[], Any]) -> None: + def __init__(self, name: str, value: Any) -> None: self.name = name - self.get_value = get_value - - @cached_property - def value(self) -> Any: - return self.get_value() + self.value = value def __get__(self, instance: Any, owner: type[Any]) -> None: if instance is None: @@ -342,7 +309,7 @@ def smart_deepcopy(obj: Obj) -> Obj: try: if not obj and obj_type in BUILTIN_COLLECTIONS: # faster way for empty collections, no need to copy its members - return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore + return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method except (TypeError, ValueError, RuntimeError): # do we really dare to catch ALL errors? Seems a bit risky pass @@ -350,7 +317,7 @@ def smart_deepcopy(obj: Obj) -> Obj: return deepcopy(obj) # slowest way when we actually might need a deepcopy -_SENTINEL = object() +_EMPTY = object() def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool: @@ -362,70 +329,7 @@ def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bo >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" False """ - for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL): + for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): if left_item is not right_item: return False return True - - -@dataclasses.dataclass(frozen=True) -class SafeGetItemProxy: - """Wrapper redirecting `__getitem__` to `get` with a sentinel value as default - - This makes is safe to use in `operator.itemgetter` when some keys may be missing - """ - - # Define __slots__manually for performances - # @dataclasses.dataclass() only support slots=True in python>=3.10 - __slots__ = ('wrapped',) - - wrapped: Mapping[str, Any] - - def __getitem__(self, key: str, /) -> Any: - return self.wrapped.get(key, _SENTINEL) - - # required to pass the object to operator.itemgetter() instances due to a quirk of typeshed - # https://github.com/python/mypy/issues/13713 - # https://github.com/python/typeshed/pull/8785 - # Since this is typing-only, hide it in a typing.TYPE_CHECKING block - if typing.TYPE_CHECKING: - - def __contains__(self, key: str, /) -> bool: - return self.wrapped.__contains__(key) - - -_ModelT = TypeVar('_ModelT', bound='BaseModel') -_RT = TypeVar('_RT') - - -class deprecated_instance_property(Generic[_ModelT, _RT]): - """A decorator exposing the decorated class method as a property, with a warning on instance access. - - This decorator takes a class method defined on the `BaseModel` class and transforms it into - an attribute. The attribute can be accessed on both the class and instances of the class. If accessed - via an instance, a deprecation warning is emitted stating that instance access will be removed in V3. - """ - - def __init__(self, fget: Callable[[type[_ModelT]], _RT], /) -> None: - # Note: fget should be a classmethod: - self.fget = fget - - @overload - def __get__(self, instance: None, objtype: type[_ModelT]) -> _RT: ... - @overload - @deprecated( - 'Accessing this attribute on the instance is deprecated, and will be removed in Pydantic V3. ' - 'Instead, you should access this attribute from the model class.', - category=None, - ) - def __get__(self, instance: _ModelT, objtype: type[_ModelT]) -> _RT: ... - def __get__(self, instance: _ModelT | None, objtype: type[_ModelT]) -> _RT: - if instance is not None: - attr_name = self.fget.__name__ if sys.version_info >= (3, 10) else self.fget.__func__.__name__ - warnings.warn( - f'Accessing the {attr_name!r} attribute on the instance is deprecated. ' - 'Instead, you should access this attribute from the model class.', - category=PydanticDeprecatedSince211, - stacklevel=2, - ) - return self.fget.__get__(instance, objtype)() diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_validate_call.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_validate_call.py index ab82832..a58e240 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_validate_call.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_validate_call.py @@ -1,122 +1,88 @@ from __future__ import annotations as _annotations -import functools import inspect -from collections.abc import Awaitable +from dataclasses import dataclass from functools import partial -from typing import Any, Callable +from typing import Any, Awaitable, Callable import pydantic_core from ..config import ConfigDict from ..plugin._schema_validator import create_schema_validator +from . import _discriminated_union, _generate_schema, _typing_extra from ._config import ConfigWrapper -from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes -from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function +from ._core_utils import simplify_schema_references, validate_core_schema -def extract_function_name(func: ValidateCallSupportedTypes) -> str: - """Extract the name of a `ValidateCallSupportedTypes` object.""" - return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__ - - -def extract_function_qualname(func: ValidateCallSupportedTypes) -> str: - """Extract the qualname of a `ValidateCallSupportedTypes` object.""" - return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__ - - -def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]): - """Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function.""" - if inspect.iscoroutinefunction(wrapped): - - @functools.wraps(wrapped) - async def wrapper_function(*args, **kwargs): # type: ignore - return await wrapper(*args, **kwargs) - else: - - @functools.wraps(wrapped) - def wrapper_function(*args, **kwargs): - return wrapper(*args, **kwargs) - - # We need to manually update this because `partial` object has no `__name__` and `__qualname__`. - wrapper_function.__name__ = extract_function_name(wrapped) - wrapper_function.__qualname__ = extract_function_qualname(wrapped) - wrapper_function.raw_function = wrapped # type: ignore - - return wrapper_function +@dataclass +class CallMarker: + function: Callable[..., Any] + validate_return: bool class ValidateCallWrapper: - """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.""" + """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value. + + It's partially inspired by `wraps` which in turn uses `partial`, but extended to be a descriptor so + these functions can be applied to instance methods, class methods, static methods, as well as normal functions. + """ __slots__ = ( - 'function', - 'validate_return', - 'schema_type', - 'module', - 'qualname', - 'ns_resolver', - 'config_wrapper', - '__pydantic_complete__', + 'raw_function', + '_config', + '_validate_return', + '__pydantic_core_schema__', '__pydantic_validator__', - '__return_pydantic_validator__', + '__signature__', + '__name__', + '__qualname__', + '__annotations__', + '__dict__', # required for __module__ ) - def __init__( - self, - function: ValidateCallSupportedTypes, - config: ConfigDict | None, - validate_return: bool, - parent_namespace: MappingNamespace | None, - ) -> None: - self.function = function - self.validate_return = validate_return + def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool): + self.raw_function = function + self._config = config + self._validate_return = validate_return + self.__signature__ = inspect.signature(function) if isinstance(function, partial): - self.schema_type = function.func - self.module = function.func.__module__ + func = function.func + self.__name__ = f'partial({func.__name__})' + self.__qualname__ = f'partial({func.__qualname__})' + self.__annotations__ = func.__annotations__ + self.__module__ = func.__module__ + self.__doc__ = func.__doc__ else: - self.schema_type = function - self.module = function.__module__ - self.qualname = extract_function_qualname(function) + self.__name__ = function.__name__ + self.__qualname__ = function.__qualname__ + self.__annotations__ = function.__annotations__ + self.__module__ = function.__module__ + self.__doc__ = function.__doc__ - self.ns_resolver = NsResolver( - namespaces_tuple=ns_for_function(self.schema_type, parent_namespace=parent_namespace) - ) - self.config_wrapper = ConfigWrapper(config) - if not self.config_wrapper.defer_build: - self._create_validators() - else: - self.__pydantic_complete__ = False + namespace = _typing_extra.add_module_globals(function, None) + config_wrapper = ConfigWrapper(config) + gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) + schema = gen_schema.collect_definitions(gen_schema.generate_schema(function)) + schema = simplify_schema_references(schema) + self.__pydantic_core_schema__ = schema = schema + core_config = config_wrapper.core_config(self) + schema = _discriminated_union.apply_discriminators(schema) + self.__pydantic_validator__ = create_schema_validator(schema, core_config, config_wrapper.plugin_settings) - def _create_validators(self) -> None: - gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver) - schema = gen_schema.clean_schema(gen_schema.generate_schema(self.function)) - core_config = self.config_wrapper.core_config(title=self.qualname) - - self.__pydantic_validator__ = create_schema_validator( - schema, - self.schema_type, - self.module, - self.qualname, - 'validate_call', - core_config, - self.config_wrapper.plugin_settings, - ) - if self.validate_return: - signature = inspect.signature(self.function) - return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any - gen_schema = GenerateSchema(self.config_wrapper, self.ns_resolver) - schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) - validator = create_schema_validator( - schema, - self.schema_type, - self.module, - self.qualname, - 'validate_call', - core_config, - self.config_wrapper.plugin_settings, + if self._validate_return: + return_type = ( + self.__signature__.return_annotation + if self.__signature__.return_annotation is not self.__signature__.empty + else Any ) - if inspect.iscoroutinefunction(self.function): + gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) + schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type)) + schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema)) + self.__return_pydantic_core_schema__ = schema + core_config = config_wrapper.core_config(self) + schema = validate_core_schema(schema) + validator = pydantic_core.SchemaValidator(schema, core_config) + if inspect.iscoroutinefunction(self.raw_function): async def return_val_wrapper(aw: Awaitable[Any]) -> None: return validator.validate_python(await aw) @@ -125,16 +91,38 @@ class ValidateCallWrapper: else: self.__return_pydantic_validator__ = validator.validate_python else: + self.__return_pydantic_core_schema__ = None self.__return_pydantic_validator__ = None - self.__pydantic_complete__ = True + self._name: str | None = None # set by __get__, used to set the instance attribute when decorating methods def __call__(self, *args: Any, **kwargs: Any) -> Any: - if not self.__pydantic_complete__: - self._create_validators() - res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs)) if self.__return_pydantic_validator__: return self.__return_pydantic_validator__(res) - else: - return res + return res + + def __get__(self, obj: Any, objtype: type[Any] | None = None) -> ValidateCallWrapper: + """Bind the raw function and return another ValidateCallWrapper wrapping that.""" + if obj is None: + try: + # Handle the case where a method is accessed as a class attribute + return objtype.__getattribute__(objtype, self._name) # type: ignore + except AttributeError: + # This will happen the first time the attribute is accessed + pass + + bound_function = self.raw_function.__get__(obj, objtype) + result = self.__class__(bound_function, self._config, self._validate_return) + if self._name is not None: + if obj is not None: + object.__setattr__(obj, self._name, result) + else: + object.__setattr__(objtype, self._name, result) + return result + + def __set_name__(self, owner: Any, name: str) -> None: + self._name = name + + def __repr__(self) -> str: + return f'ValidateCallWrapper({self.raw_function})' diff --git a/venv/lib/python3.12/site-packages/pydantic/_internal/_validators.py b/venv/lib/python3.12/site-packages/pydantic/_internal/_validators.py index 803363c..e3a7e50 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_internal/_validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/_internal/_validators.py @@ -5,32 +5,22 @@ Import of this module is deferred since it contains imports of many standard lib from __future__ import annotations as _annotations -import collections.abc import math import re import typing -from decimal import Decimal -from fractions import Fraction from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from typing import Any, Callable, Union, cast, get_origin -from zoneinfo import ZoneInfo, ZoneInfoNotFoundError +from typing import Any -import typing_extensions from pydantic_core import PydanticCustomError, core_schema from pydantic_core._pydantic_core import PydanticKnownError -from typing_inspection import typing_objects - -from pydantic._internal._import_utils import import_cached_field_info -from pydantic.errors import PydanticSchemaGenerationError def sequence_validator( - input_value: typing.Sequence[Any], - /, + __input_value: typing.Sequence[Any], validator: core_schema.ValidatorFunctionWrapHandler, ) -> typing.Sequence[Any]: """Validator for `Sequence` types, isinstance(v, Sequence) has already been called.""" - value_type = type(input_value) + value_type = type(__input_value) # We don't accept any plain string as a sequence # Relevant issue: https://github.com/pydantic/pydantic/issues/5595 @@ -41,24 +31,14 @@ def sequence_validator( {'type_name': value_type.__name__}, ) - # TODO: refactor sequence validation to validate with either a list or a tuple - # schema, depending on the type of the value. - # Additionally, we should be able to remove one of either this validator or the - # SequenceValidator in _std_types_schema.py (preferably this one, while porting over some logic). - # Effectively, a refactor for sequence validation is needed. - if value_type is tuple: - input_value = list(input_value) - - v_list = validator(input_value) + v_list = validator(__input_value) # the rest of the logic is just re-creating the original type from `v_list` - if value_type is list: + if value_type == list: return v_list elif issubclass(value_type, range): # return the list as we probably can't re-create the range return v_list - elif value_type is tuple: - return tuple(v_list) else: # best guess at how to re-create the original type, more custom construction logic might be required return value_type(v_list) # type: ignore[call-arg] @@ -69,7 +49,7 @@ def import_string(value: Any) -> Any: try: return _import_string_logic(value) except ImportError as e: - raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e + raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) else: # otherwise we just return the value and let the next validator do the rest of the work return value @@ -126,39 +106,39 @@ def _import_string_logic(dotted_path: str) -> Any: return module -def pattern_either_validator(input_value: Any, /) -> typing.Pattern[Any]: - if isinstance(input_value, typing.Pattern): - return input_value - elif isinstance(input_value, (str, bytes)): +def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]: + if isinstance(__input_value, typing.Pattern): + return __input_value + elif isinstance(__input_value, (str, bytes)): # todo strict mode - return compile_pattern(input_value) # type: ignore + return compile_pattern(__input_value) # type: ignore else: raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') -def pattern_str_validator(input_value: Any, /) -> typing.Pattern[str]: - if isinstance(input_value, typing.Pattern): - if isinstance(input_value.pattern, str): - return input_value +def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]: + if isinstance(__input_value, typing.Pattern): + if isinstance(__input_value.pattern, str): + return __input_value else: raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') - elif isinstance(input_value, str): - return compile_pattern(input_value) - elif isinstance(input_value, bytes): + elif isinstance(__input_value, str): + return compile_pattern(__input_value) + elif isinstance(__input_value, bytes): raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') else: raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') -def pattern_bytes_validator(input_value: Any, /) -> typing.Pattern[bytes]: - if isinstance(input_value, typing.Pattern): - if isinstance(input_value.pattern, bytes): - return input_value +def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]: + if isinstance(__input_value, typing.Pattern): + if isinstance(__input_value.pattern, bytes): + return __input_value else: raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') - elif isinstance(input_value, bytes): - return compile_pattern(input_value) - elif isinstance(input_value, str): + elif isinstance(__input_value, bytes): + return compile_pattern(__input_value) + elif isinstance(__input_value, str): raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') else: raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') @@ -174,359 +154,125 @@ def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]: raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression') -def ip_v4_address_validator(input_value: Any, /) -> IPv4Address: - if isinstance(input_value, IPv4Address): - return input_value +def ip_v4_address_validator(__input_value: Any) -> IPv4Address: + if isinstance(__input_value, IPv4Address): + return __input_value try: - return IPv4Address(input_value) + return IPv4Address(__input_value) except ValueError: raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address') -def ip_v6_address_validator(input_value: Any, /) -> IPv6Address: - if isinstance(input_value, IPv6Address): - return input_value +def ip_v6_address_validator(__input_value: Any) -> IPv6Address: + if isinstance(__input_value, IPv6Address): + return __input_value try: - return IPv6Address(input_value) + return IPv6Address(__input_value) except ValueError: raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address') -def ip_v4_network_validator(input_value: Any, /) -> IPv4Network: +def ip_v4_network_validator(__input_value: Any) -> IPv4Network: """Assume IPv4Network initialised with a default `strict` argument. See more: https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network """ - if isinstance(input_value, IPv4Network): - return input_value + if isinstance(__input_value, IPv4Network): + return __input_value try: - return IPv4Network(input_value) + return IPv4Network(__input_value) except ValueError: raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network') -def ip_v6_network_validator(input_value: Any, /) -> IPv6Network: +def ip_v6_network_validator(__input_value: Any) -> IPv6Network: """Assume IPv6Network initialised with a default `strict` argument. See more: https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network """ - if isinstance(input_value, IPv6Network): - return input_value + if isinstance(__input_value, IPv6Network): + return __input_value try: - return IPv6Network(input_value) + return IPv6Network(__input_value) except ValueError: raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network') -def ip_v4_interface_validator(input_value: Any, /) -> IPv4Interface: - if isinstance(input_value, IPv4Interface): - return input_value +def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface: + if isinstance(__input_value, IPv4Interface): + return __input_value try: - return IPv4Interface(input_value) + return IPv4Interface(__input_value) except ValueError: raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface') -def ip_v6_interface_validator(input_value: Any, /) -> IPv6Interface: - if isinstance(input_value, IPv6Interface): - return input_value +def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface: + if isinstance(__input_value, IPv6Interface): + return __input_value try: - return IPv6Interface(input_value) + return IPv6Interface(__input_value) except ValueError: raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface') -def fraction_validator(input_value: Any, /) -> Fraction: - if isinstance(input_value, Fraction): - return input_value +def greater_than_validator(x: Any, gt: Any) -> Any: + if not (x > gt): + raise PydanticKnownError('greater_than', {'gt': gt}) + return x - try: - return Fraction(input_value) - except ValueError: - raise PydanticCustomError('fraction_parsing', 'Input is not a valid fraction') + +def greater_than_or_equal_validator(x: Any, ge: Any) -> Any: + if not (x >= ge): + raise PydanticKnownError('greater_than_equal', {'ge': ge}) + return x + + +def less_than_validator(x: Any, lt: Any) -> Any: + if not (x < lt): + raise PydanticKnownError('less_than', {'lt': lt}) + return x + + +def less_than_or_equal_validator(x: Any, le: Any) -> Any: + if not (x <= le): + raise PydanticKnownError('less_than_equal', {'le': le}) + return x + + +def multiple_of_validator(x: Any, multiple_of: Any) -> Any: + if not (x % multiple_of == 0): + raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of}) + return x + + +def min_length_validator(x: Any, min_length: Any) -> Any: + if not (len(x) >= min_length): + raise PydanticKnownError( + 'too_short', + {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}, + ) + return x + + +def max_length_validator(x: Any, max_length: Any) -> Any: + if len(x) > max_length: + raise PydanticKnownError( + 'too_long', + {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)}, + ) + return x def forbid_inf_nan_check(x: Any) -> Any: if not math.isfinite(x): raise PydanticKnownError('finite_number') return x - - -def _safe_repr(v: Any) -> int | float | str: - """The context argument for `PydanticKnownError` requires a number or str type, so we do a simple repr() coercion for types like timedelta. - - See tests/test_types.py::test_annotated_metadata_any_order for some context. - """ - if isinstance(v, (int, float, str)): - return v - return repr(v) - - -def greater_than_validator(x: Any, gt: Any) -> Any: - try: - if not (x > gt): - raise PydanticKnownError('greater_than', {'gt': _safe_repr(gt)}) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'gt' to supplied value {x}") - - -def greater_than_or_equal_validator(x: Any, ge: Any) -> Any: - try: - if not (x >= ge): - raise PydanticKnownError('greater_than_equal', {'ge': _safe_repr(ge)}) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'ge' to supplied value {x}") - - -def less_than_validator(x: Any, lt: Any) -> Any: - try: - if not (x < lt): - raise PydanticKnownError('less_than', {'lt': _safe_repr(lt)}) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'lt' to supplied value {x}") - - -def less_than_or_equal_validator(x: Any, le: Any) -> Any: - try: - if not (x <= le): - raise PydanticKnownError('less_than_equal', {'le': _safe_repr(le)}) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'le' to supplied value {x}") - - -def multiple_of_validator(x: Any, multiple_of: Any) -> Any: - try: - if x % multiple_of: - raise PydanticKnownError('multiple_of', {'multiple_of': _safe_repr(multiple_of)}) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'multiple_of' to supplied value {x}") - - -def min_length_validator(x: Any, min_length: Any) -> Any: - try: - if not (len(x) >= min_length): - raise PydanticKnownError( - 'too_short', {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)} - ) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'min_length' to supplied value {x}") - - -def max_length_validator(x: Any, max_length: Any) -> Any: - try: - if len(x) > max_length: - raise PydanticKnownError( - 'too_long', - {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)}, - ) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'max_length' to supplied value {x}") - - -def _extract_decimal_digits_info(decimal: Decimal) -> tuple[int, int]: - """Compute the total number of digits and decimal places for a given [`Decimal`][decimal.Decimal] instance. - - This function handles both normalized and non-normalized Decimal instances. - Example: Decimal('1.230') -> 4 digits, 3 decimal places - - Args: - decimal (Decimal): The decimal number to analyze. - - Returns: - tuple[int, int]: A tuple containing the number of decimal places and total digits. - - Though this could be divided into two separate functions, the logic is easier to follow if we couple the computation - of the number of decimals and digits together. - """ - try: - decimal_tuple = decimal.as_tuple() - - assert isinstance(decimal_tuple.exponent, int) - - exponent = decimal_tuple.exponent - num_digits = len(decimal_tuple.digits) - - if exponent >= 0: - # A positive exponent adds that many trailing zeros - # Ex: digit_tuple=(1, 2, 3), exponent=2 -> 12300 -> 0 decimal places, 5 digits - num_digits += exponent - decimal_places = 0 - else: - # If the absolute value of the negative exponent is larger than the - # number of digits, then it's the same as the number of digits, - # because it'll consume all the digits in digit_tuple and then - # add abs(exponent) - len(digit_tuple) leading zeros after the decimal point. - # Ex: digit_tuple=(1, 2, 3), exponent=-2 -> 1.23 -> 2 decimal places, 3 digits - # Ex: digit_tuple=(1, 2, 3), exponent=-4 -> 0.0123 -> 4 decimal places, 4 digits - decimal_places = abs(exponent) - num_digits = max(num_digits, decimal_places) - - return decimal_places, num_digits - except (AssertionError, AttributeError): - raise TypeError(f'Unable to extract decimal digits info from supplied value {decimal}') - - -def max_digits_validator(x: Any, max_digits: Any) -> Any: - try: - _, num_digits = _extract_decimal_digits_info(x) - _, normalized_num_digits = _extract_decimal_digits_info(x.normalize()) - if (num_digits > max_digits) and (normalized_num_digits > max_digits): - raise PydanticKnownError( - 'decimal_max_digits', - {'max_digits': max_digits}, - ) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'max_digits' to supplied value {x}") - - -def decimal_places_validator(x: Any, decimal_places: Any) -> Any: - try: - decimal_places_, _ = _extract_decimal_digits_info(x) - if decimal_places_ > decimal_places: - normalized_decimal_places, _ = _extract_decimal_digits_info(x.normalize()) - if normalized_decimal_places > decimal_places: - raise PydanticKnownError( - 'decimal_max_places', - {'decimal_places': decimal_places}, - ) - return x - except TypeError: - raise TypeError(f"Unable to apply constraint 'decimal_places' to supplied value {x}") - - -def deque_validator(input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> collections.deque[Any]: - return collections.deque(handler(input_value), maxlen=getattr(input_value, 'maxlen', None)) - - -def defaultdict_validator( - input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any] -) -> collections.defaultdict[Any, Any]: - if isinstance(input_value, collections.defaultdict): - default_factory = input_value.default_factory - return collections.defaultdict(default_factory, handler(input_value)) - else: - return collections.defaultdict(default_default_factory, handler(input_value)) - - -def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]: - FieldInfo = import_cached_field_info() - - values_type_origin = get_origin(values_source_type) - - def infer_default() -> Callable[[], Any]: - allowed_default_types: dict[Any, Any] = { - tuple: tuple, - collections.abc.Sequence: tuple, - collections.abc.MutableSequence: list, - list: list, - typing.Sequence: list, - set: set, - typing.MutableSet: set, - collections.abc.MutableSet: set, - collections.abc.Set: frozenset, - typing.MutableMapping: dict, - typing.Mapping: dict, - collections.abc.Mapping: dict, - collections.abc.MutableMapping: dict, - float: float, - int: int, - str: str, - bool: bool, - } - values_type = values_type_origin or values_source_type - instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`' - if typing_objects.is_typevar(values_type): - - def type_var_default_factory() -> None: - raise RuntimeError( - 'Generic defaultdict cannot be used without a concrete value type or an' - ' explicit default factory, ' + instructions - ) - - return type_var_default_factory - elif values_type not in allowed_default_types: - # a somewhat subjective set of types that have reasonable default values - allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())]) - raise PydanticSchemaGenerationError( - f'Unable to infer a default factory for keys of type {values_source_type}.' - f' Only {allowed_msg} are supported, other types require an explicit default factory' - ' ' + instructions - ) - return allowed_default_types[values_type] - - # Assume Annotated[..., Field(...)] - if typing_objects.is_annotated(values_type_origin): - field_info = next((v for v in typing_extensions.get_args(values_source_type) if isinstance(v, FieldInfo)), None) - else: - field_info = None - if field_info and field_info.default_factory: - # Assume the default factory does not take any argument: - default_default_factory = cast(Callable[[], Any], field_info.default_factory) - else: - default_default_factory = infer_default() - return default_default_factory - - -def validate_str_is_valid_iana_tz(value: Any, /) -> ZoneInfo: - if isinstance(value, ZoneInfo): - return value - try: - return ZoneInfo(value) - except (ZoneInfoNotFoundError, ValueError, TypeError): - raise PydanticCustomError('zoneinfo_str', 'invalid timezone: {value}', {'value': value}) - - -NUMERIC_VALIDATOR_LOOKUP: dict[str, Callable] = { - 'gt': greater_than_validator, - 'ge': greater_than_or_equal_validator, - 'lt': less_than_validator, - 'le': less_than_or_equal_validator, - 'multiple_of': multiple_of_validator, - 'min_length': min_length_validator, - 'max_length': max_length_validator, - 'max_digits': max_digits_validator, - 'decimal_places': decimal_places_validator, -} - -IpType = Union[IPv4Address, IPv6Address, IPv4Network, IPv6Network, IPv4Interface, IPv6Interface] - -IP_VALIDATOR_LOOKUP: dict[type[IpType], Callable] = { - IPv4Address: ip_v4_address_validator, - IPv6Address: ip_v6_address_validator, - IPv4Network: ip_v4_network_validator, - IPv6Network: ip_v6_network_validator, - IPv4Interface: ip_v4_interface_validator, - IPv6Interface: ip_v6_interface_validator, -} - -MAPPING_ORIGIN_MAP: dict[Any, Any] = { - typing.DefaultDict: collections.defaultdict, # noqa: UP006 - collections.defaultdict: collections.defaultdict, - typing.OrderedDict: collections.OrderedDict, # noqa: UP006 - collections.OrderedDict: collections.OrderedDict, - typing_extensions.OrderedDict: collections.OrderedDict, - typing.Counter: collections.Counter, - collections.Counter: collections.Counter, - # this doesn't handle subclasses of these - typing.Mapping: dict, - typing.MutableMapping: dict, - # parametrized typing.{Mutable}Mapping creates one of these - collections.abc.Mapping: dict, - collections.abc.MutableMapping: dict, -} diff --git a/venv/lib/python3.12/site-packages/pydantic/_migration.py b/venv/lib/python3.12/site-packages/pydantic/_migration.py index 980dfd2..cc9806e 100644 --- a/venv/lib/python3.12/site-packages/pydantic/_migration.py +++ b/venv/lib/python3.12/site-packages/pydantic/_migration.py @@ -1,6 +1,8 @@ import sys -from typing import Any, Callable +import warnings +from typing import Any, Callable, Dict +from ._internal._validators import import_string from .version import version_short MOVED_IN_V2 = { @@ -271,11 +273,7 @@ def getattr_migration(module: str) -> Callable[[str], Any]: The object. """ if name == '__path__': - raise AttributeError(f'module {module!r} has no attribute {name!r}') - - import warnings - - from ._internal._validators import import_string + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') import_path = f'{module}:{name}' if import_path in MOVED_IN_V2.keys(): @@ -300,9 +298,9 @@ def getattr_migration(module: str) -> Callable[[str], Any]: ) if import_path in REMOVED_IN_V2: raise PydanticImportError(f'`{import_path}` has been removed in V2.') - globals: dict[str, Any] = sys.modules[module].__dict__ + globals: Dict[str, Any] = sys.modules[module].__dict__ if name in globals: return globals[name] - raise AttributeError(f'module {module!r} has no attribute {name!r}') + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') return wrapper diff --git a/venv/lib/python3.12/site-packages/pydantic/alias_generators.py b/venv/lib/python3.12/site-packages/pydantic/alias_generators.py index 0b7653f..bbdaaaf 100644 --- a/venv/lib/python3.12/site-packages/pydantic/alias_generators.py +++ b/venv/lib/python3.12/site-packages/pydantic/alias_generators.py @@ -1,13 +1,8 @@ """Alias generators for converting between different capitalization conventions.""" - import re __all__ = ('to_pascal', 'to_camel', 'to_snake') -# TODO: in V3, change the argument names to be more descriptive -# Generally, don't only convert from snake_case, or name the functions -# more specifically like snake_to_camel. - def to_pascal(snake: str) -> str: """Convert a snake_case string to PascalCase. @@ -31,17 +26,12 @@ def to_camel(snake: str) -> str: Returns: The converted camelCase string. """ - # If the string is already in camelCase and does not contain a digit followed - # by a lowercase letter, return it as it is - if re.match('^[a-z]+[A-Za-z0-9]*$', snake) and not re.search(r'\d[a-z]', snake): - return snake - camel = to_pascal(snake) return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel) def to_snake(camel: str) -> str: - """Convert a PascalCase, camelCase, or kebab-case string to snake_case. + """Convert a PascalCase or camelCase string to snake_case. Args: camel: The string to convert. @@ -49,14 +39,6 @@ def to_snake(camel: str) -> str: Returns: The converted string in snake_case. """ - # Handle the sequence of uppercase letters followed by a lowercase letter - snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel) - # Insert an underscore between a lowercase letter and an uppercase letter - snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) - # Insert an underscore between a digit and an uppercase letter - snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) - # Insert an underscore between a lowercase letter and a digit - snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) - # Replace hyphens with underscores to handle kebab-case - snake = snake.replace('-', '_') + snake = re.sub(r'([a-zA-Z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', camel) + snake = re.sub(r'([a-z0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) return snake.lower() diff --git a/venv/lib/python3.12/site-packages/pydantic/aliases.py b/venv/lib/python3.12/site-packages/pydantic/aliases.py deleted file mode 100644 index ac22737..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/aliases.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Support for alias configurations.""" - -from __future__ import annotations - -import dataclasses -from typing import Any, Callable, Literal - -from pydantic_core import PydanticUndefined - -from ._internal import _internal_dataclass - -__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices') - - -@dataclasses.dataclass(**_internal_dataclass.slots_true) -class AliasPath: - """!!! abstract "Usage Documentation" - [`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices) - - A data class used by `validation_alias` as a convenience to create aliases. - - Attributes: - path: A list of string or integer aliases. - """ - - path: list[int | str] - - def __init__(self, first_arg: str, *args: str | int) -> None: - self.path = [first_arg] + list(args) - - def convert_to_aliases(self) -> list[str | int]: - """Converts arguments to a list of string or integer aliases. - - Returns: - The list of aliases. - """ - return self.path - - def search_dict_for_path(self, d: dict) -> Any: - """Searches a dictionary for the path specified by the alias. - - Returns: - The value at the specified path, or `PydanticUndefined` if the path is not found. - """ - v = d - for k in self.path: - if isinstance(v, str): - # disallow indexing into a str, like for AliasPath('x', 0) and x='abc' - return PydanticUndefined - try: - v = v[k] - except (KeyError, IndexError, TypeError): - return PydanticUndefined - return v - - -@dataclasses.dataclass(**_internal_dataclass.slots_true) -class AliasChoices: - """!!! abstract "Usage Documentation" - [`AliasPath` and `AliasChoices`](../concepts/alias.md#aliaspath-and-aliaschoices) - - A data class used by `validation_alias` as a convenience to create aliases. - - Attributes: - choices: A list containing a string or `AliasPath`. - """ - - choices: list[str | AliasPath] - - def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None: - self.choices = [first_choice] + list(choices) - - def convert_to_aliases(self) -> list[list[str | int]]: - """Converts arguments to a list of lists containing string or integer aliases. - - Returns: - The list of aliases. - """ - aliases: list[list[str | int]] = [] - for c in self.choices: - if isinstance(c, AliasPath): - aliases.append(c.convert_to_aliases()) - else: - aliases.append([c]) - return aliases - - -@dataclasses.dataclass(**_internal_dataclass.slots_true) -class AliasGenerator: - """!!! abstract "Usage Documentation" - [Using an `AliasGenerator`](../concepts/alias.md#using-an-aliasgenerator) - - A data class used by `alias_generator` as a convenience to create various aliases. - - Attributes: - alias: A callable that takes a field name and returns an alias for it. - validation_alias: A callable that takes a field name and returns a validation alias for it. - serialization_alias: A callable that takes a field name and returns a serialization alias for it. - """ - - alias: Callable[[str], str] | None = None - validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None - serialization_alias: Callable[[str], str] | None = None - - def _generate_alias( - self, - alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'], - allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...], - field_name: str, - ) -> str | AliasPath | AliasChoices | None: - """Generate an alias of the specified kind. Returns None if the alias generator is None. - - Raises: - TypeError: If the alias generator produces an invalid type. - """ - alias = None - if alias_generator := getattr(self, alias_kind): - alias = alias_generator(field_name) - if alias and not isinstance(alias, allowed_types): - raise TypeError( - f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`' - ) - return alias - - def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]: - """Generate `alias`, `validation_alias`, and `serialization_alias` for a field. - - Returns: - A tuple of three aliases - validation, alias, and serialization. - """ - alias = self._generate_alias('alias', (str,), field_name) - validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name) - serialization_alias = self._generate_alias('serialization_alias', (str,), field_name) - - return alias, validation_alias, serialization_alias # type: ignore diff --git a/venv/lib/python3.12/site-packages/pydantic/annotated_handlers.py b/venv/lib/python3.12/site-packages/pydantic/annotated_handlers.py index d0cb5d3..59adabf 100644 --- a/venv/lib/python3.12/site-packages/pydantic/annotated_handlers.py +++ b/venv/lib/python3.12/site-packages/pydantic/annotated_handlers.py @@ -1,5 +1,4 @@ """Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`.""" - from __future__ import annotations as _annotations from typing import TYPE_CHECKING, Any, Union @@ -7,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Union from pydantic_core import core_schema if TYPE_CHECKING: - from ._internal._namespace_utils import NamespacesTuple from .json_schema import JsonSchemaMode, JsonSchemaValue CoreSchemaOrField = Union[ @@ -30,7 +28,7 @@ class GetJsonSchemaHandler: mode: JsonSchemaMode - def __call__(self, core_schema: CoreSchemaOrField, /) -> JsonSchemaValue: + def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue: """Call the inner handler and get the JsonSchemaValue it returns. This will call the next JSON schema modifying function up until it calls into `pydantic.json_schema.GenerateJsonSchema`, which will raise a @@ -38,7 +36,7 @@ class GetJsonSchemaHandler: a JSON schema. Args: - core_schema: A `pydantic_core.core_schema.CoreSchema`. + __core_schema: A `pydantic_core.core_schema.CoreSchema`. Returns: JsonSchemaValue: The JSON schema generated by the inner JSON schema modify @@ -46,13 +44,13 @@ class GetJsonSchemaHandler: """ raise NotImplementedError - def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue, /) -> JsonSchemaValue: + def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue: """Get the real schema for a `{"$ref": ...}` schema. If the schema given is not a `$ref` schema, it will be returned as is. This means you don't have to check before calling this function. Args: - maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema. + __maybe_ref_json_schema: A JsonSchemaValue, ref based or not. Raises: LookupError: If the ref is not found. @@ -66,7 +64,7 @@ class GetJsonSchemaHandler: class GetCoreSchemaHandler: """Handler to call into the next CoreSchema schema generation function.""" - def __call__(self, source_type: Any, /) -> core_schema.CoreSchema: + def __call__(self, __source_type: Any) -> core_schema.CoreSchema: """Call the inner handler and get the CoreSchema it returns. This will call the next CoreSchema modifying function up until it calls into Pydantic's internal schema generation machinery, which will raise a @@ -74,14 +72,14 @@ class GetCoreSchemaHandler: a CoreSchema for the given source type. Args: - source_type: The input type. + __source_type: The input type. Returns: CoreSchema: The `pydantic-core` CoreSchema generated. """ raise NotImplementedError - def generate_schema(self, source_type: Any, /) -> core_schema.CoreSchema: + def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema: """Generate a schema unrelated to the current context. Use this function if e.g. you are handling schema generation for a sequence and want to generate a schema for its items. @@ -89,20 +87,20 @@ class GetCoreSchemaHandler: that was intended for the sequence itself to its items! Args: - source_type: The input type. + __source_type: The input type. Returns: CoreSchema: The `pydantic-core` CoreSchema generated. """ raise NotImplementedError - def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema, /) -> core_schema.CoreSchema: + def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema: """Get the real schema for a `definition-ref` schema. If the schema given is not a `definition-ref` schema, it will be returned as is. This means you don't have to check before calling this function. Args: - maybe_ref_schema: A `CoreSchema`, `ref`-based or not. + __maybe_ref_schema: A `CoreSchema`, `ref`-based or not. Raises: LookupError: If the `ref` is not found. @@ -117,6 +115,6 @@ class GetCoreSchemaHandler: """Get the name of the closest field to this validator.""" raise NotImplementedError - def _get_types_namespace(self) -> NamespacesTuple: + def _get_types_namespace(self) -> dict[str, Any] | None: """Internal method used during type resolution for serializer annotations.""" raise NotImplementedError diff --git a/venv/lib/python3.12/site-packages/pydantic/class_validators.py b/venv/lib/python3.12/site-packages/pydantic/class_validators.py index 9977150..2ff72ae 100644 --- a/venv/lib/python3.12/site-packages/pydantic/class_validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/class_validators.py @@ -1,5 +1,4 @@ """`class_validators` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/color.py b/venv/lib/python3.12/site-packages/pydantic/color.py index 9a42d58..5aabec4 100644 --- a/venv/lib/python3.12/site-packages/pydantic/color.py +++ b/venv/lib/python3.12/site-packages/pydantic/color.py @@ -11,11 +11,10 @@ Warning: Deprecated See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md) for more information. """ - import math import re from colorsys import hls_to_rgb, rgb_to_hls -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Optional, Tuple, Type, Union, cast from pydantic_core import CoreSchema, PydanticCustomError, core_schema from typing_extensions import deprecated @@ -25,9 +24,9 @@ from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJso from .json_schema import JsonSchemaValue from .warnings import PydanticDeprecatedSince20 -ColorTuple = Union[tuple[int, int, int], tuple[int, int, int, float]] +ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] ColorType = Union[ColorTuple, str] -HslColorTuple = Union[tuple[float, float, float], tuple[float, float, float, float]] +HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]] class RGBA: @@ -41,7 +40,7 @@ class RGBA: self.b = b self.alpha = alpha - self._tuple: tuple[float, float, float, Optional[float]] = (r, g, b, alpha) + self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha) def __getitem__(self, item: Any) -> Any: return self._tuple[item] @@ -56,13 +55,13 @@ _r_sl = r'(\d{1,3}(?:\.\d+)?)%' r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' # CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%) -r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +r_rgb = fr'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*' # CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%) -r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +r_hsl = fr'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*' # CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%) -r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*' +r_rgb_v4_style = fr'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*' # CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%) -r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*' +r_hsl_v4_style = fr'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*' # colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} @@ -124,7 +123,7 @@ class Color(_repr.Representation): ValueError: When no named color is found and fallback is `False`. """ if self._rgba.alpha is None: - rgb = cast(tuple[int, int, int], self.as_rgb_tuple()) + rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) try: return COLORS_BY_VALUE[rgb] except KeyError as e: @@ -232,7 +231,7 @@ class Color(_repr.Representation): @classmethod def __get_pydantic_core_schema__( - cls, source: type[Any], handler: Callable[[Any], CoreSchema] + cls, source: Type[Any], handler: Callable[[Any], CoreSchema] ) -> core_schema.CoreSchema: return core_schema.with_info_plain_validator_function( cls._validate, serialization=core_schema.to_string_ser_schema() @@ -255,7 +254,7 @@ class Color(_repr.Representation): return hash(self.as_rgb_tuple()) -def parse_tuple(value: tuple[Any, ...]) -> RGBA: +def parse_tuple(value: Tuple[Any, ...]) -> RGBA: """Parse a tuple or list to get RGBA values. Args: diff --git a/venv/lib/python3.12/site-packages/pydantic/config.py b/venv/lib/python3.12/site-packages/pydantic/config.py index 12fef10..ccdcd7f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/config.py +++ b/venv/lib/python3.12/site-packages/pydantic/config.py @@ -1,33 +1,23 @@ """Configuration for Pydantic models.""" - from __future__ import annotations as _annotations -import warnings -from re import Pattern -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union -from typing_extensions import TypeAlias, TypedDict, Unpack, deprecated +from typing_extensions import Literal, TypeAlias, TypedDict from ._migration import getattr_migration -from .aliases import AliasGenerator -from .errors import PydanticUserError -from .warnings import PydanticDeprecatedSince211 if TYPE_CHECKING: from ._internal._generate_schema import GenerateSchema as _GenerateSchema - from .fields import ComputedFieldInfo, FieldInfo -__all__ = ('ConfigDict', 'with_config') +__all__ = ('ConfigDict',) -JsonValue: TypeAlias = Union[int, float, str, bool, None, list['JsonValue'], 'JsonDict'] -JsonDict: TypeAlias = dict[str, JsonValue] - JsonEncoder = Callable[[Any], Any] JsonSchemaExtraCallable: TypeAlias = Union[ - Callable[[JsonDict], None], - Callable[[JsonDict, type[Any]], None], + Callable[[Dict[str, Any]], None], + Callable[[Dict[str, Any], Type[Any]], None], ] ExtraValues = Literal['allow', 'ignore', 'forbid'] @@ -39,18 +29,11 @@ class ConfigDict(TypedDict, total=False): title: str | None """The title for the generated JSON schema, defaults to the model's name""" - model_title_generator: Callable[[type], str] | None - """A callable that takes a model class and returns the title for it. Defaults to `None`.""" - - field_title_generator: Callable[[str, FieldInfo | ComputedFieldInfo], str] | None - """A callable that takes a field's name and info and returns title for it. Defaults to `None`.""" - str_to_lower: bool """Whether to convert all characters to lowercase for str types. Defaults to `False`.""" str_to_upper: bool """Whether to convert all characters to uppercase for str types. Defaults to `False`.""" - str_strip_whitespace: bool """Whether to strip leading and trailing whitespace for str types.""" @@ -61,108 +44,84 @@ class ConfigDict(TypedDict, total=False): """The maximum length for str types. Defaults to `None`.""" extra: ExtraValues | None - ''' - Whether to ignore, allow, or forbid extra data during model initialization. Defaults to `'ignore'`. + """ + Whether to ignore, allow, or forbid extra attributes during model initialization. Defaults to `'ignore'`. - Three configuration values are available: + You can configure how pydantic handles the attributes that are not defined in the model: - - `'ignore'`: Providing extra data is ignored (the default): - ```python - from pydantic import BaseModel, ConfigDict + * `allow` - Allow any extra attributes. + * `forbid` - Forbid any extra attributes. + * `ignore` - Ignore any extra attributes. - class User(BaseModel): - model_config = ConfigDict(extra='ignore') # (1)! - - name: str - - user = User(name='John Doe', age=20) # (2)! - print(user) - #> name='John Doe' - ``` - - 1. This is the default behaviour. - 2. The `age` argument is ignored. - - - `'forbid'`: Providing extra data is not permitted, and a [`ValidationError`][pydantic_core.ValidationError] - will be raised if this is the case: - ```python - from pydantic import BaseModel, ConfigDict, ValidationError + ```py + from pydantic import BaseModel, ConfigDict - class Model(BaseModel): - x: int + class User(BaseModel): + model_config = ConfigDict(extra='ignore') # (1)! - model_config = ConfigDict(extra='forbid') + name: str - try: - Model(x=1, y='a') - except ValidationError as exc: - print(exc) - """ - 1 validation error for Model - y - Extra inputs are not permitted [type=extra_forbidden, input_value='a', input_type=str] - """ - ``` + user = User(name='John Doe', age=20) # (2)! + print(user) + #> name='John Doe' + ``` - - `'allow'`: Providing extra data is allowed and stored in the `__pydantic_extra__` dictionary attribute: - ```python - from pydantic import BaseModel, ConfigDict + 1. This is the default behaviour. + 2. The `age` argument is ignored. + + Instead, with `extra='allow'`, the `age` argument is included: + + ```py + from pydantic import BaseModel, ConfigDict - class Model(BaseModel): - x: int + class User(BaseModel): + model_config = ConfigDict(extra='allow') - model_config = ConfigDict(extra='allow') + name: str - m = Model(x=1, y='a') - assert m.__pydantic_extra__ == {'y': 'a'} - ``` - By default, no validation will be applied to these extra items, but you can set a type for the values by overriding - the type annotation for `__pydantic_extra__`: - ```python - from pydantic import BaseModel, ConfigDict, Field, ValidationError + user = User(name='John Doe', age=20) # (1)! + print(user) + #> name='John Doe' age=20 + ``` + + 1. The `age` argument is included. + + With `extra='forbid'`, an error is raised: + + ```py + from pydantic import BaseModel, ConfigDict, ValidationError - class Model(BaseModel): - __pydantic_extra__: dict[str, int] = Field(init=False) # (1)! + class User(BaseModel): + model_config = ConfigDict(extra='forbid') - x: int - - model_config = ConfigDict(extra='allow') + name: str - try: - Model(x=1, y='a') - except ValidationError as exc: - print(exc) - """ - 1 validation error for Model - y - Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str] - """ - - m = Model(x=1, y='2') - assert m.x == 1 - assert m.y == 2 - assert m.model_dump() == {'x': 1, 'y': 2} - assert m.__pydantic_extra__ == {'y': 2} - ``` - - 1. The `= Field(init=False)` does not have any effect at runtime, but prevents the `__pydantic_extra__` field from - being included as a parameter to the model's `__init__` method by type checkers. - ''' + try: + User(name='John Doe', age=20) + except ValidationError as e: + print(e) + ''' + 1 validation error for User + age + Extra inputs are not permitted [type=extra_forbidden, input_value=20, input_type=int] + ''' + ``` + """ frozen: bool """ - Whether models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates + Whether or not models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates a `__hash__()` method for the model. This makes instances of the model potentially hashable if all the attributes are hashable. Defaults to `False`. Note: - On V1, the inverse of this setting was called `allow_mutation`, and was `True` by default. + On V1, this setting was called `allow_mutation`, and was `True` by default. """ populate_by_name: bool @@ -170,77 +129,38 @@ class ConfigDict(TypedDict, total=False): Whether an aliased field may be populated by its name as given by the model attribute, as well as the alias. Defaults to `False`. - !!! warning - `populate_by_name` usage is not recommended in v2.11+ and will be deprecated in v3. - Instead, you should use the [`validate_by_name`][pydantic.config.ConfigDict.validate_by_name] configuration setting. + Note: + The name of this configuration setting was changed in **v2.0** from + `allow_population_by_alias` to `populate_by_name`. - When `validate_by_name=True` and `validate_by_alias=True`, this is strictly equivalent to the - previous behavior of `populate_by_name=True`. + ```py + from pydantic import BaseModel, ConfigDict, Field - In v2.11, we also introduced a [`validate_by_alias`][pydantic.config.ConfigDict.validate_by_alias] setting that introduces more fine grained - control for validation behavior. - Here's how you might go about using the new settings to achieve the same behavior: + class User(BaseModel): + model_config = ConfigDict(populate_by_name=True) - ```python - from pydantic import BaseModel, ConfigDict, Field + name: str = Field(alias='full_name') # (1)! + age: int - class Model(BaseModel): - model_config = ConfigDict(validate_by_name=True, validate_by_alias=True) - my_field: str = Field(alias='my_alias') # (1)! + user = User(full_name='John Doe', age=20) # (2)! + print(user) + #> name='John Doe' age=20 + user = User(name='John Doe', age=20) # (3)! + print(user) + #> name='John Doe' age=20 + ``` - m = Model(my_alias='foo') # (2)! - print(m) - #> my_field='foo' - - m = Model(my_alias='foo') # (3)! - print(m) - #> my_field='foo' - ``` - - 1. The field `'my_field'` has an alias `'my_alias'`. - 2. The model is populated by the alias `'my_alias'`. - 3. The model is populated by the attribute name `'my_field'`. + 1. The field `'name'` has an alias `'full_name'`. + 2. The model is populated by the alias `'full_name'`. + 3. The model is populated by the field name `'name'`. """ use_enum_values: bool """ Whether to populate models with the `value` property of enums, rather than the raw enum. This may be useful if you want to serialize `model.model_dump()` later. Defaults to `False`. - - !!! note - If you have an `Optional[Enum]` value that you set a default for, you need to use `validate_default=True` - for said Field to ensure that the `use_enum_values` flag takes effect on the default, as extracting an - enum's value occurs during validation, not serialization. - - ```python - from enum import Enum - from typing import Optional - - from pydantic import BaseModel, ConfigDict, Field - - class SomeEnum(Enum): - FOO = 'foo' - BAR = 'bar' - BAZ = 'baz' - - class SomeModel(BaseModel): - model_config = ConfigDict(use_enum_values=True) - - some_enum: SomeEnum - another_enum: Optional[SomeEnum] = Field( - default=SomeEnum.FOO, validate_default=True - ) - - model1 = SomeModel(some_enum=SomeEnum.BAR) - print(model1.model_dump()) - #> {'some_enum': 'bar', 'another_enum': 'foo'} - - model2 = SomeModel(some_enum=SomeEnum.BAR, another_enum=SomeEnum.BAZ) - print(model2.model_dump()) - #> {'some_enum': 'bar', 'another_enum': 'baz'} - ``` """ validate_assignment: bool @@ -251,7 +171,7 @@ class ConfigDict(TypedDict, total=False): In case the user changes the data after the model is created, the model is _not_ revalidated. - ```python + ```py from pydantic import BaseModel class User(BaseModel): @@ -270,7 +190,7 @@ class ConfigDict(TypedDict, total=False): In case you want to revalidate the model when the data is changed, you can use `validate_assignment=True`: - ```python + ```py from pydantic import BaseModel, ValidationError class User(BaseModel, validate_assignment=True): # (1)! @@ -299,7 +219,7 @@ class ConfigDict(TypedDict, total=False): """ Whether arbitrary types are allowed for field types. Defaults to `False`. - ```python + ```py from pydantic import BaseModel, ConfigDict, ValidationError # This is not a pydantic model, it's an arbitrary class @@ -358,20 +278,14 @@ class ConfigDict(TypedDict, total=False): loc_by_alias: bool """Whether to use the actual key provided in the data (e.g. alias) for error `loc`s rather than the field's name. Defaults to `True`.""" - alias_generator: Callable[[str], str] | AliasGenerator | None + alias_generator: Callable[[str], str] | None """ - A callable that takes a field name and returns an alias for it - or an instance of [`AliasGenerator`][pydantic.aliases.AliasGenerator]. Defaults to `None`. - - When using a callable, the alias generator is used for both validation and serialization. - If you want to use different alias generators for validation and serialization, you can use - [`AliasGenerator`][pydantic.aliases.AliasGenerator] instead. + A callable that takes a field name and returns an alias for it. If data source field names do not match your code style (e. g. CamelCase fields), - you can automatically generate aliases using `alias_generator`. Here's an example with - a basic callable: + you can automatically generate aliases using `alias_generator`: - ```python + ```py from pydantic import BaseModel, ConfigDict from pydantic.alias_generators import to_pascal @@ -388,30 +302,6 @@ class ConfigDict(TypedDict, total=False): #> {'Name': 'Filiz', 'LanguageCode': 'tr-TR'} ``` - If you want to use different alias generators for validation and serialization, you can use - [`AliasGenerator`][pydantic.aliases.AliasGenerator]. - - ```python - from pydantic import AliasGenerator, BaseModel, ConfigDict - from pydantic.alias_generators import to_camel, to_pascal - - class Athlete(BaseModel): - first_name: str - last_name: str - sport: str - - model_config = ConfigDict( - alias_generator=AliasGenerator( - validation_alias=to_camel, - serialization_alias=to_pascal, - ) - ) - - athlete = Athlete(firstName='John', lastName='Doe', sport='track') - print(athlete.model_dump(by_alias=True)) - #> {'FirstName': 'John', 'LastName': 'Doe', 'Sport': 'track'} - ``` - Note: Pydantic offers three built-in alias generators: [`to_pascal`][pydantic.alias_generators.to_pascal], [`to_camel`][pydantic.alias_generators.to_camel], and [`to_snake`][pydantic.alias_generators.to_snake]. @@ -425,9 +315,9 @@ class ConfigDict(TypedDict, total=False): """ allow_inf_nan: bool - """Whether to allow infinity (`+inf` an `-inf`) and NaN values to float and decimal fields. Defaults to `True`.""" + """Whether to allow infinity (`+inf` an `-inf`) and NaN values to float fields. Defaults to `True`.""" - json_schema_extra: JsonDict | JsonSchemaExtraCallable | None + json_schema_extra: dict[str, object] | JsonSchemaExtraCallable | None """A dict or callable to provide extra JSON schema properties. Defaults to `None`.""" json_encoders: dict[type[object], JsonEncoder] | None @@ -452,7 +342,7 @@ class ConfigDict(TypedDict, total=False): To configure strict mode for all fields on a model, you can set `strict=True` on the model. - ```python + ```py from pydantic import BaseModel, ConfigDict class Model(BaseModel): @@ -480,14 +370,16 @@ class ConfigDict(TypedDict, total=False): By default, model and dataclass instances are not revalidated during validation. - ```python + ```py + from typing import List + from pydantic import BaseModel class User(BaseModel, revalidate_instances='never'): # (1)! - hobbies: list[str] + hobbies: List[str] class SubUser(User): - sins: list[str] + sins: List[str] class Transaction(BaseModel): user: User @@ -515,14 +407,16 @@ class ConfigDict(TypedDict, total=False): If you want to revalidate instances during validation, you can set `revalidate_instances` to `'always'` in the model's config. - ```python + ```py + from typing import List + from pydantic import BaseModel, ValidationError class User(BaseModel, revalidate_instances='always'): # (1)! - hobbies: list[str] + hobbies: List[str] class SubUser(User): - sins: list[str] + sins: List[str] class Transaction(BaseModel): user: User @@ -556,14 +450,16 @@ class ConfigDict(TypedDict, total=False): It's also possible to set `revalidate_instances` to `'subclass-instances'` to only revalidate instances of subclasses of the model. - ```python + ```py + from typing import List + from pydantic import BaseModel class User(BaseModel, revalidate_instances='subclass-instances'): # (1)! - hobbies: list[str] + hobbies: List[str] class SubUser(User): - sins: list[str] + sins: List[str] class Transaction(BaseModel): user: User @@ -598,33 +494,13 @@ class ConfigDict(TypedDict, total=False): - `'float'` will serialize timedeltas to the total number of seconds. """ - ser_json_bytes: Literal['utf8', 'base64', 'hex'] + ser_json_bytes: Literal['utf8', 'base64'] """ - The encoding of JSON serialized bytes. Defaults to `'utf8'`. - Set equal to `val_json_bytes` to get back an equal value after serialization round trip. + The encoding of JSON serialized bytes. Accepts the string values of `'utf8'` and `'base64'`. + Defaults to `'utf8'`. - `'utf8'` will serialize bytes to UTF-8 strings. - `'base64'` will serialize bytes to URL safe base64 strings. - - `'hex'` will serialize bytes to hexadecimal strings. - """ - - val_json_bytes: Literal['utf8', 'base64', 'hex'] - """ - The encoding of JSON serialized bytes to decode. Defaults to `'utf8'`. - Set equal to `ser_json_bytes` to get back an equal value after serialization round trip. - - - `'utf8'` will deserialize UTF-8 strings to bytes. - - `'base64'` will deserialize URL safe base64 strings to bytes. - - `'hex'` will deserialize hexadecimal strings to bytes. - """ - - ser_json_inf_nan: Literal['null', 'constants', 'strings'] - """ - The encoding of JSON serialized infinity and NaN float values. Defaults to `'null'`. - - - `'null'` will serialize infinity and NaN values as `null`. - - `'constants'` will serialize infinity and NaN values as `Infinity` and `NaN`. - - `'strings'` will serialize infinity as string `"Infinity"` and NaN as string `"NaN"`. """ # whether to validate default values during validation, default False @@ -632,26 +508,17 @@ class ConfigDict(TypedDict, total=False): """Whether to validate default values during validation. Defaults to `False`.""" validate_return: bool - """Whether to validate the return value from call validators. Defaults to `False`.""" + """whether to validate the return value from call validators. Defaults to `False`.""" - protected_namespaces: tuple[str | Pattern[str], ...] + protected_namespaces: tuple[str, ...] """ - A `tuple` of strings and/or patterns that prevent models from having fields with names that conflict with them. - For strings, we match on a prefix basis. Ex, if 'dog' is in the protected namespace, 'dog_name' will be protected. - For patterns, we match on the entire field name. Ex, if `re.compile(r'^dog$')` is in the protected namespace, 'dog' will be protected, but 'dog_name' will not be. - Defaults to `('model_validate', 'model_dump',)`. + A `tuple` of strings that prevent model to have field which conflict with them. + Defaults to `('model_', )`). - The reason we've selected these is to prevent collisions with other validation / dumping formats - in the future - ex, `model_validate_{some_newly_supported_format}`. + Pydantic prevents collisions between model attributes and `BaseModel`'s own methods by + namespacing them with the prefix `model_`. - Before v2.10, Pydantic used `('model_',)` as the default value for this setting to - prevent collisions between model attributes and `BaseModel`'s own methods. This was changed - in v2.10 given feedback that this restriction was limiting in AI and data science contexts, - where it is common to have fields with names like `model_id`, `model_input`, `model_output`, etc. - - For more details, see https://github.com/pydantic/pydantic/issues/10315. - - ```python + ```py import warnings from pydantic import BaseModel @@ -661,65 +528,56 @@ class ConfigDict(TypedDict, total=False): try: class Model(BaseModel): - model_dump_something: str + model_prefixed_field: str except UserWarning as e: print(e) ''' - Field "model_dump_something" in Model has conflict with protected namespace "model_dump". + Field "model_prefixed_field" has conflict with protected namespace "model_". - You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('model_validate',)`. + You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`. ''' ``` You can customize this behavior using the `protected_namespaces` setting: - ```python {test="skip"} - import re + ```py import warnings from pydantic import BaseModel, ConfigDict - with warnings.catch_warnings(record=True) as caught_warnings: - warnings.simplefilter('always') # Catch all warnings + warnings.filterwarnings('error') # Raise warnings as errors + + try: class Model(BaseModel): - safe_field: str + model_prefixed_field: str also_protect_field: str - protect_this: str model_config = ConfigDict( - protected_namespaces=( - 'protect_me_', - 'also_protect_', - re.compile('^protect_this$'), - ) + protected_namespaces=('protect_me_', 'also_protect_') ) - for warning in caught_warnings: - print(f'{warning.message}') + except UserWarning as e: + print(e) ''' - Field "also_protect_field" in Model has conflict with protected namespace "also_protect_". - You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', re.compile('^protect_this$'))`. + Field "also_protect_field" has conflict with protected namespace "also_protect_". - Field "protect_this" in Model has conflict with protected namespace "re.compile('^protect_this$')". - You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_', 'also_protect_')`. + You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_',)`. ''' ``` While Pydantic will only emit a warning when an item is in a protected namespace but does not actually have a collision, an error _is_ raised if there is an actual collision with an existing attribute: - ```python - from pydantic import BaseModel, ConfigDict + ```py + from pydantic import BaseModel try: class Model(BaseModel): model_validate: str - model_config = ConfigDict(protected_namespaces=('model_',)) - except NameError as e: print(e) ''' @@ -734,7 +592,7 @@ class ConfigDict(TypedDict, total=False): Pydantic shows the input value and type when it raises `ValidationError` during the validation. - ```python + ```py from pydantic import BaseModel, ValidationError class Model(BaseModel): @@ -753,7 +611,7 @@ class ConfigDict(TypedDict, total=False): You can hide the input value and type by setting the `hide_input_in_errors` config to `True`. - ```python + ```py from pydantic import BaseModel, ConfigDict, ValidationError class Model(BaseModel): @@ -774,26 +632,27 @@ class ConfigDict(TypedDict, total=False): defer_build: bool """ - Whether to defer model validator and serializer construction until the first model validation. Defaults to False. + Whether to defer model validator and serializer construction until the first model validation. This can be useful to avoid the overhead of building models which are only used nested within other models, or when you want to manually define type namespace via - [`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild]. - - Since v2.10, this setting also applies to pydantic dataclasses and TypeAdapter instances. + [`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild]. Defaults to False. """ plugin_settings: dict[str, object] | None - """A `dict` of settings for plugins. Defaults to `None`.""" + """A `dict` of settings for plugins. Defaults to `None`. + + See [Pydantic Plugins](../concepts/plugins.md) for details. + """ schema_generator: type[_GenerateSchema] | None """ - !!! warning - `schema_generator` is deprecated in v2.10. + A custom core schema generator class to use when generating JSON schemas. + Useful if you want to change the way types are validated across an entire model/schema. Defaults to `None`. - Prior to v2.10, this setting was advertised as highly subject to change. - It's possible that this interface may once again become public once the internal core schema generation - API is more stable, but that will likely come after significant performance improvements have been made. + The `GenerateSchema` interface is subject to change, currently only the `string_schema` method is public. + + See [#6737](https://github.com/pydantic/pydantic/pull/6737) for details. """ json_schema_serialization_defaults_required: bool @@ -807,7 +666,7 @@ class ConfigDict(TypedDict, total=False): between validation and serialization, and don't mind fields with defaults being marked as not required during serialization. See [#7209](https://github.com/pydantic/pydantic/issues/7209) for more details. - ```python + ```py from pydantic import BaseModel, ConfigDict class Model(BaseModel): @@ -850,7 +709,7 @@ class ConfigDict(TypedDict, total=False): the validation and serialization schemas (since both will use the specified schema), and so prevents the suffixes from being added to the definition references. - ```python + ```py from pydantic import BaseModel, ConfigDict, Json class Model(BaseModel): @@ -896,7 +755,7 @@ class ConfigDict(TypedDict, total=False): Pydantic doesn't allow number types (`int`, `float`, `Decimal`) to be coerced as type `str` by default. - ```python + ```py from decimal import Decimal from pydantic import BaseModel, ConfigDict, ValidationError @@ -928,286 +787,5 @@ class ConfigDict(TypedDict, total=False): ``` """ - regex_engine: Literal['rust-regex', 'python-re'] - """ - The regex engine to be used for pattern validation. - Defaults to `'rust-regex'`. - - - `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate, - which is non-backtracking and therefore more DDoS resistant, but does not support all regex features. - - `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module, - which supports all regex features, but may be slower. - - !!! note - If you use a compiled regex pattern, the python-re engine will be used regardless of this setting. - This is so that flags such as `re.IGNORECASE` are respected. - - ```python - from pydantic import BaseModel, ConfigDict, Field, ValidationError - - class Model(BaseModel): - model_config = ConfigDict(regex_engine='python-re') - - value: str = Field(pattern=r'^abc(?=def)') - - print(Model(value='abcdef').value) - #> abcdef - - try: - print(Model(value='abxyzcdef')) - except ValidationError as e: - print(e) - ''' - 1 validation error for Model - value - String should match pattern '^abc(?=def)' [type=string_pattern_mismatch, input_value='abxyzcdef', input_type=str] - ''' - ``` - """ - - validation_error_cause: bool - """ - If `True`, Python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`. - - Note: - Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`. - - Note: - The structure of validation errors are likely to change in future Pydantic versions. Pydantic offers no guarantees about their structure. Should be used for visual traceback debugging only. - """ - - use_attribute_docstrings: bool - ''' - Whether docstrings of attributes (bare string literals immediately following the attribute declaration) - should be used for field descriptions. Defaults to `False`. - - Available in Pydantic v2.7+. - - ```python - from pydantic import BaseModel, ConfigDict, Field - - - class Model(BaseModel): - model_config = ConfigDict(use_attribute_docstrings=True) - - x: str - """ - Example of an attribute docstring - """ - - y: int = Field(description="Description in Field") - """ - Description in Field overrides attribute docstring - """ - - - print(Model.model_fields["x"].description) - # > Example of an attribute docstring - print(Model.model_fields["y"].description) - # > Description in Field - ``` - This requires the source code of the class to be available at runtime. - - !!! warning "Usage with `TypedDict` and stdlib dataclasses" - Due to current limitations, attribute docstrings detection may not work as expected when using - [`TypedDict`][typing.TypedDict] and stdlib dataclasses, in particular when: - - - inheritance is being used. - - multiple classes have the same name in the same source file. - ''' - - cache_strings: bool | Literal['all', 'keys', 'none'] - """ - Whether to cache strings to avoid constructing new Python objects. Defaults to True. - - Enabling this setting should significantly improve validation performance while increasing memory usage slightly. - - - `True` or `'all'` (the default): cache all strings - - `'keys'`: cache only dictionary keys - - `False` or `'none'`: no caching - - !!! note - `True` or `'all'` is required to cache strings during general validation because - validators don't know if they're in a key or a value. - - !!! tip - If repeated strings are rare, it's recommended to use `'keys'` or `'none'` to reduce memory usage, - as the performance difference is minimal if repeated strings are rare. - """ - - validate_by_alias: bool - """ - Whether an aliased field may be populated by its alias. Defaults to `True`. - - !!! note - In v2.11, `validate_by_alias` was introduced in conjunction with [`validate_by_name`][pydantic.ConfigDict.validate_by_name] - to empower users with more fine grained validation control. In my_field='foo' - ``` - - 1. The field `'my_field'` has an alias `'my_alias'`. - 2. The model can only be populated by the attribute name `'my_field'`. - - !!! warning - You cannot set both `validate_by_alias` and `validate_by_name` to `False`. - This would make it impossible to populate an attribute. - - See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example. - - If you set `validate_by_alias` to `False`, under the hood, Pydantic dynamically sets - `validate_by_name` to `True` to ensure that validation can still occur. - """ - - validate_by_name: bool - """ - Whether an aliased field may be populated by its name as given by the model - attribute. Defaults to `False`. - - !!! note - In v2.0-v2.10, the `populate_by_name` configuration setting was used to specify - whether or not a field could be populated by its name **and** alias. - - In v2.11, `validate_by_name` was introduced in conjunction with [`validate_by_alias`][pydantic.ConfigDict.validate_by_alias] - to empower users with more fine grained validation behavior control. - - ```python - from pydantic import BaseModel, ConfigDict, Field - - class Model(BaseModel): - model_config = ConfigDict(validate_by_name=True, validate_by_alias=True) - - my_field: str = Field(validation_alias='my_alias') # (1)! - - m = Model(my_alias='foo') # (2)! - print(m) - #> my_field='foo' - - m = Model(my_field='foo') # (3)! - print(m) - #> my_field='foo' - ``` - - 1. The field `'my_field'` has an alias `'my_alias'`. - 2. The model is populated by the alias `'my_alias'`. - 3. The model is populated by the attribute name `'my_field'`. - - !!! warning - You cannot set both `validate_by_alias` and `validate_by_name` to `False`. - This would make it impossible to populate an attribute. - - See [usage errors](../errors/usage_errors.md#validate-by-alias-and-name-false) for an example. - """ - - serialize_by_alias: bool - """ - Whether an aliased field should be serialized by its alias. Defaults to `False`. - - Note: In v2.11, `serialize_by_alias` was introduced to address the - [popular request](https://github.com/pydantic/pydantic/issues/8379) - for consistency with alias behavior for validation and serialization settings. - In v3, the default value is expected to change to `True` for consistency with the validation default. - - ```python - from pydantic import BaseModel, ConfigDict, Field - - class Model(BaseModel): - model_config = ConfigDict(serialize_by_alias=True) - - my_field: str = Field(serialization_alias='my_alias') # (1)! - - m = Model(my_field='foo') - print(m.model_dump()) # (2)! - #> {'my_alias': 'foo'} - ``` - - 1. The field `'my_field'` has an alias `'my_alias'`. - 2. The model is serialized using the alias `'my_alias'` for the `'my_field'` attribute. - """ - - -_TypeT = TypeVar('_TypeT', bound=type) - - -@overload -@deprecated('Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead.') -def with_config(*, config: ConfigDict) -> Callable[[_TypeT], _TypeT]: ... - - -@overload -def with_config(config: ConfigDict, /) -> Callable[[_TypeT], _TypeT]: ... - - -@overload -def with_config(**config: Unpack[ConfigDict]) -> Callable[[_TypeT], _TypeT]: ... - - -def with_config(config: ConfigDict | None = None, /, **kwargs: Any) -> Callable[[_TypeT], _TypeT]: - """!!! abstract "Usage Documentation" - [Configuration with other types](../concepts/config.md#configuration-on-other-supported-types) - - A convenience decorator to set a [Pydantic configuration](config.md) on a `TypedDict` or a `dataclass` from the standard library. - - Although the configuration can be set using the `__pydantic_config__` attribute, it does not play well with type checkers, - especially with `TypedDict`. - - !!! example "Usage" - - ```python - from typing_extensions import TypedDict - - from pydantic import ConfigDict, TypeAdapter, with_config - - @with_config(ConfigDict(str_to_lower=True)) - class TD(TypedDict): - x: str - - ta = TypeAdapter(TD) - - print(ta.validate_python({'x': 'ABC'})) - #> {'x': 'abc'} - ``` - """ - if config is not None and kwargs: - raise ValueError('Cannot specify both `config` and keyword arguments') - - if len(kwargs) == 1 and (kwargs_conf := kwargs.get('config')) is not None: - warnings.warn( - 'Passing `config` as a keyword argument is deprecated. Pass `config` as a positional argument instead', - category=PydanticDeprecatedSince211, - stacklevel=2, - ) - final_config = cast(ConfigDict, kwargs_conf) - else: - final_config = config if config is not None else cast(ConfigDict, kwargs) - - def inner(class_: _TypeT, /) -> _TypeT: - # Ideally, we would check for `class_` to either be a `TypedDict` or a stdlib dataclass. - # However, the `@with_config` decorator can be applied *after* `@dataclass`. To avoid - # common mistakes, we at least check for `class_` to not be a Pydantic model. - from ._internal._utils import is_model_class - - if is_model_class(class_): - raise PydanticUserError( - f'Cannot use `with_config` on {class_.__name__} as it is a Pydantic model', - code='with-config-on-model', - ) - class_.__pydantic_config__ = final_config - return class_ - - return inner - __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/dataclasses.py b/venv/lib/python3.12/site-packages/pydantic/dataclasses.py index 4e42d65..11e1614 100644 --- a/venv/lib/python3.12/site-packages/pydantic/dataclasses.py +++ b/venv/lib/python3.12/site-packages/pydantic/dataclasses.py @@ -1,25 +1,21 @@ """Provide an enhanced dataclass that performs validation.""" - from __future__ import annotations as _annotations import dataclasses import sys import types -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, NoReturn, TypeVar, overload -from warnings import warn +from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload -from typing_extensions import TypeGuard, dataclass_transform +from typing_extensions import Literal, TypeGuard, dataclass_transform -from ._internal import _config, _decorators, _namespace_utils, _typing_extra +from ._internal import _config, _decorators, _typing_extra from ._internal import _dataclasses as _pydantic_dataclasses from ._migration import getattr_migration from .config import ConfigDict -from .errors import PydanticUserError -from .fields import Field, FieldInfo, PrivateAttr +from .fields import Field if TYPE_CHECKING: from ._internal._dataclasses import PydanticDataclass - from ._internal._namespace_utils import MappingNamespace __all__ = 'dataclass', 'rebuild_dataclass' @@ -27,7 +23,7 @@ _T = TypeVar('_T') if sys.version_info >= (3, 10): - @dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( *, @@ -44,7 +40,7 @@ if sys.version_info >= (3, 10): ) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore ... - @dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( _cls: type[_T], # type: ignore @@ -54,16 +50,17 @@ if sys.version_info >= (3, 10): eq: bool = True, order: bool = False, unsafe_hash: bool = False, - frozen: bool | None = None, + frozen: bool = False, config: ConfigDict | type[object] | None = None, validate_on_init: bool | None = None, kw_only: bool = ..., slots: bool = ..., - ) -> type[PydanticDataclass]: ... + ) -> type[PydanticDataclass]: + ... else: - @dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( *, @@ -72,13 +69,13 @@ else: eq: bool = True, order: bool = False, unsafe_hash: bool = False, - frozen: bool | None = None, + frozen: bool = False, config: ConfigDict | type[object] | None = None, validate_on_init: bool | None = None, ) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore ... - @dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( _cls: type[_T], # type: ignore @@ -88,13 +85,14 @@ else: eq: bool = True, order: bool = False, unsafe_hash: bool = False, - frozen: bool | None = None, + frozen: bool = False, config: ConfigDict | type[object] | None = None, validate_on_init: bool | None = None, - ) -> type[PydanticDataclass]: ... + ) -> type[PydanticDataclass]: + ... -@dataclass_transform(field_specifiers=(dataclasses.field, Field, PrivateAttr)) +@dataclass_transform(field_specifiers=(dataclasses.field, Field)) def dataclass( _cls: type[_T] | None = None, *, @@ -103,14 +101,13 @@ def dataclass( eq: bool = True, order: bool = False, unsafe_hash: bool = False, - frozen: bool | None = None, + frozen: bool = False, config: ConfigDict | type[object] | None = None, validate_on_init: bool | None = None, kw_only: bool = False, slots: bool = False, ) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]: - """!!! abstract "Usage Documentation" - [`dataclasses`](../concepts/dataclasses.md) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/dataclasses/ A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`, but with added validation. @@ -122,13 +119,13 @@ def dataclass( init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to `dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its own `__init__` function. - repr: A boolean indicating whether to include the field in the `__repr__` output. - eq: Determines if a `__eq__` method should be generated for the class. + repr: A boolean indicating whether or not to include the field in the `__repr__` output. + eq: Determines if a `__eq__` should be generated for the class. order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`. - unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`. + unsafe_hash: Determines if an unsafe hashing function should be included in the class. frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its - attributes to be modified after it has been initialized. If not set, the value from the provided `config` argument will be used (and will default to `False` otherwise). - config: The Pydantic config to use for the `dataclass`. + attributes to be modified from its constructor. + config: A configuration for the `dataclass` generation. validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses are validated on init. kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`. @@ -145,43 +142,10 @@ def dataclass( assert validate_on_init is not False, 'validate_on_init=False is no longer supported' if sys.version_info >= (3, 10): - kwargs = {'kw_only': kw_only, 'slots': slots} + kwargs = dict(kw_only=kw_only, slots=slots) else: kwargs = {} - def make_pydantic_fields_compatible(cls: type[Any]) -> None: - """Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only` - To do that, we simply change - `x: int = pydantic.Field(..., kw_only=True)` - into - `x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)` - """ - for annotation_cls in cls.__mro__: - annotations: dict[str, Any] = getattr(annotation_cls, '__annotations__', {}) - for field_name in annotations: - field_value = getattr(cls, field_name, None) - # Process only if this is an instance of `FieldInfo`. - if not isinstance(field_value, FieldInfo): - continue - - # Initialize arguments for the standard `dataclasses.field`. - field_args: dict = {'default': field_value} - - # Handle `kw_only` for Python 3.10+ - if sys.version_info >= (3, 10) and field_value.kw_only: - field_args['kw_only'] = True - - # Set `repr` attribute if it's explicitly specified to be not `True`. - if field_value.repr is not True: - field_args['repr'] = field_value.repr - - setattr(cls, field_name, dataclasses.field(**field_args)) - # In Python 3.9, when subclassing, information is pulled from cls.__dict__['__annotations__'] - # for annotations, so we must make sure it's initialized before we add to it. - if cls.__dict__.get('__annotations__') is None: - cls.__annotations__ = {} - cls.__annotations__[field_name] = annotations[field_name] - def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]: """Create a Pydantic dataclass from a regular dataclass. @@ -191,29 +155,14 @@ def dataclass( Returns: A Pydantic dataclass. """ - from ._internal._utils import is_model_class - - if is_model_class(cls): - raise PydanticUserError( - f'Cannot create a Pydantic dataclass from {cls.__name__} as it is already a Pydantic model', - code='dataclass-on-model', - ) - original_cls = cls - # we warn on conflicting config specifications, but only if the class doesn't have a dataclass base - # because a dataclass base might provide a __pydantic_config__ attribute that we don't want to warn about - has_dataclass_base = any(dataclasses.is_dataclass(base) for base in cls.__bases__) - if not has_dataclass_base and config is not None and hasattr(cls, '__pydantic_config__'): - warn( - f'`config` is set via both the `dataclass` decorator and `__pydantic_config__` for dataclass {cls.__name__}. ' - f'The `config` specification from `dataclass` decorator will take priority.', - category=UserWarning, - stacklevel=2, - ) - - # if config is not explicitly provided, try to read it from the type - config_dict = config if config is not None else getattr(cls, '__pydantic_config__', None) + config_dict = config + if config_dict is None: + # if not explicitly provided, read from the type + cls_config = getattr(cls, '__pydantic_config__', None) + if cls_config is not None: + config_dict = cls_config config_wrapper = _config.ConfigWrapper(config_dict) decorators = _decorators.DecoratorInfos.build(cls) @@ -236,22 +185,6 @@ def dataclass( bases = bases + (generic_base,) cls = types.new_class(cls.__name__, bases) - make_pydantic_fields_compatible(cls) - - # Respect frozen setting from dataclass constructor and fallback to config setting if not provided - if frozen is not None: - frozen_ = frozen - if config_wrapper.frozen: - # It's not recommended to define both, as the setting from the dataclass decorator will take priority. - warn( - f'`frozen` is set via both the `dataclass` decorator and `config` for dataclass {cls.__name__!r}.' - 'This is not recommended. The `frozen` specification on `dataclass` will take priority.', - category=UserWarning, - stacklevel=2, - ) - else: - frozen_ = config_wrapper.frozen or False - cls = dataclasses.dataclass( # type: ignore[call-overload] cls, # the value of init here doesn't affect anything except that it makes it easier to generate a signature @@ -260,40 +193,29 @@ def dataclass( eq=eq, order=order, unsafe_hash=unsafe_hash, - frozen=frozen_, + frozen=frozen, **kwargs, ) - # This is an undocumented attribute to distinguish stdlib/Pydantic dataclasses. - # It should be set as early as possible: - cls.__is_pydantic_dataclass__ = True - cls.__pydantic_decorators__ = decorators # type: ignore cls.__doc__ = original_doc cls.__module__ = original_cls.__module__ cls.__qualname__ = original_cls.__qualname__ - cls.__pydantic_fields_complete__ = classmethod(_pydantic_fields_complete) - cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful. - # TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models: - # fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function), - # and possibly cache it (see the `__pydantic_parent_namespace__` logic for models). - _pydantic_dataclasses.complete_dataclass(cls, config_wrapper, raise_errors=False) + pydantic_complete = _pydantic_dataclasses.complete_dataclass( + cls, config_wrapper, raise_errors=False, types_namespace=None + ) + cls.__pydantic_complete__ = pydantic_complete # type: ignore return cls - return create_dataclass if _cls is None else create_dataclass(_cls) + if _cls is None: + return create_dataclass - -def _pydantic_fields_complete(cls: type[PydanticDataclass]) -> bool: - """Return whether the fields where successfully collected (i.e. type hints were successfully resolves). - - This is a private property, not meant to be used outside Pydantic. - """ - return all(field_info._complete for field_info in cls.__pydantic_fields__.values()) + return create_dataclass(_cls) __getattr__ = getattr_migration(__name__) -if sys.version_info < (3, 11): +if (3, 8) <= sys.version_info < (3, 11): # Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints # Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable. @@ -313,7 +235,7 @@ def rebuild_dataclass( force: bool = False, raise_errors: bool = True, _parent_namespace_depth: int = 2, - _types_namespace: MappingNamespace | None = None, + _types_namespace: dict[str, Any] | None = None, ) -> bool | None: """Try to rebuild the pydantic-core schema for the dataclass. @@ -323,8 +245,8 @@ def rebuild_dataclass( This is analogous to `BaseModel.model_rebuild`. Args: - cls: The class to rebuild the pydantic-core schema for. - force: Whether to force the rebuilding of the schema, defaults to `False`. + cls: The class to build the dataclass core schema for. + force: Whether to force the rebuilding of the model schema, defaults to `False`. raise_errors: Whether to raise errors, defaults to `True`. _parent_namespace_depth: The depth level of the parent namespace, defaults to 2. _types_namespace: The types namespace, defaults to `None`. @@ -335,49 +257,34 @@ def rebuild_dataclass( """ if not force and cls.__pydantic_complete__: return None - - for attr in ('__pydantic_core_schema__', '__pydantic_validator__', '__pydantic_serializer__'): - if attr in cls.__dict__: - # Deleting the validator/serializer is necessary as otherwise they can get reused in - # pydantic-core. Same applies for the core schema that can be reused in schema generation. - delattr(cls, attr) - - cls.__pydantic_complete__ = False - - if _types_namespace is not None: - rebuild_ns = _types_namespace - elif _parent_namespace_depth > 0: - rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {} else: - rebuild_ns = {} + if _types_namespace is not None: + types_namespace: dict[str, Any] | None = _types_namespace.copy() + else: + if _parent_namespace_depth > 0: + frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {} + # Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel + # here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference. + types_namespace = frame_parent_ns + else: + types_namespace = {} - ns_resolver = _namespace_utils.NsResolver( - parent_namespace=rebuild_ns, - ) - - return _pydantic_dataclasses.complete_dataclass( - cls, - _config.ConfigWrapper(cls.__pydantic_config__, check=False), - raise_errors=raise_errors, - ns_resolver=ns_resolver, - # We could provide a different config instead (with `'defer_build'` set to `True`) - # of this explicit `_force_build` argument, but because config can come from the - # decorator parameter or the `__pydantic_config__` attribute, `complete_dataclass` - # will overwrite `__pydantic_config__` with the provided config above: - _force_build=True, - ) + types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace) + return _pydantic_dataclasses.complete_dataclass( + cls, + _config.ConfigWrapper(cls.__pydantic_config__, check=False), + raise_errors=raise_errors, + types_namespace=types_namespace, + ) -def is_pydantic_dataclass(class_: type[Any], /) -> TypeGuard[type[PydanticDataclass]]: +def is_pydantic_dataclass(__cls: type[Any]) -> TypeGuard[type[PydanticDataclass]]: """Whether a class is a pydantic dataclass. Args: - class_: The class. + __cls: The class. Returns: `True` if the class is a pydantic dataclass, `False` otherwise. """ - try: - return '__is_pydantic_dataclass__' in class_.__dict__ and dataclasses.is_dataclass(class_) - except AttributeError: - return False + return dataclasses.is_dataclass(__cls) and '__pydantic_validator__' in __cls.__dict__ diff --git a/venv/lib/python3.12/site-packages/pydantic/datetime_parse.py b/venv/lib/python3.12/site-packages/pydantic/datetime_parse.py index 53d5264..902219d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/datetime_parse.py +++ b/venv/lib/python3.12/site-packages/pydantic/datetime_parse.py @@ -1,5 +1,4 @@ """The `datetime_parse` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/decorator.py b/venv/lib/python3.12/site-packages/pydantic/decorator.py index 0d97560..c364346 100644 --- a/venv/lib/python3.12/site-packages/pydantic/decorator.py +++ b/venv/lib/python3.12/site-packages/pydantic/decorator.py @@ -1,5 +1,4 @@ """The `decorator` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/class_validators.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/class_validators.py index ad92864..dc65e75 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/class_validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/class_validators.py @@ -4,10 +4,10 @@ from __future__ import annotations as _annotations from functools import partial, partialmethod from types import FunctionType -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload from warnings import warn -from typing_extensions import Protocol, TypeAlias, deprecated +from typing_extensions import Literal, Protocol, TypeAlias from .._internal import _decorators, _decorators_v1 from ..errors import PydanticUserError @@ -19,24 +19,30 @@ _ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored; if TYPE_CHECKING: class _OnlyValueValidatorClsMethod(Protocol): - def __call__(self, __cls: Any, __value: Any) -> Any: ... + def __call__(self, __cls: Any, __value: Any) -> Any: + ... class _V1ValidatorWithValuesClsMethod(Protocol): - def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: ... + def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: + ... class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol): - def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: ... + def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: + ... class _V1ValidatorWithKwargsClsMethod(Protocol): - def __call__(self, __cls: Any, **kwargs: Any) -> Any: ... + def __call__(self, __cls: Any, **kwargs: Any) -> Any: + ... class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol): - def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: ... + def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: + ... class _V1RootValidatorClsMethod(Protocol): def __call__( self, __cls: Any, __values: _decorators_v1.RootValidatorValues - ) -> _decorators_v1.RootValidatorValues: ... + ) -> _decorators_v1.RootValidatorValues: + ... V1Validator = Union[ _OnlyValueValidatorClsMethod, @@ -73,12 +79,6 @@ else: DeprecationWarning = PydanticDeprecatedSince20 -@deprecated( - 'Pydantic V1 style `@validator` validators are deprecated.' - ' You should migrate to Pydantic V2 style `@field_validator` validators,' - ' see the migration guide for more details', - category=None, -) def validator( __field: str, *fields: str, @@ -94,7 +94,7 @@ def validator( __field (str): The first field the validator should be called on; this is separate from `fields` to ensure an error is raised if you don't pass at least one. *fields (str): Additional field(s) the validator should be called on. - pre (bool, optional): Whether this validator should be called before the standard + pre (bool, optional): Whether or not this validator should be called before the standard validators (else after). Defaults to False. each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate individual elements rather than the whole object. Defaults to False. @@ -109,6 +109,22 @@ def validator( Callable: A decorator that can be used to decorate a function to be used as a validator. """ + if allow_reuse is True: # pragma: no cover + warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning) + fields = tuple((__field, *fields)) + if isinstance(fields[0], FunctionType): + raise PydanticUserError( + "`@validator` should be used with fields and keyword arguments, not bare. " + "E.g. usage should be `@validator('', ...)`", + code='validator-no-fields', + ) + elif not all(isinstance(field, str) for field in fields): + raise PydanticUserError( + "`@validator` fields should be passed as separate string args. " + "E.g. usage should be `@validator('', '', ...)`", + code='validator-invalid-fields', + ) + warn( 'Pydantic V1 style `@validator` validators are deprecated.' ' You should migrate to Pydantic V2 style `@field_validator` validators,' @@ -117,22 +133,6 @@ def validator( stacklevel=2, ) - if allow_reuse is True: # pragma: no cover - warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning) - fields = __field, *fields - if isinstance(fields[0], FunctionType): - raise PydanticUserError( - '`@validator` should be used with fields and keyword arguments, not bare. ' - "E.g. usage should be `@validator('', ...)`", - code='validator-no-fields', - ) - elif not all(isinstance(field, str) for field in fields): - raise PydanticUserError( - '`@validator` fields should be passed as separate string args. ' - "E.g. usage should be `@validator('', '', ...)`", - code='validator-invalid-fields', - ) - mode: Literal['before', 'after'] = 'before' if pre is True else 'after' def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]: @@ -162,10 +162,8 @@ def root_validator( # which means you need to specify `skip_on_failure=True` skip_on_failure: Literal[True], allow_reuse: bool = ..., -) -> Callable[ - [_V1RootValidatorFunctionType], - _V1RootValidatorFunctionType, -]: ... +) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]: + ... @overload @@ -175,10 +173,8 @@ def root_validator( # `skip_on_failure`, in fact it is not allowed as an argument! pre: Literal[True], allow_reuse: bool = ..., -) -> Callable[ - [_V1RootValidatorFunctionType], - _V1RootValidatorFunctionType, -]: ... +) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]: + ... @overload @@ -189,18 +185,10 @@ def root_validator( pre: Literal[False], skip_on_failure: Literal[True], allow_reuse: bool = ..., -) -> Callable[ - [_V1RootValidatorFunctionType], - _V1RootValidatorFunctionType, -]: ... +) -> Callable[[_V1RootValidatorFunctionType], _V1RootValidatorFunctionType,]: + ... -@deprecated( - 'Pydantic V1 style `@root_validator` validators are deprecated.' - ' You should migrate to Pydantic V2 style `@model_validator` validators,' - ' see the migration guide for more details', - category=None, -) def root_validator( *__args, pre: bool = False, diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/config.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/config.py index bd4692a..7409847 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/config.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/config.py @@ -1,9 +1,9 @@ from __future__ import annotations as _annotations import warnings -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any -from typing_extensions import deprecated +from typing_extensions import Literal, deprecated from .._internal import _config from ..warnings import PydanticDeprecatedSince20 @@ -18,10 +18,10 @@ __all__ = 'BaseConfig', 'Extra' class _ConfigMetaclass(type): def __getattr__(self, item: str) -> Any: + warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) + try: - obj = _config.config_defaults[item] - warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) - return obj + return _config.config_defaults[item] except KeyError as exc: raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc @@ -35,10 +35,9 @@ class BaseConfig(metaclass=_ConfigMetaclass): """ def __getattr__(self, item: str) -> Any: + warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) try: - obj = super().__getattribute__(item) - warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) - return obj + return super().__getattribute__(item) except AttributeError as exc: try: return getattr(type(self), item) diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/copy_internals.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/copy_internals.py index 1b0dc12..efe5de2 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/copy_internals.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/copy_internals.py @@ -3,7 +3,7 @@ from __future__ import annotations as _annotations import typing from copy import deepcopy from enum import Enum -from typing import Any +from typing import Any, Tuple import typing_extensions @@ -18,7 +18,7 @@ if typing.TYPE_CHECKING: from .._internal._utils import AbstractSetIntStr, MappingIntStrAny AnyClassMethod = classmethod[Any, Any, Any] - TupleGenerator = typing.Generator[tuple[str, Any], None, None] + TupleGenerator = typing.Generator[Tuple[str, Any], None, None] Model = typing.TypeVar('Model', bound='BaseModel') # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None' @@ -40,11 +40,11 @@ def _iter( # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. if exclude is not None: exclude = _utils.ValueItems.merge( - {k: v.exclude for k, v in self.__pydantic_fields__.items() if v.exclude is not None}, exclude + {k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude ) if include is not None: - include = _utils.ValueItems.merge({k: True for k in self.__pydantic_fields__}, include, intersect=True) + include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True) allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): @@ -68,15 +68,15 @@ def _iter( if exclude_defaults: try: - field = self.__pydantic_fields__[field_key] + field = self.model_fields[field_key] except KeyError: pass else: if not field.is_required() and field.default == v: continue - if by_alias and field_key in self.__pydantic_fields__: - dict_key = self.__pydantic_fields__[field_key].alias or field_key + if by_alias and field_key in self.model_fields: + dict_key = self.model_fields[field_key].alias or field_key else: dict_key = field_key @@ -200,7 +200,7 @@ def _calculate_keys( include: MappingIntStrAny | None, exclude: MappingIntStrAny | None, exclude_unset: bool, - update: dict[str, Any] | None = None, # noqa UP006 + update: typing.Dict[str, Any] | None = None, # noqa UP006 ) -> typing.AbstractSet[str] | None: if include is None and exclude is None and exclude_unset is False: return None diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/decorator.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/decorator.py index e73ad20..11244ba 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/decorator.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/decorator.py @@ -1,7 +1,6 @@ import warnings -from collections.abc import Mapping from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload from typing_extensions import deprecated @@ -23,29 +22,29 @@ if TYPE_CHECKING: AnyCallable = Callable[..., Any] AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) - ConfigType = Union[None, type[Any], dict[str, Any]] + ConfigType = Union[None, Type[Any], Dict[str, Any]] @overload -def validate_arguments( - func: None = None, *, config: 'ConfigType' = None -) -> Callable[['AnyCallableT'], 'AnyCallableT']: ... - - -@overload -def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ... - - @deprecated( - 'The `validate_arguments` method is deprecated; use `validate_call` instead.', - category=None, + 'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20 ) +def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: + ... + + +@overload +@deprecated( + 'The `validate_arguments` method is deprecated; use `validate_call` instead.', category=PydanticDeprecatedSince20 +) +def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': + ... + + def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: """Decorator to validate the arguments passed to a function.""" warnings.warn( - 'The `validate_arguments` method is deprecated; use `validate_call` instead.', - PydanticDeprecatedSince20, - stacklevel=2, + 'The `validate_arguments` method is deprecated; use `validate_call` instead.', DeprecationWarning, stacklevel=2 ) def validate(_func: 'AnyCallable') -> 'AnyCallable': @@ -87,7 +86,7 @@ class ValidatedFunction: ) self.raw_function = function - self.arg_mapping: dict[int, str] = {} + self.arg_mapping: Dict[int, str] = {} self.positional_only_args: set[str] = set() self.v_args_name = 'args' self.v_kwargs_name = 'kwargs' @@ -95,7 +94,7 @@ class ValidatedFunction: type_hints = _typing_extra.get_type_hints(function, include_extras=True) takes_args = False takes_kwargs = False - fields: dict[str, tuple[Any, Any]] = {} + fields: Dict[str, Tuple[Any, Any]] = {} for i, (name, p) in enumerate(parameters.items()): if p.annotation is p.empty: annotation = Any @@ -106,22 +105,22 @@ class ValidatedFunction: if p.kind == Parameter.POSITIONAL_ONLY: self.arg_mapping[i] = name fields[name] = annotation, default - fields[V_POSITIONAL_ONLY_NAME] = list[str], None + fields[V_POSITIONAL_ONLY_NAME] = List[str], None self.positional_only_args.add(name) elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: self.arg_mapping[i] = name fields[name] = annotation, default - fields[V_DUPLICATE_KWARGS] = list[str], None + fields[V_DUPLICATE_KWARGS] = List[str], None elif p.kind == Parameter.KEYWORD_ONLY: fields[name] = annotation, default elif p.kind == Parameter.VAR_POSITIONAL: self.v_args_name = name - fields[name] = tuple[annotation, ...], None + fields[name] = Tuple[annotation, ...], None takes_args = True else: assert p.kind == Parameter.VAR_KEYWORD, p.kind self.v_kwargs_name = name - fields[name] = dict[str, annotation], None + fields[name] = Dict[str, annotation], None takes_kwargs = True # these checks avoid a clash between "args" and a field with that name @@ -134,11 +133,11 @@ class ValidatedFunction: if not takes_args: # we add the field so validation below can raise the correct exception - fields[self.v_args_name] = list[Any], None + fields[self.v_args_name] = List[Any], None if not takes_kwargs: # same with kwargs - fields[self.v_kwargs_name] = dict[Any, Any], None + fields[self.v_kwargs_name] = Dict[Any, Any], None self.create_model(fields, takes_args, takes_kwargs, config) @@ -150,8 +149,8 @@ class ValidatedFunction: m = self.init_model_instance(*args, **kwargs) return self.execute(m) - def build_values(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> dict[str, Any]: - values: dict[str, Any] = {} + def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: + values: Dict[str, Any] = {} if args: arg_iter = enumerate(args) while True: @@ -166,15 +165,15 @@ class ValidatedFunction: values[self.v_args_name] = [a] + [a for _, a in arg_iter] break - var_kwargs: dict[str, Any] = {} + var_kwargs: Dict[str, Any] = {} wrong_positional_args = [] duplicate_kwargs = [] fields_alias = [ field.alias - for name, field in self.model.__pydantic_fields__.items() + for name, field in self.model.model_fields.items() if name not in (self.v_args_name, self.v_kwargs_name) ] - non_var_fields = set(self.model.__pydantic_fields__) - {self.v_args_name, self.v_kwargs_name} + non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name} for k, v in kwargs.items(): if k in non_var_fields or k in fields_alias: if k in self.positional_only_args: @@ -194,15 +193,11 @@ class ValidatedFunction: return values def execute(self, m: BaseModel) -> Any: - d = { - k: v - for k, v in m.__dict__.items() - if k in m.__pydantic_fields_set__ or m.__pydantic_fields__[k].default_factory - } + d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory} var_kwargs = d.pop(self.v_kwargs_name, {}) if self.v_args_name in d: - args_: list[Any] = [] + args_: List[Any] = [] in_kwargs = False kwargs = {} for name, value in d.items(): @@ -226,7 +221,7 @@ class ValidatedFunction: else: return self.raw_function(**d, **var_kwargs) - def create_model(self, fields: dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: + def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: pos_args = len(self.arg_mapping) config_wrapper = _config.ConfigWrapper(config) @@ -243,7 +238,7 @@ class ValidatedFunction: class DecoratorBaseModel(BaseModel): @field_validator(self.v_args_name, check_fields=False) @classmethod - def check_args(cls, v: Optional[list[Any]]) -> Optional[list[Any]]: + def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: if takes_args or v is None: return v @@ -251,7 +246,7 @@ class ValidatedFunction: @field_validator(self.v_kwargs_name, check_fields=False) @classmethod - def check_kwargs(cls, v: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if takes_kwargs or v is None: return v @@ -261,7 +256,7 @@ class ValidatedFunction: @field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False) @classmethod - def check_positional_only(cls, v: Optional[list[str]]) -> None: + def check_positional_only(cls, v: Optional[List[str]]) -> None: if v is None: return @@ -271,7 +266,7 @@ class ValidatedFunction: @field_validator(V_DUPLICATE_KWARGS, check_fields=False) @classmethod - def check_duplicate_kwargs(cls, v: Optional[list[str]]) -> None: + def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: if v is None: return diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/json.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/json.py index 1e216a7..d067353 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/json.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/json.py @@ -7,12 +7,11 @@ from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6 from pathlib import Path from re import Pattern from types import GeneratorType -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union from uuid import UUID from typing_extensions import deprecated -from .._internal._import_utils import import_cached_base_model from ..color import Color from ..networks import NameEmail from ..types import SecretBytes, SecretStr @@ -51,7 +50,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { bytes: lambda o: o.decode(), Color: str, datetime.date: isoformat, @@ -80,23 +79,18 @@ ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { @deprecated( - '`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.', - category=None, + 'pydantic_encoder is deprecated, use pydantic_core.to_jsonable_python instead.', category=PydanticDeprecatedSince20 ) def pydantic_encoder(obj: Any) -> Any: - warnings.warn( - '`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) from dataclasses import asdict, is_dataclass - BaseModel = import_cached_base_model() + from ..main import BaseModel + warnings.warn('pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2) if isinstance(obj, BaseModel): return obj.model_dump() elif is_dataclass(obj): - return asdict(obj) # type: ignore + return asdict(obj) # Check the class type and its superclasses for a matching encoder for base in obj.__class__.__mro__[:-1]: @@ -110,17 +104,12 @@ def pydantic_encoder(obj: Any) -> Any: # TODO: Add a suggested migration path once there is a way to use custom encoders -@deprecated( - '`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.', - category=None, -) -def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any) -> Any: - warnings.warn( - '`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) +@deprecated('custom_pydantic_encoder is deprecated.', category=PydanticDeprecatedSince20) +def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any: # Check the class type and its superclasses for a matching encoder + warnings.warn( + 'custom_pydantic_encoder is deprecated, use BaseModel.model_dump instead.', DeprecationWarning, stacklevel=2 + ) for base in obj.__class__.__mro__[:-1]: try: encoder = type_encoders[base] @@ -132,10 +121,10 @@ def custom_pydantic_encoder(type_encoders: dict[Any, Callable[[type[Any]], Any]] return pydantic_encoder(obj) -@deprecated('`timedelta_isoformat` is deprecated.', category=None) +@deprecated('timedelta_isoformat is deprecated.', category=PydanticDeprecatedSince20) def timedelta_isoformat(td: datetime.timedelta) -> str: """ISO 8601 encoding for Python timedelta object.""" - warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + warnings.warn('timedelta_isoformat is deprecated.', DeprecationWarning, stacklevel=2) minutes, seconds = divmod(td.seconds, 60) hours, minutes = divmod(minutes, 60) return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S' diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/parse.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/parse.py index 2a92e62..12d0d06 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/parse.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/parse.py @@ -22,7 +22,7 @@ class Protocol(str, Enum): pickle = 'pickle' -@deprecated('`load_str_bytes` is deprecated.', category=None) +@deprecated('load_str_bytes is deprecated.', category=PydanticDeprecatedSince20) def load_str_bytes( b: str | bytes, *, @@ -32,7 +32,7 @@ def load_str_bytes( allow_pickle: bool = False, json_loads: Callable[[str], Any] = json.loads, ) -> Any: - warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + warnings.warn('load_str_bytes is deprecated.', DeprecationWarning, stacklevel=2) if proto is None and content_type: if content_type.endswith(('json', 'javascript')): pass @@ -46,17 +46,17 @@ def load_str_bytes( if proto == Protocol.json: if isinstance(b, bytes): b = b.decode(encoding) - return json_loads(b) # type: ignore + return json_loads(b) elif proto == Protocol.pickle: if not allow_pickle: raise RuntimeError('Trying to decode with pickle with allow_pickle=False') - bb = b if isinstance(b, bytes) else b.encode() # type: ignore + bb = b if isinstance(b, bytes) else b.encode() return pickle.loads(bb) else: raise TypeError(f'Unknown protocol: {proto}') -@deprecated('`load_file` is deprecated.', category=None) +@deprecated('load_file is deprecated.', category=PydanticDeprecatedSince20) def load_file( path: str | Path, *, @@ -66,7 +66,7 @@ def load_file( allow_pickle: bool = False, json_loads: Callable[[str], Any] = json.loads, ) -> Any: - warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + warnings.warn('load_file is deprecated.', DeprecationWarning, stacklevel=2) path = Path(path) b = path.read_bytes() if content_type is None: diff --git a/venv/lib/python3.12/site-packages/pydantic/deprecated/tools.py b/venv/lib/python3.12/site-packages/pydantic/deprecated/tools.py index 5ad7fae..2b05d38 100644 --- a/venv/lib/python3.12/site-packages/pydantic/deprecated/tools.py +++ b/venv/lib/python3.12/site-packages/pydantic/deprecated/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations import json import warnings -from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union from typing_extensions import deprecated @@ -17,20 +17,19 @@ if not TYPE_CHECKING: __all__ = 'parse_obj_as', 'schema_of', 'schema_json_of' -NameFactory = Union[str, Callable[[type[Any]], str]] +NameFactory = Union[str, Callable[[Type[Any]], str]] T = TypeVar('T') @deprecated( - '`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.', - category=None, + 'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.', category=PydanticDeprecatedSince20 ) def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T: warnings.warn( - '`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.', - category=PydanticDeprecatedSince20, + 'parse_obj_as is deprecated. Use pydantic.TypeAdapter.validate_python instead.', + DeprecationWarning, stacklevel=2, ) if type_name is not None: # pragma: no cover @@ -43,8 +42,7 @@ def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) @deprecated( - '`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', - category=None, + 'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20 ) def schema_of( type_: Any, @@ -56,9 +54,7 @@ def schema_of( ) -> dict[str, Any]: """Generate a JSON schema (as dict) for the passed model or dynamically generated one.""" warnings.warn( - '`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'schema_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2 ) res = TypeAdapter(type_).json_schema( by_alias=by_alias, @@ -79,8 +75,7 @@ def schema_of( @deprecated( - '`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', - category=None, + 'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', category=PydanticDeprecatedSince20 ) def schema_json_of( type_: Any, @@ -93,9 +88,7 @@ def schema_json_of( ) -> str: """Generate a JSON schema (as JSON) for the passed model or dynamically generated one.""" warnings.warn( - '`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'schema_json_of is deprecated. Use pydantic.TypeAdapter.json_schema instead.', DeprecationWarning, stacklevel=2 ) return json.dumps( schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator), diff --git a/venv/lib/python3.12/site-packages/pydantic/env_settings.py b/venv/lib/python3.12/site-packages/pydantic/env_settings.py index cd0b04e..662f590 100644 --- a/venv/lib/python3.12/site-packages/pydantic/env_settings.py +++ b/venv/lib/python3.12/site-packages/pydantic/env_settings.py @@ -1,5 +1,4 @@ """The `env_settings` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/error_wrappers.py b/venv/lib/python3.12/site-packages/pydantic/error_wrappers.py index 2985419..5144eee 100644 --- a/venv/lib/python3.12/site-packages/pydantic/error_wrappers.py +++ b/venv/lib/python3.12/site-packages/pydantic/error_wrappers.py @@ -1,5 +1,4 @@ """The `error_wrappers` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/errors.py b/venv/lib/python3.12/site-packages/pydantic/errors.py index f227068..ddca0e2 100644 --- a/venv/lib/python3.12/site-packages/pydantic/errors.py +++ b/venv/lib/python3.12/site-packages/pydantic/errors.py @@ -1,14 +1,9 @@ """Pydantic-specific errors.""" - from __future__ import annotations as _annotations import re -from typing import Any, ClassVar, Literal -from typing_extensions import Self -from typing_inspection.introspection import Qualifier - -from pydantic._internal import _repr +from typing_extensions import Literal, Self from ._migration import getattr_migration from .version import version_short @@ -19,7 +14,6 @@ __all__ = ( 'PydanticImportError', 'PydanticSchemaGenerationError', 'PydanticInvalidForJsonSchema', - 'PydanticForbiddenQualifier', 'PydanticErrorCodes', ) @@ -36,13 +30,11 @@ PydanticErrorCodes = Literal[ 'discriminator-needs-literal', 'discriminator-alias', 'discriminator-validator', - 'callable-discriminator-no-tag', 'typed-dict-version', 'model-field-overridden', 'model-field-missing-annotation', 'config-both', 'removed-kwargs', - 'circular-reference-schema', 'invalid-for-json-schema', 'json-schema-already-used', 'base-model-instantiated', @@ -50,10 +42,10 @@ PydanticErrorCodes = Literal[ 'schema-for-unknown-type', 'import-error', 'create-model-field-definitions', + 'create-model-config-base', 'validator-no-fields', 'validator-invalid-fields', 'validator-instance-method', - 'validator-input-type', 'root-validator-pre-skip', 'model-serializer-instance-method', 'validator-field-config-info', @@ -62,20 +54,9 @@ PydanticErrorCodes = Literal[ 'field-serializer-signature', 'model-serializer-signature', 'multiple-field-serializers', - 'invalid-annotated-type', + 'invalid_annotated_type', 'type-adapter-config-unused', 'root-model-extra', - 'unevaluable-type-annotation', - 'dataclass-init-false-extra-allow', - 'clashing-init-and-init-var', - 'model-config-invalid-field-name', - 'with-config-on-model', - 'dataclass-on-model', - 'validate-call-type', - 'unpack-typed-dict', - 'overlapping-unpack-typed-dict', - 'invalid-self-type', - 'validate-by-alias-and-name-false', ] @@ -164,26 +145,4 @@ class PydanticInvalidForJsonSchema(PydanticUserError): super().__init__(message, code='invalid-for-json-schema') -class PydanticForbiddenQualifier(PydanticUserError): - """An error raised if a forbidden type qualifier is found in a type annotation.""" - - _qualifier_repr_map: ClassVar[dict[Qualifier, str]] = { - 'required': 'typing.Required', - 'not_required': 'typing.NotRequired', - 'read_only': 'typing.ReadOnly', - 'class_var': 'typing.ClassVar', - 'init_var': 'dataclasses.InitVar', - 'final': 'typing.Final', - } - - def __init__(self, qualifier: Qualifier, annotation: Any) -> None: - super().__init__( - message=( - f'The annotation {_repr.display_as_type(annotation)!r} contains the {self._qualifier_repr_map[qualifier]!r} ' - f'type qualifier, which is invalid in the context it is defined.' - ), - code=None, - ) - - __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/experimental/__init__.py b/venv/lib/python3.12/site-packages/pydantic/experimental/__init__.py deleted file mode 100644 index 4aa58c6..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/experimental/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""The "experimental" module of pydantic contains potential new features that are subject to change.""" - -import warnings - -from pydantic.warnings import PydanticExperimentalWarning - -warnings.warn( - 'This module is experimental, its contents are subject to change and deprecation.', - category=PydanticExperimentalWarning, -) diff --git a/venv/lib/python3.12/site-packages/pydantic/experimental/arguments_schema.py b/venv/lib/python3.12/site-packages/pydantic/experimental/arguments_schema.py deleted file mode 100644 index af4a8f3..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/experimental/arguments_schema.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Experimental module exposing a function to generate a core schema that validates callable arguments.""" - -from __future__ import annotations - -from collections.abc import Callable -from typing import Any, Literal - -from pydantic_core import CoreSchema - -from pydantic import ConfigDict -from pydantic._internal import _config, _generate_schema, _namespace_utils - - -def generate_arguments_schema( - func: Callable[..., Any], - schema_type: Literal['arguments', 'arguments-v3'] = 'arguments-v3', - parameters_callback: Callable[[int, str, Any], Literal['skip'] | None] | None = None, - config: ConfigDict | None = None, -) -> CoreSchema: - """Generate the schema for the arguments of a function. - - Args: - func: The function to generate the schema for. - schema_type: The type of schema to generate. - parameters_callback: A callable that will be invoked for each parameter. The callback - should take three required arguments: the index, the name and the type annotation - (or [`Parameter.empty`][inspect.Parameter.empty] if not annotated) of the parameter. - The callback can optionally return `'skip'`, so that the parameter gets excluded - from the resulting schema. - config: The configuration to use. - - Returns: - The generated schema. - """ - generate_schema = _generate_schema.GenerateSchema( - _config.ConfigWrapper(config), - ns_resolver=_namespace_utils.NsResolver(namespaces_tuple=_namespace_utils.ns_for_function(func)), - ) - - if schema_type == 'arguments': - schema = generate_schema._arguments_schema(func, parameters_callback) # pyright: ignore[reportArgumentType] - else: - schema = generate_schema._arguments_v3_schema(func, parameters_callback) # pyright: ignore[reportArgumentType] - return generate_schema.clean_schema(schema) diff --git a/venv/lib/python3.12/site-packages/pydantic/experimental/pipeline.py b/venv/lib/python3.12/site-packages/pydantic/experimental/pipeline.py deleted file mode 100644 index bd63d98..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/experimental/pipeline.py +++ /dev/null @@ -1,667 +0,0 @@ -"""Experimental pipeline API functionality. Be careful with this API, it's subject to change.""" - -from __future__ import annotations - -import datetime -import operator -import re -import sys -from collections import deque -from collections.abc import Container -from dataclasses import dataclass -from decimal import Decimal -from functools import cached_property, partial -from re import Pattern -from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, Protocol, TypeVar, Union, overload - -import annotated_types - -if TYPE_CHECKING: - from pydantic_core import core_schema as cs - - from pydantic import GetCoreSchemaHandler - -from pydantic._internal._internal_dataclass import slots_true as _slots_true - -if sys.version_info < (3, 10): - EllipsisType = type(Ellipsis) -else: - from types import EllipsisType - -__all__ = ['validate_as', 'validate_as_deferred', 'transform'] - -_slots_frozen = {**_slots_true, 'frozen': True} - - -@dataclass(**_slots_frozen) -class _ValidateAs: - tp: type[Any] - strict: bool = False - - -@dataclass -class _ValidateAsDefer: - func: Callable[[], type[Any]] - - @cached_property - def tp(self) -> type[Any]: - return self.func() - - -@dataclass(**_slots_frozen) -class _Transform: - func: Callable[[Any], Any] - - -@dataclass(**_slots_frozen) -class _PipelineOr: - left: _Pipeline[Any, Any] - right: _Pipeline[Any, Any] - - -@dataclass(**_slots_frozen) -class _PipelineAnd: - left: _Pipeline[Any, Any] - right: _Pipeline[Any, Any] - - -@dataclass(**_slots_frozen) -class _Eq: - value: Any - - -@dataclass(**_slots_frozen) -class _NotEq: - value: Any - - -@dataclass(**_slots_frozen) -class _In: - values: Container[Any] - - -@dataclass(**_slots_frozen) -class _NotIn: - values: Container[Any] - - -_ConstraintAnnotation = Union[ - annotated_types.Le, - annotated_types.Ge, - annotated_types.Lt, - annotated_types.Gt, - annotated_types.Len, - annotated_types.MultipleOf, - annotated_types.Timezone, - annotated_types.Interval, - annotated_types.Predicate, - # common predicates not included in annotated_types - _Eq, - _NotEq, - _In, - _NotIn, - # regular expressions - Pattern[str], -] - - -@dataclass(**_slots_frozen) -class _Constraint: - constraint: _ConstraintAnnotation - - -_Step = Union[_ValidateAs, _ValidateAsDefer, _Transform, _PipelineOr, _PipelineAnd, _Constraint] - -_InT = TypeVar('_InT') -_OutT = TypeVar('_OutT') -_NewOutT = TypeVar('_NewOutT') - - -class _FieldTypeMarker: - pass - - -# TODO: ultimately, make this public, see https://github.com/pydantic/pydantic/pull/9459#discussion_r1628197626 -# Also, make this frozen eventually, but that doesn't work right now because of the generic base -# Which attempts to modify __orig_base__ and such. -# We could go with a manual freeze, but that seems overkill for now. -@dataclass(**_slots_true) -class _Pipeline(Generic[_InT, _OutT]): - """Abstract representation of a chain of validation, transformation, and parsing steps.""" - - _steps: tuple[_Step, ...] - - def transform( - self, - func: Callable[[_OutT], _NewOutT], - ) -> _Pipeline[_InT, _NewOutT]: - """Transform the output of the previous step. - - If used as the first step in a pipeline, the type of the field is used. - That is, the transformation is applied to after the value is parsed to the field's type. - """ - return _Pipeline[_InT, _NewOutT](self._steps + (_Transform(func),)) - - @overload - def validate_as(self, tp: type[_NewOutT], *, strict: bool = ...) -> _Pipeline[_InT, _NewOutT]: ... - - @overload - def validate_as(self, tp: EllipsisType, *, strict: bool = ...) -> _Pipeline[_InT, Any]: # type: ignore - ... - - def validate_as(self, tp: type[_NewOutT] | EllipsisType, *, strict: bool = False) -> _Pipeline[_InT, Any]: # type: ignore - """Validate / parse the input into a new type. - - If no type is provided, the type of the field is used. - - Types are parsed in Pydantic's `lax` mode by default, - but you can enable `strict` mode by passing `strict=True`. - """ - if isinstance(tp, EllipsisType): - return _Pipeline[_InT, Any](self._steps + (_ValidateAs(_FieldTypeMarker, strict=strict),)) - return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAs(tp, strict=strict),)) - - def validate_as_deferred(self, func: Callable[[], type[_NewOutT]]) -> _Pipeline[_InT, _NewOutT]: - """Parse the input into a new type, deferring resolution of the type until the current class - is fully defined. - - This is useful when you need to reference the class in it's own type annotations. - """ - return _Pipeline[_InT, _NewOutT](self._steps + (_ValidateAsDefer(func),)) - - # constraints - @overload - def constrain(self: _Pipeline[_InT, _NewOutGe], constraint: annotated_types.Ge) -> _Pipeline[_InT, _NewOutGe]: ... - - @overload - def constrain(self: _Pipeline[_InT, _NewOutGt], constraint: annotated_types.Gt) -> _Pipeline[_InT, _NewOutGt]: ... - - @overload - def constrain(self: _Pipeline[_InT, _NewOutLe], constraint: annotated_types.Le) -> _Pipeline[_InT, _NewOutLe]: ... - - @overload - def constrain(self: _Pipeline[_InT, _NewOutLt], constraint: annotated_types.Lt) -> _Pipeline[_InT, _NewOutLt]: ... - - @overload - def constrain( - self: _Pipeline[_InT, _NewOutLen], constraint: annotated_types.Len - ) -> _Pipeline[_InT, _NewOutLen]: ... - - @overload - def constrain( - self: _Pipeline[_InT, _NewOutT], constraint: annotated_types.MultipleOf - ) -> _Pipeline[_InT, _NewOutT]: ... - - @overload - def constrain( - self: _Pipeline[_InT, _NewOutDatetime], constraint: annotated_types.Timezone - ) -> _Pipeline[_InT, _NewOutDatetime]: ... - - @overload - def constrain(self: _Pipeline[_InT, _OutT], constraint: annotated_types.Predicate) -> _Pipeline[_InT, _OutT]: ... - - @overload - def constrain( - self: _Pipeline[_InT, _NewOutInterval], constraint: annotated_types.Interval - ) -> _Pipeline[_InT, _NewOutInterval]: ... - - @overload - def constrain(self: _Pipeline[_InT, _OutT], constraint: _Eq) -> _Pipeline[_InT, _OutT]: ... - - @overload - def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotEq) -> _Pipeline[_InT, _OutT]: ... - - @overload - def constrain(self: _Pipeline[_InT, _OutT], constraint: _In) -> _Pipeline[_InT, _OutT]: ... - - @overload - def constrain(self: _Pipeline[_InT, _OutT], constraint: _NotIn) -> _Pipeline[_InT, _OutT]: ... - - @overload - def constrain(self: _Pipeline[_InT, _NewOutT], constraint: Pattern[str]) -> _Pipeline[_InT, _NewOutT]: ... - - def constrain(self, constraint: _ConstraintAnnotation) -> Any: - """Constrain a value to meet a certain condition. - - We support most conditions from `annotated_types`, as well as regular expressions. - - Most of the time you'll be calling a shortcut method like `gt`, `lt`, `len`, etc - so you don't need to call this directly. - """ - return _Pipeline[_InT, _OutT](self._steps + (_Constraint(constraint),)) - - def predicate(self: _Pipeline[_InT, _NewOutT], func: Callable[[_NewOutT], bool]) -> _Pipeline[_InT, _NewOutT]: - """Constrain a value to meet a certain predicate.""" - return self.constrain(annotated_types.Predicate(func)) - - def gt(self: _Pipeline[_InT, _NewOutGt], gt: _NewOutGt) -> _Pipeline[_InT, _NewOutGt]: - """Constrain a value to be greater than a certain value.""" - return self.constrain(annotated_types.Gt(gt)) - - def lt(self: _Pipeline[_InT, _NewOutLt], lt: _NewOutLt) -> _Pipeline[_InT, _NewOutLt]: - """Constrain a value to be less than a certain value.""" - return self.constrain(annotated_types.Lt(lt)) - - def ge(self: _Pipeline[_InT, _NewOutGe], ge: _NewOutGe) -> _Pipeline[_InT, _NewOutGe]: - """Constrain a value to be greater than or equal to a certain value.""" - return self.constrain(annotated_types.Ge(ge)) - - def le(self: _Pipeline[_InT, _NewOutLe], le: _NewOutLe) -> _Pipeline[_InT, _NewOutLe]: - """Constrain a value to be less than or equal to a certain value.""" - return self.constrain(annotated_types.Le(le)) - - def len(self: _Pipeline[_InT, _NewOutLen], min_len: int, max_len: int | None = None) -> _Pipeline[_InT, _NewOutLen]: - """Constrain a value to have a certain length.""" - return self.constrain(annotated_types.Len(min_len, max_len)) - - @overload - def multiple_of(self: _Pipeline[_InT, _NewOutDiv], multiple_of: _NewOutDiv) -> _Pipeline[_InT, _NewOutDiv]: ... - - @overload - def multiple_of(self: _Pipeline[_InT, _NewOutMod], multiple_of: _NewOutMod) -> _Pipeline[_InT, _NewOutMod]: ... - - def multiple_of(self: _Pipeline[_InT, Any], multiple_of: Any) -> _Pipeline[_InT, Any]: - """Constrain a value to be a multiple of a certain number.""" - return self.constrain(annotated_types.MultipleOf(multiple_of)) - - def eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]: - """Constrain a value to be equal to a certain value.""" - return self.constrain(_Eq(value)) - - def not_eq(self: _Pipeline[_InT, _OutT], value: _OutT) -> _Pipeline[_InT, _OutT]: - """Constrain a value to not be equal to a certain value.""" - return self.constrain(_NotEq(value)) - - def in_(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]: - """Constrain a value to be in a certain set.""" - return self.constrain(_In(values)) - - def not_in(self: _Pipeline[_InT, _OutT], values: Container[_OutT]) -> _Pipeline[_InT, _OutT]: - """Constrain a value to not be in a certain set.""" - return self.constrain(_NotIn(values)) - - # timezone methods - def datetime_tz_naive(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]: - return self.constrain(annotated_types.Timezone(None)) - - def datetime_tz_aware(self: _Pipeline[_InT, datetime.datetime]) -> _Pipeline[_InT, datetime.datetime]: - return self.constrain(annotated_types.Timezone(...)) - - def datetime_tz( - self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo - ) -> _Pipeline[_InT, datetime.datetime]: - return self.constrain(annotated_types.Timezone(tz)) # type: ignore - - def datetime_with_tz( - self: _Pipeline[_InT, datetime.datetime], tz: datetime.tzinfo | None - ) -> _Pipeline[_InT, datetime.datetime]: - return self.transform(partial(datetime.datetime.replace, tzinfo=tz)) - - # string methods - def str_lower(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]: - return self.transform(str.lower) - - def str_upper(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]: - return self.transform(str.upper) - - def str_title(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]: - return self.transform(str.title) - - def str_strip(self: _Pipeline[_InT, str]) -> _Pipeline[_InT, str]: - return self.transform(str.strip) - - def str_pattern(self: _Pipeline[_InT, str], pattern: str) -> _Pipeline[_InT, str]: - return self.constrain(re.compile(pattern)) - - def str_contains(self: _Pipeline[_InT, str], substring: str) -> _Pipeline[_InT, str]: - return self.predicate(lambda v: substring in v) - - def str_starts_with(self: _Pipeline[_InT, str], prefix: str) -> _Pipeline[_InT, str]: - return self.predicate(lambda v: v.startswith(prefix)) - - def str_ends_with(self: _Pipeline[_InT, str], suffix: str) -> _Pipeline[_InT, str]: - return self.predicate(lambda v: v.endswith(suffix)) - - # operators - def otherwise(self, other: _Pipeline[_OtherIn, _OtherOut]) -> _Pipeline[_InT | _OtherIn, _OutT | _OtherOut]: - """Combine two validation chains, returning the result of the first chain if it succeeds, and the second chain if it fails.""" - return _Pipeline((_PipelineOr(self, other),)) - - __or__ = otherwise - - def then(self, other: _Pipeline[_OutT, _OtherOut]) -> _Pipeline[_InT, _OtherOut]: - """Pipe the result of one validation chain into another.""" - return _Pipeline((_PipelineAnd(self, other),)) - - __and__ = then - - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> cs.CoreSchema: - from pydantic_core import core_schema as cs - - queue = deque(self._steps) - - s = None - - while queue: - step = queue.popleft() - s = _apply_step(step, s, handler, source_type) - - s = s or cs.any_schema() - return s - - def __supports_type__(self, _: _OutT) -> bool: - raise NotImplementedError - - -validate_as = _Pipeline[Any, Any](()).validate_as -validate_as_deferred = _Pipeline[Any, Any](()).validate_as_deferred -transform = _Pipeline[Any, Any]((_ValidateAs(_FieldTypeMarker),)).transform - - -def _check_func( - func: Callable[[Any], bool], predicate_err: str | Callable[[], str], s: cs.CoreSchema | None -) -> cs.CoreSchema: - from pydantic_core import core_schema as cs - - def handler(v: Any) -> Any: - if func(v): - return v - raise ValueError(f'Expected {predicate_err if isinstance(predicate_err, str) else predicate_err()}') - - if s is None: - return cs.no_info_plain_validator_function(handler) - else: - return cs.no_info_after_validator_function(handler, s) - - -def _apply_step(step: _Step, s: cs.CoreSchema | None, handler: GetCoreSchemaHandler, source_type: Any) -> cs.CoreSchema: - from pydantic_core import core_schema as cs - - if isinstance(step, _ValidateAs): - s = _apply_parse(s, step.tp, step.strict, handler, source_type) - elif isinstance(step, _ValidateAsDefer): - s = _apply_parse(s, step.tp, False, handler, source_type) - elif isinstance(step, _Transform): - s = _apply_transform(s, step.func, handler) - elif isinstance(step, _Constraint): - s = _apply_constraint(s, step.constraint) - elif isinstance(step, _PipelineOr): - s = cs.union_schema([handler(step.left), handler(step.right)]) - else: - assert isinstance(step, _PipelineAnd) - s = cs.chain_schema([handler(step.left), handler(step.right)]) - return s - - -def _apply_parse( - s: cs.CoreSchema | None, - tp: type[Any], - strict: bool, - handler: GetCoreSchemaHandler, - source_type: Any, -) -> cs.CoreSchema: - from pydantic_core import core_schema as cs - - from pydantic import Strict - - if tp is _FieldTypeMarker: - return cs.chain_schema([s, handler(source_type)]) if s else handler(source_type) - - if strict: - tp = Annotated[tp, Strict()] # type: ignore - - if s and s['type'] == 'any': - return handler(tp) - else: - return cs.chain_schema([s, handler(tp)]) if s else handler(tp) - - -def _apply_transform( - s: cs.CoreSchema | None, func: Callable[[Any], Any], handler: GetCoreSchemaHandler -) -> cs.CoreSchema: - from pydantic_core import core_schema as cs - - if s is None: - return cs.no_info_plain_validator_function(func) - - if s['type'] == 'str': - if func is str.strip: - s = s.copy() - s['strip_whitespace'] = True - return s - elif func is str.lower: - s = s.copy() - s['to_lower'] = True - return s - elif func is str.upper: - s = s.copy() - s['to_upper'] = True - return s - - return cs.no_info_after_validator_function(func, s) - - -def _apply_constraint( # noqa: C901 - s: cs.CoreSchema | None, constraint: _ConstraintAnnotation -) -> cs.CoreSchema: - """Apply a single constraint to a schema.""" - if isinstance(constraint, annotated_types.Gt): - gt = constraint.gt - if s and s['type'] in {'int', 'float', 'decimal'}: - s = s.copy() - if s['type'] == 'int' and isinstance(gt, int): - s['gt'] = gt - elif s['type'] == 'float' and isinstance(gt, float): - s['gt'] = gt - elif s['type'] == 'decimal' and isinstance(gt, Decimal): - s['gt'] = gt - else: - - def check_gt(v: Any) -> bool: - return v > gt - - s = _check_func(check_gt, f'> {gt}', s) - elif isinstance(constraint, annotated_types.Ge): - ge = constraint.ge - if s and s['type'] in {'int', 'float', 'decimal'}: - s = s.copy() - if s['type'] == 'int' and isinstance(ge, int): - s['ge'] = ge - elif s['type'] == 'float' and isinstance(ge, float): - s['ge'] = ge - elif s['type'] == 'decimal' and isinstance(ge, Decimal): - s['ge'] = ge - - def check_ge(v: Any) -> bool: - return v >= ge - - s = _check_func(check_ge, f'>= {ge}', s) - elif isinstance(constraint, annotated_types.Lt): - lt = constraint.lt - if s and s['type'] in {'int', 'float', 'decimal'}: - s = s.copy() - if s['type'] == 'int' and isinstance(lt, int): - s['lt'] = lt - elif s['type'] == 'float' and isinstance(lt, float): - s['lt'] = lt - elif s['type'] == 'decimal' and isinstance(lt, Decimal): - s['lt'] = lt - - def check_lt(v: Any) -> bool: - return v < lt - - s = _check_func(check_lt, f'< {lt}', s) - elif isinstance(constraint, annotated_types.Le): - le = constraint.le - if s and s['type'] in {'int', 'float', 'decimal'}: - s = s.copy() - if s['type'] == 'int' and isinstance(le, int): - s['le'] = le - elif s['type'] == 'float' and isinstance(le, float): - s['le'] = le - elif s['type'] == 'decimal' and isinstance(le, Decimal): - s['le'] = le - - def check_le(v: Any) -> bool: - return v <= le - - s = _check_func(check_le, f'<= {le}', s) - elif isinstance(constraint, annotated_types.Len): - min_len = constraint.min_length - max_len = constraint.max_length - - if s and s['type'] in {'str', 'list', 'tuple', 'set', 'frozenset', 'dict'}: - assert ( - s['type'] == 'str' - or s['type'] == 'list' - or s['type'] == 'tuple' - or s['type'] == 'set' - or s['type'] == 'dict' - or s['type'] == 'frozenset' - ) - s = s.copy() - if min_len != 0: - s['min_length'] = min_len - if max_len is not None: - s['max_length'] = max_len - - def check_len(v: Any) -> bool: - if max_len is not None: - return (min_len <= len(v)) and (len(v) <= max_len) - return min_len <= len(v) - - s = _check_func(check_len, f'length >= {min_len} and length <= {max_len}', s) - elif isinstance(constraint, annotated_types.MultipleOf): - multiple_of = constraint.multiple_of - if s and s['type'] in {'int', 'float', 'decimal'}: - s = s.copy() - if s['type'] == 'int' and isinstance(multiple_of, int): - s['multiple_of'] = multiple_of - elif s['type'] == 'float' and isinstance(multiple_of, float): - s['multiple_of'] = multiple_of - elif s['type'] == 'decimal' and isinstance(multiple_of, Decimal): - s['multiple_of'] = multiple_of - - def check_multiple_of(v: Any) -> bool: - return v % multiple_of == 0 - - s = _check_func(check_multiple_of, f'% {multiple_of} == 0', s) - elif isinstance(constraint, annotated_types.Timezone): - tz = constraint.tz - - if tz is ...: - if s and s['type'] == 'datetime': - s = s.copy() - s['tz_constraint'] = 'aware' - else: - - def check_tz_aware(v: object) -> bool: - assert isinstance(v, datetime.datetime) - return v.tzinfo is not None - - s = _check_func(check_tz_aware, 'timezone aware', s) - elif tz is None: - if s and s['type'] == 'datetime': - s = s.copy() - s['tz_constraint'] = 'naive' - else: - - def check_tz_naive(v: object) -> bool: - assert isinstance(v, datetime.datetime) - return v.tzinfo is None - - s = _check_func(check_tz_naive, 'timezone naive', s) - else: - raise NotImplementedError('Constraining to a specific timezone is not yet supported') - elif isinstance(constraint, annotated_types.Interval): - if constraint.ge: - s = _apply_constraint(s, annotated_types.Ge(constraint.ge)) - if constraint.gt: - s = _apply_constraint(s, annotated_types.Gt(constraint.gt)) - if constraint.le: - s = _apply_constraint(s, annotated_types.Le(constraint.le)) - if constraint.lt: - s = _apply_constraint(s, annotated_types.Lt(constraint.lt)) - assert s is not None - elif isinstance(constraint, annotated_types.Predicate): - func = constraint.func - - if func.__name__ == '': - # attempt to extract the source code for a lambda function - # to use as the function name in error messages - # TODO: is there a better way? should we just not do this? - import inspect - - try: - source = inspect.getsource(func).strip() - source = source.removesuffix(')') - lambda_source_code = '`' + ''.join(''.join(source.split('lambda ')[1:]).split(':')[1:]).strip() + '`' - except OSError: - # stringified annotations - lambda_source_code = 'lambda' - - s = _check_func(func, lambda_source_code, s) - else: - s = _check_func(func, func.__name__, s) - elif isinstance(constraint, _NotEq): - value = constraint.value - - def check_not_eq(v: Any) -> bool: - return operator.__ne__(v, value) - - s = _check_func(check_not_eq, f'!= {value}', s) - elif isinstance(constraint, _Eq): - value = constraint.value - - def check_eq(v: Any) -> bool: - return operator.__eq__(v, value) - - s = _check_func(check_eq, f'== {value}', s) - elif isinstance(constraint, _In): - values = constraint.values - - def check_in(v: Any) -> bool: - return operator.__contains__(values, v) - - s = _check_func(check_in, f'in {values}', s) - elif isinstance(constraint, _NotIn): - values = constraint.values - - def check_not_in(v: Any) -> bool: - return operator.__not__(operator.__contains__(values, v)) - - s = _check_func(check_not_in, f'not in {values}', s) - else: - assert isinstance(constraint, Pattern) - if s and s['type'] == 'str': - s = s.copy() - s['pattern'] = constraint.pattern - else: - - def check_pattern(v: object) -> bool: - assert isinstance(v, str) - return constraint.match(v) is not None - - s = _check_func(check_pattern, f'~ {constraint.pattern}', s) - return s - - -class _SupportsRange(annotated_types.SupportsLe, annotated_types.SupportsGe, Protocol): - pass - - -class _SupportsLen(Protocol): - def __len__(self) -> int: ... - - -_NewOutGt = TypeVar('_NewOutGt', bound=annotated_types.SupportsGt) -_NewOutGe = TypeVar('_NewOutGe', bound=annotated_types.SupportsGe) -_NewOutLt = TypeVar('_NewOutLt', bound=annotated_types.SupportsLt) -_NewOutLe = TypeVar('_NewOutLe', bound=annotated_types.SupportsLe) -_NewOutLen = TypeVar('_NewOutLen', bound=_SupportsLen) -_NewOutDiv = TypeVar('_NewOutDiv', bound=annotated_types.SupportsDiv) -_NewOutMod = TypeVar('_NewOutMod', bound=annotated_types.SupportsMod) -_NewOutDatetime = TypeVar('_NewOutDatetime', bound=datetime.datetime) -_NewOutInterval = TypeVar('_NewOutInterval', bound=_SupportsRange) -_OtherIn = TypeVar('_OtherIn') -_OtherOut = TypeVar('_OtherOut') diff --git a/venv/lib/python3.12/site-packages/pydantic/fields.py b/venv/lib/python3.12/site-packages/pydantic/fields.py index 408fc2b..2a8f768 100644 --- a/venv/lib/python3.12/site-packages/pydantic/fields.py +++ b/venv/lib/python3.12/site-packages/pydantic/fields.py @@ -1,32 +1,29 @@ """Defining fields on models.""" - from __future__ import annotations as _annotations import dataclasses import inspect import sys import typing -from collections.abc import Mapping from copy import copy from dataclasses import Field as DataclassField -from functools import cached_property -from typing import Annotated, Any, Callable, ClassVar, Literal, TypeVar, cast, overload + +try: + from functools import cached_property # type: ignore +except ImportError: + # python 3.7 + cached_property = None +from typing import Any, ClassVar from warnings import warn import annotated_types import typing_extensions from pydantic_core import PydanticUndefined -from typing_extensions import Self, TypeAlias, Unpack, deprecated -from typing_inspection import typing_objects -from typing_inspection.introspection import UNKNOWN, AnnotationSource, ForbiddenQualifier, Qualifier, inspect_annotation +from typing_extensions import Literal, Unpack from . import types from ._internal import _decorators, _fields, _generics, _internal_dataclass, _repr, _typing_extra, _utils -from ._internal._namespace_utils import GlobalsNamespace, MappingNamespace -from .aliases import AliasChoices, AliasPath -from .config import JsonDict -from .errors import PydanticForbiddenQualifier, PydanticUserError -from .json_schema import PydanticJsonSchemaWarning +from .errors import PydanticUserError from .warnings import PydanticDeprecatedSince20 if typing.TYPE_CHECKING: @@ -36,58 +33,43 @@ else: # and https://youtrack.jetbrains.com/issue/PY-51428 DeprecationWarning = PydanticDeprecatedSince20 -__all__ = 'Field', 'PrivateAttr', 'computed_field' - _Unset: Any = PydanticUndefined -if sys.version_info >= (3, 13): - import warnings - - Deprecated: TypeAlias = warnings.deprecated | deprecated -else: - Deprecated: TypeAlias = deprecated - class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False): """This class exists solely to add type checking for the `**kwargs` in `FieldInfo.from_field`.""" - # TODO PEP 747: use TypeForm: annotation: type[Any] | None - default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None + default_factory: typing.Callable[[], Any] | None alias: str | None alias_priority: int | None validation_alias: str | AliasPath | AliasChoices | None serialization_alias: str | None title: str | None - field_title_generator: Callable[[str, FieldInfo], str] | None description: str | None examples: list[Any] | None exclude: bool | None - gt: annotated_types.SupportsGt | None - ge: annotated_types.SupportsGe | None - lt: annotated_types.SupportsLt | None - le: annotated_types.SupportsLe | None + gt: float | None + ge: float | None + lt: float | None + le: float | None multiple_of: float | None strict: bool | None min_length: int | None max_length: int | None - pattern: str | typing.Pattern[str] | None + pattern: str | None allow_inf_nan: bool | None max_digits: int | None decimal_places: int | None union_mode: Literal['smart', 'left_to_right'] | None - discriminator: str | types.Discriminator | None - deprecated: Deprecated | str | bool | None - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None + discriminator: str | None + json_schema_extra: dict[str, Any] | typing.Callable[[dict[str, Any]], None] | None frozen: bool | None validate_default: bool | None repr: bool - init: bool | None init_var: bool | None kw_only: bool | None - coerce_numbers_to_str: bool | None - fail_fast: bool | None class _FieldInfoInputs(_FromFieldInfoInputs, total=False): @@ -109,49 +91,41 @@ class FieldInfo(_repr.Representation): Attributes: annotation: The type annotation of the field. default: The default value of the field. - default_factory: A callable to generate the default value. The callable can either take 0 arguments - (in which case it is called as is) or a single argument containing the already validated data. + default_factory: The factory function used to construct the default for the field. alias: The alias name of the field. alias_priority: The priority of the field's alias. - validation_alias: The validation alias of the field. - serialization_alias: The serialization alias of the field. + validation_alias: The validation alias name of the field. + serialization_alias: The serialization alias name of the field. title: The title of the field. - field_title_generator: A callable that takes a field name and returns title for it. description: The description of the field. examples: List of examples of the field. exclude: Whether to exclude the field from the model serialization. - discriminator: Field name or Discriminator for discriminating the type in a tagged union. - deprecated: A deprecation message, an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport, - or a boolean. If `True`, a default deprecation message will be emitted when accessing the field. - json_schema_extra: A dict or callable to provide extra JSON schema properties. + discriminator: Field name for discriminating the type in a tagged union. + json_schema_extra: Dictionary of extra JSON schema properties. frozen: Whether the field is frozen. validate_default: Whether to validate the default value of the field. repr: Whether to include the field in representation of the model. - init: Whether the field should be included in the constructor of the dataclass. - init_var: Whether the field should _only_ be included in the constructor of the dataclass, and not stored. + init_var: Whether the field should be included in the constructor of the dataclass. kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. metadata: List of metadata constraints. """ annotation: type[Any] | None default: Any - default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None + default_factory: typing.Callable[[], Any] | None alias: str | None alias_priority: int | None validation_alias: str | AliasPath | AliasChoices | None serialization_alias: str | None title: str | None - field_title_generator: Callable[[str, FieldInfo], str] | None description: str | None examples: list[Any] | None exclude: bool | None - discriminator: str | types.Discriminator | None - deprecated: Deprecated | str | bool | None - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None + discriminator: str | None + json_schema_extra: dict[str, Any] | typing.Callable[[dict[str, Any]], None] | None frozen: bool | None validate_default: bool | None repr: bool - init: bool | None init_var: bool | None kw_only: bool | None metadata: list[Any] @@ -165,25 +139,18 @@ class FieldInfo(_repr.Representation): 'validation_alias', 'serialization_alias', 'title', - 'field_title_generator', 'description', 'examples', 'exclude', 'discriminator', - 'deprecated', 'json_schema_extra', 'frozen', 'validate_default', 'repr', - 'init', 'init_var', 'kw_only', 'metadata', '_attributes_set', - '_qualifiers', - '_complete', - '_original_assignment', - '_original_annotation', ) # used to convert kwargs to metadata/constraints, @@ -202,8 +169,6 @@ class FieldInfo(_repr.Representation): 'max_digits': None, 'decimal_places': None, 'union_mode': None, - 'coerce_numbers_to_str': None, - 'fail_fast': types.FailFast, } def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None: @@ -214,12 +179,11 @@ class FieldInfo(_repr.Representation): """ self._attributes_set = {k: v for k, v in kwargs.items() if v is not _Unset} kwargs = {k: _DefaultValues.get(k) if v is _Unset else v for k, v in kwargs.items()} # type: ignore - self.annotation = kwargs.get('annotation') + self.annotation, annotation_metadata = self._extract_metadata(kwargs.get('annotation')) default = kwargs.pop('default', PydanticUndefined) if default is Ellipsis: self.default = PydanticUndefined - self._attributes_set.pop('default', None) else: self.default = default @@ -228,39 +192,30 @@ class FieldInfo(_repr.Representation): if self.default is not PydanticUndefined and self.default_factory is not None: raise TypeError('cannot specify both default and default_factory') + self.title = kwargs.pop('title', None) self.alias = kwargs.pop('alias', None) self.validation_alias = kwargs.pop('validation_alias', None) self.serialization_alias = kwargs.pop('serialization_alias', None) alias_is_set = any(alias is not None for alias in (self.alias, self.validation_alias, self.serialization_alias)) self.alias_priority = kwargs.pop('alias_priority', None) or 2 if alias_is_set else None - self.title = kwargs.pop('title', None) - self.field_title_generator = kwargs.pop('field_title_generator', None) self.description = kwargs.pop('description', None) self.examples = kwargs.pop('examples', None) self.exclude = kwargs.pop('exclude', None) self.discriminator = kwargs.pop('discriminator', None) - # For compatibility with FastAPI<=0.110.0, we preserve the existing value if it is not overridden - self.deprecated = kwargs.pop('deprecated', getattr(self, 'deprecated', None)) self.repr = kwargs.pop('repr', True) self.json_schema_extra = kwargs.pop('json_schema_extra', None) self.validate_default = kwargs.pop('validate_default', None) self.frozen = kwargs.pop('frozen', None) # currently only used on dataclasses - self.init = kwargs.pop('init', None) self.init_var = kwargs.pop('init_var', None) self.kw_only = kwargs.pop('kw_only', None) - self.metadata = self._collect_metadata(kwargs) # type: ignore + self.metadata = self._collect_metadata(kwargs) + annotation_metadata # type: ignore - # Private attributes: - self._qualifiers: set[Qualifier] = set() - # Used to rebuild FieldInfo instances: - self._complete = True - self._original_annotation: Any = PydanticUndefined - self._original_assignment: Any = PydanticUndefined - - @staticmethod - def from_field(default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs]) -> FieldInfo: + @classmethod + def from_field( + cls, default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs] + ) -> typing_extensions.Self: """Create a new `FieldInfo` object with the `Field` function. Args: @@ -285,191 +240,146 @@ class FieldInfo(_repr.Representation): """ if 'annotation' in kwargs: raise TypeError('"annotation" is not permitted as a Field keyword argument') - return FieldInfo(default=default, **kwargs) + return cls(default=default, **kwargs) - @staticmethod - def from_annotation(annotation: type[Any], *, _source: AnnotationSource = AnnotationSource.ANY) -> FieldInfo: + @classmethod + def from_annotation(cls, annotation: type[Any]) -> typing_extensions.Self: """Creates a `FieldInfo` instance from a bare annotation. - This function is used internally to create a `FieldInfo` from a bare annotation like this: - - ```python - import pydantic - - class MyModel(pydantic.BaseModel): - foo: int # <-- like this - ``` - - We also account for the case where the annotation can be an instance of `Annotated` and where - one of the (not first) arguments in `Annotated` is an instance of `FieldInfo`, e.g.: - - ```python - from typing import Annotated - - import annotated_types - - import pydantic - - class MyModel(pydantic.BaseModel): - foo: Annotated[int, annotated_types.Gt(42)] - bar: Annotated[int, pydantic.Field(gt=42)] - ``` - Args: annotation: An annotation object. Returns: An instance of the field metadata. + + Example: + This is how you can create a field from a bare annotation like this: + + ```python + import pydantic + + class MyModel(pydantic.BaseModel): + foo: int # <-- like this + ``` + + We also account for the case where the annotation can be an instance of `Annotated` and where + one of the (not first) arguments in `Annotated` are an instance of `FieldInfo`, e.g.: + + ```python + import annotated_types + from typing_extensions import Annotated + + import pydantic + + class MyModel(pydantic.BaseModel): + foo: Annotated[int, annotated_types.Gt(42)] + bar: Annotated[int, pydantic.Field(gt=42)] + ``` + """ - try: - inspected_ann = inspect_annotation( - annotation, - annotation_source=_source, - unpack_type_aliases='skip', - ) - except ForbiddenQualifier as e: - raise PydanticForbiddenQualifier(e.qualifier, annotation) + final = False + if _typing_extra.is_finalvar(annotation): + final = True + if annotation is not typing_extensions.Final: + annotation = typing_extensions.get_args(annotation)[0] - # TODO check for classvar and error? + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + if _typing_extra.is_finalvar(first_arg): + final = True + field_info_annotations = [a for a in extra_args if isinstance(a, FieldInfo)] + field_info = cls.merge_field_infos(*field_info_annotations, annotation=first_arg) + if field_info: + new_field_info = copy(field_info) + new_field_info.annotation = first_arg + new_field_info.frozen = final or field_info.frozen + metadata: list[Any] = [] + for a in extra_args: + if not isinstance(a, FieldInfo): + metadata.append(a) + else: + metadata.extend(a.metadata) + new_field_info.metadata = metadata + return new_field_info - # No assigned value, this happens when using a bare `Final` qualifier (also for other - # qualifiers, but they shouldn't appear here). In this case we infer the type as `Any` - # because we don't have any assigned value. - type_expr: Any = Any if inspected_ann.type is UNKNOWN else inspected_ann.type - final = 'final' in inspected_ann.qualifiers - metadata = inspected_ann.metadata + return cls(annotation=annotation, frozen=final or None) - if not metadata: - # No metadata, e.g. `field: int`, or `field: Final[str]`: - field_info = FieldInfo(annotation=type_expr, frozen=final or None) - field_info._qualifiers = inspected_ann.qualifiers - return field_info - - # With metadata, e.g. `field: Annotated[int, Field(...), Gt(1)]`: - field_info_annotations = [a for a in metadata if isinstance(a, FieldInfo)] - field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=type_expr) - - new_field_info = field_info._copy() - new_field_info.annotation = type_expr - new_field_info.frozen = final or field_info.frozen - field_metadata: list[Any] = [] - for a in metadata: - if typing_objects.is_deprecated(a): - new_field_info.deprecated = a.message - elif not isinstance(a, FieldInfo): - field_metadata.append(a) - else: - field_metadata.extend(a.metadata) - new_field_info.metadata = field_metadata - new_field_info._qualifiers = inspected_ann.qualifiers - return new_field_info - - @staticmethod - def from_annotated_attribute( - annotation: type[Any], default: Any, *, _source: AnnotationSource = AnnotationSource.ANY - ) -> FieldInfo: + @classmethod + def from_annotated_attribute(cls, annotation: type[Any], default: Any) -> typing_extensions.Self: """Create `FieldInfo` from an annotation with a default value. - This is used in cases like the following: - - ```python - from typing import Annotated - - import annotated_types - - import pydantic - - class MyModel(pydantic.BaseModel): - foo: int = 4 # <-- like this - bar: Annotated[int, annotated_types.Gt(4)] = 4 # <-- or this - spam: Annotated[int, pydantic.Field(gt=4)] = 4 # <-- or this - ``` - Args: annotation: The type annotation of the field. default: The default value of the field. Returns: A field object with the passed values. + + Example: + ```python + import annotated_types + from typing_extensions import Annotated + + import pydantic + + class MyModel(pydantic.BaseModel): + foo: int = 4 # <-- like this + bar: Annotated[int, annotated_types.Gt(4)] = 4 # <-- or this + spam: Annotated[int, pydantic.Field(gt=4)] = 4 # <-- or this + ``` """ - if annotation is default: - raise PydanticUserError( - 'Error when building FieldInfo from annotated attribute. ' - "Make sure you don't have any field name clashing with a type annotation.", - code='unevaluable-type-annotation', + final = False + if _typing_extra.is_finalvar(annotation): + final = True + if annotation is not typing_extensions.Final: + annotation = typing_extensions.get_args(annotation)[0] + + if isinstance(default, cls): + default.annotation, annotation_metadata = cls._extract_metadata(annotation) + default.metadata += annotation_metadata + default = default.merge_field_infos( + *[x for x in annotation_metadata if isinstance(x, cls)], default, annotation=default.annotation ) + default.frozen = final or default.frozen + return default + elif isinstance(default, dataclasses.Field): + init_var = False + if annotation is dataclasses.InitVar: + if sys.version_info < (3, 8): + raise RuntimeError('InitVar is not supported in Python 3.7 as type information is lost') - try: - inspected_ann = inspect_annotation( - annotation, - annotation_source=_source, - unpack_type_aliases='skip', - ) - except ForbiddenQualifier as e: - raise PydanticForbiddenQualifier(e.qualifier, annotation) - - # TODO check for classvar and error? - - # TODO infer from the default, this can be done in v3 once we treat final fields with - # a default as proper fields and not class variables: - type_expr: Any = Any if inspected_ann.type is UNKNOWN else inspected_ann.type - final = 'final' in inspected_ann.qualifiers - metadata = inspected_ann.metadata - - if isinstance(default, FieldInfo): - # e.g. `field: int = Field(...)` - default_metadata = default.metadata.copy() - default = copy(default) - default.metadata = default_metadata - - default.annotation = type_expr - default.metadata += metadata - merged_default = FieldInfo.merge_field_infos( - *[x for x in metadata if isinstance(x, FieldInfo)], - default, - annotation=default.annotation, - ) - merged_default.frozen = final or merged_default.frozen - merged_default._qualifiers = inspected_ann.qualifiers - return merged_default - - if isinstance(default, dataclasses.Field): - # `collect_dataclass_fields()` passes the dataclass Field as a default. - pydantic_field = FieldInfo._from_dataclass_field(default) - pydantic_field.annotation = type_expr - pydantic_field.metadata += metadata - pydantic_field = FieldInfo.merge_field_infos( - *[x for x in metadata if isinstance(x, FieldInfo)], + init_var = True + annotation = Any + elif isinstance(annotation, dataclasses.InitVar): + init_var = True + annotation = annotation.type + pydantic_field = cls._from_dataclass_field(default) + pydantic_field.annotation, annotation_metadata = cls._extract_metadata(annotation) + pydantic_field.metadata += annotation_metadata + pydantic_field = pydantic_field.merge_field_infos( + *[x for x in annotation_metadata if isinstance(x, cls)], pydantic_field, annotation=pydantic_field.annotation, ) pydantic_field.frozen = final or pydantic_field.frozen - pydantic_field.init_var = 'init_var' in inspected_ann.qualifiers - pydantic_field.init = getattr(default, 'init', None) + pydantic_field.init_var = init_var pydantic_field.kw_only = getattr(default, 'kw_only', None) - pydantic_field._qualifiers = inspected_ann.qualifiers return pydantic_field + else: + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + field_infos = [a for a in extra_args if isinstance(a, FieldInfo)] + field_info = cls.merge_field_infos(*field_infos, annotation=first_arg, default=default) + metadata: list[Any] = [] + for a in extra_args: + if not isinstance(a, FieldInfo): + metadata.append(a) + else: + metadata.extend(a.metadata) + field_info.metadata = metadata + return field_info - if not metadata: - # No metadata, e.g. `field: int = ...`, or `field: Final[str] = ...`: - field_info = FieldInfo(annotation=type_expr, default=default, frozen=final or None) - field_info._qualifiers = inspected_ann.qualifiers - return field_info - - # With metadata, e.g. `field: Annotated[int, Field(...), Gt(1)] = ...`: - field_infos = [a for a in metadata if isinstance(a, FieldInfo)] - field_info = FieldInfo.merge_field_infos(*field_infos, annotation=type_expr, default=default) - field_metadata: list[Any] = [] - for a in metadata: - if typing_objects.is_deprecated(a): - field_info.deprecated = a.message - elif not isinstance(a, FieldInfo): - field_metadata.append(a) - else: - field_metadata.extend(a.metadata) - field_info.metadata = field_metadata - field_info._qualifiers = inspected_ann.qualifiers - return field_info + return cls(annotation=annotation, default=default, frozen=final or None) @staticmethod def merge_field_infos(*field_infos: FieldInfo, **overrides: Any) -> FieldInfo: @@ -480,65 +390,33 @@ class FieldInfo(_repr.Representation): Returns: FieldInfo: A merged FieldInfo instance. """ + flattened_field_infos: list[FieldInfo] = [] + for field_info in field_infos: + flattened_field_infos.extend(x for x in field_info.metadata if isinstance(x, FieldInfo)) + flattened_field_infos.append(field_info) + field_infos = tuple(flattened_field_infos) if len(field_infos) == 1: # No merging necessary, but we still need to make a copy and apply the overrides - field_info = field_infos[0]._copy() + field_info = copy(field_infos[0]) field_info._attributes_set.update(overrides) - - default_override = overrides.pop('default', PydanticUndefined) - if default_override is Ellipsis: - default_override = PydanticUndefined - if default_override is not PydanticUndefined: - field_info.default = default_override - for k, v in overrides.items(): setattr(field_info, k, v) - return field_info # type: ignore + return field_info - merged_field_info_kwargs: dict[str, Any] = {} + new_kwargs: dict[str, Any] = {} metadata = {} for field_info in field_infos: - attributes_set = field_info._attributes_set.copy() - - try: - json_schema_extra = attributes_set.pop('json_schema_extra') - existing_json_schema_extra = merged_field_info_kwargs.get('json_schema_extra') - - if existing_json_schema_extra is None: - merged_field_info_kwargs['json_schema_extra'] = json_schema_extra - if isinstance(existing_json_schema_extra, dict): - if isinstance(json_schema_extra, dict): - merged_field_info_kwargs['json_schema_extra'] = { - **existing_json_schema_extra, - **json_schema_extra, - } - if callable(json_schema_extra): - warn( - 'Composing `dict` and `callable` type `json_schema_extra` is not supported.' - 'The `callable` type is being ignored.' - "If you'd like support for this behavior, please open an issue on pydantic.", - PydanticJsonSchemaWarning, - ) - elif callable(json_schema_extra): - # if ever there's a case of a callable, we'll just keep the last json schema extra spec - merged_field_info_kwargs['json_schema_extra'] = json_schema_extra - except KeyError: - pass - - # later FieldInfo instances override everything except json_schema_extra from earlier FieldInfo instances - merged_field_info_kwargs.update(attributes_set) - + new_kwargs.update(field_info._attributes_set) for x in field_info.metadata: if not isinstance(x, FieldInfo): metadata[type(x)] = x - - merged_field_info_kwargs.update(overrides) - field_info = FieldInfo(**merged_field_info_kwargs) + new_kwargs.update(overrides) + field_info = FieldInfo(**new_kwargs) field_info.metadata = list(metadata.values()) return field_info - @staticmethod - def _from_dataclass_field(dc_field: DataclassField[Any]) -> FieldInfo: + @classmethod + def _from_dataclass_field(cls, dc_field: DataclassField[Any]) -> typing_extensions.Self: """Return a new `FieldInfo` instance from a `dataclasses.Field` instance. Args: @@ -552,21 +430,41 @@ class FieldInfo(_repr.Representation): """ default = dc_field.default if default is dataclasses.MISSING: - default = _Unset + default = PydanticUndefined if dc_field.default_factory is dataclasses.MISSING: - default_factory = _Unset + default_factory: typing.Callable[[], Any] | None = None else: default_factory = dc_field.default_factory # use the `Field` function so in correct kwargs raise the correct `TypeError` dc_field_metadata = {k: v for k, v in dc_field.metadata.items() if k in _FIELD_ARG_NAMES} - return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) # pyright: ignore[reportCallIssue] + return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) - @staticmethod - def _collect_metadata(kwargs: dict[str, Any]) -> list[Any]: + @classmethod + def _extract_metadata(cls, annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]: + """Tries to extract metadata/constraints from an annotation if it uses `Annotated`. + + Args: + annotation: The type hint annotation for which metadata has to be extracted. + + Returns: + A tuple containing the extracted metadata type and the list of extra arguments. + """ + if annotation is not None: + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + return first_arg, list(extra_args) + + return annotation, [] + + @classmethod + def _collect_metadata(cls, kwargs: dict[str, Any]) -> list[Any]: """Collect annotations from kwargs. + The return type is actually `annotated_types.BaseMetadata | PydanticMetadata`, + but it gets combined with `list[Any]` from `Annotated[T, ...]`, hence types. + Args: kwargs: Keyword arguments passed to the function. @@ -578,7 +476,7 @@ class FieldInfo(_repr.Representation): general_metadata = {} for key, value in list(kwargs.items()): try: - marker = FieldInfo.metadata_lookup[key] + marker = cls.metadata_lookup[key] except KeyError: continue @@ -589,45 +487,10 @@ class FieldInfo(_repr.Representation): else: metadata.append(marker(value)) if general_metadata: - metadata.append(_fields.pydantic_general_metadata(**general_metadata)) + metadata.append(_fields.PydanticGeneralMetadata(**general_metadata)) return metadata - def _copy(self) -> Self: - copied = copy(self) - for attr_name in ('metadata', '_attributes_set', '_qualifiers'): - # Apply "deep-copy" behavior on collections attributes: - value = getattr(copied, attr_name).copy() - setattr(copied, attr_name, value) - - return copied - - @property - def deprecation_message(self) -> str | None: - """The deprecation message to be emitted, or `None` if not set.""" - if self.deprecated is None: - return None - if isinstance(self.deprecated, bool): - return 'deprecated' if self.deprecated else None - return self.deprecated if isinstance(self.deprecated, str) else self.deprecated.message - - @property - def default_factory_takes_validated_data(self) -> bool | None: - """Whether the provided default factory callable has a validated data parameter. - - Returns `None` if no default factory is set. - """ - if self.default_factory is not None: - return _fields.takes_validated_data_argument(self.default_factory) - - @overload - def get_default( - self, *, call_default_factory: Literal[True], validated_data: dict[str, Any] | None = None - ) -> Any: ... - - @overload - def get_default(self, *, call_default_factory: Literal[False] = ...) -> Any: ... - - def get_default(self, *, call_default_factory: bool = False, validated_data: dict[str, Any] | None = None) -> Any: + def get_default(self, *, call_default_factory: bool = False) -> Any: """Get the default value. We expose an option for whether to call the default_factory (if present), as calling it may @@ -635,8 +498,7 @@ class FieldInfo(_repr.Representation): be called (namely, when instantiating a model via `model_construct`). Args: - call_default_factory: Whether to call the default factory or not. - validated_data: The already validated data to be passed to the default factory. + call_default_factory: Whether to call the default_factory or not. Defaults to `False`. Returns: The default value, calling the default factory if requested or `None` if not set. @@ -644,36 +506,23 @@ class FieldInfo(_repr.Representation): if self.default_factory is None: return _utils.smart_deepcopy(self.default) elif call_default_factory: - if self.default_factory_takes_validated_data: - fac = cast('Callable[[dict[str, Any]], Any]', self.default_factory) - if validated_data is None: - raise ValueError( - "The default factory requires the 'validated_data' argument, which was not provided when calling 'get_default'." - ) - return fac(validated_data) - else: - fac = cast('Callable[[], Any]', self.default_factory) - return fac() + return self.default_factory() else: return None def is_required(self) -> bool: - """Check if the field is required (i.e., does not have a default value or factory). + """Check if the argument is required. Returns: - `True` if the field is required, `False` otherwise. + `True` if the argument is required, `False` otherwise. """ return self.default is PydanticUndefined and self.default_factory is None def rebuild_annotation(self) -> Any: - """Attempts to rebuild the original annotation for use in function signatures. + """Rebuilds the original annotation for use in function signatures. - If metadata is present, it adds it to the original annotation using - `Annotated`. Otherwise, it returns the original annotation as-is. - - Note that because the metadata has been flattened, the original annotation - may not be reconstructed exactly as originally provided, e.g. if the original - type had unrecognized annotations, or was annotated with a call to `pydantic.Field`. + If metadata is present, it adds it to the original annotation using an + `AnnotatedAlias`. Otherwise, it returns the original annotation as is. Returns: The rebuilt annotation. @@ -682,14 +531,9 @@ class FieldInfo(_repr.Representation): return self.annotation else: # Annotated arguments must be a tuple - return Annotated[(self.annotation, *self.metadata)] # type: ignore + return typing_extensions.Annotated[(self.annotation, *self.metadata)] # type: ignore - def apply_typevars_map( - self, - typevars_map: Mapping[TypeVar, Any] | None, - globalns: GlobalsNamespace | None = None, - localns: MappingNamespace | None = None, - ) -> None: + def apply_typevars_map(self, typevars_map: dict[Any, Any] | None, types_namespace: dict[str, Any] | None) -> None: """Apply a `typevars_map` to the annotation. This method is used when analyzing parametrized generic types to replace typevars with their concrete types. @@ -698,35 +542,23 @@ class FieldInfo(_repr.Representation): Args: typevars_map: A dictionary mapping type variables to their concrete types. - globalns: The globals namespace to use during type annotation evaluation. - localns: The locals namespace to use during type annotation evaluation. + types_namespace (dict | None): A dictionary containing related types to the annotated type. See Also: pydantic._internal._generics.replace_types is used for replacing the typevars with their concrete types. """ - annotation = _generics.replace_types(self.annotation, typevars_map) - annotation, evaluated = _typing_extra.try_eval_type(annotation, globalns, localns) - self.annotation = annotation - if not evaluated: - self._complete = False - self._original_annotation = self.annotation + annotation = _typing_extra.eval_type_lenient(self.annotation, types_namespace, None) + self.annotation = _generics.replace_types(annotation, typevars_map) def __repr_args__(self) -> ReprArgs: yield 'annotation', _repr.PlainRepr(_repr.display_as_type(self.annotation)) yield 'required', self.is_required() for s in self.__slots__: - # TODO: properly make use of the protocol (https://rich.readthedocs.io/en/stable/pretty.html#rich-repr-protocol) - # By yielding a three-tuple: - if s in ( - 'annotation', - '_attributes_set', - '_qualifiers', - '_complete', - '_original_assignment', - '_original_annotation', - ): + if s == '_attributes_set': + continue + if s == 'annotation': continue elif s == 'metadata' and not self.metadata: continue @@ -738,9 +570,7 @@ class FieldInfo(_repr.Representation): continue if s == 'serialization_alias' and self.serialization_alias == self.alias: continue - if s == 'default' and self.default is not PydanticUndefined: - yield 'default', self.default - elif s == 'default_factory' and self.default_factory is not None: + if s == 'default_factory' and self.default_factory is not None: yield 'default_factory', _repr.PlainRepr(_repr.display_as_type(self.default_factory)) else: value = getattr(self, s) @@ -748,234 +578,122 @@ class FieldInfo(_repr.Representation): yield s, value +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class AliasPath: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/fields#aliaspath-and-aliaschoices + + A data class used by `validation_alias` as a convenience to create aliases. + + Attributes: + path: A list of string or integer aliases. + """ + + path: list[int | str] + + def __init__(self, first_arg: str, *args: str | int) -> None: + self.path = [first_arg] + list(args) + + def convert_to_aliases(self) -> list[str | int]: + """Converts arguments to a list of string or integer aliases. + + Returns: + The list of aliases. + """ + return self.path + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class AliasChoices: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/fields#aliaspath-and-aliaschoices + + A data class used by `validation_alias` as a convenience to create aliases. + + Attributes: + choices: A list containing a string or `AliasPath`. + """ + + choices: list[str | AliasPath] + + def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None: + self.choices = [first_choice] + list(choices) + + def convert_to_aliases(self) -> list[list[str | int]]: + """Converts arguments to a list of lists containing string or integer aliases. + + Returns: + The list of aliases. + """ + aliases: list[list[str | int]] = [] + for c in self.choices: + if isinstance(c, AliasPath): + aliases.append(c.convert_to_aliases()) + else: + aliases.append([c]) + return aliases + + class _EmptyKwargs(typing_extensions.TypedDict): """This class exists solely to ensure that type checking warns about passing `**extra` in `Field`.""" -_DefaultValues = { - 'default': ..., - 'default_factory': None, - 'alias': None, - 'alias_priority': None, - 'validation_alias': None, - 'serialization_alias': None, - 'title': None, - 'description': None, - 'examples': None, - 'exclude': None, - 'discriminator': None, - 'json_schema_extra': None, - 'frozen': None, - 'validate_default': None, - 'repr': True, - 'init': None, - 'init_var': None, - 'kw_only': None, - 'pattern': None, - 'strict': None, - 'gt': None, - 'ge': None, - 'lt': None, - 'le': None, - 'multiple_of': None, - 'allow_inf_nan': None, - 'max_digits': None, - 'decimal_places': None, - 'min_length': None, - 'max_length': None, - 'coerce_numbers_to_str': None, -} +_DefaultValues = dict( + default=..., + default_factory=None, + alias=None, + alias_priority=None, + validation_alias=None, + serialization_alias=None, + title=None, + description=None, + examples=None, + exclude=None, + discriminator=None, + json_schema_extra=None, + frozen=None, + validate_default=None, + repr=True, + init_var=None, + kw_only=None, + pattern=None, + strict=None, + gt=None, + ge=None, + lt=None, + le=None, + multiple_of=None, + allow_inf_nan=None, + max_digits=None, + decimal_places=None, + min_length=None, + max_length=None, +) -_T = TypeVar('_T') - - -# NOTE: Actual return type is 'FieldInfo', but we want to help type checkers -# to understand the magic that happens at runtime with the following overloads: -@overload # type hint the return value as `Any` to avoid type checking regressions when using `...`. -def Field( - default: ellipsis, # noqa: F821 # TODO: use `_typing_extra.EllipsisType` when we drop Py3.9 - *, - alias: str | None = _Unset, - alias_priority: int | None = _Unset, - validation_alias: str | AliasPath | AliasChoices | None = _Unset, - serialization_alias: str | None = _Unset, - title: str | None = _Unset, - field_title_generator: Callable[[str, FieldInfo], str] | None = _Unset, - description: str | None = _Unset, - examples: list[Any] | None = _Unset, - exclude: bool | None = _Unset, - discriminator: str | types.Discriminator | None = _Unset, - deprecated: Deprecated | str | bool | None = _Unset, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, - frozen: bool | None = _Unset, - validate_default: bool | None = _Unset, - repr: bool = _Unset, - init: bool | None = _Unset, - init_var: bool | None = _Unset, - kw_only: bool | None = _Unset, - pattern: str | typing.Pattern[str] | None = _Unset, - strict: bool | None = _Unset, - coerce_numbers_to_str: bool | None = _Unset, - gt: annotated_types.SupportsGt | None = _Unset, - ge: annotated_types.SupportsGe | None = _Unset, - lt: annotated_types.SupportsLt | None = _Unset, - le: annotated_types.SupportsLe | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - union_mode: Literal['smart', 'left_to_right'] = _Unset, - fail_fast: bool | None = _Unset, - **extra: Unpack[_EmptyKwargs], -) -> Any: ... -@overload # `default` argument set -def Field( - default: _T, - *, - alias: str | None = _Unset, - alias_priority: int | None = _Unset, - validation_alias: str | AliasPath | AliasChoices | None = _Unset, - serialization_alias: str | None = _Unset, - title: str | None = _Unset, - field_title_generator: Callable[[str, FieldInfo], str] | None = _Unset, - description: str | None = _Unset, - examples: list[Any] | None = _Unset, - exclude: bool | None = _Unset, - discriminator: str | types.Discriminator | None = _Unset, - deprecated: Deprecated | str | bool | None = _Unset, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, - frozen: bool | None = _Unset, - validate_default: bool | None = _Unset, - repr: bool = _Unset, - init: bool | None = _Unset, - init_var: bool | None = _Unset, - kw_only: bool | None = _Unset, - pattern: str | typing.Pattern[str] | None = _Unset, - strict: bool | None = _Unset, - coerce_numbers_to_str: bool | None = _Unset, - gt: annotated_types.SupportsGt | None = _Unset, - ge: annotated_types.SupportsGe | None = _Unset, - lt: annotated_types.SupportsLt | None = _Unset, - le: annotated_types.SupportsLe | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - union_mode: Literal['smart', 'left_to_right'] = _Unset, - fail_fast: bool | None = _Unset, - **extra: Unpack[_EmptyKwargs], -) -> _T: ... -@overload # `default_factory` argument set -def Field( - *, - default_factory: Callable[[], _T] | Callable[[dict[str, Any]], _T], - alias: str | None = _Unset, - alias_priority: int | None = _Unset, - validation_alias: str | AliasPath | AliasChoices | None = _Unset, - serialization_alias: str | None = _Unset, - title: str | None = _Unset, - field_title_generator: Callable[[str, FieldInfo], str] | None = _Unset, - description: str | None = _Unset, - examples: list[Any] | None = _Unset, - exclude: bool | None = _Unset, - discriminator: str | types.Discriminator | None = _Unset, - deprecated: Deprecated | str | bool | None = _Unset, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, - frozen: bool | None = _Unset, - validate_default: bool | None = _Unset, - repr: bool = _Unset, - init: bool | None = _Unset, - init_var: bool | None = _Unset, - kw_only: bool | None = _Unset, - pattern: str | typing.Pattern[str] | None = _Unset, - strict: bool | None = _Unset, - coerce_numbers_to_str: bool | None = _Unset, - gt: annotated_types.SupportsGt | None = _Unset, - ge: annotated_types.SupportsGe | None = _Unset, - lt: annotated_types.SupportsLt | None = _Unset, - le: annotated_types.SupportsLe | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - union_mode: Literal['smart', 'left_to_right'] = _Unset, - fail_fast: bool | None = _Unset, - **extra: Unpack[_EmptyKwargs], -) -> _T: ... -@overload -def Field( # No default set - *, - alias: str | None = _Unset, - alias_priority: int | None = _Unset, - validation_alias: str | AliasPath | AliasChoices | None = _Unset, - serialization_alias: str | None = _Unset, - title: str | None = _Unset, - field_title_generator: Callable[[str, FieldInfo], str] | None = _Unset, - description: str | None = _Unset, - examples: list[Any] | None = _Unset, - exclude: bool | None = _Unset, - discriminator: str | types.Discriminator | None = _Unset, - deprecated: Deprecated | str | bool | None = _Unset, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, - frozen: bool | None = _Unset, - validate_default: bool | None = _Unset, - repr: bool = _Unset, - init: bool | None = _Unset, - init_var: bool | None = _Unset, - kw_only: bool | None = _Unset, - pattern: str | typing.Pattern[str] | None = _Unset, - strict: bool | None = _Unset, - coerce_numbers_to_str: bool | None = _Unset, - gt: annotated_types.SupportsGt | None = _Unset, - ge: annotated_types.SupportsGe | None = _Unset, - lt: annotated_types.SupportsLt | None = _Unset, - le: annotated_types.SupportsLe | None = _Unset, - multiple_of: float | None = _Unset, - allow_inf_nan: bool | None = _Unset, - max_digits: int | None = _Unset, - decimal_places: int | None = _Unset, - min_length: int | None = _Unset, - max_length: int | None = _Unset, - union_mode: Literal['smart', 'left_to_right'] = _Unset, - fail_fast: bool | None = _Unset, - **extra: Unpack[_EmptyKwargs], -) -> Any: ... def Field( # noqa: C901 default: Any = PydanticUndefined, *, - default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any] | None = _Unset, + default_factory: typing.Callable[[], Any] | None = _Unset, alias: str | None = _Unset, alias_priority: int | None = _Unset, validation_alias: str | AliasPath | AliasChoices | None = _Unset, serialization_alias: str | None = _Unset, title: str | None = _Unset, - field_title_generator: Callable[[str, FieldInfo], str] | None = _Unset, description: str | None = _Unset, examples: list[Any] | None = _Unset, exclude: bool | None = _Unset, - discriminator: str | types.Discriminator | None = _Unset, - deprecated: Deprecated | str | bool | None = _Unset, - json_schema_extra: JsonDict | Callable[[JsonDict], None] | None = _Unset, + discriminator: str | None = _Unset, + json_schema_extra: dict[str, Any] | typing.Callable[[dict[str, Any]], None] | None = _Unset, frozen: bool | None = _Unset, validate_default: bool | None = _Unset, repr: bool = _Unset, - init: bool | None = _Unset, init_var: bool | None = _Unset, kw_only: bool | None = _Unset, - pattern: str | typing.Pattern[str] | None = _Unset, + pattern: str | None = _Unset, strict: bool | None = _Unset, - coerce_numbers_to_str: bool | None = _Unset, - gt: annotated_types.SupportsGt | None = _Unset, - ge: annotated_types.SupportsGe | None = _Unset, - lt: annotated_types.SupportsLt | None = _Unset, - le: annotated_types.SupportsLe | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, multiple_of: float | None = _Unset, allow_inf_nan: bool | None = _Unset, max_digits: int | None = _Unset, @@ -983,11 +701,9 @@ def Field( # noqa: C901 min_length: int | None = _Unset, max_length: int | None = _Unset, union_mode: Literal['smart', 'left_to_right'] = _Unset, - fail_fast: bool | None = _Unset, **extra: Unpack[_EmptyKwargs], ) -> Any: - """!!! abstract "Usage Documentation" - [Fields](../concepts/fields.md) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/fields Create a field for objects that can be configured. @@ -999,33 +715,24 @@ def Field( # noqa: C901 Args: default: Default value if the field is not set. - default_factory: A callable to generate the default value. The callable can either take 0 arguments - (in which case it is called as is) or a single argument containing the already validated data. - alias: The name to use for the attribute when validating or serializing by alias. - This is often used for things like converting between snake and camel case. + default_factory: A callable to generate the default value, such as :func:`~datetime.utcnow`. + alias: An alternative name for the attribute. alias_priority: Priority of the alias. This affects whether an alias generator is used. - validation_alias: Like `alias`, but only affects validation, not serialization. - serialization_alias: Like `alias`, but only affects serialization, not validation. + validation_alias: 'Whitelist' validation step. The field will be the single one allowed by the alias or set of + aliases defined. + serialization_alias: 'Blacklist' validation step. The vanilla field will be the single one of the alias' or set + of aliases' fields and all the other fields will be ignored at serialization time. title: Human-readable title. - field_title_generator: A callable that takes a field name and returns title for it. description: Human-readable description. examples: Example values for this field. exclude: Whether to exclude the field from the model serialization. - discriminator: Field name or Discriminator for discriminating the type in a tagged union. - deprecated: A deprecation message, an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport, - or a boolean. If `True`, a default deprecation message will be emitted when accessing the field. - json_schema_extra: A dict or callable to provide extra JSON schema properties. - frozen: Whether the field is frozen. If true, attempts to change the value on an instance will raise an error. - validate_default: If `True`, apply validation to the default value every time you create an instance. - Otherwise, for performance reasons, the default value of the field is trusted and not validated. + discriminator: Field name for discriminating the type in a tagged union. + json_schema_extra: Any additional JSON schema data for the schema property. + frozen: Whether the field is frozen. + validate_default: Run validation that isn't only checking existence of defaults. This can be set to `True` or `False`. If not set, it defaults to `None`. repr: A boolean indicating whether to include the field in the `__repr__` output. - init: Whether the field should be included in the constructor of the dataclass. - (Only applies to dataclasses.) - init_var: Whether the field should _only_ be included in the constructor of the dataclass. - (Only applies to dataclasses.) + init_var: Whether the field should be included in the constructor of the dataclass. kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. - (Only applies to dataclasses.) - coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). strict: If `True`, strict validation is applied to the field. See [Strict Mode](../concepts/strict_mode.md) for details. gt: Greater than. If set, value must be greater than this. Only applicable to numbers. @@ -1033,24 +740,22 @@ def Field( # noqa: C901 lt: Less than. If set, value must be less than this. Only applicable to numbers. le: Less than or equal. If set, value must be less than or equal to this. Only applicable to numbers. multiple_of: Value must be a multiple of this. Only applicable to numbers. - min_length: Minimum length for iterables. - max_length: Maximum length for iterables. - pattern: Pattern for strings (a regular expression). - allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to float and [`Decimal`][decimal.Decimal] numbers. + min_length: Minimum length for strings. + max_length: Maximum length for strings. + pattern: Pattern for strings. + allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to numbers. max_digits: Maximum number of allow digits for strings. decimal_places: Maximum number of decimal places allowed for numbers. union_mode: The strategy to apply when validating a union. Can be `smart` (the default), or `left_to_right`. - See [Union Mode](../concepts/unions.md#union-modes) for details. - fail_fast: If `True`, validation will stop on the first error. If `False`, all validation errors will be collected. - This option can be applied only to iterable types (list, tuple, set, and frozenset). - extra: (Deprecated) Extra fields that will be included in the JSON schema. + See [Union Mode](standard_library_types.md#union-mode) for details. + extra: Include extra fields used by the JSON schema. !!! warning Deprecated The `extra` kwargs is deprecated. Use `json_schema_extra` instead. Returns: - A new [`FieldInfo`][pydantic.fields.FieldInfo]. The return annotation is `Any` so `Field` can be used on - type-annotated fields without causing a type error. + A new [`FieldInfo`][pydantic.fields.FieldInfo], the return annotation is `Any` so `Field` can be used on + type annotated fields without causing a typing error. """ # Check deprecated and removed params from V1. This logic should eventually be removed. const = extra.pop('const', None) # type: ignore @@ -1124,21 +829,17 @@ def Field( # noqa: C901 validation_alias=validation_alias, serialization_alias=serialization_alias, title=title, - field_title_generator=field_title_generator, description=description, examples=examples, exclude=exclude, discriminator=discriminator, - deprecated=deprecated, json_schema_extra=json_schema_extra, frozen=frozen, pattern=pattern, validate_default=validate_default, repr=repr, - init=init, init_var=init_var, kw_only=kw_only, - coerce_numbers_to_str=coerce_numbers_to_str, strict=strict, gt=gt, ge=ge, @@ -1151,7 +852,6 @@ def Field( # noqa: C901 max_digits=max_digits, decimal_places=decimal_places, union_mode=union_mode, - fail_fast=fail_fast, ) @@ -1162,25 +862,18 @@ _FIELD_ARG_NAMES.remove('extra') # do not include the varkwargs parameter class ModelPrivateAttr(_repr.Representation): """A descriptor for private attributes in class models. - !!! warning - You generally shouldn't be creating `ModelPrivateAttr` instances directly, instead use - `pydantic.fields.PrivateAttr`. (This is similar to `FieldInfo` vs. `Field`.) - Attributes: default: The default value of the attribute if not provided. default_factory: A callable function that generates the default value of the attribute if not provided. """ - __slots__ = ('default', 'default_factory') + __slots__ = 'default', 'default_factory' def __init__( self, default: Any = PydanticUndefined, *, default_factory: typing.Callable[[], Any] | None = None ) -> None: - if default is Ellipsis: - self.default = PydanticUndefined - else: - self.default = default + self.default = default self.default_factory = default_factory if not typing.TYPE_CHECKING: @@ -1197,10 +890,11 @@ class ModelPrivateAttr(_repr.Representation): def __set_name__(self, cls: type[Any], name: str) -> None: """Preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487.""" - default = self.default - if default is PydanticUndefined: + if self.default is PydanticUndefined: return - set_name = getattr(default, '__set_name__', None) + if not hasattr(self.default, '__set_name__'): + return + set_name = self.default.__set_name__ if callable(set_name): set_name(cls, name) @@ -1223,37 +917,14 @@ class ModelPrivateAttr(_repr.Representation): ) -# NOTE: Actual return type is 'ModelPrivateAttr', but we want to help type checkers -# to understand the magic that happens at runtime. -@overload # `default` argument set -def PrivateAttr( - default: _T, - *, - init: Literal[False] = False, -) -> _T: ... -@overload # `default_factory` argument set -def PrivateAttr( - *, - default_factory: Callable[[], _T], - init: Literal[False] = False, -) -> _T: ... -@overload # No default set -def PrivateAttr( - *, - init: Literal[False] = False, -) -> Any: ... def PrivateAttr( default: Any = PydanticUndefined, *, - default_factory: Callable[[], Any] | None = None, - init: Literal[False] = False, + default_factory: typing.Callable[[], Any] | None = None, ) -> Any: - """!!! abstract "Usage Documentation" - [Private Model Attributes](../concepts/models.md#private-model-attributes) + """Indicates that attribute is only used internally and never mixed with regular fields. - Indicates that an attribute is intended for private use and not handled during normal validation/serialization. - - Private attributes are not validated by Pydantic, so it's up to you to ensure they are used in a type-safe manner. + Private attributes are not checked by Pydantic, so it's up to you to maintain their accuracy. Private attributes are stored in `__private_attributes__` on the model. @@ -1262,7 +933,6 @@ def PrivateAttr( default_factory: Callable that will be called when a default value is needed for this attribute. If both `default` and `default_factory` are set, an error will be raised. - init: Whether the attribute should be included in the constructor of the dataclass. Always `False`. Returns: An instance of [`ModelPrivateAttr`][pydantic.fields.ModelPrivateAttr] class. @@ -1287,16 +957,11 @@ class ComputedFieldInfo: decorator_repr: A class variable representing the decorator string, '@computed_field'. wrapped_property: The wrapped computed field property. return_type: The type of the computed field property's return value. - alias: The alias of the property to be used during serialization. - alias_priority: The priority of the alias. This affects whether an alias generator is used. - title: Title of the computed field to include in the serialization JSON schema. - field_title_generator: A callable that takes a field name and returns title for it. - description: Description of the computed field to include in the serialization JSON schema. - deprecated: A deprecation message, an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport, - or a boolean. If `True`, a default deprecation message will be emitted when accessing the field. - examples: Example values of the computed field to include in the serialization JSON schema. - json_schema_extra: A dict or callable to provide extra JSON schema properties. - repr: A boolean indicating whether to include the field in the __repr__ output. + alias: The alias of the property to be used during encoding and decoding. + alias_priority: priority of the alias. This affects whether an alias generator is used + title: Title of the computed field as in OpenAPI document, should be a short summary. + description: Description of the computed field as in OpenAPI document. + repr: A boolean indicating whether or not to include the field in the __repr__ output. """ decorator_repr: ClassVar[str] = '@computed_field' @@ -1305,21 +970,31 @@ class ComputedFieldInfo: alias: str | None alias_priority: int | None title: str | None - field_title_generator: typing.Callable[[str, ComputedFieldInfo], str] | None description: str | None - deprecated: Deprecated | str | bool | None - examples: list[Any] | None - json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None repr: bool - @property - def deprecation_message(self) -> str | None: - """The deprecation message to be emitted, or `None` if not set.""" - if self.deprecated is None: - return None - if isinstance(self.deprecated, bool): - return 'deprecated' if self.deprecated else None - return self.deprecated if isinstance(self.deprecated, str) else self.deprecated.message + +# this should really be `property[T], cached_proprety[T]` but property is not generic unlike cached_property +# See https://github.com/python/typing/issues/985 and linked issues +PropertyT = typing.TypeVar('PropertyT') + + +@typing.overload +def computed_field( + *, + return_type: Any = PydanticUndefined, + alias: str | None = None, + alias_priority: int | None = None, + title: str | None = None, + description: str | None = None, + repr: bool = True, +) -> typing.Callable[[PropertyT], PropertyT]: + ... + + +@typing.overload +def computed_field(__func: PropertyT) -> PropertyT: + ... def _wrapped_property_is_private(property_: cached_property | property) -> bool: # type: ignore @@ -1334,54 +1009,21 @@ def _wrapped_property_is_private(property_: cached_property | property) -> bool: return wrapped_name.startswith('_') and not wrapped_name.startswith('__') -# this should really be `property[T], cached_property[T]` but property is not generic unlike cached_property -# See https://github.com/python/typing/issues/985 and linked issues -PropertyT = typing.TypeVar('PropertyT') - - -@typing.overload -def computed_field(func: PropertyT, /) -> PropertyT: ... - - -@typing.overload def computed_field( + __f: PropertyT | None = None, *, alias: str | None = None, alias_priority: int | None = None, title: str | None = None, - field_title_generator: typing.Callable[[str, ComputedFieldInfo], str] | None = None, description: str | None = None, - deprecated: Deprecated | str | bool | None = None, - examples: list[Any] | None = None, - json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = None, - repr: bool = True, - return_type: Any = PydanticUndefined, -) -> typing.Callable[[PropertyT], PropertyT]: ... - - -def computed_field( - func: PropertyT | None = None, - /, - *, - alias: str | None = None, - alias_priority: int | None = None, - title: str | None = None, - field_title_generator: typing.Callable[[str, ComputedFieldInfo], str] | None = None, - description: str | None = None, - deprecated: Deprecated | str | bool | None = None, - examples: list[Any] | None = None, - json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = None, repr: bool | None = None, return_type: Any = PydanticUndefined, ) -> PropertyT | typing.Callable[[PropertyT], PropertyT]: - """!!! abstract "Usage Documentation" - [The `computed_field` decorator](../concepts/fields.md#the-computed_field-decorator) - - Decorator to include `property` and `cached_property` when serializing models or dataclasses. + """Decorator to include `property` and `cached_property` when serializing models or dataclasses. This is useful for fields that are computed from other fields, or for fields that are expensive to compute and should be cached. - ```python + ```py from pydantic import BaseModel, computed_field class Rectangle(BaseModel): @@ -1405,11 +1047,11 @@ def computed_field( Even with the `@property` or `@cached_property` applied to your function before `@computed_field`, mypy may throw a `Decorated property not supported` error. See [mypy issue #1362](https://github.com/python/mypy/issues/1362), for more information. - To avoid this error message, add `# type: ignore[prop-decorator]` to the `@computed_field` line. + To avoid this error message, add `# type: ignore[misc]` to the `@computed_field` line. [pyright](https://github.com/microsoft/pyright) supports `@computed_field` without error. - ```python + ```py import random from pydantic import BaseModel, computed_field @@ -1449,7 +1091,7 @@ def computed_field( `mypy` complains about this behavior if allowed, and `dataclasses` doesn't allow this pattern either. See the example below: - ```python + ```py from pydantic import BaseModel, computed_field class Parent(BaseModel): @@ -1463,16 +1105,14 @@ def computed_field( def a(self) -> str: return 'new a' - except TypeError as e: - print(e) - ''' - Field 'a' of class 'Child' overrides symbol of same name in a parent class. This override with a computed_field is incompatible. - ''' + except ValueError as e: + print(repr(e)) + #> ValueError("you can't override a field with a computed field") ``` Private properties decorated with `@computed_field` have `repr=False` by default. - ```python + ```py from functools import cached_property from pydantic import BaseModel, computed_field @@ -1492,22 +1132,16 @@ def computed_field( m = Model(foo=1) print(repr(m)) - #> Model(foo=1) + #> M(foo=1) ``` Args: - func: the function to wrap. + __f: the function to wrap. alias: alias to use when serializing this computed field, only used when `by_alias=True` alias_priority: priority of the alias. This affects whether an alias generator is used - title: Title to use when including this computed field in JSON Schema - field_title_generator: A callable that takes a field name and returns title for it. - description: Description to use when including this computed field in JSON Schema, defaults to the function's - docstring - deprecated: A deprecation message (or an instance of `warnings.deprecated` or the `typing_extensions.deprecated` backport). - to be emitted when accessing the field. Or a boolean. This will automatically be set if the property is decorated with the - `deprecated` decorator. - examples: Example values to use when including this computed field in JSON Schema - json_schema_extra: A dict or callable to provide extra JSON schema properties. + title: Title to used when including this computed field in JSON Schema, currently unused waiting for #4697 + description: Description to used when including this computed field in JSON Schema, defaults to the functions + docstring, currently unused waiting for #4697 repr: whether to include this computed field in model repr. Default is `False` for private properties and `True` for public properties. return_type: optional return for serialization logic to expect when serializing to JSON, if included @@ -1520,40 +1154,24 @@ def computed_field( """ def dec(f: Any) -> Any: - nonlocal description, deprecated, return_type, alias_priority + nonlocal description, return_type, alias_priority unwrapped = _decorators.unwrap_wrapped_function(f) - if description is None and unwrapped.__doc__: description = inspect.cleandoc(unwrapped.__doc__) - if deprecated is None and hasattr(unwrapped, '__deprecated__'): - deprecated = unwrapped.__deprecated__ - # if the function isn't already decorated with `@property` (or another descriptor), then we wrap it now f = _decorators.ensure_property(f) alias_priority = (alias_priority or 2) if alias is not None else None if repr is None: - repr_: bool = not _wrapped_property_is_private(property_=f) + repr_: bool = False if _wrapped_property_is_private(property_=f) else True else: repr_ = repr - dec_info = ComputedFieldInfo( - f, - return_type, - alias, - alias_priority, - title, - field_title_generator, - description, - deprecated, - examples, - json_schema_extra, - repr_, - ) + dec_info = ComputedFieldInfo(f, return_type, alias, alias_priority, title, description, repr_) return _decorators.PydanticDescriptorProxy(f, dec_info) - if func is None: + if __f is None: return dec else: - return dec(func) + return dec(__f) diff --git a/venv/lib/python3.12/site-packages/pydantic/functional_serializers.py b/venv/lib/python3.12/site-packages/pydantic/functional_serializers.py index 4b065e4..849dfe5 100644 --- a/venv/lib/python3.12/site-packages/pydantic/functional_serializers.py +++ b/venv/lib/python3.12/site-packages/pydantic/functional_serializers.py @@ -1,14 +1,13 @@ """This module contains related classes and functions for serialization.""" - from __future__ import annotations import dataclasses -from functools import partial, partialmethod -from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, overload +from functools import partialmethod +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload from pydantic_core import PydanticUndefined, core_schema -from pydantic_core.core_schema import SerializationInfo, SerializerFunctionWrapHandler, WhenUsed -from typing_extensions import TypeAlias +from pydantic_core import core_schema as _core_schema +from typing_extensions import Annotated, Literal, TypeAlias from . import PydanticUndefinedAnnotation from ._internal import _decorators, _internal_dataclass @@ -19,26 +18,6 @@ from .annotated_handlers import GetCoreSchemaHandler class PlainSerializer: """Plain serializers use a function to modify the output of serialization. - This is particularly helpful when you want to customize the serialization for annotated types. - Consider an input of `list`, which will be serialized into a space-delimited string. - - ```python - from typing import Annotated - - from pydantic import BaseModel, PlainSerializer - - CustomStr = Annotated[ - list, PlainSerializer(lambda x: ' '.join(x), return_type=str) - ] - - class StudentModel(BaseModel): - courses: CustomStr - - student = StudentModel(courses=['Math', 'Chemistry', 'English']) - print(student.model_dump()) - #> {'courses': 'Math Chemistry English'} - ``` - Attributes: func: The serializer function. return_type: The return type for the function. If omitted it will be inferred from the type annotation. @@ -48,7 +27,7 @@ class PlainSerializer: func: core_schema.SerializerFunction return_type: Any = PydanticUndefined - when_used: WhenUsed = 'always' + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always' def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: """Gets the Pydantic core schema. @@ -61,20 +40,12 @@ class PlainSerializer: The Pydantic core schema. """ schema = handler(source_type) - if self.return_type is not PydanticUndefined: - return_type = self.return_type - else: - try: - # Do not pass in globals as the function could be defined in a different module. - # Instead, let `get_callable_return_type` infer the globals to use, but still pass - # in locals that may contain a parent/rebuild namespace: - return_type = _decorators.get_callable_return_type( - self.func, - localns=handler._get_types_namespace().locals, - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e - + try: + return_type = _decorators.get_function_return_type( + self.func, self.return_type, handler._get_types_namespace() + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type) schema['serialization'] = core_schema.plain_serializer_function_ser_schema( function=self.func, @@ -90,58 +61,6 @@ class WrapSerializer: """Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization logic, and can modify the resulting value before returning it as the final output of serialization. - For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic. - - ```python - from datetime import datetime, timezone - from typing import Annotated, Any - - from pydantic import BaseModel, WrapSerializer - - class EventDatetime(BaseModel): - start: datetime - end: datetime - - def convert_to_utc(value: Any, handler, info) -> dict[str, datetime]: - # Note that `handler` can actually help serialize the `value` for - # further custom serialization in case it's a subclass. - partial_result = handler(value, info) - if info.mode == 'json': - return { - k: datetime.fromisoformat(v).astimezone(timezone.utc) - for k, v in partial_result.items() - } - return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()} - - UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)] - - class EventModel(BaseModel): - event_datetime: UTCEventDatetime - - dt = EventDatetime( - start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00' - ) - event = EventModel(event_datetime=dt) - print(event.model_dump()) - ''' - { - 'event_datetime': { - 'start': datetime.datetime( - 2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc - ), - 'end': datetime.datetime( - 2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc - ), - } - } - ''' - - print(event.model_dump_json()) - ''' - {"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}} - ''' - ``` - Attributes: func: The serializer function to be wrapped. return_type: The return type for the function. If omitted it will be inferred from the type annotation. @@ -151,7 +70,7 @@ class WrapSerializer: func: core_schema.WrapSerializerFunction return_type: Any = PydanticUndefined - when_used: WhenUsed = 'always' + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always' def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: """This method is used to get the Pydantic core schema of the class. @@ -164,20 +83,12 @@ class WrapSerializer: The generated core schema of the class. """ schema = handler(source_type) - if self.return_type is not PydanticUndefined: - return_type = self.return_type - else: - try: - # Do not pass in globals as the function could be defined in a different module. - # Instead, let `get_callable_return_type` infer the globals to use, but still pass - # in locals that may contain a parent/rebuild namespace: - return_type = _decorators.get_callable_return_type( - self.func, - localns=handler._get_types_namespace().locals, - ) - except NameError as e: - raise PydanticUndefinedAnnotation.from_name_error(e) from e - + try: + return_type = _decorators.get_function_return_type( + self.func, self.return_type, handler._get_types_namespace() + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type) schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( function=self.func, @@ -189,77 +100,57 @@ class WrapSerializer: if TYPE_CHECKING: - _Partial: TypeAlias = 'partial[Any] | partialmethod[Any]' - - FieldPlainSerializer: TypeAlias = 'core_schema.SerializerFunction | _Partial' - """A field serializer method or function in `plain` mode.""" - - FieldWrapSerializer: TypeAlias = 'core_schema.WrapSerializerFunction | _Partial' - """A field serializer method or function in `wrap` mode.""" - - FieldSerializer: TypeAlias = 'FieldPlainSerializer | FieldWrapSerializer' - """A field serializer method or function.""" - - _FieldPlainSerializerT = TypeVar('_FieldPlainSerializerT', bound=FieldPlainSerializer) - _FieldWrapSerializerT = TypeVar('_FieldWrapSerializerT', bound=FieldWrapSerializer) + _PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]] + _PlainSerializationFunction = Union[_core_schema.SerializerFunction, _PartialClsOrStaticMethod] + _WrapSerializationFunction = Union[_core_schema.WrapSerializerFunction, _PartialClsOrStaticMethod] + _PlainSerializeMethodType = TypeVar('_PlainSerializeMethodType', bound=_PlainSerializationFunction) + _WrapSerializeMethodType = TypeVar('_WrapSerializeMethodType', bound=_WrapSerializationFunction) @overload def field_serializer( - field: str, - /, + __field: str, + *fields: str, + return_type: Any = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]: + ... + + +@overload +def field_serializer( + __field: str, + *fields: str, + mode: Literal['plain'], + return_type: Any = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]: + ... + + +@overload +def field_serializer( + __field: str, *fields: str, mode: Literal['wrap'], return_type: Any = ..., - when_used: WhenUsed = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., check_fields: bool | None = ..., -) -> Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT]: ... - - -@overload -def field_serializer( - field: str, - /, - *fields: str, - mode: Literal['plain'] = ..., - return_type: Any = ..., - when_used: WhenUsed = ..., - check_fields: bool | None = ..., -) -> Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT]: ... +) -> Callable[[_WrapSerializeMethodType], _WrapSerializeMethodType]: + ... def field_serializer( *fields: str, mode: Literal['plain', 'wrap'] = 'plain', return_type: Any = PydanticUndefined, - when_used: WhenUsed = 'always', + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', check_fields: bool | None = None, -) -> ( - Callable[[_FieldWrapSerializerT], _FieldWrapSerializerT] - | Callable[[_FieldPlainSerializerT], _FieldPlainSerializerT] -): +) -> Callable[[Any], Any]: """Decorator that enables custom field serialization. - In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list. - - ```python - from typing import Set - - from pydantic import BaseModel, field_serializer - - class StudentModel(BaseModel): - name: str = 'Jane' - courses: Set[str] - - @field_serializer('courses', when_used='json') - def serialize_courses_in_order(self, courses: Set[str]): - return sorted(courses) - - student = StudentModel(courses={'Math', 'Chemistry', 'English'}) - print(student.model_dump_json()) - #> {"name":"Jane","courses":["Chemistry","English","Math"]} - ``` - See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information. Four signatures are supported: @@ -284,7 +175,9 @@ def field_serializer( The decorator function. """ - def dec(f: FieldSerializer) -> _decorators.PydanticDescriptorProxy[Any]: + def dec( + f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any] + ) -> _decorators.PydanticDescriptorProxy[Any]: dec_info = _decorators.FieldSerializerDecoratorInfo( fields=fields, mode=mode, @@ -292,109 +185,42 @@ def field_serializer( when_used=when_used, check_fields=check_fields, ) - return _decorators.PydanticDescriptorProxy(f, dec_info) # pyright: ignore[reportArgumentType] + return _decorators.PydanticDescriptorProxy(f, dec_info) - return dec # pyright: ignore[reportReturnType] + return dec -if TYPE_CHECKING: - # The first argument in the following callables represent the `self` type: - - ModelPlainSerializerWithInfo: TypeAlias = Callable[[Any, SerializationInfo], Any] - """A model serializer method with the `info` argument, in `plain` mode.""" - - ModelPlainSerializerWithoutInfo: TypeAlias = Callable[[Any], Any] - """A model serializer method without the `info` argument, in `plain` mode.""" - - ModelPlainSerializer: TypeAlias = 'ModelPlainSerializerWithInfo | ModelPlainSerializerWithoutInfo' - """A model serializer method in `plain` mode.""" - - ModelWrapSerializerWithInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any] - """A model serializer method with the `info` argument, in `wrap` mode.""" - - ModelWrapSerializerWithoutInfo: TypeAlias = Callable[[Any, SerializerFunctionWrapHandler], Any] - """A model serializer method without the `info` argument, in `wrap` mode.""" - - ModelWrapSerializer: TypeAlias = 'ModelWrapSerializerWithInfo | ModelWrapSerializerWithoutInfo' - """A model serializer method in `wrap` mode.""" - - ModelSerializer: TypeAlias = 'ModelPlainSerializer | ModelWrapSerializer' - - _ModelPlainSerializerT = TypeVar('_ModelPlainSerializerT', bound=ModelPlainSerializer) - _ModelWrapSerializerT = TypeVar('_ModelWrapSerializerT', bound=ModelWrapSerializer) +FuncType = TypeVar('FuncType', bound=Callable[..., Any]) @overload -def model_serializer(f: _ModelPlainSerializerT, /) -> _ModelPlainSerializerT: ... - - -@overload -def model_serializer( - *, mode: Literal['wrap'], when_used: WhenUsed = 'always', return_type: Any = ... -) -> Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT]: ... +def model_serializer(__f: FuncType) -> FuncType: + ... @overload def model_serializer( *, - mode: Literal['plain'] = ..., - when_used: WhenUsed = 'always', + mode: Literal['plain', 'wrap'] = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', return_type: Any = ..., -) -> Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT]: ... +) -> Callable[[FuncType], FuncType]: + ... def model_serializer( - f: _ModelPlainSerializerT | _ModelWrapSerializerT | None = None, - /, + __f: Callable[..., Any] | None = None, *, mode: Literal['plain', 'wrap'] = 'plain', - when_used: WhenUsed = 'always', + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', return_type: Any = PydanticUndefined, -) -> ( - _ModelPlainSerializerT - | Callable[[_ModelWrapSerializerT], _ModelWrapSerializerT] - | Callable[[_ModelPlainSerializerT], _ModelPlainSerializerT] -): +) -> Callable[[Any], Any]: """Decorator that enables custom model serialization. - This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields. - - An example would be to serialize temperature to the same temperature scale, such as degrees Celsius. - - ```python - from typing import Literal - - from pydantic import BaseModel, model_serializer - - class TemperatureModel(BaseModel): - unit: Literal['C', 'F'] - value: int - - @model_serializer() - def serialize_model(self): - if self.unit == 'F': - return {'unit': 'C', 'value': int((self.value - 32) / 1.8)} - return {'unit': self.unit, 'value': self.value} - - temperature = TemperatureModel(unit='F', value=212) - print(temperature.model_dump()) - #> {'unit': 'C', 'value': 100} - ``` - - Two signatures are supported for `mode='plain'`, which is the default: - - - `(self)` - - `(self, info: SerializationInfo)` - - And two other signatures for `mode='wrap'`: - - - `(self, nxt: SerializerFunctionWrapHandler)` - - `(self, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)` - - See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information. + See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information. Args: - f: The function to be decorated. + __f: The function to be decorated. mode: The serialization mode. - `'plain'` means the function will be called instead of the default serialization logic @@ -407,14 +233,14 @@ def model_serializer( The decorator function. """ - def dec(f: ModelSerializer) -> _decorators.PydanticDescriptorProxy[Any]: + def dec(f: Callable[..., Any]) -> _decorators.PydanticDescriptorProxy[Any]: dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used) return _decorators.PydanticDescriptorProxy(f, dec_info) - if f is None: - return dec # pyright: ignore[reportReturnType] + if __f is None: + return dec else: - return dec(f) # pyright: ignore[reportReturnType] + return dec(__f) # type: ignore AnyType = TypeVar('AnyType') diff --git a/venv/lib/python3.12/site-packages/pydantic/functional_validators.py b/venv/lib/python3.12/site-packages/pydantic/functional_validators.py index 2eed4ef..5808cc5 100644 --- a/venv/lib/python3.12/site-packages/pydantic/functional_validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/functional_validators.py @@ -6,13 +6,14 @@ import dataclasses import sys from functools import partialmethod from types import FunctionType -from typing import TYPE_CHECKING, Annotated, Any, Callable, Literal, TypeVar, Union, cast, overload +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload -from pydantic_core import PydanticUndefined, core_schema +from pydantic_core import core_schema from pydantic_core import core_schema as _core_schema -from typing_extensions import Self, TypeAlias +from typing_extensions import Annotated, Literal, TypeAlias -from ._internal import _decorators, _generics, _internal_dataclass +from . import GetCoreSchemaHandler as _GetCoreSchemaHandler +from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass from .annotated_handlers import GetCoreSchemaHandler from .errors import PydanticUserError @@ -26,8 +27,7 @@ _inspect_validator = _decorators.inspect_validator @dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) class AfterValidator: - """!!! abstract "Usage Documentation" - [field *after* validators](../concepts/validators.md#field-after-validator) + '''Usage docs: https://docs.pydantic.dev/2.2/concepts/validators/#annotated-validators A metadata class that indicates that a validation should be applied **after** the inner validation logic. @@ -35,10 +35,11 @@ class AfterValidator: func: The validator function. Example: - ```python + ```py from typing import Annotated - from pydantic import AfterValidator, BaseModel, ValidationError + from pydantic import BaseModel, AfterValidator, ValidationError + MyInt = Annotated[int, AfterValidator(lambda v: v + 1)] @@ -46,31 +47,31 @@ class AfterValidator: a: MyInt print(Model(a=1).a) - #> 2 + # > 2 try: Model(a='a') except ValidationError as e: print(e.json(indent=2)) - ''' - [ - { + """ + [ + { "type": "int_parsing", "loc": [ - "a" + "a" ], "msg": "Input should be a valid integer, unable to parse string as an integer", "input": "a", - "url": "https://errors.pydantic.dev/2/v/int_parsing" - } - ] - ''' + "url": "https://errors.pydantic.dev/0.38.0/v/int_parsing" + } + ] + """ ``` - """ + ''' func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: schema = handler(source_type) info_arg = _inspect_validator(self.func, 'after') if info_arg: @@ -80,26 +81,19 @@ class AfterValidator: func = cast(core_schema.NoInfoValidatorFunction, self.func) return core_schema.no_info_after_validator_function(func, schema=schema) - @classmethod - def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self: - return cls(func=decorator.func) - @dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) class BeforeValidator: - """!!! abstract "Usage Documentation" - [field *before* validators](../concepts/validators.md#field-before-validator) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators A metadata class that indicates that a validation should be applied **before** the inner validation logic. Attributes: func: The validator function. - json_schema_input_type: The input type of the function. This is only used to generate the appropriate - JSON Schema (in validation mode). Example: - ```python - from typing import Annotated + ```py + from typing_extensions import Annotated from pydantic import BaseModel, BeforeValidator @@ -120,151 +114,68 @@ class BeforeValidator: """ func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction - json_schema_input_type: Any = PydanticUndefined - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: schema = handler(source_type) - input_schema = ( - None - if self.json_schema_input_type is PydanticUndefined - else handler.generate_schema(self.json_schema_input_type) - ) - info_arg = _inspect_validator(self.func, 'before') if info_arg: func = cast(core_schema.WithInfoValidatorFunction, self.func) - return core_schema.with_info_before_validator_function( - func, - schema=schema, - field_name=handler.field_name, - json_schema_input_schema=input_schema, - ) + return core_schema.with_info_before_validator_function(func, schema=schema, field_name=handler.field_name) else: func = cast(core_schema.NoInfoValidatorFunction, self.func) - return core_schema.no_info_before_validator_function( - func, schema=schema, json_schema_input_schema=input_schema - ) - - @classmethod - def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self: - return cls( - func=decorator.func, - json_schema_input_type=decorator.info.json_schema_input_type, - ) + return core_schema.no_info_before_validator_function(func, schema=schema) @dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) class PlainValidator: - """!!! abstract "Usage Documentation" - [field *plain* validators](../concepts/validators.md#field-plain-validator) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators A metadata class that indicates that a validation should be applied **instead** of the inner validation logic. - !!! note - Before v2.9, `PlainValidator` wasn't always compatible with JSON Schema generation for `mode='validation'`. - You can now use the `json_schema_input_type` argument to specify the input type of the function - to be used in the JSON schema when `mode='validation'` (the default). See the example below for more details. - Attributes: func: The validator function. - json_schema_input_type: The input type of the function. This is only used to generate the appropriate - JSON Schema (in validation mode). If not provided, will default to `Any`. Example: - ```python - from typing import Annotated, Union + ```py + from typing_extensions import Annotated from pydantic import BaseModel, PlainValidator - MyInt = Annotated[ - int, - PlainValidator( - lambda v: int(v) + 1, json_schema_input_type=Union[str, int] # (1)! - ), - ] + MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)] class Model(BaseModel): a: MyInt print(Model(a='1').a) #> 2 - - print(Model(a=1).a) - #> 2 ``` - - 1. In this example, we've specified the `json_schema_input_type` as `Union[str, int]` which indicates to the JSON schema - generator that in validation mode, the input type for the `a` field can be either a `str` or an `int`. """ func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction - json_schema_input_type: Any = Any - - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - # Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the - # source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper - # serialization schema. To work around this for use cases that will not involve serialization, we simply - # catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema - # and abort any attempts to handle special serialization. - from pydantic import PydanticSchemaGenerationError - - try: - schema = handler(source_type) - # TODO if `schema['serialization']` is one of `'include-exclude-dict/sequence', - # schema validation will fail. That's why we use 'type ignore' comments below. - serialization = schema.get( - 'serialization', - core_schema.wrap_serializer_function_ser_schema( - function=lambda v, h: h(v), - schema=schema, - return_schema=handler.generate_schema(source_type), - ), - ) - except PydanticSchemaGenerationError: - serialization = None - - input_schema = handler.generate_schema(self.json_schema_input_type) + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: info_arg = _inspect_validator(self.func, 'plain') if info_arg: func = cast(core_schema.WithInfoValidatorFunction, self.func) - return core_schema.with_info_plain_validator_function( - func, - field_name=handler.field_name, - serialization=serialization, # pyright: ignore[reportArgumentType] - json_schema_input_schema=input_schema, - ) + return core_schema.with_info_plain_validator_function(func, field_name=handler.field_name) else: func = cast(core_schema.NoInfoValidatorFunction, self.func) - return core_schema.no_info_plain_validator_function( - func, - serialization=serialization, # pyright: ignore[reportArgumentType] - json_schema_input_schema=input_schema, - ) - - @classmethod - def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self: - return cls( - func=decorator.func, - json_schema_input_type=decorator.info.json_schema_input_type, - ) + return core_schema.no_info_plain_validator_function(func) @dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) class WrapValidator: - """!!! abstract "Usage Documentation" - [field *wrap* validators](../concepts/validators.md#field-wrap-validator) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#annotated-validators A metadata class that indicates that a validation should be applied **around** the inner validation logic. Attributes: func: The validator function. - json_schema_input_type: The input type of the function. This is only used to generate the appropriate - JSON Schema (in validation mode). - ```python + ```py from datetime import datetime - from typing import Annotated + + from typing_extensions import Annotated from pydantic import BaseModel, ValidationError, WrapValidator @@ -291,61 +202,37 @@ class WrapValidator: """ func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction - json_schema_input_type: Any = PydanticUndefined - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: schema = handler(source_type) - input_schema = ( - None - if self.json_schema_input_type is PydanticUndefined - else handler.generate_schema(self.json_schema_input_type) - ) - info_arg = _inspect_validator(self.func, 'wrap') if info_arg: func = cast(core_schema.WithInfoWrapValidatorFunction, self.func) - return core_schema.with_info_wrap_validator_function( - func, - schema=schema, - field_name=handler.field_name, - json_schema_input_schema=input_schema, - ) + return core_schema.with_info_wrap_validator_function(func, schema=schema, field_name=handler.field_name) else: func = cast(core_schema.NoInfoWrapValidatorFunction, self.func) - return core_schema.no_info_wrap_validator_function( - func, - schema=schema, - json_schema_input_schema=input_schema, - ) - - @classmethod - def _from_decorator(cls, decorator: _decorators.Decorator[_decorators.FieldValidatorDecoratorInfo]) -> Self: - return cls( - func=decorator.func, - json_schema_input_type=decorator.info.json_schema_input_type, - ) + return core_schema.no_info_wrap_validator_function(func, schema=schema) if TYPE_CHECKING: class _OnlyValueValidatorClsMethod(Protocol): - def __call__(self, cls: Any, value: Any, /) -> Any: ... + def __call__(self, __cls: Any, __value: Any) -> Any: + ... class _V2ValidatorClsMethod(Protocol): - def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: ... - - class _OnlyValueWrapValidatorClsMethod(Protocol): - def __call__(self, cls: Any, value: Any, handler: _core_schema.ValidatorFunctionWrapHandler, /) -> Any: ... + def __call__(self, __cls: Any, __input_value: Any, __info: _core_schema.ValidationInfo) -> Any: + ... class _V2WrapValidatorClsMethod(Protocol): def __call__( self, - cls: Any, - value: Any, - handler: _core_schema.ValidatorFunctionWrapHandler, - info: _core_schema.ValidationInfo, - /, - ) -> Any: ... + __cls: Any, + __input_value: Any, + __validator: _core_schema.ValidatorFunctionWrapHandler, + __info: _core_schema.ValidationInfo, + ) -> Any: + ... _V2Validator = Union[ _V2ValidatorClsMethod, @@ -357,111 +244,57 @@ if TYPE_CHECKING: _V2WrapValidator = Union[ _V2WrapValidatorClsMethod, _core_schema.WithInfoWrapValidatorFunction, - _OnlyValueWrapValidatorClsMethod, - _core_schema.NoInfoWrapValidatorFunction, ] _PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]] _V2BeforeAfterOrPlainValidatorType = TypeVar( '_V2BeforeAfterOrPlainValidatorType', - bound=Union[_V2Validator, _PartialClsOrStaticMethod], + _V2Validator, + _PartialClsOrStaticMethod, ) - _V2WrapValidatorType = TypeVar('_V2WrapValidatorType', bound=Union[_V2WrapValidator, _PartialClsOrStaticMethod]) + _V2WrapValidatorType = TypeVar('_V2WrapValidatorType', _V2WrapValidator, _PartialClsOrStaticMethod) + + +@overload +def field_validator( + __field: str, + *fields: str, + mode: Literal['before', 'after', 'plain'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: + ... + + +@overload +def field_validator( + __field: str, + *fields: str, + mode: Literal['wrap'], + check_fields: bool | None = ..., +) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: + ... + FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain'] -@overload def field_validator( - field: str, - /, - *fields: str, - mode: Literal['wrap'], - check_fields: bool | None = ..., - json_schema_input_type: Any = ..., -) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: ... - - -@overload -def field_validator( - field: str, - /, - *fields: str, - mode: Literal['before', 'plain'], - check_fields: bool | None = ..., - json_schema_input_type: Any = ..., -) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ... - - -@overload -def field_validator( - field: str, - /, - *fields: str, - mode: Literal['after'] = ..., - check_fields: bool | None = ..., -) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: ... - - -def field_validator( - field: str, - /, + __field: str, *fields: str, mode: FieldValidatorModes = 'after', check_fields: bool | None = None, - json_schema_input_type: Any = PydanticUndefined, ) -> Callable[[Any], Any]: - """!!! abstract "Usage Documentation" - [field validators](../concepts/validators.md#field-validators) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/validators/#field-validators Decorate methods on the class indicating that they should be used to validate fields. - Example usage: - ```python - from typing import Any - - from pydantic import ( - BaseModel, - ValidationError, - field_validator, - ) - - class Model(BaseModel): - a: str - - @field_validator('a') - @classmethod - def ensure_foobar(cls, v: Any): - if 'foobar' not in v: - raise ValueError('"foobar" not found in a') - return v - - print(repr(Model(a='this is foobar good'))) - #> Model(a='this is foobar good') - - try: - Model(a='snap') - except ValidationError as exc_info: - print(exc_info) - ''' - 1 validation error for Model - a - Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str] - ''' - ``` - - For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators). - Args: - field: The first field the `field_validator` should be called on; this is separate + __field: The first field the `field_validator` should be called on; this is separate from `fields` to ensure an error is raised if you don't pass at least one. *fields: Additional field(s) the `field_validator` should be called on. mode: Specifies whether to validate the fields before or after validation. check_fields: Whether to check that the fields actually exist on the model. - json_schema_input_type: The input type of the function. This is only used to generate - the appropriate JSON Schema (in validation mode) and can only specified - when `mode` is either `'before'`, `'plain'` or `'wrap'`. Returns: A decorator that can be used to decorate a function to be used as a field_validator. @@ -472,23 +305,13 @@ def field_validator( - If the args passed to `@field_validator` as fields are not strings. - If `@field_validator` applied to instance methods. """ - if isinstance(field, FunctionType): + if isinstance(__field, FunctionType): raise PydanticUserError( '`@field_validator` should be used with fields and keyword arguments, not bare. ' "E.g. usage should be `@validator('', ...)`", code='validator-no-fields', ) - - if mode not in ('before', 'plain', 'wrap') and json_schema_input_type is not PydanticUndefined: - raise PydanticUserError( - f"`json_schema_input_type` can't be used when mode is set to {mode!r}", - code='validator-input-type', - ) - - if json_schema_input_type is PydanticUndefined and mode == 'plain': - json_schema_input_type = Any - - fields = field, *fields + fields = __field, *fields if not all(isinstance(field, str) for field in fields): raise PydanticUserError( '`@field_validator` fields should be passed as separate string args. ' @@ -497,7 +320,7 @@ def field_validator( ) def dec( - f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any], + f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any] ) -> _decorators.PydanticDescriptorProxy[Any]: if _decorators.is_instance_method_from_sig(f): raise PydanticUserError( @@ -507,9 +330,7 @@ def field_validator( # auto apply the @classmethod decorator f = _decorators.ensure_classmethod_based_on_signature(f) - dec_info = _decorators.FieldValidatorDecoratorInfo( - fields=fields, mode=mode, check_fields=check_fields, json_schema_input_type=json_schema_input_type - ) + dec_info = _decorators.FieldValidatorDecoratorInfo(fields=fields, mode=mode, check_fields=check_fields) return _decorators.PydanticDescriptorProxy(f, dec_info) return dec @@ -520,19 +341,16 @@ _ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True) class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]): - """`@model_validator` decorated function handler argument type. This is used when `mode='wrap'`.""" + """@model_validator decorated function handler argument type. This is used when `mode='wrap'`.""" def __call__( # noqa: D102 - self, - value: Any, - outer_location: str | int | None = None, - /, + self, input_value: Any, outer_location: str | int | None = None ) -> _ModelTypeCo: # pragma: no cover ... class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]): - """A `@model_validator` decorated function signature. + """A @model_validator decorated function signature. This is used when `mode='wrap'` and the function does not have info argument. """ @@ -542,14 +360,14 @@ class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]): # this can be a dict, a model instance # or anything else that gets passed to validate_python # thus validators _must_ handle all cases - value: Any, - handler: ModelWrapValidatorHandler[_ModelType], - /, - ) -> _ModelType: ... + __value: Any, + __handler: ModelWrapValidatorHandler[_ModelType], + ) -> _ModelType: + ... class ModelWrapValidator(Protocol[_ModelType]): - """A `@model_validator` decorated function signature. This is used when `mode='wrap'`.""" + """A @model_validator decorated function signature. This is used when `mode='wrap'`.""" def __call__( # noqa: D102 self, @@ -557,30 +375,15 @@ class ModelWrapValidator(Protocol[_ModelType]): # this can be a dict, a model instance # or anything else that gets passed to validate_python # thus validators _must_ handle all cases - value: Any, - handler: ModelWrapValidatorHandler[_ModelType], - info: _core_schema.ValidationInfo, - /, - ) -> _ModelType: ... - - -class FreeModelBeforeValidatorWithoutInfo(Protocol): - """A `@model_validator` decorated function signature. - This is used when `mode='before'` and the function does not have info argument. - """ - - def __call__( # noqa: D102 - self, - # this can be a dict, a model instance - # or anything else that gets passed to validate_python - # thus validators _must_ handle all cases - value: Any, - /, - ) -> Any: ... + __value: Any, + __handler: ModelWrapValidatorHandler[_ModelType], + __info: _core_schema.ValidationInfo, + ) -> _ModelType: + ... class ModelBeforeValidatorWithoutInfo(Protocol): - """A `@model_validator` decorated function signature. + """A @model_validator decorated function signature. This is used when `mode='before'` and the function does not have info argument. """ @@ -590,23 +393,9 @@ class ModelBeforeValidatorWithoutInfo(Protocol): # this can be a dict, a model instance # or anything else that gets passed to validate_python # thus validators _must_ handle all cases - value: Any, - /, - ) -> Any: ... - - -class FreeModelBeforeValidator(Protocol): - """A `@model_validator` decorated function signature. This is used when `mode='before'`.""" - - def __call__( # noqa: D102 - self, - # this can be a dict, a model instance - # or anything else that gets passed to validate_python - # thus validators _must_ handle all cases - value: Any, - info: _core_schema.ValidationInfo, - /, - ) -> Any: ... + __value: Any, + ) -> Any: + ... class ModelBeforeValidator(Protocol): @@ -618,10 +407,10 @@ class ModelBeforeValidator(Protocol): # this can be a dict, a model instance # or anything else that gets passed to validate_python # thus validators _must_ handle all cases - value: Any, - info: _core_schema.ValidationInfo, - /, - ) -> Any: ... + __value: Any, + __info: _core_schema.ValidationInfo, + ) -> Any: + ... ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType] @@ -633,9 +422,7 @@ ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _Model """A `@model_validator` decorated function signature. This is used when `mode='after'`.""" _AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]] -_AnyModelBeforeValidator = Union[ - FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo -] +_AnyModeBeforeValidator = Union[ModelBeforeValidator, ModelBeforeValidatorWithoutInfo] _AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]] @@ -645,16 +432,16 @@ def model_validator( mode: Literal['wrap'], ) -> Callable[ [_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo] -]: ... +]: + ... @overload def model_validator( *, mode: Literal['before'], -) -> Callable[ - [_AnyModelBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo] -]: ... +) -> Callable[[_AnyModeBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]: + ... @overload @@ -663,49 +450,15 @@ def model_validator( mode: Literal['after'], ) -> Callable[ [_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo] -]: ... +]: + ... def model_validator( *, mode: Literal['wrap', 'before', 'after'], ) -> Any: - """!!! abstract "Usage Documentation" - [Model Validators](../concepts/validators.md#model-validators) - - Decorate model methods for validation purposes. - - Example usage: - ```python - from typing_extensions import Self - - from pydantic import BaseModel, ValidationError, model_validator - - class Square(BaseModel): - width: float - height: float - - @model_validator(mode='after') - def verify_square(self) -> Self: - if self.width != self.height: - raise ValueError('width and height do not match') - return self - - s = Square(width=1, height=1) - print(repr(s)) - #> Square(width=1.0, height=1.0) - - try: - Square(width=1, height=2) - except ValidationError as e: - print(e) - ''' - 1 validation error for Square - Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict] - ''' - ``` - - For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators). + """Decorate model methods for validation purposes. Args: mode: A required string literal that specifies the validation mode. @@ -738,7 +491,7 @@ else: '''Generic type for annotating a type that is an instance of a given class. Example: - ```python + ```py from pydantic import BaseModel, InstanceOf class Foo: @@ -817,7 +570,7 @@ else: @classmethod def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: original_schema = handler(source) - metadata = {'pydantic_js_annotation_functions': [lambda _c, h: h(original_schema)]} + metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)]) return core_schema.any_schema( metadata=metadata, serialization=core_schema.wrap_serializer_function_ser_schema( diff --git a/venv/lib/python3.12/site-packages/pydantic/generics.py b/venv/lib/python3.12/site-packages/pydantic/generics.py index 3f1070d..5f6f7f7 100644 --- a/venv/lib/python3.12/site-packages/pydantic/generics.py +++ b/venv/lib/python3.12/site-packages/pydantic/generics.py @@ -1,5 +1,4 @@ """The `generics` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/json.py b/venv/lib/python3.12/site-packages/pydantic/json.py index bcaff9f..020fb6d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/json.py +++ b/venv/lib/python3.12/site-packages/pydantic/json.py @@ -1,5 +1,4 @@ """The `json` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/json_schema.py b/venv/lib/python3.12/site-packages/pydantic/json_schema.py index be9595c..bf327de 100644 --- a/venv/lib/python3.12/site-packages/pydantic/json_schema.py +++ b/venv/lib/python3.12/site-packages/pydantic/json_schema.py @@ -1,47 +1,43 @@ -"""!!! abstract "Usage Documentation" - [JSON Schema](../concepts/json_schema.md) - +""" The `json_schema` module contains classes and functions to allow the way [JSON Schema](https://json-schema.org/) is generated to be customized. -In general you shouldn't need to use this module directly; instead, you can use +In general you shouldn't need to use this module directly; instead, you can [`BaseModel.model_json_schema`][pydantic.BaseModel.model_json_schema] and [`TypeAdapter.json_schema`][pydantic.TypeAdapter.json_schema]. """ - from __future__ import annotations as _annotations import dataclasses import inspect import math -import os import re import warnings -from collections import Counter, defaultdict -from collections.abc import Hashable, Iterable, Sequence +from collections import defaultdict from copy import deepcopy +from dataclasses import is_dataclass from enum import Enum -from re import Pattern from typing import ( TYPE_CHECKING, - Annotated, Any, Callable, - Literal, + Counter, + Dict, + Hashable, + Iterable, + List, NewType, + Sequence, + Tuple, TypeVar, Union, cast, - overload, ) import pydantic_core from pydantic_core import CoreSchema, PydanticOmit, core_schema, to_jsonable_python from pydantic_core.core_schema import ComputedField -from typing_extensions import TypeAlias, assert_never, deprecated, final -from typing_inspection.introspection import get_literal_values - -from pydantic.warnings import PydanticDeprecatedSince26, PydanticDeprecatedSince29 +from typing_extensions import Annotated, Literal, assert_never from ._internal import ( _config, @@ -51,10 +47,11 @@ from ._internal import ( _internal_dataclass, _mock_val_ser, _schema_generation_shared, + _typing_extra, ) from .annotated_handlers import GetJsonSchemaHandler -from .config import JsonDict, JsonValue -from .errors import PydanticInvalidForJsonSchema, PydanticSchemaGenerationError, PydanticUserError +from .config import JsonSchemaExtraCallable +from .errors import PydanticInvalidForJsonSchema, PydanticUserError if TYPE_CHECKING: from . import ConfigDict @@ -71,9 +68,9 @@ A type alias for defined schema types that represents a union of `core_schema.CoreSchemaFieldType`. """ -JsonSchemaValue = dict[str, Any] +JsonSchemaValue = Dict[str, Any] """ -A type alias for a JSON schema value. This is a dictionary of string keys to arbitrary JSON values. +A type alias for a JSON schema value. This is a dictionary of string keys to arbitrary values. """ JsonSchemaMode = Literal['validation', 'serialization'] @@ -89,7 +86,23 @@ for validation inputs, or that will be matched by serialization outputs. _MODE_TITLE_MAPPING: dict[JsonSchemaMode, str] = {'validation': 'Input', 'serialization': 'Output'} -JsonSchemaWarningKind = Literal['skipped-choice', 'non-serializable-default', 'skipped-discriminator'] +def update_json_schema(schema: JsonSchemaValue, updates: dict[str, Any]) -> JsonSchemaValue: + """Update a JSON schema by providing a dictionary of updates. + + This function sets the provided key-value pairs in the schema and returns the updated schema. + + Args: + schema: The JSON schema to update. + updates: A dictionary of key-value pairs to set in the schema. + + Returns: + The updated JSON schema. + """ + schema.update(updates) + return schema + + +JsonSchemaWarningKind = Literal['skipped-choice', 'non-serializable-default'] """ A type alias representing the kinds of warnings that can be emitted during JSON schema generation. @@ -106,12 +119,6 @@ class PydanticJsonSchemaWarning(UserWarning): """ -NoDefault = object() -"""A sentinel value used to indicate that no default value should be used when generating a JSON Schema -for a core schema with a default value. -""" - - # ##### JSON Schema Generation ##### DEFAULT_REF_TEMPLATE = '#/$defs/{model}' """The default format string used to generate reference names.""" @@ -128,7 +135,7 @@ DefsRef = NewType('DefsRef', str) # * By default, these look like "#/$defs/MyModel", as in {"$ref": "#/$defs/MyModel"} JsonRef = NewType('JsonRef', str) -CoreModeRef = tuple[CoreRef, JsonSchemaMode] +CoreModeRef = Tuple[CoreRef, JsonSchemaMode] JsonSchemaKeyT = TypeVar('JsonSchemaKeyT', bound=Hashable) @@ -163,7 +170,7 @@ class _DefinitionsRemapping: # Deduplicate the schemas for each alternative; the idea is that we only want to remap to a new DefsRef # if it introduces no ambiguity, i.e., there is only one distinct schema for that DefsRef. - for defs_ref in schemas_for_alternatives: + for defs_ref, schemas in schemas_for_alternatives.items(): schemas_for_alternatives[defs_ref] = _deduplicate_schemas(schemas_for_alternatives[defs_ref]) # Build the remapping @@ -214,10 +221,7 @@ class _DefinitionsRemapping: class GenerateJsonSchema: - """!!! abstract "Usage Documentation" - [Customizing the JSON Schema Generation Process](../concepts/json_schema.md#customizing-the-json-schema-generation-process) - - A class for generating JSON schemas. + """A class for generating JSON schemas. This class generates JSON schemas based on configured parameters. The default schema dialect is [https://json-schema.org/draft/2020-12/schema](https://json-schema.org/draft/2020-12/schema). @@ -231,20 +235,27 @@ class GenerateJsonSchema: ignored_warning_kinds: Warnings to ignore when generating the schema. `self.render_warning_message` will do nothing if its argument `kind` is in `ignored_warning_kinds`; this value can be modified on subclasses to easily control which warnings are emitted. - by_alias: Whether to use field aliases when generating the schema. + by_alias: Whether or not to use field names when generating the schema. ref_template: The format string used when generating reference names. core_to_json_refs: A mapping of core refs to JSON refs. core_to_defs_refs: A mapping of core refs to definition refs. defs_to_core_refs: A mapping of definition refs to core refs. json_to_defs_refs: A mapping of JSON refs to definition refs. definitions: Definitions in the schema. + collisions: Definitions with colliding names. When collisions are detected, we choose a non-colliding + name during generation, but we also track the colliding tag so that it can be remapped for the first + occurrence at the end of the process. + defs_ref_fallbacks: Core refs to fallback definitions refs. + _schema_type_to_method: A mapping of schema types to generator methods. + _used: Set to `True` after generating a schema to avoid re-use issues. + mode: The schema mode. Args: - by_alias: Whether to use field aliases in the generated schemas. + by_alias: Whether or not to include field names. ref_template: The format string to use when generating reference names. Raises: - JsonSchemaError: If the instance of the class is inadvertently reused after generating a schema. + JsonSchemaError: If the instance of the class is inadvertently re-used after generating a schema. """ schema_dialect = 'https://json-schema.org/draft/2020-12/schema' @@ -285,7 +296,7 @@ class GenerateJsonSchema: # store the error raised and re-throw it if we end up needing that def self._core_defs_invalid_for_json_schema: dict[DefsRef, PydanticInvalidForJsonSchema] = {} - # This changes to True after generating a schema, to prevent issues caused by accidental reuse + # This changes to True after generating a schema, to prevent issues caused by accidental re-use # of a single instance of a schema generator self._used = False @@ -312,14 +323,14 @@ class GenerateJsonSchema: TypeError: If no method has been defined for generating a JSON schema for a given pydantic core schema type. """ mapping: dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]] = {} - core_schema_types: list[CoreSchemaOrFieldType] = list(get_literal_values(CoreSchemaOrFieldType)) + core_schema_types: list[CoreSchemaOrFieldType] = _typing_extra.all_literal_values( + CoreSchemaOrFieldType # type: ignore + ) for key in core_schema_types: - method_name = f'{key.replace("-", "_")}_schema' + method_name = f"{key.replace('-', '_')}_schema" try: mapping[key] = getattr(self, method_name) except AttributeError as e: # pragma: no cover - if os.getenv('PYDANTIC_PRIVATE_ALLOW_UNHANDLED_SCHEMA_TYPES'): - continue raise TypeError( f'No method for generating JsonSchema for core_schema.type={key!r} ' f'(expected: {type(self).__name__}.{method_name})' @@ -358,7 +369,7 @@ class GenerateJsonSchema: code='json-schema-already-used', ) - for _, mode, schema in inputs: + for key, mode, schema in inputs: self._mode = mode self.generate_inner(schema) @@ -373,7 +384,7 @@ class GenerateJsonSchema: json_schema = {'$defs': self.definitions} json_schema = definitions_remapping.remap_json_schema(json_schema) self._used = True - return json_schemas_map, self.sort(json_schema['$defs']) # type: ignore + return json_schemas_map, _sort_json_schema(json_schema['$defs']) # type: ignore def generate(self, schema: CoreSchema, mode: JsonSchemaMode = 'validation') -> JsonSchemaValue: """Generates a JSON schema for a specified schema in a specified mode. @@ -399,15 +410,18 @@ class GenerateJsonSchema: json_schema: JsonSchemaValue = self.generate_inner(schema) json_ref_counts = self.get_json_ref_counts(json_schema) + # Remove the top-level $ref if present; note that the _generate method already ensures there are no sibling keys ref = cast(JsonRef, json_schema.get('$ref')) while ref is not None: # may need to unpack multiple levels ref_json_schema = self.get_schema_from_definitions(ref) - if json_ref_counts[ref] == 1 and ref_json_schema is not None and len(json_schema) == 1: - # "Unpack" the ref since this is the only reference and there are no sibling keys + if json_ref_counts[ref] > 1 or ref_json_schema is None: + # Keep the ref, but use an allOf to remove the top level $ref + json_schema = {'allOf': [{'$ref': ref}]} + else: + # "Unpack" the ref since this is the only reference json_schema = ref_json_schema.copy() # copy to prevent recursive dict reference json_ref_counts[ref] -= 1 - ref = cast(JsonRef, json_schema.get('$ref')) - ref = None + ref = cast(JsonRef, json_schema.get('$ref')) self._garbage_collect_definitions(json_schema) definitions_remapping = self._build_definitions_remapping() @@ -422,7 +436,7 @@ class GenerateJsonSchema: # json_schema['$schema'] = self.schema_dialect self._used = True - return self.sort(json_schema) + return _sort_json_schema(json_schema) def generate_inner(self, schema: CoreSchemaOrField) -> JsonSchemaValue: # noqa: C901 """Generates a JSON schema for a given core schema. @@ -432,10 +446,6 @@ class GenerateJsonSchema: Returns: The generated JSON schema. - - TODO: the nested function definitions here seem like bad practice, I'd like to unpack these - in a future PR. It'd be great if we could shorten the call stack a bit for JSON schema generation, - and I think there's potential for that here. """ # If a schema with the same CoreRef has been handled, just return a reference to it # Note that this assumes that it will _never_ be the case that the same CoreRef is used @@ -446,11 +456,15 @@ class GenerateJsonSchema: if core_mode_ref in self.core_to_defs_refs and self.core_to_defs_refs[core_mode_ref] in self.definitions: return {'$ref': self.core_to_json_refs[core_mode_ref]} + # Generate the JSON schema, accounting for the json_schema_override and core_schema_override + metadata_handler = _core_metadata.CoreMetadataHandler(schema) + def populate_defs(core_schema: CoreSchema, json_schema: JsonSchemaValue) -> JsonSchemaValue: if 'ref' in core_schema: core_ref = CoreRef(core_schema['ref']) # type: ignore[typeddict-item] defs_ref, ref_json_schema = self.get_cache_defs_ref_schema(core_ref) json_ref = JsonRef(ref_json_schema['$ref']) + self.json_to_defs_refs[json_ref] = defs_ref # Replace the schema if it's not a reference to itself # What we want to avoid is having the def be just a ref to itself # which is what would happen if we blindly assigned any @@ -460,6 +474,15 @@ class GenerateJsonSchema: json_schema = ref_json_schema return json_schema + def convert_to_all_of(json_schema: JsonSchemaValue) -> JsonSchemaValue: + if '$ref' in json_schema and len(json_schema.keys()) > 1: + # technically you can't have any other keys next to a "$ref" + # but it's an easy mistake to make and not hard to correct automatically here + json_schema = json_schema.copy() + ref = json_schema.pop('$ref') + json_schema = {'allOf': [{'$ref': ref}], **json_schema} + return json_schema + def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue: """Generate a JSON schema based on the input schema. @@ -475,63 +498,22 @@ class GenerateJsonSchema: # Generate the core-schema-type-specific bits of the schema generation: json_schema: JsonSchemaValue | None = None if self.mode == 'serialization' and 'serialization' in schema_or_field: - # In this case, we skip the JSON Schema generation of the schema - # and use the `'serialization'` schema instead (canonical example: - # `Annotated[int, PlainSerializer(str)]`). ser_schema = schema_or_field['serialization'] # type: ignore json_schema = self.ser_schema(ser_schema) - - # It might be that the 'serialization'` is skipped depending on `when_used`. - # This is only relevant for `nullable` schemas though, so we special case here. - if ( - json_schema is not None - and ser_schema.get('when_used') in ('unless-none', 'json-unless-none') - and schema_or_field['type'] == 'nullable' - ): - json_schema = self.get_flattened_anyof([{'type': 'null'}, json_schema]) if json_schema is None: if _core_utils.is_core_schema(schema_or_field) or _core_utils.is_core_schema_field(schema_or_field): generate_for_schema_type = self._schema_type_to_method[schema_or_field['type']] json_schema = generate_for_schema_type(schema_or_field) else: raise TypeError(f'Unexpected schema type: schema={schema_or_field}') - + if _core_utils.is_core_schema(schema_or_field): + json_schema = populate_defs(schema_or_field, json_schema) + json_schema = convert_to_all_of(json_schema) return json_schema current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, handler_func) - metadata = cast(_core_metadata.CoreMetadata, schema.get('metadata', {})) - - # TODO: I dislike that we have to wrap these basic dict updates in callables, is there any way around this? - - if js_updates := metadata.get('pydantic_js_updates'): - - def js_updates_handler_func( - schema_or_field: CoreSchemaOrField, - current_handler: GetJsonSchemaHandler = current_handler, - ) -> JsonSchemaValue: - json_schema = {**current_handler(schema_or_field), **js_updates} - return json_schema - - current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, js_updates_handler_func) - - if js_extra := metadata.get('pydantic_js_extra'): - - def js_extra_handler_func( - schema_or_field: CoreSchemaOrField, - current_handler: GetJsonSchemaHandler = current_handler, - ) -> JsonSchemaValue: - json_schema = current_handler(schema_or_field) - if isinstance(js_extra, dict): - json_schema.update(to_jsonable_python(js_extra)) - elif callable(js_extra): - # similar to typing issue in _update_class_schema when we're working with callable js extra - js_extra(json_schema) # type: ignore - return json_schema - - current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, js_extra_handler_func) - - for js_modify_function in metadata.get('pydantic_js_functions', ()): + for js_modify_function in metadata_handler.metadata.get('pydantic_js_functions', ()): def new_handler_func( schema_or_field: CoreSchemaOrField, @@ -549,61 +531,28 @@ class GenerateJsonSchema: current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func) - for js_modify_function in metadata.get('pydantic_js_annotation_functions', ()): + for js_modify_function in metadata_handler.metadata.get('pydantic_js_annotation_functions', ()): def new_handler_func( schema_or_field: CoreSchemaOrField, current_handler: GetJsonSchemaHandler = current_handler, js_modify_function: GetJsonSchemaFunction = js_modify_function, ) -> JsonSchemaValue: - return js_modify_function(schema_or_field, current_handler) + json_schema = js_modify_function(schema_or_field, current_handler) + if _core_utils.is_core_schema(schema_or_field): + json_schema = populate_defs(schema_or_field, json_schema) + json_schema = convert_to_all_of(json_schema) + return json_schema current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func) json_schema = current_handler(schema) if _core_utils.is_core_schema(schema): json_schema = populate_defs(schema, json_schema) + json_schema = convert_to_all_of(json_schema) return json_schema - def sort(self, value: JsonSchemaValue, parent_key: str | None = None) -> JsonSchemaValue: - """Override this method to customize the sorting of the JSON schema (e.g., don't sort at all, sort all keys unconditionally, etc.) - - By default, alphabetically sort the keys in the JSON schema, skipping the 'properties' and 'default' keys to preserve field definition order. - This sort is recursive, so it will sort all nested dictionaries as well. - """ - sorted_dict: dict[str, JsonSchemaValue] = {} - keys = value.keys() - if parent_key not in ('properties', 'default'): - keys = sorted(keys) - for key in keys: - sorted_dict[key] = self._sort_recursive(value[key], parent_key=key) - return sorted_dict - - def _sort_recursive(self, value: Any, parent_key: str | None = None) -> Any: - """Recursively sort a JSON schema value.""" - if isinstance(value, dict): - sorted_dict: dict[str, JsonSchemaValue] = {} - keys = value.keys() - if parent_key not in ('properties', 'default'): - keys = sorted(keys) - for key in keys: - sorted_dict[key] = self._sort_recursive(value[key], parent_key=key) - return sorted_dict - elif isinstance(value, list): - sorted_list: list[JsonSchemaValue] = [] - for item in value: - sorted_list.append(self._sort_recursive(item, parent_key)) - return sorted_list - else: - return value - # ### Schema generation methods - - def invalid_schema(self, schema: core_schema.InvalidSchema) -> JsonSchemaValue: - """Placeholder - should never be called.""" - - raise RuntimeError('Cannot generate schema for invalid_schema. This is a bug! Please report it.') - def any_schema(self, schema: core_schema.AnySchema) -> JsonSchemaValue: """Generates a JSON schema that matches any value. @@ -616,7 +565,7 @@ class GenerateJsonSchema: return {} def none_schema(self, schema: core_schema.NoneSchema) -> JsonSchemaValue: - """Generates a JSON schema that matches `None`. + """Generates a JSON schema that matches a None value. Args: schema: The core schema. @@ -638,7 +587,7 @@ class GenerateJsonSchema: return {'type': 'boolean'} def int_schema(self, schema: core_schema.IntSchema) -> JsonSchemaValue: - """Generates a JSON schema that matches an int value. + """Generates a JSON schema that matches an Int value. Args: schema: The core schema. @@ -709,9 +658,6 @@ class GenerateJsonSchema: """ json_schema = {'type': 'string'} self.update_with_validations(json_schema, schema, self.ValidationsMapping.string) - if isinstance(json_schema.get('pattern'), Pattern): - # TODO: should we add regex flags to the pattern? - json_schema['pattern'] = json_schema.get('pattern').pattern # type: ignore return json_schema def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue: @@ -736,7 +682,9 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - return {'type': 'string', 'format': 'date'} + json_schema = {'type': 'string', 'format': 'date'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.date) + return json_schema def time_schema(self, schema: core_schema.TimeSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a time value. @@ -782,69 +730,32 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - expected = [to_jsonable_python(v.value if isinstance(v, Enum) else v) for v in schema['expected']] + expected = [v.value if isinstance(v, Enum) else v for v in schema['expected']] + # jsonify the expected values + expected = [to_jsonable_python(v) for v in expected] - result: dict[str, Any] = {} if len(expected) == 1: - result['const'] = expected[0] - else: - result['enum'] = expected + return {'const': expected[0]} types = {type(e) for e in expected} if types == {str}: - result['type'] = 'string' + return {'enum': expected, 'type': 'string'} elif types == {int}: - result['type'] = 'integer' + return {'enum': expected, 'type': 'integer'} elif types == {float}: - result['type'] = 'number' + return {'enum': expected, 'type': 'number'} elif types == {bool}: - result['type'] = 'boolean' + return {'enum': expected, 'type': 'boolean'} elif types == {list}: - result['type'] = 'array' - elif types == {type(None)}: - result['type'] = 'null' - return result - - def enum_schema(self, schema: core_schema.EnumSchema) -> JsonSchemaValue: - """Generates a JSON schema that matches an Enum value. - - Args: - schema: The core schema. - - Returns: - The generated JSON schema. - """ - enum_type = schema['cls'] - description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) - if ( - description == 'An enumeration.' - ): # This is the default value provided by enum.EnumMeta.__new__; don't use it - description = None - result: dict[str, Any] = {'title': enum_type.__name__, 'description': description} - result = {k: v for k, v in result.items() if v is not None} - - expected = [to_jsonable_python(v.value) for v in schema['members']] - - result['enum'] = expected - - types = {type(e) for e in expected} - if isinstance(enum_type, str) or types == {str}: - result['type'] = 'string' - elif isinstance(enum_type, int) or types == {int}: - result['type'] = 'integer' - elif isinstance(enum_type, float) or types == {float}: - result['type'] = 'number' - elif types == {bool}: - result['type'] = 'boolean' - elif types == {list}: - result['type'] = 'array' - - return result + return {'enum': expected, 'type': 'array'} + # there is not None case because if it's mixed it hits the final `else` + # if it's a single Literal[None] then it becomes a `const` schema above + else: + return {'enum': expected} def is_instance_schema(self, schema: core_schema.IsInstanceSchema) -> JsonSchemaValue: - """Handles JSON schema generation for a core schema that checks if a value is an instance of a class. - - Unless overridden in a subclass, this raises an error. + """Generates a JSON schema that checks if a value is an instance of a class, equivalent to Python's + `isinstance` method. Args: schema: The core schema. @@ -855,9 +766,8 @@ class GenerateJsonSchema: return self.handle_invalid_for_json_schema(schema, f'core_schema.IsInstanceSchema ({schema["cls"]})') def is_subclass_schema(self, schema: core_schema.IsSubclassSchema) -> JsonSchemaValue: - """Handles JSON schema generation for a core schema that checks if a value is a subclass of a class. - - For backwards compatibility with v1, this does not raise an error, but can be overridden to change this. + """Generates a JSON schema that checks if a value is a subclass of a class, equivalent to Python's `issubclass` + method. Args: schema: The core schema. @@ -871,8 +781,6 @@ class GenerateJsonSchema: def callable_schema(self, schema: core_schema.CallableSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a callable value. - Unless overridden in a subclass, this raises an error. - Args: schema: The core schema. @@ -895,31 +803,8 @@ class GenerateJsonSchema: self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) return json_schema - @deprecated('`tuple_positional_schema` is deprecated. Use `tuple_schema` instead.', category=None) - @final - def tuple_positional_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: - """Replaced by `tuple_schema`.""" - warnings.warn( - '`tuple_positional_schema` is deprecated. Use `tuple_schema` instead.', - PydanticDeprecatedSince26, - stacklevel=2, - ) - return self.tuple_schema(schema) - - @deprecated('`tuple_variable_schema` is deprecated. Use `tuple_schema` instead.', category=None) - @final - def tuple_variable_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: - """Replaced by `tuple_schema`.""" - warnings.warn( - '`tuple_variable_schema` is deprecated. Use `tuple_schema` instead.', - PydanticDeprecatedSince26, - stacklevel=2, - ) - return self.tuple_schema(schema) - - def tuple_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: - """Generates a JSON schema that matches a tuple schema e.g. `tuple[int, - str, bool]` or `tuple[int, ...]`. + def tuple_positional_schema(self, schema: core_schema.TuplePositionalSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a positional tuple schema e.g. `Tuple[int, str, bool]`. Args: schema: The core schema. @@ -928,27 +813,28 @@ class GenerateJsonSchema: The generated JSON schema. """ json_schema: JsonSchemaValue = {'type': 'array'} - if 'variadic_item_index' in schema: - variadic_item_index = schema['variadic_item_index'] - if variadic_item_index > 0: - json_schema['minItems'] = variadic_item_index - json_schema['prefixItems'] = [ - self.generate_inner(item) for item in schema['items_schema'][:variadic_item_index] - ] - if variadic_item_index + 1 == len(schema['items_schema']): - # if the variadic item is the last item, then represent it faithfully - json_schema['items'] = self.generate_inner(schema['items_schema'][variadic_item_index]) - else: - # otherwise, 'items' represents the schema for the variadic - # item plus the suffix, so just allow anything for simplicity - # for now - json_schema['items'] = True + json_schema['minItems'] = len(schema['items_schema']) + prefixItems = [self.generate_inner(item) for item in schema['items_schema']] + if prefixItems: + json_schema['prefixItems'] = prefixItems + if 'extras_schema' in schema: + json_schema['items'] = self.generate_inner(schema['extras_schema']) else: - prefixItems = [self.generate_inner(item) for item in schema['items_schema']] - if prefixItems: - json_schema['prefixItems'] = prefixItems - json_schema['minItems'] = len(prefixItems) - json_schema['maxItems'] = len(prefixItems) + json_schema['maxItems'] = len(schema['items_schema']) + self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) + return json_schema + + def tuple_variable_schema(self, schema: core_schema.TupleVariableSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a variable tuple schema e.g. `Tuple[int, ...]`. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + items_schema = {} if 'items_schema' not in schema else self.generate_inner(schema['items_schema']) + json_schema = {'type': 'array', 'items': items_schema} self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) return json_schema @@ -1006,42 +892,33 @@ class GenerateJsonSchema: json_schema: JsonSchemaValue = {'type': 'object'} keys_schema = self.generate_inner(schema['keys_schema']).copy() if 'keys_schema' in schema else {} - if '$ref' not in keys_schema: - keys_pattern = keys_schema.pop('pattern', None) - # Don't give a title to patternProperties/propertyNames: - keys_schema.pop('title', None) - else: - # Here, we assume that if the keys schema is a definition reference, - # it can't be a simple string core schema (and thus no pattern can exist). - # However, this is only in practice (in theory, a definition reference core - # schema could be generated for a simple string schema). - # Note that we avoid calling `self.resolve_ref_schema`, as it might not exist yet. - keys_pattern = None + keys_pattern = keys_schema.pop('pattern', None) values_schema = self.generate_inner(schema['values_schema']).copy() if 'values_schema' in schema else {} - # don't give a title to additionalProperties: - values_schema.pop('title', None) - - if values_schema or keys_pattern is not None: + values_schema.pop('title', None) # don't give a title to the additionalProperties + if values_schema or keys_pattern is not None: # don't add additionalProperties if it's empty if keys_pattern is None: json_schema['additionalProperties'] = values_schema else: json_schema['patternProperties'] = {keys_pattern: values_schema} - else: # for `dict[str, Any]`, we allow any key and any value, since `str` is the default key type - json_schema['additionalProperties'] = True - - if ( - # The len check indicates that constraints are probably present: - (keys_schema.get('type') == 'string' and len(keys_schema) > 1) - # If this is a definition reference schema, it most likely has constraints: - or '$ref' in keys_schema - ): - keys_schema.pop('type', None) - json_schema['propertyNames'] = keys_schema self.update_with_validations(json_schema, schema, self.ValidationsMapping.object) return json_schema + def _function_schema( + self, + schema: _core_utils.AnyFunctionSchema, + ) -> JsonSchemaValue: + if _core_utils.is_function_with_inner_schema(schema): + # This could be wrong if the function's mode is 'before', but in practice will often be right, and when it + # isn't, I think it would be hard to automatically infer what the desired schema should be. + return self.generate_inner(schema['schema']) + + # function-plain + return self.handle_invalid_for_json_schema( + schema, f'core_schema.PlainValidatorFunctionSchema ({schema["function"]})' + ) + def function_before_schema(self, schema: core_schema.BeforeValidatorFunctionSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a function-before schema. @@ -1051,10 +928,7 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - if self.mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')): - return self.generate_inner(input_schema) - - return self.generate_inner(schema['schema']) + return self._function_schema(schema) def function_after_schema(self, schema: core_schema.AfterValidatorFunctionSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a function-after schema. @@ -1065,7 +939,7 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - return self.generate_inner(schema['schema']) + return self._function_schema(schema) def function_plain_schema(self, schema: core_schema.PlainValidatorFunctionSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a function-plain schema. @@ -1076,12 +950,7 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - if self.mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')): - return self.generate_inner(input_schema) - - return self.handle_invalid_for_json_schema( - schema, f'core_schema.PlainValidatorFunctionSchema ({schema["function"]})' - ) + return self._function_schema(schema) def function_wrap_schema(self, schema: core_schema.WrapValidatorFunctionSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a function-wrap schema. @@ -1092,10 +961,7 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - if self.mode == 'validation' and (input_schema := schema.get('json_schema_input_schema')): - return self.generate_inner(input_schema) - - return self.generate_inner(schema['schema']) + return self._function_schema(schema) def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a schema with a default value. @@ -1108,35 +974,17 @@ class GenerateJsonSchema: """ json_schema = self.generate_inner(schema['schema']) - default = self.get_default_value(schema) - if default is NoDefault: + if 'default' not in schema: return json_schema - - # we reflect the application of custom plain, no-info serializers to defaults for - # JSON Schemas viewed in serialization mode: - # TODO: improvements along with https://github.com/pydantic/pydantic/issues/8208 - if ( - self.mode == 'serialization' - and (ser_schema := schema['schema'].get('serialization')) - and (ser_func := ser_schema.get('function')) - and ser_schema.get('type') == 'function-plain' - and not ser_schema.get('info_arg') - and not (default is None and ser_schema.get('when_used') in ('unless-none', 'json-unless-none')) - ): - try: - default = ser_func(default) # type: ignore - except Exception: - # It might be that the provided default needs to be validated (read: parsed) first - # (assuming `validate_default` is enabled). However, we can't perform - # such validation during JSON Schema generation so we don't support - # this pattern for now. - # (One example is when using `foo: ByteSize = '1MB'`, which validates and - # serializes as an int. In this case, `ser_func` is `int` and `int('1MB')` fails). - self.emit_warning( - 'non-serializable-default', - f'Unable to serialize value {default!r} with the plain serializer; excluding default from JSON schema', - ) - return json_schema + default = schema['default'] + # Note: if you want to include the value returned by the default_factory, + # override this method and replace the code above with: + # if 'default' in schema: + # default = schema['default'] + # elif 'default_factory' in schema: + # default = schema['default_factory']() + # else: + # return json_schema try: encoded_default = self.encode_default(default) @@ -1148,23 +996,12 @@ class GenerateJsonSchema: # Return the inner schema, as though there was no default return json_schema - json_schema['default'] = encoded_default - return json_schema - - def get_default_value(self, schema: core_schema.WithDefaultSchema) -> Any: - """Get the default value to be used when generating a JSON Schema for a core schema with a default. - - The default implementation is to use the statically defined default value. This method can be overridden - if you want to make use of the default factory. - - Args: - schema: The `'with-default'` core schema. - - Returns: - The default value to use, or [`NoDefault`][pydantic.json_schema.NoDefault] if no default - value is available. - """ - return schema.get('default', NoDefault) + if '$ref' in json_schema: + # Since reference schemas do not support child keys, we wrap the reference schema in a single-case allOf: + return {'allOf': [json_schema], 'default': encoded_default} + else: + json_schema['default'] = encoded_default + return json_schema def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a schema that allows null values. @@ -1248,7 +1085,7 @@ class GenerateJsonSchema: return json_schema def _extract_discriminator( - self, schema: core_schema.TaggedUnionSchema, one_of_choices: list[JsonDict] + self, schema: core_schema.TaggedUnionSchema, one_of_choices: list[_JsonDict] ) -> str | None: """Extract a compatible OpenAPI discriminator from the schema and one_of choices that end up in the final schema.""" @@ -1275,14 +1112,9 @@ class GenerateJsonSchema: continue # this means that the "alias" does not represent a field alias_is_present_on_all_choices = True for choice in one_of_choices: - try: - choice = self.resolve_ref_schema(choice) - except RuntimeError as exc: - # TODO: fixme - this is a workaround for the fact that we can't always resolve refs - # for tagged union choices at this point in the schema gen process, we might need to do - # another pass at the end like we do for core schemas - self.emit_warning('skipped-discriminator', str(exc)) - choice = {} + while '$ref' in choice: + assert isinstance(choice['$ref'], str) + choice = self.get_schema_from_definitions(JsonRef(choice['$ref'])) or {} properties = choice.get('properties', {}) if not isinstance(properties, dict) or alias not in properties: alias_is_present_on_all_choices = False @@ -1358,19 +1190,16 @@ class GenerateJsonSchema: ] if self.mode == 'serialization': named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) - cls = schema.get('cls') - config = _get_typed_dict_config(cls) + + config = _get_typed_dict_config(schema) with self._config_wrapper_stack.push(config): json_schema = self._named_required_fields_schema(named_required_fields) - if cls is not None: - self._update_class_schema(json_schema, cls, config) - else: - extra = config.get('extra') - if extra == 'forbid': - json_schema['additionalProperties'] = False - elif extra == 'allow': - json_schema['additionalProperties'] = True + extra = config.get('extra', 'ignore') + if extra == 'forbid': + json_schema['additionalProperties'] = False + elif extra == 'allow': + json_schema['additionalProperties'] = True return json_schema @@ -1483,56 +1312,13 @@ class GenerateJsonSchema: # because it could lead to inconsistent refs handling, etc. cls = cast('type[BaseModel]', schema['cls']) config = cls.model_config + title = config.get('title') with self._config_wrapper_stack.push(config): json_schema = self.generate_inner(schema['schema']) - self._update_class_schema(json_schema, cls, config) - - return json_schema - - def _update_class_schema(self, json_schema: JsonSchemaValue, cls: type[Any], config: ConfigDict) -> None: - """Update json_schema with the following, extracted from `config` and `cls`: - - * title - * description - * additional properties - * json_schema_extra - * deprecated - - Done in place, hence there's no return value as the original json_schema is mutated. - No ref resolving is involved here, as that's not appropriate for simple updates. - """ - from .main import BaseModel - from .root_model import RootModel - - if (config_title := config.get('title')) is not None: - json_schema.setdefault('title', config_title) - elif model_title_generator := config.get('model_title_generator'): - title = model_title_generator(cls) - if not isinstance(title, str): - raise TypeError(f'model_title_generator {model_title_generator} must return str, not {title.__class__}') - json_schema.setdefault('title', title) - if 'title' not in json_schema: - json_schema['title'] = cls.__name__ - - # BaseModel and dataclasses; don't use cls.__doc__ as it will contain the verbose class signature by default - docstring = None if cls is BaseModel or dataclasses.is_dataclass(cls) else cls.__doc__ - - if docstring: - json_schema.setdefault('description', inspect.cleandoc(docstring)) - elif issubclass(cls, RootModel) and (root_description := cls.__pydantic_fields__['root'].description): - json_schema.setdefault('description', root_description) - - extra = config.get('extra') - if 'additionalProperties' not in json_schema: - if extra == 'allow': - json_schema['additionalProperties'] = True - elif extra == 'forbid': - json_schema['additionalProperties'] = False - json_schema_extra = config.get('json_schema_extra') - if issubclass(cls, BaseModel) and cls.__pydantic_root_model__: + if cls.__pydantic_root_model__: root_json_schema_extra = cls.model_fields['root'].json_schema_extra if json_schema_extra and root_json_schema_extra: raise ValueError( @@ -1542,27 +1328,52 @@ class GenerateJsonSchema: if root_json_schema_extra: json_schema_extra = root_json_schema_extra + json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra) + + return json_schema + + def _update_class_schema( + self, + json_schema: JsonSchemaValue, + title: str | None, + extra: Literal['allow', 'ignore', 'forbid'] | None, + cls: type[Any], + json_schema_extra: dict[str, Any] | JsonSchemaExtraCallable | None, + ) -> JsonSchemaValue: + if '$ref' in json_schema: + schema_to_update = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) or json_schema + else: + schema_to_update = json_schema + + if title is not None: + # referenced_schema['title'] = title + schema_to_update.setdefault('title', title) + + if 'additionalProperties' not in schema_to_update: + if extra == 'allow': + schema_to_update['additionalProperties'] = True + elif extra == 'forbid': + schema_to_update['additionalProperties'] = False + if isinstance(json_schema_extra, (staticmethod, classmethod)): # In older versions of python, this is necessary to ensure staticmethod/classmethods are callable json_schema_extra = json_schema_extra.__get__(cls) if isinstance(json_schema_extra, dict): - json_schema.update(json_schema_extra) + schema_to_update.update(json_schema_extra) elif callable(json_schema_extra): - # FIXME: why are there type ignores here? We support two signatures for json_schema_extra callables... if len(inspect.signature(json_schema_extra).parameters) > 1: - json_schema_extra(json_schema, cls) # type: ignore + json_schema_extra(schema_to_update, cls) # type: ignore else: - json_schema_extra(json_schema) # type: ignore + json_schema_extra(schema_to_update) # type: ignore elif json_schema_extra is not None: raise ValueError( f"model_config['json_schema_extra']={json_schema_extra} should be a dict, callable, or None" ) - if hasattr(cls, '__deprecated__'): - json_schema['deprecated'] = True + return json_schema - def resolve_ref_schema(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: + def resolve_schema_to_update(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: """Resolve a JsonSchemaValue to the non-ref schema if it is a $ref schema. Args: @@ -1570,17 +1381,15 @@ class GenerateJsonSchema: Returns: The resolved schema. - - Raises: - RuntimeError: If the schema reference can't be found in definitions. """ - while '$ref' in json_schema: - ref = json_schema['$ref'] - schema_to_update = self.get_schema_from_definitions(JsonRef(ref)) + if '$ref' in json_schema: + schema_to_update = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) if schema_to_update is None: - raise RuntimeError(f'Cannot update undefined schema for $ref={ref}') - json_schema = schema_to_update - return json_schema + raise RuntimeError(f'Cannot update undefined schema for $ref={json_schema["$ref"]}') + return self.resolve_schema_to_update(schema_to_update) + else: + schema_to_update = json_schema + return schema_to_update def model_fields_schema(self, schema: core_schema.ModelFieldsSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a schema that defines a model's fields. @@ -1601,7 +1410,7 @@ class GenerateJsonSchema: json_schema = self._named_required_fields_schema(named_required_fields) extras_schema = schema.get('extras_schema', None) if extras_schema is not None: - schema_to_update = self.resolve_ref_schema(json_schema) + schema_to_update = self.resolve_schema_to_update(json_schema) schema_to_update['additionalProperties'] = self.generate_inner(extras_schema) return json_schema @@ -1675,18 +1484,18 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - from ._internal._dataclasses import is_builtin_dataclass - cls = schema['cls'] config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {})) + title = config.get('title') or cls.__name__ with self._config_wrapper_stack.push(config): json_schema = self.generate_inner(schema['schema']).copy() - self._update_class_schema(json_schema, cls, config) + json_schema_extra = config.get('json_schema_extra') + json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra) # Dataclass-specific handling of description - if is_builtin_dataclass(cls): + if is_dataclass(cls) and not hasattr(cls, '__pydantic_validator__'): # vanilla dataclass; don't use cls.__doc__ as it will contain the class signature by default description = None else: @@ -1705,7 +1514,8 @@ class GenerateJsonSchema: Returns: The generated JSON schema. """ - prefer_positional = schema.get('metadata', {}).get('pydantic_js_prefer_positional_arguments') + metadata = _core_metadata.CoreMetadataHandler(schema).metadata + prefer_positional = metadata.get('pydantic_js_prefer_positional_arguments') arguments = schema['arguments_schema'] kw_only_arguments = [a for a in arguments if a.get('mode') == 'keyword_only'] @@ -1728,7 +1538,9 @@ class GenerateJsonSchema: if positional_possible: return self.p_arguments_schema(p_only_arguments + kw_or_p_arguments, var_args_schema) - raise PydanticInvalidForJsonSchema( + # TODO: When support for Python 3.7 is dropped, uncomment the block on `test_json_schema` + # to cover this test case. + raise PydanticInvalidForJsonSchema( # pragma: no cover 'Unable to generate JSON schema for arguments validator with positional-only and keyword-only arguments' ) @@ -1796,9 +1608,7 @@ class GenerateJsonSchema: # I believe this is true, but I am not 100% sure min_items += 1 - json_schema: JsonSchemaValue = {'type': 'array'} - if prefix_items: - json_schema['prefixItems'] = prefix_items + json_schema: JsonSchemaValue = {'type': 'array', 'prefixItems': prefix_items} if min_items: json_schema['minItems'] = min_items @@ -1811,7 +1621,7 @@ class GenerateJsonSchema: return json_schema - def get_argument_name(self, argument: core_schema.ArgumentsParameter | core_schema.ArgumentsV3Parameter) -> str: + def get_argument_name(self, argument: core_schema.ArgumentsParameter) -> str: """Retrieves the name of an argument. Args: @@ -1829,45 +1639,6 @@ class GenerateJsonSchema: pass # might want to do something else? return name - def arguments_v3_schema(self, schema: core_schema.ArgumentsV3Schema) -> JsonSchemaValue: - """Generates a JSON schema that matches a schema that defines a function's arguments. - - Args: - schema: The core schema. - - Returns: - The generated JSON schema. - """ - arguments = schema['arguments_schema'] - properties: dict[str, JsonSchemaValue] = {} - required: list[str] = [] - for argument in arguments: - mode = argument.get('mode', 'positional_or_keyword') - name = self.get_argument_name(argument) - argument_schema = self.generate_inner(argument['schema']).copy() - if mode == 'var_args': - argument_schema = {'type': 'array', 'items': argument_schema} - elif mode == 'var_kwargs_uniform': - argument_schema = {'type': 'object', 'additionalProperties': argument_schema} - - argument_schema.setdefault('title', self.get_title_from_name(name)) - properties[name] = argument_schema - - if ( - (mode == 'var_kwargs_unpacked_typed_dict' and 'required' in argument_schema) - or mode not in {'var_args', 'var_kwargs_uniform', 'var_kwargs_unpacked_typed_dict'} - and argument['schema']['type'] != 'default' - ): - # This assumes that if the argument has a default value, - # the inner schema must be of type WithDefaultSchema. - # I believe this is true, but I am not 100% sure - required.append(name) - - json_schema: JsonSchemaValue = {'type': 'object', 'properties': properties} - if required: - json_schema['required'] = required - return json_schema - def call_schema(self, schema: core_schema.CallSchema) -> JsonSchemaValue: """Generates a JSON schema that matches a schema that defines a function call. @@ -2001,22 +1772,6 @@ class GenerateJsonSchema: return self.generate_inner(schema['schema']) return None - def complex_schema(self, schema: core_schema.ComplexSchema) -> JsonSchemaValue: - """Generates a JSON schema that matches a complex number. - - JSON has no standard way to represent complex numbers. Complex number is not a numeric - type. Here we represent complex number as strings following the rule defined by Python. - For instance, '1+2j' is an accepted complex string. Details can be found in - [Python's `complex` documentation][complex]. - - Args: - schema: The core schema. - - Returns: - The generated JSON schema. - """ - return {'type': 'string'} - # ### Utility methods def get_title_from_name(self, name: str) -> str: @@ -2028,7 +1783,7 @@ class GenerateJsonSchema: Returns: The title. """ - return name.title().replace('_', ' ').strip() + return name.title().replace('_', ' ') def field_title_should_be_set(self, schema: CoreSchemaOrField) -> bool: """Returns true if a field with the given schema should have a title set based on the field name. @@ -2089,7 +1844,7 @@ class GenerateJsonSchema: core_ref, mode = core_mode_ref components = re.split(r'([\][,])', core_ref) # Remove IDs from each component - components = [x.rsplit(':', 1)[0] for x in components] + components = [x.split(':')[0] for x in components] core_ref_no_id = ''.join(components) # Remove everything before the last period from each "component" components = [re.sub(r'(?:[^.[\]]+\.)+((?:[^.[\]]+))', r'\1', x) for x in components] @@ -2153,13 +1908,14 @@ class GenerateJsonSchema: return defs_ref, ref_json_schema def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: - """Remove any sibling keys that are redundant with the referenced schema. + """It is not valid for a schema with a top-level $ref to have sibling keys. - Args: - json_schema: The schema to remove redundant sibling keys from. + During our own schema generation, we treat sibling keys as overrides to the referenced schema, + but this is not how the official JSON schema spec works. - Returns: - The schema with redundant sibling keys removed. + Because of this, we first remove any sibling keys that are redundant with the referenced schema, then if + any remain, we transform the schema from a top-level '$ref' to use allOf to move the $ref out of the top level. + (See bottom of https://swagger.io/docs/specification/using-ref/ for a reference about this behavior) """ if '$ref' in json_schema: # prevent modifications to the input; this copy may be safe to drop if there is significant overhead @@ -2170,25 +1926,33 @@ class GenerateJsonSchema: # This can happen when building schemas for models with not-yet-defined references. # It may be a good idea to do a recursive pass at the end of the generation to remove # any redundant override keys. + if len(json_schema) > 1: + # Make it an allOf to at least resolve the sibling keys issue + json_schema = json_schema.copy() + json_schema.setdefault('allOf', []) + json_schema['allOf'].append({'$ref': json_schema['$ref']}) + del json_schema['$ref'] + return json_schema for k, v in list(json_schema.items()): if k == '$ref': continue if k in referenced_json_schema and referenced_json_schema[k] == v: del json_schema[k] # redundant key + if len(json_schema) > 1: + # There is a remaining "override" key, so we need to move $ref out of the top level + json_ref = JsonRef(json_schema['$ref']) + del json_schema['$ref'] + assert 'allOf' not in json_schema # this should never happen, but just in case + json_schema['allOf'] = [{'$ref': json_ref}] return json_schema def get_schema_from_definitions(self, json_ref: JsonRef) -> JsonSchemaValue | None: - try: - def_ref = self.json_to_defs_refs[json_ref] - if def_ref in self._core_defs_invalid_for_json_schema: - raise self._core_defs_invalid_for_json_schema[def_ref] - return self.definitions.get(def_ref, None) - except KeyError: - if json_ref.startswith(('http://', 'https://')): - return None - raise + def_ref = self.json_to_defs_refs[json_ref] + if def_ref in self._core_defs_invalid_for_json_schema: + raise self._core_defs_invalid_for_json_schema[def_ref] + return self.definitions.get(def_ref, None) def encode_default(self, dft: Any) -> Any: """Encode a default value to a JSON-serializable value. @@ -2201,22 +1965,11 @@ class GenerateJsonSchema: Returns: The encoded default value. """ - from .type_adapter import TypeAdapter, _type_has_config - config = self._config - try: - default = ( - dft - if _type_has_config(type(dft)) - else TypeAdapter(type(dft), config=config.config_dict).dump_python( - dft, by_alias=self.by_alias, mode='json' - ) - ) - except PydanticSchemaGenerationError: - raise pydantic_core.PydanticSerializationError(f'Unable to encode default value {dft}') - return pydantic_core.to_jsonable_python( - default, timedelta_mode=config.ser_json_timedelta, bytes_mode=config.ser_json_bytes, by_alias=self.by_alias + dft, + timedelta_mode=config.ser_json_timedelta, + bytes_mode=config.ser_json_bytes, ) def update_with_validations( @@ -2265,6 +2018,12 @@ class GenerateJsonSchema: 'min_length': 'minProperties', 'max_length': 'maxProperties', } + date = { + 'le': 'maximum', + 'ge': 'minimum', + 'lt': 'exclusiveMaximum', + 'gt': 'exclusiveMinimum', + } def get_flattened_anyof(self, schemas: list[JsonSchemaValue]) -> JsonSchemaValue: members = [] @@ -2292,20 +2051,12 @@ class GenerateJsonSchema: json_refs[json_ref] += 1 if already_visited: return # prevent recursion on a definition that was already visited - try: - defs_ref = self.json_to_defs_refs[json_ref] - if defs_ref in self._core_defs_invalid_for_json_schema: - raise self._core_defs_invalid_for_json_schema[defs_ref] - _add_json_refs(self.definitions[defs_ref]) - except KeyError: - if not json_ref.startswith(('http://', 'https://')): - raise + defs_ref = self.json_to_defs_refs[json_ref] + if defs_ref in self._core_defs_invalid_for_json_schema: + raise self._core_defs_invalid_for_json_schema[defs_ref] + _add_json_refs(self.definitions[defs_ref]) - for k, v in schema.items(): - if k == 'examples' and isinstance(v, list): - # Skip examples that may contain arbitrary values and references - # (see the comment in `_get_all_json_refs` for more details). - continue + for v in schema.values(): _add_json_refs(v) elif isinstance(schema, list): for v in schema: @@ -2360,15 +2111,11 @@ class GenerateJsonSchema: unvisited_json_refs = _get_all_json_refs(schema) while unvisited_json_refs: next_json_ref = unvisited_json_refs.pop() - try: - next_defs_ref = self.json_to_defs_refs[next_json_ref] - if next_defs_ref in visited_defs_refs: - continue - visited_defs_refs.add(next_defs_ref) - unvisited_json_refs.update(_get_all_json_refs(self.definitions[next_defs_ref])) - except KeyError: - if not next_json_ref.startswith(('http://', 'https://')): - raise + next_defs_ref = self.json_to_defs_refs[next_json_ref] + if next_defs_ref in visited_defs_refs: + continue + visited_defs_refs.add(next_defs_ref) + unvisited_json_refs.update(_get_all_json_refs(self.definitions[next_defs_ref])) self.definitions = {k: v for k, v in self.definitions.items() if k in visited_defs_refs} @@ -2399,17 +2146,10 @@ def model_json_schema( Returns: The generated JSON Schema. """ - from .main import BaseModel - schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) - - if isinstance(cls.__pydantic_core_schema__, _mock_val_ser.MockCoreSchema): - cls.__pydantic_core_schema__.rebuild() - - if cls is BaseModel: - raise AttributeError('model_json_schema() must be called on a subclass of BaseModel, not BaseModel itself.') - - assert not isinstance(cls.__pydantic_core_schema__, _mock_val_ser.MockCoreSchema), 'this is a bug! please report it' + if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer): + cls.__pydantic_validator__.rebuild() + assert '__pydantic_core_schema__' in cls.__dict__, 'this is a bug! please report it' return schema_generator_instance.generate(cls.__pydantic_core_schema__, mode=mode) @@ -2441,13 +2181,11 @@ def models_json_schema( element, along with the optional title and description keys. """ for cls, _ in models: - if isinstance(cls.__pydantic_core_schema__, _mock_val_ser.MockCoreSchema): - cls.__pydantic_core_schema__.rebuild() + if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer): + cls.__pydantic_validator__.rebuild() instance = schema_generator(by_alias=by_alias, ref_template=ref_template) - inputs: list[tuple[type[BaseModel] | type[PydanticDataclass], JsonSchemaMode, CoreSchema]] = [ - (m, mode, m.__pydantic_core_schema__) for m, mode in models - ] + inputs = [(m, mode, m.__pydantic_core_schema__) for m, mode in models] json_schemas_map, definitions = instance.generate_definitions(inputs) json_schema: dict[str, Any] = {} @@ -2464,16 +2202,16 @@ def models_json_schema( # ##### End JSON Schema Generation Functions ##### -_HashableJsonValue: TypeAlias = Union[ - int, float, str, bool, None, tuple['_HashableJsonValue', ...], tuple[tuple[str, '_HashableJsonValue'], ...] -] +_Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None] +_JsonDict = Dict[str, _Json] +_HashableJson = Union[Tuple[Tuple[str, Any], ...], Tuple[Any, ...], str, int, float, bool, None] -def _deduplicate_schemas(schemas: Iterable[JsonDict]) -> list[JsonDict]: +def _deduplicate_schemas(schemas: Iterable[_JsonDict]) -> list[_JsonDict]: return list({_make_json_hashable(schema): schema for schema in schemas}.values()) -def _make_json_hashable(value: JsonValue) -> _HashableJsonValue: +def _make_json_hashable(value: _Json) -> _HashableJson: if isinstance(value, dict): return tuple(sorted((k, _make_json_hashable(v)) for k, v in value.items())) elif isinstance(value, list): @@ -2482,12 +2220,27 @@ def _make_json_hashable(value: JsonValue) -> _HashableJsonValue: return value +def _sort_json_schema(value: JsonSchemaValue, parent_key: str | None = None) -> JsonSchemaValue: + if isinstance(value, dict): + sorted_dict: dict[str, JsonSchemaValue] = {} + keys = value.keys() + if parent_key != 'properties': + keys = sorted(keys) + for key in keys: + sorted_dict[key] = _sort_json_schema(value[key], parent_key=key) + return sorted_dict + elif isinstance(value, list): + sorted_list: list[JsonSchemaValue] = [] + for item in value: # type: ignore + sorted_list.append(_sort_json_schema(item)) + return sorted_list # type: ignore + else: + return value + + @dataclasses.dataclass(**_internal_dataclass.slots_true) class WithJsonSchema: - """!!! abstract "Usage Documentation" - [`WithJsonSchema` Annotation](../concepts/json_schema.md#withjsonschema-annotation) - - Add this as an annotation on a field to override the (base) JSON schema that would be generated for that field. + """Add this as an annotation on a field to override the (base) JSON schema that would be generated for that field. This provides a way to set a JSON schema for types that would otherwise raise errors when producing a JSON schema, such as Callable, or types that have an is-instance core schema, without needing to go so far as creating a custom subclass of pydantic.json_schema.GenerateJsonSchema. @@ -2511,42 +2264,25 @@ class WithJsonSchema: # This exception is handled in pydantic.json_schema.GenerateJsonSchema._named_required_fields_schema raise PydanticOmit else: - return self.json_schema.copy() + return self.json_schema def __hash__(self) -> int: return hash(type(self.mode)) +@dataclasses.dataclass(**_internal_dataclass.slots_true) class Examples: """Add examples to a JSON schema. - If the JSON Schema already contains examples, the provided examples - will be appended. + Examples should be a map of example names (strings) + to example values (any valid JSON). If `mode` is set this will only apply to that schema generation mode, allowing you to add different examples for validation and serialization. """ - @overload - @deprecated('Using a dict for `examples` is deprecated since v2.9 and will be removed in v3.0. Use a list instead.') - def __init__( - self, examples: dict[str, Any], mode: Literal['validation', 'serialization'] | None = None - ) -> None: ... - - @overload - def __init__(self, examples: list[Any], mode: Literal['validation', 'serialization'] | None = None) -> None: ... - - def __init__( - self, examples: dict[str, Any] | list[Any], mode: Literal['validation', 'serialization'] | None = None - ) -> None: - if isinstance(examples, dict): - warnings.warn( - 'Using a dict for `examples` is deprecated, use a list instead.', - PydanticDeprecatedSince29, - stacklevel=2, - ) - self.examples = examples - self.mode = mode + examples: dict[str, Any] + mode: Literal['validation', 'serialization'] | None = None def __get_pydantic_json_schema__( self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler @@ -2555,36 +2291,9 @@ class Examples: json_schema = handler(core_schema) if mode != handler.mode: return json_schema - examples = json_schema.get('examples') - if examples is None: - json_schema['examples'] = to_jsonable_python(self.examples) - if isinstance(examples, dict): - if isinstance(self.examples, list): - warnings.warn( - 'Updating existing JSON Schema examples of type dict with examples of type list. ' - 'Only the existing examples values will be retained. Note that dict support for ' - 'examples is deprecated and will be removed in v3.0.', - UserWarning, - ) - json_schema['examples'] = to_jsonable_python( - [ex for value in examples.values() for ex in value] + self.examples - ) - else: - json_schema['examples'] = to_jsonable_python({**examples, **self.examples}) - if isinstance(examples, list): - if isinstance(self.examples, list): - json_schema['examples'] = to_jsonable_python(examples + self.examples) - elif isinstance(self.examples, dict): - warnings.warn( - 'Updating existing JSON Schema examples of type list with examples of type dict. ' - 'Only the examples values will be retained. Note that dict support for ' - 'examples is deprecated and will be removed in v3.0.', - UserWarning, - ) - json_schema['examples'] = to_jsonable_python( - examples + [ex for value in self.examples.values() for ex in value] - ) - + examples = json_schema.get('examples', {}) + examples.update(to_jsonable_python(self.examples)) + json_schema['examples'] = examples return json_schema def __hash__(self) -> int: @@ -2594,28 +2303,19 @@ class Examples: def _get_all_json_refs(item: Any) -> set[JsonRef]: """Get all the definitions references from a JSON schema.""" refs: set[JsonRef] = set() - stack = [item] - - while stack: - current = stack.pop() - if isinstance(current, dict): - for key, value in current.items(): - if key == 'examples' and isinstance(value, list): - # Skip examples that may contain arbitrary values and references - # (e.g. `{"examples": [{"$ref": "..."}]}`). Note: checking for value - # of type list is necessary to avoid skipping valid portions of the schema, - # for instance when "examples" is used as a property key. A more robust solution - # could be found, but would require more advanced JSON Schema parsing logic. - continue - if key == '$ref' and isinstance(value, str): - refs.add(JsonRef(value)) - elif isinstance(value, dict): - stack.append(value) - elif isinstance(value, list): - stack.extend(value) - elif isinstance(current, list): - stack.extend(current) - + if isinstance(item, dict): + for key, value in item.items(): + if key == '$ref' and isinstance(value, str): + # the isinstance check ensures that '$ref' isn't the name of a property, etc. + refs.add(JsonRef(value)) + elif isinstance(value, dict): + refs.update(_get_all_json_refs(value)) + elif isinstance(value, list): + for item in value: + refs.update(_get_all_json_refs(item)) + elif isinstance(item, list): + for item in item: + refs.update(_get_all_json_refs(item)) return refs @@ -2627,51 +2327,20 @@ else: @dataclasses.dataclass(**_internal_dataclass.slots_true) class SkipJsonSchema: - """!!! abstract "Usage Documentation" - [`SkipJsonSchema` Annotation](../concepts/json_schema.md#skipjsonschema-annotation) - - Add this as an annotation on a field to skip generating a JSON schema for that field. + """Add this as an annotation on a field to skip generating a JSON schema for that field. Example: - ```python - from pprint import pprint - from typing import Union - + ```py from pydantic import BaseModel from pydantic.json_schema import SkipJsonSchema class Model(BaseModel): - a: Union[int, None] = None # (1)! - b: Union[int, SkipJsonSchema[None]] = None # (2)! - c: SkipJsonSchema[Union[int, None]] = None # (3)! + a: int | SkipJsonSchema[None] = None - pprint(Model.model_json_schema()) - ''' - { - 'properties': { - 'a': { - 'anyOf': [ - {'type': 'integer'}, - {'type': 'null'} - ], - 'default': None, - 'title': 'A' - }, - 'b': { - 'default': None, - 'title': 'B', - 'type': 'integer' - } - }, - 'title': 'Model', - 'type': 'object' - } - ''' + + print(Model.model_json_schema()) + #> {'properties': {'a': {'default': None, 'title': 'A', 'type': 'integer'}}, 'title': 'Model', 'type': 'object'} ``` - - 1. The integer and null types are both included in the schema for `a`. - 2. The integer type is the only type included in the schema for `b`. - 3. The entirety of the `c` field is omitted from the schema. """ def __class_getitem__(cls, item: AnyType) -> AnyType: @@ -2686,7 +2355,9 @@ else: return hash(type(self)) -def _get_typed_dict_config(cls: type[Any] | None) -> ConfigDict: +def _get_typed_dict_config(schema: core_schema.TypedDictSchema) -> ConfigDict: + metadata = _core_metadata.CoreMetadataHandler(schema).metadata + cls = metadata.get('pydantic_typed_dict_cls') if cls is not None: try: return _decorators.get_attribute_from_bases(cls, '__pydantic_config__') diff --git a/venv/lib/python3.12/site-packages/pydantic/main.py b/venv/lib/python3.12/site-packages/pydantic/main.py index 272c10a..355a912 100644 --- a/venv/lib/python3.12/site-packages/pydantic/main.py +++ b/venv/lib/python3.12/site-packages/pydantic/main.py @@ -1,38 +1,16 @@ """Logic for creating models.""" - -# Because `dict` is in the local namespace of the `BaseModel` class, we use `Dict` for annotations. -# TODO v3 fallback to `dict` when the deprecated `dict` method gets removed. -# ruff: noqa: UP035 - from __future__ import annotations as _annotations -import operator -import sys import types import typing import warnings -from collections.abc import Generator, Mapping from copy import copy, deepcopy -from functools import cached_property -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - Literal, - TypeVar, - Union, - cast, - overload, -) +from typing import Any, ClassVar import pydantic_core import typing_extensions -from pydantic_core import PydanticUndefined, ValidationError -from typing_extensions import Self, TypeAlias, Unpack +from pydantic_core import PydanticUndefined -from . import PydanticDeprecatedSince20, PydanticDeprecatedSince211 from ._internal import ( _config, _decorators, @@ -41,29 +19,34 @@ from ._internal import ( _generics, _mock_val_ser, _model_construction, - _namespace_utils, _repr, _typing_extra, _utils, ) from ._migration import getattr_migration -from .aliases import AliasChoices, AliasPath from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler from .config import ConfigDict from .errors import PydanticUndefinedAnnotation, PydanticUserError +from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema -from .plugin._schema_validator import PluggableSchemaValidator +from .warnings import PydanticDeprecatedSince20 -if TYPE_CHECKING: +if typing.TYPE_CHECKING: from inspect import Signature from pathlib import Path from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator + from typing_extensions import Literal, Unpack - from ._internal._namespace_utils import MappingNamespace from ._internal._utils import AbstractSetIntStr, MappingIntStrAny from .deprecated.parse import Protocol as DeprecatedParseProtocol - from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr + from .fields import Field as _Field + + AnyClassMethod = classmethod[Any, Any, Any] + TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None] + Model = typing.TypeVar('Model', bound='BaseModel') + # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope + IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None' else: # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 # and https://youtrack.jetbrains.com/issue/PY-51428 @@ -71,66 +54,21 @@ else: __all__ = 'BaseModel', 'create_model' -# Keep these type aliases available at runtime: -TupleGenerator: TypeAlias = Generator[tuple[str, Any], None, None] -# NOTE: In reality, `bool` should be replaced by `Literal[True]` but mypy fails to correctly apply bidirectional -# type inference (e.g. when using `{'a': {'b': True}}`): -# NOTE: Keep this type alias in sync with the stub definition in `pydantic-core`: -IncEx: TypeAlias = Union[set[int], set[str], Mapping[int, Union['IncEx', bool]], Mapping[str, Union['IncEx', bool]]] - _object_setattr = _model_construction.object_setattr -def _check_frozen(model_cls: type[BaseModel], name: str, value: Any) -> None: - if model_cls.model_config.get('frozen'): - error_type = 'frozen_instance' - elif getattr(model_cls.__pydantic_fields__.get(name), 'frozen', False): - error_type = 'frozen_field' - else: - return - - raise ValidationError.from_exception_data( - model_cls.__name__, [{'type': error_type, 'loc': (name,), 'input': value}] - ) - - -def _model_field_setattr_handler(model: BaseModel, name: str, val: Any) -> None: - model.__dict__[name] = val - model.__pydantic_fields_set__.add(name) - - -def _private_setattr_handler(model: BaseModel, name: str, val: Any) -> None: - if getattr(model, '__pydantic_private__', None) is None: - # While the attribute should be present at this point, this may not be the case if - # users do unusual stuff with `model_post_init()` (which is where the `__pydantic_private__` - # is initialized, by wrapping the user-defined `model_post_init()`), e.g. if they mock - # the `model_post_init()` call. Ideally we should find a better way to init private attrs. - object.__setattr__(model, '__pydantic_private__', {}) - model.__pydantic_private__[name] = val # pyright: ignore[reportOptionalSubscript] - - -_SIMPLE_SETATTR_HANDLERS: Mapping[str, Callable[[BaseModel, str, Any], None]] = { - 'model_field': _model_field_setattr_handler, - 'validate_assignment': lambda model, name, val: model.__pydantic_validator__.validate_assignment(model, name, val), # pyright: ignore[reportAssignmentType] - 'private': _private_setattr_handler, - 'cached_property': lambda model, name, val: model.__dict__.__setitem__(name, val), - 'extra_known': lambda model, name, val: _object_setattr(model, name, val), -} - - class BaseModel(metaclass=_model_construction.ModelMetaclass): - """!!! abstract "Usage Documentation" - [Models](../concepts/models.md) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/models/ A base class for creating Pydantic models. Attributes: - __class_vars__: The names of the class variables defined on the model. + __class_vars__: The names of classvars defined on the model. __private_attributes__: Metadata about the private attributes of the model. - __signature__: The synthesized `__init__` [`Signature`][inspect.Signature] of the model. + __signature__: The signature for instantiating the model. __pydantic_complete__: Whether model building is completed, or if there are still undefined fields. - __pydantic_core_schema__: The core schema of the model. + __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer. __pydantic_custom_init__: Whether the model has a custom `__init__` function. __pydantic_decorators__: Metadata containing the decorators defined on the model. This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1. @@ -138,95 +76,63 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these. __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models. __pydantic_post_init__: The name of the post-init method for the model, if defined. - __pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel]. - __pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model. - __pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model. + __pydantic_root_model__: Whether the model is a `RootModel`. + __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model. + __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model. - __pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects. - __pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects. - - __pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra] - is set to `'allow'`. - __pydantic_fields_set__: The names of fields explicitly set during instantiation. - __pydantic_private__: Values of private attributes set on the model instance. + __pydantic_extra__: An instance attribute with the values of extra fields from validation when + `model_config['extra'] == 'allow'`. + __pydantic_fields_set__: An instance attribute with the names of fields explicitly specified during validation. + __pydantic_private__: Instance attribute with the values of private attributes set on the model instance. """ - # Note: Many of the below class vars are defined in the metaclass, but we define them here for type checking purposes. + if typing.TYPE_CHECKING: + # Here we provide annotations for the attributes of BaseModel. + # Many of these are populated by the metaclass, which is why this section is in a `TYPE_CHECKING` block. + # However, for the sake of easy review, we have included type annotations of all class and instance attributes + # of `BaseModel` here: - model_config: ClassVar[ConfigDict] = ConfigDict() - """ - Configuration for the model, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. - """ + # Class attributes + model_config: ClassVar[ConfigDict] + """ + Configuration for the model, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. + """ - __class_vars__: ClassVar[set[str]] - """The names of the class variables defined on the model.""" + model_fields: ClassVar[dict[str, FieldInfo]] + """ + Metadata about the fields defined on the model, + mapping of field names to [`FieldInfo`][pydantic.fields.FieldInfo]. - __private_attributes__: ClassVar[Dict[str, ModelPrivateAttr]] # noqa: UP006 - """Metadata about the private attributes of the model.""" + This replaces `Model.__fields__` from Pydantic V1. + """ - __signature__: ClassVar[Signature] - """The synthesized `__init__` [`Signature`][inspect.Signature] of the model.""" + __class_vars__: ClassVar[set[str]] + __private_attributes__: ClassVar[dict[str, ModelPrivateAttr]] + __signature__: ClassVar[Signature] - __pydantic_complete__: ClassVar[bool] = False - """Whether model building is completed, or if there are still undefined fields.""" + __pydantic_complete__: ClassVar[bool] + __pydantic_core_schema__: ClassVar[CoreSchema] + __pydantic_custom_init__: ClassVar[bool] + __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] + __pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata] + __pydantic_parent_namespace__: ClassVar[dict[str, Any] | None] + __pydantic_post_init__: ClassVar[None | Literal['model_post_init']] + __pydantic_root_model__: ClassVar[bool] + __pydantic_serializer__: ClassVar[SchemaSerializer] + __pydantic_validator__: ClassVar[SchemaValidator] - __pydantic_core_schema__: ClassVar[CoreSchema] - """The core schema of the model.""" - - __pydantic_custom_init__: ClassVar[bool] - """Whether the model has a custom `__init__` method.""" - - # Must be set for `GenerateSchema.model_schema` to work for a plain `BaseModel` annotation. - __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] = _decorators.DecoratorInfos() - """Metadata containing the decorators defined on the model. - This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.""" - - __pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata] - """Metadata for generic models; contains data used for a similar purpose to - __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.""" - - __pydantic_parent_namespace__: ClassVar[Dict[str, Any] | None] = None # noqa: UP006 - """Parent namespace of the model, used for automatic rebuilding of models.""" - - __pydantic_post_init__: ClassVar[None | Literal['model_post_init']] - """The name of the post-init method for the model, if defined.""" - - __pydantic_root_model__: ClassVar[bool] = False - """Whether the model is a [`RootModel`][pydantic.root_model.RootModel].""" - - __pydantic_serializer__: ClassVar[SchemaSerializer] - """The `pydantic-core` `SchemaSerializer` used to dump instances of the model.""" - - __pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator] - """The `pydantic-core` `SchemaValidator` used to validate instances of the model.""" - - __pydantic_fields__: ClassVar[Dict[str, FieldInfo]] # noqa: UP006 - """A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects. - This replaces `Model.__fields__` from Pydantic V1. - """ - - __pydantic_setattr_handlers__: ClassVar[Dict[str, Callable[[BaseModel, str, Any], None]]] # noqa: UP006 - """`__setattr__` handlers. Memoizing the handlers leads to a dramatic performance improvement in `__setattr__`""" - - __pydantic_computed_fields__: ClassVar[Dict[str, ComputedFieldInfo]] # noqa: UP006 - """A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.""" - - __pydantic_extra__: dict[str, Any] | None = _model_construction.NoInitField(init=False) - """A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra] is set to `'allow'`.""" - - __pydantic_fields_set__: set[str] = _model_construction.NoInitField(init=False) - """The names of fields explicitly set during instantiation.""" - - __pydantic_private__: dict[str, Any] | None = _model_construction.NoInitField(init=False) - """Values of private attributes set on the model instance.""" - - if not TYPE_CHECKING: - # Prevent `BaseModel` from being instantiated directly - # (defined in an `if not TYPE_CHECKING` block for clarity and to avoid type checking errors): - __pydantic_core_schema__ = _mock_val_ser.MockCoreSchema( - 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly', - code='base-model-instantiated', - ) + # Instance attributes + # Note: we use the non-existent kwarg `init=False` in pydantic.fields.Field below so that @dataclass_transform + # doesn't think these are valid as keyword arguments to the class initializer. + __pydantic_extra__: dict[str, Any] | None = _Field(init=False) # type: ignore + __pydantic_fields_set__: set[str] = _Field(init=False) # type: ignore + __pydantic_private__: dict[str, Any] | None = _Field(init=False) # type: ignore + else: + # `model_fields` and `__pydantic_decorators__` must be set for + # pydantic._internal._generate_schema.GenerateSchema.model_schema to work for a plain BaseModel annotation + model_fields = {} + __pydantic_decorators__ = _decorators.DecoratorInfos() + # Prevent `BaseModel` from being instantiated directly: __pydantic_validator__ = _mock_val_ser.MockValSer( 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly', val_or_ser='validator', @@ -240,49 +146,34 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' - def __init__(self, /, **data: Any) -> None: + model_config = ConfigDict() + __pydantic_complete__ = False + __pydantic_root_model__ = False + + def __init__(__pydantic_self__, **data: Any) -> None: # type: ignore """Create a new model by parsing and validating input data from keyword arguments. Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model. - `self` is explicitly positional-only to allow `self` as a field name. + `__init__` uses `__pydantic_self__` instead of the more common `self` for the first arg to + allow `self` as a field name. """ # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks __tracebackhide__ = True - validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self) - if self is not validated_self: - warnings.warn( - 'A custom validator is returning a value other than `self`.\n' - "Returning anything other than `self` from a top level model validator isn't supported when validating via `__init__`.\n" - 'See the `model_validator` docs (https://docs.pydantic.dev/latest/concepts/validators/#model-validators) for more details.', - stacklevel=2, - ) + __pydantic_self__.__pydantic_validator__.validate_python(data, self_instance=__pydantic_self__) # The following line sets a flag that we use to determine when `__init__` gets overridden by the user - __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess] + __init__.__pydantic_base_init__ = True - @_utils.deprecated_instance_property - @classmethod - def model_fields(cls) -> dict[str, FieldInfo]: - """A mapping of field names to their respective [`FieldInfo`][pydantic.fields.FieldInfo] instances. + @property + def model_computed_fields(self) -> dict[str, ComputedFieldInfo]: + """Get the computed fields of this model instance. - !!! warning - Accessing this attribute from a model instance is deprecated, and will not work in Pydantic V3. - Instead, you should access this attribute from the model class. + Returns: + A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects. """ - return getattr(cls, '__pydantic_fields__', {}) - - @_utils.deprecated_instance_property - @classmethod - def model_computed_fields(cls) -> dict[str, ComputedFieldInfo]: - """A mapping of computed field names to their respective [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] instances. - - !!! warning - Accessing this attribute from a model instance is deprecated, and will not work in Pydantic V3. - Instead, you should access this attribute from the model class. - """ - return getattr(cls, '__pydantic_computed_fields__', {}) + return {k: v.info for k, v in self.__pydantic_decorators__.computed_fields.items()} @property def model_extra(self) -> dict[str, Any] | None: @@ -295,7 +186,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): @property def model_fields_set(self) -> set[str]: - """Returns the set of fields that have been explicitly set on this model instance. + """Returns the set of fields that have been set on this model instance. Returns: A set of strings representing the fields that have been set, @@ -304,23 +195,15 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return self.__pydantic_fields_set__ @classmethod - def model_construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self: # noqa: C901 + def model_construct(cls: type[Model], _fields_set: set[str] | None = None, **values: Any) -> Model: """Creates a new instance of the `Model` class with validated data. Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data. Default values are respected, but no other validation is performed. - - !!! note - `model_construct()` generally respects the `model_config.extra` setting on the provided model. - That is, if `model_config.extra == 'allow'`, then all extra passed values are added to the model instance's `__dict__` - and `__pydantic_extra__` fields. If `model_config.extra == 'ignore'` (the default), then all extra passed values are ignored. - Because no validation is performed with a call to `model_construct()`, having `model_config.extra == 'forbid'` does not result in - an error if extra values are passed, but they will be ignored. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values Args: - _fields_set: A set of field names that were originally explicitly set during instantiation. If provided, - this is directly used for the [`model_fields_set`][pydantic.BaseModel.model_fields_set] attribute. - Otherwise, the field names from the `values` argument will be used. + _fields_set: The set of field names accepted for the Model instance. values: Trusted or pre-validated data dictionary. Returns: @@ -328,42 +211,25 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): """ m = cls.__new__(cls) fields_values: dict[str, Any] = {} - fields_set = set() - - for name, field in cls.__pydantic_fields__.items(): - if field.alias is not None and field.alias in values: + defaults: dict[str, Any] = {} # keeping this separate from `fields_values` helps us compute `_fields_set` + for name, field in cls.model_fields.items(): + if field.alias and field.alias in values: fields_values[name] = values.pop(field.alias) - fields_set.add(name) - - if (name not in fields_set) and (field.validation_alias is not None): - validation_aliases: list[str | AliasPath] = ( - field.validation_alias.choices - if isinstance(field.validation_alias, AliasChoices) - else [field.validation_alias] - ) - - for alias in validation_aliases: - if isinstance(alias, str) and alias in values: - fields_values[name] = values.pop(alias) - fields_set.add(name) - break - elif isinstance(alias, AliasPath): - value = alias.search_dict_for_path(values) - if value is not PydanticUndefined: - fields_values[name] = value - fields_set.add(name) - break - - if name not in fields_set: - if name in values: - fields_values[name] = values.pop(name) - fields_set.add(name) - elif not field.is_required(): - fields_values[name] = field.get_default(call_default_factory=True, validated_data=fields_values) + elif name in values: + fields_values[name] = values.pop(name) + elif not field.is_required(): + defaults[name] = field.get_default(call_default_factory=True) if _fields_set is None: - _fields_set = fields_set + _fields_set = set(fields_values.keys()) + fields_values.update(defaults) - _extra: dict[str, Any] | None = values if cls.model_config.get('extra') == 'allow' else None + _extra: dict[str, Any] | None = None + if cls.model_config.get('extra') == 'allow': + _extra = {} + for k, v in values.items(): + _extra[k] = v + else: + fields_values.update(values) _object_setattr(m, '__dict__', fields_values) _object_setattr(m, '__pydantic_fields_set__', _fields_set) if not cls.__pydantic_root_model__: @@ -371,12 +237,6 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): if cls.__pydantic_post_init__: m.model_post_init(None) - # update private attributes with values set - if hasattr(m, '__pydantic_private__') and m.__pydantic_private__ is not None: - for k, v in values.items(): - if k in m.__private_attributes__: - m.__pydantic_private__[k] = v - elif not cls.__pydantic_root_model__: # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist # Since it doesn't, that means that `__pydantic_private__` should be set to None @@ -384,17 +244,11 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return m - def model_copy(self, *, update: Mapping[str, Any] | None = None, deep: bool = False) -> Self: - """!!! abstract "Usage Documentation" - [`model_copy`](../concepts/serialization.md#model_copy) + def model_copy(self: Model, *, update: dict[str, Any] | None = None, deep: bool = False) -> Model: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#model_copy Returns a copy of the model. - !!! note - The underlying instance's [`__dict__`][object.__dict__] attribute is copied. This - might have unexpected side effects if you store anything in it, on top of the model - fields (e.g. the value of [cached properties][functools.cached_property]). - Args: update: Values to change/add in the new model. Note: the data is not validated before creating the new model. You should trust this data. @@ -407,7 +261,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): if update: if self.model_config.get('extra') == 'allow': for k, v in update.items(): - if k in self.__pydantic_fields__: + if k in self.model_fields: copied.__dict__[k] = v else: if copied.__pydantic_extra__ is None: @@ -422,40 +276,31 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): self, *, mode: Literal['json', 'python'] | str = 'python', - include: IncEx | None = None, - exclude: IncEx | None = None, - context: Any | None = None, - by_alias: bool | None = None, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, - fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, + warnings: bool = True, ) -> dict[str, Any]: - """!!! abstract "Usage Documentation" - [`model_dump`](../concepts/serialization.md#modelmodel_dump) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. Args: mode: The mode in which `to_python` should run. - If mode is 'json', the output will only contain JSON serializable types. - If mode is 'python', the output may contain non-JSON-serializable Python objects. - include: A set of fields to include in the output. - exclude: A set of fields to exclude from the output. - context: Additional context to pass to the serializer. + If mode is 'json', the dictionary will only contain JSON serializable types. + If mode is 'python', the dictionary may contain any Python objects. + include: A list of fields to include in the output. + exclude: A list of fields to exclude from the output. by_alias: Whether to use the field's alias in the dictionary key if defined. - exclude_unset: Whether to exclude fields that have not been explicitly set. - exclude_defaults: Whether to exclude fields that are set to their default value. - exclude_none: Whether to exclude fields that have a value of `None`. - round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. - warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. - fallback: A function to call when an unknown value is encountered. If not provided, - a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + exclude_unset: Whether to exclude fields that are unset or None from the output. + exclude_defaults: Whether to exclude fields that are set to their default value from the output. + exclude_none: Whether to exclude fields that have a value of `None` from the output. + round_trip: Whether to enable serialization and deserialization round-trip support. + warnings: Whether to log warnings when invalid fields are encountered. Returns: A dictionary representation of the model. @@ -466,52 +311,40 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): by_alias=by_alias, include=include, exclude=exclude, - context=context, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, - fallback=fallback, - serialize_as_any=serialize_as_any, ) def model_dump_json( self, *, indent: int | None = None, - include: IncEx | None = None, - exclude: IncEx | None = None, - context: Any | None = None, - by_alias: bool | None = None, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, - fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, + warnings: bool = True, ) -> str: - """!!! abstract "Usage Documentation" - [`model_dump_json`](../concepts/serialization.md#modelmodel_dump_json) + """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json Generates a JSON representation of the model using Pydantic's `to_json` method. Args: indent: Indentation to use in the JSON output. If None is passed, the output will be compact. - include: Field(s) to include in the JSON output. - exclude: Field(s) to exclude from the JSON output. - context: Additional context to pass to the serializer. + include: Field(s) to include in the JSON output. Can take either a string or set of strings. + exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. by_alias: Whether to serialize using field aliases. exclude_unset: Whether to exclude fields that have not been explicitly set. - exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_defaults: Whether to exclude fields that have the default value. exclude_none: Whether to exclude fields that have a value of `None`. - round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. - warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. - fallback: A function to call when an unknown value is encountered. If not provided, - a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. + round_trip: Whether to use serialization/deserialization between JSON and class instance. + warnings: Whether to show any warnings that occurred during serialization. Returns: A JSON string representation of the model. @@ -521,15 +354,12 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): indent=indent, include=include, exclude=exclude, - context=context, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, - fallback=fallback, - serialize_as_any=serialize_as_any, ).decode() @classmethod @@ -583,7 +413,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): params_component = ', '.join(param_names) return f'{cls.__name__}[{params_component}]' - def model_post_init(self, context: Any, /) -> None: + def model_post_init(self, __context: Any) -> None: """Override this method to perform additional initialization after `__init__` and `model_construct`. This is useful if you want to do some validation that requires the entire model to be initialized. """ @@ -596,7 +426,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): force: bool = False, raise_errors: bool = True, _parent_namespace_depth: int = 2, - _types_namespace: MappingNamespace | None = None, + _types_namespace: dict[str, Any] | None = None, ) -> bool | None: """Try to rebuild the pydantic-core schema for the model. @@ -615,77 +445,52 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): """ if not force and cls.__pydantic_complete__: return None - - for attr in ('__pydantic_core_schema__', '__pydantic_validator__', '__pydantic_serializer__'): - if attr in cls.__dict__ and not isinstance(getattr(cls, attr), _mock_val_ser.MockValSer): - # Deleting the validator/serializer is necessary as otherwise they can get reused in - # pydantic-core. We do so only if they aren't mock instances, otherwise — as `model_rebuild()` - # isn't thread-safe — concurrent model instantiations can lead to the parent validator being used. - # Same applies for the core schema that can be reused in schema generation. - delattr(cls, attr) - - cls.__pydantic_complete__ = False - - if _types_namespace is not None: - rebuild_ns = _types_namespace - elif _parent_namespace_depth > 0: - rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {} else: - rebuild_ns = {} + if '__pydantic_core_schema__' in cls.__dict__: + delattr(cls, '__pydantic_core_schema__') # delete cached value to ensure full rebuild happens + if _types_namespace is not None: + types_namespace: dict[str, Any] | None = _types_namespace.copy() + else: + if _parent_namespace_depth > 0: + frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {} + cls_parent_ns = ( + _model_construction.unpack_lenient_weakvaluedict(cls.__pydantic_parent_namespace__) or {} + ) + types_namespace = {**cls_parent_ns, **frame_parent_ns} + cls.__pydantic_parent_namespace__ = _model_construction.build_lenient_weakvaluedict(types_namespace) + else: + types_namespace = _model_construction.unpack_lenient_weakvaluedict( + cls.__pydantic_parent_namespace__ + ) - parent_ns = _model_construction.unpack_lenient_weakvaluedict(cls.__pydantic_parent_namespace__) or {} + types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace) - ns_resolver = _namespace_utils.NsResolver( - parent_namespace={**rebuild_ns, **parent_ns}, - ) - - if not cls.__pydantic_fields_complete__: - typevars_map = _generics.get_model_typevars_map(cls) - try: - cls.__pydantic_fields__ = _fields.rebuild_model_fields( - cls, - ns_resolver=ns_resolver, - typevars_map=typevars_map, - ) - except NameError as e: - exc = PydanticUndefinedAnnotation.from_name_error(e) - _mock_val_ser.set_model_mocks(cls, f'`{exc.name}`') - if raise_errors: - raise exc from e - - if not raise_errors and not cls.__pydantic_fields_complete__: - # No need to continue with schema gen, it is guaranteed to fail - return False - - assert cls.__pydantic_fields_complete__ - - return _model_construction.complete_model_class( - cls, - _config.ConfigWrapper(cls.model_config, check=False), - raise_errors=raise_errors, - ns_resolver=ns_resolver, - ) + # manually override defer_build so complete_model_class doesn't skip building the model again + config = {**cls.model_config, 'defer_build': False} + return _model_construction.complete_model_class( + cls, + cls.__name__, + _config.ConfigWrapper(config, check=False), + raise_errors=raise_errors, + types_namespace=types_namespace, + ) @classmethod def model_validate( - cls, + cls: type[Model], obj: Any, *, strict: bool | None = None, from_attributes: bool | None = None, - context: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, - ) -> Self: + context: dict[str, Any] | None = None, + ) -> Model: """Validate a pydantic model instance. Args: obj: The object to validate. - strict: Whether to enforce types strictly. + strict: Whether to raise an exception on invalid fields. from_attributes: Whether to extract data from object attributes. context: Additional context to pass to the validator. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Raises: ValidationError: If the object could not be validated. @@ -695,128 +500,95 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): """ # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks __tracebackhide__ = True - - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - return cls.__pydantic_validator__.validate_python( - obj, strict=strict, from_attributes=from_attributes, context=context, by_alias=by_alias, by_name=by_name + obj, strict=strict, from_attributes=from_attributes, context=context ) @classmethod def model_validate_json( - cls, + cls: type[Model], json_data: str | bytes | bytearray, *, strict: bool | None = None, - context: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, - ) -> Self: - """!!! abstract "Usage Documentation" - [JSON Parsing](../concepts/json.md#json-parsing) - - Validate the given JSON data against the Pydantic model. + context: dict[str, Any] | None = None, + ) -> Model: + """Validate the given JSON data against the Pydantic model. Args: json_data: The JSON data to validate. strict: Whether to enforce types strictly. context: Extra variables to pass to the validator. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Returns: The validated Pydantic model. Raises: - ValidationError: If `json_data` is not a JSON string or the object could not be validated. + ValueError: If `json_data` is not a JSON string. """ # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks __tracebackhide__ = True - - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - - return cls.__pydantic_validator__.validate_json( - json_data, strict=strict, context=context, by_alias=by_alias, by_name=by_name - ) + return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context) @classmethod def model_validate_strings( - cls, + cls: type[Model], obj: Any, *, strict: bool | None = None, - context: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, - ) -> Self: - """Validate the given object with string data against the Pydantic model. + context: dict[str, Any] | None = None, + ) -> Model: + """Validate the given object contains string data against the Pydantic model. Args: - obj: The object containing string data to validate. + obj: The object contains string data to validate. strict: Whether to enforce types strictly. context: Extra variables to pass to the validator. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Returns: The validated Pydantic model. """ # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks __tracebackhide__ = True - - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - - return cls.__pydantic_validator__.validate_strings( - obj, strict=strict, context=context, by_alias=by_alias, by_name=by_name - ) + return cls.__pydantic_validator__.validate_strings(obj, strict=strict, context=context) @classmethod - def __get_pydantic_core_schema__(cls, source: type[BaseModel], handler: GetCoreSchemaHandler, /) -> CoreSchema: - # This warning is only emitted when calling `super().__get_pydantic_core_schema__` from a model subclass. - # In the generate schema logic, this method (`BaseModel.__get_pydantic_core_schema__`) is special cased to - # *not* be called if not overridden. - warnings.warn( - 'The `__get_pydantic_core_schema__` method of the `BaseModel` class is deprecated. If you are calling ' - '`super().__get_pydantic_core_schema__` when overriding the method on a Pydantic model, consider using ' - '`handler(source)` instead. However, note that overriding this method on models can lead to unexpected ' - 'side effects.', - PydanticDeprecatedSince211, - stacklevel=2, - ) - # Logic copied over from `GenerateSchema._model_schema`: - schema = cls.__dict__.get('__pydantic_core_schema__') - if schema is not None and not isinstance(schema, _mock_val_ser.MockCoreSchema): - return cls.__pydantic_core_schema__ + def __get_pydantic_core_schema__(cls, __source: type[BaseModel], __handler: GetCoreSchemaHandler) -> CoreSchema: + """Hook into generating the model's CoreSchema. - return handler(source) + Args: + __source: The class we are generating a schema for. + This will generally be the same as the `cls` argument if this is a classmethod. + __handler: Call into Pydantic's internal JSON schema generation. + A callable that calls into Pydantic's internal CoreSchema generation logic. + + Returns: + A `pydantic-core` `CoreSchema`. + """ + # Only use the cached value from this _exact_ class; we don't want one from a parent class + # This is why we check `cls.__dict__` and don't use `cls.__pydantic_core_schema__` or similar. + if '__pydantic_core_schema__' in cls.__dict__: + # Due to the way generic classes are built, it's possible that an invalid schema may be temporarily + # set on generic classes. I think we could resolve this to ensure that we get proper schema caching + # for generics, but for simplicity for now, we just always rebuild if the class has a generic origin. + if not cls.__pydantic_generic_metadata__['origin']: + return cls.__pydantic_core_schema__ + + return __handler(__source) @classmethod def __get_pydantic_json_schema__( cls, - core_schema: CoreSchema, - handler: GetJsonSchemaHandler, - /, + __core_schema: CoreSchema, + __handler: GetJsonSchemaHandler, ) -> JsonSchemaValue: """Hook into generating the model's JSON schema. Args: - core_schema: A `pydantic-core` CoreSchema. + __core_schema: A `pydantic-core` CoreSchema. You can ignore this argument and call the handler with a new CoreSchema, wrap this CoreSchema (`{'type': 'nullable', 'schema': current_schema}`), or just call the handler with the original schema. - handler: Call into Pydantic's internal JSON schema generation. + __handler: Call into Pydantic's internal JSON schema generation. This will raise a `pydantic.errors.PydanticInvalidForJsonSchema` if JSON schema generation fails. Since this gets called by `BaseModel.model_json_schema` you can override the @@ -826,7 +598,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): Returns: A JSON schema, as a Python object. """ - return handler(core_schema) + return __handler(__core_schema) @classmethod def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: @@ -863,12 +635,12 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): if not isinstance(typevar_values, tuple): typevar_values = (typevar_values,) + _generics.check_parameters_count(cls, typevar_values) - # For a model `class Model[T, U, V = int](BaseModel): ...` parametrized with `(str, bool)`, - # this gives us `{T: str, U: bool, V: int}`: - typevars_map = _generics.map_generic_model_arguments(cls, typevar_values) - # We also update the provided args to use defaults values (`(str, bool)` becomes `(str, bool, int)`): - typevar_values = tuple(v for v in typevars_map.values()) + # Build map from generic typevars to passed params + typevars_map: dict[_typing_extra.TypeVarType, type[Any]] = dict( + zip(cls.__pydantic_generic_metadata__['parameters'], typevar_values) + ) if _utils.all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: submodel = cls # if arguments are equal to parameters it's the same object @@ -887,34 +659,31 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): ) # use dict as ordered set with _generics.generic_recursion_self_type(origin, args) as maybe_self_type: + if maybe_self_type is not None: + return maybe_self_type + cached = _generics.get_cached_generic_type_late(cls, typevar_values, origin, args) if cached is not None: return cached - if maybe_self_type is not None: - return maybe_self_type - # Attempt to rebuild the origin in case new types have been defined try: - # depth 2 gets you above this __class_getitem__ call. - # Note that we explicitly provide the parent ns, otherwise - # `model_rebuild` will use the parent ns no matter if it is the ns of a module. - # We don't want this here, as this has unexpected effects when a model - # is being parametrized during a forward annotation evaluation. - parent_ns = _typing_extra.parent_frame_namespace(parent_depth=2) or {} - origin.model_rebuild(_types_namespace=parent_ns) + # depth 3 gets you above this __class_getitem__ call + origin.model_rebuild(_parent_namespace_depth=3) except PydanticUndefinedAnnotation: # It's okay if it fails, it just means there are still undefined types # that could be evaluated later. + # TODO: Make sure validation fails if there are still undefined types, perhaps using MockValidator pass submodel = _generics.create_generic_submodel(model_name, origin, args, params) + # Update cache _generics.set_cached_generic_type(cls, typevar_values, submodel, origin, args) return submodel - def __copy__(self) -> Self: + def __copy__(self: Model) -> Model: """Returns a shallow copy of the model.""" cls = type(self) m = cls.__new__(cls) @@ -922,7 +691,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): _object_setattr(m, '__pydantic_extra__', copy(self.__pydantic_extra__)) _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) - if not hasattr(self, '__pydantic_private__') or self.__pydantic_private__ is None: + if self.__pydantic_private__ is None: _object_setattr(m, '__pydantic_private__', None) else: _object_setattr( @@ -933,7 +702,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return m - def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model: """Returns a deep copy of the model.""" cls = type(self) m = cls.__new__(cls) @@ -943,7 +712,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): # and attempting a deepcopy would be marginally slower. _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) - if not hasattr(self, '__pydantic_private__') or self.__pydantic_private__ is None: + if self.__pydantic_private__ is None: _object_setattr(m, '__pydantic_private__', None) else: _object_setattr( @@ -954,9 +723,8 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return m - if not TYPE_CHECKING: + if not typing.TYPE_CHECKING: # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access - # The same goes for __setattr__ and __delattr__, see: https://github.com/pydantic/pydantic/issues/8643 def __getattr__(self, item: str) -> Any: private_attributes = object.__getattribute__(self, '__private_attributes__') @@ -978,7 +746,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): except AttributeError: pydantic_extra = None - if pydantic_extra: + if pydantic_extra is not None: try: return pydantic_extra[item] except KeyError as exc: @@ -990,105 +758,73 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): # this is the current error raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') - def __setattr__(self, name: str, value: Any) -> None: - if (setattr_handler := self.__pydantic_setattr_handlers__.get(name)) is not None: - setattr_handler(self, name, value) - # if None is returned from _setattr_handler, the attribute was set directly - elif (setattr_handler := self._setattr_handler(name, value)) is not None: - setattr_handler(self, name, value) # call here to not memo on possibly unknown fields - self.__pydantic_setattr_handlers__[name] = setattr_handler # memoize the handler for faster access - - def _setattr_handler(self, name: str, value: Any) -> Callable[[BaseModel, str, Any], None] | None: - """Get a handler for setting an attribute on the model instance. - - Returns: - A handler for setting an attribute on the model instance. Used for memoization of the handler. - Memoizing the handlers leads to a dramatic performance improvement in `__setattr__` - Returns `None` when memoization is not safe, then the attribute is set directly. - """ - cls = self.__class__ - if name in cls.__class_vars__: - raise AttributeError( - f'{name!r} is a ClassVar of `{cls.__name__}` and cannot be set on an instance. ' - f'If you want to set a value on the class, use `{cls.__name__}.{name} = value`.' - ) - elif not _fields.is_valid_field_name(name): - if (attribute := cls.__private_attributes__.get(name)) is not None: - if hasattr(attribute, '__set__'): - return lambda model, _name, val: attribute.__set__(model, val) - else: - return _SIMPLE_SETATTR_HANDLERS['private'] - else: - _object_setattr(self, name, value) - return None # Can not return memoized handler with possibly freeform attr names - - attr = getattr(cls, name, None) - # NOTE: We currently special case properties and `cached_property`, but we might need - # to generalize this to all data/non-data descriptors at some point. For non-data descriptors - # (such as `cached_property`), it isn't obvious though. `cached_property` caches the value - # to the instance's `__dict__`, but other non-data descriptors might do things differently. - if isinstance(attr, cached_property): - return _SIMPLE_SETATTR_HANDLERS['cached_property'] - - _check_frozen(cls, name, value) - - # We allow properties to be set only on non frozen models for now (to match dataclasses). - # This can be changed if it ever gets requested. - if isinstance(attr, property): - return lambda model, _name, val: attr.__set__(model, val) - elif cls.model_config.get('validate_assignment'): - return _SIMPLE_SETATTR_HANDLERS['validate_assignment'] - elif name not in cls.__pydantic_fields__: - if cls.model_config.get('extra') != 'allow': - # TODO - matching error - raise ValueError(f'"{cls.__name__}" object has no field "{name}"') - elif attr is None: - # attribute does not exist, so put it in extra - self.__pydantic_extra__[name] = value - return None # Can not return memoized handler with possibly freeform attr names - else: - # attribute _does_ exist, and was not in extra, so update it - return _SIMPLE_SETATTR_HANDLERS['extra_known'] + def __setattr__(self, name: str, value: Any) -> None: + if name in self.__class_vars__: + raise AttributeError( + f'{name!r} is a ClassVar of `{self.__class__.__name__}` and cannot be set on an instance. ' + f'If you want to set a value on the class, use `{self.__class__.__name__}.{name} = value`.' + ) + elif not _fields.is_valid_field_name(name): + if self.__pydantic_private__ is None or name not in self.__private_attributes__: + _object_setattr(self, name, value) else: - return _SIMPLE_SETATTR_HANDLERS['model_field'] + attribute = self.__private_attributes__[name] + if hasattr(attribute, '__set__'): + attribute.__set__(self, value) # type: ignore + else: + self.__pydantic_private__[name] = value + return + elif self.model_config.get('frozen', None): + error: pydantic_core.InitErrorDetails = { + 'type': 'frozen_instance', + 'loc': (name,), + 'input': value, + } + raise pydantic_core.ValidationError.from_exception_data(self.__class__.__name__, [error]) + elif getattr(self.model_fields.get(name), 'frozen', False): + error: pydantic_core.InitErrorDetails = { + 'type': 'frozen_field', + 'loc': (name,), + 'input': value, + } + raise pydantic_core.ValidationError.from_exception_data(self.__class__.__name__, [error]) - def __delattr__(self, item: str) -> Any: - cls = self.__class__ + attr = getattr(self.__class__, name, None) + if isinstance(attr, property): + attr.__set__(self, value) + elif self.model_config.get('validate_assignment', None): + self.__pydantic_validator__.validate_assignment(self, name, value) + elif self.model_config.get('extra') != 'allow' and name not in self.model_fields: + # TODO - matching error + raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"') + elif self.model_config.get('extra') == 'allow' and name not in self.model_fields: + # SAFETY: __pydantic_extra__ is not None when extra = 'allow' + self.__pydantic_extra__[name] = value # type: ignore + else: + self.__dict__[name] = value + self.__pydantic_fields_set__.add(name) - if item in self.__private_attributes__: - attribute = self.__private_attributes__[item] - if hasattr(attribute, '__delete__'): - attribute.__delete__(self) # type: ignore - return + def __delattr__(self, item: str) -> Any: + if item in self.__private_attributes__: + attribute = self.__private_attributes__[item] + if hasattr(attribute, '__delete__'): + attribute.__delete__(self) # type: ignore + return - try: - # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items - del self.__pydantic_private__[item] # type: ignore - return - except KeyError as exc: - raise AttributeError(f'{cls.__name__!r} object has no attribute {item!r}') from exc - - # Allow cached properties to be deleted (even if the class is frozen): - attr = getattr(cls, item, None) - if isinstance(attr, cached_property): - return object.__delattr__(self, item) - - _check_frozen(cls, name=item, value=None) - - if item in self.__pydantic_fields__: + try: + # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items + del self.__pydantic_private__[item] # type: ignore + except KeyError as exc: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc + elif item in self.model_fields: + object.__delattr__(self, item) + elif self.__pydantic_extra__ is not None and item in self.__pydantic_extra__: + del self.__pydantic_extra__[item] + else: + try: object.__delattr__(self, item) - elif self.__pydantic_extra__ is not None and item in self.__pydantic_extra__: - del self.__pydantic_extra__[item] - else: - try: - object.__delattr__(self, item) - except AttributeError: - raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') - - # Because we make use of `@dataclass_transform()`, `__replace__` is already synthesized by - # type checkers, so we define the implementation in this `if not TYPE_CHECKING:` block: - def __replace__(self, **changes: Any) -> Self: - return self.model_copy(update=changes) + except AttributeError: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') def __getstate__(self) -> dict[Any, Any]: private = self.__pydantic_private__ @@ -1102,69 +838,29 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): } def __setstate__(self, state: dict[Any, Any]) -> None: - _object_setattr(self, '__pydantic_fields_set__', state.get('__pydantic_fields_set__', {})) - _object_setattr(self, '__pydantic_extra__', state.get('__pydantic_extra__', {})) - _object_setattr(self, '__pydantic_private__', state.get('__pydantic_private__', {})) - _object_setattr(self, '__dict__', state.get('__dict__', {})) + _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__']) + _object_setattr(self, '__pydantic_extra__', state['__pydantic_extra__']) + _object_setattr(self, '__pydantic_private__', state['__pydantic_private__']) + _object_setattr(self, '__dict__', state['__dict__']) - if not TYPE_CHECKING: + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ - def __eq__(self, other: Any) -> bool: - if isinstance(other, BaseModel): - # When comparing instances of generic types for equality, as long as all field values are equal, - # only require their generic origin types to be equal, rather than exact type equality. - # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). - self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ - other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + return ( + self_type == other_type + and self.__dict__ == other.__dict__ + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ) + else: + return NotImplemented # delegate to the other item in the comparison - # Perform common checks first - if not ( - self_type == other_type - and getattr(self, '__pydantic_private__', None) == getattr(other, '__pydantic_private__', None) - and self.__pydantic_extra__ == other.__pydantic_extra__ - ): - return False - - # We only want to compare pydantic fields but ignoring fields is costly. - # We'll perform a fast check first, and fallback only when needed - # See GH-7444 and GH-7825 for rationale and a performance benchmark - - # First, do the fast (and sometimes faulty) __dict__ comparison - if self.__dict__ == other.__dict__: - # If the check above passes, then pydantic fields are equal, we can return early - return True - - # We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return - # early if there are no keys to ignore (we would just return False later on anyway) - model_fields = type(self).__pydantic_fields__.keys() - if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields: - return False - - # If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore - # Resort to costly filtering of the __dict__ objects - # We use operator.itemgetter because it is much faster than dict comprehensions - # NOTE: Contrary to standard python class and instances, when the Model class has a default value for an - # attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute - # raises an error in BaseModel.__getattr__ instead of returning the class attribute - # So we can use operator.itemgetter() instead of operator.attrgetter() - getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL - try: - return getter(self.__dict__) == getter(other.__dict__) - except KeyError: - # In rare cases (such as when using the deprecated BaseModel.copy() method), - # the __dict__ may not contain all model fields, which is how we can get here. - # getter(self.__dict__) is much faster than any 'safe' method that accounts - # for missing keys, and wrapping it in a `try` doesn't slow things down much - # in the common case. - self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__) - other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__) - return getter(self_fields_proxy) == getter(other_fields_proxy) - - # other instance is not a BaseModel - else: - return NotImplemented # delegate to the other item in the comparison - - if TYPE_CHECKING: + if typing.TYPE_CHECKING: # We put `__init_subclass__` in a TYPE_CHECKING block because, even though we want the type-checking benefits # described in the signature of `__init_subclass__` below, we don't want to modify the default behavior of # subclass initialization. @@ -1173,10 +869,11 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): """This signature is included purely to help type-checkers check arguments to class declaration, which provides a way to conveniently set model_config key/value pairs. - ```python + ```py from pydantic import BaseModel - class MyModel(BaseModel, extra='allow'): ... + class MyModel(BaseModel, extra='allow'): + ... ``` However, this may be deceiving, since the _actual_ calls to `__init_subclass__` will not receive any @@ -1202,20 +899,11 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return f'{self.__repr_name__()}({self.__repr_str__(", ")})' def __repr_args__(self) -> _repr.ReprArgs: - # Eagerly create the repr of computed fields, as this may trigger access of cached properties and as such - # modify the instance's `__dict__`. If we don't do it now, it could happen when iterating over the `__dict__` - # below if the instance happens to be referenced in a field, and would modify the `__dict__` size *during* iteration. - computed_fields_repr_args = [ - (k, getattr(self, k)) for k, v in self.__pydantic_computed_fields__.items() if v.repr - ] - for k, v in self.__dict__.items(): - field = self.__pydantic_fields__.get(k) + field = self.model_fields.get(k) if field and field.repr: - if v is not self: - yield k, v - else: - yield k, self.__repr_recursion__(v) + yield k, v + # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized. # This can happen if a `ValidationError` is raised during initialization and the instance's # repr is generated as part of the exception handling. Therefore, we use `getattr` here @@ -1227,11 +915,10 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): if pydantic_extra is not None: yield from ((k, v) for k, v in pydantic_extra.items()) - yield from computed_fields_repr_args + yield from ((k, getattr(self, k)) for k, v in self.model_computed_fields.items() if v.repr) # take logic from `_repr.Representation` without the side effects of inheritance, see #5740 __repr_name__ = _repr.Representation.__repr_name__ - __repr_recursion__ = _repr.Representation.__repr_recursion__ __repr_str__ = _repr.Representation.__repr_str__ __pretty__ = _repr.Representation.__pretty__ __rich_repr__ = _repr.Representation.__rich_repr__ @@ -1242,45 +929,37 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): # ##### Deprecated methods from v1 ##### @property @typing_extensions.deprecated( - 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None + 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20 ) def __fields__(self) -> dict[str, FieldInfo]: - warnings.warn( - 'The `__fields__` attribute is deprecated, use `model_fields` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) - return getattr(type(self), '__pydantic_fields__', {}) + warnings.warn('The `__fields__` attribute is deprecated, use `model_fields` instead.', DeprecationWarning) + return self.model_fields @property @typing_extensions.deprecated( 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', - category=None, + category=PydanticDeprecatedSince20, ) def __fields_set__(self) -> set[str]: warnings.warn( - 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', DeprecationWarning ) return self.__pydantic_fields_set__ - @typing_extensions.deprecated('The `dict` method is deprecated; use `model_dump` instead.', category=None) + @typing_extensions.deprecated( + 'The `dict` method is deprecated; use `model_dump` instead.', category=PydanticDeprecatedSince20 + ) def dict( # noqa: D102 self, *, - include: IncEx | None = None, - exclude: IncEx | None = None, + include: IncEx = None, + exclude: IncEx = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - ) -> Dict[str, Any]: # noqa UP006 - warnings.warn( - 'The `dict` method is deprecated; use `model_dump` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + ) -> typing.Dict[str, Any]: # noqa UP006 + warnings.warn('The `dict` method is deprecated; use `model_dump` instead.', DeprecationWarning) return self.model_dump( include=include, exclude=exclude, @@ -1290,25 +969,23 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): exclude_none=exclude_none, ) - @typing_extensions.deprecated('The `json` method is deprecated; use `model_dump_json` instead.', category=None) + @typing_extensions.deprecated( + 'The `json` method is deprecated; use `model_dump_json` instead.', category=PydanticDeprecatedSince20 + ) def json( # noqa: D102 self, *, - include: IncEx | None = None, - exclude: IncEx | None = None, + include: IncEx = None, + exclude: IncEx = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - encoder: Callable[[Any], Any] | None = PydanticUndefined, # type: ignore[assignment] + encoder: typing.Callable[[Any], Any] | None = PydanticUndefined, # type: ignore[assignment] models_as_dict: bool = PydanticUndefined, # type: ignore[assignment] **dumps_kwargs: Any, ) -> str: - warnings.warn( - 'The `json` method is deprecated; use `model_dump_json` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + warnings.warn('The `json` method is deprecated; use `model_dump_json` instead.', DeprecationWarning) if encoder is not PydanticUndefined: raise TypeError('The `encoder` argument is no longer supported; use field serializers instead.') if models_as_dict is not PydanticUndefined: @@ -1325,35 +1002,32 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): ) @classmethod - @typing_extensions.deprecated('The `parse_obj` method is deprecated; use `model_validate` instead.', category=None) - def parse_obj(cls, obj: Any) -> Self: # noqa: D102 - warnings.warn( - 'The `parse_obj` method is deprecated; use `model_validate` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + @typing_extensions.deprecated( + 'The `parse_obj` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20 + ) + def parse_obj(cls: type[Model], obj: Any) -> Model: # noqa: D102 + warnings.warn('The `parse_obj` method is deprecated; use `model_validate` instead.', DeprecationWarning) return cls.model_validate(obj) @classmethod @typing_extensions.deprecated( 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, ' 'otherwise load the data then use `model_validate` instead.', - category=None, + category=PydanticDeprecatedSince20, ) def parse_raw( # noqa: D102 - cls, + cls: type[Model], b: str | bytes, *, content_type: str | None = None, encoding: str = 'utf8', proto: DeprecatedParseProtocol | None = None, allow_pickle: bool = False, - ) -> Self: # pragma: no cover + ) -> Model: # pragma: no cover warnings.warn( 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, ' 'otherwise load the data then use `model_validate` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + DeprecationWarning, ) from .deprecated import parse @@ -1392,22 +1066,21 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): @typing_extensions.deprecated( 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON ' 'use `model_validate_json`, otherwise `model_validate` instead.', - category=None, + category=PydanticDeprecatedSince20, ) def parse_file( # noqa: D102 - cls, + cls: type[Model], path: str | Path, *, content_type: str | None = None, encoding: str = 'utf8', proto: DeprecatedParseProtocol | None = None, allow_pickle: bool = False, - ) -> Self: + ) -> Model: warnings.warn( 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON ' - 'use `model_validate_json`, otherwise `model_validate` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'use `model_validate_json` otherwise `model_validate` instead.', + DeprecationWarning, ) from .deprecated import parse @@ -1422,16 +1095,15 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): @classmethod @typing_extensions.deprecated( - 'The `from_orm` method is deprecated; set ' + "The `from_orm` method is deprecated; set " "`model_config['from_attributes']=True` and use `model_validate` instead.", - category=None, + category=PydanticDeprecatedSince20, ) - def from_orm(cls, obj: Any) -> Self: # noqa: D102 + def from_orm(cls: type[Model], obj: Any) -> Model: # noqa: D102 warnings.warn( - 'The `from_orm` method is deprecated; set ' - "`model_config['from_attributes']=True` and use `model_validate` instead.", - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The `from_orm` method is deprecated; set `model_config["from_attributes"]=True` ' + 'and use `model_validate` instead.', + DeprecationWarning, ) if not cls.model_config.get('from_attributes', None): raise PydanticUserError( @@ -1440,28 +1112,24 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return cls.model_validate(obj) @classmethod - @typing_extensions.deprecated('The `construct` method is deprecated; use `model_construct` instead.', category=None) - def construct(cls, _fields_set: set[str] | None = None, **values: Any) -> Self: # noqa: D102 - warnings.warn( - 'The `construct` method is deprecated; use `model_construct` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + @typing_extensions.deprecated( + 'The `construct` method is deprecated; use `model_construct` instead.', category=PydanticDeprecatedSince20 + ) + def construct(cls: type[Model], _fields_set: set[str] | None = None, **values: Any) -> Model: # noqa: D102 + warnings.warn('The `construct` method is deprecated; use `model_construct` instead.', DeprecationWarning) return cls.model_construct(_fields_set=_fields_set, **values) @typing_extensions.deprecated( - 'The `copy` method is deprecated; use `model_copy` instead. ' - 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.', - category=None, + 'The copy method is deprecated; use `model_copy` instead.', category=PydanticDeprecatedSince20 ) def copy( - self, + self: Model, *, include: AbstractSetIntStr | MappingIntStrAny | None = None, exclude: AbstractSetIntStr | MappingIntStrAny | None = None, - update: Dict[str, Any] | None = None, # noqa UP006 + update: typing.Dict[str, Any] | None = None, # noqa UP006 deep: bool = False, - ) -> Self: # pragma: no cover + ) -> Model: # pragma: no cover """Returns a copy of the model. !!! warning "Deprecated" @@ -1469,17 +1137,20 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): If you need `include` or `exclude`, use: - ```python {test="skip" lint="skip"} + ```py data = self.model_dump(include=include, exclude=exclude, round_trip=True) data = {**data, **(update or {})} copied = self.model_validate(data) ``` Args: - include: Optional set or mapping specifying which fields to include in the copied model. - exclude: Optional set or mapping specifying which fields to exclude in the copied model. - update: Optional dictionary of field-value pairs to override field values in the copied model. - deep: If True, the values of fields that are Pydantic models will be deep-copied. + include: Optional set or mapping + specifying which fields to include in the copied model. + exclude: Optional set or mapping + specifying which fields to exclude in the copied model. + update: Optional dictionary of field-value pairs to override field values + in the copied model. + deep: If True, the values of fields that are Pydantic models will be deep copied. Returns: A copy of the model with included, excluded and updated fields as specified. @@ -1487,8 +1158,7 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): warnings.warn( 'The `copy` method is deprecated; use `model_copy` instead. ' 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.', - category=PydanticDeprecatedSince20, - stacklevel=2, + DeprecationWarning, ) from .deprecated import copy_internals @@ -1527,32 +1197,29 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): return copy_internals._copy_and_set_values(self, values, fields_set, extra, private, deep=deep) @classmethod - @typing_extensions.deprecated('The `schema` method is deprecated; use `model_json_schema` instead.', category=None) + @typing_extensions.deprecated( + 'The `schema` method is deprecated; use `model_json_schema` instead.', category=PydanticDeprecatedSince20 + ) def schema( # noqa: D102 cls, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE - ) -> Dict[str, Any]: # noqa UP006 - warnings.warn( - 'The `schema` method is deprecated; use `model_json_schema` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + ) -> typing.Dict[str, Any]: # noqa UP006 + warnings.warn('The `schema` method is deprecated; use `model_json_schema` instead.', DeprecationWarning) return cls.model_json_schema(by_alias=by_alias, ref_template=ref_template) @classmethod @typing_extensions.deprecated( 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.', - category=None, + category=PydanticDeprecatedSince20, ) def schema_json( # noqa: D102 cls, *, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, **dumps_kwargs: Any ) -> str: # pragma: no cover - warnings.warn( - 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) import json + warnings.warn( + 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.', + DeprecationWarning, + ) from .deprecated.json import pydantic_encoder return json.dumps( @@ -1562,52 +1229,44 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): ) @classmethod - @typing_extensions.deprecated('The `validate` method is deprecated; use `model_validate` instead.', category=None) - def validate(cls, value: Any) -> Self: # noqa: D102 - warnings.warn( - 'The `validate` method is deprecated; use `model_validate` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + @typing_extensions.deprecated( + 'The `validate` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20 + ) + def validate(cls: type[Model], value: Any) -> Model: # noqa: D102 + warnings.warn('The `validate` method is deprecated; use `model_validate` instead.', DeprecationWarning) return cls.model_validate(value) @classmethod @typing_extensions.deprecated( 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', - category=None, + category=PydanticDeprecatedSince20, ) def update_forward_refs(cls, **localns: Any) -> None: # noqa: D102 warnings.warn( - 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', DeprecationWarning ) if localns: # pragma: no cover raise TypeError('`localns` arguments are not longer accepted.') cls.model_rebuild(force=True) @typing_extensions.deprecated( - 'The private method `_iter` will be removed and should no longer be used.', category=None + 'The private method `_iter` will be removed and should no longer be used.', category=PydanticDeprecatedSince20 ) def _iter(self, *args: Any, **kwargs: Any) -> Any: - warnings.warn( - 'The private method `_iter` will be removed and should no longer be used.', - category=PydanticDeprecatedSince20, - stacklevel=2, - ) + warnings.warn('The private method `_iter` will be removed and should no longer be used.', DeprecationWarning) + from .deprecated import copy_internals return copy_internals._iter(self, *args, **kwargs) @typing_extensions.deprecated( 'The private method `_copy_and_set_values` will be removed and should no longer be used.', - category=None, + category=PydanticDeprecatedSince20, ) def _copy_and_set_values(self, *args: Any, **kwargs: Any) -> Any: warnings.warn( - 'The private method `_copy_and_set_values` will be removed and should no longer be used.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The private method `_copy_and_set_values` will be removed and should no longer be used.', + DeprecationWarning, ) from .deprecated import copy_internals @@ -1616,101 +1275,84 @@ class BaseModel(metaclass=_model_construction.ModelMetaclass): @classmethod @typing_extensions.deprecated( 'The private method `_get_value` will be removed and should no longer be used.', - category=None, + category=PydanticDeprecatedSince20, ) def _get_value(cls, *args: Any, **kwargs: Any) -> Any: warnings.warn( - 'The private method `_get_value` will be removed and should no longer be used.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The private method `_get_value` will be removed and should no longer be used.', DeprecationWarning ) + from .deprecated import copy_internals return copy_internals._get_value(cls, *args, **kwargs) @typing_extensions.deprecated( 'The private method `_calculate_keys` will be removed and should no longer be used.', - category=None, + category=PydanticDeprecatedSince20, ) def _calculate_keys(self, *args: Any, **kwargs: Any) -> Any: warnings.warn( - 'The private method `_calculate_keys` will be removed and should no longer be used.', - category=PydanticDeprecatedSince20, - stacklevel=2, + 'The private method `_calculate_keys` will be removed and should no longer be used.', DeprecationWarning ) + from .deprecated import copy_internals return copy_internals._calculate_keys(self, *args, **kwargs) -ModelT = TypeVar('ModelT', bound=BaseModel) - - -@overload +@typing.overload def create_model( - model_name: str, - /, + __model_name: str, *, __config__: ConfigDict | None = None, - __doc__: str | None = None, __base__: None = None, __module__: str = __name__, - __validators__: dict[str, Callable[..., Any]] | None = None, + __validators__: dict[str, AnyClassMethod] | None = None, __cls_kwargs__: dict[str, Any] | None = None, - **field_definitions: Any | tuple[str, Any], -) -> type[BaseModel]: ... + **field_definitions: Any, +) -> type[BaseModel]: + ... -@overload +@typing.overload def create_model( - model_name: str, - /, + __model_name: str, *, __config__: ConfigDict | None = None, - __doc__: str | None = None, - __base__: type[ModelT] | tuple[type[ModelT], ...], + __base__: type[Model] | tuple[type[Model], ...], __module__: str = __name__, - __validators__: dict[str, Callable[..., Any]] | None = None, + __validators__: dict[str, AnyClassMethod] | None = None, __cls_kwargs__: dict[str, Any] | None = None, - **field_definitions: Any | tuple[str, Any], -) -> type[ModelT]: ... + **field_definitions: Any, +) -> type[Model]: + ... -def create_model( # noqa: C901 - model_name: str, - /, +def create_model( + __model_name: str, *, __config__: ConfigDict | None = None, - __doc__: str | None = None, - __base__: type[ModelT] | tuple[type[ModelT], ...] | None = None, - __module__: str | None = None, - __validators__: dict[str, Callable[..., Any]] | None = None, + __base__: type[Model] | tuple[type[Model], ...] | None = None, + __module__: str = __name__, + __validators__: dict[str, AnyClassMethod] | None = None, __cls_kwargs__: dict[str, Any] | None = None, - # TODO PEP 747: replace `Any` by the TypeForm: - **field_definitions: Any | tuple[str, Any], -) -> type[ModelT]: - """!!! abstract "Usage Documentation" - [Dynamic Model Creation](../concepts/models.md#dynamic-model-creation) - - Dynamically creates and returns a new Pydantic model, in other words, `create_model` dynamically creates a + __slots__: tuple[str, ...] | None = None, + **field_definitions: Any, +) -> type[Model]: + """Dynamically creates and returns a new Pydantic model, in other words, `create_model` dynamically creates a subclass of [`BaseModel`][pydantic.BaseModel]. Args: - model_name: The name of the newly created model. + __model_name: The name of the newly created model. __config__: The configuration of the new model. - __doc__: The docstring of the new model. - __base__: The base class or classes for the new model. - __module__: The name of the module that the model belongs to; - if `None`, the value is taken from `sys._getframe(1)` - __validators__: A dictionary of methods that validate fields. The keys are the names of the validation methods to - be added to the model, and the values are the validation methods themselves. You can read more about functional - validators [here](https://docs.pydantic.dev/2.9/concepts/validators/#field-validators). - __cls_kwargs__: A dictionary of keyword arguments for class creation, such as `metaclass`. - **field_definitions: Field definitions of the new model. Either: - - - a single element, representing the type annotation of the field. - - a two-tuple, the first element being the type and the second element the assigned value - (either a default or the [`Field()`][pydantic.Field] function). + __base__: The base class for the new model. + __module__: The name of the module that the model belongs to. + __validators__: A dictionary of methods that validate + fields. + __cls_kwargs__: A dictionary of keyword arguments for class creation. + __slots__: Deprecated. Should not be passed to `create_model`. + **field_definitions: Attributes of the new model. They should be passed in the format: + `=(, )` or `=(, )`. Returns: The new [model][pydantic.BaseModel]. @@ -1718,56 +1360,57 @@ def create_model( # noqa: C901 Raises: PydanticUserError: If `__base__` and `__config__` are both passed. """ - if __base__ is None: - __base__ = (cast('type[ModelT]', BaseModel),) - elif not isinstance(__base__, tuple): - __base__ = (__base__,) + if __slots__ is not None: + # __slots__ will be ignored from here on + warnings.warn('__slots__ should not be passed to create_model', RuntimeWarning) + + if __base__ is not None: + if __config__ is not None: + raise PydanticUserError( + 'to avoid confusion `__config__` and `__base__` cannot be used together', + code='create-model-config-base', + ) + if not isinstance(__base__, tuple): + __base__ = (__base__,) + else: + __base__ = (typing.cast(typing.Type['Model'], BaseModel),) __cls_kwargs__ = __cls_kwargs__ or {} - fields: dict[str, Any] = {} - annotations: dict[str, Any] = {} + fields = {} + annotations = {} for f_name, f_def in field_definitions.items(): + if not _fields.is_valid_field_name(f_name): + warnings.warn(f'fields may not start with an underscore, ignoring "{f_name}"', RuntimeWarning) if isinstance(f_def, tuple): - if len(f_def) != 2: + f_def = typing.cast('tuple[str, Any]', f_def) + try: + f_annotation, f_value = f_def + except ValueError as e: raise PydanticUserError( - f'Field definition for {f_name!r} should a single element representing the type or a two-tuple, the first element ' - 'being the type and the second element the assigned value (either a default or the `Field()` function).', + 'Field definitions should be a `(, )`.', code='create-model-field-definitions', - ) - - annotations[f_name] = f_def[0] - fields[f_name] = f_def[1] + ) from e else: - annotations[f_name] = f_def + f_annotation, f_value = None, f_def - if __module__ is None: - f = sys._getframe(1) - __module__ = f.f_globals['__name__'] + if f_annotation: + annotations[f_name] = f_annotation + fields[f_name] = f_value namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__} - if __doc__: - namespace.update({'__doc__': __doc__}) if __validators__: namespace.update(__validators__) namespace.update(fields) if __config__: - namespace['model_config'] = __config__ + namespace['model_config'] = _config.ConfigWrapper(__config__).config_dict resolved_bases = types.resolve_bases(__base__) - meta, ns, kwds = types.prepare_class(model_name, resolved_bases, kwds=__cls_kwargs__) + meta, ns, kwds = types.prepare_class(__model_name, resolved_bases, kwds=__cls_kwargs__) if resolved_bases is not __base__: ns['__orig_bases__'] = __base__ namespace.update(ns) - - return meta( - model_name, - resolved_bases, - namespace, - __pydantic_reset_parent_namespace__=False, - _create_model_module=__module__, - **kwds, - ) + return meta(__model_name, resolved_bases, namespace, __pydantic_reset_parent_namespace__=False, **kwds) __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/mypy.py b/venv/lib/python3.12/site-packages/pydantic/mypy.py index 1ac1b87..c2eb919 100644 --- a/venv/lib/python3.12/site-packages/pydantic/mypy.py +++ b/venv/lib/python3.12/site-packages/pydantic/mypy.py @@ -3,9 +3,8 @@ from __future__ import annotations import sys -from collections.abc import Iterator from configparser import ConfigParser -from typing import Any, Callable +from typing import Any, Callable, Iterator from mypy.errorcodes import ErrorCode from mypy.expandtype import expand_type, expand_type_by_instance @@ -15,7 +14,6 @@ from mypy.nodes import ( ARG_OPT, ARG_POS, ARG_STAR2, - INVARIANT, MDEF, Argument, AssignmentStmt, @@ -47,24 +45,26 @@ from mypy.options import Options from mypy.plugin import ( CheckerPluginInterface, ClassDefContext, + FunctionContext, MethodContext, Plugin, ReportConfigContext, SemanticAnalyzerPluginInterface, ) +from mypy.plugins import dataclasses from mypy.plugins.common import ( deserialize_and_fixup_type, ) from mypy.semanal import set_callable_name from mypy.server.trigger import make_wildcard_trigger from mypy.state import state -from mypy.type_visitor import TypeTranslator from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, CallableType, Instance, NoneType, + Overloaded, Type, TypeOfAny, TypeType, @@ -79,11 +79,16 @@ from mypy.version import __version__ as mypy_version from pydantic._internal import _fields from pydantic.version import parse_mypy_version +try: + from mypy.types import TypeVarDef # type: ignore[attr-defined] +except ImportError: # pragma: no cover + # Backward-compatible with TypeVarDef from Mypy 0.930. + from mypy.types import TypeVarType as TypeVarDef + CONFIGFILE_KEY = 'pydantic-mypy' METADATA_KEY = 'pydantic-mypy-metadata' BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings' -ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel' MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass' FIELD_FULLNAME = 'pydantic.fields.Field' DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' @@ -96,11 +101,10 @@ DECORATOR_FULLNAMES = { 'pydantic.deprecated.class_validators.validator', 'pydantic.deprecated.class_validators.root_validator', } -IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES = DECORATOR_FULLNAMES - {'pydantic.functional_serializers.model_serializer'} MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) -BUILTINS_NAME = 'builtins' +BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' # Increment version if plugin changes and mypy caches should be invalidated __version__ = 2 @@ -129,12 +133,12 @@ class PydanticPlugin(Plugin): self._plugin_data = self.plugin_config.to_data() super().__init__(options) - def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: + def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None: """Update Pydantic model class.""" sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): # pragma: no branch # No branching may occur if the mypy cache has not been cleared - if sym.node.has_base(BASEMODEL_FULLNAME): + if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro): return self._pydantic_model_class_maker_callback return None @@ -144,12 +148,28 @@ class PydanticPlugin(Plugin): return self._pydantic_model_metaclass_marker_callback return None + def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: + """Adjust the return type of the `Field` function.""" + sym = self.lookup_fully_qualified(fullname) + if sym and sym.fullname == FIELD_FULLNAME: + return self._pydantic_field_callback + return None + def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: """Adjust return type of `from_orm` method call.""" if fullname.endswith('.from_orm'): return from_attributes_callback return None + def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: + """Mark pydantic.dataclasses as dataclass. + + Mypy version 1.1.1 added support for `@dataclass_transform` decorator. + """ + if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1): + return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] + return None + def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]: """Return all plugin config data. @@ -157,9 +177,9 @@ class PydanticPlugin(Plugin): """ return self._plugin_data - def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: + def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) - transformer.transform() + return transformer.transform() def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: """Reset dataclass_transform_spec attribute of ModelMetaclass. @@ -174,6 +194,54 @@ class PydanticPlugin(Plugin): if getattr(info_metaclass.type, 'dataclass_transform_spec', None): info_metaclass.type.dataclass_transform_spec = None + def _pydantic_field_callback(self, ctx: FunctionContext) -> Type: + """Extract the type of the `default` argument from the Field function, and use it as the return type. + + In particular: + * Check whether the default and default_factory argument is specified. + * Output an error if both are specified. + * Retrieve the type of the argument which is specified, and use it as return type for the function. + """ + default_any_type = ctx.default_return_type + + assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()' + assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()' + default_args = ctx.args[0] + default_factory_args = ctx.args[1] + + if default_args and default_factory_args: + error_default_and_default_factory_specified(ctx.api, ctx.context) + return default_any_type + + if default_args: + default_type = ctx.arg_types[0][0] + default_arg = default_args[0] + + # Fallback to default Any type if the field is required + if not isinstance(default_arg, EllipsisExpr): + return default_type + + elif default_factory_args: + default_factory_type = ctx.arg_types[1][0] + + # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter + # Pydantic calls the default factory without any argument, so we retrieve the first item + if isinstance(default_factory_type, Overloaded): + default_factory_type = default_factory_type.items[0] + + if isinstance(default_factory_type, CallableType): + ret_type = default_factory_type.ret_type + # mypy doesn't think `ret_type` has `args`, you'd think mypy should know, + # add this check in case it varies by version + args = getattr(ret_type, 'args', None) + if args: + if all(isinstance(arg, TypeVarType) for arg in args): + # Looks like the default factory is a type like `list` or `dict`, replace all args with `Any` + ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined] + return ret_type + + return default_any_type + class PydanticPluginConfig: """A Pydantic mypy plugin config holder. @@ -238,9 +306,6 @@ def from_attributes_callback(ctx: MethodContext) -> Type: pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) if pydantic_metadata is None: return ctx.default_return_type - if not model_type.type.has_base(BASEMODEL_FULLNAME): - # not a Pydantic v2 model - return ctx.default_return_type from_attributes = pydantic_metadata.get('config', {}).get('from_attributes') if from_attributes is not True: error_from_attributes(model_type.type.name, ctx.api, ctx.context) @@ -254,10 +319,8 @@ class PydanticModelField: self, name: str, alias: str | None, - is_frozen: bool, has_dynamic_alias: bool, has_default: bool, - strict: bool | None, line: int, column: int, type: Type | None, @@ -265,103 +328,40 @@ class PydanticModelField: ): self.name = name self.alias = alias - self.is_frozen = is_frozen self.has_dynamic_alias = has_dynamic_alias self.has_default = has_default - self.strict = strict self.line = line self.column = column self.type = type self.info = info - def to_argument( - self, - current_info: TypeInfo, - typed: bool, - model_strict: bool, - force_optional: bool, - use_alias: bool, - api: SemanticAnalyzerPluginInterface, - force_typevars_invariant: bool, - is_root_model_root: bool, - ) -> Argument: + def to_argument(self, current_info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument: """Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument.""" - variable = self.to_var(current_info, api, use_alias, force_typevars_invariant) - - strict = model_strict if self.strict is None else self.strict - if typed or strict: - type_annotation = self.expand_type(current_info, api, include_root_type=True) - else: - type_annotation = AnyType(TypeOfAny.explicit) - return Argument( - variable=variable, - type_annotation=type_annotation, + variable=self.to_var(current_info, use_alias), + type_annotation=self.expand_type(current_info) if typed else AnyType(TypeOfAny.explicit), initializer=None, - kind=ARG_OPT - if is_root_model_root - else (ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED), + kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED, ) - def expand_type( - self, - current_info: TypeInfo, - api: SemanticAnalyzerPluginInterface, - force_typevars_invariant: bool = False, - include_root_type: bool = False, - ) -> Type | None: + def expand_type(self, current_info: TypeInfo) -> Type | None: """Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type.""" - if force_typevars_invariant: - # In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter" - # To prevent that, we add an option to replace typevars with invariant ones while building certain - # method signatures (in particular, `__init__`). There may be a better way to do this, if this causes - # us problems in the future, we should look into why the dataclasses plugin doesn't have this issue. - if isinstance(self.type, TypeVarType): - modified_type = self.type.copy_modified() - modified_type.variance = INVARIANT - self.type = modified_type - if self.type is not None and self.info.self_type is not None: - # In general, it is not safe to call `expand_type()` during semantic analysis, + # In general, it is not safe to call `expand_type()` during semantic analyzis, # however this plugin is called very late, so all types should be fully ready. # Also, it is tricky to avoid eager expansion of Self types here (e.g. because # we serialize attributes). - with state.strict_optional_set(api.options.strict_optional): - filled_with_typevars = fill_typevars(current_info) - # Cannot be TupleType as current_info represents a Pydantic model: - assert isinstance(filled_with_typevars, Instance) - if force_typevars_invariant: - for arg in filled_with_typevars.args: - if isinstance(arg, TypeVarType): - arg.variance = INVARIANT - - expanded_type = expand_type(self.type, {self.info.self_type.id: filled_with_typevars}) - if include_root_type and isinstance(expanded_type, Instance) and is_root_model(expanded_type.type): - # When a root model is used as a field, Pydantic allows both an instance of the root model - # as well as instances of the `root` field type: - root_type = expanded_type.type['root'].type - if root_type is None: - # Happens if the hint for 'root' has unsolved forward references - return expanded_type - expanded_root_type = expand_type_by_instance(root_type, expanded_type) - expanded_type = UnionType([expanded_type, expanded_root_type]) - return expanded_type + return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) return self.type - def to_var( - self, - current_info: TypeInfo, - api: SemanticAnalyzerPluginInterface, - use_alias: bool, - force_typevars_invariant: bool = False, - ) -> Var: + def to_var(self, current_info: TypeInfo, use_alias: bool) -> Var: """Based on mypy.plugins.dataclasses.DataclassAttribute.to_var.""" if use_alias and self.alias is not None: name = self.alias else: name = self.name - return Var(name, self.expand_type(current_info, api, force_typevars_invariant)) + return Var(name, self.expand_type(current_info)) def serialize(self) -> JsonDict: """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" @@ -369,10 +369,8 @@ class PydanticModelField: return { 'name': self.name, 'alias': self.alias, - 'is_frozen': self.is_frozen, 'has_dynamic_alias': self.has_dynamic_alias, 'has_default': self.has_default, - 'strict': self.strict, 'line': self.line, 'column': self.column, 'type': self.type.serialize(), @@ -385,38 +383,12 @@ class PydanticModelField: typ = deserialize_and_fixup_type(data.pop('type'), api) return cls(type=typ, info=info, **data) - def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None: + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: """Expands type vars in the context of a subtype when an attribute is inherited from a generic super type. """ if self.type is not None: - with state.strict_optional_set(api.options.strict_optional): - self.type = map_type_from_supertype(self.type, sub_type, self.info) - - -class PydanticModelClassVar: - """Based on mypy.plugins.dataclasses.DataclassAttribute. - - ClassVars are ignored by subclasses. - - Attributes: - name: the ClassVar name - """ - - def __init__(self, name): - self.name = name - - @classmethod - def deserialize(cls, data: JsonDict) -> PydanticModelClassVar: - """Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" - data = data.copy() - return cls(**data) - - def serialize(self) -> JsonDict: - """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" - return { - 'name': self.name, - } + self.type = map_type_from_supertype(self.type, sub_type, self.info) class PydanticModelTransformer: @@ -431,10 +403,7 @@ class PydanticModelTransformer: 'frozen', 'from_attributes', 'populate_by_name', - 'validate_by_alias', - 'validate_by_name', 'alias_generator', - 'strict', } def __init__( @@ -461,26 +430,24 @@ class PydanticModelTransformer: * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses """ info = self._cls.info - is_a_root_model = is_root_model(info) config = self.collect_config() - fields, class_vars = self.collect_fields_and_class_vars(config, is_a_root_model) - if fields is None or class_vars is None: + fields = self.collect_fields(config) + if fields is None: # Some definitions are not ready. We need another pass. return False for field in fields: if field.type is None: return False - is_settings = info.has_base(BASESETTINGS_FULLNAME) - self.add_initializer(fields, config, is_settings, is_a_root_model) - self.add_model_construct_method(fields, config, is_settings, is_a_root_model) - self.set_frozen(fields, self._api, frozen=config.frozen is True) + is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1]) + self.add_initializer(fields, config, is_settings) + self.add_model_construct_method(fields, config, is_settings) + self.set_frozen(fields, frozen=config.frozen is True) self.adjust_decorator_signatures() info.metadata[METADATA_KEY] = { 'fields': {field.name: field.serialize() for field in fields}, - 'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars}, 'config': config.get_values_dict(), } @@ -494,13 +461,13 @@ class PydanticModelTransformer: Teach mypy this by marking any function whose outermost decorator is a `validator()`, `field_validator()` or `serializer()` call as a `classmethod`. """ - for sym in self._cls.info.names.values(): + for name, sym in self._cls.info.names.items(): if isinstance(sym.node, Decorator): first_dec = sym.node.original_decorators[0] if ( isinstance(first_dec, CallExpr) and isinstance(first_dec.callee, NameExpr) - and first_dec.callee.fullname in IMPLICIT_CLASSMETHOD_DECORATOR_FULLNAMES + and first_dec.callee.fullname in DECORATOR_FULLNAMES # @model_validator(mode="after") is an exception, it expects a regular method and not ( first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME @@ -543,7 +510,7 @@ class PydanticModelTransformer: for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args): if arg_name is None: continue - config.update(self.get_config_update(arg_name, arg, lax_extra=True)) + config.update(self.get_config_update(arg_name, arg)) elif isinstance(stmt.rvalue, DictExpr): # dict literals for key_expr, value_expr in stmt.rvalue.items: if not isinstance(key_expr, StrExpr): @@ -574,7 +541,7 @@ class PydanticModelTransformer: if ( stmt and config.has_alias_generator - and not (config.validate_by_name or config.populate_by_name) + and not config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases ): error_required_dynamic_aliases(self._api, stmt) @@ -589,13 +556,11 @@ class PydanticModelTransformer: config.setdefault(name, value) return config - def collect_fields_and_class_vars( - self, model_config: ModelConfigData, is_root_model: bool - ) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]: + def collect_fields(self, model_config: ModelConfigData) -> list[PydanticModelField] | None: """Collects the fields for the model, accounting for parent classes.""" cls = self._cls - # First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates. + # First, collect fields belonging to any class in the MRO, ignoring duplicates. # # We iterate through the MRO in reverse because attrs defined in the parent must appear # earlier in the attributes list than attrs defined in the child. See: @@ -605,11 +570,10 @@ class PydanticModelTransformer: # in the parent. We can implement this via a dict without disrupting the attr order # because dicts preserve insertion order in Python 3.7+. found_fields: dict[str, PydanticModelField] = {} - found_class_vars: dict[str, PydanticModelClassVar] = {} for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object # if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata: # # We haven't processed the base class yet. Need another pass. - # return None, None + # return None if METADATA_KEY not in info.metadata: continue @@ -622,7 +586,8 @@ class PydanticModelTransformer: # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. - field.expand_typevar_from_subtype(cls.info, self._api) + with state.strict_optional_set(self._api.options.strict_optional): + field.expand_typevar_from_subtype(cls.info) found_fields[name] = field sym_node = cls.info.names.get(name) @@ -631,31 +596,17 @@ class PydanticModelTransformer: 'BaseModel field may only be overridden by another field', sym_node.node, ) - # Collect ClassVars - for name, data in info.metadata[METADATA_KEY]['class_vars'].items(): - found_class_vars[name] = PydanticModelClassVar.deserialize(data) - # Second, collect fields and ClassVars belonging to the current class. + # Second, collect fields belonging to the current class. current_field_names: set[str] = set() - current_class_vars_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): - maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars) - if maybe_field is None: - continue + maybe_field = self.collect_field_from_stmt(stmt, model_config) + if maybe_field is not None: + lhs = stmt.lvalues[0] + current_field_names.add(lhs.name) + found_fields[lhs.name] = maybe_field - lhs = stmt.lvalues[0] - assert isinstance(lhs, NameExpr) # collect_field_or_class_var_from_stmt guarantees this - if isinstance(maybe_field, PydanticModelField): - if is_root_model and lhs.name != 'root': - error_extra_fields_on_root_model(self._api, stmt) - else: - current_field_names.add(lhs.name) - found_fields[lhs.name] = maybe_field - elif isinstance(maybe_field, PydanticModelClassVar): - current_class_vars_names.add(lhs.name) - found_class_vars[lhs.name] = maybe_field - - return list(found_fields.values()), list(found_class_vars.values()) + return list(found_fields.values()) def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: for body in stmt.body: @@ -671,15 +622,14 @@ class PydanticModelTransformer: elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) - def collect_field_or_class_var_from_stmt( # noqa C901 - self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar] - ) -> PydanticModelField | PydanticModelClassVar | None: + def collect_field_from_stmt( # noqa C901 + self, stmt: AssignmentStmt, model_config: ModelConfigData + ) -> PydanticModelField | None: """Get pydantic model field from statement. Args: stmt: The statement. model_config: Configuration settings for the model. - class_vars: ClassVars already known to be defined on the model. Returns: A pydantic model field if it could find the field in statement. Otherwise, `None`. @@ -702,10 +652,6 @@ class PydanticModelTransformer: # Eventually, we may want to attempt to respect model_config['ignored_types'] return None - if lhs.name in class_vars: - # Class vars are not fields and are not required to be annotated - return None - # The assignment does not have an annotation, and it's not anything else we recognize error_untyped_fields(self._api, stmt) return None @@ -750,7 +696,7 @@ class PydanticModelTransformer: # x: ClassVar[int] is not a field if node.is_classvar: - return PydanticModelClassVar(lhs.name) + return None # x: InitVar[int] is not supported in BaseModel node_type = get_proper_type(node.type) @@ -761,7 +707,6 @@ class PydanticModelTransformer: ) has_default = self.get_has_default(stmt) - strict = self.get_strict(stmt) if sym.type is None and node.is_final and node.is_inferred: # This follows the logic from the dataclasses plugin. The following comment is taken verbatim: @@ -781,27 +726,16 @@ class PydanticModelTransformer: ) node.type = AnyType(TypeOfAny.from_error) - if node.is_final and has_default: - # TODO this path should be removed (see https://github.com/pydantic/pydantic/issues/11119) - return PydanticModelClassVar(lhs.name) - alias, has_dynamic_alias = self.get_alias_info(stmt) - if ( - has_dynamic_alias - and not (model_config.validate_by_name or model_config.populate_by_name) - and self.plugin_config.warn_required_dynamic_aliases - ): + if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases: error_required_dynamic_aliases(self._api, stmt) - is_frozen = self.is_field_frozen(stmt) init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt) return PydanticModelField( name=lhs.name, has_dynamic_alias=has_dynamic_alias, has_default=has_default, - strict=strict, alias=alias, - is_frozen=is_frozen, line=stmt.line, column=stmt.column, type=init_type, @@ -846,9 +780,7 @@ class PydanticModelTransformer: return default - def add_initializer( - self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool - ) -> None: + def add_initializer(self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool) -> None: """Adds a fields-aware `__init__` method to the class. The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. @@ -857,42 +789,28 @@ class PydanticModelTransformer: return # Don't generate an __init__ if one already exists typed = self.plugin_config.init_typed - model_strict = bool(config.strict) - use_alias = not (config.validate_by_name or config.populate_by_name) and config.validate_by_alias is not False - requires_dynamic_aliases = bool(config.has_alias_generator and not config.validate_by_name) - args = self.get_field_arguments( - fields, - typed=typed, - model_strict=model_strict, - requires_dynamic_aliases=requires_dynamic_aliases, - use_alias=use_alias, - is_settings=is_settings, - is_root_model=is_root_model, - force_typevars_invariant=True, - ) - - if is_settings: - base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node - assert isinstance(base_settings_node, TypeInfo) - if '__init__' in base_settings_node.names: - base_settings_init_node = base_settings_node.names['__init__'].node - assert isinstance(base_settings_init_node, FuncDef) - if base_settings_init_node is not None and base_settings_init_node.type is not None: - func_type = base_settings_init_node.type - assert isinstance(func_type, CallableType) - for arg_idx, arg_name in enumerate(func_type.arg_names): - if arg_name is None or arg_name.startswith('__') or not arg_name.startswith('_'): - continue - analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx]) - if analyzed_variable_type is not None and arg_name == '_cli_settings_source': - # _cli_settings_source is defined as CliSettingsSource[Any], and as such - # the Any causes issues with --disallow-any-explicit. As a workaround, change - # the Any type (as if CliSettingsSource was left unparameterized): - analyzed_variable_type = analyzed_variable_type.accept( - ChangeExplicitTypeOfAny(TypeOfAny.from_omitted_generics) - ) - variable = Var(arg_name, analyzed_variable_type) - args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT)) + use_alias = config.populate_by_name is not True + requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name) + with state.strict_optional_set(self._api.options.strict_optional): + args = self.get_field_arguments( + fields, + typed=typed, + requires_dynamic_aliases=requires_dynamic_aliases, + use_alias=use_alias, + is_settings=is_settings, + ) + if is_settings: + base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node + if '__init__' in base_settings_node.names: + base_settings_init_node = base_settings_node.names['__init__'].node + if base_settings_init_node is not None and base_settings_init_node.type is not None: + func_type = base_settings_init_node.type + for arg_idx, arg_name in enumerate(func_type.arg_names): + if arg_name.startswith('__') or not arg_name.startswith('_'): + continue + analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx]) + variable = Var(arg_name, analyzed_variable_type) + args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT)) if not self.should_init_forbid_extra(fields, config): var = Var('kwargs') @@ -901,11 +819,7 @@ class PydanticModelTransformer: add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType()) def add_model_construct_method( - self, - fields: list[PydanticModelField], - config: ModelConfigData, - is_settings: bool, - is_root_model: bool, + self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool ) -> None: """Adds a fully typed `model_construct` classmethod to the class. @@ -917,19 +831,13 @@ class PydanticModelTransformer: fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) with state.strict_optional_set(self._api.options.strict_optional): args = self.get_field_arguments( - fields, - typed=True, - model_strict=bool(config.strict), - requires_dynamic_aliases=False, - use_alias=False, - is_settings=is_settings, - is_root_model=is_root_model, + fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings ) if not self.should_init_forbid_extra(fields, config): var = Var('kwargs') args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) - args = args + [fields_set_argument] if is_root_model else [fields_set_argument] + args + args = [fields_set_argument] + args add_method( self._api, @@ -940,7 +848,7 @@ class PydanticModelTransformer: is_classmethod=True, ) - def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None: + def set_frozen(self, fields: list[PydanticModelField], frozen: bool) -> None: """Marks all fields as properties so that attempts to set them trigger mypy errors. This is the same approach used by the attrs and dataclasses plugins. @@ -951,7 +859,7 @@ class PydanticModelTransformer: if sym_node is not None: var = sym_node.node if isinstance(var, Var): - var.is_property = frozen or field.is_frozen + var.is_property = frozen elif isinstance(var, PlaceholderNode) and not self._api.final_iteration: # See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage self._api.defer() @@ -965,13 +873,13 @@ class PydanticModelTransformer: detail = f'sym_node.node: {var_str} (of type {var.__class__})' error_unexpected_behavior(detail, self._api, self._cls) else: - var = field.to_var(info, api, use_alias=False) + var = field.to_var(info, use_alias=False) var.info = info var.is_property = frozen var._fullname = info.fullname + '.' + var.name info.names[var.name] = SymbolTableNode(MDEF, var) - def get_config_update(self, name: str, arg: Expression, lax_extra: bool = False) -> ModelConfigData | None: + def get_config_update(self, name: str, arg: Expression) -> ModelConfigData | None: """Determines the config update due to a single kwarg in the ConfigDict definition. Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) @@ -984,16 +892,7 @@ class PydanticModelTransformer: elif isinstance(arg, MemberExpr): forbid_extra = arg.name == 'forbid' else: - if not lax_extra: - # Only emit an error for other types of `arg` (e.g., `NameExpr`, `ConditionalExpr`, etc.) when - # reading from a config class, etc. If a ConfigDict is used, then we don't want to emit an error - # because you'll get type checking from the ConfigDict itself. - # - # It would be nice if we could introspect the types better otherwise, but I don't know what the API - # is to evaluate an expr into its type and then check if that type is compatible with the expected - # type. Note that you can still get proper type checking via: `model_config = ConfigDict(...)`, just - # if you don't use an explicit string, the plugin won't be able to infer whether extra is forbidden. - error_invalid_config_value(name, self._api, arg) + error_invalid_config_value(name, self._api, arg) return None return ModelConfigData(forbid_extra=forbid_extra) if name == 'alias_generator': @@ -1028,22 +927,6 @@ class PydanticModelTransformer: # Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) return not isinstance(expr, EllipsisExpr) - @staticmethod - def get_strict(stmt: AssignmentStmt) -> bool | None: - """Returns a the `strict` value of a field if defined, otherwise `None`.""" - expr = stmt.rvalue - if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: - for arg, name in zip(expr.args, expr.arg_names): - if name != 'strict': - continue - if isinstance(arg, NameExpr): - if arg.fullname == 'builtins.True': - return True - elif arg.fullname == 'builtins.False': - return False - return None - return None - @staticmethod def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]: """Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. @@ -1062,53 +945,23 @@ class PydanticModelTransformer: # Assigned value is not a call to pydantic.fields.Field return None, False - if 'validation_alias' in expr.arg_names: - arg = expr.args[expr.arg_names.index('validation_alias')] - elif 'alias' in expr.arg_names: - arg = expr.args[expr.arg_names.index('alias')] - else: - return None, False - - if isinstance(arg, StrExpr): - return arg.value, False - else: - return None, True - - @staticmethod - def is_field_frozen(stmt: AssignmentStmt) -> bool: - """Returns whether the field is frozen, extracted from the declaration of the field defined in `stmt`. - - Note that this is only whether the field was declared to be frozen in a ` = Field(frozen=True)` - sense; this does not determine whether the field is frozen because the entire model is frozen; that is - handled separately. - """ - expr = stmt.rvalue - if isinstance(expr, TempNode): - # TempNode means annotation-only - return False - - if not ( - isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME - ): - # Assigned value is not a call to pydantic.fields.Field - return False - for i, arg_name in enumerate(expr.arg_names): - if arg_name == 'frozen': - arg = expr.args[i] - return isinstance(arg, NameExpr) and arg.fullname == 'builtins.True' - return False + if arg_name != 'alias': + continue + arg = expr.args[i] + if isinstance(arg, StrExpr): + return arg.value, False + else: + return None, True + return None, False def get_field_arguments( self, fields: list[PydanticModelField], typed: bool, - model_strict: bool, use_alias: bool, requires_dynamic_aliases: bool, is_settings: bool, - is_root_model: bool, - force_typevars_invariant: bool = False, ) -> list[Argument]: """Helper function used during the construction of the `__init__` and `model_construct` method signatures. @@ -1117,14 +970,7 @@ class PydanticModelTransformer: info = self._cls.info arguments = [ field.to_argument( - info, - typed=typed, - model_strict=model_strict, - force_optional=requires_dynamic_aliases or is_settings, - use_alias=use_alias, - api=self._api, - force_typevars_invariant=force_typevars_invariant, - is_root_model_root=is_root_model and field.name == 'root', + info, typed=typed, force_optional=requires_dynamic_aliases or is_settings, use_alias=use_alias ) for field in fields if not (use_alias and field.has_dynamic_alias) @@ -1137,7 +983,7 @@ class PydanticModelTransformer: We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, *unless* a required dynamic alias is present (since then we can't determine a valid signature). """ - if not (config.validate_by_name or config.populate_by_name): + if not config.populate_by_name: if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): return False if config.forbid_extra: @@ -1159,20 +1005,6 @@ class PydanticModelTransformer: return False -class ChangeExplicitTypeOfAny(TypeTranslator): - """A type translator used to change type of Any's, if explicit.""" - - def __init__(self, type_of_any: int) -> None: - self._type_of_any = type_of_any - super().__init__() - - def visit_any(self, t: AnyType) -> Type: # noqa: D102 - if t.type_of_any == TypeOfAny.explicit: - return t.copy_modified(type_of_any=self._type_of_any) - else: - return t - - class ModelConfigData: """Pydantic mypy plugin model config class.""" @@ -1182,19 +1014,13 @@ class ModelConfigData: frozen: bool | None = None, from_attributes: bool | None = None, populate_by_name: bool | None = None, - validate_by_alias: bool | None = None, - validate_by_name: bool | None = None, has_alias_generator: bool | None = None, - strict: bool | None = None, ): self.forbid_extra = forbid_extra self.frozen = frozen self.from_attributes = from_attributes self.populate_by_name = populate_by_name - self.validate_by_alias = validate_by_alias - self.validate_by_name = validate_by_name self.has_alias_generator = has_alias_generator - self.strict = strict def get_values_dict(self) -> dict[str, Any]: """Returns a dict of Pydantic model config names to their values. @@ -1216,18 +1042,12 @@ class ModelConfigData: setattr(self, key, value) -def is_root_model(info: TypeInfo) -> bool: - """Return whether the type info is a root model subclass (or the `RootModel` class itself).""" - return info.has_base(ROOT_MODEL_FULLNAME) - - ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic') ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') -ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic') def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None: @@ -1264,9 +1084,9 @@ def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) -def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None: - """Emits an error when there is more than just a root field defined for a subclass of RootModel.""" - api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL) +def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: + """Emits an error when `Field` has both `default` and `default_factory` together.""" + api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) def add_method( @@ -1276,7 +1096,7 @@ def add_method( args: list[Argument], return_type: Type, self_type: Type | None = None, - tvar_def: TypeVarType | None = None, + tvar_def: TypeVarDef | None = None, is_classmethod: bool = False, ) -> None: """Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes.""" @@ -1299,16 +1119,6 @@ def add_method( first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)] else: self_type = self_type or fill_typevars(info) - # `self` is positional *ONLY* here, but this can't be expressed - # fully in the mypy internal API. ARG_POS is the closest we can get. - # Using ARG_POS will, however, give mypy errors if a `self` field - # is present on a model: - # - # Name "self" already defined (possibly by an import) [no-redef] - # - # As a workaround, we give this argument a name that will - # never conflict. By its positional nature, this name will not - # be used or exposed to users. first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] args = first + args @@ -1319,9 +1129,9 @@ def add_method( arg_names.append(arg.variable.name) arg_kinds.append(arg.kind) - signature = CallableType( - arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None - ) + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) + if tvar_def: + signature.variables = [tvar_def] func = FuncDef(name, args, Block([PassStmt()])) func.info = info diff --git a/venv/lib/python3.12/site-packages/pydantic/networks.py b/venv/lib/python3.12/site-packages/pydantic/networks.py index 2221578..7dc1e5a 100644 --- a/venv/lib/python3.12/site-packages/pydantic/networks.py +++ b/venv/lib/python3.12/site-packages/pydantic/networks.py @@ -1,33 +1,18 @@ """The networks module contains types for common network-related fields.""" - from __future__ import annotations as _annotations import dataclasses as _dataclasses import re -from dataclasses import fields -from functools import lru_cache -from importlib.metadata import version from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from typing import TYPE_CHECKING, Annotated, Any, ClassVar +from typing import TYPE_CHECKING, Any -from pydantic_core import ( - MultiHostHost, - PydanticCustomError, - PydanticSerializationUnexpectedValue, - SchemaSerializer, - core_schema, -) -from pydantic_core import MultiHostUrl as _CoreMultiHostUrl -from pydantic_core import Url as _CoreUrl -from typing_extensions import Self, TypeAlias +from pydantic_core import MultiHostUrl, PydanticCustomError, Url, core_schema +from typing_extensions import Annotated, TypeAlias -from pydantic.errors import PydanticUserError - -from ._internal import _repr, _schema_generation_shared +from ._internal import _fields, _repr, _schema_generation_shared from ._migration import getattr_migration from .annotated_handlers import GetCoreSchemaHandler from .json_schema import JsonSchemaValue -from .type_adapter import TypeAdapter if TYPE_CHECKING: import email_validator @@ -42,10 +27,7 @@ __all__ = [ 'AnyUrl', 'AnyHttpUrl', 'FileUrl', - 'FtpUrl', 'HttpUrl', - 'WebsocketUrl', - 'AnyWebsocketUrl', 'UrlConstraints', 'EmailStr', 'NameEmail', @@ -58,17 +40,14 @@ __all__ = [ 'RedisDsn', 'MongoDsn', 'KafkaDsn', - 'NatsDsn', 'validate_email', 'MySQLDsn', 'MariaDBDsn', - 'ClickHouseDsn', - 'SnowflakeDsn', ] @_dataclasses.dataclass -class UrlConstraints: +class UrlConstraints(_fields.PydanticMetadata): """Url constraints. Attributes: @@ -99,655 +78,115 @@ class UrlConstraints: ) ) - @property - def defined_constraints(self) -> dict[str, Any]: - """Fetch a key / value mapping of constraints to values that are not None. Used for core schema updates.""" - return {field.name: value for field in fields(self) if (value := getattr(self, field.name)) is not None} - def __get_pydantic_core_schema__(self, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - schema = handler(source) - - # for function-wrap schemas, url constraints is applied to the inner schema - # because when we generate schemas for urls, we wrap a core_schema.url_schema() with a function-wrap schema - # that helps with validation on initialization, see _BaseUrl and _BaseMultiHostUrl below. - schema_to_mutate = schema['schema'] if schema['type'] == 'function-wrap' else schema - if annotated_type := schema_to_mutate['type'] not in ('url', 'multi-host-url'): - raise PydanticUserError( - f"'UrlConstraints' cannot annotate '{annotated_type}'.", code='invalid-annotated-type' - ) - for constraint_key, constraint_value in self.defined_constraints.items(): - schema_to_mutate[constraint_key] = constraint_value - return schema - - -class _BaseUrl: - _constraints: ClassVar[UrlConstraints] = UrlConstraints() - _url: _CoreUrl - - def __init__(self, url: str | _CoreUrl | _BaseUrl) -> None: - self._url = _build_type_adapter(self.__class__).validate_python(url)._url - - @property - def scheme(self) -> str: - """The scheme part of the URL. - - e.g. `https` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.scheme - - @property - def username(self) -> str | None: - """The username part of the URL, or `None`. - - e.g. `user` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.username - - @property - def password(self) -> str | None: - """The password part of the URL, or `None`. - - e.g. `pass` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.password - - @property - def host(self) -> str | None: - """The host part of the URL, or `None`. - - If the URL must be punycode encoded, this is the encoded host, e.g if the input URL is `https://£££.com`, - `host` will be `xn--9aaa.com` - """ - return self._url.host - - def unicode_host(self) -> str | None: - """The host part of the URL as a unicode string, or `None`. - - e.g. `host` in `https://user:pass@host:port/path?query#fragment` - - If the URL must be punycode encoded, this is the decoded host, e.g if the input URL is `https://£££.com`, - `unicode_host()` will be `£££.com` - """ - return self._url.unicode_host() - - @property - def port(self) -> int | None: - """The port part of the URL, or `None`. - - e.g. `port` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.port - - @property - def path(self) -> str | None: - """The path part of the URL, or `None`. - - e.g. `/path` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.path - - @property - def query(self) -> str | None: - """The query part of the URL, or `None`. - - e.g. `query` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.query - - def query_params(self) -> list[tuple[str, str]]: - """The query part of the URL as a list of key-value pairs. - - e.g. `[('foo', 'bar')]` in `https://user:pass@host:port/path?foo=bar#fragment` - """ - return self._url.query_params() - - @property - def fragment(self) -> str | None: - """The fragment part of the URL, or `None`. - - e.g. `fragment` in `https://user:pass@host:port/path?query#fragment` - """ - return self._url.fragment - - def unicode_string(self) -> str: - """The URL as a unicode string, unlike `__str__()` this will not punycode encode the host. - - If the URL must be punycode encoded, this is the decoded string, e.g if the input URL is `https://£££.com`, - `unicode_string()` will be `https://£££.com` - """ - return self._url.unicode_string() - - def encoded_string(self) -> str: - """The URL's encoded string representation via __str__(). - - This returns the punycode-encoded host version of the URL as a string. - """ - return str(self) - - def __str__(self) -> str: - """The URL as a string, this will punycode encode the host if required.""" - return str(self._url) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}({str(self._url)!r})' - - def __deepcopy__(self, memo: dict) -> Self: - return self.__class__(self._url) - - def __eq__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url == other._url - - def __lt__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url < other._url - - def __gt__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url > other._url - - def __le__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url <= other._url - - def __ge__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url >= other._url - - def __hash__(self) -> int: - return hash(self._url) - - def __len__(self) -> int: - return len(str(self._url)) - - @classmethod - def build( - cls, - *, - scheme: str, - username: str | None = None, - password: str | None = None, - host: str, - port: int | None = None, - path: str | None = None, - query: str | None = None, - fragment: str | None = None, - ) -> Self: - """Build a new `Url` instance from its component parts. - - Args: - scheme: The scheme part of the URL. - username: The username part of the URL, or omit for no username. - password: The password part of the URL, or omit for no password. - host: The host part of the URL. - port: The port part of the URL, or omit for no port. - path: The path part of the URL, or omit for no path. - query: The query part of the URL, or omit for no query. - fragment: The fragment part of the URL, or omit for no fragment. - - Returns: - An instance of URL - """ - return cls( - _CoreUrl.build( - scheme=scheme, - username=username, - password=password, - host=host, - port=port, - path=path, - query=query, - fragment=fragment, - ) - ) - - @classmethod - def serialize_url(cls, url: Any, info: core_schema.SerializationInfo) -> str | Self: - if not isinstance(url, cls): - raise PydanticSerializationUnexpectedValue( - f"Expected `{cls}` but got `{type(url)}` with value `'{url}'` - serialized value may not be as expected." - ) - if info.mode == 'json': - return str(url) - return url - - @classmethod - def __get_pydantic_core_schema__( - cls, source: type[_BaseUrl], handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - def wrap_val(v, h): - if isinstance(v, source): - return v - if isinstance(v, _BaseUrl): - v = str(v) - core_url = h(v) - instance = source.__new__(source) - instance._url = core_url - return instance - - return core_schema.no_info_wrap_validator_function( - wrap_val, - schema=core_schema.url_schema(**cls._constraints.defined_constraints), - serialization=core_schema.plain_serializer_function_ser_schema( - cls.serialize_url, info_arg=True, when_used='always' - ), - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler - ) -> JsonSchemaValue: - # we use the url schema for json schema generation, but we might have to extract it from - # the function-wrap schema we use as a tool for validation on initialization - inner_schema = core_schema['schema'] if core_schema['type'] == 'function-wrap' else core_schema - return handler(inner_schema) - - __pydantic_serializer__ = SchemaSerializer(core_schema.any_schema(serialization=core_schema.to_string_ser_schema())) - - -class _BaseMultiHostUrl: - _constraints: ClassVar[UrlConstraints] = UrlConstraints() - _url: _CoreMultiHostUrl - - def __init__(self, url: str | _CoreMultiHostUrl | _BaseMultiHostUrl) -> None: - self._url = _build_type_adapter(self.__class__).validate_python(url)._url - - @property - def scheme(self) -> str: - """The scheme part of the URL. - - e.g. `https` in `https://foo.com,bar.com/path?query#fragment` - """ - return self._url.scheme - - @property - def path(self) -> str | None: - """The path part of the URL, or `None`. - - e.g. `/path` in `https://foo.com,bar.com/path?query#fragment` - """ - return self._url.path - - @property - def query(self) -> str | None: - """The query part of the URL, or `None`. - - e.g. `query` in `https://foo.com,bar.com/path?query#fragment` - """ - return self._url.query - - def query_params(self) -> list[tuple[str, str]]: - """The query part of the URL as a list of key-value pairs. - - e.g. `[('foo', 'bar')]` in `https://foo.com,bar.com/path?foo=bar#fragment` - """ - return self._url.query_params() - - @property - def fragment(self) -> str | None: - """The fragment part of the URL, or `None`. - - e.g. `fragment` in `https://foo.com,bar.com/path?query#fragment` - """ - return self._url.fragment - - def hosts(self) -> list[MultiHostHost]: - '''The hosts of the `MultiHostUrl` as [`MultiHostHost`][pydantic_core.MultiHostHost] typed dicts. - - ```python - from pydantic_core import MultiHostUrl - - mhu = MultiHostUrl('https://foo.com:123,foo:bar@bar.com/path') - print(mhu.hosts()) - """ - [ - {'username': None, 'password': None, 'host': 'foo.com', 'port': 123}, - {'username': 'foo', 'password': 'bar', 'host': 'bar.com', 'port': 443} - ] - ``` - Returns: - A list of dicts, each representing a host. - ''' - return self._url.hosts() - - def encoded_string(self) -> str: - """The URL's encoded string representation via __str__(). - - This returns the punycode-encoded host version of the URL as a string. - """ - return str(self) - - def unicode_string(self) -> str: - """The URL as a unicode string, unlike `__str__()` this will not punycode encode the hosts.""" - return self._url.unicode_string() - - def __str__(self) -> str: - """The URL as a string, this will punycode encode the host if required.""" - return str(self._url) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}({str(self._url)!r})' - - def __deepcopy__(self, memo: dict) -> Self: - return self.__class__(self._url) - - def __eq__(self, other: Any) -> bool: - return self.__class__ is other.__class__ and self._url == other._url - - def __hash__(self) -> int: - return hash(self._url) - - def __len__(self) -> int: - return len(str(self._url)) - - @classmethod - def build( - cls, - *, - scheme: str, - hosts: list[MultiHostHost] | None = None, - username: str | None = None, - password: str | None = None, - host: str | None = None, - port: int | None = None, - path: str | None = None, - query: str | None = None, - fragment: str | None = None, - ) -> Self: - """Build a new `MultiHostUrl` instance from its component parts. - - This method takes either `hosts` - a list of `MultiHostHost` typed dicts, or the individual components - `username`, `password`, `host` and `port`. - - Args: - scheme: The scheme part of the URL. - hosts: Multiple hosts to build the URL from. - username: The username part of the URL. - password: The password part of the URL. - host: The host part of the URL. - port: The port part of the URL. - path: The path part of the URL. - query: The query part of the URL, or omit for no query. - fragment: The fragment part of the URL, or omit for no fragment. - - Returns: - An instance of `MultiHostUrl` - """ - return cls( - _CoreMultiHostUrl.build( - scheme=scheme, - hosts=hosts, - username=username, - password=password, - host=host, - port=port, - path=path, - query=query, - fragment=fragment, - ) - ) - - @classmethod - def serialize_url(cls, url: Any, info: core_schema.SerializationInfo) -> str | Self: - if not isinstance(url, cls): - raise PydanticSerializationUnexpectedValue( - f"Expected `{cls}` but got `{type(url)}` with value `'{url}'` - serialized value may not be as expected." - ) - if info.mode == 'json': - return str(url) - return url - - @classmethod - def __get_pydantic_core_schema__( - cls, source: type[_BaseMultiHostUrl], handler: GetCoreSchemaHandler - ) -> core_schema.CoreSchema: - def wrap_val(v, h): - if isinstance(v, source): - return v - if isinstance(v, _BaseMultiHostUrl): - v = str(v) - core_url = h(v) - instance = source.__new__(source) - instance._url = core_url - return instance - - return core_schema.no_info_wrap_validator_function( - wrap_val, - schema=core_schema.multi_host_url_schema(**cls._constraints.defined_constraints), - serialization=core_schema.plain_serializer_function_ser_schema( - cls.serialize_url, info_arg=True, when_used='always' - ), - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler - ) -> JsonSchemaValue: - # we use the url schema for json schema generation, but we might have to extract it from - # the function-wrap schema we use as a tool for validation on initialization - inner_schema = core_schema['schema'] if core_schema['type'] == 'function-wrap' else core_schema - return handler(inner_schema) - - __pydantic_serializer__ = SchemaSerializer(core_schema.any_schema(serialization=core_schema.to_string_ser_schema())) - - -@lru_cache -def _build_type_adapter(cls: type[_BaseUrl | _BaseMultiHostUrl]) -> TypeAdapter: - return TypeAdapter(cls) - - -class AnyUrl(_BaseUrl): - """Base type for all URLs. - - * Any scheme allowed - * Top-level domain (TLD) not required - * Host not required - - Assuming an input URL of `http://samuel:pass@example.com:8000/the/path/?query=here#fragment=is;this=bit`, - the types export the following properties: - - - `scheme`: the URL scheme (`http`), always set. - - `host`: the URL host (`example.com`). - - `username`: optional username if included (`samuel`). - - `password`: optional password if included (`pass`). - - `port`: optional port (`8000`). - - `path`: optional path (`/the/path/`). - - `query`: optional URL query (for example, `GET` arguments or "search string", such as `query=here`). - - `fragment`: optional fragment (`fragment=is;this=bit`). - """ - - -# Note: all single host urls inherit from `AnyUrl` to preserve compatibility with pre-v2.10 code -# Where urls were annotated variants of `AnyUrl`, which was an alias to `pydantic_core.Url` - - -class AnyHttpUrl(AnyUrl): - """A type that will accept any http or https URL. - - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints(allowed_schemes=['http', 'https']) - - -class HttpUrl(AnyUrl): - """A type that will accept any http or https URL. - - * TLD not required - * Host not required - * Max length 2083 - - ```python - from pydantic import BaseModel, HttpUrl, ValidationError - - class MyModel(BaseModel): - url: HttpUrl - - m = MyModel(url='http://www.example.com') # (1)! - print(m.url) - #> http://www.example.com/ - - try: - MyModel(url='ftp://invalid.url') - except ValidationError as e: - print(e) - ''' - 1 validation error for MyModel - url - URL scheme should be 'http' or 'https' [type=url_scheme, input_value='ftp://invalid.url', input_type=str] - ''' - - try: - MyModel(url='not a url') - except ValidationError as e: - print(e) - ''' - 1 validation error for MyModel - url - Input should be a valid URL, relative URL without a base [type=url_parsing, input_value='not a url', input_type=str] - ''' - ``` - - 1. Note: mypy would prefer `m = MyModel(url=HttpUrl('http://www.example.com'))`, but Pydantic will convert the string to an HttpUrl instance anyway. - - "International domains" (e.g. a URL where the host or TLD includes non-ascii characters) will be encoded via - [punycode](https://en.wikipedia.org/wiki/Punycode) (see - [this article](https://www.xudongz.com/blog/2017/idn-phishing/) for a good description of why this is important): - - ```python - from pydantic import BaseModel, HttpUrl - - class MyModel(BaseModel): - url: HttpUrl - - m1 = MyModel(url='http://puny£code.com') - print(m1.url) - #> http://xn--punycode-eja.com/ - m2 = MyModel(url='https://www.аррӏе.com/') - print(m2.url) - #> https://www.xn--80ak6aa92e.com/ - m3 = MyModel(url='https://www.example.珠宝/') - print(m3.url) - #> https://www.example.xn--pbt977c/ - ``` - - - !!! warning "Underscores in Hostnames" - In Pydantic, underscores are allowed in all parts of a domain except the TLD. - Technically this might be wrong - in theory the hostname cannot have underscores, but subdomains can. - - To explain this; consider the following two cases: - - - `exam_ple.co.uk`: the hostname is `exam_ple`, which should not be allowed since it contains an underscore. - - `foo_bar.example.com` the hostname is `example`, which should be allowed since the underscore is in the subdomain. - - Without having an exhaustive list of TLDs, it would be impossible to differentiate between these two. Therefore - underscores are allowed, but you can always do further validation in a validator if desired. - - Also, Chrome, Firefox, and Safari all currently accept `http://exam_ple.com` as a URL, so we're in good - (or at least big) company. - """ - - _constraints = UrlConstraints(max_length=2083, allowed_schemes=['http', 'https']) - - -class AnyWebsocketUrl(AnyUrl): - """A type that will accept any ws or wss URL. - - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints(allowed_schemes=['ws', 'wss']) - - -class WebsocketUrl(AnyUrl): - """A type that will accept any ws or wss URL. - - * TLD not required - * Host not required - * Max length 2083 - """ - - _constraints = UrlConstraints(max_length=2083, allowed_schemes=['ws', 'wss']) - - -class FileUrl(AnyUrl): - """A type that will accept any file URL. - - * Host not required - """ - - _constraints = UrlConstraints(allowed_schemes=['file']) - - -class FtpUrl(AnyUrl): - """A type that will accept ftp URL. - - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints(allowed_schemes=['ftp']) - - -class PostgresDsn(_BaseMultiHostUrl): - """A type that will accept any Postgres DSN. - - * User info required - * TLD not required - * Host required - * Supports multiple hosts - - If further validation is required, these properties can be used by validators to enforce specific behaviour: - - ```python - from pydantic import ( - BaseModel, - HttpUrl, - PostgresDsn, - ValidationError, - field_validator, - ) - - class MyModel(BaseModel): - url: HttpUrl - - m = MyModel(url='http://www.example.com') - - # the repr() method for a url will display all properties of the url - print(repr(m.url)) - #> HttpUrl('http://www.example.com/') - print(m.url.scheme) - #> http - print(m.url.host) - #> www.example.com - print(m.url.port) - #> 80 - - class MyDatabaseModel(BaseModel): - db: PostgresDsn - - @field_validator('db') - def check_db_name(cls, v): - assert v.path and len(v.path) > 1, 'database must be provided' - return v - - m = MyDatabaseModel(db='postgres://user:pass@localhost:5432/foobar') - print(m.db) - #> postgres://user:pass@localhost:5432/foobar - - try: - MyDatabaseModel(db='postgres://user:pass@localhost:5432') - except ValidationError as e: - print(e) - ''' - 1 validation error for MyDatabaseModel - db - Assertion failed, database must be provided - assert (None) - + where None = PostgresDsn('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str] - ''' - ``` - """ - - _constraints = UrlConstraints( +AnyUrl = Url +"""Base type for all URLs. + +* Any scheme allowed +* Top-level domain (TLD) not required +* Host required + +Assuming an input URL of `http://samuel:pass@example.com:8000/the/path/?query=here#fragment=is;this=bit`, +the types export the following properties: + +- `scheme`: the URL scheme (`http`), always set. +- `host`: the URL host (`example.com`), always set. +- `username`: optional username if included (`samuel`). +- `password`: optional password if included (`pass`). +- `port`: optional port (`8000`). +- `path`: optional path (`/the/path/`). +- `query`: optional URL query (for example, `GET` arguments or "search string", such as `query=here`). +- `fragment`: optional fragment (`fragment=is;this=bit`). +""" +AnyHttpUrl = Annotated[Url, UrlConstraints(allowed_schemes=['http', 'https'])] +"""A type that will accept any http or https URL. + +* TLD not required +* Host required +""" +HttpUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'])] +"""A type that will accept any http or https URL. + +* TLD required +* Host required +* Max length 2083 + +```py +from pydantic import BaseModel, HttpUrl, ValidationError + +class MyModel(BaseModel): + url: HttpUrl + +m = MyModel(url='http://www.example.com') +print(m.url) +#> http://www.example.com/ + +try: + MyModel(url='ftp://invalid.url') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyModel + url + URL scheme should be 'http' or 'https' [type=url_scheme, input_value='ftp://invalid.url', input_type=str] + ''' + +try: + MyModel(url='not a url') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyModel + url + Input should be a valid URL, relative URL without a base [type=url_parsing, input_value='not a url', input_type=str] + ''' +``` + +"International domains" (e.g. a URL where the host or TLD includes non-ascii characters) will be encoded via +[punycode](https://en.wikipedia.org/wiki/Punycode) (see +[this article](https://www.xudongz.com/blog/2017/idn-phishing/) for a good description of why this is important): + +```py +from pydantic import BaseModel, HttpUrl + +class MyModel(BaseModel): + url: HttpUrl + +m1 = MyModel(url='http://puny£code.com') +print(m1.url) +#> http://xn--punycode-eja.com/ +m2 = MyModel(url='https://www.аррӏе.com/') +print(m2.url) +#> https://www.xn--80ak6aa92e.com/ +m3 = MyModel(url='https://www.example.珠宝/') +print(m3.url) +#> https://www.example.xn--pbt977c/ +``` + + +!!! warning "Underscores in Hostnames" + In Pydantic, underscores are allowed in all parts of a domain except the TLD. + Technically this might be wrong - in theory the hostname cannot have underscores, but subdomains can. + + To explain this; consider the following two cases: + + - `exam_ple.co.uk`: the hostname is `exam_ple`, which should not be allowed since it contains an underscore. + - `foo_bar.example.com` the hostname is `example`, which should be allowed since the underscore is in the subdomain. + + Without having an exhaustive list of TLDs, it would be impossible to differentiate between these two. Therefore + underscores are allowed, but you can always do further validation in a validator if desired. + + Also, Chrome, Firefox, and Safari all currently accept `http://exam_ple.com` as a URL, so we're in good + (or at least big) company. +""" +FileUrl = Annotated[Url, UrlConstraints(allowed_schemes=['file'])] +"""A type that will accept any file URL. + +* Host not required +""" +PostgresDsn = Annotated[ + MultiHostUrl, + UrlConstraints( host_required=True, allowed_schemes=[ 'postgres', @@ -760,116 +199,119 @@ class PostgresDsn(_BaseMultiHostUrl): 'postgresql+py-postgresql', 'postgresql+pygresql', ], - ) + ), +] +"""A type that will accept any Postgres DSN. - @property - def host(self) -> str: - """The required URL host.""" - return self._url.host # pyright: ignore[reportAttributeAccessIssue] +* User info required +* TLD not required +* Host required +* Supports multiple hosts +If further validation is required, these properties can be used by validators to enforce specific behaviour: -class CockroachDsn(AnyUrl): - """A type that will accept any Cockroach DSN. +```py +from pydantic import ( + BaseModel, + HttpUrl, + PostgresDsn, + ValidationError, + field_validator, +) - * User info required - * TLD not required - * Host required - """ +class MyModel(BaseModel): + url: HttpUrl - _constraints = UrlConstraints( +m = MyModel(url='http://www.example.com') + +# the repr() method for a url will display all properties of the url +print(repr(m.url)) +#> Url('http://www.example.com/') +print(m.url.scheme) +#> http +print(m.url.host) +#> www.example.com +print(m.url.port) +#> 80 + +class MyDatabaseModel(BaseModel): + db: PostgresDsn + + @field_validator('db') + def check_db_name(cls, v): + assert v.path and len(v.path) > 1, 'database must be provided' + return v + +m = MyDatabaseModel(db='postgres://user:pass@localhost:5432/foobar') +print(m.db) +#> postgres://user:pass@localhost:5432/foobar + +try: + MyDatabaseModel(db='postgres://user:pass@localhost:5432') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyDatabaseModel + db + Assertion failed, database must be provided + assert (None) + + where None = MultiHostUrl('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str] + ''' +``` +""" + +CockroachDsn = Annotated[ + Url, + UrlConstraints( host_required=True, allowed_schemes=[ 'cockroachdb', 'cockroachdb+psycopg2', 'cockroachdb+asyncpg', ], - ) + ), +] +"""A type that will accept any Cockroach DSN. - @property - def host(self) -> str: - """The required URL host.""" - return self._url.host # pyright: ignore[reportReturnType] +* User info required +* TLD not required +* Host required +""" +AmqpDsn = Annotated[Url, UrlConstraints(allowed_schemes=['amqp', 'amqps'])] +"""A type that will accept any AMQP DSN. +* User info required +* TLD not required +* Host required +""" +RedisDsn = Annotated[ + Url, + UrlConstraints(allowed_schemes=['redis', 'rediss'], default_host='localhost', default_port=6379, default_path='/0'), +] +"""A type that will accept any Redis DSN. -class AmqpDsn(AnyUrl): - """A type that will accept any AMQP DSN. +* User info required +* TLD not required +* Host required (e.g., `rediss://:pass@localhost`) +""" +MongoDsn = Annotated[MultiHostUrl, UrlConstraints(allowed_schemes=['mongodb', 'mongodb+srv'], default_port=27017)] +"""A type that will accept any MongoDB DSN. - * User info required - * TLD not required - * Host not required - """ +* User info not required +* Database name not required +* Port not required +* User info may be passed without user part (e.g., `mongodb://mongodb0.example.com:27017`). +""" +KafkaDsn = Annotated[Url, UrlConstraints(allowed_schemes=['kafka'], default_host='localhost', default_port=9092)] +"""A type that will accept any Kafka DSN. - _constraints = UrlConstraints(allowed_schemes=['amqp', 'amqps']) - - -class RedisDsn(AnyUrl): - """A type that will accept any Redis DSN. - - * User info required - * TLD not required - * Host required (e.g., `rediss://:pass@localhost`) - """ - - _constraints = UrlConstraints( - allowed_schemes=['redis', 'rediss'], - default_host='localhost', - default_port=6379, - default_path='/0', - host_required=True, - ) - - @property - def host(self) -> str: - """The required URL host.""" - return self._url.host # pyright: ignore[reportReturnType] - - -class MongoDsn(_BaseMultiHostUrl): - """A type that will accept any MongoDB DSN. - - * User info not required - * Database name not required - * Port not required - * User info may be passed without user part (e.g., `mongodb://mongodb0.example.com:27017`). - """ - - _constraints = UrlConstraints(allowed_schemes=['mongodb', 'mongodb+srv'], default_port=27017) - - -class KafkaDsn(AnyUrl): - """A type that will accept any Kafka DSN. - - * User info required - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints(allowed_schemes=['kafka'], default_host='localhost', default_port=9092) - - -class NatsDsn(_BaseMultiHostUrl): - """A type that will accept any NATS DSN. - - NATS is a connective technology built for the ever increasingly hyper-connected world. - It is a single technology that enables applications to securely communicate across - any combination of cloud vendors, on-premise, edge, web and mobile, and devices. - More: https://nats.io - """ - - _constraints = UrlConstraints( - allowed_schemes=['nats', 'tls', 'ws', 'wss'], default_host='localhost', default_port=4222 - ) - - -class MySQLDsn(AnyUrl): - """A type that will accept any MySQL DSN. - - * User info required - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints( +* User info required +* TLD not required +* Host required +""" +MySQLDsn = Annotated[ + Url, + UrlConstraints( allowed_schemes=[ 'mysql', 'mysql+mysqlconnector', @@ -881,63 +323,27 @@ class MySQLDsn(AnyUrl): 'mysql+pyodbc', ], default_port=3306, - host_required=True, - ) + ), +] +"""A type that will accept any MySQL DSN. - -class MariaDBDsn(AnyUrl): - """A type that will accept any MariaDB DSN. - - * User info required - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints( +* User info required +* TLD not required +* Host required +""" +MariaDBDsn = Annotated[ + Url, + UrlConstraints( allowed_schemes=['mariadb', 'mariadb+mariadbconnector', 'mariadb+pymysql'], default_port=3306, - ) + ), +] +"""A type that will accept any MariaDB DSN. - -class ClickHouseDsn(AnyUrl): - """A type that will accept any ClickHouse DSN. - - * User info required - * TLD not required - * Host not required - """ - - _constraints = UrlConstraints( - allowed_schemes=[ - 'clickhouse+native', - 'clickhouse+asynch', - 'clickhouse+http', - 'clickhouse', - 'clickhouses', - 'clickhousedb', - ], - default_host='localhost', - default_port=9000, - ) - - -class SnowflakeDsn(AnyUrl): - """A type that will accept any Snowflake DSN. - - * User info required - * TLD not required - * Host required - """ - - _constraints = UrlConstraints( - allowed_schemes=['snowflake'], - host_required=True, - ) - - @property - def host(self) -> str: - """The required URL host.""" - return self._url.host # pyright: ignore[reportReturnType] +* User info required +* TLD not required +* Host required +""" def import_email_validator() -> None: @@ -946,8 +352,6 @@ def import_email_validator() -> None: import email_validator except ImportError as e: raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e - if not version('email-validator').partition('.')[0] == '2': - raise ImportError('email-validator version >= 2.0 required, run pip install -U email-validator') if TYPE_CHECKING: @@ -966,7 +370,7 @@ else: Validate email addresses. - ```python + ```py from pydantic import BaseModel, EmailStr class Model(BaseModel): @@ -995,8 +399,8 @@ else: return field_schema @classmethod - def _validate(cls, input_value: str, /) -> str: - return validate_email(input_value)[1] + def _validate(cls, __input_value: str) -> str: + return validate_email(__input_value)[1] class NameEmail(_repr.Representation): @@ -1015,7 +419,7 @@ class NameEmail(_repr.Representation): The `NameEmail` has two properties: `name` and `email`. In case the `name` is not provided, it's inferred from the email address. - ```python + ```py from pydantic import BaseModel, NameEmail class User(BaseModel): @@ -1059,197 +463,182 @@ class NameEmail(_repr.Representation): _handler: GetCoreSchemaHandler, ) -> core_schema.CoreSchema: import_email_validator() - return core_schema.no_info_after_validator_function( cls._validate, - core_schema.json_or_python_schema( - json_schema=core_schema.str_schema(), - python_schema=core_schema.union_schema( - [core_schema.is_instance_schema(cls), core_schema.str_schema()], - custom_error_type='name_email_type', - custom_error_message='Input is not a valid NameEmail', - ), - serialization=core_schema.to_string_ser_schema(), + core_schema.union_schema( + [core_schema.is_instance_schema(cls), core_schema.str_schema()], + custom_error_type='name_email_type', + custom_error_message='Input is not a valid NameEmail', ), + serialization=core_schema.to_string_ser_schema(), ) @classmethod - def _validate(cls, input_value: Self | str, /) -> Self: - if isinstance(input_value, str): - name, email = validate_email(input_value) - return cls(name, email) + def _validate(cls, __input_value: NameEmail | str) -> NameEmail: + if isinstance(__input_value, cls): + return __input_value else: - return input_value + name, email = validate_email(__input_value) # type: ignore[arg-type] + return cls(name, email) def __str__(self) -> str: - if '@' in self.name: - return f'"{self.name}" <{self.email}>' - return f'{self.name} <{self.email}>' -IPvAnyAddressType: TypeAlias = 'IPv4Address | IPv6Address' -IPvAnyInterfaceType: TypeAlias = 'IPv4Interface | IPv6Interface' -IPvAnyNetworkType: TypeAlias = 'IPv4Network | IPv6Network' +class IPvAnyAddress: + """Validate an IPv4 or IPv6 address. -if TYPE_CHECKING: - IPvAnyAddress = IPvAnyAddressType - IPvAnyInterface = IPvAnyInterfaceType - IPvAnyNetwork = IPvAnyNetworkType -else: + ```py + from pydantic import BaseModel + from pydantic.networks import IPvAnyAddress - class IPvAnyAddress: - """Validate an IPv4 or IPv6 address. + class IpModel(BaseModel): + ip: IPvAnyAddress - ```python - from pydantic import BaseModel - from pydantic.networks import IPvAnyAddress + print(IpModel(ip='127.0.0.1')) + #> ip=IPv4Address('127.0.0.1') - class IpModel(BaseModel): - ip: IPvAnyAddress + try: + IpModel(ip='http://www.example.com') + except ValueError as e: + print(e.errors()) + ''' + [ + { + 'type': 'ip_any_address', + 'loc': ('ip',), + 'msg': 'value is not a valid IPv4 or IPv6 address', + 'input': 'http://www.example.com', + } + ] + ''' + ``` + """ - print(IpModel(ip='127.0.0.1')) - #> ip=IPv4Address('127.0.0.1') + __slots__ = () + + def __new__(cls, value: Any) -> IPv4Address | IPv6Address: + """Validate an IPv4 or IPv6 address.""" + try: + return IPv4Address(value) + except ValueError: + pass try: - IpModel(ip='http://www.example.com') - except ValueError as e: - print(e.errors()) - ''' - [ - { - 'type': 'ip_any_address', - 'loc': ('ip',), - 'msg': 'value is not a valid IPv4 or IPv6 address', - 'input': 'http://www.example.com', - } - ] - ''' - ``` - """ + return IPv6Address(value) + except ValueError: + raise PydanticCustomError('ip_any_address', 'value is not a valid IPv4 or IPv6 address') - __slots__ = () + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanyaddress') + return field_schema - def __new__(cls, value: Any) -> IPvAnyAddressType: - """Validate an IPv4 or IPv6 address.""" - try: - return IPv4Address(value) - except ValueError: - pass + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) - try: - return IPv6Address(value) - except ValueError: - raise PydanticCustomError('ip_any_address', 'value is not a valid IPv4 or IPv6 address') + @classmethod + def _validate(cls, __input_value: Any) -> IPv4Address | IPv6Address: + return cls(__input_value) # type: ignore[return-value] - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler - ) -> JsonSchemaValue: - field_schema = {} - field_schema.update(type='string', format='ipvanyaddress') - return field_schema - @classmethod - def __get_pydantic_core_schema__( - cls, - _source: type[Any], - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - return core_schema.no_info_plain_validator_function( - cls._validate, serialization=core_schema.to_string_ser_schema() - ) +class IPvAnyInterface: + """Validate an IPv4 or IPv6 interface.""" - @classmethod - def _validate(cls, input_value: Any, /) -> IPvAnyAddressType: - return cls(input_value) # type: ignore[return-value] + __slots__ = () - class IPvAnyInterface: + def __new__(cls, value: NetworkType) -> IPv4Interface | IPv6Interface: """Validate an IPv4 or IPv6 interface.""" + try: + return IPv4Interface(value) + except ValueError: + pass - __slots__ = () + try: + return IPv6Interface(value) + except ValueError: + raise PydanticCustomError('ip_any_interface', 'value is not a valid IPv4 or IPv6 interface') - def __new__(cls, value: NetworkType) -> IPvAnyInterfaceType: - """Validate an IPv4 or IPv6 interface.""" - try: - return IPv4Interface(value) - except ValueError: - pass + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanyinterface') + return field_schema - try: - return IPv6Interface(value) - except ValueError: - raise PydanticCustomError('ip_any_interface', 'value is not a valid IPv4 or IPv6 interface') + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler - ) -> JsonSchemaValue: - field_schema = {} - field_schema.update(type='string', format='ipvanyinterface') - return field_schema + @classmethod + def _validate(cls, __input_value: NetworkType) -> IPv4Interface | IPv6Interface: + return cls(__input_value) # type: ignore[return-value] - @classmethod - def __get_pydantic_core_schema__( - cls, - _source: type[Any], - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - return core_schema.no_info_plain_validator_function( - cls._validate, serialization=core_schema.to_string_ser_schema() - ) - @classmethod - def _validate(cls, input_value: NetworkType, /) -> IPvAnyInterfaceType: - return cls(input_value) # type: ignore[return-value] +class IPvAnyNetwork: + """Validate an IPv4 or IPv6 network.""" - class IPvAnyNetwork: + __slots__ = () + + def __new__(cls, value: NetworkType) -> IPv4Network | IPv6Network: """Validate an IPv4 or IPv6 network.""" + # Assume IP Network is defined with a default value for `strict` argument. + # Define your own class if you want to specify network address check strictness. + try: + return IPv4Network(value) + except ValueError: + pass - __slots__ = () + try: + return IPv6Network(value) + except ValueError: + raise PydanticCustomError('ip_any_network', 'value is not a valid IPv4 or IPv6 network') - def __new__(cls, value: NetworkType) -> IPvAnyNetworkType: - """Validate an IPv4 or IPv6 network.""" - # Assume IP Network is defined with a default value for `strict` argument. - # Define your own class if you want to specify network address check strictness. - try: - return IPv4Network(value) - except ValueError: - pass + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanynetwork') + return field_schema - try: - return IPv6Network(value) - except ValueError: - raise PydanticCustomError('ip_any_network', 'value is not a valid IPv4 or IPv6 network') + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) - @classmethod - def __get_pydantic_json_schema__( - cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler - ) -> JsonSchemaValue: - field_schema = {} - field_schema.update(type='string', format='ipvanynetwork') - return field_schema - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source: type[Any], - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - return core_schema.no_info_plain_validator_function( - cls._validate, serialization=core_schema.to_string_ser_schema() - ) - - @classmethod - def _validate(cls, input_value: NetworkType, /) -> IPvAnyNetworkType: - return cls(input_value) # type: ignore[return-value] + @classmethod + def _validate(cls, __input_value: NetworkType) -> IPv4Network | IPv6Network: + return cls(__input_value) # type: ignore[return-value] def _build_pretty_email_regex() -> re.Pattern[str]: name_chars = r'[\w!#$%&\'*+\-/=?^_`{|}~]' - unquoted_name_group = rf'((?:{name_chars}+\s+)*{name_chars}+)' + unquoted_name_group = fr'((?:{name_chars}+\s+)*{name_chars}+)' quoted_name_group = r'"((?:[^"]|\")+)"' - email_group = r'<(.+)>' + email_group = r'<\s*(.+)\s*>' return re.compile(rf'\s*(?:{unquoted_name_group}|{quoted_name_group})?\s*{email_group}\s*') @@ -1264,13 +653,6 @@ A somewhat arbitrary but very generous number compared to what is allowed by mos def validate_email(value: str) -> tuple[str, str]: """Email address validation using [email-validator](https://pypi.org/project/email-validator/). - Returns: - A tuple containing the local part of the email (or the name for "pretty" email addresses) - and the normalized email. - - Raises: - PydanticCustomError: If the email is invalid. - Note: Note that: diff --git a/venv/lib/python3.12/site-packages/pydantic/parse.py b/venv/lib/python3.12/site-packages/pydantic/parse.py index 68b7f04..ceee634 100644 --- a/venv/lib/python3.12/site-packages/pydantic/parse.py +++ b/venv/lib/python3.12/site-packages/pydantic/parse.py @@ -1,5 +1,4 @@ """The `parse` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/plugin/__init__.py b/venv/lib/python3.12/site-packages/pydantic/plugin/__init__.py index 3620305..78d2271 100644 --- a/venv/lib/python3.12/site-packages/pydantic/plugin/__init__.py +++ b/venv/lib/python3.12/site-packages/pydantic/plugin/__init__.py @@ -1,12 +1,10 @@ -"""!!! abstract "Usage Documentation" - [Build a Plugin](../concepts/plugins.md#build-a-plugin) +"""Usage docs: https://docs.pydantic.dev/2.4/concepts/plugins#build-a-plugin Plugin interface for Pydantic plugins, and related types. """ - from __future__ import annotations -from typing import Any, Callable, Literal, NamedTuple +from typing import Any, Callable from pydantic_core import CoreConfig, CoreSchema, ValidationError from typing_extensions import Protocol, TypeAlias @@ -18,32 +16,17 @@ __all__ = ( 'ValidateJsonHandlerProtocol', 'ValidateStringsHandlerProtocol', 'NewSchemaReturns', - 'SchemaTypePath', - 'SchemaKind', ) NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]' -class SchemaTypePath(NamedTuple): - """Path defining where `schema_type` was defined, or where `TypeAdapter` was called.""" - - module: str - name: str - - -SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'] - - class PydanticPluginProtocol(Protocol): """Protocol defining the interface for Pydantic plugins.""" def new_schema_validator( self, schema: CoreSchema, - schema_type: Any, - schema_type_path: SchemaTypePath, - schema_kind: SchemaKind, config: CoreConfig | None, plugin_settings: dict[str, object], ) -> tuple[ @@ -57,9 +40,6 @@ class PydanticPluginProtocol(Protocol): Args: schema: The schema to validate against. - schema_type: The original type which the schema was created from, e.g. the model class. - schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called. - schema_kind: The kind of schema to validate against. config: The config to use for validation. plugin_settings: Any plugin settings. @@ -96,14 +76,6 @@ class BaseValidateHandlerProtocol(Protocol): """ return - def on_exception(self, exception: Exception) -> None: - """Callback to be notified of validation exceptions. - - Args: - exception: The exception raised during validation. - """ - return - class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): """Event handler for `SchemaValidator.validate_python`.""" @@ -116,8 +88,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): from_attributes: bool | None = None, context: dict[str, Any] | None = None, self_instance: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> None: """Callback to be notified of validation start, and create an instance of the event handler. @@ -128,8 +98,6 @@ class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): context: The context to use for validation, this is passed to functional validators. self_instance: An instance of a model to set attributes on from validation, this is used when running validation from the `__init__` method of a model. - by_alias: Whether to use the field's alias to match the input data to an attribute. - by_name: Whether to use the field's name to match the input data to an attribute. """ pass @@ -144,8 +112,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): strict: bool | None = None, context: dict[str, Any] | None = None, self_instance: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> None: """Callback to be notified of validation start, and create an instance of the event handler. @@ -155,8 +121,6 @@ class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): context: The context to use for validation, this is passed to functional validators. self_instance: An instance of a model to set attributes on from validation, this is used when running validation from the `__init__` method of a model. - by_alias: Whether to use the field's alias to match the input data to an attribute. - by_name: Whether to use the field's name to match the input data to an attribute. """ pass @@ -168,13 +132,7 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol): """Event handler for `SchemaValidator.validate_strings`.""" def on_enter( - self, - input: StringInput, - *, - strict: bool | None = None, - context: dict[str, Any] | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, + self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None ) -> None: """Callback to be notified of validation start, and create an instance of the event handler. @@ -182,7 +140,5 @@ class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol): input: The string data to be validated. strict: Whether to validate the object in strict mode. context: The context to use for validation, this is passed to functional validators. - by_alias: Whether to use the field's alias to match the input data to an attribute. - by_name: Whether to use the field's name to match the input data to an attribute. """ pass diff --git a/venv/lib/python3.12/site-packages/pydantic/plugin/_loader.py b/venv/lib/python3.12/site-packages/pydantic/plugin/_loader.py index 7d1f0f2..b30143b 100644 --- a/venv/lib/python3.12/site-packages/pydantic/plugin/_loader.py +++ b/venv/lib/python3.12/site-packages/pydantic/plugin/_loader.py @@ -1,10 +1,16 @@ from __future__ import annotations -import importlib.metadata as importlib_metadata -import os +import sys import warnings -from collections.abc import Iterable -from typing import TYPE_CHECKING, Final +from typing import TYPE_CHECKING, Iterable + +from typing_extensions import Final + +if sys.version_info >= (3, 8): + import importlib.metadata as importlib_metadata +else: + import importlib_metadata + if TYPE_CHECKING: from . import PydanticPluginProtocol @@ -24,13 +30,10 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]: Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402 """ - disabled_plugins = os.getenv('PYDANTIC_DISABLE_PLUGINS') global _plugins, _loading_plugins if _loading_plugins: # this happens when plugins themselves use pydantic, we return no plugins return () - elif disabled_plugins in ('__all__', '1', 'true'): - return () elif _plugins is None: _plugins = {} # set _loading_plugins so any plugins that use pydantic don't themselves use plugins @@ -42,8 +45,6 @@ def get_plugins() -> Iterable[PydanticPluginProtocol]: continue if entry_point.value in _plugins: continue - if disabled_plugins is not None and entry_point.name in disabled_plugins.split(','): - continue try: _plugins[entry_point.value] = entry_point.load() except (ImportError, AttributeError) as e: diff --git a/venv/lib/python3.12/site-packages/pydantic/plugin/_schema_validator.py b/venv/lib/python3.12/site-packages/pydantic/plugin/_schema_validator.py index 83f2562..2ab16ce 100644 --- a/venv/lib/python3.12/site-packages/pydantic/plugin/_schema_validator.py +++ b/venv/lib/python3.12/site-packages/pydantic/plugin/_schema_validator.py @@ -1,16 +1,14 @@ """Pluggable schema validator for pydantic.""" - from __future__ import annotations import functools -from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError -from typing_extensions import ParamSpec +from typing_extensions import Literal, ParamSpec if TYPE_CHECKING: - from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath + from . import BaseValidateHandlerProtocol, PydanticPluginProtocol P = ParamSpec('P') @@ -20,33 +18,18 @@ events: list[Event] = list(Event.__args__) # type: ignore def create_schema_validator( - schema: CoreSchema, - schema_type: Any, - schema_type_module: str, - schema_type_name: str, - schema_kind: SchemaKind, - config: CoreConfig | None = None, - plugin_settings: dict[str, Any] | None = None, -) -> SchemaValidator | PluggableSchemaValidator: + schema: CoreSchema, config: CoreConfig | None = None, plugin_settings: dict[str, Any] | None = None +) -> SchemaValidator: """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed. Returns: If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`. """ - from . import SchemaTypePath from ._loader import get_plugins plugins = get_plugins() if plugins: - return PluggableSchemaValidator( - schema, - schema_type, - SchemaTypePath(schema_type_module, schema_type_name), - schema_kind, - config, - plugins, - plugin_settings or {}, - ) + return PluggableSchemaValidator(schema, config, plugins, plugin_settings or {}) # type: ignore else: return SchemaValidator(schema, config) @@ -59,9 +42,6 @@ class PluggableSchemaValidator: def __init__( self, schema: CoreSchema, - schema_type: Any, - schema_type_path: SchemaTypePath, - schema_kind: SchemaKind, config: CoreConfig | None, plugins: Iterable[PydanticPluginProtocol], plugin_settings: dict[str, Any], @@ -72,12 +52,7 @@ class PluggableSchemaValidator: json_event_handlers: list[BaseValidateHandlerProtocol] = [] strings_event_handlers: list[BaseValidateHandlerProtocol] = [] for plugin in plugins: - try: - p, j, s = plugin.new_schema_validator( - schema, schema_type, schema_type_path, schema_kind, config, plugin_settings - ) - except TypeError as e: # pragma: no cover - raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e + p, j, s = plugin.new_schema_validator(schema, config, plugin_settings) if p is not None: python_event_handlers.append(p) if j is not None: @@ -100,7 +75,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter')) on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success')) on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error')) - on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception')) @functools.wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: @@ -113,10 +87,6 @@ def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandler for on_error_handler in on_errors: on_error_handler(error) raise - except Exception as exception: - for on_exception_handler in on_exceptions: - on_exception_handler(exception) - raise else: for on_success_handler in on_successes: on_success_handler(result) diff --git a/venv/lib/python3.12/site-packages/pydantic/root_model.py b/venv/lib/python3.12/site-packages/pydantic/root_model.py index 8b3ff01..da78831 100644 --- a/venv/lib/python3.12/site-packages/pydantic/root_model.py +++ b/venv/lib/python3.12/site-packages/pydantic/root_model.py @@ -8,33 +8,25 @@ from copy import copy, deepcopy from pydantic_core import PydanticUndefined from . import PydanticUserError -from ._internal import _model_construction, _repr +from ._internal import _repr from .main import BaseModel, _object_setattr if typing.TYPE_CHECKING: - from typing import Any, Literal + from typing import Any - from typing_extensions import Self, dataclass_transform + from typing_extensions import Literal - from .fields import Field as PydanticModelField - from .fields import PrivateAttr as PydanticModelPrivateAttr + Model = typing.TypeVar('Model', bound='BaseModel') - # dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform - # takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform - # on a new metaclass. - @dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField, PydanticModelPrivateAttr)) - class _RootModelMetaclass(_model_construction.ModelMetaclass): ... -else: - _RootModelMetaclass = _model_construction.ModelMetaclass __all__ = ('RootModel',) + RootModelRootType = typing.TypeVar('RootModelRootType') -class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass): - """!!! abstract "Usage Documentation" - [`RootModel` and Custom Root Types](../concepts/models.md#rootmodel-and-custom-root-types) +class RootModel(BaseModel, typing.Generic[RootModelRootType]): + """Usage docs: https://docs.pydantic.dev/2.4/concepts/models/#rootmodel-and-custom-root-types A Pydantic `BaseModel` for the root object of the model. @@ -60,7 +52,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod ) super().__init_subclass__(**kwargs) - def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore + def __init__(__pydantic_self__, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore __tracebackhide__ = True if data: if root is not PydanticUndefined: @@ -68,12 +60,12 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod '"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments' ) root = data # type: ignore - self.__pydantic_validator__.validate_python(root, self_instance=self) + __pydantic_self__.__pydantic_validator__.validate_python(root, self_instance=__pydantic_self__) - __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess] + __init__.__pydantic_base_init__ = True @classmethod - def model_construct(cls, root: RootModelRootType, _fields_set: set[str] | None = None) -> Self: # type: ignore + def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model: """Create a new model using the provided root object and update fields set. Args: @@ -98,7 +90,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__']) _object_setattr(self, '__dict__', state['__dict__']) - def __copy__(self) -> Self: + def __copy__(self: Model) -> Model: """Returns a shallow copy of the model.""" cls = type(self) m = cls.__new__(cls) @@ -106,7 +98,7 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) return m - def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Self: + def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model: """Returns a deep copy of the model.""" cls = type(self) m = cls.__new__(cls) @@ -118,40 +110,30 @@ class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootMod if typing.TYPE_CHECKING: - def model_dump( # type: ignore + def model_dump( self, *, mode: Literal['json', 'python'] | str = 'python', include: Any = None, exclude: Any = None, - context: dict[str, Any] | None = None, - by_alias: bool | None = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, - serialize_as_any: bool = False, - ) -> Any: + warnings: bool = True, + ) -> RootModelRootType: """This method is included just to get a more accurate return type for type checkers. It is included in this `if TYPE_CHECKING:` block since no override is actually necessary. See the documentation of `BaseModel.model_dump` for more details about the arguments. - - Generally, this method will have a return type of `RootModelRootType`, assuming that `RootModelRootType` is - not a `BaseModel` subclass. If `RootModelRootType` is a `BaseModel` subclass, then the return - type will likely be `dict[str, Any]`, as `model_dump` calls are recursive. The return type could - even be something different, in the case of a custom serializer. - Thus, `Any` is used here to catch all of these cases. """ ... def __eq__(self, other: Any) -> bool: if not isinstance(other, RootModel): return NotImplemented - return self.__pydantic_fields__['root'].annotation == other.__pydantic_fields__[ - 'root' - ].annotation and super().__eq__(other) + return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other) def __repr_args__(self) -> _repr.ReprArgs: yield 'root', self.root diff --git a/venv/lib/python3.12/site-packages/pydantic/schema.py b/venv/lib/python3.12/site-packages/pydantic/schema.py index a3245a6..e290aed 100644 --- a/venv/lib/python3.12/site-packages/pydantic/schema.py +++ b/venv/lib/python3.12/site-packages/pydantic/schema.py @@ -1,5 +1,4 @@ """The `schema` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/tools.py b/venv/lib/python3.12/site-packages/pydantic/tools.py index fdc68c4..8e317c9 100644 --- a/venv/lib/python3.12/site-packages/pydantic/tools.py +++ b/venv/lib/python3.12/site-packages/pydantic/tools.py @@ -1,5 +1,4 @@ """The `tools` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/type_adapter.py b/venv/lib/python3.12/site-packages/pydantic/type_adapter.py index a6cdaba..4ee100f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/type_adapter.py +++ b/venv/lib/python3.12/site-packages/pydantic/type_adapter.py @@ -1,30 +1,93 @@ -"""Type adapter specification.""" +""" +You may have types that are not `BaseModel`s that you want to validate data against. +Or you may want to validate a `List[SomeModel]`, or dump it to JSON. +For use cases like this, Pydantic provides [`TypeAdapter`][pydantic.type_adapter.TypeAdapter], +which can be used for type validation, serialization, and JSON schema generation without creating a +[`BaseModel`][pydantic.main.BaseModel]. + +A [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] instance exposes some of the functionality from +[`BaseModel`][pydantic.main.BaseModel] instance methods for types that do not have such methods +(such as dataclasses, primitive types, and more): + +```py +from typing import List + +from typing_extensions import TypedDict + +from pydantic import TypeAdapter, ValidationError + +class User(TypedDict): + name: str + id: int + +UserListValidator = TypeAdapter(List[User]) +print(repr(UserListValidator.validate_python([{'name': 'Fred', 'id': '3'}]))) +#> [{'name': 'Fred', 'id': 3}] + +try: + UserListValidator.validate_python( + [{'name': 'Fred', 'id': 'wrong', 'other': 'no'}] + ) +except ValidationError as e: + print(e) + ''' + 1 validation error for list[typed-dict] + 0.id + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='wrong', input_type=str] + ''' +``` + +Note: + Despite some overlap in use cases with [`RootModel`][pydantic.root_model.RootModel], + [`TypeAdapter`][pydantic.type_adapter.TypeAdapter] should not be used as a type annotation for + specifying fields of a `BaseModel`, etc. + +## Parsing data into a specified type + +[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] can be used to apply the parsing logic to populate Pydantic models +in a more ad-hoc way. This function behaves similarly to +[`BaseModel.model_validate`][pydantic.main.BaseModel.model_validate], +but works with arbitrary Pydantic-compatible types. + +This is especially useful when you want to parse results into a type that is not a direct subclass of +[`BaseModel`][pydantic.main.BaseModel]. For example: + +```py +from typing import List + +from pydantic import BaseModel, TypeAdapter + +class Item(BaseModel): + id: int + name: str + +# `item_data` could come from an API call, eg., via something like: +# item_data = requests.get('https://my-api.com/items').json() +item_data = [{'id': 1, 'name': 'My Item'}] + +items = TypeAdapter(List[Item]).validate_python(item_data) +print(items) +#> [Item(id=1, name='My Item')] +``` + +[`TypeAdapter`][pydantic.type_adapter.TypeAdapter] is capable of parsing data into any of the types Pydantic can +handle as fields of a [`BaseModel`][pydantic.main.BaseModel]. +""" # noqa: D212 from __future__ import annotations as _annotations import sys -from collections.abc import Callable, Iterable from dataclasses import is_dataclass -from types import FrameType -from typing import ( - Any, - Generic, - Literal, - TypeVar, - cast, - final, - overload, -) +from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, overload from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some -from typing_extensions import ParamSpec, is_typeddict +from typing_extensions import Literal, is_typeddict from pydantic.errors import PydanticUserError -from pydantic.main import BaseModel, IncEx +from pydantic.main import BaseModel -from ._internal import _config, _generate_schema, _mock_val_ser, _namespace_utils, _repr, _typing_extra, _utils +from ._internal import _config, _core_utils, _discriminated_union, _generate_schema, _typing_extra from .config import ConfigDict -from .errors import PydanticUndefinedAnnotation from .json_schema import ( DEFAULT_REF_TEMPLATE, GenerateJsonSchema, @@ -32,12 +95,67 @@ from .json_schema import ( JsonSchemaMode, JsonSchemaValue, ) -from .plugin._schema_validator import PluggableSchemaValidator, create_schema_validator +from .plugin._schema_validator import create_schema_validator T = TypeVar('T') -R = TypeVar('R') -P = ParamSpec('P') -TypeAdapterT = TypeVar('TypeAdapterT', bound='TypeAdapter') + +if TYPE_CHECKING: + # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope + IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] + + +def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema: + """`BaseModel` uses its own `__module__` to find out where it was defined + and then look for symbols to resolve forward references in those globals. + On the other hand this function can be called with arbitrary objects, + including type aliases where `__module__` (always `typing.py`) is not useful. + So instead we look at the globals in our parent stack frame. + + This works for the case where this function is called in a module that + has the target of forward references in its scope, but + does not work for more complex cases. + + For example, take the following: + + a.py + ```python + from typing import Dict, List + + IntList = List[int] + OuterDict = Dict[str, 'IntList'] + ``` + + b.py + ```python test="skip" + from a import OuterDict + + from pydantic import TypeAdapter + + IntList = int # replaces the symbol the forward reference is looking for + v = TypeAdapter(OuterDict) + v({'x': 1}) # should fail but doesn't + ``` + + If OuterDict were a `BaseModel`, this would work because it would resolve + the forward reference within the `a.py` namespace. + But `TypeAdapter(OuterDict)` + can't know what module OuterDict came from. + + In other words, the assumption that _all_ forward references exist in the + module we are being called from is not technically always true. + Although most of the time it is and it works fine for recursive models and such, + `BaseModel`'s behavior isn't perfect either and _can_ break in similar ways, + so there is no right or wrong between the two. + + But at the very least this behavior is _subtly_ different from `BaseModel`'s. + """ + local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth) + global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy() + global_ns.update(local_ns or {}) + gen = _generate_schema.GenerateSchema(config_wrapper, types_namespace=global_ns, typevars_map={}) + schema = gen.generate_schema(type_) + schema = gen.collect_definitions(schema) + return schema def _getattr_no_parents(obj: Any, attribute: str) -> Any: @@ -55,152 +173,59 @@ def _getattr_no_parents(obj: Any, attribute: str) -> Any: raise AttributeError(attribute) -def _type_has_config(type_: Any) -> bool: - """Returns whether the type has config.""" - type_ = _typing_extra.annotated_type(type_) or type_ - try: - return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_) - except TypeError: - # type is not a class - return False - - -@final class TypeAdapter(Generic[T]): - """!!! abstract "Usage Documentation" - [`TypeAdapter`](../concepts/type_adapter.md) - - Type adapters provide a flexible way to perform validation and serialization based on a Python type. + """Type adapters provide a flexible way to perform validation and serialization based on a Python type. A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods for types that do not have such methods (such as dataclasses, primitive types, and more). - **Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields. - - Args: - type: The type associated with the `TypeAdapter`. - config: Configuration for the `TypeAdapter`, should be a dictionary conforming to - [`ConfigDict`][pydantic.config.ConfigDict]. - - !!! note - You cannot provide a configuration when instantiating a `TypeAdapter` if the type you're using - has its own config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A - [`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will - be raised in this case. - _parent_depth: Depth at which to search for the [parent frame][frame-objects]. This frame is used when - resolving forward annotations during schema building, by looking for the globals and locals of this - frame. Defaults to 2, which will result in the frame where the `TypeAdapter` was instantiated. - - !!! note - This parameter is named with an underscore to suggest its private nature and discourage use. - It may be deprecated in a minor version, so we only recommend using it if you're comfortable - with potential change in behavior/support. It's default value is 2 because internally, - the `TypeAdapter` class makes another call to fetch the frame. - module: The module that passes to plugin if provided. + Note that `TypeAdapter` is not an actual type, so you cannot use it in type annotations. Attributes: core_schema: The core schema for the type. - validator: The schema validator for the type. + validator (SchemaValidator): The schema validator for the type. serializer: The schema serializer for the type. - pydantic_complete: Whether the core schema for the type is successfully built. - - ??? tip "Compatibility with `mypy`" - Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly - annotate your variable: - - ```py - from typing import Union - - from pydantic import TypeAdapter - - ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type] - ``` - - ??? info "Namespace management nuances and implementation details" - - Here, we collect some notes on namespace management, and subtle differences from `BaseModel`: - - `BaseModel` uses its own `__module__` to find out where it was defined - and then looks for symbols to resolve forward references in those globals. - On the other hand, `TypeAdapter` can be initialized with arbitrary objects, - which may not be types and thus do not have a `__module__` available. - So instead we look at the globals in our parent stack frame. - - It is expected that the `ns_resolver` passed to this function will have the correct - namespace for the type we're adapting. See the source code for `TypeAdapter.__init__` - and `TypeAdapter.rebuild` for various ways to construct this namespace. - - This works for the case where this function is called in a module that - has the target of forward references in its scope, but - does not always work for more complex cases. - - For example, take the following: - - ```python {title="a.py"} - IntList = list[int] - OuterDict = dict[str, 'IntList'] - ``` - - ```python {test="skip" title="b.py"} - from a import OuterDict - - from pydantic import TypeAdapter - - IntList = int # replaces the symbol the forward reference is looking for - v = TypeAdapter(OuterDict) - v({'x': 1}) # should fail but doesn't - ``` - - If `OuterDict` were a `BaseModel`, this would work because it would resolve - the forward reference within the `a.py` namespace. - But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from. - - In other words, the assumption that _all_ forward references exist in the - module we are being called from is not technically always true. - Although most of the time it is and it works fine for recursive models and such, - `BaseModel`'s behavior isn't perfect either and _can_ break in similar ways, - so there is no right or wrong between the two. - - But at the very least this behavior is _subtly_ different from `BaseModel`'s. """ - core_schema: CoreSchema - validator: SchemaValidator | PluggableSchemaValidator - serializer: SchemaSerializer - pydantic_complete: bool + if TYPE_CHECKING: - @overload - def __init__( - self, - type: type[T], - *, - config: ConfigDict | None = ..., - _parent_depth: int = ..., - module: str | None = ..., - ) -> None: ... + @overload + def __new__(cls, __type: type[T], *, config: ConfigDict | None = ...) -> TypeAdapter[T]: + ... - # This second overload is for unsupported special forms (such as Annotated, Union, etc.) - # Currently there is no way to type this correctly - # See https://github.com/python/typing/pull/1618 - @overload - def __init__( - self, - type: Any, - *, - config: ConfigDict | None = ..., - _parent_depth: int = ..., - module: str | None = ..., - ) -> None: ... + # this overload is for non-type things like Union[int, str] + # Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object] + # so an explicit type cast is needed + @overload + def __new__(cls, __type: T, *, config: ConfigDict | None = ...) -> TypeAdapter[T]: + ... - def __init__( - self, - type: Any, - *, - config: ConfigDict | None = None, - _parent_depth: int = 2, - module: str | None = None, - ) -> None: - if _type_has_config(type) and config is not None: + def __new__(cls, __type: Any, *, config: ConfigDict | None = ...) -> TypeAdapter[T]: + """A class representing the type adapter.""" + raise NotImplementedError + + @overload + def __init__(self, type: type[T], *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + ... + + # this overload is for non-type things like Union[int, str] + # Pyright currently handles this "correctly", but MyPy understands this as TypeAdapter[object] + # so an explicit type cast is needed + @overload + def __init__(self, type: T, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + ... + + def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth: int = 2) -> None: + """Initializes the TypeAdapter object.""" + config_wrapper = _config.ConfigWrapper(config) + + try: + type_has_config = issubclass(type, BaseModel) or is_dataclass(type) or is_typeddict(type) + except TypeError: + # type is not a class + type_has_config = False + + if type_has_config and config is not None: raise PydanticUserError( 'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.' ' These types can have their own config and setting the config via the `config`' @@ -209,313 +234,81 @@ class TypeAdapter(Generic[T]): code='type-adapter-config-unused', ) - self._type = type - self._config = config - self._parent_depth = _parent_depth - self.pydantic_complete = False - - parent_frame = self._fetch_parent_frame() - if parent_frame is not None: - globalns = parent_frame.f_globals - # Do not provide a local ns if the type adapter happens to be instantiated at the module level: - localns = parent_frame.f_locals if parent_frame.f_locals is not globalns else {} - else: - globalns = {} - localns = {} - - self._module_name = module or cast(str, globalns.get('__name__', '')) - self._init_core_attrs( - ns_resolver=_namespace_utils.NsResolver( - namespaces_tuple=_namespace_utils.NamespacesTuple(locals=localns, globals=globalns), - parent_namespace=localns, - ), - force=False, - ) - - def _fetch_parent_frame(self) -> FrameType | None: - frame = sys._getframe(self._parent_depth) - if frame.f_globals.get('__name__') == 'typing': - # Because `TypeAdapter` is generic, explicitly parametrizing the class results - # in a `typing._GenericAlias` instance, which proxies instantiation calls to the - # "real" `TypeAdapter` class and thus adding an extra frame to the call. To avoid - # pulling anything from the `typing` module, use the correct frame (the one before): - return frame.f_back - - return frame - - def _init_core_attrs( - self, ns_resolver: _namespace_utils.NsResolver, force: bool, raise_errors: bool = False - ) -> bool: - """Initialize the core schema, validator, and serializer for the type. - - Args: - ns_resolver: The namespace resolver to use when building the core schema for the adapted type. - force: Whether to force the construction of the core schema, validator, and serializer. - If `force` is set to `False` and `_defer_build` is `True`, the core schema, validator, and serializer will be set to mocks. - raise_errors: Whether to raise errors if initializing any of the core attrs fails. - - Returns: - `True` if the core schema, validator, and serializer were successfully initialized, otherwise `False`. - - Raises: - PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__` - and `raise_errors=True`. - """ - if not force and self._defer_build: - _mock_val_ser.set_type_adapter_mocks(self) - self.pydantic_complete = False - return False - + core_schema: CoreSchema try: - self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__') - self.validator = _getattr_no_parents(self._type, '__pydantic_validator__') - self.serializer = _getattr_no_parents(self._type, '__pydantic_serializer__') - - # TODO: we don't go through the rebuild logic here directly because we don't want - # to repeat all of the namespace fetching logic that we've already done - # so we simply skip to the block below that does the actual schema generation - if ( - isinstance(self.core_schema, _mock_val_ser.MockCoreSchema) - or isinstance(self.validator, _mock_val_ser.MockValSer) - or isinstance(self.serializer, _mock_val_ser.MockValSer) - ): - raise AttributeError() + core_schema = _getattr_no_parents(type, '__pydantic_core_schema__') except AttributeError: - config_wrapper = _config.ConfigWrapper(self._config) + core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1) - schema_generator = _generate_schema.GenerateSchema(config_wrapper, ns_resolver=ns_resolver) + core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema)) - try: - core_schema = schema_generator.generate_schema(self._type) - except PydanticUndefinedAnnotation: - if raise_errors: - raise - _mock_val_ser.set_type_adapter_mocks(self) - return False + core_schema = _core_utils.validate_core_schema(core_schema) - try: - self.core_schema = schema_generator.clean_schema(core_schema) - except _generate_schema.InvalidSchemaError: - _mock_val_ser.set_type_adapter_mocks(self) - return False + core_config = config_wrapper.core_config(None) + validator: SchemaValidator + try: + validator = _getattr_no_parents(type, '__pydantic_validator__') + except AttributeError: + validator = create_schema_validator(core_schema, core_config, config_wrapper.plugin_settings) - core_config = config_wrapper.core_config(None) + serializer: SchemaSerializer + try: + serializer = _getattr_no_parents(type, '__pydantic_serializer__') + except AttributeError: + serializer = SchemaSerializer(core_schema, core_config) - self.validator = create_schema_validator( - schema=self.core_schema, - schema_type=self._type, - schema_type_module=self._module_name, - schema_type_name=str(self._type), - schema_kind='TypeAdapter', - config=core_config, - plugin_settings=config_wrapper.plugin_settings, - ) - self.serializer = SchemaSerializer(self.core_schema, core_config) - - self.pydantic_complete = True - return True - - @property - def _defer_build(self) -> bool: - config = self._config if self._config is not None else self._model_config - if config: - return config.get('defer_build') is True - return False - - @property - def _model_config(self) -> ConfigDict | None: - type_: Any = _typing_extra.annotated_type(self._type) or self._type # Eg FastAPI heavily uses Annotated - if _utils.lenient_issubclass(type_, BaseModel): - return type_.model_config - return getattr(type_, '__pydantic_config__', None) - - def __repr__(self) -> str: - return f'TypeAdapter({_repr.display_as_type(self._type)})' - - def rebuild( - self, - *, - force: bool = False, - raise_errors: bool = True, - _parent_namespace_depth: int = 2, - _types_namespace: _namespace_utils.MappingNamespace | None = None, - ) -> bool | None: - """Try to rebuild the pydantic-core schema for the adapter's type. - - This may be necessary when one of the annotations is a ForwardRef which could not be resolved during - the initial attempt to build the schema, and automatic rebuilding fails. - - Args: - force: Whether to force the rebuilding of the type adapter's schema, defaults to `False`. - raise_errors: Whether to raise errors, defaults to `True`. - _parent_namespace_depth: Depth at which to search for the [parent frame][frame-objects]. This - frame is used when resolving forward annotations during schema rebuilding, by looking for - the locals of this frame. Defaults to 2, which will result in the frame where the method - was called. - _types_namespace: An explicit types namespace to use, instead of using the local namespace - from the parent frame. Defaults to `None`. - - Returns: - Returns `None` if the schema is already "complete" and rebuilding was not required. - If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`. - """ - if not force and self.pydantic_complete: - return None - - if _types_namespace is not None: - rebuild_ns = _types_namespace - elif _parent_namespace_depth > 0: - rebuild_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth, force=True) or {} - else: - rebuild_ns = {} - - # we have to manually fetch globals here because there's no type on the stack of the NsResolver - # and so we skip the globalns = get_module_ns_of(typ) call that would normally happen - globalns = sys._getframe(max(_parent_namespace_depth - 1, 1)).f_globals - ns_resolver = _namespace_utils.NsResolver( - namespaces_tuple=_namespace_utils.NamespacesTuple(locals=rebuild_ns, globals=globalns), - parent_namespace=rebuild_ns, - ) - return self._init_core_attrs(ns_resolver=ns_resolver, force=True, raise_errors=raise_errors) + self.core_schema = core_schema + self.validator = validator + self.serializer = serializer def validate_python( self, - object: Any, - /, + __object: Any, *, strict: bool | None = None, from_attributes: bool | None = None, context: dict[str, Any] | None = None, - experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> T: """Validate a Python object against the model. Args: - object: The Python object to validate against the model. + __object: The Python object to validate against the model. strict: Whether to strictly check types. from_attributes: Whether to extract data from object attributes. context: Additional context to pass to the validator. - experimental_allow_partial: **Experimental** whether to enable - [partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams. - * False / 'off': Default behavior, no partial validation. - * True / 'on': Enable partial validation. - * 'trailing-strings': Enable partial validation and allow trailing strings in the input. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. - - !!! note - When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes` - argument is not supported. Returns: The validated object. """ - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - - return self.validator.validate_python( - object, - strict=strict, - from_attributes=from_attributes, - context=context, - allow_partial=experimental_allow_partial, - by_alias=by_alias, - by_name=by_name, - ) + return self.validator.validate_python(__object, strict=strict, from_attributes=from_attributes, context=context) def validate_json( - self, - data: str | bytes | bytearray, - /, - *, - strict: bool | None = None, - context: dict[str, Any] | None = None, - experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, + self, __data: str | bytes, *, strict: bool | None = None, context: dict[str, Any] | None = None ) -> T: - """!!! abstract "Usage Documentation" - [JSON Parsing](../concepts/json.md#json-parsing) - - Validate a JSON string or bytes against the model. + """Validate a JSON string or bytes against the model. Args: - data: The JSON data to validate against the model. + __data: The JSON data to validate against the model. strict: Whether to strictly check types. context: Additional context to use during validation. - experimental_allow_partial: **Experimental** whether to enable - [partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams. - * False / 'off': Default behavior, no partial validation. - * True / 'on': Enable partial validation. - * 'trailing-strings': Enable partial validation and allow trailing strings in the input. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Returns: The validated object. """ - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) + return self.validator.validate_json(__data, strict=strict, context=context) - return self.validator.validate_json( - data, - strict=strict, - context=context, - allow_partial=experimental_allow_partial, - by_alias=by_alias, - by_name=by_name, - ) - - def validate_strings( - self, - obj: Any, - /, - *, - strict: bool | None = None, - context: dict[str, Any] | None = None, - experimental_allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, - ) -> T: + def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T: """Validate object contains string data against the model. Args: - obj: The object contains string data to validate. + __obj: The object contains string data to validate. strict: Whether to strictly check types. context: Additional context to use during validation. - experimental_allow_partial: **Experimental** whether to enable - [partial validation](../concepts/experimental.md#partial-validation), e.g. to process streams. - * False / 'off': Default behavior, no partial validation. - * True / 'on': Enable partial validation. - * 'trailing-strings': Enable partial validation and allow trailing strings in the input. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Returns: The validated object. """ - if by_alias is False and by_name is not True: - raise PydanticUserError( - 'At least one of `by_alias` or `by_name` must be set to True.', - code='validate-by-alias-and-name-false', - ) - - return self.validator.validate_strings( - obj, - strict=strict, - context=context, - allow_partial=experimental_allow_partial, - by_alias=by_alias, - by_name=by_name, - ) + return self.validator.validate_strings(__obj, strict=strict, context=context) def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None: """Get the default value for the wrapped type. @@ -531,26 +324,22 @@ class TypeAdapter(Generic[T]): def dump_python( self, - instance: T, - /, + __instance: T, *, mode: Literal['json', 'python'] = 'python', include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool | None = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, - fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: dict[str, Any] | None = None, + warnings: bool = True, ) -> Any: """Dump an instance of the adapted type to a Python object. Args: - instance: The Python object to serialize. + __instance: The Python object to serialize. mode: The output format. include: Fields to include in the output. exclude: Fields to exclude from the output. @@ -559,18 +348,13 @@ class TypeAdapter(Generic[T]): exclude_defaults: Whether to exclude fields with default values. exclude_none: Whether to exclude fields with None values. round_trip: Whether to output the serialized data in a way that is compatible with deserialization. - warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. - fallback: A function to call when an unknown value is encountered. If not provided, - a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: Additional context to pass to the serializer. + warnings: Whether to display serialization warnings. Returns: The serialized object. """ return self.serializer.to_python( - instance, + __instance, mode=mode, by_alias=by_alias, include=include, @@ -580,36 +364,26 @@ class TypeAdapter(Generic[T]): exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, - fallback=fallback, - serialize_as_any=serialize_as_any, - context=context, ) def dump_json( self, - instance: T, - /, + __instance: T, *, indent: int | None = None, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool | None = None, + by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, - fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: dict[str, Any] | None = None, + warnings: bool = True, ) -> bytes: - """!!! abstract "Usage Documentation" - [JSON Serialization](../concepts/json.md#json-serialization) - - Serialize an instance of the adapted type to JSON. + """Serialize an instance of the adapted type to JSON. Args: - instance: The instance to be serialized. + __instance: The instance to be serialized. indent: Number of spaces for JSON indentation. include: Fields to include. exclude: Fields to exclude. @@ -618,18 +392,13 @@ class TypeAdapter(Generic[T]): exclude_defaults: Whether to exclude fields with default values. exclude_none: Whether to exclude fields with a value of `None`. round_trip: Whether to serialize and deserialize the instance to ensure round-tripping. - warnings: How to handle serialization errors. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. - fallback: A function to call when an unknown value is encountered. If not provided, - a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: Additional context to pass to the serializer. + warnings: Whether to emit serialization warnings. Returns: The JSON representation of the given instance as bytes. """ return self.serializer.to_json( - instance, + __instance, indent=indent, include=include, exclude=exclude, @@ -639,9 +408,6 @@ class TypeAdapter(Generic[T]): exclude_none=exclude_none, round_trip=round_trip, warnings=warnings, - fallback=fallback, - serialize_as_any=serialize_as_any, - context=context, ) def json_schema( @@ -664,15 +430,11 @@ class TypeAdapter(Generic[T]): The JSON schema for the model as a dictionary. """ schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) - if isinstance(self.core_schema, _mock_val_ser.MockCoreSchema): - self.core_schema.rebuild() - assert not isinstance(self.core_schema, _mock_val_ser.MockCoreSchema), 'this is a bug! please report it' return schema_generator_instance.generate(self.core_schema, mode=mode) @staticmethod def json_schemas( - inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]], - /, + __inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]], *, by_alias: bool = True, title: str | None = None, @@ -683,7 +445,7 @@ class TypeAdapter(Generic[T]): """Generate a JSON schema including definitions from multiple type adapters. Args: - inputs: Inputs to schema generation. The first two items will form the keys of the (first) + __inputs: Inputs to schema generation. The first two items will form the keys of the (first) output mapping; the type adapters will provide the core schemas that get converted into definitions in the output JSON schema. by_alias: Whether to use alias names. @@ -704,17 +466,9 @@ class TypeAdapter(Generic[T]): """ schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) - inputs_ = [] - for key, mode, adapter in inputs: - # This is the same pattern we follow for model json schemas - we attempt a core schema rebuild if we detect a mock - if isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema): - adapter.core_schema.rebuild() - assert not isinstance(adapter.core_schema, _mock_val_ser.MockCoreSchema), ( - 'this is a bug! please report it' - ) - inputs_.append((key, mode, adapter.core_schema)) + inputs = [(key, mode, adapter.core_schema) for key, mode, adapter in __inputs] - json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs_) + json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs) json_schema: dict[str, Any] = {} if definitions: diff --git a/venv/lib/python3.12/site-packages/pydantic/types.py b/venv/lib/python3.12/site-packages/pydantic/types.py index b5c4fd6..5d1bffa 100644 --- a/venv/lib/python3.12/site-packages/pydantic/types.py +++ b/venv/lib/python3.12/site-packages/pydantic/types.py @@ -1,36 +1,34 @@ """The types module contains custom types used by pydantic.""" - from __future__ import annotations as _annotations import base64 import dataclasses as _dataclasses import re -from collections.abc import Hashable, Iterator from datetime import date, datetime from decimal import Decimal from enum import Enum from pathlib import Path -from re import Pattern from types import ModuleType from typing import ( TYPE_CHECKING, - Annotated, Any, Callable, ClassVar, + FrozenSet, Generic, - Literal, + Hashable, + Iterator, + List, + Set, TypeVar, - Union, cast, ) from uuid import UUID import annotated_types from annotated_types import BaseMetadata, MaxLen, MinLen -from pydantic_core import CoreSchema, PydanticCustomError, SchemaSerializer, core_schema -from typing_extensions import Protocol, TypeAlias, TypeAliasType, deprecated, get_args, get_origin -from typing_inspection.introspection import is_union_origin +from pydantic_core import CoreSchema, PydanticCustomError, core_schema +from typing_extensions import Annotated, Literal, Protocol, deprecated from ._internal import _fields, _internal_dataclass, _utils, _validators from ._migration import getattr_migration @@ -39,13 +37,9 @@ from .errors import PydanticUserError from .json_schema import JsonSchemaValue from .warnings import PydanticDeprecatedSince20 -if TYPE_CHECKING: - from ._internal._core_metadata import CoreMetadata - __all__ = ( 'Strict', 'StrictStr', - 'SocketPath', 'conbytes', 'conlist', 'conset', @@ -68,14 +62,10 @@ __all__ = ( 'UUID3', 'UUID4', 'UUID5', - 'UUID6', - 'UUID7', - 'UUID8', 'FilePath', 'DirectoryPath', 'NewPath', 'Json', - 'Secret', 'SecretStr', 'SecretBytes', 'StrictBool', @@ -102,31 +92,19 @@ __all__ = ( 'Base64UrlStr', 'GetPydanticSchema', 'StringConstraints', - 'Tag', - 'Discriminator', - 'JsonValue', - 'OnErrorOmit', - 'FailFast', ) -T = TypeVar('T') - - @_dataclasses.dataclass class Strict(_fields.PydanticMetadata, BaseMetadata): - """!!! abstract "Usage Documentation" - [Strict Mode with `Annotated` `Strict`](../concepts/strict_mode.md#strict-mode-with-annotated-strict) - - A field metadata class to indicate that a field should be validated in strict mode. - Use this class as an annotation via [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated), as seen below. + """A field metadata class to indicate that a field should be validated in strict mode. Attributes: strict: Whether to validate the field in strict mode. Example: ```python - from typing import Annotated + from typing_extensions import Annotated from pydantic.types import Strict @@ -168,7 +146,7 @@ def conint( The reason is that `conint` returns a type, which doesn't play well with static analysis tools. === ":x: Don't do this" - ```python + ```py from pydantic import BaseModel, conint class Foo(BaseModel): @@ -176,8 +154,8 @@ def conint( ``` === ":white_check_mark: Do this" - ```python - from typing import Annotated + ```py + from typing_extensions import Annotated from pydantic import BaseModel, Field @@ -198,7 +176,7 @@ def conint( Returns: The wrapped integer type. - ```python + ```py from pydantic import BaseModel, ValidationError, conint class ConstrainedExample(BaseModel): @@ -227,7 +205,7 @@ def conint( ``` """ # noqa: D212 - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ int, Strict(strict) if strict is not None else None, annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), @@ -238,7 +216,7 @@ def conint( PositiveInt = Annotated[int, annotated_types.Gt(0)] """An integer that must be greater than zero. -```python +```py from pydantic import BaseModel, PositiveInt, ValidationError class Model(BaseModel): @@ -269,7 +247,7 @@ except ValidationError as e: NegativeInt = Annotated[int, annotated_types.Lt(0)] """An integer that must be less than zero. -```python +```py from pydantic import BaseModel, NegativeInt, ValidationError class Model(BaseModel): @@ -300,7 +278,7 @@ except ValidationError as e: NonPositiveInt = Annotated[int, annotated_types.Le(0)] """An integer that must be less than or equal to zero. -```python +```py from pydantic import BaseModel, NonPositiveInt, ValidationError class Model(BaseModel): @@ -331,7 +309,7 @@ except ValidationError as e: NonNegativeInt = Annotated[int, annotated_types.Ge(0)] """An integer that must be greater than or equal to zero. -```python +```py from pydantic import BaseModel, NonNegativeInt, ValidationError class Model(BaseModel): @@ -362,7 +340,7 @@ except ValidationError as e: StrictInt = Annotated[int, Strict()] """An integer that must be validated in strict mode. -```python +```py from pydantic import BaseModel, StrictInt, ValidationError class StrictIntModel(BaseModel): @@ -385,22 +363,7 @@ except ValidationError as e: @_dataclasses.dataclass class AllowInfNan(_fields.PydanticMetadata): - """A field metadata class to indicate that a field should allow `-inf`, `inf`, and `nan`. - - Use this class as an annotation via [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated), as seen below. - - Attributes: - allow_inf_nan: Whether to allow `-inf`, `inf`, and `nan`. Defaults to `True`. - - Example: - ```python - from typing import Annotated - - from pydantic.types import AllowInfNan - - LaxFloat = Annotated[float, AllowInfNan()] - ``` - """ + """A field metadata class to indicate that a field should allow ``-inf``, ``inf``, and ``nan``.""" allow_inf_nan: bool = True @@ -429,7 +392,7 @@ def confloat( The reason is that `confloat` returns a type, which doesn't play well with static analysis tools. === ":x: Don't do this" - ```python + ```py from pydantic import BaseModel, confloat class Foo(BaseModel): @@ -437,9 +400,8 @@ def confloat( ``` === ":white_check_mark: Do this" - ```python - from typing import Annotated - + ```py + from typing_extensions import Annotated from pydantic import BaseModel, Field class Foo(BaseModel): @@ -460,7 +422,7 @@ def confloat( Returns: The wrapped float type. - ```python + ```py from pydantic import BaseModel, ValidationError, confloat class ConstrainedExample(BaseModel): @@ -488,7 +450,7 @@ def confloat( ''' ``` """ # noqa: D212 - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ float, Strict(strict) if strict is not None else None, annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), @@ -500,7 +462,7 @@ def confloat( PositiveFloat = Annotated[float, annotated_types.Gt(0)] """A float that must be greater than zero. -```python +```py from pydantic import BaseModel, PositiveFloat, ValidationError class Model(BaseModel): @@ -531,7 +493,7 @@ except ValidationError as e: NegativeFloat = Annotated[float, annotated_types.Lt(0)] """A float that must be less than zero. -```python +```py from pydantic import BaseModel, NegativeFloat, ValidationError class Model(BaseModel): @@ -562,7 +524,7 @@ except ValidationError as e: NonPositiveFloat = Annotated[float, annotated_types.Le(0)] """A float that must be less than or equal to zero. -```python +```py from pydantic import BaseModel, NonPositiveFloat, ValidationError class Model(BaseModel): @@ -593,7 +555,7 @@ except ValidationError as e: NonNegativeFloat = Annotated[float, annotated_types.Ge(0)] """A float that must be greater than or equal to zero. -```python +```py from pydantic import BaseModel, NonNegativeFloat, ValidationError class Model(BaseModel): @@ -624,7 +586,7 @@ except ValidationError as e: StrictFloat = Annotated[float, Strict(True)] """A float that must be validated in strict mode. -```python +```py from pydantic import BaseModel, StrictFloat, ValidationError class StrictFloatModel(BaseModel): @@ -644,7 +606,7 @@ except ValidationError as e: FiniteFloat = Annotated[float, AllowInfNan(False)] """A float that must be finite (not ``-inf``, ``inf``, or ``nan``). -```python +```py from pydantic import BaseModel, FiniteFloat class Model(BaseModel): @@ -676,7 +638,7 @@ def conbytes( Returns: The wrapped bytes type. """ - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ bytes, Strict(strict) if strict is not None else None, annotated_types.Len(min_length or 0, max_length), @@ -692,29 +654,16 @@ StrictBytes = Annotated[bytes, Strict()] @_dataclasses.dataclass(frozen=True) class StringConstraints(annotated_types.GroupedMetadata): - """!!! abstract "Usage Documentation" - [`StringConstraints`](../concepts/fields.md#string-constraints) - - A field metadata class to apply constraints to `str` types. - Use this class as an annotation via [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated), as seen below. + """Apply constraints to `str` types. Attributes: - strip_whitespace: Whether to remove leading and trailing whitespace. + strip_whitespace: Whether to strip whitespace from the string. to_upper: Whether to convert the string to uppercase. to_lower: Whether to convert the string to lowercase. strict: Whether to validate the string in strict mode. min_length: The minimum length of the string. max_length: The maximum length of the string. pattern: A regex pattern that the string must match. - - Example: - ```python - from typing import Annotated - - from pydantic.types import StringConstraints - - ConstrainedStr = Annotated[str, StringConstraints(min_length=1, max_length=10)] - ``` """ strip_whitespace: bool | None = None @@ -723,7 +672,7 @@ class StringConstraints(annotated_types.GroupedMetadata): strict: bool | None = None min_length: int | None = None max_length: int | None = None - pattern: str | Pattern[str] | None = None + pattern: str | None = None def __iter__(self) -> Iterator[BaseMetadata]: if self.min_length is not None: @@ -731,14 +680,14 @@ class StringConstraints(annotated_types.GroupedMetadata): if self.max_length is not None: yield MaxLen(self.max_length) if self.strict is not None: - yield Strict(self.strict) + yield Strict() if ( self.strip_whitespace is not None or self.pattern is not None or self.to_lower is not None or self.to_upper is not None ): - yield _fields.pydantic_general_metadata( + yield _fields.PydanticGeneralMetadata( strip_whitespace=self.strip_whitespace, to_upper=self.to_upper, to_lower=self.to_lower, @@ -754,7 +703,7 @@ def constr( strict: bool | None = None, min_length: int | None = None, max_length: int | None = None, - pattern: str | Pattern[str] | None = None, + pattern: str | None = None, ) -> type[str]: """ !!! warning "Discouraged" @@ -767,7 +716,7 @@ def constr( The reason is that `constr` returns a type, which doesn't play well with static analysis tools. === ":x: Don't do this" - ```python + ```py from pydantic import BaseModel, constr class Foo(BaseModel): @@ -775,27 +724,21 @@ def constr( ``` === ":white_check_mark: Do this" - ```python - from typing import Annotated - - from pydantic import BaseModel, StringConstraints + ```py + from pydantic import BaseModel, Annotated, StringConstraints class Foo(BaseModel): - bar: Annotated[ - str, - StringConstraints( - strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$' - ), - ] + bar: Annotated[str, StringConstraints(strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$')] ``` A wrapper around `str` that allows for additional constraints. - ```python + ```py from pydantic import BaseModel, constr class Foo(BaseModel): - bar: constr(strip_whitespace=True, to_upper=True) + bar: constr(strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$') + foo = Foo(bar=' hello ') print(foo) @@ -814,7 +757,7 @@ def constr( Returns: The wrapped string type. """ # noqa: D212 - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ str, StringConstraints( strip_whitespace=strip_whitespace, @@ -849,7 +792,7 @@ def conset( Returns: The wrapped set type. """ - return Annotated[set[item_type], annotated_types.Len(min_length or 0, max_length)] # pyright: ignore[reportReturnType] + return Annotated[Set[item_type], annotated_types.Len(min_length or 0, max_length)] def confrozenset( @@ -865,7 +808,7 @@ def confrozenset( Returns: The wrapped frozenset type. """ - return Annotated[frozenset[item_type], annotated_types.Len(min_length or 0, max_length)] # pyright: ignore[reportReturnType] + return Annotated[FrozenSet[item_type], annotated_types.Len(min_length or 0, max_length)] AnyItemType = TypeVar('AnyItemType') @@ -878,16 +821,13 @@ def conlist( max_length: int | None = None, unique_items: bool | None = None, ) -> type[list[AnyItemType]]: - """A wrapper around [`list`][] that adds validation. + """A wrapper around typing.List that adds validation. Args: item_type: The type of the items in the list. min_length: The minimum length of the list. Defaults to None. max_length: The maximum length of the list. Defaults to None. unique_items: Whether the items in the list must be unique. Defaults to None. - !!! warning Deprecated - The `unique_items` parameter is deprecated, use `Set` instead. - See [this issue](https://github.com/pydantic/pydantic-core/issues/296) for more details. Returns: The wrapped list type. @@ -900,7 +840,7 @@ def conlist( ), code='removed-kwargs', ) - return Annotated[list[item_type], annotated_types.Len(min_length or 0, max_length)] # pyright: ignore[reportReturnType] + return Annotated[List[item_type], annotated_types.Len(min_length or 0, max_length)] # ~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT STRING TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -911,24 +851,31 @@ if TYPE_CHECKING: else: class ImportString: - """A type that can be used to import a Python object from a string. + """A type that can be used to import a type from a string. `ImportString` expects a string and loads the Python object importable at that dotted path. - Attributes of modules may be separated from the module by `:` or `.`, e.g. if `'math:cos'` is provided, - the resulting field value would be the function `cos`. If a `.` is used and both an attribute and submodule + Attributes of modules may be separated from the module by `:` or `.`, e.g. if `'math:cos'` was provided, + the resulting field value would be the function`cos`. If a `.` is used and both an attribute and submodule are present at the same path, the module will be preferred. On model instantiation, pointers will be evaluated and imported. There is some nuance to this behavior, demonstrated in the examples below. - ```python - import math + > A known limitation: setting a default value to a string + > won't result in validation (thus evaluation). This is actively + > being worked on. + + **Good behavior:** + ```py + from math import cos + + from pydantic import BaseModel, ImportString, ValidationError - from pydantic import BaseModel, Field, ImportString, ValidationError class ImportThings(BaseModel): obj: ImportString + # A string value will cause an automatic import my_cos = ImportThings(obj='math.cos') @@ -936,6 +883,7 @@ else: cos_of_0 = my_cos.obj(0) assert cos_of_0 == 1 + # A string whose value cannot be imported will raise an error try: ImportThings(obj='foo.bar') @@ -944,45 +892,28 @@ else: ''' 1 validation error for ImportThings obj - Invalid python path: No module named 'foo.bar' [type=import_error, input_value='foo.bar', input_type=str] + Invalid python path: No module named 'foo.bar' [type=import_error, input_value='foo.bar', input_type=str] ''' + # Actual python objects can be assigned as well - my_cos = ImportThings(obj=math.cos) + my_cos = ImportThings(obj=cos) my_cos_2 = ImportThings(obj='math.cos') - my_cos_3 = ImportThings(obj='math:cos') - assert my_cos == my_cos_2 == my_cos_3 - - # You can set default field value either as Python object: - class ImportThingsDefaultPyObj(BaseModel): - obj: ImportString = math.cos - - # or as a string value (but only if used with `validate_default=True`) - class ImportThingsDefaultString(BaseModel): - obj: ImportString = Field(default='math.cos', validate_default=True) - - my_cos_default1 = ImportThingsDefaultPyObj() - my_cos_default2 = ImportThingsDefaultString() - assert my_cos_default1.obj == my_cos_default2.obj == math.cos - - # note: this will not work! - class ImportThingsMissingValidateDefault(BaseModel): - obj: ImportString = 'math.cos' - - my_cos_default3 = ImportThingsMissingValidateDefault() - assert my_cos_default3.obj == 'math.cos' # just string, not evaluated + assert my_cos == my_cos_2 ``` Serializing an `ImportString` type to json is also possible. - ```python + ```py from pydantic import BaseModel, ImportString + class ImportThings(BaseModel): obj: ImportString + # Create an instance - m = ImportThings(obj='math.cos') + m = ImportThings(obj='math:cos') print(m) #> obj= print(m.model_dump_json()) @@ -1009,25 +940,12 @@ else: function=_validators.import_string, schema=handler(source), serialization=serializer ) - @classmethod - def __get_pydantic_json_schema__(cls, cs: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - return handler(core_schema.str_schema()) - @staticmethod def _serialize(v: Any) -> str: if isinstance(v, ModuleType): return v.__name__ elif hasattr(v, '__module__') and hasattr(v, '__name__'): return f'{v.__module__}.{v.__name__}' - # Handle special cases for sys.XXX streams - # if we see more of these, we should consider a more general solution - elif hasattr(v, 'name'): - if v.name == '': - return 'sys.stdout' - elif v.name == '': - return 'sys.stdin' - elif v.name == '': - return 'sys.stderr' else: return v @@ -1061,7 +979,7 @@ def condecimal( The reason is that `condecimal` returns a type, which doesn't play well with static analysis tools. === ":x: Don't do this" - ```python + ```py from pydantic import BaseModel, condecimal class Foo(BaseModel): @@ -1069,9 +987,9 @@ def condecimal( ``` === ":white_check_mark: Do this" - ```python + ```py from decimal import Decimal - from typing import Annotated + from typing_extensions import Annotated from pydantic import BaseModel, Field @@ -1092,7 +1010,7 @@ def condecimal( decimal_places: The number of decimal places. Defaults to `None`. allow_inf_nan: Whether to allow infinity and NaN. Defaults to `None`. - ```python + ```py from decimal import Decimal from pydantic import BaseModel, ValidationError, condecimal @@ -1122,12 +1040,12 @@ def condecimal( ''' ``` """ # noqa: D212 - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ Decimal, Strict(strict) if strict is not None else None, annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, - _fields.pydantic_general_metadata(max_digits=max_digits, decimal_places=decimal_places), + _fields.PydanticGeneralMetadata(max_digits=max_digits, decimal_places=decimal_places), AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None, ] @@ -1137,25 +1055,9 @@ def condecimal( @_dataclasses.dataclass(**_internal_dataclass.slots_true) class UuidVersion: - """A field metadata class to indicate a [UUID](https://docs.python.org/3/library/uuid.html) version. + """A field metadata class to indicate a [UUID](https://docs.python.org/3/library/uuid.html) version.""" - Use this class as an annotation via [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated), as seen below. - - Attributes: - uuid_version: The version of the UUID. Must be one of 1, 3, 4, 5, or 7. - - Example: - ```python - from typing import Annotated - from uuid import UUID - - from pydantic.types import UuidVersion - - UUID1 = Annotated[UUID, UuidVersion(1)] - ``` - """ - - uuid_version: Literal[1, 3, 4, 5, 6, 7, 8] + uuid_version: Literal[1, 3, 4, 5] def __get_pydantic_json_schema__( self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler @@ -1166,15 +1068,7 @@ class UuidVersion: return field_schema def __get_pydantic_core_schema__(self, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - if isinstance(self, source): - # used directly as a type - return core_schema.uuid_schema(version=self.uuid_version) - else: - # update existing schema with self.uuid_version - schema = handler(source) - _check_annotated_type(schema['type'], 'uuid', self.__class__.__name__) - schema['version'] = self.uuid_version # type: ignore - return schema + return core_schema.uuid_schema(version=self.uuid_version) def __hash__(self) -> int: return hash(type(self.uuid_version)) @@ -1183,10 +1077,10 @@ class UuidVersion: UUID1 = Annotated[UUID, UuidVersion(1)] """A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 1. -```python +```py import uuid -from pydantic import UUID1, BaseModel +from pydantic import BaseModel, UUID1 class Model(BaseModel): uuid1: UUID1 @@ -1197,10 +1091,10 @@ Model(uuid1=uuid.uuid1()) UUID3 = Annotated[UUID, UuidVersion(3)] """A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 3. -```python +```py import uuid -from pydantic import UUID3, BaseModel +from pydantic import BaseModel, UUID3 class Model(BaseModel): uuid3: UUID3 @@ -1211,10 +1105,10 @@ Model(uuid3=uuid.uuid3(uuid.NAMESPACE_DNS, 'pydantic.org')) UUID4 = Annotated[UUID, UuidVersion(4)] """A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 4. -```python +```py import uuid -from pydantic import UUID4, BaseModel +from pydantic import BaseModel, UUID4 class Model(BaseModel): uuid4: UUID4 @@ -1225,10 +1119,10 @@ Model(uuid4=uuid.uuid4()) UUID5 = Annotated[UUID, UuidVersion(5)] """A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 5. -```python +```py import uuid -from pydantic import UUID5, BaseModel +from pydantic import BaseModel, UUID5 class Model(BaseModel): uuid5: UUID5 @@ -1236,55 +1130,14 @@ class Model(BaseModel): Model(uuid5=uuid.uuid5(uuid.NAMESPACE_DNS, 'pydantic.org')) ``` """ -UUID6 = Annotated[UUID, UuidVersion(6)] -"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 6. -```python -import uuid - -from pydantic import UUID6, BaseModel - -class Model(BaseModel): - uuid6: UUID6 - -Model(uuid6=uuid.UUID('1efea953-c2d6-6790-aa0a-69db8c87df97')) -``` -""" -UUID7 = Annotated[UUID, UuidVersion(7)] -"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 7. - -```python -import uuid - -from pydantic import UUID7, BaseModel - -class Model(BaseModel): - uuid7: UUID7 - -Model(uuid7=uuid.UUID('0194fdcb-1c47-7a09-b52c-561154de0b4a')) -``` -""" -UUID8 = Annotated[UUID, UuidVersion(8)] -"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 8. - -```python -import uuid - -from pydantic import UUID8, BaseModel - -class Model(BaseModel): - uuid8: UUID8 - -Model(uuid8=uuid.UUID('81a0b92e-6078-8551-9c81-8ccb666bdab8')) -``` -""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @_dataclasses.dataclass class PathType: - path_type: Literal['file', 'dir', 'new', 'socket'] + path_type: Literal['file', 'dir', 'new'] def __get_pydantic_json_schema__( self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler @@ -1299,7 +1152,6 @@ class PathType: 'file': cast(core_schema.WithInfoValidatorFunction, self.validate_file), 'dir': cast(core_schema.WithInfoValidatorFunction, self.validate_directory), 'new': cast(core_schema.WithInfoValidatorFunction, self.validate_new), - 'socket': cast(core_schema.WithInfoValidatorFunction, self.validate_socket), } return core_schema.with_info_after_validator_function( @@ -1314,13 +1166,6 @@ class PathType: else: raise PydanticCustomError('path_not_file', 'Path does not point to a file') - @staticmethod - def validate_socket(path: Path, _: core_schema.ValidationInfo) -> Path: - if path.is_socket(): - return path - else: - raise PydanticCustomError('path_not_socket', 'Path does not point to a socket') - @staticmethod def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path: if path.is_dir(): @@ -1344,7 +1189,7 @@ class PathType: FilePath = Annotated[Path, PathType('file')] """A path that must point to a file. -```python +```py from pathlib import Path from pydantic import BaseModel, FilePath, ValidationError @@ -1386,7 +1231,7 @@ except ValidationError as e: DirectoryPath = Annotated[Path, PathType('dir')] """A path that must point to a directory. -```python +```py from pathlib import Path from pydantic import BaseModel, DirectoryPath, ValidationError @@ -1426,16 +1271,13 @@ except ValidationError as e: ``` """ NewPath = Annotated[Path, PathType('new')] -"""A path for a new file or directory that must not already exist. The parent directory must already exist.""" +"""A path for a new file or directory that must not already exist.""" -SocketPath = Annotated[Path, PathType('socket')] -"""A path to an existing socket file""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ if TYPE_CHECKING: - # Json[list[str]] will be recognized by type checkers as list[str] - Json = Annotated[AnyType, ...] + Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str] else: @@ -1445,16 +1287,19 @@ else: You can use the `Json` data type to make Pydantic first load a raw JSON string before validating the loaded data into the parametrized type: - ```python - from typing import Any + ```py + from typing import Any, List from pydantic import BaseModel, Json, ValidationError + class AnyJsonModel(BaseModel): json_obj: Json[Any] + class ConstrainedJsonModel(BaseModel): - json_obj: Json[list[int]] + json_obj: Json[List[int]] + print(AnyJsonModel(json_obj='{"b": 1}')) #> json_obj={'b': 1} @@ -1468,7 +1313,7 @@ else: ''' 1 validation error for ConstrainedJsonModel json_obj - JSON input should be string, bytes or bytearray [type=json_type, input_value=12, input_type=int] + JSON input should be string, bytes or bytearray [type=json_type, input_value=12, input_type=int] ''' try: @@ -1478,7 +1323,7 @@ else: ''' 1 validation error for ConstrainedJsonModel json_obj - Invalid JSON: expected value at line 1 column 2 [type=json_invalid, input_value='[a, b]', input_type=str] + Invalid JSON: expected value at line 1 column 2 [type=json_invalid, input_value='[a, b]', input_type=str] ''' try: @@ -1488,20 +1333,24 @@ else: ''' 2 validation errors for ConstrainedJsonModel json_obj.0 - Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str] + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str] json_obj.1 - Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='b', input_type=str] + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='b', input_type=str] ''' ``` When you dump the model using `model_dump` or `model_dump_json`, the dumped value will be the result of validation, not the original JSON string. However, you can use the argument `round_trip=True` to get the original JSON string back: - ```python + ```py + from typing import List + from pydantic import BaseModel, Json + class ConstrainedJsonModel(BaseModel): - json_obj: Json[list[int]] + json_obj: Json[List[int]] + print(ConstrainedJsonModel(json_obj='[1, 2, 3]').model_dump_json()) #> {"json_obj":[1,2,3]} @@ -1530,15 +1379,15 @@ else: return hash(type(self)) def __eq__(self, other: Any) -> bool: - return type(other) is type(self) + return type(other) == type(self) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -SecretType = TypeVar('SecretType') +SecretType = TypeVar('SecretType', str, bytes) -class _SecretBase(Generic[SecretType]): +class _SecretField(Generic[SecretType]): def __init__(self, secret_value: SecretType) -> None: self._secret_value: SecretType = secret_value @@ -1556,206 +1405,41 @@ class _SecretBase(Generic[SecretType]): def __hash__(self) -> int: return hash(self.get_secret_value()) + def __len__(self) -> int: + return len(self._secret_value) + def __str__(self) -> str: return str(self._display()) def __repr__(self) -> str: return f'{self.__class__.__name__}({self._display()!r})' - def _display(self) -> str | bytes: + def _display(self) -> SecretType: raise NotImplementedError - -def _serialize_secret(value: Secret[SecretType], info: core_schema.SerializationInfo) -> str | Secret[SecretType]: - if info.mode == 'json': - return str(value) - else: - return value - - -class Secret(_SecretBase[SecretType]): - """A generic base class used for defining a field with sensitive information that you do not want to be visible in logging or tracebacks. - - You may either directly parametrize `Secret` with a type, or subclass from `Secret` with a parametrized type. The benefit of subclassing - is that you can define a custom `_display` method, which will be used for `repr()` and `str()` methods. The examples below demonstrate both - ways of using `Secret` to create a new secret type. - - 1. Directly parametrizing `Secret` with a type: - - ```python - from pydantic import BaseModel, Secret - - SecretBool = Secret[bool] - - class Model(BaseModel): - secret_bool: SecretBool - - m = Model(secret_bool=True) - print(m.model_dump()) - #> {'secret_bool': Secret('**********')} - - print(m.model_dump_json()) - #> {"secret_bool":"**********"} - - print(m.secret_bool.get_secret_value()) - #> True - ``` - - 2. Subclassing from parametrized `Secret`: - - ```python - from datetime import date - - from pydantic import BaseModel, Secret - - class SecretDate(Secret[date]): - def _display(self) -> str: - return '****/**/**' - - class Model(BaseModel): - secret_date: SecretDate - - m = Model(secret_date=date(2022, 1, 1)) - print(m.model_dump()) - #> {'secret_date': SecretDate('****/**/**')} - - print(m.model_dump_json()) - #> {"secret_date":"****/**/**"} - - print(m.secret_date.get_secret_value()) - #> 2022-01-01 - ``` - - The value returned by the `_display` method will be used for `repr()` and `str()`. - - You can enforce constraints on the underlying type through annotations: - For example: - - ```python - from typing import Annotated - - from pydantic import BaseModel, Field, Secret, ValidationError - - SecretPosInt = Secret[Annotated[int, Field(gt=0, strict=True)]] - - class Model(BaseModel): - sensitive_int: SecretPosInt - - m = Model(sensitive_int=42) - print(m.model_dump()) - #> {'sensitive_int': Secret('**********')} - - try: - m = Model(sensitive_int=-42) # (1)! - except ValidationError as exc_info: - print(exc_info.errors(include_url=False, include_input=False)) - ''' - [ - { - 'type': 'greater_than', - 'loc': ('sensitive_int',), - 'msg': 'Input should be greater than 0', - 'ctx': {'gt': 0}, - } - ] - ''' - - try: - m = Model(sensitive_int='42') # (2)! - except ValidationError as exc_info: - print(exc_info.errors(include_url=False, include_input=False)) - ''' - [ - { - 'type': 'int_type', - 'loc': ('sensitive_int',), - 'msg': 'Input should be a valid integer', - } - ] - ''' - ``` - - 1. The input value is not greater than 0, so it raises a validation error. - 2. The input value is not an integer, so it raises a validation error because the `SecretPosInt` type has strict mode enabled. - """ - - def _display(self) -> str | bytes: - return '**********' if self.get_secret_value() else '' - @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - inner_type = None - # if origin_type is Secret, then cls is a GenericAlias, and we can extract the inner type directly - origin_type = get_origin(source) - if origin_type is not None: - inner_type = get_args(source)[0] - # otherwise, we need to get the inner type from the base class + if issubclass(source, SecretStr): + field_type = str + inner_schema = core_schema.str_schema() else: - bases = getattr(cls, '__orig_bases__', getattr(cls, '__bases__', [])) - for base in bases: - if get_origin(base) is Secret: - inner_type = get_args(base)[0] - if bases == [] or inner_type is None: - raise TypeError( - f"Can't get secret type from {cls.__name__}. " - 'Please use Secret[], or subclass from Secret[] instead.' - ) + assert issubclass(source, SecretBytes) + field_type = bytes + inner_schema = core_schema.bytes_schema() + error_kind = 'string_type' if field_type is str else 'bytes_type' - inner_schema = handler.generate_schema(inner_type) # type: ignore + def serialize( + value: _SecretField[SecretType], info: core_schema.SerializationInfo + ) -> str | _SecretField[SecretType]: + if info.mode == 'json': + # we want the output to always be string without the `b'` prefix for bytes, + # hence we just use `secret_display` + return _secret_display(value.get_secret_value()) + else: + return value - def validate_secret_value(value, handler) -> Secret[SecretType]: - if isinstance(value, Secret): - value = value.get_secret_value() - validated_inner = handler(value) - return cls(validated_inner) - - return core_schema.json_or_python_schema( - python_schema=core_schema.no_info_wrap_validator_function( - validate_secret_value, - inner_schema, - ), - json_schema=core_schema.no_info_after_validator_function(lambda x: cls(x), inner_schema), - serialization=core_schema.plain_serializer_function_ser_schema( - _serialize_secret, - info_arg=True, - when_used='always', - ), - ) - - __pydantic_serializer__ = SchemaSerializer( - core_schema.any_schema( - serialization=core_schema.plain_serializer_function_ser_schema( - _serialize_secret, - info_arg=True, - when_used='always', - ) - ) - ) - - -def _secret_display(value: SecretType) -> str: # type: ignore - return '**********' if value else '' - - -def _serialize_secret_field( - value: _SecretField[SecretType], info: core_schema.SerializationInfo -) -> str | _SecretField[SecretType]: - if info.mode == 'json': - # we want the output to always be string without the `b'` prefix for bytes, - # hence we just use `secret_display` - return _secret_display(value.get_secret_value()) - else: - return value - - -class _SecretField(_SecretBase[SecretType]): - _inner_schema: ClassVar[CoreSchema] - _error_kind: ClassVar[str] - - @classmethod - def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: - json_schema = handler(cls._inner_schema) + json_schema = handler(inner_schema) _utils.update_not_none( json_schema, type='string', @@ -1764,52 +1448,39 @@ class _SecretField(_SecretBase[SecretType]): ) return json_schema - def get_secret_schema(strict: bool) -> CoreSchema: - inner_schema = {**cls._inner_schema, 'strict': strict} - json_schema = core_schema.no_info_after_validator_function( - source, # construct the type - inner_schema, # pyright: ignore[reportArgumentType] - ) - return core_schema.json_or_python_schema( - python_schema=core_schema.union_schema( - [ - core_schema.is_instance_schema(source), - json_schema, - ], - custom_error_type=cls._error_kind, + s = core_schema.union_schema( + [ + core_schema.is_instance_schema(source), + core_schema.no_info_after_validator_function( + source, # construct the type + inner_schema, ), - json_schema=json_schema, - serialization=core_schema.plain_serializer_function_ser_schema( - _serialize_secret_field, - info_arg=True, - when_used='always', - ), - ) - - return core_schema.lax_or_strict_schema( - lax_schema=get_secret_schema(strict=False), - strict_schema=get_secret_schema(strict=True), - metadata={'pydantic_js_functions': [get_json_schema]}, - ) - - __pydantic_serializer__ = SchemaSerializer( - core_schema.any_schema( + ], + strict=True, + custom_error_type=error_kind, serialization=core_schema.plain_serializer_function_ser_schema( - _serialize_secret_field, + serialize, info_arg=True, - when_used='always', - ) + return_schema=core_schema.str_schema(), + when_used='json', + ), ) - ) + s.setdefault('metadata', {}).setdefault('pydantic_js_functions', []).append(get_json_schema) + return s + + +def _secret_display(value: str | bytes) -> str: + if isinstance(value, bytes): + value = value.decode() + return '**********' if value else '' class SecretStr(_SecretField[str]): """A string used for storing sensitive information that you do not want to be visible in logging or tracebacks. - When the secret value is nonempty, it is displayed as `'**********'` instead of the underlying value in - calls to `repr()` and `str()`. If the value _is_ empty, it is displayed as `''`. + It displays `'**********'` instead of the string value on `repr()` and `str()` calls. - ```python + ```py from pydantic import BaseModel, SecretStr class User(BaseModel): @@ -1822,62 +1493,19 @@ class SecretStr(_SecretField[str]): #> username='scolvin' password=SecretStr('**********') print(user.password.get_secret_value()) #> password1 - print((SecretStr('password'), SecretStr(''))) - #> (SecretStr('**********'), SecretStr('')) - ``` - - As seen above, by default, [`SecretStr`][pydantic.types.SecretStr] (and [`SecretBytes`][pydantic.types.SecretBytes]) - will be serialized as `**********` when serializing to json. - - You can use the [`field_serializer`][pydantic.functional_serializers.field_serializer] to dump the - secret as plain-text when serializing to json. - - ```python - from pydantic import BaseModel, SecretBytes, SecretStr, field_serializer - - class Model(BaseModel): - password: SecretStr - password_bytes: SecretBytes - - @field_serializer('password', 'password_bytes', when_used='json') - def dump_secret(self, v): - return v.get_secret_value() - - model = Model(password='IAmSensitive', password_bytes=b'IAmSensitiveBytes') - print(model) - #> password=SecretStr('**********') password_bytes=SecretBytes(b'**********') - print(model.password) - #> ********** - print(model.model_dump()) - ''' - { - 'password': SecretStr('**********'), - 'password_bytes': SecretBytes(b'**********'), - } - ''' - print(model.model_dump_json()) - #> {"password":"IAmSensitive","password_bytes":"IAmSensitiveBytes"} ``` """ - _inner_schema: ClassVar[CoreSchema] = core_schema.str_schema() - _error_kind: ClassVar[str] = 'string_type' - - def __len__(self) -> int: - return len(self._secret_value) - def _display(self) -> str: - return _secret_display(self._secret_value) + return _secret_display(self.get_secret_value()) class SecretBytes(_SecretField[bytes]): """A bytes used for storing sensitive information that you do not want to be visible in logging or tracebacks. It displays `b'**********'` instead of the string value on `repr()` and `str()` calls. - When the secret value is nonempty, it is displayed as `b'**********'` instead of the underlying value in - calls to `repr()` and `str()`. If the value _is_ empty, it is displayed as `b''`. - ```python + ```py from pydantic import BaseModel, SecretBytes class User(BaseModel): @@ -1888,19 +1516,11 @@ class SecretBytes(_SecretField[bytes]): #> username='scolvin' password=SecretBytes(b'**********') print(user.password.get_secret_value()) #> b'password1' - print((SecretBytes(b'password'), SecretBytes(b''))) - #> (SecretBytes(b'**********'), SecretBytes(b'')) ``` """ - _inner_schema: ClassVar[CoreSchema] = core_schema.bytes_schema() - _error_kind: ClassVar[str] = 'bytes_type' - - def __len__(self) -> int: - return len(self._secret_value) - def _display(self) -> bytes: - return _secret_display(self._secret_value).encode() + return _secret_display(self.get_secret_value()).encode() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1950,9 +1570,9 @@ class PaymentCardNumber(str): ) @classmethod - def validate(cls, input_value: str, /, _: core_schema.ValidationInfo) -> PaymentCardNumber: + def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> PaymentCardNumber: """Validate the card number and return a `PaymentCardNumber` instance.""" - return cls(input_value) + return cls(__input_value) @property def masked(self) -> str: @@ -2026,6 +1646,24 @@ class PaymentCardNumber(str): # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +BYTE_SIZES = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, +} +BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) +byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) + class ByteSize(int): """Converts a string representing a number of bytes with units (such as `'1KB'` or `'11.5MiB'`) into an integer. @@ -2040,7 +1678,7 @@ class ByteSize(int): !!! info Note that `1b` will be parsed as "1 byte" and not "1 bit". - ```python + ```py from pydantic import BaseModel, ByteSize class MyModel(BaseModel): @@ -2056,72 +1694,24 @@ class ByteSize(int): #> 44.4PiB print(m.size.human_readable(decimal=True)) #> 50.0PB - print(m.size.human_readable(separator=' ')) - #> 44.4 PiB print(m.size.to('TiB')) #> 45474.73508864641 ``` """ - byte_sizes = { - 'b': 1, - 'kb': 10**3, - 'mb': 10**6, - 'gb': 10**9, - 'tb': 10**12, - 'pb': 10**15, - 'eb': 10**18, - 'kib': 2**10, - 'mib': 2**20, - 'gib': 2**30, - 'tib': 2**40, - 'pib': 2**50, - 'eib': 2**60, - 'bit': 1 / 8, - 'kbit': 10**3 / 8, - 'mbit': 10**6 / 8, - 'gbit': 10**9 / 8, - 'tbit': 10**12 / 8, - 'pbit': 10**15 / 8, - 'ebit': 10**18 / 8, - 'kibit': 2**10 / 8, - 'mibit': 2**20 / 8, - 'gibit': 2**30 / 8, - 'tibit': 2**40 / 8, - 'pibit': 2**50 / 8, - 'eibit': 2**60 / 8, - } - byte_sizes.update({k.lower()[0]: v for k, v in byte_sizes.items() if 'i' not in k}) - - byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?' - byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE) - @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - return core_schema.with_info_after_validator_function( - function=cls._validate, - schema=core_schema.union_schema( - [ - core_schema.str_schema(pattern=cls.byte_string_pattern), - core_schema.int_schema(ge=0), - ], - custom_error_type='byte_size', - custom_error_message='could not parse value and unit from byte string', - ), - serialization=core_schema.plain_serializer_function_ser_schema( - int, return_schema=core_schema.int_schema(ge=0) - ), - ) + return core_schema.with_info_plain_validator_function(cls._validate) @classmethod - def _validate(cls, input_value: Any, /, _: core_schema.ValidationInfo) -> ByteSize: + def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize: try: - return cls(int(input_value)) + return cls(int(__input_value)) except ValueError: pass - str_match = cls.byte_string_re.match(str(input_value)) + str_match = byte_string_re.match(str(__input_value)) if str_match is None: raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string') @@ -2130,19 +1720,18 @@ class ByteSize(int): unit = 'b' try: - unit_mult = cls.byte_sizes[unit.lower()] + unit_mult = BYTE_SIZES[unit.lower()] except KeyError: raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit}) return cls(int(float(scalar) * unit_mult)) - def human_readable(self, decimal: bool = False, separator: str = '') -> str: + def human_readable(self, decimal: bool = False) -> str: """Converts a byte size to a human readable string. Args: decimal: If True, use decimal units (e.g. 1000 bytes per KB). If False, use binary units (e.g. 1024 bytes per KiB). - separator: A string used to split the value and unit. Defaults to an empty string (''). Returns: A human readable string representation of the byte size. @@ -2160,27 +1749,25 @@ class ByteSize(int): for unit in units: if abs(num) < divisor: if unit == 'B': - return f'{num:0.0f}{separator}{unit}' + return f'{num:0.0f}{unit}' else: - return f'{num:0.1f}{separator}{unit}' + return f'{num:0.1f}{unit}' num /= divisor - return f'{num:0.1f}{separator}{final_unit}' + return f'{num:0.1f}{final_unit}' def to(self, unit: str) -> float: - """Converts a byte size to another unit, including both byte and bit units. + """Converts a byte size to another unit. Args: - unit: The unit to convert to. Must be one of the following: B, KB, MB, GB, TB, PB, EB, - KiB, MiB, GiB, TiB, PiB, EiB (byte units) and - bit, kbit, mbit, gbit, tbit, pbit, ebit, - kibit, mibit, gibit, tibit, pibit, eibit (bit units). + unit: The unit to convert to. Must be one of the following: B, KB, MB, GB, TB, PB, EiB, + KiB, MiB, GiB, TiB, PiB, EiB. Returns: The byte size in the new unit. """ try: - unit_div = self.byte_sizes[unit.lower()] + unit_div = BYTE_SIZES[unit.lower()] except KeyError: raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit}) @@ -2192,7 +1779,7 @@ class ByteSize(int): def _check_annotated_type(annotated_type: str, expected_type: str, annotation: str) -> None: if annotated_type != expected_type: - raise PydanticUserError(f"'{annotation}' cannot annotate '{annotated_type}'.", code='invalid-annotated-type') + raise PydanticUserError(f"'{annotation}' cannot annotate '{annotated_type}'.", code='invalid_annotated_type') if TYPE_CHECKING: @@ -2259,7 +1846,7 @@ def condate( Returns: A date type with the specified constraints. """ - return Annotated[ # pyright: ignore[reportReturnType] + return Annotated[ date, Strict(strict) if strict is not None else None, annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), @@ -2407,7 +1994,7 @@ class Base64Encoder(EncoderProtocol): The decoded data. """ try: - return base64.b64decode(data) + return base64.decodebytes(data) except ValueError as e: raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)}) @@ -2421,7 +2008,7 @@ class Base64Encoder(EncoderProtocol): Returns: The encoded data. """ - return base64.b64encode(value) + return base64.encodebytes(value) @classmethod def get_json_format(cls) -> Literal['base64']: @@ -2479,8 +2066,8 @@ class EncodedBytes: `EncodedBytes` needs an encoder that implements `EncoderProtocol` to operate. - ```python - from typing import Annotated + ```py + from typing_extensions import Annotated from pydantic import BaseModel, EncodedBytes, EncoderProtocol, ValidationError @@ -2538,11 +2125,9 @@ class EncodedBytes: return field_schema def __get_pydantic_core_schema__(self, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - schema = handler(source) - _check_annotated_type(schema['type'], 'bytes', self.__class__.__name__) return core_schema.with_info_after_validator_function( function=self.decode, - schema=schema, + schema=core_schema.bytes_schema(), serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode), ) @@ -2573,13 +2158,13 @@ class EncodedBytes: @_dataclasses.dataclass(**_internal_dataclass.slots_true) -class EncodedStr: +class EncodedStr(EncodedBytes): """A str type that is encoded and decoded using the specified encoder. `EncodedStr` needs an encoder that implements `EncoderProtocol` to operate. - ```python - from typing import Annotated + ```py + from typing_extensions import Annotated from pydantic import BaseModel, EncodedStr, EncoderProtocol, ValidationError @@ -2627,25 +2212,14 @@ class EncodedStr: ``` """ - encoder: type[EncoderProtocol] - - def __get_pydantic_json_schema__( - self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - field_schema = handler(core_schema) - field_schema.update(type='string', format=self.encoder.get_json_format()) - return field_schema - def __get_pydantic_core_schema__(self, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - schema = handler(source) - _check_annotated_type(schema['type'], 'str', self.__class__.__name__) return core_schema.with_info_after_validator_function( function=self.decode_str, - schema=schema, + schema=super(EncodedStr, self).__get_pydantic_core_schema__(source=source, handler=handler), # noqa: UP008 serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode_str), ) - def decode_str(self, data: str, _: core_schema.ValidationInfo) -> str: + def decode_str(self, data: bytes, _: core_schema.ValidationInfo) -> str: """Decode the data using the specified encoder. Args: @@ -2654,7 +2228,7 @@ class EncodedStr: Returns: The decoded data. """ - return self.encoder.decode(data.encode()).decode() + return data.decode() def encode_str(self, value: str) -> str: """Encode the data using the specified encoder. @@ -2665,7 +2239,7 @@ class EncodedStr: Returns: The encoded data. """ - return self.encoder.encode(value.encode()).decode() # noqa: UP008 + return super(EncodedStr, self).encode(value=value.encode()).decode() # noqa: UP008 def __hash__(self) -> int: return hash(self.encoder) @@ -2675,52 +2249,12 @@ Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)] """A bytes type that is encoded and decoded using the standard (non-URL-safe) base64 encoder. Note: - Under the hood, `Base64Bytes` uses the standard library `base64.b64encode` and `base64.b64decode` functions. + Under the hood, `Base64Bytes` use standard library `base64.encodebytes` and `base64.decodebytes` functions. As a result, attempting to decode url-safe base64 data using the `Base64Bytes` type may fail or produce an incorrect decoding. -Warning: - In versions of Pydantic prior to v2.10, `Base64Bytes` used [`base64.encodebytes`][base64.encodebytes] - and [`base64.decodebytes`][base64.decodebytes] functions. According to the [base64 documentation](https://docs.python.org/3/library/base64.html), - these methods are considered legacy implementation, and thus, Pydantic v2.10+ now uses the modern - [`base64.b64encode`][base64.b64encode] and [`base64.b64decode`][base64.b64decode] functions. - - If you'd still like to use these legacy encoders / decoders, you can achieve this by creating a custom annotated type, - like follows: - - ```python - import base64 - from typing import Annotated, Literal - - from pydantic_core import PydanticCustomError - - from pydantic import EncodedBytes, EncoderProtocol - - class LegacyBase64Encoder(EncoderProtocol): - @classmethod - def decode(cls, data: bytes) -> bytes: - try: - return base64.decodebytes(data) - except ValueError as e: - raise PydanticCustomError( - 'base64_decode', - "Base64 decoding error: '{error}'", - {'error': str(e)}, - ) - - @classmethod - def encode(cls, value: bytes) -> bytes: - return base64.encodebytes(value) - - @classmethod - def get_json_format(cls) -> Literal['base64']: - return 'base64' - - LegacyBase64Bytes = Annotated[bytes, EncodedBytes(encoder=LegacyBase64Encoder)] - ``` - -```python +```py from pydantic import Base64Bytes, BaseModel, ValidationError class Model(BaseModel): @@ -2735,7 +2269,7 @@ print(m.base64_bytes) # Serialize into the base64 form print(m.model_dump()) -#> {'base64_bytes': b'VGhpcyBpcyB0aGUgd2F5'} +#> {'base64_bytes': b'VGhpcyBpcyB0aGUgd2F5\n'} # Validate base64 data try: @@ -2753,21 +2287,12 @@ Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)] """A str type that is encoded and decoded using the standard (non-URL-safe) base64 encoder. Note: - Under the hood, `Base64Str` uses the standard library `base64.b64encode` and `base64.b64decode` functions. + Under the hood, `Base64Bytes` use standard library `base64.encodebytes` and `base64.decodebytes` functions. As a result, attempting to decode url-safe base64 data using the `Base64Str` type may fail or produce an incorrect decoding. -Warning: - In versions of Pydantic prior to v2.10, `Base64Str` used [`base64.encodebytes`][base64.encodebytes] - and [`base64.decodebytes`][base64.decodebytes] functions. According to the [base64 documentation](https://docs.python.org/3/library/base64.html), - these methods are considered legacy implementation, and thus, Pydantic v2.10+ now uses the modern - [`base64.b64encode`][base64.b64encode] and [`base64.b64decode`][base64.b64decode] functions. - - See the [`Base64Bytes`][pydantic.types.Base64Bytes] type for more information on how to - replicate the old behavior with the legacy encoders / decoders. - -```python +```py from pydantic import Base64Str, BaseModel, ValidationError class Model(BaseModel): @@ -2782,7 +2307,7 @@ print(m.base64_str) # Serialize into the base64 form print(m.model_dump()) -#> {'base64_str': 'VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y'} +#> {'base64_str': 'VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y\n'} # Validate base64 data try: @@ -2806,7 +2331,7 @@ Note: As a result, the `Base64UrlBytes` type can be used to faithfully decode "vanilla" base64 data (using `'+'` and `'/'`). -```python +```py from pydantic import Base64UrlBytes, BaseModel class Model(BaseModel): @@ -2827,7 +2352,7 @@ Note: As a result, the `Base64UrlStr` type can be used to faithfully decode "vanilla" base64 data (using `'+'` and `'/'`). -```python +```py from pydantic import Base64UrlStr, BaseModel class Model(BaseModel): @@ -2846,17 +2371,16 @@ __getattr__ = getattr_migration(__name__) @_dataclasses.dataclass(**_internal_dataclass.slots_true) class GetPydanticSchema: - """!!! abstract "Usage Documentation" - [Using `GetPydanticSchema` to Reduce Boilerplate](../concepts/types.md#using-getpydanticschema-to-reduce-boilerplate) - - A convenience class for creating an annotation that provides pydantic custom type hooks. + """A convenience class for creating an annotation that provides pydantic custom type hooks. This class is intended to eliminate the need to create a custom "marker" which defines the `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__` custom hook methods. For example, to have a field treated by type checkers as `int`, but by pydantic as `Any`, you can do: ```python - from typing import Annotated, Any + from typing import Any + + from typing_extensions import Annotated from pydantic import BaseModel, GetPydanticSchema @@ -2889,397 +2413,3 @@ class GetPydanticSchema: return object.__getattribute__(self, item) __hash__ = object.__hash__ - - -@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) -class Tag: - """Provides a way to specify the expected tag to use for a case of a (callable) discriminated union. - - Also provides a way to label a union case in error messages. - - When using a callable `Discriminator`, attach a `Tag` to each case in the `Union` to specify the tag that - should be used to identify that case. For example, in the below example, the `Tag` is used to specify that - if `get_discriminator_value` returns `'apple'`, the input should be validated as an `ApplePie`, and if it - returns `'pumpkin'`, the input should be validated as a `PumpkinPie`. - - The primary role of the `Tag` here is to map the return value from the callable `Discriminator` function to - the appropriate member of the `Union` in question. - - ```python - from typing import Annotated, Any, Literal, Union - - from pydantic import BaseModel, Discriminator, Tag - - class Pie(BaseModel): - time_to_cook: int - num_ingredients: int - - class ApplePie(Pie): - fruit: Literal['apple'] = 'apple' - - class PumpkinPie(Pie): - filling: Literal['pumpkin'] = 'pumpkin' - - def get_discriminator_value(v: Any) -> str: - if isinstance(v, dict): - return v.get('fruit', v.get('filling')) - return getattr(v, 'fruit', getattr(v, 'filling', None)) - - class ThanksgivingDinner(BaseModel): - dessert: Annotated[ - Union[ - Annotated[ApplePie, Tag('apple')], - Annotated[PumpkinPie, Tag('pumpkin')], - ], - Discriminator(get_discriminator_value), - ] - - apple_variation = ThanksgivingDinner.model_validate( - {'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}} - ) - print(repr(apple_variation)) - ''' - ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple')) - ''' - - pumpkin_variation = ThanksgivingDinner.model_validate( - { - 'dessert': { - 'filling': 'pumpkin', - 'time_to_cook': 40, - 'num_ingredients': 6, - } - } - ) - print(repr(pumpkin_variation)) - ''' - ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin')) - ''' - ``` - - !!! note - You must specify a `Tag` for every case in a `Tag` that is associated with a - callable `Discriminator`. Failing to do so will result in a `PydanticUserError` with code - [`callable-discriminator-no-tag`](../errors/usage_errors.md#callable-discriminator-no-tag). - - See the [Discriminated Unions] concepts docs for more details on how to use `Tag`s. - - [Discriminated Unions]: ../concepts/unions.md#discriminated-unions - """ - - tag: str - - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - schema = handler(source_type) - metadata = cast('CoreMetadata', schema.setdefault('metadata', {})) - metadata['pydantic_internal_union_tag_key'] = self.tag - return schema - - -@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) -class Discriminator: - """!!! abstract "Usage Documentation" - [Discriminated Unions with `Callable` `Discriminator`](../concepts/unions.md#discriminated-unions-with-callable-discriminator) - - Provides a way to use a custom callable as the way to extract the value of a union discriminator. - - This allows you to get validation behavior like you'd get from `Field(discriminator=)`, - but without needing to have a single shared field across all the union choices. This also makes it - possible to handle unions of models and primitive types with discriminated-union-style validation errors. - Finally, this allows you to use a custom callable as the way to identify which member of a union a value - belongs to, while still seeing all the performance benefits of a discriminated union. - - Consider this example, which is much more performant with the use of `Discriminator` and thus a `TaggedUnion` - than it would be as a normal `Union`. - - ```python - from typing import Annotated, Any, Literal, Union - - from pydantic import BaseModel, Discriminator, Tag - - class Pie(BaseModel): - time_to_cook: int - num_ingredients: int - - class ApplePie(Pie): - fruit: Literal['apple'] = 'apple' - - class PumpkinPie(Pie): - filling: Literal['pumpkin'] = 'pumpkin' - - def get_discriminator_value(v: Any) -> str: - if isinstance(v, dict): - return v.get('fruit', v.get('filling')) - return getattr(v, 'fruit', getattr(v, 'filling', None)) - - class ThanksgivingDinner(BaseModel): - dessert: Annotated[ - Union[ - Annotated[ApplePie, Tag('apple')], - Annotated[PumpkinPie, Tag('pumpkin')], - ], - Discriminator(get_discriminator_value), - ] - - apple_variation = ThanksgivingDinner.model_validate( - {'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}} - ) - print(repr(apple_variation)) - ''' - ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple')) - ''' - - pumpkin_variation = ThanksgivingDinner.model_validate( - { - 'dessert': { - 'filling': 'pumpkin', - 'time_to_cook': 40, - 'num_ingredients': 6, - } - } - ) - print(repr(pumpkin_variation)) - ''' - ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin')) - ''' - ``` - - See the [Discriminated Unions] concepts docs for more details on how to use `Discriminator`s. - - [Discriminated Unions]: ../concepts/unions.md#discriminated-unions - """ - - discriminator: str | Callable[[Any], Hashable] - """The callable or field name for discriminating the type in a tagged union. - - A `Callable` discriminator must extract the value of the discriminator from the input. - A `str` discriminator must be the name of a field to discriminate against. - """ - custom_error_type: str | None = None - """Type to use in [custom errors](../errors/errors.md) replacing the standard discriminated union - validation errors. - """ - custom_error_message: str | None = None - """Message to use in custom errors.""" - custom_error_context: dict[str, int | str | float] | None = None - """Context to use in custom errors.""" - - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - if not is_union_origin(get_origin(source_type)): - raise TypeError(f'{type(self).__name__} must be used with a Union type, not {source_type}') - - if isinstance(self.discriminator, str): - from pydantic import Field - - return handler(Annotated[source_type, Field(discriminator=self.discriminator)]) - else: - original_schema = handler(source_type) - return self._convert_schema(original_schema) - - def _convert_schema(self, original_schema: core_schema.CoreSchema) -> core_schema.TaggedUnionSchema: - if original_schema['type'] != 'union': - # This likely indicates that the schema was a single-item union that was simplified. - # In this case, we do the same thing we do in - # `pydantic._internal._discriminated_union._ApplyInferredDiscriminator._apply_to_root`, namely, - # package the generated schema back into a single-item union. - original_schema = core_schema.union_schema([original_schema]) - - tagged_union_choices = {} - for choice in original_schema['choices']: - tag = None - if isinstance(choice, tuple): - choice, tag = choice - metadata = cast('CoreMetadata | None', choice.get('metadata')) - if metadata is not None: - tag = metadata.get('pydantic_internal_union_tag_key') or tag - if tag is None: - raise PydanticUserError( - f'`Tag` not provided for choice {choice} used with `Discriminator`', - code='callable-discriminator-no-tag', - ) - tagged_union_choices[tag] = choice - - # Have to do these verbose checks to ensure falsy values ('' and {}) don't get ignored - custom_error_type = self.custom_error_type - if custom_error_type is None: - custom_error_type = original_schema.get('custom_error_type') - - custom_error_message = self.custom_error_message - if custom_error_message is None: - custom_error_message = original_schema.get('custom_error_message') - - custom_error_context = self.custom_error_context - if custom_error_context is None: - custom_error_context = original_schema.get('custom_error_context') - - custom_error_type = original_schema.get('custom_error_type') if custom_error_type is None else custom_error_type - return core_schema.tagged_union_schema( - tagged_union_choices, - self.discriminator, - custom_error_type=custom_error_type, - custom_error_message=custom_error_message, - custom_error_context=custom_error_context, - strict=original_schema.get('strict'), - ref=original_schema.get('ref'), - metadata=original_schema.get('metadata'), - serialization=original_schema.get('serialization'), - ) - - -_JSON_TYPES = {int, float, str, bool, list, dict, type(None)} - - -def _get_type_name(x: Any) -> str: - type_ = type(x) - if type_ in _JSON_TYPES: - return type_.__name__ - - # Handle proper subclasses; note we don't need to handle None or bool here - if isinstance(x, int): - return 'int' - if isinstance(x, float): - return 'float' - if isinstance(x, str): - return 'str' - if isinstance(x, list): - return 'list' - if isinstance(x, dict): - return 'dict' - - # Fail by returning the type's actual name - return getattr(type_, '__name__', '') - - -class _AllowAnyJson: - @classmethod - def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - python_schema = handler(source_type) - return core_schema.json_or_python_schema(json_schema=core_schema.any_schema(), python_schema=python_schema) - - -if TYPE_CHECKING: - # This seems to only be necessary for mypy - JsonValue: TypeAlias = Union[ - list['JsonValue'], - dict[str, 'JsonValue'], - str, - bool, - int, - float, - None, - ] - """A `JsonValue` is used to represent a value that can be serialized to JSON. - - It may be one of: - - * `list['JsonValue']` - * `dict[str, 'JsonValue']` - * `str` - * `bool` - * `int` - * `float` - * `None` - - The following example demonstrates how to use `JsonValue` to validate JSON data, - and what kind of errors to expect when input data is not json serializable. - - ```python - import json - - from pydantic import BaseModel, JsonValue, ValidationError - - class Model(BaseModel): - j: JsonValue - - valid_json_data = {'j': {'a': {'b': {'c': 1, 'd': [2, None]}}}} - invalid_json_data = {'j': {'a': {'b': ...}}} - - print(repr(Model.model_validate(valid_json_data))) - #> Model(j={'a': {'b': {'c': 1, 'd': [2, None]}}}) - print(repr(Model.model_validate_json(json.dumps(valid_json_data)))) - #> Model(j={'a': {'b': {'c': 1, 'd': [2, None]}}}) - - try: - Model.model_validate(invalid_json_data) - except ValidationError as e: - print(e) - ''' - 1 validation error for Model - j.dict.a.dict.b - input was not a valid JSON value [type=invalid-json-value, input_value=Ellipsis, input_type=ellipsis] - ''' - ``` - """ - -else: - JsonValue = TypeAliasType( - 'JsonValue', - Annotated[ - Union[ - Annotated[list['JsonValue'], Tag('list')], - Annotated[dict[str, 'JsonValue'], Tag('dict')], - Annotated[str, Tag('str')], - Annotated[bool, Tag('bool')], - Annotated[int, Tag('int')], - Annotated[float, Tag('float')], - Annotated[None, Tag('NoneType')], - ], - Discriminator( - _get_type_name, - custom_error_type='invalid-json-value', - custom_error_message='input was not a valid JSON value', - ), - _AllowAnyJson, - ], - ) - - -class _OnErrorOmit: - @classmethod - def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - # there is no actual default value here but we use with_default_schema since it already has the on_error - # behavior implemented and it would be no more efficient to implement it on every other validator - # or as a standalone validator - return core_schema.with_default_schema(schema=handler(source_type), on_error='omit') - - -OnErrorOmit = Annotated[T, _OnErrorOmit] -""" -When used as an item in a list, the key type in a dict, optional values of a TypedDict, etc. -this annotation omits the item from the iteration if there is any error validating it. -That is, instead of a [`ValidationError`][pydantic_core.ValidationError] being propagated up and the entire iterable being discarded -any invalid items are discarded and the valid ones are returned. -""" - - -@_dataclasses.dataclass -class FailFast(_fields.PydanticMetadata, BaseMetadata): - """A `FailFast` annotation can be used to specify that validation should stop at the first error. - - This can be useful when you want to validate a large amount of data and you only need to know if it's valid or not. - - You might want to enable this setting if you want to validate your data faster (basically, if you use this, - validation will be more performant with the caveat that you get less information). - - ```python - from typing import Annotated - - from pydantic import BaseModel, FailFast, ValidationError - - class Model(BaseModel): - x: Annotated[list[int], FailFast()] - - # This will raise a single error for the first invalid value and stop validation - try: - obj = Model(x=[1, 2, 'a', 4, 5, 'b', 7, 8, 9, 'c']) - except ValidationError as e: - print(e) - ''' - 1 validation error for Model - x.2 - Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str] - ''' - ``` - """ - - fail_fast: bool = True diff --git a/venv/lib/python3.12/site-packages/pydantic/typing.py b/venv/lib/python3.12/site-packages/pydantic/typing.py index 0bda22d..f1b32ba 100644 --- a/venv/lib/python3.12/site-packages/pydantic/typing.py +++ b/venv/lib/python3.12/site-packages/pydantic/typing.py @@ -1,5 +1,4 @@ """`typing` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/utils.py b/venv/lib/python3.12/site-packages/pydantic/utils.py index 8d1e2a8..1619d1d 100644 --- a/venv/lib/python3.12/site-packages/pydantic/utils.py +++ b/venv/lib/python3.12/site-packages/pydantic/utils.py @@ -1,5 +1,4 @@ """The `utils` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/__init__.py b/venv/lib/python3.12/site-packages/pydantic/v1/__init__.py index 6ad3f46..3bf1418 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/__init__.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/__init__.py @@ -1,24 +1,24 @@ # flake8: noqa -from pydantic.v1 import dataclasses -from pydantic.v1.annotated_types import create_model_from_namedtuple, create_model_from_typeddict -from pydantic.v1.class_validators import root_validator, validator -from pydantic.v1.config import BaseConfig, ConfigDict, Extra -from pydantic.v1.decorator import validate_arguments -from pydantic.v1.env_settings import BaseSettings -from pydantic.v1.error_wrappers import ValidationError -from pydantic.v1.errors import * -from pydantic.v1.fields import Field, PrivateAttr, Required -from pydantic.v1.main import * -from pydantic.v1.networks import * -from pydantic.v1.parse import Protocol -from pydantic.v1.tools import * -from pydantic.v1.types import * -from pydantic.v1.version import VERSION, compiled +from . import dataclasses +from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict +from .class_validators import root_validator, validator +from .config import BaseConfig, ConfigDict, Extra +from .decorator import validate_arguments +from .env_settings import BaseSettings +from .error_wrappers import ValidationError +from .errors import * +from .fields import Field, PrivateAttr, Required +from .main import * +from .networks import * +from .parse import Protocol +from .tools import * +from .types import * +from .version import VERSION, compiled __version__ = VERSION -# WARNING __all__ from pydantic.errors is not included here, it will be removed as an export here in v2 -# please use "from pydantic.v1.errors import ..." instead +# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2 +# please use "from pydantic.errors import ..." instead __all__ = [ # annotated types utils 'create_model_from_namedtuple', diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/_hypothesis_plugin.py b/venv/lib/python3.12/site-packages/pydantic/v1/_hypothesis_plugin.py index b62234d..0c52962 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/_hypothesis_plugin.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/_hypothesis_plugin.py @@ -35,7 +35,7 @@ import hypothesis.strategies as st import pydantic import pydantic.color import pydantic.types -from pydantic.v1.utils import lenient_issubclass +from pydantic.utils import lenient_issubclass # FilePath and DirectoryPath are explicitly unsupported, as we'd have to create # them on-disk, and that's unsafe in general without being told *where* to do so. diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/annotated_types.py b/venv/lib/python3.12/site-packages/pydantic/v1/annotated_types.py index d9eaaaf..d333457 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/annotated_types.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/annotated_types.py @@ -1,9 +1,9 @@ import sys from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type -from pydantic.v1.fields import Required -from pydantic.v1.main import BaseModel, create_model -from pydantic.v1.typing import is_typeddict, is_typeddict_special +from .fields import Required +from .main import BaseModel, create_model +from .typing import is_typeddict, is_typeddict_special if TYPE_CHECKING: from typing_extensions import TypedDict diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/class_validators.py b/venv/lib/python3.12/site-packages/pydantic/v1/class_validators.py index 2f68fc8..71e6650 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/class_validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/class_validators.py @@ -5,12 +5,12 @@ from itertools import chain from types import FunctionType from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload -from pydantic.v1.errors import ConfigError -from pydantic.v1.typing import AnyCallable -from pydantic.v1.utils import ROOT_KEY, in_ipython +from .errors import ConfigError +from .typing import AnyCallable +from .utils import ROOT_KEY, in_ipython if TYPE_CHECKING: - from pydantic.v1.typing import AnyClassMethod + from .typing import AnyClassMethod class Validator: @@ -36,9 +36,9 @@ class Validator: if TYPE_CHECKING: from inspect import Signature - from pydantic.v1.config import BaseConfig - from pydantic.v1.fields import ModelField - from pydantic.v1.types import ModelOrDc + from .config import BaseConfig + from .fields import ModelField + from .types import ModelOrDc ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any] ValidatorsList = List[ValidatorCallable] diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/color.py b/venv/lib/python3.12/site-packages/pydantic/v1/color.py index b0bbf78..6fdc9fb 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/color.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/color.py @@ -12,11 +12,11 @@ import re from colorsys import hls_to_rgb, rgb_to_hls from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast -from pydantic.v1.errors import ColorError -from pydantic.v1.utils import Representation, almost_equal_floats +from .errors import ColorError +from .utils import Representation, almost_equal_floats if TYPE_CHECKING: - from pydantic.v1.typing import CallableGenerator, ReprArgs + from .typing import CallableGenerator, ReprArgs ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] ColorType = Union[ColorTuple, str] diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/config.py b/venv/lib/python3.12/site-packages/pydantic/v1/config.py index 18f7c99..a25973a 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/config.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/config.py @@ -4,15 +4,15 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tup from typing_extensions import Literal, Protocol -from pydantic.v1.typing import AnyArgTCallable, AnyCallable -from pydantic.v1.utils import GetterDict -from pydantic.v1.version import compiled +from .typing import AnyArgTCallable, AnyCallable +from .utils import GetterDict +from .version import compiled if TYPE_CHECKING: from typing import overload - from pydantic.v1.fields import ModelField - from pydantic.v1.main import BaseModel + from .fields import ModelField + from .main import BaseModel ConfigType = Type['BaseConfig'] diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/dataclasses.py b/venv/lib/python3.12/site-packages/pydantic/v1/dataclasses.py index bd16702..86bad1e 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/dataclasses.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/dataclasses.py @@ -36,28 +36,21 @@ import dataclasses import sys from contextlib import contextmanager from functools import wraps - -try: - from functools import cached_property -except ImportError: - # cached_property available only for python3.8+ - pass - from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload from typing_extensions import dataclass_transform -from pydantic.v1.class_validators import gather_all_validators -from pydantic.v1.config import BaseConfig, ConfigDict, Extra, get_config -from pydantic.v1.error_wrappers import ValidationError -from pydantic.v1.errors import DataclassTypeError -from pydantic.v1.fields import Field, FieldInfo, Required, Undefined -from pydantic.v1.main import create_model, validate_model -from pydantic.v1.utils import ClassAttribute +from .class_validators import gather_all_validators +from .config import BaseConfig, ConfigDict, Extra, get_config +from .error_wrappers import ValidationError +from .errors import DataclassTypeError +from .fields import Field, FieldInfo, Required, Undefined +from .main import create_model, validate_model +from .utils import ClassAttribute if TYPE_CHECKING: - from pydantic.v1.main import BaseModel - from pydantic.v1.typing import CallableGenerator, NoArgAnyCallable + from .main import BaseModel + from .typing import CallableGenerator, NoArgAnyCallable DataclassT = TypeVar('DataclassT', bound='Dataclass') @@ -416,17 +409,6 @@ def create_pydantic_model_from_dataclass( return model -if sys.version_info >= (3, 8): - - def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: - return isinstance(getattr(type(obj), k, None), cached_property) - -else: - - def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: - return False - - def _dataclass_validate_values(self: 'Dataclass') -> None: # validation errors can occur if this function is called twice on an already initialised dataclass. # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property @@ -435,13 +417,9 @@ def _dataclass_validate_values(self: 'Dataclass') -> None: if getattr(self, '__pydantic_has_field_info_default__', False): # We need to remove `FieldInfo` values since they are not valid as input # It's ok to do that because they are obviously the default values! - input_data = { - k: v - for k, v in self.__dict__.items() - if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k)) - } + input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} else: - input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)} + input_data = self.__dict__ d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) if validation_error: raise validation_error diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/datetime_parse.py b/venv/lib/python3.12/site-packages/pydantic/v1/datetime_parse.py index a7598fc..cfd5459 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/datetime_parse.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/datetime_parse.py @@ -18,7 +18,7 @@ import re from datetime import date, datetime, time, timedelta, timezone from typing import Dict, Optional, Type, Union -from pydantic.v1 import errors +from . import errors date_expr = r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})' time_expr = ( diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/decorator.py b/venv/lib/python3.12/site-packages/pydantic/v1/decorator.py index 2c7c2c2..089aab6 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/decorator.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/decorator.py @@ -1,17 +1,17 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload -from pydantic.v1 import validator -from pydantic.v1.config import Extra -from pydantic.v1.errors import ConfigError -from pydantic.v1.main import BaseModel, create_model -from pydantic.v1.typing import get_all_type_hints -from pydantic.v1.utils import to_camel +from . import validator +from .config import Extra +from .errors import ConfigError +from .main import BaseModel, create_model +from .typing import get_all_type_hints +from .utils import to_camel __all__ = ('validate_arguments',) if TYPE_CHECKING: - from pydantic.v1.typing import AnyCallable + from .typing import AnyCallable AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) ConfigType = Union[None, Type[Any], Dict[str, Any]] diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/env_settings.py b/venv/lib/python3.12/site-packages/pydantic/v1/env_settings.py index 5f6f217..6c446e5 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/env_settings.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/env_settings.py @@ -3,12 +3,12 @@ import warnings from pathlib import Path from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union -from pydantic.v1.config import BaseConfig, Extra -from pydantic.v1.fields import ModelField -from pydantic.v1.main import BaseModel -from pydantic.v1.types import JsonWrapper -from pydantic.v1.typing import StrPath, display_as_type, get_origin, is_union -from pydantic.v1.utils import deep_update, lenient_issubclass, path_type, sequence_like +from .config import BaseConfig, Extra +from .fields import ModelField +from .main import BaseModel +from .types import JsonWrapper +from .typing import StrPath, display_as_type, get_origin, is_union +from .utils import deep_update, lenient_issubclass, path_type, sequence_like env_file_sentinel = str(object()) diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/error_wrappers.py b/venv/lib/python3.12/site-packages/pydantic/v1/error_wrappers.py index bc7f263..5d3204f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/error_wrappers.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/error_wrappers.py @@ -1,15 +1,15 @@ import json from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union -from pydantic.v1.json import pydantic_encoder -from pydantic.v1.utils import Representation +from .json import pydantic_encoder +from .utils import Representation if TYPE_CHECKING: from typing_extensions import TypedDict - from pydantic.v1.config import BaseConfig - from pydantic.v1.types import ModelOrDc - from pydantic.v1.typing import ReprArgs + from .config import BaseConfig + from .types import ModelOrDc + from .typing import ReprArgs Loc = Tuple[Union[int, str], ...] @@ -101,6 +101,7 @@ def flatten_errors( ) -> Generator['ErrorDict', None, None]: for error in errors: if isinstance(error, ErrorWrapper): + if loc: error_loc = loc + error.loc_tuple() else: diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/errors.py b/venv/lib/python3.12/site-packages/pydantic/v1/errors.py index 6e86442..7bdafdd 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/errors.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/errors.py @@ -2,12 +2,12 @@ from decimal import Decimal from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union -from pydantic.v1.typing import display_as_type +from .typing import display_as_type if TYPE_CHECKING: - from pydantic.v1.typing import DictStrAny + from .typing import DictStrAny -# explicitly state exports to avoid "from pydantic.v1.errors import *" also importing Decimal, Path etc. +# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc. __all__ = ( 'PydanticTypeError', 'PydanticValueError', diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/fields.py b/venv/lib/python3.12/site-packages/pydantic/v1/fields.py index 002b60c..b1856c1 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/fields.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/fields.py @@ -28,12 +28,12 @@ from typing import ( from typing_extensions import Annotated, Final -from pydantic.v1 import errors as errors_ -from pydantic.v1.class_validators import Validator, make_generic_validator, prep_validators -from pydantic.v1.error_wrappers import ErrorWrapper -from pydantic.v1.errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError -from pydantic.v1.types import Json, JsonWrapper -from pydantic.v1.typing import ( +from . import errors as errors_ +from .class_validators import Validator, make_generic_validator, prep_validators +from .error_wrappers import ErrorWrapper +from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError +from .types import Json, JsonWrapper +from .typing import ( NoArgAnyCallable, convert_generics, display_as_type, @@ -48,7 +48,7 @@ from pydantic.v1.typing import ( is_union, new_type_supertype, ) -from pydantic.v1.utils import ( +from .utils import ( PyObjectStr, Representation, ValueItems, @@ -59,7 +59,7 @@ from pydantic.v1.utils import ( sequence_like, smart_deepcopy, ) -from pydantic.v1.validators import constant_validator, dict_validator, find_validators, validate_json +from .validators import constant_validator, dict_validator, find_validators, validate_json Required: Any = Ellipsis @@ -83,11 +83,11 @@ class UndefinedType: Undefined = UndefinedType() if TYPE_CHECKING: - from pydantic.v1.class_validators import ValidatorsList - from pydantic.v1.config import BaseConfig - from pydantic.v1.error_wrappers import ErrorList - from pydantic.v1.types import ModelOrDc - from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs + from .class_validators import ValidatorsList + from .config import BaseConfig + from .error_wrappers import ErrorList + from .types import ModelOrDc + from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] LocStr = Union[Tuple[Union[int, str], ...], str] @@ -178,6 +178,7 @@ class FieldInfo(Representation): self.extra = kwargs def __repr_args__(self) -> 'ReprArgs': + field_defaults_to_hide: Dict[str, Any] = { 'repr': True, **self.__field_constraints__, @@ -404,6 +405,7 @@ class ModelField(Representation): alias: Optional[str] = None, field_info: Optional[FieldInfo] = None, ) -> None: + self.name: str = name self.has_alias: bool = alias is not None self.alias: str = alias if alias is not None else name @@ -490,7 +492,7 @@ class ModelField(Representation): class_validators: Optional[Dict[str, Validator]], config: Type['BaseConfig'], ) -> 'ModelField': - from pydantic.v1.schema import get_annotation_from_field_info + from .schema import get_annotation_from_field_info field_info, value = cls._get_field_info(name, annotation, value, config) required: 'BoolUndefined' = Undefined @@ -850,6 +852,7 @@ class ModelField(Representation): def validate( self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None ) -> 'ValidateReturn': + assert self.type_.__class__ is not DeferredType if self.type_.__class__ is ForwardRef: @@ -1160,7 +1163,7 @@ class ModelField(Representation): """ Whether the field is "complex" eg. env variables should be parsed as JSON. """ - from pydantic.v1.main import BaseModel + from .main import BaseModel return ( self.shape != SHAPE_SINGLETON diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/generics.py b/venv/lib/python3.12/site-packages/pydantic/v1/generics.py index 9a69f2b..a75b6b9 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/generics.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/generics.py @@ -22,12 +22,12 @@ from weakref import WeakKeyDictionary, WeakValueDictionary from typing_extensions import Annotated, Literal as ExtLiteral -from pydantic.v1.class_validators import gather_all_validators -from pydantic.v1.fields import DeferredType -from pydantic.v1.main import BaseModel, create_model -from pydantic.v1.types import JsonWrapper -from pydantic.v1.typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base -from pydantic.v1.utils import all_identical, lenient_issubclass +from .class_validators import gather_all_validators +from .fields import DeferredType +from .main import BaseModel, create_model +from .types import JsonWrapper +from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base +from .utils import all_identical, lenient_issubclass if sys.version_info >= (3, 10): from typing import _UnionGenericAlias diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/json.py b/venv/lib/python3.12/site-packages/pydantic/v1/json.py index 41d0d5f..b358b85 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/json.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/json.py @@ -9,9 +9,9 @@ from types import GeneratorType from typing import Any, Callable, Dict, Type, Union from uuid import UUID -from pydantic.v1.color import Color -from pydantic.v1.networks import NameEmail -from pydantic.v1.types import SecretBytes, SecretStr +from .color import Color +from .networks import NameEmail +from .types import SecretBytes, SecretStr __all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat' @@ -72,7 +72,7 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { def pydantic_encoder(obj: Any) -> Any: from dataclasses import asdict, is_dataclass - from pydantic.v1.main import BaseModel + from .main import BaseModel if isinstance(obj, BaseModel): return obj.dict() diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/main.py b/venv/lib/python3.12/site-packages/pydantic/v1/main.py index 8000967..683f3f8 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/main.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/main.py @@ -26,11 +26,11 @@ from typing import ( from typing_extensions import dataclass_transform -from pydantic.v1.class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators -from pydantic.v1.config import BaseConfig, Extra, inherit_config, prepare_config -from pydantic.v1.error_wrappers import ErrorWrapper, ValidationError -from pydantic.v1.errors import ConfigError, DictError, ExtraError, MissingError -from pydantic.v1.fields import ( +from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators +from .config import BaseConfig, Extra, inherit_config, prepare_config +from .error_wrappers import ErrorWrapper, ValidationError +from .errors import ConfigError, DictError, ExtraError, MissingError +from .fields import ( MAPPING_LIKE_SHAPES, Field, ModelField, @@ -39,11 +39,11 @@ from pydantic.v1.fields import ( Undefined, is_finalvar_with_default_val, ) -from pydantic.v1.json import custom_pydantic_encoder, pydantic_encoder -from pydantic.v1.parse import Protocol, load_file, load_str_bytes -from pydantic.v1.schema import default_ref_template, model_schema -from pydantic.v1.types import PyObject, StrBytes -from pydantic.v1.typing import ( +from .json import custom_pydantic_encoder, pydantic_encoder +from .parse import Protocol, load_file, load_str_bytes +from .schema import default_ref_template, model_schema +from .types import PyObject, StrBytes +from .typing import ( AnyCallable, get_args, get_origin, @@ -53,7 +53,7 @@ from pydantic.v1.typing import ( resolve_annotations, update_model_forward_refs, ) -from pydantic.v1.utils import ( +from .utils import ( DUNDER_ATTRIBUTES, ROOT_KEY, ClassAttribute, @@ -73,9 +73,9 @@ from pydantic.v1.utils import ( if TYPE_CHECKING: from inspect import Signature - from pydantic.v1.class_validators import ValidatorListDict - from pydantic.v1.types import ModelOrDc - from pydantic.v1.typing import ( + from .class_validators import ValidatorListDict + from .types import ModelOrDc + from .typing import ( AbstractSetIntStr, AnyClassMethod, CallableGenerator, @@ -282,12 +282,6 @@ class ModelMetaclass(ABCMeta): cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) # set __signature__ attr only for model class, but not for its instances cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config)) - - if not _is_base_model_class_defined: - # Cython does not understand the `if TYPE_CHECKING:` condition in the - # BaseModel's body (where annotations are set), so clear them manually: - getattr(cls, '__annotations__', {}).clear() - if resolve_forward_refs: cls.__try_update_forward_refs__() @@ -307,7 +301,7 @@ class ModelMetaclass(ABCMeta): See #3829 and python/cpython#92810 """ - return hasattr(instance, '__post_root_validators__') and super().__instancecheck__(instance) + return hasattr(instance, '__fields__') and super().__instancecheck__(instance) object_setattr = object.__setattr__ @@ -675,7 +669,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass): def schema_json( cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any ) -> str: - from pydantic.v1.json import pydantic_encoder + from .json import pydantic_encoder return cls.__config__.json_dumps( cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs @@ -743,6 +737,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass): exclude_defaults: bool, exclude_none: bool, ) -> Any: + if isinstance(v, BaseModel): if to_dict: v_dict = v.dict( @@ -835,6 +830,7 @@ class BaseModel(Representation, metaclass=ModelMetaclass): exclude_defaults: bool = False, exclude_none: bool = False, ) -> 'TupleGenerator': + # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. if exclude is not None or self.__exclude_fields__ is not None: diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/mypy.py b/venv/lib/python3.12/site-packages/pydantic/v1/mypy.py index 7912317..1d6d5ae 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/mypy.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/mypy.py @@ -57,7 +57,6 @@ from mypy.types import ( Type, TypeOfAny, TypeType, - TypeVarId, TypeVarType, UnionType, get_proper_type, @@ -66,7 +65,7 @@ from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name from mypy.version import __version__ as mypy_version -from pydantic.v1.utils import is_valid_field +from pydantic.utils import is_valid_field try: from mypy.types import TypeVarDef # type: ignore[attr-defined] @@ -499,11 +498,7 @@ class PydanticModelTransformer: tvd = TypeVarType( self_tvar_name, tvar_fullname, - ( - TypeVarId(-1, namespace=ctx.cls.fullname + '.construct') - if MYPY_VERSION_TUPLE >= (1, 11) - else TypeVarId(-1) - ), + -1, [], obj_type, AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type] @@ -863,9 +858,9 @@ def add_method( arg_kinds.append(arg.kind) function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function') - signature = CallableType( - arg_types, arg_kinds, arg_names, return_type, function_type, variables=[tvar_def] if tvar_def else None - ) + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) + if tvar_def: + signature.variables = [tvar_def] func = FuncDef(name, args, Block([PassStmt()])) func.info = info diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/networks.py b/venv/lib/python3.12/site-packages/pydantic/v1/networks.py index ba07b74..cfebe58 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/networks.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/networks.py @@ -27,17 +27,17 @@ from typing import ( no_type_check, ) -from pydantic.v1 import errors -from pydantic.v1.utils import Representation, update_not_none -from pydantic.v1.validators import constr_length_validator, str_validator +from . import errors +from .utils import Representation, update_not_none +from .validators import constr_length_validator, str_validator if TYPE_CHECKING: import email_validator from typing_extensions import TypedDict - from pydantic.v1.config import BaseConfig - from pydantic.v1.fields import ModelField - from pydantic.v1.typing import AnyCallable + from .config import BaseConfig + from .fields import ModelField + from .typing import AnyCallable CallableGenerator = Generator[AnyCallable, None, None] diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/parse.py b/venv/lib/python3.12/site-packages/pydantic/v1/parse.py index 431d75a..7ac330c 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/parse.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/parse.py @@ -4,7 +4,7 @@ from enum import Enum from pathlib import Path from typing import Any, Callable, Union -from pydantic.v1.types import StrBytes +from .types import StrBytes class Protocol(str, Enum): diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/schema.py b/venv/lib/python3.12/site-packages/pydantic/v1/schema.py index a91fe2c..31e8ae3 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/schema.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/schema.py @@ -31,7 +31,7 @@ from uuid import UUID from typing_extensions import Annotated, Literal -from pydantic.v1.fields import ( +from .fields import ( MAPPING_LIKE_SHAPES, SHAPE_DEQUE, SHAPE_FROZENSET, @@ -46,9 +46,9 @@ from pydantic.v1.fields import ( FieldInfo, ModelField, ) -from pydantic.v1.json import pydantic_encoder -from pydantic.v1.networks import AnyUrl, EmailStr -from pydantic.v1.types import ( +from .json import pydantic_encoder +from .networks import AnyUrl, EmailStr +from .types import ( ConstrainedDecimal, ConstrainedFloat, ConstrainedFrozenSet, @@ -69,7 +69,7 @@ from pydantic.v1.types import ( conset, constr, ) -from pydantic.v1.typing import ( +from .typing import ( all_literal_values, get_args, get_origin, @@ -80,11 +80,11 @@ from pydantic.v1.typing import ( is_none_type, is_union, ) -from pydantic.v1.utils import ROOT_KEY, get_model, lenient_issubclass +from .utils import ROOT_KEY, get_model, lenient_issubclass if TYPE_CHECKING: - from pydantic.v1.dataclasses import Dataclass - from pydantic.v1.main import BaseModel + from .dataclasses import Dataclass + from .main import BaseModel default_prefix = '#/definitions/' default_ref_template = '#/definitions/{model}' @@ -198,6 +198,7 @@ def model_schema( def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]: + # If no title is explicitly set, we don't set title in the schema for enums. # The behaviour is the same as `BaseModel` reference, where the default title # is in the definitions part of the schema. @@ -378,7 +379,7 @@ def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> :param known_models: used to solve circular references :return: a set with the model used in the declaration for this field, if any, and all its sub-models """ - from pydantic.v1.main import BaseModel + from .main import BaseModel flat_models: TypeModelSet = set() @@ -445,7 +446,7 @@ def field_type_schema( Take a single ``field`` and generate the schema for its type only, not including additional information as title, etc. Also return additional schema definitions, from sub-models. """ - from pydantic.v1.main import BaseModel # noqa: F811 + from .main import BaseModel # noqa: F811 definitions = {} nested_models: Set[str] = set() @@ -738,7 +739,7 @@ def field_singleton_sub_fields_schema( discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref'] s['discriminator'] = { - 'propertyName': field.discriminator_alias if by_alias else field.discriminator_key, + 'propertyName': field.discriminator_alias, 'mapping': discriminator_models_refs, } @@ -838,7 +839,7 @@ def field_singleton_schema( # noqa: C901 (ignore complexity) Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models. """ - from pydantic.v1.main import BaseModel + from .main import BaseModel definitions: Dict[str, Any] = {} nested_models: Set[str] = set() @@ -974,7 +975,7 @@ def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelFie def encode_default(dft: Any) -> Any: - from pydantic.v1.main import BaseModel + from .main import BaseModel if isinstance(dft, BaseModel) or is_dataclass(dft): dft = cast('dict[str, Any]', pydantic_encoder(dft)) @@ -1090,7 +1091,7 @@ def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> T if issubclass(type_, (SecretStr, SecretBytes)): attrs = ('max_length', 'min_length') - def constraint_func(**kw: Any) -> Type[Any]: # noqa: F811 + def constraint_func(**kw: Any) -> Type[Any]: return type(type_.__name__, (type_,), kw) elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)): diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/tools.py b/venv/lib/python3.12/site-packages/pydantic/v1/tools.py index 6838a23..45be277 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/tools.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/tools.py @@ -3,16 +3,16 @@ from functools import lru_cache from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union -from pydantic.v1.parse import Protocol, load_file, load_str_bytes -from pydantic.v1.types import StrBytes -from pydantic.v1.typing import display_as_type +from .parse import Protocol, load_file, load_str_bytes +from .types import StrBytes +from .typing import display_as_type __all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of') NameFactory = Union[str, Callable[[Type[Any]], str]] if TYPE_CHECKING: - from pydantic.v1.typing import DictStrAny + from .typing import DictStrAny def _generate_parsing_type_name(type_: Any) -> str: @@ -21,7 +21,7 @@ def _generate_parsing_type_name(type_: Any) -> str: @lru_cache(maxsize=2048) def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any: - from pydantic.v1.main import create_model + from .main import create_model if type_name is None: type_name = _generate_parsing_type_name diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/types.py b/venv/lib/python3.12/site-packages/pydantic/v1/types.py index 0cd789a..5881e74 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/types.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/types.py @@ -28,10 +28,10 @@ from typing import ( from uuid import UUID from weakref import WeakSet -from pydantic.v1 import errors -from pydantic.v1.datetime_parse import parse_date -from pydantic.v1.utils import import_string, update_not_none -from pydantic.v1.validators import ( +from . import errors +from .datetime_parse import parse_date +from .utils import import_string, update_not_none +from .validators import ( bytes_validator, constr_length_validator, constr_lower, @@ -123,9 +123,9 @@ StrIntFloat = Union[str, int, float] if TYPE_CHECKING: from typing_extensions import Annotated - from pydantic.v1.dataclasses import Dataclass - from pydantic.v1.main import BaseModel - from pydantic.v1.typing import CallableGenerator + from .dataclasses import Dataclass + from .main import BaseModel + from .typing import CallableGenerator ModelOrDc = Type[Union[BaseModel, Dataclass]] @@ -481,7 +481,6 @@ else: # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # This types superclass should be Set[T], but cython chokes on that... class ConstrainedSet(set): # type: ignore # Needed for pydantic to detect that this is a set @@ -570,7 +569,6 @@ def confrozenset( # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # This types superclass should be List[T], but cython chokes on that... class ConstrainedList(list): # type: ignore # Needed for pydantic to detect that this is a list @@ -1096,6 +1094,7 @@ class ByteSize(int): @classmethod def validate(cls, v: StrIntFloat) -> 'ByteSize': + try: return cls(int(v)) except ValueError: @@ -1117,6 +1116,7 @@ class ByteSize(int): return cls(int(float(scalar) * unit_mult)) def human_readable(self, decimal: bool = False) -> str: + if decimal: divisor = 1000 units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] @@ -1135,6 +1135,7 @@ class ByteSize(int): return f'{num:0.1f}{final_unit}' def to(self, unit: str) -> float: + try: unit_div = BYTE_SIZES[unit.lower()] except KeyError: diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/typing.py b/venv/lib/python3.12/site-packages/pydantic/v1/typing.py index 3038ccd..a690a05 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/typing.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/typing.py @@ -58,21 +58,12 @@ if sys.version_info < (3, 9): def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: return type_._evaluate(globalns, localns) -elif sys.version_info < (3, 12, 4): +else: def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: # Even though it is the right signature for python 3.9, mypy complains with # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... - # Python 3.13/3.12.4+ made `recursive_guard` a kwarg, so name it explicitly to avoid: - # TypeError: ForwardRef._evaluate() missing 1 required keyword-only argument: 'recursive_guard' - return cast(Any, type_)._evaluate(globalns, localns, recursive_guard=set()) - -else: - - def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: - # Pydantic 1.x will not support PEP 695 syntax, but provide `type_params` to avoid - # warnings: - return cast(Any, type_)._evaluate(globalns, localns, type_params=(), recursive_guard=set()) + return cast(Any, type_)._evaluate(globalns, localns, set()) if sys.version_info < (3, 9): @@ -265,7 +256,7 @@ StrPath = Union[str, PathLike] if TYPE_CHECKING: - from pydantic.v1.fields import ModelField + from .fields import ModelField TupleGenerator = Generator[Tuple[str, Any], None, None] DictStrAny = Dict[str, Any] @@ -406,10 +397,7 @@ def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Opti else: value = ForwardRef(value, is_argument=False) try: - if sys.version_info >= (3, 13): - value = _eval_type(value, base_globals, None, type_params=()) - else: - value = _eval_type(value, base_globals, None) + value = _eval_type(value, base_globals, None) except NameError: # this is ok, it can be fixed with update_forward_refs pass @@ -447,7 +435,7 @@ def is_namedtuple(type_: Type[Any]) -> bool: Check if a given class is a named tuple. It can be either a `typing.NamedTuple` or `collections.namedtuple` """ - from pydantic.v1.utils import lenient_issubclass + from .utils import lenient_issubclass return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') @@ -457,7 +445,7 @@ def is_typeddict(type_: Type[Any]) -> bool: Check if a given class is a typed dict (from `typing` or `typing_extensions`) In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict) """ - from pydantic.v1.utils import lenient_issubclass + from .utils import lenient_issubclass return lenient_issubclass(type_, dict) and hasattr(type_, '__total__') diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/utils.py b/venv/lib/python3.12/site-packages/pydantic/v1/utils.py index 02543fd..4d0f68e 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/utils.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/utils.py @@ -28,8 +28,8 @@ from typing import ( from typing_extensions import Annotated -from pydantic.v1.errors import ConfigError -from pydantic.v1.typing import ( +from .errors import ConfigError +from .typing import ( NoneType, WithArgsTypes, all_literal_values, @@ -39,17 +39,17 @@ from pydantic.v1.typing import ( is_literal_type, is_union, ) -from pydantic.v1.version import version_info +from .version import version_info if TYPE_CHECKING: from inspect import Signature from pathlib import Path - from pydantic.v1.config import BaseConfig - from pydantic.v1.dataclasses import Dataclass - from pydantic.v1.fields import ModelField - from pydantic.v1.main import BaseModel - from pydantic.v1.typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs + from .config import BaseConfig + from .dataclasses import Dataclass + from .fields import ModelField + from .main import BaseModel + from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] @@ -66,7 +66,6 @@ __all__ = ( 'almost_equal_floats', 'get_model', 'to_camel', - 'to_lower_camel', 'is_valid_field', 'smart_deepcopy', 'PyObjectStr', @@ -159,7 +158,7 @@ def sequence_like(v: Any) -> bool: return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) -def validate_field_name(bases: Iterable[Type[Any]], field_name: str) -> None: +def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None: """ Ensure that the field's name does not shadow an existing attribute of the model. """ @@ -241,7 +240,7 @@ def generate_model_signature( """ from inspect import Parameter, Signature, signature - from pydantic.v1.config import Extra + from .config import Extra present_params = signature(init).parameters.values() merged_params: Dict[str, Parameter] = {} @@ -299,7 +298,7 @@ def generate_model_signature( def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: - from pydantic.v1.main import BaseModel + from .main import BaseModel try: model_cls = obj.__pydantic_model__ # type: ignore @@ -708,8 +707,6 @@ DUNDER_ATTRIBUTES = { '__orig_bases__', '__orig_class__', '__qualname__', - '__firstlineno__', - '__static_attributes__', } diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/validators.py b/venv/lib/python3.12/site-packages/pydantic/v1/validators.py index c0940e8..549a235 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/validators.py @@ -27,11 +27,10 @@ from typing import ( Union, ) from uuid import UUID -from warnings import warn -from pydantic.v1 import errors -from pydantic.v1.datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from pydantic.v1.typing import ( +from . import errors +from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time +from .typing import ( AnyCallable, all_literal_values, display_as_type, @@ -42,14 +41,14 @@ from pydantic.v1.typing import ( is_none_type, is_typeddict, ) -from pydantic.v1.utils import almost_equal_floats, lenient_issubclass, sequence_like +from .utils import almost_equal_floats, lenient_issubclass, sequence_like if TYPE_CHECKING: from typing_extensions import Literal, TypedDict - from pydantic.v1.config import BaseConfig - from pydantic.v1.fields import ModelField - from pydantic.v1.types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt + from .config import BaseConfig + from .fields import ModelField + from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt] AnyOrderedDict = OrderedDict[Any, Any] @@ -595,7 +594,7 @@ NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple) def make_namedtuple_validator( namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig'] ) -> Callable[[Tuple[Any, ...]], NamedTupleT]: - from pydantic.v1.annotated_types import create_model_from_namedtuple + from .annotated_types import create_model_from_namedtuple NamedTupleModel = create_model_from_namedtuple( namedtuple_cls, @@ -620,7 +619,7 @@ def make_namedtuple_validator( def make_typeddict_validator( typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type] ) -> Callable[[Any], Dict[str, Any]]: - from pydantic.v1.annotated_types import create_model_from_typeddict + from .annotated_types import create_model_from_typeddict TypedDictModel = create_model_from_typeddict( typeddict_cls, @@ -699,7 +698,7 @@ _VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [ def find_validators( # noqa: C901 (ignore complexity) type_: Type[Any], config: Type['BaseConfig'] ) -> Generator[AnyCallable, None, None]: - from pydantic.v1.dataclasses import is_builtin_dataclass, make_dataclass_validator + from .dataclasses import is_builtin_dataclass, make_dataclass_validator if type_ is Any or type_ is object: return @@ -763,6 +762,4 @@ def find_validators( # noqa: C901 (ignore complexity) if config.arbitrary_types_allowed: yield make_arbitrary_type_validator(type_) else: - if hasattr(type_, '__pydantic_core_schema__'): - warn(f'Mixing V1 and V2 models is not supported. `{type_.__name__}` is a V2 model.', UserWarning) raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config') diff --git a/venv/lib/python3.12/site-packages/pydantic/v1/version.py b/venv/lib/python3.12/site-packages/pydantic/v1/version.py index c77cde1..462c497 100644 --- a/venv/lib/python3.12/site-packages/pydantic/v1/version.py +++ b/venv/lib/python3.12/site-packages/pydantic/v1/version.py @@ -1,6 +1,6 @@ __all__ = 'compiled', 'VERSION', 'version_info' -VERSION = '1.10.21' +VERSION = '1.10.13' try: import cython # type: ignore diff --git a/venv/lib/python3.12/site-packages/pydantic/validate_call.py b/venv/lib/python3.12/site-packages/pydantic/validate_call.py new file mode 100644 index 0000000..8058486 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic/validate_call.py @@ -0,0 +1,58 @@ +"""Decorator for validating function calls.""" +from __future__ import annotations as _annotations + +from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload + +from ._internal import _validate_call + +__all__ = ('validate_call',) + +if TYPE_CHECKING: + from .config import ConfigDict + + AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any]) + + +@overload +def validate_call( + *, config: ConfigDict | None = None, validate_return: bool = False +) -> Callable[[AnyCallableT], AnyCallableT]: + ... + + +@overload +def validate_call(__func: AnyCallableT) -> AnyCallableT: + ... + + +def validate_call( + __func: AnyCallableT | None = None, + *, + config: ConfigDict | None = None, + validate_return: bool = False, +) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]: + """Usage docs: https://docs.pydantic.dev/2.4/concepts/validation_decorator/ + + Returns a decorated wrapper around the function that validates the arguments and, optionally, the return value. + + Usage may be either as a plain decorator `@validate_call` or with arguments `@validate_call(...)`. + + Args: + __func: The function to be decorated. + config: The configuration dictionary. + validate_return: Whether to validate the return value. + + Returns: + The decorated function. + """ + + def validate(function: AnyCallableT) -> AnyCallableT: + if isinstance(function, (classmethod, staticmethod)): + name = type(function).__name__ + raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)') + return _validate_call.ValidateCallWrapper(function, config, validate_return) # type: ignore + + if __func: + return validate(__func) + else: + return validate diff --git a/venv/lib/python3.12/site-packages/pydantic/validate_call_decorator.py b/venv/lib/python3.12/site-packages/pydantic/validate_call_decorator.py deleted file mode 100644 index fe4d9c9..0000000 --- a/venv/lib/python3.12/site-packages/pydantic/validate_call_decorator.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Decorator for validating function calls.""" - -from __future__ import annotations as _annotations - -import inspect -from functools import partial -from types import BuiltinFunctionType -from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast, overload - -from ._internal import _generate_schema, _typing_extra, _validate_call -from .errors import PydanticUserError - -__all__ = ('validate_call',) - -if TYPE_CHECKING: - from .config import ConfigDict - - AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any]) - - -_INVALID_TYPE_ERROR_CODE = 'validate-call-type' - - -def _check_function_type(function: object) -> None: - """Check if the input function is a supported type for `validate_call`.""" - if isinstance(function, _generate_schema.VALIDATE_CALL_SUPPORTED_TYPES): - try: - inspect.signature(cast(_generate_schema.ValidateCallSupportedTypes, function)) - except ValueError: - raise PydanticUserError( - f"Input function `{function}` doesn't have a valid signature", code=_INVALID_TYPE_ERROR_CODE - ) - - if isinstance(function, partial): - try: - assert not isinstance(partial.func, partial), 'Partial of partial' - _check_function_type(function.func) - except PydanticUserError as e: - raise PydanticUserError( - f'Partial of `{function.func}` is invalid because the type of `{function.func}` is not supported by `validate_call`', - code=_INVALID_TYPE_ERROR_CODE, - ) from e - - return - - if isinstance(function, BuiltinFunctionType): - raise PydanticUserError(f'Input built-in function `{function}` is not supported', code=_INVALID_TYPE_ERROR_CODE) - if isinstance(function, (classmethod, staticmethod, property)): - name = type(function).__name__ - raise PydanticUserError( - f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)', - code=_INVALID_TYPE_ERROR_CODE, - ) - - if inspect.isclass(function): - raise PydanticUserError( - f'Unable to validate {function}: `validate_call` should be applied to functions, not classes (put `@validate_call` on top of `__init__` or `__new__` instead)', - code=_INVALID_TYPE_ERROR_CODE, - ) - if callable(function): - raise PydanticUserError( - f'Unable to validate {function}: `validate_call` should be applied to functions, not instances or other callables. Use `validate_call` explicitly on `__call__` instead.', - code=_INVALID_TYPE_ERROR_CODE, - ) - - raise PydanticUserError( - f'Unable to validate {function}: `validate_call` should be applied to one of the following: function, method, partial, or lambda', - code=_INVALID_TYPE_ERROR_CODE, - ) - - -@overload -def validate_call( - *, config: ConfigDict | None = None, validate_return: bool = False -) -> Callable[[AnyCallableT], AnyCallableT]: ... - - -@overload -def validate_call(func: AnyCallableT, /) -> AnyCallableT: ... - - -def validate_call( - func: AnyCallableT | None = None, - /, - *, - config: ConfigDict | None = None, - validate_return: bool = False, -) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]: - """!!! abstract "Usage Documentation" - [Validation Decorator](../concepts/validation_decorator.md) - - Returns a decorated wrapper around the function that validates the arguments and, optionally, the return value. - - Usage may be either as a plain decorator `@validate_call` or with arguments `@validate_call(...)`. - - Args: - func: The function to be decorated. - config: The configuration dictionary. - validate_return: Whether to validate the return value. - - Returns: - The decorated function. - """ - parent_namespace = _typing_extra.parent_frame_namespace() - - def validate(function: AnyCallableT) -> AnyCallableT: - _check_function_type(function) - validate_call_wrapper = _validate_call.ValidateCallWrapper( - cast(_generate_schema.ValidateCallSupportedTypes, function), config, validate_return, parent_namespace - ) - return _validate_call.update_wrapper_attributes(function, validate_call_wrapper.__call__) # type: ignore - - if func is not None: - return validate(func) - else: - return validate diff --git a/venv/lib/python3.12/site-packages/pydantic/validators.py b/venv/lib/python3.12/site-packages/pydantic/validators.py index 7921b04..55b0339 100644 --- a/venv/lib/python3.12/site-packages/pydantic/validators.py +++ b/venv/lib/python3.12/site-packages/pydantic/validators.py @@ -1,5 +1,4 @@ """The `validators` module is a backport module from V1.""" - from ._migration import getattr_migration __getattr__ = getattr_migration(__name__) diff --git a/venv/lib/python3.12/site-packages/pydantic/version.py b/venv/lib/python3.12/site-packages/pydantic/version.py index 28d77e9..d65780f 100644 --- a/venv/lib/python3.12/site-packages/pydantic/version.py +++ b/venv/lib/python3.12/site-packages/pydantic/version.py @@ -1,12 +1,9 @@ """The `version` module holds the version information for Pydantic.""" - -from __future__ import annotations as _annotations - -from pydantic_core import __version__ as __pydantic_core_version__ +from typing import Tuple __all__ = 'VERSION', 'version_info' -VERSION = '2.11.9' +VERSION = '2.4.2' """The version of Pydantic.""" @@ -20,14 +17,16 @@ def version_short() -> str: def version_info() -> str: """Return complete version information for Pydantic and its dependencies.""" - import importlib.metadata import platform import sys from pathlib import Path import pydantic_core._pydantic_core as pdc - from ._internal import _git as git + if sys.version_info >= (3, 8): + import importlib.metadata as importlib_metadata + else: + import importlib_metadata # get data about packages that are closely related to pydantic, use pydantic or often conflict with pydantic package_names = { @@ -41,44 +40,36 @@ def version_info() -> str: } related_packages = [] - for dist in importlib.metadata.distributions(): + for dist in importlib_metadata.distributions(): name = dist.metadata['Name'] if name in package_names: related_packages.append(f'{name}-{dist.version}') - pydantic_dir = Path(__file__).parents[1].resolve() - most_recent_commit = ( - git.git_revision(pydantic_dir) if git.is_git_repo(pydantic_dir) and git.have_git() else 'unknown' - ) - info = { 'pydantic version': VERSION, 'pydantic-core version': pdc.__version__, 'pydantic-core build': getattr(pdc, 'build_info', None) or pdc.build_profile, + 'install path': Path(__file__).resolve().parent, 'python version': sys.version, 'platform': platform.platform(), 'related packages': ' '.join(related_packages), - 'commit': most_recent_commit, } return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items()) -def check_pydantic_core_version() -> bool: - """Check that the installed `pydantic-core` dependency is compatible.""" - # Keep this in sync with the version constraint in the `pyproject.toml` dependencies: - return __pydantic_core_version__ == '2.33.2' +def parse_mypy_version(version: str) -> Tuple[int, ...]: + """Parse mypy string version to tuple of ints. + This function is included here rather than the mypy plugin file because the mypy plugin file cannot be imported + outside a mypy run. -def parse_mypy_version(version: str) -> tuple[int, int, int]: - """Parse `mypy` string version to a 3-tuple of ints. - - It parses normal version like `1.11.0` and extra info followed by a `+` sign - like `1.11.0+dev.d6d9d8cd4f27c52edac1f537e236ec48a01e54cb.dirty`. + It parses normal version like `0.930` and dev version + like `0.940+dev.04cac4b5d911c4f9529e6ce86a27b44f28846f5d.dirty`. Args: version: The mypy version string. Returns: - A triple of ints, e.g. `(1, 11, 0)`. + A tuple of ints. e.g. (0, 930). """ - return tuple(map(int, version.partition('+')[0].split('.'))) # pyright: ignore[reportReturnType] + return tuple(map(int, version.partition('+')[0].split('.'))) diff --git a/venv/lib/python3.12/site-packages/pydantic/warnings.py b/venv/lib/python3.12/site-packages/pydantic/warnings.py index 6e874dd..4b7b760 100644 --- a/venv/lib/python3.12/site-packages/pydantic/warnings.py +++ b/venv/lib/python3.12/site-packages/pydantic/warnings.py @@ -1,18 +1,9 @@ """Pydantic-specific warnings.""" - from __future__ import annotations as _annotations from .version import version_short -__all__ = ( - 'PydanticDeprecatedSince20', - 'PydanticDeprecatedSince26', - 'PydanticDeprecatedSince29', - 'PydanticDeprecatedSince210', - 'PydanticDeprecatedSince211', - 'PydanticDeprecationWarning', - 'PydanticExperimentalWarning', -) +__all__ = 'PydanticDeprecatedSince20', 'PydanticDeprecationWarning' class PydanticDeprecationWarning(DeprecationWarning): @@ -54,43 +45,3 @@ class PydanticDeprecatedSince20(PydanticDeprecationWarning): def __init__(self, message: str, *args: object) -> None: super().__init__(message, *args, since=(2, 0), expected_removal=(3, 0)) - - -class PydanticDeprecatedSince26(PydanticDeprecationWarning): - """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.6.""" - - def __init__(self, message: str, *args: object) -> None: - super().__init__(message, *args, since=(2, 6), expected_removal=(3, 0)) - - -class PydanticDeprecatedSince29(PydanticDeprecationWarning): - """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.9.""" - - def __init__(self, message: str, *args: object) -> None: - super().__init__(message, *args, since=(2, 9), expected_removal=(3, 0)) - - -class PydanticDeprecatedSince210(PydanticDeprecationWarning): - """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.10.""" - - def __init__(self, message: str, *args: object) -> None: - super().__init__(message, *args, since=(2, 10), expected_removal=(3, 0)) - - -class PydanticDeprecatedSince211(PydanticDeprecationWarning): - """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.11.""" - - def __init__(self, message: str, *args: object) -> None: - super().__init__(message, *args, since=(2, 11), expected_removal=(3, 0)) - - -class GenericBeforeBaseModelWarning(Warning): - pass - - -class PydanticExperimentalWarning(Warning): - """A Pydantic specific experimental functionality warning. - - This warning is raised when using experimental functionality in Pydantic. - It is raised to warn users that the functionality may change or be removed in future versions of Pydantic. - """ diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/METADATA b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/METADATA similarity index 91% rename from venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/METADATA rename to venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/METADATA index 9b98dfc..b7ca267 100644 --- a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/METADATA @@ -1,15 +1,16 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: pydantic_core -Version: 2.33.2 +Version: 2.10.1 Classifier: Development Status :: 3 - Alpha Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Rust Classifier: Framework :: Pydantic Classifier: Intended Audience :: Developers @@ -19,13 +20,12 @@ Classifier: Operating System :: POSIX :: Linux Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: MacOS Classifier: Typing :: Typed -Requires-Dist: typing-extensions>=4.6.0,!=4.7.0 +Requires-Dist: typing-extensions >=4.6.0, !=4.7.0 License-File: LICENSE -Summary: Core functionality for Pydantic validation and serialization Home-Page: https://github.com/pydantic/pydantic-core -Author-email: Samuel Colvin , Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com>, David Montague , David Hewitt , Sydney Runkle , Victorien Plot +Author-email: Samuel Colvin License: MIT -Requires-Python: >=3.9 +Requires-Python: >=3.7 Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM Project-URL: Homepage, https://github.com/pydantic/pydantic-core Project-URL: Funding, https://github.com/sponsors/samuelcolvin @@ -104,7 +104,7 @@ except ValidationError as e: You'll need rust stable [installed](https://rustup.rs/), or rust nightly if you want to generate accurate coverage. -With rust and python 3.9+ installed, compiling pydantic-core should be possible with roughly the following: +With rust and python 3.7+ installed, compiling pydantic-core should be possible with roughly the following: ```bash # clone this repo or your fork diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/RECORD new file mode 100644 index 0000000..546cf34 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/RECORD @@ -0,0 +1,12 @@ +pydantic_core-2.10.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +pydantic_core-2.10.1.dist-info/METADATA,sha256=WGUIBHZlOo9atU3rxpgTG6OSeI8Bb1bd1W_euCfgKLw,6514 +pydantic_core-2.10.1.dist-info/RECORD,, +pydantic_core-2.10.1.dist-info/WHEEL,sha256=Q-k7sgnFAZDxP-9P2jtPmCpI5l3MZojpu9e7PGDJ4NQ,129 +pydantic_core-2.10.1.dist-info/license_files/LICENSE,sha256=Kv3TDVS01itvSIprzBVG6E7FBh8T9CCcA9ASNIeDeVo,1080 +pydantic_core/__init__.py,sha256=WQk9nOr2kAjsCk8pd2NBCN67BD-QdOqrJP8ML-CU0jk,4165 +pydantic_core/__pycache__/__init__.cpython-312.pyc,, +pydantic_core/__pycache__/core_schema.cpython-312.pyc,, +pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so,sha256=a4S5LWAjKwdg1d0oQteJ2aqdZQJkj5GMtMTxJE41Ue4,5019712 +pydantic_core/_pydantic_core.pyi,sha256=Wm-59ewesZFPidlOEVjvT48hP0DAYW-SxVtUqeCCFHM,31733 +pydantic_core/core_schema.py,sha256=2zEA2bnYNOmM_3x4yqzoMMkJUl86803JqlpABuvjVrk,132468 +pydantic_core/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/WHEEL b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/WHEEL similarity index 79% rename from venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/WHEEL index 379e62d..b707a14 100644 --- a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/WHEEL @@ -1,4 +1,4 @@ Wheel-Version: 1.0 -Generator: maturin (1.8.3) +Generator: maturin (1.2.3) Root-Is-Purelib: false Tag: cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64 diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/license_files/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/pydantic_core-2.10.1.dist-info/license_files/LICENSE diff --git a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/RECORD deleted file mode 100644 index 16c7172..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_core-2.33.2.dist-info/RECORD +++ /dev/null @@ -1,12 +0,0 @@ -pydantic_core-2.33.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -pydantic_core-2.33.2.dist-info/METADATA,sha256=78lBoOZz4Kzfzz_yMI_qFHMFs2SE3VgnWpeGPtjycKs,6757 -pydantic_core-2.33.2.dist-info/RECORD,, -pydantic_core-2.33.2.dist-info/WHEEL,sha256=S1kMhEh6EykiwczohhYmFNC42qiDle6DE4gvf4LHuck,129 -pydantic_core-2.33.2.dist-info/licenses/LICENSE,sha256=Kv3TDVS01itvSIprzBVG6E7FBh8T9CCcA9ASNIeDeVo,1080 -pydantic_core/__init__.py,sha256=TzOWuJMgpXaZcPiS2Yjd8OUqjPbKOupdzXp3dZjWCGc,4403 -pydantic_core/__pycache__/__init__.cpython-312.pyc,, -pydantic_core/__pycache__/core_schema.cpython-312.pyc,, -pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so,sha256=JxGjRu_zhJCQmK5aHR24qPdB8jAMWcH8ZR85w3C40LQ,4776920 -pydantic_core/_pydantic_core.pyi,sha256=xIR9CkJaClUD5HcHtGEPElBpWiD33PaxLKnss8rsuSM,43359 -pydantic_core/core_schema.py,sha256=98qpsz-jklOqmsA9h-zWg4K4jNkkk6N_nDBLW3Cjp-w,149655 -pydantic_core/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.12/site-packages/pydantic_core/__init__.py b/venv/lib/python3.12/site-packages/pydantic_core/__init__.py index 98b64b8..a46a77b 100644 --- a/venv/lib/python3.12/site-packages/pydantic_core/__init__.py +++ b/venv/lib/python3.12/site-packages/pydantic_core/__init__.py @@ -22,7 +22,6 @@ from ._pydantic_core import ( Url, ValidationError, __version__, - from_json, to_json, to_jsonable_python, validate_core_schema, @@ -34,7 +33,7 @@ if _sys.version_info < (3, 11): else: from typing import NotRequired as _NotRequired -if _sys.version_info < (3, 12): +if _sys.version_info < (3, 9): from typing_extensions import TypedDict as _TypedDict else: from typing import TypedDict as _TypedDict @@ -64,7 +63,6 @@ __all__ = [ 'PydanticSerializationUnexpectedValue', 'TzInfo', 'to_json', - 'from_json', 'to_jsonable_python', 'validate_core_schema', ] @@ -89,16 +87,11 @@ class ErrorDetails(_TypedDict): Values which are required to render the error message, and could hence be useful in rendering custom error messages. Also useful for passing custom error data forward. """ - url: _NotRequired[str] - """ - The documentation URL giving information about the error. No URL is available if - a [`PydanticCustomError`][pydantic_core.PydanticCustomError] is used. - """ class InitErrorDetails(_TypedDict): type: str | PydanticCustomError - """The type of error that occurred, this should be a "slug" identifier that changes rarely or never.""" + """The type of error that occurred, this should a "slug" identifier that changes rarely or never.""" loc: _NotRequired[tuple[int | str, ...]] """Tuple of strings and ints identifying where in the schema the error occurred.""" input: _Any diff --git a/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so index e1797ec..dc56825 100755 Binary files a/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.pyi b/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.pyi index 17098cc..8ed3092 100644 --- a/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.pyi +++ b/venv/lib/python3.12/site-packages/pydantic_core/_pydantic_core.pyi @@ -1,13 +1,24 @@ -import datetime -from collections.abc import Mapping -from typing import Any, Callable, Generic, Literal, TypeVar, final +from __future__ import annotations -from _typeshed import SupportsAllComparisons -from typing_extensions import LiteralString, Self, TypeAlias +import datetime +import sys +from typing import Any, Callable, Generic, Optional, Type, TypeVar from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType +if sys.version_info < (3, 8): + from typing_extensions import final +else: + from typing import final + +if sys.version_info < (3, 11): + from typing_extensions import Literal, LiteralString, Self, TypeAlias +else: + from typing import Literal, LiteralString, Self, TypeAlias + +from _typeshed import SupportsAllComparisons + __all__ = [ '__version__', 'build_profile', @@ -30,7 +41,6 @@ __all__ = [ 'PydanticUndefinedType', 'Some', 'to_json', - 'from_json', 'to_jsonable_python', 'list_all_errors', 'TzInfo', @@ -62,7 +72,7 @@ class Some(Generic[_T]): Returns the value wrapped by `Some`. """ @classmethod - def __class_getitem__(cls, item: Any, /) -> type[Self]: ... + def __class_getitem__(cls, __item: Any) -> Type[Self]: ... @final class SchemaValidator: @@ -71,18 +81,14 @@ class SchemaValidator: `CombinedValidator` which may in turn own more `CombinedValidator`s which make up the full schema validator. """ - # note: pyo3 currently supports __new__, but not __init__, though we include __init__ stubs - # and docstrings here (and in the following classes) for documentation purposes - - def __init__(self, schema: CoreSchema, config: CoreConfig | None = None) -> None: - """Initializes the `SchemaValidator`. + def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: + """ + Create a new SchemaValidator. Arguments: - schema: The `CoreSchema` to use for validation. + schema: The [`CoreSchema`][pydantic_core.core_schema.CoreSchema] to use for validation. config: Optionally a [`CoreConfig`][pydantic_core.core_schema.CoreConfig] to configure validation. """ - - def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: ... @property def title(self) -> str: """ @@ -94,11 +100,8 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: Any | None = None, + context: 'dict[str, Any] | None' = None, self_instance: Any | None = None, - allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> Any: """ Validate a Python object against the schema and return the validated object. @@ -113,11 +116,6 @@ class SchemaValidator: [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. self_instance: An instance of a model set attributes on from validation, this is used when running validation from the `__init__` method of a model. - allow_partial: Whether to allow partial validation; if `True` errors in the last element of sequences - and mappings are ignored. - `'trailing-strings'` means any final unfinished JSON string is included in the result. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Raises: ValidationError: If validation fails. @@ -132,10 +130,8 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: Any | None = None, + context: 'dict[str, Any] | None' = None, self_instance: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> bool: """ Similar to [`validate_python()`][pydantic_core.SchemaValidator.validate_python] but returns a boolean. @@ -151,11 +147,8 @@ class SchemaValidator: input: str | bytes | bytearray, *, strict: bool | None = None, - context: Any | None = None, + context: 'dict[str, Any] | None' = None, self_instance: Any | None = None, - allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, ) -> Any: """ Validate JSON data directly against the schema and return the validated Python object. @@ -173,11 +166,6 @@ class SchemaValidator: context: The context to use for validation, this is passed to functional validators as [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. self_instance: An instance of a model set attributes on from validation. - allow_partial: Whether to allow partial validation; if `True` incomplete JSON will be parsed successfully - and errors in the last element of sequences and mappings are ignored. - `'trailing-strings'` means any final unfinished JSON string is included in the result. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Raises: ValidationError: If validation fails or if the JSON data is invalid. @@ -187,14 +175,7 @@ class SchemaValidator: The validated Python object. """ def validate_strings( - self, - input: _StringInput, - *, - strict: bool | None = None, - context: Any | None = None, - allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - by_alias: bool | None = None, - by_name: bool | None = None, + self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None ) -> Any: """ Validate a string against the schema and return the validated Python object. @@ -208,11 +189,6 @@ class SchemaValidator: If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. context: The context to use for validation, this is passed to functional validators as [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. - allow_partial: Whether to allow partial validation; if `True` errors in the last element of sequences - and mappings are ignored. - `'trailing-strings'` means any final unfinished JSON string is included in the result. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Raises: ValidationError: If validation fails or if the JSON data is invalid. @@ -229,9 +205,7 @@ class SchemaValidator: *, strict: bool | None = None, from_attributes: bool | None = None, - context: Any | None = None, - by_alias: bool | None = None, - by_name: bool | None = None, + context: 'dict[str, Any] | None' = None, ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]: """ Validate an assignment to a field on a model. @@ -246,8 +220,6 @@ class SchemaValidator: If `None`, the value of [`CoreConfig.from_attributes`][pydantic_core.core_schema.CoreConfig] is used. context: The context to use for validation, this is passed to functional validators as [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. - by_alias: Whether to use the field's alias when validating against the provided input data. - by_name: Whether to use the field's name when validating against the provided input data. Raises: ValidationError: If validation fails. @@ -274,9 +246,7 @@ class SchemaValidator: `None` if the schema has no default value, otherwise a [`Some`][pydantic_core.Some] containing the default. """ -# In reality, `bool` should be replaced by `Literal[True]` but mypy fails to correctly apply bidirectional type inference -# (e.g. when using `{'a': {'b': True}}`). -_IncEx: TypeAlias = set[int] | set[str] | Mapping[int, _IncEx | bool] | Mapping[str, _IncEx | bool] +_IncEx: TypeAlias = set[int] | set[str] | dict[int, _IncEx] | dict[str, _IncEx] | None @final class SchemaSerializer: @@ -285,31 +255,28 @@ class SchemaSerializer: `CombinedSerializer` which may in turn own more `CombinedSerializer`s which make up the full schema serializer. """ - def __init__(self, schema: CoreSchema, config: CoreConfig | None = None) -> None: - """Initializes the `SchemaSerializer`. + def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: + """ + Create a new SchemaSerializer. Arguments: - schema: The `CoreSchema` to use for serialization. + schema: The [`CoreSchema`][pydantic_core.core_schema.CoreSchema] to use for serialization. config: Optionally a [`CoreConfig`][pydantic_core.core_schema.CoreConfig] to to configure serialization. """ - - def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: ... def to_python( self, value: Any, *, mode: str | None = None, - include: _IncEx | None = None, - exclude: _IncEx | None = None, - by_alias: bool | None = None, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, + warnings: bool = True, fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: Any | None = None, ) -> Any: """ Serialize/marshal a Python object to a Python object including transforming and filtering data. @@ -326,13 +293,9 @@ class SchemaSerializer: exclude_defaults: Whether to exclude fields that are equal to their default value. exclude_none: Whether to exclude fields that have a value of `None`. round_trip: Whether to enable serialization and validation round-trip support. - warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + warnings: Whether to log warnings when invalid fields are encountered. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: The context to use for serialization, this is passed to functional serializers as - [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -345,17 +308,15 @@ class SchemaSerializer: value: Any, *, indent: int | None = None, - include: _IncEx | None = None, - exclude: _IncEx | None = None, - by_alias: bool | None = None, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, - warnings: bool | Literal['none', 'warn', 'error'] = True, + warnings: bool = True, fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: Any | None = None, ) -> bytes: """ Serialize a Python object to JSON including transforming and filtering data. @@ -371,13 +332,9 @@ class SchemaSerializer: exclude_defaults: Whether to exclude fields that are equal to their default value. exclude_none: Whether to exclude fields that have a value of `None`. round_trip: Whether to enable serialization and validation round-trip support. - warnings: How to handle invalid fields. False/"none" ignores them, True/"warn" logs errors, - "error" raises a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError]. + warnings: Whether to log warnings when invalid fields are encountered. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: The context to use for serialization, this is passed to functional serializers as - [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -390,21 +347,15 @@ def to_json( value: Any, *, indent: int | None = None, - include: _IncEx | None = None, - exclude: _IncEx | None = None, - # Note: In Pydantic 2.11, the default value of `by_alias` on `SchemaSerializer` was changed from `True` to `None`, - # to be consistent with the Pydantic "dump" methods. However, the default of `True` was kept here for - # backwards compatibility. In Pydantic V3, `by_alias` is expected to default to `True` everywhere: + include: _IncEx = None, + exclude: _IncEx = None, by_alias: bool = True, exclude_none: bool = False, round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', - bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8', - inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants', + bytes_mode: Literal['utf8', 'base64'] = 'utf8', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: Any | None = None, ) -> bytes: """ Serialize a Python object to JSON including transforming and filtering data. @@ -420,15 +371,11 @@ def to_json( exclude_none: Whether to exclude fields that have a value of `None`. round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. - bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`. - inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`. + bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: The context to use for serialization, this is passed to functional serializers as - [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -437,53 +384,18 @@ def to_json( JSON bytes. """ -def from_json( - data: str | bytes | bytearray, - *, - allow_inf_nan: bool = True, - cache_strings: bool | Literal['all', 'keys', 'none'] = True, - allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, -) -> Any: - """ - Deserialize JSON data to a Python object. - - This is effectively a faster version of `json.loads()`, with some extra functionality. - - Arguments: - data: The JSON data to deserialize. - allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. - cache_strings: Whether to cache strings to avoid constructing new Python objects, - this should have a significant impact on performance while increasing memory usage slightly, - `all/True` means cache all strings, `keys` means cache only dict keys, `none/False` means no caching. - allow_partial: Whether to allow partial deserialization, if `True` JSON data is returned if the end of the - input is reached before the full object is deserialized, e.g. `["aa", "bb", "c` would return `['aa', 'bb']`. - `'trailing-strings'` means any final unfinished JSON string is included in the result. - - Raises: - ValueError: If deserialization fails. - - Returns: - The deserialized Python object. - """ - def to_jsonable_python( value: Any, *, - include: _IncEx | None = None, - exclude: _IncEx | None = None, - # Note: In Pydantic 2.11, the default value of `by_alias` on `SchemaSerializer` was changed from `True` to `None`, - # to be consistent with the Pydantic "dump" methods. However, the default of `True` was kept here for - # backwards compatibility. In Pydantic V3, `by_alias` is expected to default to `True` everywhere: + include: _IncEx = None, + exclude: _IncEx = None, by_alias: bool = True, exclude_none: bool = False, round_trip: bool = False, timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', - bytes_mode: Literal['utf8', 'base64', 'hex'] = 'utf8', - inf_nan_mode: Literal['null', 'constants', 'strings'] = 'constants', + bytes_mode: Literal['utf8', 'base64'] = 'utf8', serialize_unknown: bool = False, fallback: Callable[[Any], Any] | None = None, - serialize_as_any: bool = False, - context: Any | None = None, ) -> Any: """ Serialize/marshal a Python object to a JSON-serializable Python object including transforming and filtering data. @@ -499,15 +411,11 @@ def to_jsonable_python( exclude_none: Whether to exclude fields that have a value of `None`. round_trip: Whether to enable serialization and validation round-trip support. timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. - bytes_mode: How to serialize `bytes` objects, either `'utf8'`, `'base64'`, or `'hex'`. - inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'`, `'constants'`, or `'strings'`. + bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails `""` will be used. fallback: A function to call when an unknown value is encountered, if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. - serialize_as_any: Whether to serialize fields with duck-typing serialization behavior. - context: The context to use for serialization, this is passed to functional serializers as - [`info.context`][pydantic_core.core_schema.SerializationInfo.context]. Raises: PydanticSerializationError: If serialization fails and no `fallback` function is provided. @@ -522,43 +430,133 @@ class Url(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str) -> None: ... - def __new__(cls, url: str) -> Self: ... + def __new__(cls, url: str) -> Self: + """ + Create a new `Url` instance. + + Args: + url: String representation of a URL. + + Returns: + A new `Url` instance. + + Raises: + ValidationError: If the URL is invalid. + """ @property - def scheme(self) -> str: ... + def scheme(self) -> str: + """ + The scheme part of the URL. + + e.g. `https` in `https://user:pass@host:port/path?query#fragment` + """ @property - def username(self) -> str | None: ... + def username(self) -> str | None: + """ + The username part of the URL, or `None`. + + e.g. `user` in `https://user:pass@host:port/path?query#fragment` + """ @property - def password(self) -> str | None: ... + def password(self) -> str | None: + """ + The password part of the URL, or `None`. + + e.g. `pass` in `https://user:pass@host:port/path?query#fragment` + """ @property - def host(self) -> str | None: ... - def unicode_host(self) -> str | None: ... + def host(self) -> str | None: + """ + The host part of the URL, or `None`. + + If the URL must be punycode encoded, this is the encoded host, e.g if the input URL is `https://£££.com`, + `host` will be `xn--9aaa.com` + """ + def unicode_host(self) -> str | None: + """ + The host part of the URL as a unicode string, or `None`. + + e.g. `host` in `https://user:pass@host:port/path?query#fragment` + + If the URL must be punycode encoded, this is the decoded host, e.g if the input URL is `https://£££.com`, + `unicode_host()` will be `£££.com` + """ @property - def port(self) -> int | None: ... + def port(self) -> int | None: + """ + The port part of the URL, or `None`. + + e.g. `port` in `https://user:pass@host:port/path?query#fragment` + """ @property - def path(self) -> str | None: ... + def path(self) -> str | None: + """ + The path part of the URL, or `None`. + + e.g. `/path` in `https://user:pass@host:port/path?query#fragment` + """ @property - def query(self) -> str | None: ... - def query_params(self) -> list[tuple[str, str]]: ... + def query(self) -> str | None: + """ + The query part of the URL, or `None`. + + e.g. `query` in `https://user:pass@host:port/path?query#fragment` + """ + def query_params(self) -> list[tuple[str, str]]: + """ + The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://user:pass@host:port/path?foo=bar#fragment` + """ @property - def fragment(self) -> str | None: ... - def unicode_string(self) -> str: ... + def fragment(self) -> str | None: + """ + The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://user:pass@host:port/path?query#fragment` + """ + def unicode_string(self) -> str: + """ + The URL as a unicode string, unlike `__str__()` this will not punycode encode the host. + + If the URL must be punycode encoded, this is the decoded string, e.g if the input URL is `https://£££.com`, + `unicode_string()` will be `https://£££.com` + """ def __repr__(self) -> str: ... - def __str__(self) -> str: ... + def __str__(self) -> str: + """ + The URL as a string, this will punycode encode the host if required. + """ def __deepcopy__(self, memo: dict) -> str: ... @classmethod def build( cls, *, scheme: str, - username: str | None = None, - password: str | None = None, + username: Optional[str] = None, + password: Optional[str] = None, host: str, - port: int | None = None, - path: str | None = None, - query: str | None = None, - fragment: str | None = None, - ) -> Self: ... + port: Optional[int] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> Self: + """ + Build a new `Url` instance from its component parts. + + Args: + scheme: The scheme part of the URL. + username: The username part of the URL, or omit for no username. + password: The password part of the URL, or omit for no password. + host: The host part of the URL. + port: The port part of the URL, or omit for no port. + path: The path part of the URL, or omit for no path. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of URL + """ class MultiHostUrl(SupportsAllComparisons): """ @@ -568,36 +566,116 @@ class MultiHostUrl(SupportsAllComparisons): by Mozilla. """ - def __init__(self, url: str) -> None: ... - def __new__(cls, url: str) -> Self: ... + def __new__(cls, url: str) -> Self: + """ + Create a new `MultiHostUrl` instance. + + Args: + url: String representation of a URL. + + Returns: + A new `MultiHostUrl` instance. + + Raises: + ValidationError: If the URL is invalid. + """ @property - def scheme(self) -> str: ... + def scheme(self) -> str: + """ + The scheme part of the URL. + + e.g. `https` in `https://foo.com,bar.com/path?query#fragment` + """ @property - def path(self) -> str | None: ... + def path(self) -> str | None: + """ + The path part of the URL, or `None`. + + e.g. `/path` in `https://foo.com,bar.com/path?query#fragment` + """ @property - def query(self) -> str | None: ... - def query_params(self) -> list[tuple[str, str]]: ... + def query(self) -> str | None: + """ + The query part of the URL, or `None`. + + e.g. `query` in `https://foo.com,bar.com/path?query#fragment` + """ + def query_params(self) -> list[tuple[str, str]]: + """ + The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://foo.com,bar.com/path?query#fragment` + """ @property - def fragment(self) -> str | None: ... - def hosts(self) -> list[MultiHostHost]: ... - def unicode_string(self) -> str: ... + def fragment(self) -> str | None: + """ + The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://foo.com,bar.com/path?query#fragment` + """ + def hosts(self) -> list[MultiHostHost]: + ''' + + The hosts of the `MultiHostUrl` as [`MultiHostHost`][pydantic_core.MultiHostHost] typed dicts. + + ```py + from pydantic_core import MultiHostUrl + + mhu = MultiHostUrl('https://foo.com:123,foo:bar@bar.com/path') + print(mhu.hosts()) + """ + [ + {'username': None, 'password': None, 'host': 'foo.com', 'port': 123}, + {'username': 'foo', 'password': 'bar', 'host': 'bar.com', 'port': 443} + ] + ``` + Returns: + A list of dicts, each representing a host. + ''' + def unicode_string(self) -> str: + """ + The URL as a unicode string, unlike `__str__()` this will not punycode encode the hosts. + """ def __repr__(self) -> str: ... - def __str__(self) -> str: ... + def __str__(self) -> str: + """ + The URL as a string, this will punycode encode the hosts if required. + """ def __deepcopy__(self, memo: dict) -> Self: ... @classmethod def build( cls, *, scheme: str, - hosts: list[MultiHostHost] | None = None, - username: str | None = None, - password: str | None = None, - host: str | None = None, - port: int | None = None, - path: str | None = None, - query: str | None = None, - fragment: str | None = None, - ) -> Self: ... + hosts: Optional[list[MultiHostHost]] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> Self: + """ + Build a new `MultiHostUrl` instance from its component parts. + + This method takes either `hosts` - a list of `MultiHostHost` typed dicts, or the individual components + `username`, `password`, `host` and `port`. + + Args: + scheme: The scheme part of the URL. + hosts: Multiple hosts to build the URL from. + username: The username part of the URL. + password: The password part of the URL. + host: The host part of the URL. + port: The port part of the URL. + path: The path part of the URL. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of `MultiHostUrl` + """ @final class SchemaError(Exception): @@ -617,19 +695,20 @@ class SchemaError(Exception): A list of [`ErrorDetails`][pydantic_core.ErrorDetails] for each error in the schema. """ +@final class ValidationError(ValueError): """ `ValidationError` is the exception raised by `pydantic-core` when validation fails, it contains a list of errors which detail why validation failed. """ - @classmethod + + @staticmethod def from_exception_data( - cls, title: str, line_errors: list[InitErrorDetails], input_type: Literal['python', 'json'] = 'python', hide_input: bool = False, - ) -> Self: + ) -> ValidationError: """ Python constructor for a Validation Error. @@ -688,303 +767,56 @@ class ValidationError(ValueError): a JSON string. """ - def __repr__(self) -> str: - """ - A string representation of the validation error. - - Whether or not documentation URLs are included in the repr is controlled by the - environment variable `PYDANTIC_ERRORS_INCLUDE_URL` being set to `1` or - `true`; by default, URLs are shown. - - Due to implementation details, this environment variable can only be set once, - before the first validation error is created. - """ - +@final class PydanticCustomError(ValueError): - """A custom exception providing flexible error handling for Pydantic validators. - - You can raise this error in custom validators when you'd like flexibility in regards to the error type, message, and context. - - Example: - ```py - from pydantic_core import PydanticCustomError - - def custom_validator(v) -> None: - if v <= 10: - raise PydanticCustomError('custom_value_error', 'Value must be greater than {value}', {'value': 10, 'extra_context': 'extra_data'}) - return v - ``` - """ - - def __init__( - self, error_type: LiteralString, message_template: LiteralString, context: dict[str, Any] | None = None - ) -> None: - """Initializes the `PydanticCustomError`. - - Arguments: - error_type: The error type. - message_template: The message template. - context: The data to inject into the message template. - """ - def __new__( cls, error_type: LiteralString, message_template: LiteralString, context: dict[str, Any] | None = None ) -> Self: ... @property - def context(self) -> dict[str, Any] | None: - """Values which are required to render the error message, and could hence be useful in passing error data forward.""" - + def context(self) -> dict[str, Any] | None: ... @property - def type(self) -> str: - """The error type associated with the error. For consistency with Pydantic, this is typically a snake_case string.""" - + def type(self) -> str: ... @property - def message_template(self) -> str: - """The message template associated with the error. This is a string that can be formatted with context variables in `{curly_braces}`.""" - - def message(self) -> str: - """The formatted message associated with the error. This presents as the message template with context variables appropriately injected.""" + def message_template(self) -> str: ... + def message(self) -> str: ... @final class PydanticKnownError(ValueError): - """A helper class for raising exceptions that mimic Pydantic's built-in exceptions, with more flexibility in regards to context. - - Unlike [`PydanticCustomError`][pydantic_core.PydanticCustomError], the `error_type` argument must be a known `ErrorType`. - - Example: - ```py - from pydantic_core import PydanticKnownError - - def custom_validator(v) -> None: - if v <= 10: - raise PydanticKnownError(error_type='greater_than', context={'gt': 10}) - return v - ``` - """ - - def __init__(self, error_type: ErrorType, context: dict[str, Any] | None = None) -> None: - """Initializes the `PydanticKnownError`. - - Arguments: - error_type: The error type. - context: The data to inject into the message template. - """ - def __new__(cls, error_type: ErrorType, context: dict[str, Any] | None = None) -> Self: ... @property - def context(self) -> dict[str, Any] | None: - """Values which are required to render the error message, and could hence be useful in passing error data forward.""" - + def context(self) -> dict[str, Any] | None: ... @property - def type(self) -> ErrorType: - """The type of the error.""" - + def type(self) -> ErrorType: ... @property - def message_template(self) -> str: - """The message template associated with the provided error type. This is a string that can be formatted with context variables in `{curly_braces}`.""" - - def message(self) -> str: - """The formatted message associated with the error. This presents as the message template with context variables appropriately injected.""" + def message_template(self) -> str: ... + def message(self) -> str: ... @final class PydanticOmit(Exception): - """An exception to signal that a field should be omitted from a generated result. - - This could span from omitting a field from a JSON Schema to omitting a field from a serialized result. - Upcoming: more robust support for using PydanticOmit in custom serializers is still in development. - Right now, this is primarily used in the JSON Schema generation process. - - Example: - ```py - from typing import Callable - - from pydantic_core import PydanticOmit - - from pydantic import BaseModel - from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue - - - class MyGenerateJsonSchema(GenerateJsonSchema): - def handle_invalid_for_json_schema(self, schema, error_info) -> JsonSchemaValue: - raise PydanticOmit - - - class Predicate(BaseModel): - name: str = 'no-op' - func: Callable = lambda x: x - - - instance_example = Predicate() - - validation_schema = instance_example.model_json_schema(schema_generator=MyGenerateJsonSchema, mode='validation') - print(validation_schema) - ''' - {'properties': {'name': {'default': 'no-op', 'title': 'Name', 'type': 'string'}}, 'title': 'Predicate', 'type': 'object'} - ''' - ``` - - For a more in depth example / explanation, see the [customizing JSON schema](../concepts/json_schema.md#customizing-the-json-schema-generation-process) docs. - """ - def __new__(cls) -> Self: ... @final class PydanticUseDefault(Exception): - """An exception to signal that standard validation either failed or should be skipped, and the default value should be used instead. - - This warning can be raised in custom valiation functions to redirect the flow of validation. - - Example: - ```py - from pydantic_core import PydanticUseDefault - from datetime import datetime - from pydantic import BaseModel, field_validator - - - class Event(BaseModel): - name: str = 'meeting' - time: datetime - - @field_validator('name', mode='plain') - def name_must_be_present(cls, v) -> str: - if not v or not isinstance(v, str): - raise PydanticUseDefault() - return v - - - event1 = Event(name='party', time=datetime(2024, 1, 1, 12, 0, 0)) - print(repr(event1)) - # > Event(name='party', time=datetime.datetime(2024, 1, 1, 12, 0)) - event2 = Event(time=datetime(2024, 1, 1, 12, 0, 0)) - print(repr(event2)) - # > Event(name='meeting', time=datetime.datetime(2024, 1, 1, 12, 0)) - ``` - - For an additional example, see the [validating partial json data](../concepts/json.md#partial-json-parsing) section of the Pydantic documentation. - """ - def __new__(cls) -> Self: ... @final class PydanticSerializationError(ValueError): - """An error raised when an issue occurs during serialization. - - In custom serializers, this error can be used to indicate that serialization has failed. - """ - - def __init__(self, message: str) -> None: - """Initializes the `PydanticSerializationError`. - - Arguments: - message: The message associated with the error. - """ - def __new__(cls, message: str) -> Self: ... @final class PydanticSerializationUnexpectedValue(ValueError): - """An error raised when an unexpected value is encountered during serialization. - - This error is often caught and coerced into a warning, as `pydantic-core` generally makes a best attempt - at serializing values, in contrast with validation where errors are eagerly raised. - - Example: - ```py - from pydantic import BaseModel, field_serializer - from pydantic_core import PydanticSerializationUnexpectedValue - - class BasicPoint(BaseModel): - x: int - y: int - - @field_serializer('*') - def serialize(self, v): - if not isinstance(v, int): - raise PydanticSerializationUnexpectedValue(f'Expected type `int`, got {type(v)} with value {v}') - return v - - point = BasicPoint(x=1, y=2) - # some sort of mutation - point.x = 'a' - - print(point.model_dump()) - ''' - UserWarning: Pydantic serializer warnings: - PydanticSerializationUnexpectedValue(Expected type `int`, got with value a) - return self.__pydantic_serializer__.to_python( - {'x': 'a', 'y': 2} - ''' - ``` - - This is often used internally in `pydantic-core` when unexpected types are encountered during serialization, - but it can also be used by users in custom serializers, as seen above. - """ - - def __init__(self, message: str) -> None: - """Initializes the `PydanticSerializationUnexpectedValue`. - - Arguments: - message: The message associated with the unexpected value. - """ - def __new__(cls, message: str | None = None) -> Self: ... @final class ArgsKwargs: - """A construct used to store arguments and keyword arguments for a function call. - - This data structure is generally used to store information for core schemas associated with functions (like in an arguments schema). - This data structure is also currently used for some validation against dataclasses. - - Example: - ```py - from pydantic.dataclasses import dataclass - from pydantic import model_validator - - - @dataclass - class Model: - a: int - b: int - - @model_validator(mode="before") - @classmethod - def no_op_validator(cls, values): - print(values) - return values - - Model(1, b=2) - #> ArgsKwargs((1,), {"b": 2}) - - Model(1, 2) - #> ArgsKwargs((1, 2), {}) - - Model(a=1, b=2) - #> ArgsKwargs((), {"a": 1, "b": 2}) - ``` - """ - - def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any] | None = None) -> None: - """Initializes the `ArgsKwargs`. - - Arguments: - args: The arguments (inherently ordered) for a function call. - kwargs: The keyword arguments for a function call - """ - def __new__(cls, args: tuple[Any, ...], kwargs: dict[str, Any] | None = None) -> Self: ... @property - def args(self) -> tuple[Any, ...]: - """The arguments (inherently ordered) for a function call.""" - + def args(self) -> tuple[Any, ...]: ... @property - def kwargs(self) -> dict[str, Any] | None: - """The keyword arguments for a function call.""" + def kwargs(self) -> dict[str, Any] | None: ... @final class PydanticUndefinedType: - """A type used as a sentinel for undefined values.""" - def __copy__(self) -> Self: ... def __deepcopy__(self, memo: Any) -> Self: ... @@ -997,41 +829,17 @@ def list_all_errors() -> list[ErrorTypeInfo]: Returns: A list of `ErrorTypeInfo` typed dicts. """ + @final class TzInfo(datetime.tzinfo): - """An `pydantic-core` implementation of the abstract [`datetime.tzinfo`][] class.""" - - # def __new__(cls, seconds: float) -> Self: ... - - # Docstrings for attributes sourced from the abstract base class, [`datetime.tzinfo`](https://docs.python.org/3/library/datetime.html#datetime.tzinfo). - - def tzname(self, dt: datetime.datetime | None) -> str | None: - """Return the time zone name corresponding to the [`datetime`][datetime.datetime] object _dt_, as a string. - - For more info, see [`tzinfo.tzname`][datetime.tzinfo.tzname]. - """ - - def utcoffset(self, dt: datetime.datetime | None) -> datetime.timedelta | None: - """Return offset of local time from UTC, as a [`timedelta`][datetime.timedelta] object that is positive east of UTC. If local time is west of UTC, this should be negative. - - More info can be found at [`tzinfo.utcoffset`][datetime.tzinfo.utcoffset]. - """ - - def dst(self, dt: datetime.datetime | None) -> datetime.timedelta | None: - """Return the daylight saving time (DST) adjustment, as a [`timedelta`][datetime.timedelta] object or `None` if DST information isn’t known. - - More info can be found at[`tzinfo.dst`][datetime.tzinfo.dst].""" - - def fromutc(self, dt: datetime.datetime) -> datetime.datetime: - """Adjust the date and time data associated datetime object _dt_, returning an equivalent datetime in self’s local time. - - More info can be found at [`tzinfo.fromutc`][datetime.tzinfo.fromutc].""" - - def __deepcopy__(self, _memo: dict[Any, Any]) -> TzInfo: ... + def tzname(self, _dt: datetime.datetime | None) -> str | None: ... + def utcoffset(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... + def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... + def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ... + def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ... def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> CoreSchema: - """Validate a core schema. - + """Validate a CoreSchema This currently uses lax mode for validation (i.e. will coerce strings to dates and such) but may use strict mode in the future. We may also remove this function altogether, do not rely on it being present if you are diff --git a/venv/lib/python3.12/site-packages/pydantic_core/core_schema.py b/venv/lib/python3.12/site-packages/pydantic_core/core_schema.py index 0ab3dd9..2d7061f 100644 --- a/venv/lib/python3.12/site-packages/pydantic_core/core_schema.py +++ b/venv/lib/python3.12/site-packages/pydantic_core/core_schema.py @@ -7,11 +7,10 @@ from __future__ import annotations as _annotations import sys import warnings -from collections.abc import Hashable, Mapping +from collections.abc import Mapping from datetime import date, datetime, time, timedelta from decimal import Decimal -from re import Pattern -from typing import TYPE_CHECKING, Any, Callable, Literal, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union from typing_extensions import deprecated @@ -25,6 +24,11 @@ if sys.version_info < (3, 11): else: from typing import Protocol, Required, TypeAlias +if sys.version_info < (3, 9): + from typing_extensions import Literal +else: + from typing import Literal + if TYPE_CHECKING: from pydantic_core import PydanticUndefined else: @@ -54,6 +58,8 @@ class CoreConfig(TypedDict, total=False): `field_names` to construct error `loc`s. Default is `True`. revalidate_instances: Whether instances of models and dataclasses should re-validate. Default is 'never'. validate_default: Whether to validate default values during validation. Default is `False`. + populate_by_name: Whether an aliased field may be populated by its name as given by the model attribute, + as well as the alias. (Replaces 'allow_population_by_field_name' in Pydantic v1.) Default is `False`. str_max_length: The maximum length for string fields. str_min_length: The minimum length for string fields. str_strip_whitespace: Whether to strip whitespace from string fields. @@ -62,19 +68,11 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`. ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. - ser_json_inf_nan: The serialization option for infinity and NaN values - in float fields. Default is 'null'. - val_json_bytes: The validation option for `bytes` values, complementing ser_json_bytes. Default is 'utf8'. hide_input_in_errors: Whether to hide input data from `ValidationError` representation. validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. Requires exceptiongroup backport pre Python 3.11. coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). regex_engine: The regex engine to use for regex pattern validation. Default is 'rust-regex'. See `StringSchema`. - cache_strings: Whether to cache strings. Default is `True`, `True` or `'all'` is required to cache strings - during general validation since validators don't know if they're in a key or a value. - validate_by_alias: Whether to use the field's alias when validating against the provided input data. Default is `True`. - validate_by_name: Whether to use the field's name when validating against the provided input data. Default is `False`. Replacement for `populate_by_name`. - serialize_by_alias: Whether to serialize by alias. Default is `False`, expected to change to `True` in V3. """ title: str @@ -92,6 +90,7 @@ class CoreConfig(TypedDict, total=False): # whether to validate default values during validation, default False validate_default: bool # used on typed-dicts and arguments + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 # fields related to string fields only str_max_length: int str_min_length: int @@ -102,18 +101,11 @@ class CoreConfig(TypedDict, total=False): allow_inf_nan: bool # default: True # the config options are used to customise serialization to JSON ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601' - ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' - ser_json_inf_nan: Literal['null', 'constants', 'strings'] # default: 'null' - val_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' + ser_json_bytes: Literal['utf8', 'base64'] # default: 'utf8' # used to hide input data from ValidationError repr hide_input_in_errors: bool validation_error_cause: bool # default: False coerce_numbers_to_str: bool # default: False - regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' - cache_strings: Union[bool, Literal['all', 'keys', 'none']] # default: 'True' - validate_by_alias: bool # default: True - validate_by_name: bool # default: False - serialize_by_alias: bool # default: False IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' @@ -121,46 +113,51 @@ IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' class SerializationInfo(Protocol): @property - def include(self) -> IncExCall: ... + def include(self) -> IncExCall: + ... @property - def exclude(self) -> IncExCall: ... + def exclude(self) -> IncExCall: + ... @property - def context(self) -> Any | None: - """Current serialization context.""" + def mode(self) -> str: + ... @property - def mode(self) -> str: ... + def by_alias(self) -> bool: + ... @property - def by_alias(self) -> bool: ... + def exclude_unset(self) -> bool: + ... @property - def exclude_unset(self) -> bool: ... + def exclude_defaults(self) -> bool: + ... @property - def exclude_defaults(self) -> bool: ... + def exclude_none(self) -> bool: + ... @property - def exclude_none(self) -> bool: ... + def round_trip(self) -> bool: + ... - @property - def serialize_as_any(self) -> bool: ... + def mode_is_json(self) -> bool: + ... - @property - def round_trip(self) -> bool: ... + def __str__(self) -> str: + ... - def mode_is_json(self) -> bool: ... - - def __str__(self) -> str: ... - - def __repr__(self) -> str: ... + def __repr__(self) -> str: + ... class FieldSerializationInfo(SerializationInfo, Protocol): @property - def field_name(self) -> str: ... + def field_name(self) -> str: + ... class ValidationInfo(Protocol): @@ -184,7 +181,7 @@ class ValidationInfo(Protocol): ... @property - def data(self) -> dict[str, Any]: + def data(self) -> Dict[str, Any]: """The data being validated for this model.""" ... @@ -219,7 +216,6 @@ ExpectedSerializationTypes = Literal[ 'multi-host-url', 'json', 'uuid', - 'any', ] @@ -237,13 +233,13 @@ def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema: return SimpleSerSchema(type=type) -# (input_value: Any, /) -> Any +# (__input_value: Any) -> Any GeneralPlainNoInfoSerializerFunction = Callable[[Any], Any] -# (input_value: Any, info: FieldSerializationInfo, /) -> Any +# (__input_value: Any, __info: FieldSerializationInfo) -> Any GeneralPlainInfoSerializerFunction = Callable[[Any, SerializationInfo], Any] -# (model: Any, input_value: Any, /) -> Any +# (__model: Any, __input_value: Any) -> Any FieldPlainNoInfoSerializerFunction = Callable[[Any, Any], Any] -# (model: Any, input_value: Any, info: FieldSerializationInfo, /) -> Any +# (__model: Any, __input_value: Any, __info: FieldSerializationInfo) -> Any FieldPlainInfoSerializerFunction = Callable[[Any, Any, FieldSerializationInfo], Any] SerializerFunction = Union[ GeneralPlainNoInfoSerializerFunction, @@ -287,7 +283,7 @@ def plain_serializer_function_ser_schema( function: The function to use for serialization is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, and `info` includes `field_name` - info_arg: Whether the function takes an `info` argument + info_arg: Whether the function takes an `__info` argument return_schema: Schema to use for serializing return value when_used: When the function should be called """ @@ -305,16 +301,17 @@ def plain_serializer_function_ser_schema( class SerializerFunctionWrapHandler(Protocol): # pragma: no cover - def __call__(self, input_value: Any, index_key: int | str | None = None, /) -> Any: ... + def __call__(self, __input_value: Any, __index_key: int | str | None = None) -> Any: + ... -# (input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any +# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any GeneralWrapNoInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler], Any] -# (input_value: Any, serializer: SerializerFunctionWrapHandler, info: SerializationInfo, /) -> Any +# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any GeneralWrapInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any] -# (model: Any, input_value: Any, serializer: SerializerFunctionWrapHandler, /) -> Any +# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any FieldWrapNoInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler], Any] -# (model: Any, input_value: Any, serializer: SerializerFunctionWrapHandler, info: FieldSerializationInfo, /) -> Any +# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: FieldSerializationInfo) -> Any FieldWrapInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler, FieldSerializationInfo], Any] WrapSerializerFunction = Union[ GeneralWrapNoInfoSerializerFunction, @@ -350,7 +347,7 @@ def wrap_serializer_function_ser_schema( function: The function to use for serialization is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, and `info` includes `field_name` - info_arg: Whether the function takes an `info` argument + info_arg: Whether the function takes an `__info` argument schema: The schema to use for the inner serialization return_schema: Schema to use for serializing return value when_used: When the function should be called @@ -410,11 +407,11 @@ def to_string_ser_schema(*, when_used: WhenUsed = 'json-unless-none') -> ToStrin class ModelSerSchema(TypedDict, total=False): type: Required[Literal['model']] - cls: Required[type[Any]] + cls: Required[Type[Any]] schema: Required[CoreSchema] -def model_ser_schema(cls: type[Any], schema: CoreSchema) -> ModelSerSchema: +def model_ser_schema(cls: Type[Any], schema: CoreSchema) -> ModelSerSchema: """ Returns a schema for serialization using a model. @@ -435,39 +432,16 @@ SerSchema = Union[ ] -class InvalidSchema(TypedDict, total=False): - type: Required[Literal['invalid']] - ref: str - metadata: dict[str, Any] - # note, we never plan to use this, but include it for type checking purposes to match - # all other CoreSchema union members - serialization: SerSchema - - -def invalid_schema(ref: str | None = None, metadata: dict[str, Any] | None = None) -> InvalidSchema: - """ - Returns an invalid schema, used to indicate that a schema is invalid. - - Returns a schema that matches any value, e.g.: - - Args: - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - """ - - return _dict_not_none(type='invalid', ref=ref, metadata=metadata) - - class ComputedField(TypedDict, total=False): type: Required[Literal['computed-field']] property_name: Required[str] return_schema: Required[CoreSchema] alias: str - metadata: dict[str, Any] + metadata: Any def computed_field( - property_name: str, return_schema: CoreSchema, *, alias: str | None = None, metadata: dict[str, Any] | None = None + property_name: str, return_schema: CoreSchema, *, alias: str | None = None, metadata: Any = None ) -> ComputedField: """ ComputedFields are properties of a model or dataclass that are included in serialization. @@ -486,13 +460,11 @@ def computed_field( class AnySchema(TypedDict, total=False): type: Required[Literal['any']] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema -def any_schema( - *, ref: str | None = None, metadata: dict[str, Any] | None = None, serialization: SerSchema | None = None -) -> AnySchema: +def any_schema(*, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None) -> AnySchema: """ Returns a schema that matches any value, e.g.: @@ -515,13 +487,11 @@ def any_schema( class NoneSchema(TypedDict, total=False): type: Required[Literal['none']] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema -def none_schema( - *, ref: str | None = None, metadata: dict[str, Any] | None = None, serialization: SerSchema | None = None -) -> NoneSchema: +def none_schema(*, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None) -> NoneSchema: """ Returns a schema that matches a None value, e.g.: @@ -545,15 +515,12 @@ class BoolSchema(TypedDict, total=False): type: Required[Literal['bool']] strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def bool_schema( - strict: bool | None = None, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, + strict: bool | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None ) -> BoolSchema: """ Returns a schema that matches a bool value, e.g.: @@ -584,7 +551,7 @@ class IntSchema(TypedDict, total=False): gt: int strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -597,7 +564,7 @@ def int_schema( gt: int | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> IntSchema: """ @@ -646,7 +613,7 @@ class FloatSchema(TypedDict, total=False): gt: float strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -660,7 +627,7 @@ def float_schema( gt: float | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> FloatSchema: """ @@ -713,13 +680,13 @@ class DecimalSchema(TypedDict, total=False): decimal_places: int strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def decimal_schema( *, - allow_inf_nan: bool | None = None, + allow_inf_nan: bool = None, multiple_of: Decimal | None = None, le: Decimal | None = None, ge: Decimal | None = None, @@ -729,7 +696,7 @@ def decimal_schema( decimal_places: int | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> DecimalSchema: """ @@ -775,51 +742,9 @@ def decimal_schema( ) -class ComplexSchema(TypedDict, total=False): - type: Required[Literal['complex']] - strict: bool - ref: str - metadata: dict[str, Any] - serialization: SerSchema - - -def complex_schema( - *, - strict: bool | None = None, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, -) -> ComplexSchema: - """ - Returns a schema that matches a complex value, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - schema = core_schema.complex_schema() - v = SchemaValidator(schema) - assert v.validate_python('1+2j') == complex(1, 2) - assert v.validate_python(complex(1, 2)) == complex(1, 2) - ``` - - Args: - strict: Whether the value should be a complex object instance or a value that can be converted to a complex object - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='complex', - strict=strict, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - class StringSchema(TypedDict, total=False): type: Required[Literal['str']] - pattern: Union[str, Pattern[str]] + pattern: str max_length: int min_length: int strip_whitespace: bool @@ -827,15 +752,14 @@ class StringSchema(TypedDict, total=False): to_upper: bool regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' strict: bool - coerce_numbers_to_str: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def str_schema( *, - pattern: str | Pattern[str] | None = None, + pattern: str | None = None, max_length: int | None = None, min_length: int | None = None, strip_whitespace: bool | None = None, @@ -843,9 +767,8 @@ def str_schema( to_upper: bool | None = None, regex_engine: Literal['rust-regex', 'python-re'] | None = None, strict: bool | None = None, - coerce_numbers_to_str: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> StringSchema: """ @@ -873,7 +796,6 @@ def str_schema( - `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module, which supports all regex features, but may be slower. strict: Whether the value should be a string or a value that can be converted to a string - coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema @@ -888,7 +810,6 @@ def str_schema( to_upper=to_upper, regex_engine=regex_engine, strict=strict, - coerce_numbers_to_str=coerce_numbers_to_str, ref=ref, metadata=metadata, serialization=serialization, @@ -901,7 +822,7 @@ class BytesSchema(TypedDict, total=False): min_length: int strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -911,7 +832,7 @@ def bytes_schema( min_length: int | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> BytesSchema: """ @@ -956,7 +877,7 @@ class DateSchema(TypedDict, total=False): # value is restricted to -86_400 < offset < 86_400 by bounds in generate_self_schema.py now_utc_offset: int ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -970,7 +891,7 @@ def date_schema( now_op: Literal['past', 'future'] | None = None, now_utc_offset: int | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> DateSchema: """ @@ -1022,7 +943,7 @@ class TimeSchema(TypedDict, total=False): tz_constraint: Union[Literal['aware', 'naive'], int] microseconds_precision: Literal['truncate', 'error'] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1036,7 +957,7 @@ def time_schema( tz_constraint: Literal['aware', 'naive'] | int | None = None, microseconds_precision: Literal['truncate', 'error'] = 'truncate', ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> TimeSchema: """ @@ -1092,7 +1013,7 @@ class DatetimeSchema(TypedDict, total=False): now_utc_offset: int microseconds_precision: Literal['truncate', 'error'] # default: 'truncate' ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1108,7 +1029,7 @@ def datetime_schema( now_utc_offset: int | None = None, microseconds_precision: Literal['truncate', 'error'] = 'truncate', ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> DatetimeSchema: """ @@ -1165,7 +1086,7 @@ class TimedeltaSchema(TypedDict, total=False): gt: timedelta microseconds_precision: Literal['truncate', 'error'] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1178,7 +1099,7 @@ def timedelta_schema( gt: timedelta | None = None, microseconds_precision: Literal['truncate', 'error'] = 'truncate', ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> TimedeltaSchema: """ @@ -1220,18 +1141,14 @@ def timedelta_schema( class LiteralSchema(TypedDict, total=False): type: Required[Literal['literal']] - expected: Required[list[Any]] + expected: Required[List[Any]] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def literal_schema( - expected: list[Any], - *, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, + expected: list[Any], *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None ) -> LiteralSchema: """ Returns a schema that matches a literal value, e.g.: @@ -1253,69 +1170,6 @@ def literal_schema( return _dict_not_none(type='literal', expected=expected, ref=ref, metadata=metadata, serialization=serialization) -class EnumSchema(TypedDict, total=False): - type: Required[Literal['enum']] - cls: Required[Any] - members: Required[list[Any]] - sub_type: Literal['str', 'int', 'float'] - missing: Callable[[Any], Any] - strict: bool - ref: str - metadata: dict[str, Any] - serialization: SerSchema - - -def enum_schema( - cls: Any, - members: list[Any], - *, - sub_type: Literal['str', 'int', 'float'] | None = None, - missing: Callable[[Any], Any] | None = None, - strict: bool | None = None, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, -) -> EnumSchema: - """ - Returns a schema that matches an enum value, e.g.: - - ```py - from enum import Enum - from pydantic_core import SchemaValidator, core_schema - - class Color(Enum): - RED = 1 - GREEN = 2 - BLUE = 3 - - schema = core_schema.enum_schema(Color, list(Color.__members__.values())) - v = SchemaValidator(schema) - assert v.validate_python(2) is Color.GREEN - ``` - - Args: - cls: The enum class - members: The members of the enum, generally `list(MyEnum.__members__.values())` - sub_type: The type of the enum, either 'str' or 'int' or None for plain enums - missing: A function to use when the value is not found in the enum, from `_missing_` - strict: Whether to use strict mode, defaults to False - ref: optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return _dict_not_none( - type='enum', - cls=cls, - members=members, - sub_type=sub_type, - missing=missing, - strict=strict, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - # must match input/parse_json.rs::JsonType::try_from JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict'] @@ -1325,7 +1179,7 @@ class IsInstanceSchema(TypedDict, total=False): cls: Required[Any] cls_repr: str ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1334,11 +1188,11 @@ def is_instance_schema( *, cls_repr: str | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> IsInstanceSchema: """ - Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstance` method, e.g.: + Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstnace` method, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -1365,19 +1219,19 @@ def is_instance_schema( class IsSubclassSchema(TypedDict, total=False): type: Required[Literal['is-subclass']] - cls: Required[type[Any]] + cls: Required[Type[Any]] cls_repr: str ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def is_subclass_schema( - cls: type[Any], + cls: Type[Any], *, cls_repr: str | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> IsInstanceSchema: """ @@ -1412,12 +1266,12 @@ def is_subclass_schema( class CallableSchema(TypedDict, total=False): type: Required[Literal['callable']] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def callable_schema( - *, ref: str | None = None, metadata: dict[str, Any] | None = None, serialization: SerSchema | None = None + *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None ) -> CallableSchema: """ Returns a schema that checks if a value is callable, equivalent to python's `callable` method, e.g.: @@ -1440,19 +1294,19 @@ def callable_schema( class UuidSchema(TypedDict, total=False): type: Required[Literal['uuid']] - version: Literal[1, 3, 4, 5, 7] + version: Literal[1, 3, 4, 5] strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def uuid_schema( *, - version: Literal[1, 3, 4, 5, 6, 7, 8] | None = None, + version: Literal[1, 3, 4, 5] | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> UuidSchema: return _dict_not_none( @@ -1462,11 +1316,11 @@ def uuid_schema( class IncExSeqSerSchema(TypedDict, total=False): type: Required[Literal['include-exclude-sequence']] - include: set[int] - exclude: set[int] + include: Set[int] + exclude: Set[int] -def filter_seq_schema(*, include: set[int] | None = None, exclude: set[int] | None = None) -> IncExSeqSerSchema: +def filter_seq_schema(*, include: Set[int] | None = None, exclude: Set[int] | None = None) -> IncExSeqSerSchema: return _dict_not_none(type='include-exclude-sequence', include=include, exclude=exclude) @@ -1478,10 +1332,9 @@ class ListSchema(TypedDict, total=False): items_schema: CoreSchema min_length: int max_length: int - fail_fast: bool strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: IncExSeqOrElseSerSchema @@ -1490,10 +1343,9 @@ def list_schema( *, min_length: int | None = None, max_length: int | None = None, - fail_fast: bool | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, ) -> ListSchema: """ @@ -1511,7 +1363,6 @@ def list_schema( items_schema: The value must be a list of items that match this schema min_length: The value must be a list with at least this many items max_length: The value must be a list with at most this many items - fail_fast: Stop validation on the first error strict: The value must be a list with exactly this many items ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1522,7 +1373,6 @@ def list_schema( items_schema=items_schema, min_length=min_length, max_length=max_length, - fail_fast=fail_fast, strict=strict, ref=ref, metadata=metadata, @@ -1530,16 +1380,25 @@ def list_schema( ) -# @deprecated('tuple_positional_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') +class TuplePositionalSchema(TypedDict, total=False): + type: Required[Literal['tuple-positional']] + items_schema: Required[List[CoreSchema]] + extras_schema: CoreSchema + strict: bool + ref: str + metadata: Any + serialization: IncExSeqOrElseSerSchema + + def tuple_positional_schema( items_schema: list[CoreSchema], *, extras_schema: CoreSchema | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, -) -> TupleSchema: +) -> TuplePositionalSchema: """ Returns a schema that matches a tuple of schemas, e.g.: @@ -1564,14 +1423,10 @@ def tuple_positional_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ - if extras_schema is not None: - variadic_item_index = len(items_schema) - items_schema = items_schema + [extras_schema] - else: - variadic_item_index = None - return tuple_schema( + return _dict_not_none( + type='tuple-positional', items_schema=items_schema, - variadic_item_index=variadic_item_index, + extras_schema=extras_schema, strict=strict, ref=ref, metadata=metadata, @@ -1579,7 +1434,17 @@ def tuple_positional_schema( ) -# @deprecated('tuple_variable_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') +class TupleVariableSchema(TypedDict, total=False): + type: Required[Literal['tuple-variable']] + items_schema: CoreSchema + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: IncExSeqOrElseSerSchema + + def tuple_variable_schema( items_schema: CoreSchema | None = None, *, @@ -1587,9 +1452,9 @@ def tuple_variable_schema( max_length: int | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, -) -> TupleSchema: +) -> TupleVariableSchema: """ Returns a schema that matches a tuple of a given schema, e.g.: @@ -1608,79 +1473,15 @@ def tuple_variable_schema( min_length: The value must be a tuple with at least this many items max_length: The value must be a tuple with at most this many items strict: The value must be a tuple with exactly this many items - ref: Optional unique identifier of the schema, used to reference the schema in other places - metadata: Any other information you want to include with the schema, not used by pydantic-core - serialization: Custom serialization schema - """ - return tuple_schema( - items_schema=[items_schema or any_schema()], - variadic_item_index=0, - min_length=min_length, - max_length=max_length, - strict=strict, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - -class TupleSchema(TypedDict, total=False): - type: Required[Literal['tuple']] - items_schema: Required[list[CoreSchema]] - variadic_item_index: int - min_length: int - max_length: int - fail_fast: bool - strict: bool - ref: str - metadata: dict[str, Any] - serialization: IncExSeqOrElseSerSchema - - -def tuple_schema( - items_schema: list[CoreSchema], - *, - variadic_item_index: int | None = None, - min_length: int | None = None, - max_length: int | None = None, - fail_fast: bool | None = None, - strict: bool | None = None, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: IncExSeqOrElseSerSchema | None = None, -) -> TupleSchema: - """ - Returns a schema that matches a tuple of schemas, with an optional variadic item, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - schema = core_schema.tuple_schema( - [core_schema.int_schema(), core_schema.str_schema(), core_schema.float_schema()], - variadic_item_index=1, - ) - v = SchemaValidator(schema) - assert v.validate_python((1, 'hello', 'world', 1.5)) == (1, 'hello', 'world', 1.5) - ``` - - Args: - items_schema: The value must be a tuple with items that match these schemas - variadic_item_index: The index of the schema in `items_schema` to be treated as variadic (following PEP 646) - min_length: The value must be a tuple with at least this many items - max_length: The value must be a tuple with at most this many items - fail_fast: Stop validation on the first error - strict: The value must be a tuple with exactly this many items - ref: Optional unique identifier of the schema, used to reference the schema in other places + ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ return _dict_not_none( - type='tuple', + type='tuple-variable', items_schema=items_schema, - variadic_item_index=variadic_item_index, min_length=min_length, max_length=max_length, - fail_fast=fail_fast, strict=strict, ref=ref, metadata=metadata, @@ -1693,10 +1494,9 @@ class SetSchema(TypedDict, total=False): items_schema: CoreSchema min_length: int max_length: int - fail_fast: bool strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1705,10 +1505,9 @@ def set_schema( *, min_length: int | None = None, max_length: int | None = None, - fail_fast: bool | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> SetSchema: """ @@ -1728,7 +1527,6 @@ def set_schema( items_schema: The value must be a set with items that match this schema min_length: The value must be a set with at least this many items max_length: The value must be a set with at most this many items - fail_fast: Stop validation on the first error strict: The value must be a set with exactly this many items ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1739,7 +1537,6 @@ def set_schema( items_schema=items_schema, min_length=min_length, max_length=max_length, - fail_fast=fail_fast, strict=strict, ref=ref, metadata=metadata, @@ -1752,10 +1549,9 @@ class FrozenSetSchema(TypedDict, total=False): items_schema: CoreSchema min_length: int max_length: int - fail_fast: bool strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -1764,10 +1560,9 @@ def frozenset_schema( *, min_length: int | None = None, max_length: int | None = None, - fail_fast: bool | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> FrozenSetSchema: """ @@ -1787,7 +1582,6 @@ def frozenset_schema( items_schema: The value must be a frozenset with items that match this schema min_length: The value must be a frozenset with at least this many items max_length: The value must be a frozenset with at most this many items - fail_fast: Stop validation on the first error strict: The value must be a frozenset with exactly this many items ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -1798,7 +1592,6 @@ def frozenset_schema( items_schema=items_schema, min_length=min_length, max_length=max_length, - fail_fast=fail_fast, strict=strict, ref=ref, metadata=metadata, @@ -1812,7 +1605,7 @@ class GeneratorSchema(TypedDict, total=False): min_length: int max_length: int ref: str - metadata: dict[str, Any] + metadata: Any serialization: IncExSeqOrElseSerSchema @@ -1822,7 +1615,7 @@ def generator_schema( min_length: int | None = None, max_length: int | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: IncExSeqOrElseSerSchema | None = None, ) -> GeneratorSchema: """ @@ -1863,7 +1656,7 @@ def generator_schema( ) -IncExDict = set[Union[int, str]] +IncExDict = Set[Union[int, str]] class IncExDictSerSchema(TypedDict, total=False): @@ -1887,7 +1680,7 @@ class DictSchema(TypedDict, total=False): max_length: int strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: IncExDictOrElseSerSchema @@ -1899,7 +1692,7 @@ def dict_schema( max_length: int | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> DictSchema: """ @@ -1938,7 +1731,7 @@ def dict_schema( ) -# (input_value: Any, /) -> Any +# (__input_value: Any) -> Any NoInfoValidatorFunction = Callable[[Any], Any] @@ -1947,7 +1740,7 @@ class NoInfoValidatorFunctionSchema(TypedDict): function: NoInfoValidatorFunction -# (input_value: Any, info: ValidationInfo, /) -> Any +# (__input_value: Any, __info: ValidationInfo) -> Any WithInfoValidatorFunction = Callable[[Any, ValidationInfo], Any] @@ -1964,13 +1757,12 @@ class _ValidatorFunctionSchema(TypedDict, total=False): function: Required[ValidationFunction] schema: Required[CoreSchema] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema class BeforeValidatorFunctionSchema(_ValidatorFunctionSchema, total=False): type: Required[Literal['function-before']] - json_schema_input_schema: CoreSchema def no_info_before_validator_function( @@ -1978,8 +1770,7 @@ def no_info_before_validator_function( schema: CoreSchema, *, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> BeforeValidatorFunctionSchema: """ @@ -2004,7 +1795,6 @@ def no_info_before_validator_function( function: The validator function to call schema: The schema to validate the output of the validator function ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2013,7 +1803,6 @@ def no_info_before_validator_function( function={'type': 'no-info', 'function': function}, schema=schema, ref=ref, - json_schema_input_schema=json_schema_input_schema, metadata=metadata, serialization=serialization, ) @@ -2025,8 +1814,7 @@ def with_info_before_validator_function( *, field_name: str | None = None, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> BeforeValidatorFunctionSchema: """ @@ -2055,7 +1843,6 @@ def with_info_before_validator_function( field_name: The name of the field schema: The schema to validate the output of the validator function ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2064,7 +1851,6 @@ def with_info_before_validator_function( function=_dict_not_none(type='with-info', function=function, field_name=field_name), schema=schema, ref=ref, - json_schema_input_schema=json_schema_input_schema, metadata=metadata, serialization=serialization, ) @@ -2079,8 +1865,7 @@ def no_info_after_validator_function( schema: CoreSchema, *, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> AfterValidatorFunctionSchema: """ @@ -2103,7 +1888,6 @@ def no_info_after_validator_function( function: The validator function to call after the schema is validated schema: The schema to validate before the validator function ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2112,7 +1896,6 @@ def no_info_after_validator_function( function={'type': 'no-info', 'function': function}, schema=schema, ref=ref, - json_schema_input_schema=json_schema_input_schema, metadata=metadata, serialization=serialization, ) @@ -2124,7 +1907,7 @@ def with_info_after_validator_function( *, field_name: str | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> AfterValidatorFunctionSchema: """ @@ -2167,11 +1950,11 @@ def with_info_after_validator_function( class ValidatorFunctionWrapHandler(Protocol): - def __call__(self, input_value: Any, outer_location: str | int | None = None, /) -> Any: # pragma: no cover + def __call__(self, input_value: Any, outer_location: str | int | None = None) -> Any: # pragma: no cover ... -# (input_value: Any, validator: ValidatorFunctionWrapHandler, /) -> Any +# (__input_value: Any, __validator: ValidatorFunctionWrapHandler) -> Any NoInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler], Any] @@ -2180,7 +1963,7 @@ class NoInfoWrapValidatorFunctionSchema(TypedDict): function: NoInfoWrapValidatorFunction -# (input_value: Any, validator: ValidatorFunctionWrapHandler, info: ValidationInfo, /) -> Any +# (__input_value: Any, __validator: ValidatorFunctionWrapHandler, __info: ValidationInfo) -> Any WithInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, ValidationInfo], Any] @@ -2198,8 +1981,7 @@ class WrapValidatorFunctionSchema(TypedDict, total=False): function: Required[WrapValidatorFunction] schema: Required[CoreSchema] ref: str - json_schema_input_schema: CoreSchema - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2208,8 +1990,7 @@ def no_info_wrap_validator_function( schema: CoreSchema, *, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> WrapValidatorFunctionSchema: """ @@ -2237,7 +2018,6 @@ def no_info_wrap_validator_function( function: The validator function to call schema: The schema to validate the output of the validator function ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2245,7 +2025,6 @@ def no_info_wrap_validator_function( type='function-wrap', function={'type': 'no-info', 'function': function}, schema=schema, - json_schema_input_schema=json_schema_input_schema, ref=ref, metadata=metadata, serialization=serialization, @@ -2257,9 +2036,8 @@ def with_info_wrap_validator_function( schema: CoreSchema, *, field_name: str | None = None, - json_schema_input_schema: CoreSchema | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> WrapValidatorFunctionSchema: """ @@ -2288,7 +2066,6 @@ def with_info_wrap_validator_function( function: The validator function to call schema: The schema to validate the output of the validator function field_name: The name of the field this validators is applied to, if any - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema @@ -2297,7 +2074,6 @@ def with_info_wrap_validator_function( type='function-wrap', function=_dict_not_none(type='with-info', function=function, field_name=field_name), schema=schema, - json_schema_input_schema=json_schema_input_schema, ref=ref, metadata=metadata, serialization=serialization, @@ -2308,8 +2084,7 @@ class PlainValidatorFunctionSchema(TypedDict, total=False): type: Required[Literal['function-plain']] function: Required[ValidationFunction] ref: str - json_schema_input_schema: CoreSchema - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2317,8 +2092,7 @@ def no_info_plain_validator_function( function: NoInfoValidatorFunction, *, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> PlainValidatorFunctionSchema: """ @@ -2339,7 +2113,6 @@ def no_info_plain_validator_function( Args: function: The validator function to call ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2347,7 +2120,6 @@ def no_info_plain_validator_function( type='function-plain', function={'type': 'no-info', 'function': function}, ref=ref, - json_schema_input_schema=json_schema_input_schema, metadata=metadata, serialization=serialization, ) @@ -2358,8 +2130,7 @@ def with_info_plain_validator_function( *, field_name: str | None = None, ref: str | None = None, - json_schema_input_schema: CoreSchema | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> PlainValidatorFunctionSchema: """ @@ -2381,7 +2152,6 @@ def with_info_plain_validator_function( function: The validator function to call field_name: The name of the field this validators is applied to, if any ref: optional unique identifier of the schema, used to reference the schema in other places - json_schema_input_schema: The core schema to be used to generate the corresponding JSON Schema input type metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ @@ -2389,7 +2159,6 @@ def with_info_plain_validator_function( type='function-plain', function=_dict_not_none(type='with-info', function=function, field_name=field_name), ref=ref, - json_schema_input_schema=json_schema_input_schema, metadata=metadata, serialization=serialization, ) @@ -2399,13 +2168,12 @@ class WithDefaultSchema(TypedDict, total=False): type: Required[Literal['default']] schema: Required[CoreSchema] default: Any - default_factory: Union[Callable[[], Any], Callable[[dict[str, Any]], Any]] - default_factory_takes_data: bool + default_factory: Callable[[], Any] on_error: Literal['raise', 'omit', 'default'] # default: 'raise' validate_default: bool # default: False strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2413,13 +2181,12 @@ def with_default_schema( schema: CoreSchema, *, default: Any = PydanticUndefined, - default_factory: Union[Callable[[], Any], Callable[[dict[str, Any]], Any], None] = None, - default_factory_takes_data: bool | None = None, + default_factory: Callable[[], Any] | None = None, on_error: Literal['raise', 'omit', 'default'] | None = None, validate_default: bool | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> WithDefaultSchema: """ @@ -2439,8 +2206,7 @@ def with_default_schema( Args: schema: The schema to add a default value to default: The default value to use - default_factory: A callable that returns the default value to use - default_factory_takes_data: Whether the default factory takes a validated data argument + default_factory: A function that returns the default value to use on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default' validate_default: Whether the default value should be validated strict: Whether the underlying schema should be validated with strict mode @@ -2452,7 +2218,6 @@ def with_default_schema( type='default', schema=schema, default_factory=default_factory, - default_factory_takes_data=default_factory_takes_data, on_error=on_error, validate_default=validate_default, strict=strict, @@ -2470,7 +2235,7 @@ class NullableSchema(TypedDict, total=False): schema: Required[CoreSchema] strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2479,7 +2244,7 @@ def nullable_schema( *, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> NullableSchema: """ @@ -2507,16 +2272,16 @@ def nullable_schema( class UnionSchema(TypedDict, total=False): type: Required[Literal['union']] - choices: Required[list[Union[CoreSchema, tuple[CoreSchema, str]]]] + choices: Required[List[Union[CoreSchema, Tuple[CoreSchema, str]]]] # default true, whether to automatically collapse unions with one element to the inner validator auto_collapse: bool custom_error_type: str custom_error_message: str - custom_error_context: dict[str, Union[str, int, float]] + custom_error_context: Dict[str, Union[str, int, float]] mode: Literal['smart', 'left_to_right'] # default: 'smart' strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2528,8 +2293,9 @@ def union_schema( custom_error_message: str | None = None, custom_error_context: dict[str, str | int] | None = None, mode: Literal['smart', 'left_to_right'] | None = None, + strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> UnionSchema: """ @@ -2553,6 +2319,7 @@ def union_schema( mode: How to select which choice to return * `smart` (default) will try to return the choice which is the closest match to the input value * `left_to_right` will return the first choice in `choices` which succeeds validation + strict: Whether the underlying schemas should be validated with strict mode ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema @@ -2565,6 +2332,7 @@ def union_schema( custom_error_message=custom_error_message, custom_error_context=custom_error_context, mode=mode, + strict=strict, ref=ref, metadata=metadata, serialization=serialization, @@ -2573,21 +2341,21 @@ def union_schema( class TaggedUnionSchema(TypedDict, total=False): type: Required[Literal['tagged-union']] - choices: Required[dict[Hashable, CoreSchema]] - discriminator: Required[Union[str, list[Union[str, int]], list[list[Union[str, int]]], Callable[[Any], Hashable]]] + choices: Required[Dict[Hashable, CoreSchema]] + discriminator: Required[Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Hashable]]] custom_error_type: str custom_error_message: str - custom_error_context: dict[str, Union[str, int, float]] + custom_error_context: Dict[str, Union[str, int, float]] strict: bool from_attributes: bool # default: True ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def tagged_union_schema( - choices: dict[Any, CoreSchema], - discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Any], + choices: Dict[Hashable, CoreSchema], + discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable], *, custom_error_type: str | None = None, custom_error_message: str | None = None, @@ -2595,7 +2363,7 @@ def tagged_union_schema( strict: bool | None = None, from_attributes: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> TaggedUnionSchema: """ @@ -2670,18 +2438,14 @@ def tagged_union_schema( class ChainSchema(TypedDict, total=False): type: Required[Literal['chain']] - steps: Required[list[CoreSchema]] + steps: Required[List[CoreSchema]] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def chain_schema( - steps: list[CoreSchema], - *, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, + steps: list[CoreSchema], *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None ) -> ChainSchema: """ Returns a schema that chains the provided validation schemas, e.g.: @@ -2716,7 +2480,7 @@ class LaxOrStrictSchema(TypedDict, total=False): strict_schema: Required[CoreSchema] strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2726,7 +2490,7 @@ def lax_or_strict_schema( *, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> LaxOrStrictSchema: """ @@ -2779,7 +2543,7 @@ class JsonOrPythonSchema(TypedDict, total=False): json_schema: Required[CoreSchema] python_schema: Required[CoreSchema] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -2788,7 +2552,7 @@ def json_or_python_schema( python_schema: CoreSchema, *, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> JsonOrPythonSchema: """ @@ -2835,10 +2599,10 @@ class TypedDictField(TypedDict, total=False): type: Required[Literal['typed-dict-field']] schema: Required[CoreSchema] required: bool - validation_alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False - metadata: dict[str, Any] + metadata: Any def typed_dict_field( @@ -2848,7 +2612,7 @@ def typed_dict_field( validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, serialization_exclude: bool | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, ) -> TypedDictField: """ Returns a schema that matches a typed dict field, e.g.: @@ -2861,7 +2625,7 @@ def typed_dict_field( Args: schema: The schema to use for the field - required: Whether the field is required, otherwise uses the value from `total` on the typed dict + required: Whether the field is required validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing serialization_exclude: Whether to exclude the field when serializing @@ -2880,33 +2644,31 @@ def typed_dict_field( class TypedDictSchema(TypedDict, total=False): type: Required[Literal['typed-dict']] - fields: Required[dict[str, TypedDictField]] - cls: type[Any] - cls_name: str - computed_fields: list[ComputedField] + fields: Required[Dict[str, TypedDictField]] + computed_fields: List[ComputedField] strict: bool extras_schema: CoreSchema # all these values can be set via config, equivalent fields have `typed_dict_` prefix extra_behavior: ExtraBehavior total: bool # default: True + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema config: CoreConfig def typed_dict_schema( - fields: dict[str, TypedDictField], + fields: Dict[str, TypedDictField], *, - cls: type[Any] | None = None, - cls_name: str | None = None, computed_fields: list[ComputedField] | None = None, strict: bool | None = None, extras_schema: CoreSchema | None = None, extra_behavior: ExtraBehavior | None = None, total: bool | None = None, + populate_by_name: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, config: CoreConfig | None = None, ) -> TypedDictSchema: @@ -2914,15 +2676,10 @@ def typed_dict_schema( Returns a schema that matches a typed dict, e.g.: ```py - from typing_extensions import TypedDict - from pydantic_core import SchemaValidator, core_schema - class MyTypedDict(TypedDict): - a: str - wrapper_schema = core_schema.typed_dict_schema( - {'a': core_schema.typed_dict_field(core_schema.str_schema())}, cls=MyTypedDict + {'a': core_schema.typed_dict_field(core_schema.str_schema())} ) v = SchemaValidator(wrapper_schema) assert v.validate_python({'a': 'hello'}) == {'a': 'hello'} @@ -2930,28 +2687,25 @@ def typed_dict_schema( Args: fields: The fields to use for the typed dict - cls: The class to use for the typed dict - cls_name: The name to use in error locations. Falls back to `cls.__name__`, or the validator name if no class - is provided. computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model strict: Whether the typed dict is strict extras_schema: The extra validator to use for the typed dict ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core extra_behavior: The extra behavior to use for the typed dict - total: Whether the typed dict is total, otherwise uses `typed_dict_total` from config + total: Whether the typed dict is total + populate_by_name: Whether the typed dict should populate by name serialization: Custom serialization schema """ return _dict_not_none( type='typed-dict', fields=fields, - cls=cls, - cls_name=cls_name, computed_fields=computed_fields, strict=strict, extras_schema=extras_schema, extra_behavior=extra_behavior, total=total, + populate_by_name=populate_by_name, ref=ref, metadata=metadata, serialization=serialization, @@ -2962,11 +2716,11 @@ def typed_dict_schema( class ModelField(TypedDict, total=False): type: Required[Literal['model-field']] schema: Required[CoreSchema] - validation_alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False frozen: bool - metadata: dict[str, Any] + metadata: Any def model_field( @@ -2976,7 +2730,7 @@ def model_field( serialization_alias: str | None = None, serialization_exclude: bool | None = None, frozen: bool | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, ) -> ModelField: """ Returns a schema for a model field, e.g.: @@ -3008,35 +2762,36 @@ def model_field( class ModelFieldsSchema(TypedDict, total=False): type: Required[Literal['model-fields']] - fields: Required[dict[str, ModelField]] + fields: Required[Dict[str, ModelField]] model_name: str - computed_fields: list[ComputedField] + computed_fields: List[ComputedField] strict: bool extras_schema: CoreSchema - extras_keys_schema: CoreSchema + # all these values can be set via config, equivalent fields have `typed_dict_` prefix extra_behavior: ExtraBehavior + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 from_attributes: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def model_fields_schema( - fields: dict[str, ModelField], + fields: Dict[str, ModelField], *, model_name: str | None = None, computed_fields: list[ComputedField] | None = None, strict: bool | None = None, extras_schema: CoreSchema | None = None, - extras_keys_schema: CoreSchema | None = None, extra_behavior: ExtraBehavior | None = None, + populate_by_name: bool | None = None, from_attributes: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> ModelFieldsSchema: """ - Returns a schema that matches the fields of a Pydantic model, e.g.: + Returns a schema that matches a typed dict, e.g.: ```py from pydantic_core import SchemaValidator, core_schema @@ -3050,16 +2805,16 @@ def model_fields_schema( ``` Args: - fields: The fields of the model + fields: The fields to use for the typed dict model_name: The name of the model, used for error messages, defaults to "Model" computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model - strict: Whether the model is strict - extras_schema: The schema to use when validating extra input data - extras_keys_schema: The schema to use when validating the keys of extra input data + strict: Whether the typed dict is strict + extras_schema: The extra validator to use for the typed dict ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core - extra_behavior: The extra behavior to use for the model fields - from_attributes: Whether the model fields should be populated from attributes + extra_behavior: The extra behavior to use for the typed dict + populate_by_name: Whether the typed dict should populate by name + from_attributes: Whether the typed dict should be populated from attributes serialization: Custom serialization schema """ return _dict_not_none( @@ -3069,8 +2824,8 @@ def model_fields_schema( computed_fields=computed_fields, strict=strict, extras_schema=extras_schema, - extras_keys_schema=extras_keys_schema, extra_behavior=extra_behavior, + populate_by_name=populate_by_name, from_attributes=from_attributes, ref=ref, metadata=metadata, @@ -3080,8 +2835,7 @@ def model_fields_schema( class ModelSchema(TypedDict, total=False): type: Required[Literal['model']] - cls: Required[type[Any]] - generic_origin: type[Any] + cls: Required[Type[Any]] schema: Required[CoreSchema] custom_init: bool root_model: bool @@ -3092,15 +2846,14 @@ class ModelSchema(TypedDict, total=False): extra_behavior: ExtraBehavior config: CoreConfig ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def model_schema( - cls: type[Any], + cls: Type[Any], schema: CoreSchema, *, - generic_origin: type[Any] | None = None, custom_init: bool | None = None, root_model: bool | None = None, post_init: str | None = None, @@ -3110,7 +2863,7 @@ def model_schema( extra_behavior: ExtraBehavior | None = None, config: CoreConfig | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> ModelSchema: """ @@ -3147,8 +2900,6 @@ def model_schema( Args: cls: The class to use for the model schema: The schema to use for the model - generic_origin: The origin type used for this model, if it's a parametrized generic. Ex, - if this model schema represents `SomeModel[int]`, generic_origin is `SomeModel` custom_init: Whether the model has a custom init method root_model: Whether the model is a `RootModel` post_init: The call after init to use for the model @@ -3165,7 +2916,6 @@ def model_schema( return _dict_not_none( type='model', cls=cls, - generic_origin=generic_origin, schema=schema, custom_init=custom_init, root_model=root_model, @@ -3186,13 +2936,12 @@ class DataclassField(TypedDict, total=False): name: Required[str] schema: Required[CoreSchema] kw_only: bool # default: True - init: bool # default: True init_only: bool # default: False frozen: bool # default: False - validation_alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] serialization_alias: str serialization_exclude: bool # default: False - metadata: dict[str, Any] + metadata: Any def dataclass_field( @@ -3200,12 +2949,11 @@ def dataclass_field( schema: CoreSchema, *, kw_only: bool | None = None, - init: bool | None = None, init_only: bool | None = None, validation_alias: str | list[str | int] | list[list[str | int]] | None = None, serialization_alias: str | None = None, serialization_exclude: bool | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, frozen: bool | None = None, ) -> DataclassField: """ @@ -3226,7 +2974,6 @@ def dataclass_field( name: The name to use for the argument parameter schema: The schema to use for the argument parameter kw_only: Whether the field can be set with a positional argument as well as a keyword argument - init: Whether the field should be validated during initialization init_only: Whether the field should be omitted from `__dict__` and passed to `__post_init__` validation_alias: The alias(es) to use to find the field in the validation data serialization_alias: The alias to use as a key when serializing @@ -3239,7 +2986,6 @@ def dataclass_field( name=name, schema=schema, kw_only=kw_only, - init=init, init_only=init_only, validation_alias=validation_alias, serialization_alias=serialization_alias, @@ -3252,11 +2998,12 @@ def dataclass_field( class DataclassArgsSchema(TypedDict, total=False): type: Required[Literal['dataclass-args']] dataclass_name: Required[str] - fields: Required[list[DataclassField]] - computed_fields: list[ComputedField] + fields: Required[List[DataclassField]] + computed_fields: List[ComputedField] + populate_by_name: bool # default: False collect_init_only: bool # default: False ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema extra_behavior: ExtraBehavior @@ -3265,10 +3012,11 @@ def dataclass_args_schema( dataclass_name: str, fields: list[DataclassField], *, - computed_fields: list[ComputedField] | None = None, + computed_fields: List[ComputedField] | None = None, + populate_by_name: bool | None = None, collect_init_only: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, extra_behavior: ExtraBehavior | None = None, ) -> DataclassArgsSchema: @@ -3293,6 +3041,7 @@ def dataclass_args_schema( dataclass_name: The name of the dataclass being validated fields: The fields to use for the dataclass computed_fields: Computed fields to use when serializing the dataclass + populate_by_name: Whether to populate by name collect_init_only: Whether to collect init only fields into a dict to pass to `__post_init__` ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3304,6 +3053,7 @@ def dataclass_args_schema( dataclass_name=dataclass_name, fields=fields, computed_fields=computed_fields, + populate_by_name=populate_by_name, collect_init_only=collect_init_only, ref=ref, metadata=metadata, @@ -3314,34 +3064,32 @@ def dataclass_args_schema( class DataclassSchema(TypedDict, total=False): type: Required[Literal['dataclass']] - cls: Required[type[Any]] - generic_origin: type[Any] + cls: Required[Type[Any]] schema: Required[CoreSchema] - fields: Required[list[str]] + fields: Required[List[str]] cls_name: str post_init: bool # default: False revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' strict: bool # default: False frozen: bool # default False ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema slots: bool config: CoreConfig def dataclass_schema( - cls: type[Any], + cls: Type[Any], schema: CoreSchema, - fields: list[str], + fields: List[str], *, - generic_origin: type[Any] | None = None, cls_name: str | None = None, post_init: bool | None = None, revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, frozen: bool | None = None, slots: bool | None = None, @@ -3356,8 +3104,6 @@ def dataclass_schema( schema: The schema to use for the dataclass fields fields: Fields of the dataclass, this is used in serialization and in validation during re-validation and while validating assignment - generic_origin: The origin type used for this dataclass, if it's a parametrized generic. Ex, - if this model schema represents `SomeDataclass[int]`, generic_origin is `SomeDataclass` cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`) post_init: Whether to call `__post_init__` after validation revalidate_instances: whether instances of models and dataclasses (including subclass instances) @@ -3373,7 +3119,6 @@ def dataclass_schema( return _dict_not_none( type='dataclass', cls=cls, - generic_origin=generic_origin, fields=fields, cls_name=cls_name, schema=schema, @@ -3393,7 +3138,7 @@ class ArgumentsParameter(TypedDict, total=False): name: Required[str] schema: Required[CoreSchema] mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] # default positional_or_keyword - alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] + alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] def arguments_parameter( @@ -3426,32 +3171,25 @@ def arguments_parameter( return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias) -VarKwargsMode: TypeAlias = Literal['uniform', 'unpacked-typed-dict'] - - class ArgumentsSchema(TypedDict, total=False): type: Required[Literal['arguments']] - arguments_schema: Required[list[ArgumentsParameter]] - validate_by_name: bool - validate_by_alias: bool + arguments_schema: Required[List[ArgumentsParameter]] + populate_by_name: bool var_args_schema: CoreSchema - var_kwargs_mode: VarKwargsMode var_kwargs_schema: CoreSchema ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def arguments_schema( arguments: list[ArgumentsParameter], *, - validate_by_name: bool | None = None, - validate_by_alias: bool | None = None, + populate_by_name: bool | None = None, var_args_schema: CoreSchema | None = None, - var_kwargs_mode: VarKwargsMode | None = None, var_kwargs_schema: CoreSchema | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> ArgumentsSchema: """ @@ -3473,12 +3211,8 @@ def arguments_schema( Args: arguments: The arguments to use for the arguments schema - validate_by_name: Whether to populate by the parameter names, defaults to `False`. - validate_by_alias: Whether to populate by the parameter aliases, defaults to `True`. + populate_by_name: Whether to populate by name var_args_schema: The variable args schema to use for the arguments schema - var_kwargs_mode: The validation mode to use for variadic keyword arguments. If `'uniform'`, every value of the - keyword arguments will be validated against the `var_kwargs_schema` schema. If `'unpacked-typed-dict'`, - the `var_kwargs_schema` argument must be a [`typed_dict_schema`][pydantic_core.core_schema.typed_dict_schema] var_kwargs_schema: The variable kwargs schema to use for the arguments schema ref: optional unique identifier of the schema, used to reference the schema in other places metadata: Any other information you want to include with the schema, not used by pydantic-core @@ -3487,10 +3221,8 @@ def arguments_schema( return _dict_not_none( type='arguments', arguments_schema=arguments, - validate_by_name=validate_by_name, - validate_by_alias=validate_by_alias, + populate_by_name=populate_by_name, var_args_schema=var_args_schema, - var_kwargs_mode=var_kwargs_mode, var_kwargs_schema=var_kwargs_schema, ref=ref, metadata=metadata, @@ -3498,120 +3230,6 @@ def arguments_schema( ) -class ArgumentsV3Parameter(TypedDict, total=False): - name: Required[str] - schema: Required[CoreSchema] - mode: Literal[ - 'positional_only', - 'positional_or_keyword', - 'keyword_only', - 'var_args', - 'var_kwargs_uniform', - 'var_kwargs_unpacked_typed_dict', - ] # default positional_or_keyword - alias: Union[str, list[Union[str, int]], list[list[Union[str, int]]]] - - -def arguments_v3_parameter( - name: str, - schema: CoreSchema, - *, - mode: Literal[ - 'positional_only', - 'positional_or_keyword', - 'keyword_only', - 'var_args', - 'var_kwargs_uniform', - 'var_kwargs_unpacked_typed_dict', - ] - | None = None, - alias: str | list[str | int] | list[list[str | int]] | None = None, -) -> ArgumentsV3Parameter: - """ - Returns a schema that matches an argument parameter, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - param = core_schema.arguments_v3_parameter( - name='a', schema=core_schema.str_schema(), mode='positional_only' - ) - schema = core_schema.arguments_v3_schema([param]) - v = SchemaValidator(schema) - assert v.validate_python({'a': 'hello'}) == (('hello',), {}) - ``` - - Args: - name: The name to use for the argument parameter - schema: The schema to use for the argument parameter - mode: The mode to use for the argument parameter - alias: The alias to use for the argument parameter - """ - return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias) - - -class ArgumentsV3Schema(TypedDict, total=False): - type: Required[Literal['arguments-v3']] - arguments_schema: Required[list[ArgumentsV3Parameter]] - validate_by_name: bool - validate_by_alias: bool - extra_behavior: Literal['forbid', 'ignore'] # 'allow' doesn't make sense here. - ref: str - metadata: dict[str, Any] - serialization: SerSchema - - -def arguments_v3_schema( - arguments: list[ArgumentsV3Parameter], - *, - validate_by_name: bool | None = None, - validate_by_alias: bool | None = None, - extra_behavior: Literal['forbid', 'ignore'] | None = None, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, -) -> ArgumentsV3Schema: - """ - Returns a schema that matches an arguments schema, e.g.: - - ```py - from pydantic_core import SchemaValidator, core_schema - - param_a = core_schema.arguments_v3_parameter( - name='a', schema=core_schema.str_schema(), mode='positional_only' - ) - param_b = core_schema.arguments_v3_parameter( - name='kwargs', schema=core_schema.bool_schema(), mode='var_kwargs_uniform' - ) - schema = core_schema.arguments_v3_schema([param_a, param_b]) - v = SchemaValidator(schema) - assert v.validate_python({'a': 'hi', 'kwargs': {'b': True}}) == (('hi',), {'b': True}) - ``` - - This schema is currently not used by other Pydantic components. In V3, it will most likely - become the default arguments schema for the `'call'` schema. - - Args: - arguments: The arguments to use for the arguments schema. - validate_by_name: Whether to populate by the parameter names, defaults to `False`. - validate_by_alias: Whether to populate by the parameter aliases, defaults to `True`. - extra_behavior: The extra behavior to use. - ref: optional unique identifier of the schema, used to reference the schema in other places. - metadata: Any other information you want to include with the schema, not used by pydantic-core. - serialization: Custom serialization schema. - """ - return _dict_not_none( - type='arguments-v3', - arguments_schema=arguments, - validate_by_name=validate_by_name, - validate_by_alias=validate_by_alias, - extra_behavior=extra_behavior, - ref=ref, - metadata=metadata, - serialization=serialization, - ) - - class CallSchema(TypedDict, total=False): type: Required[Literal['call']] arguments_schema: Required[CoreSchema] @@ -3619,7 +3237,7 @@ class CallSchema(TypedDict, total=False): function_name: str # default function.__name__ return_schema: CoreSchema ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -3630,7 +3248,7 @@ def call_schema( function_name: str | None = None, return_schema: CoreSchema | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> CallSchema: """ @@ -3682,9 +3300,9 @@ class CustomErrorSchema(TypedDict, total=False): schema: Required[CoreSchema] custom_error_type: Required[str] custom_error_message: str - custom_error_context: dict[str, Union[str, int, float]] + custom_error_context: Dict[str, Union[str, int, float]] ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -3695,7 +3313,7 @@ def custom_error_schema( custom_error_message: str | None = None, custom_error_context: dict[str, Any] | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> CustomErrorSchema: """ @@ -3738,7 +3356,7 @@ class JsonSchema(TypedDict, total=False): type: Required[Literal['json']] schema: CoreSchema ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -3746,7 +3364,7 @@ def json_schema( schema: CoreSchema | None = None, *, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> JsonSchema: """ @@ -3791,14 +3409,14 @@ def json_schema( class UrlSchema(TypedDict, total=False): type: Required[Literal['url']] max_length: int - allowed_schemes: list[str] + allowed_schemes: List[str] host_required: bool # default False default_host: str default_port: int default_path: str strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -3812,7 +3430,7 @@ def url_schema( default_path: str | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> UrlSchema: """ @@ -3857,14 +3475,14 @@ def url_schema( class MultiHostUrlSchema(TypedDict, total=False): type: Required[Literal['multi-host-url']] max_length: int - allowed_schemes: list[str] + allowed_schemes: List[str] host_required: bool # default False default_host: str default_port: int default_path: str strict: bool ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema @@ -3878,7 +3496,7 @@ def multi_host_url_schema( default_path: str | None = None, strict: bool | None = None, ref: str | None = None, - metadata: dict[str, Any] | None = None, + metadata: Any = None, serialization: SerSchema | None = None, ) -> MultiHostUrlSchema: """ @@ -3923,8 +3541,8 @@ def multi_host_url_schema( class DefinitionsSchema(TypedDict, total=False): type: Required[Literal['definitions']] schema: Required[CoreSchema] - definitions: Required[list[CoreSchema]] - metadata: dict[str, Any] + definitions: Required[List[CoreSchema]] + metadata: Any serialization: SerSchema @@ -3954,16 +3572,12 @@ def definitions_schema(schema: CoreSchema, definitions: list[CoreSchema]) -> Def class DefinitionReferenceSchema(TypedDict, total=False): type: Required[Literal['definition-ref']] schema_ref: Required[str] - ref: str - metadata: dict[str, Any] + metadata: Any serialization: SerSchema def definition_reference_schema( - schema_ref: str, - ref: str | None = None, - metadata: dict[str, Any] | None = None, - serialization: SerSchema | None = None, + schema_ref: str, metadata: Any = None, serialization: SerSchema | None = None ) -> DefinitionReferenceSchema: """ Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive @@ -3988,9 +3602,7 @@ def definition_reference_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ - return _dict_not_none( - type='definition-ref', schema_ref=schema_ref, ref=ref, metadata=metadata, serialization=serialization - ) + return _dict_not_none(type='definition-ref', schema_ref=schema_ref, metadata=metadata, serialization=serialization) MYPY = False @@ -3998,7 +3610,6 @@ MYPY = False # union which kills performance not just for pydantic, but even for code using pydantic if not MYPY: CoreSchema = Union[ - InvalidSchema, AnySchema, NoneSchema, BoolSchema, @@ -4012,12 +3623,12 @@ if not MYPY: DatetimeSchema, TimedeltaSchema, LiteralSchema, - EnumSchema, IsInstanceSchema, IsSubclassSchema, CallableSchema, ListSchema, - TupleSchema, + TuplePositionalSchema, + TupleVariableSchema, SetSchema, FrozenSetSchema, GeneratorSchema, @@ -4039,7 +3650,6 @@ if not MYPY: DataclassArgsSchema, DataclassSchema, ArgumentsSchema, - ArgumentsV3Schema, CallSchema, CustomErrorSchema, JsonSchema, @@ -4048,7 +3658,6 @@ if not MYPY: DefinitionsSchema, DefinitionReferenceSchema, UuidSchema, - ComplexSchema, ] elif False: CoreSchema: TypeAlias = Mapping[str, Any] @@ -4056,7 +3665,6 @@ elif False: # to update this, call `pytest -k test_core_schema_type_literal` and copy the output CoreSchemaType = Literal[ - 'invalid', 'any', 'none', 'bool', @@ -4070,12 +3678,12 @@ CoreSchemaType = Literal[ 'datetime', 'timedelta', 'literal', - 'enum', 'is-instance', 'is-subclass', 'callable', 'list', - 'tuple', + 'tuple-positional', + 'tuple-variable', 'set', 'frozenset', 'generator', @@ -4097,7 +3705,6 @@ CoreSchemaType = Literal[ 'dataclass-args', 'dataclass', 'arguments', - 'arguments-v3', 'call', 'custom-error', 'json', @@ -4106,7 +3713,6 @@ CoreSchemaType = Literal[ 'definitions', 'definition-ref', 'uuid', - 'complex', ] CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] @@ -4118,7 +3724,6 @@ ErrorType = Literal[ 'no_such_attribute', 'json_invalid', 'json_type', - 'needs_python_object', 'recursion_loop', 'missing', 'frozen_field', @@ -4153,7 +3758,6 @@ ErrorType = Literal[ 'list_type', 'tuple_type', 'set_type', - 'set_item_not_hashable', 'bool_type', 'bool_parsing', 'int_type', @@ -4165,7 +3769,6 @@ ErrorType = Literal[ 'bytes_type', 'bytes_too_short', 'bytes_too_long', - 'bytes_invalid_encoding', 'value_error', 'assertion_error', 'literal_error', @@ -4180,7 +3783,6 @@ ErrorType = Literal[ 'datetime_type', 'datetime_parsing', 'datetime_object_invalid', - 'datetime_from_date_parsing', 'datetime_past', 'datetime_future', 'timezone_naive', @@ -4214,8 +3816,6 @@ ErrorType = Literal[ 'decimal_max_digits', 'decimal_max_places', 'decimal_whole_digits', - 'complex_type', - 'complex_str_parsing', ] @@ -4258,7 +3858,7 @@ def field_after_validator_function(function: WithInfoValidatorFunction, field_na @deprecated('`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') def general_after_validator_function(*args, **kwargs): warnings.warn( - '`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + '`with_info_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', DeprecationWarning, ) return with_info_after_validator_function(*args, **kwargs) @@ -4309,9 +3909,6 @@ _deprecated_import_lookup = { 'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction, } -if TYPE_CHECKING: - FieldValidationInfo = ValidationInfo - def __getattr__(attr_name: str) -> object: new_attr = _deprecated_import_lookup.get(attr_name) diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/METADATA similarity index 56% rename from venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/METADATA index b328340..53a4683 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/METADATA @@ -1,12 +1,12 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: pydantic-settings -Version: 2.11.0 +Version: 2.0.3 Summary: Settings management using Pydantic Project-URL: Homepage, https://github.com/pydantic/pydantic-settings Project-URL: Funding, https://github.com/sponsors/samuelcolvin Project-URL: Source, https://github.com/pydantic/pydantic-settings Project-URL: Changelog, https://github.com/pydantic/pydantic-settings/releases -Project-URL: Documentation, https://docs.pydantic.dev/dev-v2/concepts/pydantic_settings/ +Project-URL: Documentation, https://docs.pydantic.dev/dev-v2/usage/pydantic_settings/ Author-email: Samuel Colvin , Eric Jolibois , Hasan Ramezani License-Expression: MIT License-File: LICENSE @@ -24,40 +24,30 @@ Classifier: Operating System :: Unix Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Topic :: Internet Classifier: Topic :: Software Development :: Libraries :: Python Modules -Requires-Python: >=3.9 -Requires-Dist: pydantic>=2.7.0 +Requires-Python: >=3.7 +Requires-Dist: pydantic>=2.0.1 Requires-Dist: python-dotenv>=0.21.0 -Requires-Dist: typing-inspection>=0.4.0 -Provides-Extra: aws-secrets-manager -Requires-Dist: boto3-stubs[secretsmanager]; extra == 'aws-secrets-manager' -Requires-Dist: boto3>=1.35.0; extra == 'aws-secrets-manager' -Provides-Extra: azure-key-vault -Requires-Dist: azure-identity>=1.16.0; extra == 'azure-key-vault' -Requires-Dist: azure-keyvault-secrets>=4.8.0; extra == 'azure-key-vault' -Provides-Extra: gcp-secret-manager -Requires-Dist: google-cloud-secret-manager>=2.23.1; extra == 'gcp-secret-manager' -Provides-Extra: toml -Requires-Dist: tomli>=2.0.1; extra == 'toml' -Provides-Extra: yaml -Requires-Dist: pyyaml>=6.0.1; extra == 'yaml' Description-Content-Type: text/markdown # pydantic-settings -[![CI](https://github.com/pydantic/pydantic-settings/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-settings/actions/workflows/ci.yml?query=branch%3Amain) +[![CI](https://github.com/pydantic/pydantic-settings/workflows/CI/badge.svg?event=push)](https://github.com/pydantic/pydantic-settings/actions?query=event%3Apush+branch%3Amain+workflow%3ACI) [![Coverage](https://codecov.io/gh/pydantic/pydantic-settings/branch/main/graph/badge.svg)](https://codecov.io/gh/pydantic/pydantic-settings) [![pypi](https://img.shields.io/pypi/v/pydantic-settings.svg)](https://pypi.python.org/pypi/pydantic-settings) [![license](https://img.shields.io/github/license/pydantic/pydantic-settings.svg)](https://github.com/pydantic/pydantic-settings/blob/main/LICENSE) -[![downloads](https://static.pepy.tech/badge/pydantic-settings/month)](https://pepy.tech/project/pydantic-settings) -[![versions](https://img.shields.io/pypi/pyversions/pydantic-settings.svg)](https://github.com/pydantic/pydantic-settings) -Settings management using Pydantic. +Settings management using Pydantic, this is the new official home of Pydantic's `BaseSettings`. -See [documentation](https://docs.pydantic.dev/latest/concepts/pydantic_settings/) for more details. +This package was kindly donated to the [Pydantic organisation](https://github.com/pydantic) by Daniel Daniels, see [pydantic/pydantic#4492](https://github.com/pydantic/pydantic/pull/4492) for discussion. + +For the old "Hipster-orgazmic tool to mange application settings" package, see [version 0.2.5](https://pypi.org/project/pydantic-settings/0.2.5/). + + +See [documentation](https://docs.pydantic.dev/latest/usage/pydantic_settings/) for more details. diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/RECORD new file mode 100644 index 0000000..0015193 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/RECORD @@ -0,0 +1,17 @@ +pydantic_settings-2.0.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +pydantic_settings-2.0.3.dist-info/METADATA,sha256=iuDM6bM6VDeLKrOyfSRQiE4Bp_SqFNmDvNYxjNlojEU,2924 +pydantic_settings-2.0.3.dist-info/RECORD,, +pydantic_settings-2.0.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic_settings-2.0.3.dist-info/WHEEL,sha256=9QBuHhg6FNW7lppboF2vKVbCGTVzsFykgRQjjlajrhA,87 +pydantic_settings-2.0.3.dist-info/licenses/LICENSE,sha256=6zVadT4CA0bTPYO_l2kTW4n8YQVorFMaAcKVvO5_2Zg,1103 +pydantic_settings/__init__.py,sha256=h0HRyW_I6s0YYFIB-qx8gNZOtDI8vCbXnwPbp4BqwzE,482 +pydantic_settings/__pycache__/__init__.cpython-312.pyc,, +pydantic_settings/__pycache__/main.cpython-312.pyc,, +pydantic_settings/__pycache__/sources.cpython-312.pyc,, +pydantic_settings/__pycache__/utils.cpython-312.pyc,, +pydantic_settings/__pycache__/version.cpython-312.pyc,, +pydantic_settings/main.py,sha256=DPJPyjM9g7CgaB8-zuoydot1iYVuLOb05rJZUXDt1-o,7178 +pydantic_settings/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +pydantic_settings/sources.py,sha256=ruCzD_1mL9e20o-33B7n46cTE5COCJ0524w29uED5BM,24857 +pydantic_settings/utils.py,sha256=nomYSaFO_IegfWSL9KJ8SAtLZgyhcruLgE3dTHwSmgo,557 +pydantic_settings/version.py,sha256=gemzbOzXm8MxToVh3wokBkbvZFRFfCkFQumP9kJFca4,18 diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/WHEEL similarity index 67% rename from venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/WHEEL index 12228d4..ba1a8af 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/WHEEL @@ -1,4 +1,4 @@ Wheel-Version: 1.0 -Generator: hatchling 1.27.0 +Generator: hatchling 1.18.0 Root-Is-Purelib: true Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/licenses/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/pydantic_settings-2.0.3.dist-info/licenses/LICENSE diff --git a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/RECORD deleted file mode 100644 index af36bb9..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings-2.11.0.dist-info/RECORD +++ /dev/null @@ -1,49 +0,0 @@ -pydantic_settings-2.11.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -pydantic_settings-2.11.0.dist-info/METADATA,sha256=PDGByqQ8O-pOIP1ulD_GA9MtU82OhPvAx_XwBZo2z8M,3393 -pydantic_settings-2.11.0.dist-info/RECORD,, -pydantic_settings-2.11.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic_settings-2.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 -pydantic_settings-2.11.0.dist-info/licenses/LICENSE,sha256=6zVadT4CA0bTPYO_l2kTW4n8YQVorFMaAcKVvO5_2Zg,1103 -pydantic_settings/__init__.py,sha256=IUkO5TkUu6eYgRJhA1piTw4jp6-CBhV7kam0rEh1Flo,1563 -pydantic_settings/__pycache__/__init__.cpython-312.pyc,, -pydantic_settings/__pycache__/exceptions.cpython-312.pyc,, -pydantic_settings/__pycache__/main.cpython-312.pyc,, -pydantic_settings/__pycache__/utils.cpython-312.pyc,, -pydantic_settings/__pycache__/version.cpython-312.pyc,, -pydantic_settings/exceptions.py,sha256=SHLrIBHeFltPMc8abiQxw-MGqEadlYI-VdLELiZtWPU,97 -pydantic_settings/main.py,sha256=KR_ut942bw5hQLqA1aAGE7niHLHUr-Ca-gOK6ZNmL1k,32156 -pydantic_settings/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -pydantic_settings/sources/__init__.py,sha256=Ti1bRZb0r7IxkO-wJWKy-qEpeBUFKYRpa3A1AQodOyk,2052 -pydantic_settings/sources/__pycache__/__init__.cpython-312.pyc,, -pydantic_settings/sources/__pycache__/base.cpython-312.pyc,, -pydantic_settings/sources/__pycache__/types.cpython-312.pyc,, -pydantic_settings/sources/__pycache__/utils.cpython-312.pyc,, -pydantic_settings/sources/base.py,sha256=N8DOFzKuNPdzVFt22gcSHqs_GHUqCc8AfTRZuWVfl84,20921 -pydantic_settings/sources/providers/__init__.py,sha256=jBTurqBXeJvMfTl2lvHr2iDVDOvHfO-8PVNJiKt7MBk,1205 -pydantic_settings/sources/providers/__pycache__/__init__.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/aws.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/azure.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/cli.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/dotenv.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/env.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/gcp.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/json.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/pyproject.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/secrets.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/toml.cpython-312.pyc,, -pydantic_settings/sources/providers/__pycache__/yaml.cpython-312.pyc,, -pydantic_settings/sources/providers/aws.py,sha256=y-GXXP-dQ9kewMWpPQ9sHYFZ2KfrO0vMNYqVtiF1ysg,2549 -pydantic_settings/sources/providers/azure.py,sha256=X_u5hYjysUTM7B0iPlEdY0nRfEw7AOPU_ALjixNS57Q,5004 -pydantic_settings/sources/providers/cli.py,sha256=c4hb980ZLMlvqgQaEpkNaMBQTRT153jNpSrsMu6cQb4,61842 -pydantic_settings/sources/providers/dotenv.py,sha256=X4fkql4sEyaEaK9WV1xUpxRAiJhMFvgj4DMODdUV_bA,5956 -pydantic_settings/sources/providers/env.py,sha256=E2q9YHjFrFUWAid2VpY3678PDSuIDQc_47iWcz_ojQ4,10717 -pydantic_settings/sources/providers/gcp.py,sha256=3bFh75aZp6mmn12VihQycND-5CLgnYWg6HBfNvIV26U,5644 -pydantic_settings/sources/providers/json.py,sha256=k0hWDu0fNLrI5z3zWTGtlKyR0xx-2pOPu-oWjwqmVXo,1436 -pydantic_settings/sources/providers/pyproject.py,sha256=zSQsV3-jtZhiLm3YlrlYoE2__tZBazp0KjQyKLNyLr0,2052 -pydantic_settings/sources/providers/secrets.py,sha256=JLMIj3VVwp86foGTP8fb6zWddmYpELBu95Ldzobnsw8,4303 -pydantic_settings/sources/providers/toml.py,sha256=5k9wMJbKrUqXNiCM5G1hYnCOEZNUJJBTAzFw6Pv2K6A,1827 -pydantic_settings/sources/providers/yaml.py,sha256=mhjmOkrwLT16AEGNDuYoex2PYHejusn7Y0J4KL6SVbw,2305 -pydantic_settings/sources/types.py,sha256=8TT7eJvOam2-B2M2TYS-z4XTIyckBmbluw96ayVnWHc,1513 -pydantic_settings/sources/utils.py,sha256=0fQ2yDBzxqrmvwHLeSu5ASfdHkbRviCFsqQADfSqk40,7601 -pydantic_settings/utils.py,sha256=SkOfKGo0omDB4REfg31XSO8yVmpzCQgeIcdg-qqcSrk,1382 -pydantic_settings/version.py,sha256=pneluWHKumnFZTuxtZZ1nYFMF9LFmg5igdk6rGxXKhU,19 diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/__init__.py b/venv/lib/python3.12/site-packages/pydantic_settings/__init__.py index 60990a8..7b99f88 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings/__init__.py +++ b/venv/lib/python3.12/site-packages/pydantic_settings/__init__.py @@ -1,63 +1,22 @@ -from .exceptions import SettingsError -from .main import BaseSettings, CliApp, SettingsConfigDict +from .main import BaseSettings, SettingsConfigDict from .sources import ( - CLI_SUPPRESS, - AWSSecretsManagerSettingsSource, - AzureKeyVaultSettingsSource, - CliExplicitFlag, - CliImplicitFlag, - CliMutuallyExclusiveGroup, - CliPositionalArg, - CliSettingsSource, - CliSubCommand, - CliSuppress, - CliUnknownArgs, DotEnvSettingsSource, EnvSettingsSource, - ForceDecode, - GoogleSecretManagerSettingsSource, InitSettingsSource, - JsonConfigSettingsSource, - NoDecode, PydanticBaseSettingsSource, - PyprojectTomlConfigSettingsSource, SecretsSettingsSource, - TomlConfigSettingsSource, - YamlConfigSettingsSource, - get_subcommand, ) from .version import VERSION __all__ = ( - 'CLI_SUPPRESS', - 'AWSSecretsManagerSettingsSource', - 'AzureKeyVaultSettingsSource', 'BaseSettings', - 'CliApp', - 'CliExplicitFlag', - 'CliImplicitFlag', - 'CliMutuallyExclusiveGroup', - 'CliPositionalArg', - 'CliSettingsSource', - 'CliSubCommand', - 'CliSuppress', - 'CliUnknownArgs', 'DotEnvSettingsSource', 'EnvSettingsSource', - 'ForceDecode', - 'GoogleSecretManagerSettingsSource', 'InitSettingsSource', - 'JsonConfigSettingsSource', - 'NoDecode', 'PydanticBaseSettingsSource', - 'PyprojectTomlConfigSettingsSource', 'SecretsSettingsSource', 'SettingsConfigDict', - 'SettingsError', - 'TomlConfigSettingsSource', - 'YamlConfigSettingsSource', '__version__', - 'get_subcommand', ) __version__ = VERSION diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/exceptions.py b/venv/lib/python3.12/site-packages/pydantic_settings/exceptions.py deleted file mode 100644 index 90806c6..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/exceptions.py +++ /dev/null @@ -1,4 +0,0 @@ -class SettingsError(ValueError): - """Base exception for settings-related errors.""" - - pass diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/main.py b/venv/lib/python3.12/site-packages/pydantic_settings/main.py index f72e950..64e9d64 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings/main.py +++ b/venv/lib/python3.12/site-packages/pydantic_settings/main.py @@ -1,104 +1,31 @@ from __future__ import annotations as _annotations -import asyncio -import inspect -import threading -import warnings -from argparse import Namespace -from collections.abc import Mapping -from types import SimpleNamespace -from typing import Any, ClassVar, TypeVar +from pathlib import Path +from typing import Any, ClassVar from pydantic import ConfigDict from pydantic._internal._config import config_keys -from pydantic._internal._signature import _field_name_for_signature -from pydantic._internal._utils import deep_update, is_model_class -from pydantic.dataclasses import is_pydantic_dataclass +from pydantic._internal._utils import deep_update from pydantic.main import BaseModel -from .exceptions import SettingsError from .sources import ( ENV_FILE_SENTINEL, - CliSettingsSource, - DefaultSettingsSource, DotEnvSettingsSource, DotenvType, EnvSettingsSource, InitSettingsSource, - JsonConfigSettingsSource, - PathType, PydanticBaseSettingsSource, - PydanticModel, - PyprojectTomlConfigSettingsSource, SecretsSettingsSource, - TomlConfigSettingsSource, - YamlConfigSettingsSource, - get_subcommand, ) -T = TypeVar('T') - class SettingsConfigDict(ConfigDict, total=False): case_sensitive: bool - nested_model_default_partial_update: bool | None env_prefix: str env_file: DotenvType | None env_file_encoding: str | None - env_ignore_empty: bool env_nested_delimiter: str | None - env_nested_max_split: int | None - env_parse_none_str: str | None - env_parse_enums: bool | None - cli_prog_name: str | None - cli_parse_args: bool | list[str] | tuple[str, ...] | None - cli_parse_none_str: str | None - cli_hide_none_type: bool - cli_avoid_json: bool - cli_enforce_required: bool - cli_use_class_docs_for_groups: bool - cli_exit_on_error: bool - cli_prefix: str - cli_flag_prefix_char: str - cli_implicit_flags: bool | None - cli_ignore_unknown_args: bool | None - cli_kebab_case: bool | None - cli_shortcuts: Mapping[str, str | list[str]] | None - secrets_dir: PathType | None - json_file: PathType | None - json_file_encoding: str | None - yaml_file: PathType | None - yaml_file_encoding: str | None - yaml_config_section: str | None - """ - Specifies the top-level key in a YAML file from which to load the settings. - If provided, the settings will be loaded from the nested section under this key. - This is useful when the YAML file contains multiple configuration sections - and you only want to load a specific subset into your settings model. - """ - - pyproject_toml_depth: int - """ - Number of levels **up** from the current working directory to attempt to find a pyproject.toml - file. - - This is only used when a pyproject.toml file is not found in the current working directory. - """ - - pyproject_toml_table_header: tuple[str, ...] - """ - Header of the TOML table within a pyproject.toml file to use when filling variables. - This is supplied as a `tuple[str, ...]` instead of a `str` to accommodate for headers - containing a `.`. - - For example, `toml_table_header = ("tool", "my.tool", "foo")` can be used to fill variable - values from a table with header `[tool."my.tool".foo]`. - - To use the root table, exclude this config setting or provide an empty tuple. - """ - - toml_file: PathType | None - enable_decoding: bool + secrets_dir: str | Path | None # Extend `config_keys` by pydantic settings config keys to @@ -120,104 +47,35 @@ class BaseSettings(BaseModel): All the below attributes can be set via `model_config`. Args: - _case_sensitive: Whether environment and CLI variable names should be read with case-sensitivity. - Defaults to `None`. - _nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields. - Defaults to `False`. + _case_sensitive: Whether environment variables names should be read with case-sensitivity. Defaults to `None`. _env_prefix: Prefix for all environment variables. Defaults to `None`. _env_file: The env file(s) to load settings values from. Defaults to `Path('')`, which means that the value from `model_config['env_file']` should be used. You can also pass `None` to indicate that environment variables should not be loaded from an env file. _env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`. - _env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`. _env_nested_delimiter: The nested env values delimiter. Defaults to `None`. - _env_nested_max_split: The nested env values maximum nesting. Defaults to `None`, which means no limit. - _env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.) - into `None` type(None). Defaults to `None` type(None), which means no parsing should occur. - _env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur. - _cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`. - Otherwise, defaults to sys.argv[0]. - _cli_parse_args: The list of CLI arguments to parse. Defaults to None. - If set to `True`, defaults to sys.argv[1:]. - _cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None. - _cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into - `None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if - _cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`. - _cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`. - _cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`. - _cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`. - _cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions. - Defaults to `False`. - _cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs. - Defaults to `True`. - _cli_prefix: The root parser command line arguments prefix. Defaults to "". - _cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'. - _cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. - (e.g. --flag, --no-flag). Defaults to `False`. - _cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`. - _cli_kebab_case: CLI args use kebab case. Defaults to `False`. - _cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`. - _secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`. + _secrets_dir: The secret files directory. Defaults to `None`. """ def __init__( __pydantic_self__, _case_sensitive: bool | None = None, - _nested_model_default_partial_update: bool | None = None, _env_prefix: str | None = None, _env_file: DotenvType | None = ENV_FILE_SENTINEL, _env_file_encoding: str | None = None, - _env_ignore_empty: bool | None = None, _env_nested_delimiter: str | None = None, - _env_nested_max_split: int | None = None, - _env_parse_none_str: str | None = None, - _env_parse_enums: bool | None = None, - _cli_prog_name: str | None = None, - _cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, - _cli_settings_source: CliSettingsSource[Any] | None = None, - _cli_parse_none_str: str | None = None, - _cli_hide_none_type: bool | None = None, - _cli_avoid_json: bool | None = None, - _cli_enforce_required: bool | None = None, - _cli_use_class_docs_for_groups: bool | None = None, - _cli_exit_on_error: bool | None = None, - _cli_prefix: str | None = None, - _cli_flag_prefix_char: str | None = None, - _cli_implicit_flags: bool | None = None, - _cli_ignore_unknown_args: bool | None = None, - _cli_kebab_case: bool | None = None, - _cli_shortcuts: Mapping[str, str | list[str]] | None = None, - _secrets_dir: PathType | None = None, + _secrets_dir: str | Path | None = None, **values: Any, ) -> None: + # Uses something other than `self` the first arg to allow "self" as a settable attribute super().__init__( **__pydantic_self__._settings_build_values( values, _case_sensitive=_case_sensitive, - _nested_model_default_partial_update=_nested_model_default_partial_update, _env_prefix=_env_prefix, _env_file=_env_file, _env_file_encoding=_env_file_encoding, - _env_ignore_empty=_env_ignore_empty, _env_nested_delimiter=_env_nested_delimiter, - _env_nested_max_split=_env_nested_max_split, - _env_parse_none_str=_env_parse_none_str, - _env_parse_enums=_env_parse_enums, - _cli_prog_name=_cli_prog_name, - _cli_parse_args=_cli_parse_args, - _cli_settings_source=_cli_settings_source, - _cli_parse_none_str=_cli_parse_none_str, - _cli_hide_none_type=_cli_hide_none_type, - _cli_avoid_json=_cli_avoid_json, - _cli_enforce_required=_cli_enforce_required, - _cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups, - _cli_exit_on_error=_cli_exit_on_error, - _cli_prefix=_cli_prefix, - _cli_flag_prefix_char=_cli_flag_prefix_char, - _cli_implicit_flags=_cli_implicit_flags, - _cli_ignore_unknown_args=_cli_ignore_unknown_args, - _cli_kebab_case=_cli_kebab_case, - _cli_shortcuts=_cli_shortcuts, _secrets_dir=_secrets_dir, ) ) @@ -250,125 +108,33 @@ class BaseSettings(BaseModel): self, init_kwargs: dict[str, Any], _case_sensitive: bool | None = None, - _nested_model_default_partial_update: bool | None = None, _env_prefix: str | None = None, _env_file: DotenvType | None = None, _env_file_encoding: str | None = None, - _env_ignore_empty: bool | None = None, _env_nested_delimiter: str | None = None, - _env_nested_max_split: int | None = None, - _env_parse_none_str: str | None = None, - _env_parse_enums: bool | None = None, - _cli_prog_name: str | None = None, - _cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, - _cli_settings_source: CliSettingsSource[Any] | None = None, - _cli_parse_none_str: str | None = None, - _cli_hide_none_type: bool | None = None, - _cli_avoid_json: bool | None = None, - _cli_enforce_required: bool | None = None, - _cli_use_class_docs_for_groups: bool | None = None, - _cli_exit_on_error: bool | None = None, - _cli_prefix: str | None = None, - _cli_flag_prefix_char: str | None = None, - _cli_implicit_flags: bool | None = None, - _cli_ignore_unknown_args: bool | None = None, - _cli_kebab_case: bool | None = None, - _cli_shortcuts: Mapping[str, str | list[str]] | None = None, - _secrets_dir: PathType | None = None, + _secrets_dir: str | Path | None = None, ) -> dict[str, Any]: # Determine settings config values case_sensitive = _case_sensitive if _case_sensitive is not None else self.model_config.get('case_sensitive') env_prefix = _env_prefix if _env_prefix is not None else self.model_config.get('env_prefix') - nested_model_default_partial_update = ( - _nested_model_default_partial_update - if _nested_model_default_partial_update is not None - else self.model_config.get('nested_model_default_partial_update') - ) env_file = _env_file if _env_file != ENV_FILE_SENTINEL else self.model_config.get('env_file') env_file_encoding = ( _env_file_encoding if _env_file_encoding is not None else self.model_config.get('env_file_encoding') ) - env_ignore_empty = ( - _env_ignore_empty if _env_ignore_empty is not None else self.model_config.get('env_ignore_empty') - ) env_nested_delimiter = ( _env_nested_delimiter if _env_nested_delimiter is not None else self.model_config.get('env_nested_delimiter') ) - env_nested_max_split = ( - _env_nested_max_split - if _env_nested_max_split is not None - else self.model_config.get('env_nested_max_split') - ) - env_parse_none_str = ( - _env_parse_none_str if _env_parse_none_str is not None else self.model_config.get('env_parse_none_str') - ) - env_parse_enums = _env_parse_enums if _env_parse_enums is not None else self.model_config.get('env_parse_enums') - - cli_prog_name = _cli_prog_name if _cli_prog_name is not None else self.model_config.get('cli_prog_name') - cli_parse_args = _cli_parse_args if _cli_parse_args is not None else self.model_config.get('cli_parse_args') - cli_settings_source = ( - _cli_settings_source if _cli_settings_source is not None else self.model_config.get('cli_settings_source') - ) - cli_parse_none_str = ( - _cli_parse_none_str if _cli_parse_none_str is not None else self.model_config.get('cli_parse_none_str') - ) - cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str - cli_hide_none_type = ( - _cli_hide_none_type if _cli_hide_none_type is not None else self.model_config.get('cli_hide_none_type') - ) - cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else self.model_config.get('cli_avoid_json') - cli_enforce_required = ( - _cli_enforce_required - if _cli_enforce_required is not None - else self.model_config.get('cli_enforce_required') - ) - cli_use_class_docs_for_groups = ( - _cli_use_class_docs_for_groups - if _cli_use_class_docs_for_groups is not None - else self.model_config.get('cli_use_class_docs_for_groups') - ) - cli_exit_on_error = ( - _cli_exit_on_error if _cli_exit_on_error is not None else self.model_config.get('cli_exit_on_error') - ) - cli_prefix = _cli_prefix if _cli_prefix is not None else self.model_config.get('cli_prefix') - cli_flag_prefix_char = ( - _cli_flag_prefix_char - if _cli_flag_prefix_char is not None - else self.model_config.get('cli_flag_prefix_char') - ) - cli_implicit_flags = ( - _cli_implicit_flags if _cli_implicit_flags is not None else self.model_config.get('cli_implicit_flags') - ) - cli_ignore_unknown_args = ( - _cli_ignore_unknown_args - if _cli_ignore_unknown_args is not None - else self.model_config.get('cli_ignore_unknown_args') - ) - cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else self.model_config.get('cli_kebab_case') - cli_shortcuts = _cli_shortcuts if _cli_shortcuts is not None else self.model_config.get('cli_shortcuts') - secrets_dir = _secrets_dir if _secrets_dir is not None else self.model_config.get('secrets_dir') # Configure built-in sources - default_settings = DefaultSettingsSource( - self.__class__, nested_model_default_partial_update=nested_model_default_partial_update - ) - init_settings = InitSettingsSource( - self.__class__, - init_kwargs=init_kwargs, - nested_model_default_partial_update=nested_model_default_partial_update, - ) + init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs) env_settings = EnvSettingsSource( self.__class__, case_sensitive=case_sensitive, env_prefix=env_prefix, env_nested_delimiter=env_nested_delimiter, - env_nested_max_split=env_nested_max_split, - env_ignore_empty=env_ignore_empty, - env_parse_none_str=env_parse_none_str, - env_parse_enums=env_parse_enums, ) dotenv_settings = DotEnvSettingsSource( self.__class__, @@ -377,10 +143,6 @@ class BaseSettings(BaseModel): case_sensitive=case_sensitive, env_prefix=env_prefix, env_nested_delimiter=env_nested_delimiter, - env_nested_max_split=env_nested_max_split, - env_ignore_empty=env_ignore_empty, - env_parse_none_str=env_parse_none_str, - env_parse_enums=env_parse_enums, ) file_secret_settings = SecretsSettingsSource( @@ -393,294 +155,23 @@ class BaseSettings(BaseModel): env_settings=env_settings, dotenv_settings=dotenv_settings, file_secret_settings=file_secret_settings, - ) + (default_settings,) - custom_cli_sources = [source for source in sources if isinstance(source, CliSettingsSource)] - if not any(custom_cli_sources): - if isinstance(cli_settings_source, CliSettingsSource): - sources = (cli_settings_source,) + sources - elif cli_parse_args is not None: - cli_settings = CliSettingsSource[Any]( - self.__class__, - cli_prog_name=cli_prog_name, - cli_parse_args=cli_parse_args, - cli_parse_none_str=cli_parse_none_str, - cli_hide_none_type=cli_hide_none_type, - cli_avoid_json=cli_avoid_json, - cli_enforce_required=cli_enforce_required, - cli_use_class_docs_for_groups=cli_use_class_docs_for_groups, - cli_exit_on_error=cli_exit_on_error, - cli_prefix=cli_prefix, - cli_flag_prefix_char=cli_flag_prefix_char, - cli_implicit_flags=cli_implicit_flags, - cli_ignore_unknown_args=cli_ignore_unknown_args, - cli_kebab_case=cli_kebab_case, - cli_shortcuts=cli_shortcuts, - case_sensitive=case_sensitive, - ) - sources = (cli_settings,) + sources - # We ensure that if command line arguments haven't been parsed yet, we do so. - elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars: - custom_cli_sources[0](args=cli_parse_args) # type: ignore - - self._settings_warn_unused_config_keys(sources, self.model_config) - + ) if sources: - state: dict[str, Any] = {} - states: dict[str, dict[str, Any]] = {} - for source in sources: - if isinstance(source, PydanticBaseSettingsSource): - source._set_current_state(state) - source._set_settings_sources_data(states) - - source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__ - source_state = source() - - states[source_name] = source_state - state = deep_update(source_state, state) - return state + return deep_update(*reversed([source() for source in sources])) else: # no one should mean to do this, but I think returning an empty dict is marginally preferable # to an informative error and much better than a confusing error return {} - @staticmethod - def _settings_warn_unused_config_keys(sources: tuple[object, ...], model_config: SettingsConfigDict) -> None: - """ - Warns if any values in model_config were set but the corresponding settings source has not been initialised. - - The list alternative sources and their config keys can be found here: - https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source - - Args: - sources: The tuple of configured sources - model_config: The model config to check for unused config keys - """ - - def warn_if_not_used(source_type: type[PydanticBaseSettingsSource], keys: tuple[str, ...]) -> None: - if not any(isinstance(source, source_type) for source in sources): - for key in keys: - if model_config.get(key) is not None: - warnings.warn( - f'Config key `{key}` is set in model_config but will be ignored because no ' - f'{source_type.__name__} source is configured. To use this config key, add a ' - f'{source_type.__name__} source to the settings sources via the ' - 'settings_customise_sources hook.', - UserWarning, - stacklevel=3, - ) - - warn_if_not_used(JsonConfigSettingsSource, ('json_file', 'json_file_encoding')) - warn_if_not_used(PyprojectTomlConfigSettingsSource, ('pyproject_toml_depth', 'pyproject_toml_table_header')) - warn_if_not_used(TomlConfigSettingsSource, ('toml_file',)) - warn_if_not_used(YamlConfigSettingsSource, ('yaml_file', 'yaml_file_encoding', 'yaml_config_section')) - model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict( extra='forbid', arbitrary_types_allowed=True, validate_default=True, case_sensitive=False, env_prefix='', - nested_model_default_partial_update=False, env_file=None, env_file_encoding=None, - env_ignore_empty=False, env_nested_delimiter=None, - env_nested_max_split=None, - env_parse_none_str=None, - env_parse_enums=None, - cli_prog_name=None, - cli_parse_args=None, - cli_parse_none_str=None, - cli_hide_none_type=False, - cli_avoid_json=False, - cli_enforce_required=False, - cli_use_class_docs_for_groups=False, - cli_exit_on_error=True, - cli_prefix='', - cli_flag_prefix_char='-', - cli_implicit_flags=False, - cli_ignore_unknown_args=False, - cli_kebab_case=False, - cli_shortcuts=None, - json_file=None, - json_file_encoding=None, - yaml_file=None, - yaml_file_encoding=None, - yaml_config_section=None, - toml_file=None, secrets_dir=None, - protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'), - enable_decoding=True, + protected_namespaces=('model_', 'settings_'), ) - - -class CliApp: - """ - A utility class for running Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as - CLI applications. - """ - - @staticmethod - def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]: - if issubclass(model_cls, BaseSettings): - return model_cls - - class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore - __doc__ = model_cls.__doc__ - model_config = SettingsConfigDict( - nested_model_default_partial_update=True, - case_sensitive=True, - cli_hide_none_type=True, - cli_avoid_json=True, - cli_enforce_required=True, - cli_implicit_flags=True, - cli_kebab_case=True, - ) - - return CliAppBaseSettings - - @staticmethod - def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any: - command = getattr(type(model), cli_cmd_method_name, None) - if command is None: - if is_required: - raise SettingsError(f'Error: {type(model).__name__} class is missing {cli_cmd_method_name} entrypoint') - return model - - # If the method is asynchronous, we handle its execution based on the current event loop status. - if inspect.iscoroutinefunction(command): - # For asynchronous methods, we have two execution scenarios: - # 1. If no event loop is running in the current thread, run the coroutine directly with asyncio.run(). - # 2. If an event loop is already running in the current thread, run the coroutine in a separate thread to avoid conflicts. - try: - # Check if an event loop is currently running in this thread. - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - # We're in a context with an active event loop (e.g., Jupyter Notebook). - # Running asyncio.run() here would cause conflicts, so we use a separate thread. - exception_container = [] - - def run_coro() -> None: - try: - # Execute the coroutine in a new event loop in this separate thread. - asyncio.run(command(model)) - except Exception as e: - exception_container.append(e) - - thread = threading.Thread(target=run_coro) - thread.start() - thread.join() - if exception_container: - # Propagate exceptions from the separate thread. - raise exception_container[0] - else: - # No event loop is running; safe to run the coroutine directly. - asyncio.run(command(model)) - else: - # For synchronous methods, call them directly. - command(model) - - return model - - @staticmethod - def run( - model_cls: type[T], - cli_args: list[str] | Namespace | SimpleNamespace | dict[str, Any] | None = None, - cli_settings_source: CliSettingsSource[Any] | None = None, - cli_exit_on_error: bool | None = None, - cli_cmd_method_name: str = 'cli_cmd', - **model_init_data: Any, - ) -> T: - """ - Runs a Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as a CLI application. - Running a model as a CLI application requires the `cli_cmd` method to be defined in the model class. - - Args: - model_cls: The model class to run as a CLI application. - cli_args: The list of CLI arguments to parse. If `cli_settings_source` is specified, this may - also be a namespace or dictionary of pre-parsed CLI arguments. Defaults to `sys.argv[1:]`. - cli_settings_source: Override the default CLI settings source with a user defined instance. - Defaults to `None`. - cli_exit_on_error: Determines whether this function exits on error. If model is subclass of - `BaseSettings`, defaults to BaseSettings `cli_exit_on_error` value. Otherwise, defaults to - `True`. - cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd". - model_init_data: The model init data. - - Returns: - The ran instance of model. - - Raises: - SettingsError: If model_cls is not subclass of `BaseModel` or `pydantic.dataclasses.dataclass`. - SettingsError: If model_cls does not have a `cli_cmd` entrypoint defined. - """ - - if not (is_pydantic_dataclass(model_cls) or is_model_class(model_cls)): - raise SettingsError( - f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass' - ) - - cli_settings = None - cli_parse_args = True if cli_args is None else cli_args - if cli_settings_source is not None: - if isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)): - cli_settings = cli_settings_source(parsed_args=cli_parse_args) - else: - cli_settings = cli_settings_source(args=cli_parse_args) - elif isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)): - raise SettingsError('Error: `cli_args` must be list[str] or None when `cli_settings_source` is not used') - - model_init_data['_cli_parse_args'] = cli_parse_args - model_init_data['_cli_exit_on_error'] = cli_exit_on_error - model_init_data['_cli_settings_source'] = cli_settings - if not issubclass(model_cls, BaseSettings): - base_settings_cls = CliApp._get_base_settings_cls(model_cls) - model = base_settings_cls(**model_init_data) - model_init_data = {} - for field_name, field_info in base_settings_cls.model_fields.items(): - model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name) - - return CliApp._run_cli_cmd(model_cls(**model_init_data), cli_cmd_method_name, is_required=False) - - @staticmethod - def run_subcommand( - model: PydanticModel, cli_exit_on_error: bool | None = None, cli_cmd_method_name: str = 'cli_cmd' - ) -> PydanticModel: - """ - Runs the model subcommand. Running a model subcommand requires the `cli_cmd` method to be defined in - the nested model subcommand class. - - Args: - model: The model to run the subcommand from. - cli_exit_on_error: Determines whether this function exits with error if no subcommand is found. - Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`. - cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd". - - Returns: - The ran subcommand model. - - Raises: - SystemExit: When no subcommand is found and cli_exit_on_error=`True` (the default). - SettingsError: When no subcommand is found and cli_exit_on_error=`False`. - """ - - subcommand = get_subcommand(model, is_required=True, cli_exit_on_error=cli_exit_on_error) - return CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True) - - @staticmethod - def serialize(model: PydanticModel) -> list[str]: - """ - Serializes the CLI arguments for a Pydantic data model. - - Args: - model: The data model to serialize. - - Returns: - The serialized CLI arguments for the data model. - """ - - base_settings_cls = CliApp._get_base_settings_cls(type(model)) - return CliSettingsSource[Any](base_settings_cls)._serialized_args(model) diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources.py new file mode 100644 index 0000000..baa10b8 --- /dev/null +++ b/venv/lib/python3.12/site-packages/pydantic_settings/sources.py @@ -0,0 +1,653 @@ +from __future__ import annotations as _annotations + +import json +import os +import warnings +from abc import ABC, abstractmethod +from collections import deque +from dataclasses import is_dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Union, cast + +from pydantic import AliasChoices, AliasPath, BaseModel, Json +from pydantic._internal._typing_extra import origin_is_union +from pydantic._internal._utils import deep_update, lenient_issubclass +from pydantic.fields import FieldInfo +from typing_extensions import get_args, get_origin + +from pydantic_settings.utils import path_type_label + +if TYPE_CHECKING: + from pydantic_settings.main import BaseSettings + + +DotenvType = Union[Path, str, List[Union[Path, str]], Tuple[Union[Path, str], ...]] + +# This is used as default value for `_env_file` in the `BaseSettings` class and +# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`. +# See the docstring of `BaseSettings` for more details. +ENV_FILE_SENTINEL: DotenvType = Path('') + + +class SettingsError(ValueError): + pass + + +class PydanticBaseSettingsSource(ABC): + """ + Abstract base class for settings sources, every settings source classes should inherit from it. + """ + + def __init__(self, settings_cls: type[BaseSettings]): + self.settings_cls = settings_cls + self.config = settings_cls.model_config + + @abstractmethod + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + """ + Gets the value, the key for model creation, and a flag to determine whether value is complex. + + This is an abstract method that should be overridden in every settings source classes. + + Args: + field: The field. + field_name: The field name. + + Returns: + A tuple contains the key, value and a flag to determine whether value is complex. + """ + pass + + def field_is_complex(self, field: FieldInfo) -> bool: + """ + Checks whether a field is complex, in which case it will attempt to be parsed as JSON. + + Args: + field: The field. + + Returns: + Whether the field is complex. + """ + return _annotation_is_complex(field.annotation, field.metadata) + + def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: + """ + Prepares the value of a field. + + Args: + field_name: The field name. + field: The field. + value: The value of the field that has to be prepared. + value_is_complex: A flag to determine whether value is complex. + + Returns: + The prepared value. + """ + if value is not None and (self.field_is_complex(field) or value_is_complex): + return self.decode_complex_value(field_name, field, value) + return value + + def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any: + """ + Decode the value for a complex field + + Args: + field_name: The field name. + field: The field. + value: The value of the field that has to be prepared. + + Returns: + The decoded value for further preparation + """ + return json.loads(value) + + @abstractmethod + def __call__(self) -> dict[str, Any]: + pass + + +class InitSettingsSource(PydanticBaseSettingsSource): + """ + Source class for loading values provided during settings class initialization. + """ + + def __init__(self, settings_cls: type[BaseSettings], init_kwargs: dict[str, Any]): + self.init_kwargs = init_kwargs + super().__init__(settings_cls) + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + # Nothing to do here. Only implement the return statement to make mypy happy + return None, '', False + + def __call__(self) -> dict[str, Any]: + return self.init_kwargs + + def __repr__(self) -> str: + return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' + + +class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource): + def __init__( + self, settings_cls: type[BaseSettings], case_sensitive: bool | None = None, env_prefix: str | None = None + ) -> None: + super().__init__(settings_cls) + self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False) + self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '') + + def _apply_case_sensitive(self, value: str) -> str: + return value.lower() if not self.case_sensitive else value + + def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: + """ + Extracts field info. This info is used to get the value of field from environment variables. + + It returns a list of tuples, each tuple contains: + * field_key: The key of field that has to be used in model creation. + * env_name: The environment variable name of the field. + * value_is_complex: A flag to determine whether the value from environment variable + is complex and has to be parsed. + + Args: + field (FieldInfo): The field. + field_name (str): The field name. + + Returns: + list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex. + """ + field_info: list[tuple[str, str, bool]] = [] + if isinstance(field.validation_alias, (AliasChoices, AliasPath)): + v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases() + else: + v_alias = field.validation_alias + + if v_alias: + if isinstance(v_alias, list): # AliasChoices, AliasPath + for alias in v_alias: + if isinstance(alias, str): # AliasPath + field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False)) + elif isinstance(alias, list): # AliasChoices + first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str + field_info.append( + (first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False) + ) + else: # string validation alias + field_info.append((v_alias, self._apply_case_sensitive(v_alias), False)) + else: + field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False)) + + return field_info + + def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]: + """ + Replace field names in values dict by looking in models fields insensitively. + + By having the following models: + + ```py + class SubSubSub(BaseModel): + VaL3: str + + class SubSub(BaseModel): + Val2: str + SUB_sub_SuB: SubSubSub + + class Sub(BaseModel): + VAL1: str + SUB_sub: SubSub + + class Settings(BaseSettings): + nested: Sub + + model_config = SettingsConfigDict(env_nested_delimiter='__') + ``` + + Then: + _replace_field_names_case_insensitively( + field, + {"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}} + ) + Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}} + """ + values: dict[str, Any] = {} + + for name, value in field_values.items(): + sub_model_field: FieldInfo | None = None + + # This is here to make mypy happy + # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields" + if not field.annotation or not hasattr(field.annotation, 'model_fields'): + values[name] = value + continue + + # Find field in sub model by looking in fields case insensitively + for sub_model_field_name, f in field.annotation.model_fields.items(): + if not f.validation_alias and sub_model_field_name.lower() == name.lower(): + sub_model_field = f + break + + if not sub_model_field: + values[name] = value + continue + + if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict): + values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value) + else: + values[sub_model_field_name] = value + + return values + + def __call__(self) -> dict[str, Any]: + data: dict[str, Any] = {} + + for field_name, field in self.settings_cls.model_fields.items(): + try: + field_value, field_key, value_is_complex = self.get_field_value(field, field_name) + except Exception as e: + raise SettingsError( + f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"' + ) from e + + try: + field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex) + except ValueError as e: + raise SettingsError( + f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"' + ) from e + + if field_value is not None: + if ( + not self.case_sensitive + and lenient_issubclass(field.annotation, BaseModel) + and isinstance(field_value, dict) + ): + data[field_key] = self._replace_field_names_case_insensitively(field, field_value) + else: + data[field_key] = field_value + + return data + + +class SecretsSettingsSource(PydanticBaseEnvSettingsSource): + """ + Source class for loading settings values from secret files. + """ + + def __init__( + self, + settings_cls: type[BaseSettings], + secrets_dir: str | Path | None = None, + case_sensitive: bool | None = None, + env_prefix: str | None = None, + ) -> None: + super().__init__(settings_cls, case_sensitive, env_prefix) + self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir') + + def __call__(self) -> dict[str, Any]: + """ + Build fields from "secrets" files. + """ + secrets: dict[str, str | None] = {} + + if self.secrets_dir is None: + return secrets + + self.secrets_path = Path(self.secrets_dir).expanduser() + + if not self.secrets_path.exists(): + warnings.warn(f'directory "{self.secrets_path}" does not exist') + return secrets + + if not self.secrets_path.is_dir(): + raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(self.secrets_path)}') + + return super().__call__() + + @classmethod + def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None: + """ + Find a file within path's directory matching filename, optionally ignoring case. + + Args: + dir_path: Directory path. + file_name: File name. + case_sensitive: Whether to search for file name case sensitively. + + Returns: + Whether file path or `None` if file does not exist in directory. + """ + for f in dir_path.iterdir(): + if f.name == file_name: + return f + elif not case_sensitive and f.name.lower() == file_name.lower(): + return f + return None + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + """ + Gets the value for field from secret file and a flag to determine whether value is complex. + + Args: + field: The field. + field_name: The field name. + + Returns: + A tuple contains the key, value if the file exists otherwise `None`, and + a flag to determine whether value is complex. + """ + + for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + path = self.find_case_path(self.secrets_path, env_name, self.case_sensitive) + if not path: + # path does not exist, we currently don't return a warning for this + continue + + if path.is_file(): + return path.read_text().strip(), field_key, value_is_complex + else: + warnings.warn( + f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.', + stacklevel=4, + ) + + return None, field_key, value_is_complex + + def __repr__(self) -> str: + return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' + + +class EnvSettingsSource(PydanticBaseEnvSettingsSource): + """ + Source class for loading settings values from environment variables. + """ + + def __init__( + self, + settings_cls: type[BaseSettings], + case_sensitive: bool | None = None, + env_prefix: str | None = None, + env_nested_delimiter: str | None = None, + ) -> None: + super().__init__(settings_cls, case_sensitive, env_prefix) + self.env_nested_delimiter = ( + env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter') + ) + self.env_prefix_len = len(self.env_prefix) + + self.env_vars = self._load_env_vars() + + def _load_env_vars(self) -> Mapping[str, str | None]: + if self.case_sensitive: + return os.environ + return {k.lower(): v for k, v in os.environ.items()} + + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: + """ + Gets the value for field from environment variables and a flag to determine whether value is complex. + + Args: + field: The field. + field_name: The field name. + + Returns: + A tuple contains the key, value if the file exists otherwise `None`, and + a flag to determine whether value is complex. + """ + + env_val: str | None = None + for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + env_val = self.env_vars.get(env_name) + if env_val is not None: + break + + return env_val, field_key, value_is_complex + + def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: + """ + Prepare value for the field. + + * Extract value for nested field. + * Deserialize value to python object for complex field. + + Args: + field: The field. + field_name: The field name. + + Returns: + A tuple contains prepared value for the field. + + Raises: + ValuesError: When There is an error in deserializing value for complex field. + """ + is_complex, allow_parse_failure = self._field_is_complex(field) + if is_complex or value_is_complex: + if value is None: + # field is complex but no value found so far, try explode_env_vars + env_val_built = self.explode_env_vars(field_name, field, self.env_vars) + if env_val_built: + return env_val_built + else: + # field is complex and there's a value, decode that as JSON, then add explode_env_vars + try: + value = self.decode_complex_value(field_name, field, value) + except ValueError as e: + if not allow_parse_failure: + raise e + + if isinstance(value, dict): + return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars)) + else: + return value + elif value is not None: + # simplest case, field is not complex, we only need to add the value if it was found + return value + + def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool: + return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) + + def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: + """ + Find out if a field is complex, and if so whether JSON errors should be ignored + """ + if self.field_is_complex(field): + allow_parse_failure = False + elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata): + allow_parse_failure = True + else: + return False, False + + return True, allow_parse_failure + + @staticmethod + def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None: + """ + Find the field in a sub model by key(env name) + + By having the following models: + + ```py + class SubSubModel(BaseSettings): + dvals: Dict + + class SubModel(BaseSettings): + vals: list[str] + sub_sub_model: SubSubModel + + class Cfg(BaseSettings): + sub_model: SubModel + ``` + + Then: + next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class + next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class + + Args: + field: The field. + key: The key (env name). + + Returns: + Field if it finds the next field otherwise `None`. + """ + if not field or origin_is_union(get_origin(field.annotation)): + # no support for Unions of complex BaseSettings fields + return None + elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key): + return field.annotation.model_fields[key] + + return None + + def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]: + """ + Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. + + This is applied to a single field, hence filtering by env_var prefix. + + Args: + field_name: The field name. + field: The field. + env_vars: Environment variables. + + Returns: + A dictionaty contains extracted values from nested env values. + """ + prefixes = [ + f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name) + ] + result: dict[str, Any] = {} + for env_name, env_val in env_vars.items(): + if not any(env_name.startswith(prefix) for prefix in prefixes): + continue + # we remove the prefix before splitting in case the prefix has characters in common with the delimiter + env_name_without_prefix = env_name[self.env_prefix_len :] + _, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter) + env_var = result + target_field: FieldInfo | None = field + for key in keys: + target_field = self.next_field(target_field, key) + env_var = env_var.setdefault(key, {}) + + # get proper field with last_key + target_field = self.next_field(target_field, last_key) + + # check if env_val maps to a complex field and if so, parse the env_val + if target_field and env_val: + is_complex, allow_json_failure = self._field_is_complex(target_field) + if is_complex: + try: + env_val = self.decode_complex_value(last_key, target_field, env_val) + except ValueError as e: + if not allow_json_failure: + raise e + env_var[last_key] = env_val + + return result + + def __repr__(self) -> str: + return ( + f'EnvSettingsSource(env_nested_delimiter={self.env_nested_delimiter!r}, ' + f'env_prefix_len={self.env_prefix_len!r})' + ) + + +class DotEnvSettingsSource(EnvSettingsSource): + """ + Source class for loading settings values from env files. + """ + + def __init__( + self, + settings_cls: type[BaseSettings], + env_file: DotenvType | None = ENV_FILE_SENTINEL, + env_file_encoding: str | None = None, + case_sensitive: bool | None = None, + env_prefix: str | None = None, + env_nested_delimiter: str | None = None, + ) -> None: + self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file') + self.env_file_encoding = ( + env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding') + ) + super().__init__(settings_cls, case_sensitive, env_prefix, env_nested_delimiter) + + def _load_env_vars(self) -> Mapping[str, str | None]: + return self._read_env_files(self.case_sensitive) + + def _read_env_files(self, case_sensitive: bool) -> Mapping[str, str | None]: + env_files = self.env_file + if env_files is None: + return {} + + if isinstance(env_files, (str, os.PathLike)): + env_files = [env_files] + + dotenv_vars: dict[str, str | None] = {} + for env_file in env_files: + env_path = Path(env_file).expanduser() + if env_path.is_file(): + dotenv_vars.update( + read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive) + ) + + return dotenv_vars + + def __call__(self) -> dict[str, Any]: + data: dict[str, Any] = super().__call__() + + data_lower_keys: list[str] = [] + if not self.case_sensitive: + data_lower_keys = [x.lower() for x in data.keys()] + + # As `extra` config is allowed in dotenv settings source, We have to + # update data with extra env variabels from dotenv file. + for env_name, env_value in self.env_vars.items(): + if env_name.startswith(self.env_prefix) and env_value is not None: + env_name_without_prefix = env_name[self.env_prefix_len :] + first_key, *_ = env_name_without_prefix.split(self.env_nested_delimiter) + + if (data_lower_keys and first_key not in data_lower_keys) or ( + not data_lower_keys and first_key not in data + ): + data[first_key] = env_value + + return data + + def __repr__(self) -> str: + return ( + f'DotEnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' + f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})' + ) + + +def read_env_file( + file_path: Path, *, encoding: str | None = None, case_sensitive: bool = False +) -> Mapping[str, str | None]: + try: + from dotenv import dotenv_values + except ImportError as e: + raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e + + file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8') + if not case_sensitive: + return {k.lower(): v for k, v in file_vars.items()} + else: + return file_vars + + +def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: + if any(isinstance(md, Json) for md in metadata): # type: ignore[misc] + return False + origin = get_origin(annotation) + return ( + _annotation_is_complex_inner(annotation) + or _annotation_is_complex_inner(origin) + or hasattr(origin, '__pydantic_core_schema__') + or hasattr(origin, '__get_pydantic_core_schema__') + ) + + +def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: + if lenient_issubclass(annotation, (str, bytes)): + return False + + return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass( + annotation + ) diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/__init__.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/__init__.py deleted file mode 100644 index a795c49..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Package for handling configuration sources in pydantic-settings.""" - -from .base import ( - ConfigFileSourceMixin, - DefaultSettingsSource, - InitSettingsSource, - PydanticBaseEnvSettingsSource, - PydanticBaseSettingsSource, - get_subcommand, -) -from .providers.aws import AWSSecretsManagerSettingsSource -from .providers.azure import AzureKeyVaultSettingsSource -from .providers.cli import ( - CLI_SUPPRESS, - CliExplicitFlag, - CliImplicitFlag, - CliMutuallyExclusiveGroup, - CliPositionalArg, - CliSettingsSource, - CliSubCommand, - CliSuppress, - CliUnknownArgs, -) -from .providers.dotenv import DotEnvSettingsSource, read_env_file -from .providers.env import EnvSettingsSource -from .providers.gcp import GoogleSecretManagerSettingsSource -from .providers.json import JsonConfigSettingsSource -from .providers.pyproject import PyprojectTomlConfigSettingsSource -from .providers.secrets import SecretsSettingsSource -from .providers.toml import TomlConfigSettingsSource -from .providers.yaml import YamlConfigSettingsSource -from .types import DEFAULT_PATH, ENV_FILE_SENTINEL, DotenvType, ForceDecode, NoDecode, PathType, PydanticModel - -__all__ = [ - 'CLI_SUPPRESS', - 'ENV_FILE_SENTINEL', - 'DEFAULT_PATH', - 'AWSSecretsManagerSettingsSource', - 'AzureKeyVaultSettingsSource', - 'CliExplicitFlag', - 'CliImplicitFlag', - 'CliMutuallyExclusiveGroup', - 'CliPositionalArg', - 'CliSettingsSource', - 'CliSubCommand', - 'CliSuppress', - 'CliUnknownArgs', - 'DefaultSettingsSource', - 'DotEnvSettingsSource', - 'DotenvType', - 'EnvSettingsSource', - 'ForceDecode', - 'GoogleSecretManagerSettingsSource', - 'InitSettingsSource', - 'JsonConfigSettingsSource', - 'NoDecode', - 'PathType', - 'PydanticBaseEnvSettingsSource', - 'PydanticBaseSettingsSource', - 'ConfigFileSourceMixin', - 'PydanticModel', - 'PyprojectTomlConfigSettingsSource', - 'SecretsSettingsSource', - 'TomlConfigSettingsSource', - 'YamlConfigSettingsSource', - 'get_subcommand', - 'read_env_file', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/base.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/base.py deleted file mode 100644 index a5ec7e5..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/base.py +++ /dev/null @@ -1,527 +0,0 @@ -"""Base classes and core functionality for pydantic-settings sources.""" - -from __future__ import annotations as _annotations - -import json -import os -from abc import ABC, abstractmethod -from dataclasses import asdict, is_dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any, Optional, cast - -from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter -from pydantic._internal._typing_extra import ( # type: ignore[attr-defined] - get_origin, -) -from pydantic._internal._utils import is_model_class -from pydantic.fields import FieldInfo -from typing_extensions import get_args -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin - -from ..exceptions import SettingsError -from ..utils import _lenient_issubclass -from .types import EnvNoneType, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand -from .utils import ( - _annotation_is_complex, - _get_alias_names, - _get_model_fields, - _strip_annotated, - _union_is_complex, -) - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -def get_subcommand( - model: PydanticModel, is_required: bool = True, cli_exit_on_error: bool | None = None -) -> Optional[PydanticModel]: - """ - Get the subcommand from a model. - - Args: - model: The model to get the subcommand from. - is_required: Determines whether a model must have subcommand set and raises error if not - found. Defaults to `True`. - cli_exit_on_error: Determines whether this function exits with error if no subcommand is found. - Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`. - - Returns: - The subcommand model if found, otherwise `None`. - - Raises: - SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True` - (the default). - SettingsError: When no subcommand is found and is_required=`True` and - cli_exit_on_error=`False`. - """ - - model_cls = type(model) - if cli_exit_on_error is None and is_model_class(model_cls): - model_default = model_cls.model_config.get('cli_exit_on_error') - if isinstance(model_default, bool): - cli_exit_on_error = model_default - if cli_exit_on_error is None: - cli_exit_on_error = True - - subcommands: list[str] = [] - for field_name, field_info in _get_model_fields(model_cls).items(): - if _CliSubCommand in field_info.metadata: - if getattr(model, field_name) is not None: - return getattr(model, field_name) - subcommands.append(field_name) - - if is_required: - error_message = ( - f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}' - if subcommands - else 'Error: CLI subcommand is required but no subcommands were found.' - ) - raise SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message) - - return None - - -class PydanticBaseSettingsSource(ABC): - """ - Abstract base class for settings sources, every settings source classes should inherit from it. - """ - - def __init__(self, settings_cls: type[BaseSettings]): - self.settings_cls = settings_cls - self.config = settings_cls.model_config - self._current_state: dict[str, Any] = {} - self._settings_sources_data: dict[str, dict[str, Any]] = {} - - def _set_current_state(self, state: dict[str, Any]) -> None: - """ - Record the state of settings from the previous settings sources. This should - be called right before __call__. - """ - self._current_state = state - - def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None: - """ - Record the state of settings from all previous settings sources. This should - be called right before __call__. - """ - self._settings_sources_data = states - - @property - def current_state(self) -> dict[str, Any]: - """ - The current state of the settings, populated by the previous settings sources. - """ - return self._current_state - - @property - def settings_sources_data(self) -> dict[str, dict[str, Any]]: - """ - The state of all previous settings sources. - """ - return self._settings_sources_data - - @abstractmethod - def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - """ - Gets the value, the key for model creation, and a flag to determine whether value is complex. - - This is an abstract method that should be overridden in every settings source classes. - - Args: - field: The field. - field_name: The field name. - - Returns: - A tuple that contains the value, key and a flag to determine whether value is complex. - """ - pass - - def field_is_complex(self, field: FieldInfo) -> bool: - """ - Checks whether a field is complex, in which case it will attempt to be parsed as JSON. - - Args: - field: The field. - - Returns: - Whether the field is complex. - """ - return _annotation_is_complex(field.annotation, field.metadata) - - def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: - """ - Prepares the value of a field. - - Args: - field_name: The field name. - field: The field. - value: The value of the field that has to be prepared. - value_is_complex: A flag to determine whether value is complex. - - Returns: - The prepared value. - """ - if value is not None and (self.field_is_complex(field) or value_is_complex): - return self.decode_complex_value(field_name, field, value) - return value - - def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any: - """ - Decode the value for a complex field - - Args: - field_name: The field name. - field: The field. - value: The value of the field that has to be prepared. - - Returns: - The decoded value for further preparation - """ - if field and ( - NoDecode in field.metadata - or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata) - ): - return value - - return json.loads(value) - - @abstractmethod - def __call__(self) -> dict[str, Any]: - pass - - -class ConfigFileSourceMixin(ABC): - def _read_files(self, files: PathType | None) -> dict[str, Any]: - if files is None: - return {} - if isinstance(files, (str, os.PathLike)): - files = [files] - vars: dict[str, Any] = {} - for file in files: - file_path = Path(file).expanduser() - if file_path.is_file(): - vars.update(self._read_file(file_path)) - return vars - - @abstractmethod - def _read_file(self, path: Path) -> dict[str, Any]: - pass - - -class DefaultSettingsSource(PydanticBaseSettingsSource): - """ - Source class for loading default object values. - - Args: - settings_cls: The Settings class. - nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields. - Defaults to `False`. - """ - - def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None): - super().__init__(settings_cls) - self.defaults: dict[str, Any] = {} - self.nested_model_default_partial_update = ( - nested_model_default_partial_update - if nested_model_default_partial_update is not None - else self.config.get('nested_model_default_partial_update', False) - ) - if self.nested_model_default_partial_update: - for field_name, field_info in settings_cls.model_fields.items(): - alias_names, *_ = _get_alias_names(field_name, field_info) - preferred_alias = alias_names[0] - if is_dataclass(type(field_info.default)): - self.defaults[preferred_alias] = asdict(field_info.default) - elif is_model_class(type(field_info.default)): - self.defaults[preferred_alias] = field_info.default.model_dump() - - def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - # Nothing to do here. Only implement the return statement to make mypy happy - return None, '', False - - def __call__(self) -> dict[str, Any]: - return self.defaults - - def __repr__(self) -> str: - return ( - f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})' - ) - - -class InitSettingsSource(PydanticBaseSettingsSource): - """ - Source class for loading values provided during settings class initialization. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - init_kwargs: dict[str, Any], - nested_model_default_partial_update: bool | None = None, - ): - self.init_kwargs = {} - init_kwarg_names = set(init_kwargs.keys()) - for field_name, field_info in settings_cls.model_fields.items(): - alias_names, *_ = _get_alias_names(field_name, field_info) - init_kwarg_name = init_kwarg_names & set(alias_names) - if init_kwarg_name: - preferred_alias = alias_names[0] - preferred_set_alias = next(alias for alias in alias_names if alias in init_kwarg_name) - init_kwarg_names -= init_kwarg_name - self.init_kwargs[preferred_alias] = init_kwargs[preferred_set_alias] - self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names}) - - super().__init__(settings_cls) - self.nested_model_default_partial_update = ( - nested_model_default_partial_update - if nested_model_default_partial_update is not None - else self.config.get('nested_model_default_partial_update', False) - ) - - def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - # Nothing to do here. Only implement the return statement to make mypy happy - return None, '', False - - def __call__(self) -> dict[str, Any]: - return ( - TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs) - if self.nested_model_default_partial_update - else self.init_kwargs - ) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})' - - -class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource): - def __init__( - self, - settings_cls: type[BaseSettings], - case_sensitive: bool | None = None, - env_prefix: str | None = None, - env_ignore_empty: bool | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - super().__init__(settings_cls) - self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False) - self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '') - self.env_ignore_empty = ( - env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False) - ) - self.env_parse_none_str = ( - env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str') - ) - self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums') - - def _apply_case_sensitive(self, value: str) -> str: - return value.lower() if not self.case_sensitive else value - - def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: - """ - Extracts field info. This info is used to get the value of field from environment variables. - - It returns a list of tuples, each tuple contains: - * field_key: The key of field that has to be used in model creation. - * env_name: The environment variable name of the field. - * value_is_complex: A flag to determine whether the value from environment variable - is complex and has to be parsed. - - Args: - field (FieldInfo): The field. - field_name (str): The field name. - - Returns: - list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex. - """ - field_info: list[tuple[str, str, bool]] = [] - if isinstance(field.validation_alias, (AliasChoices, AliasPath)): - v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases() - else: - v_alias = field.validation_alias - - if v_alias: - if isinstance(v_alias, list): # AliasChoices, AliasPath - for alias in v_alias: - if isinstance(alias, str): # AliasPath - field_info.append((alias, self._apply_case_sensitive(alias), True if len(alias) > 1 else False)) - elif isinstance(alias, list): # AliasChoices - first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str - field_info.append( - (first_arg, self._apply_case_sensitive(first_arg), True if len(alias) > 1 else False) - ) - else: # string validation alias - field_info.append((v_alias, self._apply_case_sensitive(v_alias), False)) - - if not v_alias or self.config.get('populate_by_name', False): - annotation = field.annotation - if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)): - annotation = _strip_annotated(annotation.__value__) # type: ignore[union-attr] - if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata): - field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True)) - else: - field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False)) - - return field_info - - def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]: - """ - Replace field names in values dict by looking in models fields insensitively. - - By having the following models: - - ```py - class SubSubSub(BaseModel): - VaL3: str - - class SubSub(BaseModel): - Val2: str - SUB_sub_SuB: SubSubSub - - class Sub(BaseModel): - VAL1: str - SUB_sub: SubSub - - class Settings(BaseSettings): - nested: Sub - - model_config = SettingsConfigDict(env_nested_delimiter='__') - ``` - - Then: - _replace_field_names_case_insensitively( - field, - {"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}} - ) - Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}} - """ - values: dict[str, Any] = {} - - for name, value in field_values.items(): - sub_model_field: FieldInfo | None = None - - annotation = field.annotation - - # If field is Optional, we need to find the actual type - if is_union_origin(get_origin(field.annotation)): - args = get_args(annotation) - if len(args) == 2 and type(None) in args: - for arg in args: - if arg is not None: - annotation = arg - break - - # This is here to make mypy happy - # Item "None" of "Optional[Type[Any]]" has no attribute "model_fields" - if not annotation or not hasattr(annotation, 'model_fields'): - values[name] = value - continue - else: - model_fields: dict[str, FieldInfo] = annotation.model_fields - - # Find field in sub model by looking in fields case insensitively - field_key: str | None = None - for sub_model_field_name, sub_model_field in model_fields.items(): - aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field) - _search = (alias for alias in aliases if alias.lower() == name.lower()) - if field_key := next(_search, None): - break - - if not field_key: - values[name] = value - continue - - if ( - sub_model_field is not None - and _lenient_issubclass(sub_model_field.annotation, BaseModel) - and isinstance(value, dict) - ): - values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value) - else: - values[field_key] = value - - return values - - def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]: - """ - Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None). - """ - values: dict[str, Any] = {} - - for key, value in field_value.items(): - if not isinstance(value, EnvNoneType): - values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value) - else: - values[key] = None - - return values - - def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - """ - Gets the value, the preferred alias key for model creation, and a flag to determine whether value - is complex. - - Note: - In V3, this method should either be made public, or, this method should be removed and the - abstract method get_field_value should be updated to include a "use_preferred_alias" flag. - - Args: - field: The field. - field_name: The field name. - - Returns: - A tuple that contains the value, preferred key and a flag to determine whether value is complex. - """ - field_value, field_key, value_is_complex = self.get_field_value(field, field_name) - if not (value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name))): - field_infos = self._extract_field_info(field, field_name) - preferred_key, *_ = field_infos[0] - return field_value, preferred_key, value_is_complex - return field_value, field_key, value_is_complex - - def __call__(self) -> dict[str, Any]: - data: dict[str, Any] = {} - - for field_name, field in self.settings_cls.model_fields.items(): - try: - field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name) - except Exception as e: - raise SettingsError( - f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"' - ) from e - - try: - field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex) - except ValueError as e: - raise SettingsError( - f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"' - ) from e - - if field_value is not None: - if self.env_parse_none_str is not None: - if isinstance(field_value, dict): - field_value = self._replace_env_none_type_values(field_value) - elif isinstance(field_value, EnvNoneType): - field_value = None - if ( - not self.case_sensitive - # and _lenient_issubclass(field.annotation, BaseModel) - and isinstance(field_value, dict) - ): - data[field_key] = self._replace_field_names_case_insensitively(field, field_value) - else: - data[field_key] = field_value - - return data - - -__all__ = [ - 'ConfigFileSourceMixin', - 'DefaultSettingsSource', - 'InitSettingsSource', - 'PydanticBaseEnvSettingsSource', - 'PydanticBaseSettingsSource', - 'SettingsError', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/__init__.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/__init__.py deleted file mode 100644 index 31759f3..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Package containing individual source implementations.""" - -from .aws import AWSSecretsManagerSettingsSource -from .azure import AzureKeyVaultSettingsSource -from .cli import ( - CliExplicitFlag, - CliImplicitFlag, - CliMutuallyExclusiveGroup, - CliPositionalArg, - CliSettingsSource, - CliSubCommand, - CliSuppress, -) -from .dotenv import DotEnvSettingsSource -from .env import EnvSettingsSource -from .gcp import GoogleSecretManagerSettingsSource -from .json import JsonConfigSettingsSource -from .pyproject import PyprojectTomlConfigSettingsSource -from .secrets import SecretsSettingsSource -from .toml import TomlConfigSettingsSource -from .yaml import YamlConfigSettingsSource - -__all__ = [ - 'AWSSecretsManagerSettingsSource', - 'AzureKeyVaultSettingsSource', - 'CliExplicitFlag', - 'CliImplicitFlag', - 'CliMutuallyExclusiveGroup', - 'CliPositionalArg', - 'CliSettingsSource', - 'CliSubCommand', - 'CliSuppress', - 'DotEnvSettingsSource', - 'EnvSettingsSource', - 'GoogleSecretManagerSettingsSource', - 'JsonConfigSettingsSource', - 'PyprojectTomlConfigSettingsSource', - 'SecretsSettingsSource', - 'TomlConfigSettingsSource', - 'YamlConfigSettingsSource', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/aws.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/aws.py deleted file mode 100644 index 5efa3f9..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/aws.py +++ /dev/null @@ -1,79 +0,0 @@ -from __future__ import annotations as _annotations # important for BaseSettings import to work - -import json -from collections.abc import Mapping -from typing import TYPE_CHECKING, Optional - -from ..utils import parse_env_vars -from .env import EnvSettingsSource - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -boto3_client = None -SecretsManagerClient = None - - -def import_aws_secrets_manager() -> None: - global boto3_client - global SecretsManagerClient - - try: - from boto3 import client as boto3_client - from mypy_boto3_secretsmanager.client import SecretsManagerClient - except ImportError as e: # pragma: no cover - raise ImportError( - 'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`' - ) from e - - -class AWSSecretsManagerSettingsSource(EnvSettingsSource): - _secret_id: str - _secretsmanager_client: SecretsManagerClient # type: ignore - - def __init__( - self, - settings_cls: type[BaseSettings], - secret_id: str, - region_name: str | None = None, - endpoint_url: str | None = None, - case_sensitive: bool | None = True, - env_prefix: str | None = None, - env_nested_delimiter: str | None = '--', - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - import_aws_secrets_manager() - self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore - self._secret_id = secret_id - super().__init__( - settings_cls, - case_sensitive=case_sensitive, - env_prefix=env_prefix, - env_nested_delimiter=env_nested_delimiter, - env_ignore_empty=False, - env_parse_none_str=env_parse_none_str, - env_parse_enums=env_parse_enums, - ) - - def _load_env_vars(self) -> Mapping[str, Optional[str]]: - response = self._secretsmanager_client.get_secret_value(SecretId=self._secret_id) # type: ignore - - return parse_env_vars( - json.loads(response['SecretString']), - self.case_sensitive, - self.env_ignore_empty, - self.env_parse_none_str, - ) - - def __repr__(self) -> str: - return ( - f'{self.__class__.__name__}(secret_id={self._secret_id!r}, ' - f'env_nested_delimiter={self.env_nested_delimiter!r})' - ) - - -__all__ = [ - 'AWSSecretsManagerSettingsSource', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/azure.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/azure.py deleted file mode 100644 index c0c9506..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/azure.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Azure Key Vault settings source.""" - -from __future__ import annotations as _annotations - -from collections.abc import Iterator, Mapping -from typing import TYPE_CHECKING, Optional - -from pydantic.alias_generators import to_snake -from pydantic.fields import FieldInfo - -from .env import EnvSettingsSource - -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential - from azure.core.exceptions import ResourceNotFoundError - from azure.keyvault.secrets import SecretClient - - from pydantic_settings.main import BaseSettings -else: - TokenCredential = None - ResourceNotFoundError = None - SecretClient = None - - -def import_azure_key_vault() -> None: - global TokenCredential - global SecretClient - global ResourceNotFoundError - - try: - from azure.core.credentials import TokenCredential - from azure.core.exceptions import ResourceNotFoundError - from azure.keyvault.secrets import SecretClient - except ImportError as e: # pragma: no cover - raise ImportError( - 'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`' - ) from e - - -class AzureKeyVaultMapping(Mapping[str, Optional[str]]): - _loaded_secrets: dict[str, str | None] - _secret_client: SecretClient - _secret_names: list[str] - - def __init__( - self, - secret_client: SecretClient, - case_sensitive: bool, - snake_case_conversion: bool, - ) -> None: - self._loaded_secrets = {} - self._secret_client = secret_client - self._case_sensitive = case_sensitive - self._snake_case_conversion = snake_case_conversion - self._secret_map: dict[str, str] = self._load_remote() - - def _load_remote(self) -> dict[str, str]: - secret_names: Iterator[str] = ( - secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled - ) - - if self._snake_case_conversion: - return {to_snake(name): name for name in secret_names} - - if self._case_sensitive: - return {name: name for name in secret_names} - - return {name.lower(): name for name in secret_names} - - def __getitem__(self, key: str) -> str | None: - new_key = key - - if self._snake_case_conversion: - new_key = to_snake(key) - elif not self._case_sensitive: - new_key = key.lower() - - if new_key not in self._loaded_secrets: - if new_key in self._secret_map: - self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value - else: - raise KeyError(key) - - return self._loaded_secrets[new_key] - - def __len__(self) -> int: - return len(self._secret_map) - - def __iter__(self) -> Iterator[str]: - return iter(self._secret_map.keys()) - - -class AzureKeyVaultSettingsSource(EnvSettingsSource): - _url: str - _credential: TokenCredential - - def __init__( - self, - settings_cls: type[BaseSettings], - url: str, - credential: TokenCredential, - dash_to_underscore: bool = False, - case_sensitive: bool | None = None, - snake_case_conversion: bool = False, - env_prefix: str | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - import_azure_key_vault() - self._url = url - self._credential = credential - self._dash_to_underscore = dash_to_underscore - self._snake_case_conversion = snake_case_conversion - super().__init__( - settings_cls, - case_sensitive=False if snake_case_conversion else case_sensitive, - env_prefix=env_prefix, - env_nested_delimiter='__' if snake_case_conversion else '--', - env_ignore_empty=False, - env_parse_none_str=env_parse_none_str, - env_parse_enums=env_parse_enums, - ) - - def _load_env_vars(self) -> Mapping[str, Optional[str]]: - secret_client = SecretClient(vault_url=self._url, credential=self._credential) - return AzureKeyVaultMapping( - secret_client=secret_client, - case_sensitive=self.case_sensitive, - snake_case_conversion=self._snake_case_conversion, - ) - - def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: - if self._snake_case_conversion: - return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name)) - - if self._dash_to_underscore: - return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name)) - - return super()._extract_field_info(field, field_name) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})' - - -__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/cli.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/cli.py deleted file mode 100644 index 87d0d58..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/cli.py +++ /dev/null @@ -1,1331 +0,0 @@ -"""Command-line interface settings source.""" - -from __future__ import annotations as _annotations - -import json -import re -import shlex -import sys -import typing -from argparse import ( - SUPPRESS, - ArgumentParser, - BooleanOptionalAction, - Namespace, - RawDescriptionHelpFormatter, - _SubParsersAction, -) -from collections import defaultdict -from collections.abc import Mapping, Sequence -from enum import Enum -from functools import cached_property -from textwrap import dedent -from types import SimpleNamespace -from typing import ( - TYPE_CHECKING, - Annotated, - Any, - Callable, - Generic, - NoReturn, - Optional, - TypeVar, - Union, - cast, - overload, -) - -import typing_extensions -from pydantic import AliasChoices, AliasPath, BaseModel, Field, PrivateAttr -from pydantic._internal._repr import Representation -from pydantic._internal._utils import is_model_class -from pydantic.dataclasses import is_pydantic_dataclass -from pydantic.fields import FieldInfo -from pydantic_core import PydanticUndefined -from typing_extensions import get_args, get_origin -from typing_inspection import typing_objects -from typing_inspection.introspection import is_union_origin - -from ...exceptions import SettingsError -from ...utils import _lenient_issubclass, _WithArgsTypes -from ..types import ( - ForceDecode, - NoDecode, - PydanticModel, - _CliExplicitFlag, - _CliImplicitFlag, - _CliPositionalArg, - _CliSubCommand, - _CliUnknownArgs, -) -from ..utils import ( - _annotation_contains_types, - _annotation_enum_val_to_name, - _get_alias_names, - _get_model_fields, - _is_function, - _strip_annotated, - parse_env_vars, -) -from .env import EnvSettingsSource - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class _CliInternalArgParser(ArgumentParser): - def __init__(self, cli_exit_on_error: bool = True, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._cli_exit_on_error = cli_exit_on_error - - def error(self, message: str) -> NoReturn: - if not self._cli_exit_on_error: - raise SettingsError(f'error parsing CLI: {message}') - super().error(message) - - -class CliMutuallyExclusiveGroup(BaseModel): - pass - - -class _CliArg(BaseModel): - model: Any - field_name: str - arg_prefix: str - case_sensitive: bool - hide_none_type: bool - kebab_case: bool - enable_decoding: Optional[bool] - env_prefix_len: int - args: list[str] = [] - kwargs: dict[str, Any] = {} - - _alias_names: tuple[str, ...] = PrivateAttr(()) - _alias_paths: dict[str, Optional[int]] = PrivateAttr({}) - _is_alias_path_only: bool = PrivateAttr(False) - _field_info: FieldInfo = PrivateAttr() - - def __init__( - self, - field_info: FieldInfo, - parser_map: defaultdict[str | FieldInfo, dict[Optional[int] | str, _CliArg]], - **values: Any, - ) -> None: - super().__init__(**values) - self._field_info = field_info - self._alias_names, self._is_alias_path_only = _get_alias_names( - self.field_name, self.field_info, alias_path_args=self._alias_paths, case_sensitive=self.case_sensitive - ) - - alias_path_dests = {f'{self.arg_prefix}{name}': index for name, index in self._alias_paths.items()} - if self.subcommand_dest: - for sub_model in self.sub_models: - subcommand_alias = self.subcommand_alias(sub_model) - parser_map[self.subcommand_dest][subcommand_alias] = self.model_copy(update={'args': [], 'kwargs': {}}) - parser_map[self.field_info][subcommand_alias] = parser_map[self.subcommand_dest][subcommand_alias] - elif self.dest not in alias_path_dests: - parser_map[self.dest][None] = self - parser_map[self.field_info][None] = parser_map[self.dest][None] - for alias_path_dest, index in alias_path_dests.items(): - parser_map[alias_path_dest][index] = self.model_copy(update={'args': [], 'kwargs': {}}) - parser_map[self.field_info][index] = parser_map[alias_path_dest][index] - - @classmethod - def get_kebab_case(cls, name: str, kebab_case: Optional[bool]) -> str: - return name.replace('_', '-') if kebab_case else name - - def subcommand_alias(self, sub_model: type[BaseModel]) -> str: - return self.get_kebab_case( - sub_model.__name__ if len(self.sub_models) > 1 else self.preferred_alias, self.kebab_case - ) - - @cached_property - def field_info(self) -> FieldInfo: - return self._field_info - - @cached_property - def subcommand_dest(self) -> Optional[str]: - return f'{self.arg_prefix}:subcommand' if _CliSubCommand in self.field_info.metadata else None - - @cached_property - def dest(self) -> str: - if ( - not self.subcommand_dest - and self.arg_prefix - and self.field_info.validation_alias is not None - and not self.is_parser_submodel - ): - # Strip prefix if validation alias is set and value is not complex. - # Related https://github.com/pydantic/pydantic-settings/pull/25 - return f'{self.arg_prefix}{self.preferred_alias}'[self.env_prefix_len :] - return f'{self.arg_prefix}{self.preferred_alias}' - - @cached_property - def preferred_arg_name(self) -> str: - return self.args[0].replace('_', '-') if self.kebab_case else self.args[0] - - @cached_property - def sub_models(self) -> list[type[BaseModel]]: - field_types: tuple[Any, ...] = ( - (self.field_info.annotation,) - if not get_args(self.field_info.annotation) - else get_args(self.field_info.annotation) - ) - if self.hide_none_type: - field_types = tuple([type_ for type_ in field_types if type_ is not type(None)]) - - sub_models: list[type[BaseModel]] = [] - for type_ in field_types: - if _annotation_contains_types(type_, (_CliSubCommand,), is_include_origin=False): - raise SettingsError( - f'CliSubCommand is not outermost annotation for {self.model.__name__}.{self.field_name}' - ) - elif _annotation_contains_types(type_, (_CliPositionalArg,), is_include_origin=False): - raise SettingsError( - f'CliPositionalArg is not outermost annotation for {self.model.__name__}.{self.field_name}' - ) - if is_model_class(_strip_annotated(type_)) or is_pydantic_dataclass(_strip_annotated(type_)): - sub_models.append(_strip_annotated(type_)) - return sub_models - - @cached_property - def alias_names(self) -> tuple[str, ...]: - return self._alias_names - - @cached_property - def alias_paths(self) -> dict[str, Optional[int]]: - return self._alias_paths - - @cached_property - def preferred_alias(self) -> str: - return self._alias_names[0] - - @cached_property - def is_alias_path_only(self) -> bool: - return self._is_alias_path_only - - @cached_property - def is_append_action(self) -> bool: - return not self.subcommand_dest and _annotation_contains_types( - self.field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True - ) - - @cached_property - def is_parser_submodel(self) -> bool: - return not self.subcommand_dest and bool(self.sub_models) and not self.is_append_action - - @cached_property - def is_no_decode(self) -> bool: - return self.field_info is not None and ( - NoDecode in self.field_info.metadata - or (self.enable_decoding is False and ForceDecode not in self.field_info.metadata) - ) - - -T = TypeVar('T') -CliSubCommand = Annotated[Union[T, None], _CliSubCommand] -CliPositionalArg = Annotated[T, _CliPositionalArg] -_CliBoolFlag = TypeVar('_CliBoolFlag', bound=bool) -CliImplicitFlag = Annotated[_CliBoolFlag, _CliImplicitFlag] -CliExplicitFlag = Annotated[_CliBoolFlag, _CliExplicitFlag] -CLI_SUPPRESS = SUPPRESS -CliSuppress = Annotated[T, CLI_SUPPRESS] -CliUnknownArgs = Annotated[list[str], Field(default=[]), _CliUnknownArgs, NoDecode] - - -class CliSettingsSource(EnvSettingsSource, Generic[T]): - """ - Source class for loading settings values from CLI. - - Note: - A `CliSettingsSource` connects with a `root_parser` object by using the parser methods to add - `settings_cls` fields as command line arguments. The `CliSettingsSource` internal parser representation - is based upon the `argparse` parsing library, and therefore, requires the parser methods to support - the same attributes as their `argparse` library counterparts. - - Args: - cli_prog_name: The CLI program name to display in help text. Defaults to `None` if cli_parse_args is `None`. - Otherwise, defaults to sys.argv[0]. - cli_parse_args: The list of CLI arguments to parse. Defaults to None. - If set to `True`, defaults to sys.argv[1:]. - cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into `None` - type(None). Defaults to "null" if cli_avoid_json is `False`, and "None" if cli_avoid_json is `True`. - cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`. - cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`. - cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`. - cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions. - Defaults to `False`. - cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs. - Defaults to `True`. - cli_prefix: Prefix for command line arguments added under the root parser. Defaults to "". - cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'. - cli_implicit_flags: Whether `bool` fields should be implicitly converted into CLI boolean flags. - (e.g. --flag, --no-flag). Defaults to `False`. - cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`. - cli_kebab_case: CLI args use kebab case. Defaults to `False`. - cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`. - case_sensitive: Whether CLI "--arg" names should be read with case-sensitivity. Defaults to `True`. - Note: Case-insensitive matching is only supported on the internal root parser and does not apply to CLI - subcommands. - root_parser: The root parser object. - parse_args_method: The root parser parse args method. Defaults to `argparse.ArgumentParser.parse_args`. - add_argument_method: The root parser add argument method. Defaults to `argparse.ArgumentParser.add_argument`. - add_argument_group_method: The root parser add argument group method. - Defaults to `argparse.ArgumentParser.add_argument_group`. - add_parser_method: The root parser add new parser (sub-command) method. - Defaults to `argparse._SubParsersAction.add_parser`. - add_subparsers_method: The root parser add subparsers (sub-commands) method. - Defaults to `argparse.ArgumentParser.add_subparsers`. - formatter_class: A class for customizing the root parser help text. Defaults to `argparse.RawDescriptionHelpFormatter`. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - cli_prog_name: str | None = None, - cli_parse_args: bool | list[str] | tuple[str, ...] | None = None, - cli_parse_none_str: str | None = None, - cli_hide_none_type: bool | None = None, - cli_avoid_json: bool | None = None, - cli_enforce_required: bool | None = None, - cli_use_class_docs_for_groups: bool | None = None, - cli_exit_on_error: bool | None = None, - cli_prefix: str | None = None, - cli_flag_prefix_char: str | None = None, - cli_implicit_flags: bool | None = None, - cli_ignore_unknown_args: bool | None = None, - cli_kebab_case: bool | None = None, - cli_shortcuts: Mapping[str, str | list[str]] | None = None, - case_sensitive: bool | None = True, - root_parser: Any = None, - parse_args_method: Callable[..., Any] | None = None, - add_argument_method: Callable[..., Any] | None = ArgumentParser.add_argument, - add_argument_group_method: Callable[..., Any] | None = ArgumentParser.add_argument_group, - add_parser_method: Callable[..., Any] | None = _SubParsersAction.add_parser, - add_subparsers_method: Callable[..., Any] | None = ArgumentParser.add_subparsers, - formatter_class: Any = RawDescriptionHelpFormatter, - ) -> None: - self.cli_prog_name = ( - cli_prog_name if cli_prog_name is not None else settings_cls.model_config.get('cli_prog_name', sys.argv[0]) - ) - self.cli_hide_none_type = ( - cli_hide_none_type - if cli_hide_none_type is not None - else settings_cls.model_config.get('cli_hide_none_type', False) - ) - self.cli_avoid_json = ( - cli_avoid_json if cli_avoid_json is not None else settings_cls.model_config.get('cli_avoid_json', False) - ) - if not cli_parse_none_str: - cli_parse_none_str = 'None' if self.cli_avoid_json is True else 'null' - self.cli_parse_none_str = cli_parse_none_str - self.cli_enforce_required = ( - cli_enforce_required - if cli_enforce_required is not None - else settings_cls.model_config.get('cli_enforce_required', False) - ) - self.cli_use_class_docs_for_groups = ( - cli_use_class_docs_for_groups - if cli_use_class_docs_for_groups is not None - else settings_cls.model_config.get('cli_use_class_docs_for_groups', False) - ) - self.cli_exit_on_error = ( - cli_exit_on_error - if cli_exit_on_error is not None - else settings_cls.model_config.get('cli_exit_on_error', True) - ) - self.cli_prefix = cli_prefix if cli_prefix is not None else settings_cls.model_config.get('cli_prefix', '') - self.cli_flag_prefix_char = ( - cli_flag_prefix_char - if cli_flag_prefix_char is not None - else settings_cls.model_config.get('cli_flag_prefix_char', '-') - ) - self._cli_flag_prefix = self.cli_flag_prefix_char * 2 - if self.cli_prefix: - if cli_prefix.startswith('.') or cli_prefix.endswith('.') or not cli_prefix.replace('.', '').isidentifier(): # type: ignore - raise SettingsError(f'CLI settings source prefix is invalid: {cli_prefix}') - self.cli_prefix += '.' - self.cli_implicit_flags = ( - cli_implicit_flags - if cli_implicit_flags is not None - else settings_cls.model_config.get('cli_implicit_flags', False) - ) - self.cli_ignore_unknown_args = ( - cli_ignore_unknown_args - if cli_ignore_unknown_args is not None - else settings_cls.model_config.get('cli_ignore_unknown_args', False) - ) - self.cli_kebab_case = ( - cli_kebab_case if cli_kebab_case is not None else settings_cls.model_config.get('cli_kebab_case', False) - ) - self.cli_shortcuts = ( - cli_shortcuts if cli_shortcuts is not None else settings_cls.model_config.get('cli_shortcuts', None) - ) - - case_sensitive = case_sensitive if case_sensitive is not None else True - if not case_sensitive and root_parser is not None: - raise SettingsError('Case-insensitive matching is only supported on the internal root parser') - - super().__init__( - settings_cls, - env_nested_delimiter='.', - env_parse_none_str=self.cli_parse_none_str, - env_parse_enums=True, - env_prefix=self.cli_prefix, - case_sensitive=case_sensitive, - ) - - root_parser = ( - _CliInternalArgParser( - cli_exit_on_error=self.cli_exit_on_error, - prog=self.cli_prog_name, - description=None if settings_cls.__doc__ is None else dedent(settings_cls.__doc__), - formatter_class=formatter_class, - prefix_chars=self.cli_flag_prefix_char, - allow_abbrev=False, - ) - if root_parser is None - else root_parser - ) - self._connect_root_parser( - root_parser=root_parser, - parse_args_method=parse_args_method, - add_argument_method=add_argument_method, - add_argument_group_method=add_argument_group_method, - add_parser_method=add_parser_method, - add_subparsers_method=add_subparsers_method, - formatter_class=formatter_class, - ) - - if cli_parse_args not in (None, False): - if cli_parse_args is True: - cli_parse_args = sys.argv[1:] - elif not isinstance(cli_parse_args, (list, tuple)): - raise SettingsError( - f'cli_parse_args must be a list or tuple of strings, received {type(cli_parse_args)}' - ) - self._load_env_vars(parsed_args=self._parse_args(self.root_parser, cli_parse_args)) - - @overload - def __call__(self) -> dict[str, Any]: ... - - @overload - def __call__(self, *, args: list[str] | tuple[str, ...] | bool) -> CliSettingsSource[T]: - """ - Parse and load the command line arguments list into the CLI settings source. - - Args: - args: - The command line arguments to parse and load. Defaults to `None`, which means do not parse - command line arguments. If set to `True`, defaults to sys.argv[1:]. If set to `False`, does - not parse command line arguments. - - Returns: - CliSettingsSource: The object instance itself. - """ - ... - - @overload - def __call__(self, *, parsed_args: Namespace | SimpleNamespace | dict[str, Any]) -> CliSettingsSource[T]: - """ - Loads parsed command line arguments into the CLI settings source. - - Note: - The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary - (e.g., vars(argparse.Namespace)) format. - - Args: - parsed_args: The parsed args to load. - - Returns: - CliSettingsSource: The object instance itself. - """ - ... - - def __call__( - self, - *, - args: list[str] | tuple[str, ...] | bool | None = None, - parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None, - ) -> dict[str, Any] | CliSettingsSource[T]: - if args is not None and parsed_args is not None: - raise SettingsError('`args` and `parsed_args` are mutually exclusive') - elif args is not None: - if args is False: - return self._load_env_vars(parsed_args={}) - if args is True: - args = sys.argv[1:] - return self._load_env_vars(parsed_args=self._parse_args(self.root_parser, args)) - elif parsed_args is not None: - return self._load_env_vars(parsed_args=parsed_args) - else: - return super().__call__() - - @overload - def _load_env_vars(self) -> Mapping[str, str | None]: ... - - @overload - def _load_env_vars(self, *, parsed_args: Namespace | SimpleNamespace | dict[str, Any]) -> CliSettingsSource[T]: - """ - Loads the parsed command line arguments into the CLI environment settings variables. - - Note: - The parsed args must be in `argparse.Namespace`, `SimpleNamespace`, or vars dictionary - (e.g., vars(argparse.Namespace)) format. - - Args: - parsed_args: The parsed args to load. - - Returns: - CliSettingsSource: The object instance itself. - """ - ... - - def _load_env_vars( - self, *, parsed_args: Namespace | SimpleNamespace | dict[str, list[str] | str] | None = None - ) -> Mapping[str, str | None] | CliSettingsSource[T]: - if parsed_args is None: - return {} - - if isinstance(parsed_args, (Namespace, SimpleNamespace)): - parsed_args = vars(parsed_args) - - selected_subcommands: list[str] = [] - for field_name, val in list(parsed_args.items()): - if isinstance(val, list): - if self._is_nested_alias_path_only_workaround(parsed_args, field_name, val): - # Workaround for nested alias path environment variables not being handled. - # See https://github.com/pydantic/pydantic-settings/issues/670 - continue - - cli_arg = self._parser_map.get(field_name, {}).get(None) - if cli_arg and cli_arg.is_no_decode: - parsed_args[field_name] = ','.join(val) - continue - - parsed_args[field_name] = self._merge_parsed_list(val, field_name) - elif field_name.endswith(':subcommand') and val is not None: - selected_subcommands.append(self._parser_map[field_name][val].dest) - - for arg_dest, arg_map in self._parser_map.items(): - if isinstance(arg_dest, str) and arg_dest.endswith(':subcommand'): - for subcommand_dest in [arg.dest for arg in arg_map.values()]: - if subcommand_dest not in selected_subcommands: - parsed_args[subcommand_dest] = self.cli_parse_none_str - - parsed_args = { - key: val - for key, val in parsed_args.items() - if not key.endswith(':subcommand') and val is not PydanticUndefined - } - if selected_subcommands: - last_selected_subcommand = max(selected_subcommands, key=len) - if not any(field_name for field_name in parsed_args.keys() if f'{last_selected_subcommand}.' in field_name): - parsed_args[last_selected_subcommand] = '{}' - - parsed_args.update(self._cli_unknown_args) - - self.env_vars = parse_env_vars( - cast(Mapping[str, str], parsed_args), - self.case_sensitive, - self.env_ignore_empty, - self.cli_parse_none_str, - ) - - return self - - def _is_nested_alias_path_only_workaround( - self, parsed_args: dict[str, list[str] | str], field_name: str, val: list[str] - ) -> bool: - """ - Workaround for nested alias path environment variables not being handled. - See https://github.com/pydantic/pydantic-settings/issues/670 - """ - known_arg = self._parser_map.get(field_name, {}).values() - if not known_arg: - return False - arg = next(iter(known_arg)) - if arg.is_alias_path_only and arg.arg_prefix.endswith('.'): - del parsed_args[field_name] - nested_dest = arg.arg_prefix[:-1] - nested_val = f'"{arg.preferred_alias}": {self._merge_parsed_list(val, field_name)}' - parsed_args[nested_dest] = ( - f'{{{nested_val}}}' - if nested_dest not in parsed_args - else f'{parsed_args[nested_dest][:-1]}, {nested_val}}}' - ) - return True - return False - - def _get_merge_parsed_list_types( - self, parsed_list: list[str], field_name: str - ) -> tuple[Optional[type], Optional[type]]: - merge_type = self._cli_dict_args.get(field_name, list) - if ( - merge_type is list - or not is_union_origin(get_origin(merge_type)) - or not any( - type_ - for type_ in get_args(merge_type) - if type_ is not type(None) and get_origin(type_) not in (dict, Mapping) - ) - ): - inferred_type = merge_type - else: - inferred_type = list if parsed_list and (len(parsed_list) > 1 or parsed_list[0].startswith('[')) else str - - return merge_type, inferred_type - - def _merged_list_to_str(self, merged_list: list[str], field_name: str) -> str: - decode_list: list[str] = [] - is_use_decode: Optional[bool] = None - cli_arg_map = self._parser_map.get(field_name, {}) - for index, item in enumerate(merged_list): - cli_arg = cli_arg_map.get(index) - is_decode = cli_arg is None or not cli_arg.is_no_decode - if is_use_decode is None: - is_use_decode = is_decode - elif is_use_decode != is_decode: - raise SettingsError('Mixing Decode and NoDecode across different AliasPath fields is not allowed') - if is_use_decode: - item = item.replace('\\', '\\\\') - elif item.startswith('"') and item.endswith('"'): - item = item[1:-1] - decode_list.append(item) - merged_list_str = ','.join(decode_list) - return f'[{merged_list_str}]' if is_use_decode else merged_list_str - - def _merge_parsed_list(self, parsed_list: list[str], field_name: str) -> str: - try: - merged_list: list[str] = [] - is_last_consumed_a_value = False - merge_type, inferred_type = self._get_merge_parsed_list_types(parsed_list, field_name) - for val in parsed_list: - if not isinstance(val, str): - # If val is not a string, it's from an external parser and we can ignore parsing the rest of the - # list. - break - val = val.strip() - if val.startswith('[') and val.endswith(']'): - val = val[1:-1].strip() - while val: - val = val.strip() - if val.startswith(','): - val = self._consume_comma(val, merged_list, is_last_consumed_a_value) - is_last_consumed_a_value = False - else: - if val.startswith('{') or val.startswith('['): - val = self._consume_object_or_array(val, merged_list) - else: - try: - val = self._consume_string_or_number(val, merged_list, merge_type) - except ValueError as e: - if merge_type is inferred_type: - raise e - merge_type = inferred_type - val = self._consume_string_or_number(val, merged_list, merge_type) - is_last_consumed_a_value = True - if not is_last_consumed_a_value: - val = self._consume_comma(val, merged_list, is_last_consumed_a_value) - - if merge_type is str: - return merged_list[0] - elif merge_type is list: - return self._merged_list_to_str(merged_list, field_name) - else: - merged_dict: dict[str, str] = {} - for item in merged_list: - merged_dict.update(json.loads(item)) - return json.dumps(merged_dict) - except Exception as e: - raise SettingsError(f'Parsing error encountered for {field_name}: {e}') - - def _consume_comma(self, item: str, merged_list: list[str], is_last_consumed_a_value: bool) -> str: - if not is_last_consumed_a_value: - merged_list.append('""') - return item[1:] - - def _consume_object_or_array(self, item: str, merged_list: list[str]) -> str: - count = 1 - close_delim = '}' if item.startswith('{') else ']' - in_str = False - for consumed in range(1, len(item)): - if item[consumed] == '"' and item[consumed - 1] != '\\': - in_str = not in_str - elif in_str: - continue - elif item[consumed] in ('{', '['): - count += 1 - elif item[consumed] in ('}', ']'): - count -= 1 - if item[consumed] == close_delim and count == 0: - merged_list.append(item[: consumed + 1]) - return item[consumed + 1 :] - raise SettingsError(f'Missing end delimiter "{close_delim}"') - - def _consume_string_or_number(self, item: str, merged_list: list[str], merge_type: type[Any] | None) -> str: - consumed = 0 if merge_type is not str else len(item) - is_find_end_quote = False - while consumed < len(item): - if item[consumed] == '"' and (consumed == 0 or item[consumed - 1] != '\\'): - is_find_end_quote = not is_find_end_quote - if not is_find_end_quote and item[consumed] == ',': - break - consumed += 1 - if is_find_end_quote: - raise SettingsError('Mismatched quotes') - val_string = item[:consumed].strip() - if merge_type in (list, str): - try: - float(val_string) - except ValueError: - if val_string == self.cli_parse_none_str: - val_string = 'null' - if val_string not in ('true', 'false', 'null') and not val_string.startswith('"'): - val_string = f'"{val_string}"' - merged_list.append(val_string) - else: - key, val = (kv for kv in val_string.split('=', 1)) - if key.startswith('"') and not key.endswith('"') and not val.startswith('"') and val.endswith('"'): - raise ValueError(f'Dictionary key=val parameter is a quoted string: {val_string}') - key, val = key.strip('"'), val.strip('"') - merged_list.append(json.dumps({key: val})) - return item[consumed:] - - def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None: - if _CliImplicitFlag in field_info.metadata: - cli_flag_name = 'CliImplicitFlag' - elif _CliExplicitFlag in field_info.metadata: - cli_flag_name = 'CliExplicitFlag' - else: - return - - if field_info.annotation is not bool: - raise SettingsError(f'{cli_flag_name} argument {model.__name__}.{field_name} is not of type bool') - - def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]]: - positional_variadic_arg = [] - positional_args, subcommand_args, optional_args = [], [], [] - for field_name, field_info in _get_model_fields(model).items(): - if _CliSubCommand in field_info.metadata: - if not field_info.is_required(): - raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value') - else: - alias_names, *_ = _get_alias_names(field_name, field_info) - if len(alias_names) > 1: - raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases') - field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)] - for field_type in field_types: - if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)): - raise SettingsError( - f'subcommand argument {model.__name__}.{field_name} has type not derived from BaseModel' - ) - subcommand_args.append((field_name, field_info)) - elif _CliPositionalArg in field_info.metadata: - alias_names, *_ = _get_alias_names(field_name, field_info) - if len(alias_names) > 1: - raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases') - is_append_action = _annotation_contains_types( - field_info.annotation, (list, set, dict, Sequence, Mapping), is_strip_annotated=True - ) - if not is_append_action: - positional_args.append((field_name, field_info)) - else: - positional_variadic_arg.append((field_name, field_info)) - else: - self._verify_cli_flag_annotations(model, field_name, field_info) - optional_args.append((field_name, field_info)) - - if positional_variadic_arg: - if len(positional_variadic_arg) > 1: - field_names = ', '.join([name for name, info in positional_variadic_arg]) - raise SettingsError(f'{model.__name__} has multiple variadic positional arguments: {field_names}') - elif subcommand_args: - field_names = ', '.join([name for name, info in positional_variadic_arg + subcommand_args]) - raise SettingsError( - f'{model.__name__} has variadic positional arguments and subcommand arguments: {field_names}' - ) - - return positional_args + positional_variadic_arg + subcommand_args + optional_args - - @property - def root_parser(self) -> T: - """The connected root parser instance.""" - return self._root_parser - - def _connect_parser_method( - self, parser_method: Callable[..., Any] | None, method_name: str, *args: Any, **kwargs: Any - ) -> Callable[..., Any]: - if ( - parser_method is not None - and self.case_sensitive is False - and method_name == 'parse_args_method' - and isinstance(self._root_parser, _CliInternalArgParser) - ): - - def parse_args_insensitive_method( - root_parser: _CliInternalArgParser, - args: list[str] | tuple[str, ...] | None = None, - namespace: Namespace | None = None, - ) -> Any: - insensitive_args = [] - for arg in shlex.split(shlex.join(args)) if args else []: - flag_prefix = rf'\{self.cli_flag_prefix_char}{{1,2}}' - matched = re.match(rf'^({flag_prefix}[^\s=]+)(.*)', arg) - if matched: - arg = matched.group(1).lower() + matched.group(2) - insensitive_args.append(arg) - return parser_method(root_parser, insensitive_args, namespace) - - return parse_args_insensitive_method - - elif parser_method is None: - - def none_parser_method(*args: Any, **kwargs: Any) -> Any: - raise SettingsError( - f'cannot connect CLI settings source root parser: {method_name} is set to `None` but is needed for connecting' - ) - - return none_parser_method - - else: - return parser_method - - def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]: - add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') - - def add_group_method(parser: Any, **kwargs: Any) -> Any: - if not kwargs.pop('_is_cli_mutually_exclusive_group'): - kwargs.pop('required') - return add_argument_group(parser, **kwargs) - else: - main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs} - main_group_kwargs['title'] += ' (mutually exclusive)' - group = add_argument_group(parser, **main_group_kwargs) - if not hasattr(group, 'add_mutually_exclusive_group'): - raise SettingsError( - 'cannot connect CLI settings source root parser: ' - 'group object is missing add_mutually_exclusive_group but is needed for connecting' - ) - return group.add_mutually_exclusive_group(**kwargs) - - return add_group_method - - def _connect_root_parser( - self, - root_parser: T, - parse_args_method: Callable[..., Any] | None, - add_argument_method: Callable[..., Any] | None = ArgumentParser.add_argument, - add_argument_group_method: Callable[..., Any] | None = ArgumentParser.add_argument_group, - add_parser_method: Callable[..., Any] | None = _SubParsersAction.add_parser, - add_subparsers_method: Callable[..., Any] | None = ArgumentParser.add_subparsers, - formatter_class: Any = RawDescriptionHelpFormatter, - ) -> None: - self._cli_unknown_args: dict[str, list[str]] = {} - - def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace: - args, unknown_args = ArgumentParser.parse_known_args(*args, **kwargs) - for dest in self._cli_unknown_args: - self._cli_unknown_args[dest] = unknown_args - return cast(Namespace, args) - - self._root_parser = root_parser - if parse_args_method is None: - parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args - self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method') - self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method') - self._add_group = self._connect_group_method(add_argument_group_method) - self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method') - self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method') - self._formatter_class = formatter_class - self._cli_dict_args: dict[str, type[Any] | None] = {} - self._parser_map: defaultdict[str | FieldInfo, dict[Optional[int] | str, _CliArg]] = defaultdict(dict) - self._add_parser_args( - parser=self.root_parser, - model=self.settings_cls, - added_args=[], - arg_prefix=self.env_prefix, - subcommand_prefix=self.env_prefix, - group=None, - alias_prefixes=[], - model_default=PydanticUndefined, - ) - - def _add_parser_args( - self, - parser: Any, - model: type[BaseModel], - added_args: list[str], - arg_prefix: str, - subcommand_prefix: str, - group: Any, - alias_prefixes: list[str], - model_default: Any, - is_model_suppressed: bool = False, - ) -> ArgumentParser: - subparsers: Any = None - alias_path_args: dict[str, Optional[int]] = {} - # Ignore model default if the default is a model and not a subclass of the current model. - model_default = ( - None - if ( - (is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default))) - and not issubclass(type(model_default), model) - ) - else model_default - ) - for field_name, field_info in self._sort_arg_fields(model): - arg = _CliArg( - field_info=field_info, - parser_map=self._parser_map, - model=model, - field_name=field_name, - arg_prefix=arg_prefix, - case_sensitive=self.case_sensitive, - hide_none_type=self.cli_hide_none_type, - kebab_case=self.cli_kebab_case, - enable_decoding=self.config.get('enable_decoding'), - env_prefix_len=self.env_prefix_len, - ) - alias_path_args.update(arg.alias_paths) - - if arg.subcommand_dest: - for sub_model in arg.sub_models: - subcommand_alias = arg.subcommand_alias(sub_model) - subcommand_arg = self._parser_map[arg.subcommand_dest][subcommand_alias] - subcommand_arg.args = [subcommand_alias] - subcommand_arg.kwargs['allow_abbrev'] = False - subcommand_arg.kwargs['formatter_class'] = self._formatter_class - subcommand_arg.kwargs['description'] = ( - None if sub_model.__doc__ is None else dedent(sub_model.__doc__) - ) - subcommand_arg.kwargs['help'] = None if len(arg.sub_models) > 1 else field_info.description - if self.cli_use_class_docs_for_groups: - subcommand_arg.kwargs['help'] = None if sub_model.__doc__ is None else dedent(sub_model.__doc__) - - subparsers = ( - self._add_subparsers( - parser, - title='subcommands', - dest=f'{arg_prefix}:subcommand', - description=field_info.description if len(arg.sub_models) > 1 else None, - ) - if subparsers is None - else subparsers - ) - - if hasattr(subparsers, 'metavar'): - subparsers.metavar = ( - f'{subparsers.metavar[:-1]},{subcommand_alias}}}' - if subparsers.metavar - else f'{{{subcommand_alias}}}' - ) - - self._add_parser_args( - parser=self._add_parser(subparsers, *subcommand_arg.args, **subcommand_arg.kwargs), - model=sub_model, - added_args=[], - arg_prefix=f'{arg.dest}.', - subcommand_prefix=f'{subcommand_prefix}{arg.preferred_alias}.', - group=None, - alias_prefixes=[], - model_default=PydanticUndefined, - ) - else: - flag_prefix: str = self._cli_flag_prefix - arg.kwargs['dest'] = arg.dest - arg.kwargs['default'] = CLI_SUPPRESS - arg.kwargs['help'] = self._help_format(field_name, field_info, model_default, is_model_suppressed) - arg.kwargs['metavar'] = self._metavar_format(field_info.annotation) - arg.kwargs['required'] = ( - self.cli_enforce_required and field_info.is_required() and model_default is PydanticUndefined - ) - - arg_names = self._get_arg_names( - arg_prefix, subcommand_prefix, alias_prefixes, arg.alias_names, added_args - ) - if not arg_names or (arg.kwargs['dest'] in added_args): - continue - - self._convert_append_action(arg.kwargs, field_info, arg.is_append_action) - - if _CliPositionalArg in field_info.metadata: - arg_names, flag_prefix = self._convert_positional_arg( - arg.kwargs, field_info, arg.preferred_alias, model_default - ) - - self._convert_bool_flag(arg.kwargs, field_info, model_default) - - if arg.is_parser_submodel and not getattr(field_info.annotation, '__pydantic_root_model__', False): - self._add_parser_submodels( - parser, - model, - arg.sub_models, - added_args, - arg_prefix, - subcommand_prefix, - flag_prefix, - arg_names, - arg.kwargs, - field_name, - field_info, - arg.alias_names, - model_default=model_default, - is_model_suppressed=is_model_suppressed, - ) - elif _CliUnknownArgs in field_info.metadata: - self._cli_unknown_args[arg.kwargs['dest']] = [] - elif not arg.is_alias_path_only: - if isinstance(group, dict): - group = self._add_group(parser, **group) - context = parser if group is None else group - arg.args = [f'{flag_prefix[: len(name)]}{name}' for name in arg_names] - self._add_argument(context, *arg.args, **arg.kwargs) - added_args += list(arg_names) - - self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group) - return parser - - def _convert_append_action(self, kwargs: dict[str, Any], field_info: FieldInfo, is_append_action: bool) -> None: - if is_append_action: - kwargs['action'] = 'append' - if _annotation_contains_types(field_info.annotation, (dict, Mapping), is_strip_annotated=True): - self._cli_dict_args[kwargs['dest']] = field_info.annotation - - def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None: - if kwargs['metavar'] == 'bool': - if (self.cli_implicit_flags or _CliImplicitFlag in field_info.metadata) and ( - _CliExplicitFlag not in field_info.metadata - ): - del kwargs['metavar'] - kwargs['action'] = BooleanOptionalAction - - def _convert_positional_arg( - self, kwargs: dict[str, Any], field_info: FieldInfo, preferred_alias: str, model_default: Any - ) -> tuple[list[str], str]: - flag_prefix = '' - arg_names = [kwargs['dest']] - kwargs['default'] = PydanticUndefined - kwargs['metavar'] = _CliArg.get_kebab_case(preferred_alias.upper(), self.cli_kebab_case) - - # Note: CLI positional args are always strictly required at the CLI. Therefore, use field_info.is_required in - # conjunction with model_default instead of the derived kwargs['required']. - is_required = field_info.is_required() and model_default is PydanticUndefined - if kwargs.get('action') == 'append': - del kwargs['action'] - kwargs['nargs'] = '+' if is_required else '*' - elif not is_required: - kwargs['nargs'] = '?' - - del kwargs['dest'] - del kwargs['required'] - return arg_names, flag_prefix - - def _get_arg_names( - self, - arg_prefix: str, - subcommand_prefix: str, - alias_prefixes: list[str], - alias_names: tuple[str, ...], - added_args: list[str], - ) -> list[str]: - arg_names: list[str] = [] - for prefix in [arg_prefix] + alias_prefixes: - for name in alias_names: - arg_name = _CliArg.get_kebab_case( - f'{prefix}{name}' - if subcommand_prefix == self.env_prefix - else f'{prefix.replace(subcommand_prefix, "", 1)}{name}', - self.cli_kebab_case, - ) - if arg_name not in added_args: - arg_names.append(arg_name) - - if self.cli_shortcuts: - for target, aliases in self.cli_shortcuts.items(): - if target in arg_names: - alias_list = [aliases] if isinstance(aliases, str) else aliases - arg_names.extend(alias for alias in alias_list if alias not in added_args) - - return arg_names - - def _add_parser_submodels( - self, - parser: Any, - model: type[BaseModel], - sub_models: list[type[BaseModel]], - added_args: list[str], - arg_prefix: str, - subcommand_prefix: str, - flag_prefix: str, - arg_names: list[str], - kwargs: dict[str, Any], - field_name: str, - field_info: FieldInfo, - alias_names: tuple[str, ...], - model_default: Any, - is_model_suppressed: bool, - ) -> None: - if issubclass(model, CliMutuallyExclusiveGroup): - # Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a - # mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion). - # Since nested models result in a group add, raise an exception for nested models in a mutually - # exclusive group. - raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup') - - model_group: Any = None - model_group_kwargs: dict[str, Any] = {} - model_group_kwargs['title'] = f'{arg_names[0]} options' - model_group_kwargs['description'] = field_info.description - model_group_kwargs['required'] = kwargs['required'] - model_group_kwargs['_is_cli_mutually_exclusive_group'] = any( - issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models - ) - if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1: - raise SettingsError('cannot use union with CliMutuallyExclusiveGroup') - if self.cli_use_class_docs_for_groups and len(sub_models) == 1: - model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__) - - if model_default is not PydanticUndefined: - if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)): - model_default = getattr(model_default, field_name) - else: - if field_info.default is not PydanticUndefined: - model_default = field_info.default - elif field_info.default_factory is not None: - model_default = field_info.default_factory - if model_default is None: - desc_header = f'default: {self.cli_parse_none_str} (undefined)' - if model_group_kwargs['description'] is not None: - model_group_kwargs['description'] = dedent(f'{desc_header}\n{model_group_kwargs["description"]}') - else: - model_group_kwargs['description'] = desc_header - - preferred_alias = alias_names[0] - is_model_suppressed = self._is_field_suppressed(field_info) or is_model_suppressed - if is_model_suppressed: - model_group_kwargs['description'] = CLI_SUPPRESS - if not self.cli_avoid_json: - added_args.append(arg_names[0]) - kwargs['required'] = False - kwargs['nargs'] = '?' - kwargs['const'] = '{}' - kwargs['help'] = ( - CLI_SUPPRESS if is_model_suppressed else f'set {arg_names[0]} from JSON string (default: {{}})' - ) - model_group = self._add_group(parser, **model_group_kwargs) - self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs) - for model in sub_models: - self._add_parser_args( - parser=parser, - model=model, - added_args=added_args, - arg_prefix=f'{arg_prefix}{preferred_alias}.', - subcommand_prefix=subcommand_prefix, - group=model_group if model_group else model_group_kwargs, - alias_prefixes=[f'{arg_prefix}{name}.' for name in alias_names[1:]], - model_default=model_default, - is_model_suppressed=is_model_suppressed, - ) - - def _add_parser_alias_paths( - self, - parser: Any, - alias_path_args: dict[str, Optional[int]], - added_args: list[str], - arg_prefix: str, - subcommand_prefix: str, - group: Any, - ) -> None: - if alias_path_args: - context = parser - if group is not None: - context = self._add_group(parser, **group) if isinstance(group, dict) else group - for name, index in alias_path_args.items(): - arg_name = ( - f'{arg_prefix}{name}' - if subcommand_prefix == self.env_prefix - else f'{arg_prefix.replace(subcommand_prefix, "", 1)}{name}' - ) - kwargs: dict[str, Any] = {} - kwargs['default'] = CLI_SUPPRESS - kwargs['help'] = 'pydantic alias path' - kwargs['action'] = 'append' - kwargs['metavar'] = 'list' - if index is None: - kwargs['metavar'] = 'dict' - self._cli_dict_args[arg_name] = dict - args = [f'{self._cli_flag_prefix}{arg_name}'] - for key, arg in self._parser_map[arg_name].items(): - arg.args, arg.kwargs = args, kwargs - self._add_argument(context, *args, **kwargs) - added_args.append(arg_name) - - def _get_modified_args(self, obj: Any) -> tuple[str, ...]: - if not self.cli_hide_none_type: - return get_args(obj) - else: - return tuple([type_ for type_ in get_args(obj) if type_ is not type(None)]) - - def _metavar_format_choices(self, args: list[str], obj_qualname: str | None = None) -> str: - if 'JSON' in args: - args = args[: args.index('JSON') + 1] + [arg for arg in args[args.index('JSON') + 1 :] if arg != 'JSON'] - metavar = ','.join(args) - if obj_qualname: - return f'{obj_qualname}[{metavar}]' - else: - return metavar if len(args) == 1 else f'{{{metavar}}}' - - def _metavar_format_recurse(self, obj: Any) -> str: - """Pretty metavar representation of a type. Adapts logic from `pydantic._repr.display_as_type`.""" - obj = _strip_annotated(obj) - if _is_function(obj): - # If function is locally defined use __name__ instead of __qualname__ - return obj.__name__ if '' in obj.__qualname__ else obj.__qualname__ - elif obj is ...: - return '...' - elif isinstance(obj, Representation): - return repr(obj) - elif typing_objects.is_typealiastype(obj): - return str(obj) - - origin = get_origin(obj) - if origin is None and not isinstance(obj, (type, typing.ForwardRef, typing_extensions.ForwardRef)): - obj = obj.__class__ - - if is_union_origin(origin): - return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj)))) - elif typing_objects.is_literal(origin): - return self._metavar_format_choices(list(map(str, self._get_modified_args(obj)))) - elif _lenient_issubclass(obj, Enum): - return self._metavar_format_choices([val.name for val in obj]) - elif isinstance(obj, _WithArgsTypes): - return self._metavar_format_choices( - list(map(self._metavar_format_recurse, self._get_modified_args(obj))), - obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj), - ) - elif obj is type(None): - return self.cli_parse_none_str - elif is_model_class(obj) or is_pydantic_dataclass(obj): - return ( - self._metavar_format_recurse(_get_model_fields(obj)['root'].annotation) - if getattr(obj, '__pydantic_root_model__', False) - else 'JSON' - ) - elif isinstance(obj, type): - return obj.__qualname__ - else: - return repr(obj).replace('typing.', '').replace('typing_extensions.', '') - - def _metavar_format(self, obj: Any) -> str: - return self._metavar_format_recurse(obj).replace(', ', ',') - - def _help_format( - self, field_name: str, field_info: FieldInfo, model_default: Any, is_model_suppressed: bool - ) -> str: - _help = field_info.description if field_info.description else '' - if is_model_suppressed or self._is_field_suppressed(field_info): - return CLI_SUPPRESS - - if field_info.is_required() and model_default in (PydanticUndefined, None): - if _CliPositionalArg not in field_info.metadata: - ifdef = 'ifdef: ' if model_default is None else '' - _help += f' ({ifdef}required)' if _help else f'({ifdef}required)' - else: - default = f'(default: {self.cli_parse_none_str})' - if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)): - default = f'(default: {getattr(model_default, field_name)})' - elif model_default not in (PydanticUndefined, None) and _is_function(model_default): - default = f'(default factory: {self._metavar_format(model_default)})' - elif field_info.default not in (PydanticUndefined, None): - enum_name = _annotation_enum_val_to_name(field_info.annotation, field_info.default) - default = f'(default: {field_info.default if enum_name is None else enum_name})' - elif field_info.default_factory is not None: - default = f'(default factory: {self._metavar_format(field_info.default_factory)})' - _help += f' {default}' if _help else default - return _help.replace('%', '%%') if issubclass(type(self._root_parser), ArgumentParser) else _help - - def _is_field_suppressed(self, field_info: FieldInfo) -> bool: - _help = field_info.description if field_info.description else '' - return _help == CLI_SUPPRESS or CLI_SUPPRESS in field_info.metadata - - def _update_alias_path_only_default( - self, arg_name: str, value: Any, field_info: FieldInfo, alias_path_only_defaults: dict[str, Any] - ) -> list[Any] | dict[str, Any]: - alias_path: AliasPath = [ - alias if isinstance(alias, AliasPath) else cast(AliasPath, alias.choices[0]) - for alias in (field_info.alias, field_info.validation_alias) - if isinstance(alias, (AliasPath, AliasChoices)) - ][0] - - alias_nested_paths: list[str] = alias_path.path[1:-1] # type: ignore - if not alias_nested_paths: - alias_path_only_defaults.setdefault(arg_name, []) - alias_default = alias_path_only_defaults[arg_name] - else: - alias_path_only_defaults.setdefault(arg_name, {}) - current_path = alias_path_only_defaults[arg_name] - - for nested_path in alias_nested_paths[:-1]: - current_path.setdefault(nested_path, {}) - current_path = current_path[nested_path] - current_path.setdefault(alias_nested_paths[-1], []) - alias_default = current_path[alias_nested_paths[-1]] - - alias_path_index = cast(int, alias_path.path[-1]) - alias_default.extend([''] * max(alias_path_index + 1 - len(alias_default), 0)) - alias_default[alias_path_index] = value - return alias_path_only_defaults[arg_name] - - def _serialized_args(self, model: PydanticModel, _is_submodel: bool = False) -> list[str]: - alias_path_only_defaults: dict[str, Any] = {} - optional_args: list[str | list[Any] | dict[str, Any]] = [] - positional_args: list[str | list[Any] | dict[str, Any]] = [] - subcommand_args: list[str] = [] - for field_name, field_info in _get_model_fields(type(model) if _is_submodel else self.settings_cls).items(): - model_default = getattr(model, field_name) - if field_info.default == model_default: - continue - if _CliSubCommand in field_info.metadata and model_default is None: - continue - arg = next(iter(self._parser_map[field_info].values())) - if arg.subcommand_dest: - subcommand_args.append(arg.subcommand_alias(type(model_default))) - subcommand_args += self._serialized_args(model_default, _is_submodel=True) - continue - if is_model_class(type(model_default)) or is_pydantic_dataclass(type(model_default)): - positional_args += self._serialized_args(model_default, _is_submodel=True) - continue - - matched = re.match(r'(-*)(.+)', arg.preferred_arg_name) - flag_chars, arg_name = matched.groups() if matched else ('', '') - value: str | list[Any] | dict[str, Any] = ( - json.dumps(model_default) if isinstance(model_default, (dict, list, set)) else str(model_default) - ) - - if arg.is_alias_path_only: - # For alias path only, we wont know the complete value until we've finished parsing the entire class. In - # this case, insert value as a non-string reference pointing to the relevant alias_path_only_defaults - # entry and convert into completed string value later. - value = self._update_alias_path_only_default(arg_name, value, field_info, alias_path_only_defaults) - - if _CliPositionalArg in field_info.metadata: - for value in model_default if isinstance(model_default, list) else [model_default]: - value = json.dumps(value) if isinstance(value, (dict, list, set)) else str(value) - positional_args.append(value) - continue - - # Note: prepend 'no-' for boolean optional action flag if model_default value is False and flag is not a short option - if arg.kwargs.get('action') == BooleanOptionalAction and model_default is False and flag_chars == '--': - flag_chars += 'no-' - - optional_args.append(f'{flag_chars}{arg_name}') - - # If implicit bool flag, do not add a value - if arg.kwargs.get('action') != BooleanOptionalAction: - optional_args.append(value) - - serialized_args: list[str] = [] - serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in optional_args] - serialized_args += [json.dumps(value) if not isinstance(value, str) else value for value in positional_args] - return serialized_args + subcommand_args diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/dotenv.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/dotenv.py deleted file mode 100644 index 9816588..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/dotenv.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Dotenv file settings source.""" - -from __future__ import annotations as _annotations - -import os -import warnings -from collections.abc import Mapping -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from dotenv import dotenv_values -from pydantic._internal._typing_extra import ( # type: ignore[attr-defined] - get_origin, -) -from typing_inspection.introspection import is_union_origin - -from ..types import ENV_FILE_SENTINEL, DotenvType -from ..utils import ( - _annotation_is_complex, - _union_is_complex, - parse_env_vars, -) -from .env import EnvSettingsSource - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class DotEnvSettingsSource(EnvSettingsSource): - """ - Source class for loading settings values from env files. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - env_file: DotenvType | None = ENV_FILE_SENTINEL, - env_file_encoding: str | None = None, - case_sensitive: bool | None = None, - env_prefix: str | None = None, - env_nested_delimiter: str | None = None, - env_nested_max_split: int | None = None, - env_ignore_empty: bool | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file') - self.env_file_encoding = ( - env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding') - ) - super().__init__( - settings_cls, - case_sensitive, - env_prefix, - env_nested_delimiter, - env_nested_max_split, - env_ignore_empty, - env_parse_none_str, - env_parse_enums, - ) - - def _load_env_vars(self) -> Mapping[str, str | None]: - return self._read_env_files() - - @staticmethod - def _static_read_env_file( - file_path: Path, - *, - encoding: str | None = None, - case_sensitive: bool = False, - ignore_empty: bool = False, - parse_none_str: str | None = None, - ) -> Mapping[str, str | None]: - file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8') - return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str) - - def _read_env_file( - self, - file_path: Path, - ) -> Mapping[str, str | None]: - return self._static_read_env_file( - file_path, - encoding=self.env_file_encoding, - case_sensitive=self.case_sensitive, - ignore_empty=self.env_ignore_empty, - parse_none_str=self.env_parse_none_str, - ) - - def _read_env_files(self) -> Mapping[str, str | None]: - env_files = self.env_file - if env_files is None: - return {} - - if isinstance(env_files, (str, os.PathLike)): - env_files = [env_files] - - dotenv_vars: dict[str, str | None] = {} - for env_file in env_files: - env_path = Path(env_file).expanduser() - if env_path.is_file(): - dotenv_vars.update(self._read_env_file(env_path)) - - return dotenv_vars - - def __call__(self) -> dict[str, Any]: - data: dict[str, Any] = super().__call__() - is_extra_allowed = self.config.get('extra') != 'forbid' - - # As `extra` config is allowed in dotenv settings source, We have to - # update data with extra env variables from dotenv file. - for env_name, env_value in self.env_vars.items(): - if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields): - continue - env_used = False - for field_name, field in self.settings_cls.model_fields.items(): - for _, field_env_name, _ in self._extract_field_info(field, field_name): - if env_name == field_env_name or ( - ( - _annotation_is_complex(field.annotation, field.metadata) - or ( - is_union_origin(get_origin(field.annotation)) - and _union_is_complex(field.annotation, field.metadata) - ) - ) - and env_name.startswith(field_env_name) - ): - env_used = True - break - if env_used: - break - if not env_used: - if is_extra_allowed and env_name.startswith(self.env_prefix): - # env_prefix should be respected and removed from the env_name - normalized_env_name = env_name[len(self.env_prefix) :] - data[normalized_env_name] = env_value - else: - data[env_name] = env_value - return data - - def __repr__(self) -> str: - return ( - f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' - f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})' - ) - - -def read_env_file( - file_path: Path, - *, - encoding: str | None = None, - case_sensitive: bool = False, - ignore_empty: bool = False, - parse_none_str: str | None = None, -) -> Mapping[str, str | None]: - warnings.warn( - 'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must', - DeprecationWarning, - ) - return DotEnvSettingsSource._static_read_env_file( - file_path, - encoding=encoding, - case_sensitive=case_sensitive, - ignore_empty=ignore_empty, - parse_none_str=parse_none_str, - ) - - -__all__ = ['DotEnvSettingsSource', 'read_env_file'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/env.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/env.py deleted file mode 100644 index 5a350f1..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/env.py +++ /dev/null @@ -1,270 +0,0 @@ -from __future__ import annotations as _annotations - -import os -from collections.abc import Mapping -from typing import ( - TYPE_CHECKING, - Any, -) - -from pydantic._internal._utils import deep_update, is_model_class -from pydantic.dataclasses import is_pydantic_dataclass -from pydantic.fields import FieldInfo -from typing_extensions import get_args, get_origin -from typing_inspection.introspection import is_union_origin - -from ...utils import _lenient_issubclass -from ..base import PydanticBaseEnvSettingsSource -from ..types import EnvNoneType -from ..utils import ( - _annotation_enum_name_to_val, - _get_model_fields, - _union_is_complex, - parse_env_vars, -) - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class EnvSettingsSource(PydanticBaseEnvSettingsSource): - """ - Source class for loading settings values from environment variables. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - case_sensitive: bool | None = None, - env_prefix: str | None = None, - env_nested_delimiter: str | None = None, - env_nested_max_split: int | None = None, - env_ignore_empty: bool | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - super().__init__( - settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums - ) - self.env_nested_delimiter = ( - env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter') - ) - self.env_nested_max_split = ( - env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split') - ) - self.maxsplit = (self.env_nested_max_split or 0) - 1 - self.env_prefix_len = len(self.env_prefix) - - self.env_vars = self._load_env_vars() - - def _load_env_vars(self) -> Mapping[str, str | None]: - return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str) - - def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - """ - Gets the value for field from environment variables and a flag to determine whether value is complex. - - Args: - field: The field. - field_name: The field name. - - Returns: - A tuple that contains the value (`None` if not found), key, and - a flag to determine whether value is complex. - """ - - env_val: str | None = None - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): - env_val = self.env_vars.get(env_name) - if env_val is not None: - break - - return env_val, field_key, value_is_complex - - def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: - """ - Prepare value for the field. - - * Extract value for nested field. - * Deserialize value to python object for complex field. - - Args: - field: The field. - field_name: The field name. - - Returns: - A tuple contains prepared value for the field. - - Raises: - ValuesError: When There is an error in deserializing value for complex field. - """ - is_complex, allow_parse_failure = self._field_is_complex(field) - if self.env_parse_enums: - enum_val = _annotation_enum_name_to_val(field.annotation, value) - value = value if enum_val is None else enum_val - - if is_complex or value_is_complex: - if isinstance(value, EnvNoneType): - return value - elif value is None: - # field is complex but no value found so far, try explode_env_vars - env_val_built = self.explode_env_vars(field_name, field, self.env_vars) - if env_val_built: - return env_val_built - else: - # field is complex and there's a value, decode that as JSON, then add explode_env_vars - try: - value = self.decode_complex_value(field_name, field, value) - except ValueError as e: - if not allow_parse_failure: - raise e - - if isinstance(value, dict): - return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars)) - else: - return value - elif value is not None: - # simplest case, field is not complex, we only need to add the value if it was found - return value - - def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: - """ - Find out if a field is complex, and if so whether JSON errors should be ignored - """ - if self.field_is_complex(field): - allow_parse_failure = False - elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): - allow_parse_failure = True - else: - return False, False - - return True, allow_parse_failure - - # Default value of `case_sensitive` is `None`, because we don't want to break existing behavior. - # We have to change the method to a non-static method and use - # `self.case_sensitive` instead in V3. - def next_field( - self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None - ) -> FieldInfo | None: - """ - Find the field in a sub model by key(env name) - - By having the following models: - - ```py - class SubSubModel(BaseSettings): - dvals: Dict - - class SubModel(BaseSettings): - vals: list[str] - sub_sub_model: SubSubModel - - class Cfg(BaseSettings): - sub_model: SubModel - ``` - - Then: - next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class - next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class - - Args: - field: The field. - key: The key (env name). - case_sensitive: Whether to search for key case sensitively. - - Returns: - Field if it finds the next field otherwise `None`. - """ - if not field: - return None - - annotation = field.annotation if isinstance(field, FieldInfo) else field - for type_ in get_args(annotation): - type_has_key = self.next_field(type_, key, case_sensitive) - if type_has_key: - return type_has_key - if is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type] - fields = _get_model_fields(annotation) - # `case_sensitive is None` is here to be compatible with the old behavior. - # Has to be removed in V3. - for field_name, f in fields.items(): - for _, env_name, _ in self._extract_field_info(f, field_name): - if case_sensitive is None or case_sensitive: - if field_name == key or env_name == key: - return f - elif field_name.lower() == key.lower() or env_name.lower() == key.lower(): - return f - return None - - def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]: - """ - Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. - - This is applied to a single field, hence filtering by env_var prefix. - - Args: - field_name: The field name. - field: The field. - env_vars: Environment variables. - - Returns: - A dictionary contains extracted values from nested env values. - """ - if not self.env_nested_delimiter: - return {} - - ann = field.annotation - is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict) - - prefixes = [ - f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name) - ] - result: dict[str, Any] = {} - for env_name, env_val in env_vars.items(): - try: - prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix)) - except StopIteration: - continue - # we remove the prefix before splitting in case the prefix has characters in common with the delimiter - env_name_without_prefix = env_name[len(prefix) :] - *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit) - env_var = result - target_field: FieldInfo | None = field - for key in keys: - target_field = self.next_field(target_field, key, self.case_sensitive) - if isinstance(env_var, dict): - env_var = env_var.setdefault(key, {}) - - # get proper field with last_key - target_field = self.next_field(target_field, last_key, self.case_sensitive) - - # check if env_val maps to a complex field and if so, parse the env_val - if (target_field or is_dict) and env_val: - if target_field: - is_complex, allow_json_failure = self._field_is_complex(target_field) - if self.env_parse_enums: - enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val) - env_val = env_val if enum_val is None else enum_val - else: - # nested field type is dict - is_complex, allow_json_failure = True, True - if is_complex: - try: - env_val = self.decode_complex_value(last_key, target_field, env_val) # type: ignore - except ValueError as e: - if not allow_json_failure: - raise e - if isinstance(env_var, dict): - if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}: - env_var[last_key] = env_val - - return result - - def __repr__(self) -> str: - return ( - f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, ' - f'env_prefix_len={self.env_prefix_len!r})' - ) - - -__all__ = ['EnvSettingsSource'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/gcp.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/gcp.py deleted file mode 100644 index 62f356a..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/gcp.py +++ /dev/null @@ -1,152 +0,0 @@ -from __future__ import annotations as _annotations - -from collections.abc import Iterator, Mapping -from functools import cached_property -from typing import TYPE_CHECKING, Optional - -from .env import EnvSettingsSource - -if TYPE_CHECKING: - from google.auth import default as google_auth_default - from google.auth.credentials import Credentials - from google.cloud.secretmanager import SecretManagerServiceClient - - from pydantic_settings.main import BaseSettings -else: - Credentials = None - SecretManagerServiceClient = None - google_auth_default = None - - -def import_gcp_secret_manager() -> None: - global Credentials - global SecretManagerServiceClient - global google_auth_default - - try: - from google.auth import default as google_auth_default - from google.auth.credentials import Credentials - from google.cloud.secretmanager import SecretManagerServiceClient - except ImportError as e: # pragma: no cover - raise ImportError( - 'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`' - ) from e - - -class GoogleSecretManagerMapping(Mapping[str, Optional[str]]): - _loaded_secrets: dict[str, str | None] - _secret_client: SecretManagerServiceClient - - def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None: - self._loaded_secrets = {} - self._secret_client = secret_client - self._project_id = project_id - self._case_sensitive = case_sensitive - - @property - def _gcp_project_path(self) -> str: - return self._secret_client.common_project_path(self._project_id) - - @cached_property - def _secret_names(self) -> list[str]: - rv: list[str] = [] - - secrets = self._secret_client.list_secrets(parent=self._gcp_project_path) - for secret in secrets: - name = self._secret_client.parse_secret_path(secret.name).get('secret', '') - if not self._case_sensitive: - name = name.lower() - rv.append(name) - return rv - - def _secret_version_path(self, key: str, version: str = 'latest') -> str: - return self._secret_client.secret_version_path(self._project_id, key, version) - - def __getitem__(self, key: str) -> str | None: - if not self._case_sensitive: - key = key.lower() - if key not in self._loaded_secrets: - # If we know the key isn't available in secret manager, raise a key error - if key not in self._secret_names: - raise KeyError(key) - - try: - self._loaded_secrets[key] = self._secret_client.access_secret_version( - name=self._secret_version_path(key) - ).payload.data.decode('UTF-8') - except Exception: - raise KeyError(key) - - return self._loaded_secrets[key] - - def __len__(self) -> int: - return len(self._secret_names) - - def __iter__(self) -> Iterator[str]: - return iter(self._secret_names) - - -class GoogleSecretManagerSettingsSource(EnvSettingsSource): - _credentials: Credentials - _secret_client: SecretManagerServiceClient - _project_id: str - - def __init__( - self, - settings_cls: type[BaseSettings], - credentials: Credentials | None = None, - project_id: str | None = None, - env_prefix: str | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - secret_client: SecretManagerServiceClient | None = None, - case_sensitive: bool | None = True, - ) -> None: - # Import Google Packages if they haven't already been imported - if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None: - import_gcp_secret_manager() - - # If credentials or project_id are not passed, then - # try to get them from the default function - if not credentials or not project_id: - _creds, _project_id = google_auth_default() # type: ignore[no-untyped-call] - - # Set the credentials and/or project id if they weren't specified - if credentials is None: - credentials = _creds - - if project_id is None: - if isinstance(_project_id, str): - project_id = _project_id - else: - raise AttributeError( - 'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default' - ) - - self._credentials: Credentials = credentials - self._project_id: str = project_id - - if secret_client: - self._secret_client = secret_client - else: - self._secret_client = SecretManagerServiceClient(credentials=self._credentials) - - super().__init__( - settings_cls, - case_sensitive=case_sensitive, - env_prefix=env_prefix, - env_ignore_empty=False, - env_parse_none_str=env_parse_none_str, - env_parse_enums=env_parse_enums, - ) - - def _load_env_vars(self) -> Mapping[str, Optional[str]]: - return GoogleSecretManagerMapping( - self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive - ) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})' - - -__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/json.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/json.py deleted file mode 100644 index 837601c..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/json.py +++ /dev/null @@ -1,47 +0,0 @@ -"""JSON file settings source.""" - -from __future__ import annotations as _annotations - -import json -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, -) - -from ..base import ConfigFileSourceMixin, InitSettingsSource -from ..types import DEFAULT_PATH, PathType - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin): - """ - A source class that loads variables from a JSON file - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - json_file: PathType | None = DEFAULT_PATH, - json_file_encoding: str | None = None, - ): - self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file') - self.json_file_encoding = ( - json_file_encoding - if json_file_encoding is not None - else settings_cls.model_config.get('json_file_encoding') - ) - self.json_data = self._read_files(self.json_file_path) - super().__init__(settings_cls, self.json_data) - - def _read_file(self, file_path: Path) -> dict[str, Any]: - with open(file_path, encoding=self.json_file_encoding) as json_file: - return json.load(json_file) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(json_file={self.json_file_path})' - - -__all__ = ['JsonConfigSettingsSource'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/pyproject.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/pyproject.py deleted file mode 100644 index bb02cbb..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/pyproject.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Pyproject TOML file settings source.""" - -from __future__ import annotations as _annotations - -from pathlib import Path -from typing import ( - TYPE_CHECKING, -) - -from .toml import TomlConfigSettingsSource - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource): - """ - A source class that loads variables from a `pyproject.toml` file. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - toml_file: Path | None = None, - ) -> None: - self.toml_file_path = self._pick_pyproject_toml_file( - toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0) - ) - self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get( - 'pyproject_toml_table_header', ('tool', 'pydantic-settings') - ) - self.toml_data = self._read_files(self.toml_file_path) - for key in self.toml_table_header: - self.toml_data = self.toml_data.get(key, {}) - super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data) - - @staticmethod - def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path: - """Pick a `pyproject.toml` file path to use. - - Args: - provided: Explicit path provided when instantiating this class. - depth: Number of directories up the tree to check of a pyproject.toml. - - """ - if provided: - return provided.resolve() - rv = Path.cwd() / 'pyproject.toml' - count = 0 - if not rv.is_file(): - child = rv.parent.parent / 'pyproject.toml' - while count < depth: - if child.is_file(): - return child - if str(child.parent) == rv.root: - break # end discovery after checking system root once - child = child.parent.parent / 'pyproject.toml' - count += 1 - return rv - - -__all__ = ['PyprojectTomlConfigSettingsSource'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/secrets.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/secrets.py deleted file mode 100644 index 00a8f47..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/secrets.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Secrets file settings source.""" - -from __future__ import annotations as _annotations - -import os -import warnings -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, -) - -from pydantic.fields import FieldInfo - -from pydantic_settings.utils import path_type_label - -from ...exceptions import SettingsError -from ..base import PydanticBaseEnvSettingsSource -from ..types import PathType - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - -class SecretsSettingsSource(PydanticBaseEnvSettingsSource): - """ - Source class for loading settings values from secret files. - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - secrets_dir: PathType | None = None, - case_sensitive: bool | None = None, - env_prefix: str | None = None, - env_ignore_empty: bool | None = None, - env_parse_none_str: str | None = None, - env_parse_enums: bool | None = None, - ) -> None: - super().__init__( - settings_cls, case_sensitive, env_prefix, env_ignore_empty, env_parse_none_str, env_parse_enums - ) - self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir') - - def __call__(self) -> dict[str, Any]: - """ - Build fields from "secrets" files. - """ - secrets: dict[str, str | None] = {} - - if self.secrets_dir is None: - return secrets - - secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir - secrets_paths = [Path(p).expanduser() for p in secrets_dirs] - self.secrets_paths = [] - - for path in secrets_paths: - if not path.exists(): - warnings.warn(f'directory "{path}" does not exist') - else: - self.secrets_paths.append(path) - - if not len(self.secrets_paths): - return secrets - - for path in self.secrets_paths: - if not path.is_dir(): - raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}') - - return super().__call__() - - @classmethod - def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None: - """ - Find a file within path's directory matching filename, optionally ignoring case. - - Args: - dir_path: Directory path. - file_name: File name. - case_sensitive: Whether to search for file name case sensitively. - - Returns: - Whether file path or `None` if file does not exist in directory. - """ - for f in dir_path.iterdir(): - if f.name == file_name: - return f - elif not case_sensitive and f.name.lower() == file_name.lower(): - return f - return None - - def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: - """ - Gets the value for field from secret file and a flag to determine whether value is complex. - - Args: - field: The field. - field_name: The field name. - - Returns: - A tuple that contains the value (`None` if the file does not exist), key, and - a flag to determine whether value is complex. - """ - - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): - # paths reversed to match the last-wins behaviour of `env_file` - for secrets_path in reversed(self.secrets_paths): - path = self.find_case_path(secrets_path, env_name, self.case_sensitive) - if not path: - # path does not exist, we currently don't return a warning for this - continue - - if path.is_file(): - return path.read_text().strip(), field_key, value_is_complex - else: - warnings.warn( - f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.', - stacklevel=4, - ) - - return None, field_key, value_is_complex - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})' diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/toml.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/toml.py deleted file mode 100644 index eaff41d..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/toml.py +++ /dev/null @@ -1,66 +0,0 @@ -"""TOML file settings source.""" - -from __future__ import annotations as _annotations - -import sys -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, -) - -from ..base import ConfigFileSourceMixin, InitSettingsSource -from ..types import DEFAULT_PATH, PathType - -if TYPE_CHECKING: - from pydantic_settings.main import BaseSettings - - if sys.version_info >= (3, 11): - import tomllib - else: - tomllib = None - import tomli -else: - tomllib = None - tomli = None - - -def import_toml() -> None: - global tomli - global tomllib - if sys.version_info < (3, 11): - if tomli is not None: - return - try: - import tomli - except ImportError as e: # pragma: no cover - raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e - else: - if tomllib is not None: - return - import tomllib - - -class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin): - """ - A source class that loads variables from a TOML file - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - toml_file: PathType | None = DEFAULT_PATH, - ): - self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file') - self.toml_data = self._read_files(self.toml_file_path) - super().__init__(settings_cls, self.toml_data) - - def _read_file(self, file_path: Path) -> dict[str, Any]: - import_toml() - with open(file_path, mode='rb') as toml_file: - if sys.version_info < (3, 11): - return tomli.load(toml_file) - return tomllib.load(toml_file) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(toml_file={self.toml_file_path})' diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/yaml.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/yaml.py deleted file mode 100644 index 82778b4..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/providers/yaml.py +++ /dev/null @@ -1,75 +0,0 @@ -"""YAML file settings source.""" - -from __future__ import annotations as _annotations - -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, -) - -from ..base import ConfigFileSourceMixin, InitSettingsSource -from ..types import DEFAULT_PATH, PathType - -if TYPE_CHECKING: - import yaml - - from pydantic_settings.main import BaseSettings -else: - yaml = None - - -def import_yaml() -> None: - global yaml - if yaml is not None: - return - try: - import yaml - except ImportError as e: - raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e - - -class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin): - """ - A source class that loads variables from a yaml file - """ - - def __init__( - self, - settings_cls: type[BaseSettings], - yaml_file: PathType | None = DEFAULT_PATH, - yaml_file_encoding: str | None = None, - yaml_config_section: str | None = None, - ): - self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file') - self.yaml_file_encoding = ( - yaml_file_encoding - if yaml_file_encoding is not None - else settings_cls.model_config.get('yaml_file_encoding') - ) - self.yaml_config_section = ( - yaml_config_section - if yaml_config_section is not None - else settings_cls.model_config.get('yaml_config_section') - ) - self.yaml_data = self._read_files(self.yaml_file_path) - - if self.yaml_config_section: - try: - self.yaml_data = self.yaml_data[self.yaml_config_section] - except KeyError: - raise KeyError( - f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}' - ) - super().__init__(settings_cls, self.yaml_data) - - def _read_file(self, file_path: Path) -> dict[str, Any]: - import_yaml() - with open(file_path, encoding=self.yaml_file_encoding) as yaml_file: - return yaml.safe_load(yaml_file) or {} - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})' - - -__all__ = ['YamlConfigSettingsSource'] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/types.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/types.py deleted file mode 100644 index 9a64979..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/types.py +++ /dev/null @@ -1,78 +0,0 @@ -"""Type definitions for pydantic-settings sources.""" - -from __future__ import annotations as _annotations - -from collections.abc import Sequence -from pathlib import Path -from typing import TYPE_CHECKING, Any, Union - -if TYPE_CHECKING: - from pydantic._internal._dataclasses import PydanticDataclass - from pydantic.main import BaseModel - - PydanticModel = Union[PydanticDataclass, BaseModel] -else: - PydanticModel = Any - - -class EnvNoneType(str): - pass - - -class NoDecode: - """Annotation to prevent decoding of a field value.""" - - pass - - -class ForceDecode: - """Annotation to force decoding of a field value.""" - - pass - - -DotenvType = Union[Path, str, Sequence[Union[Path, str]]] -PathType = Union[Path, str, Sequence[Union[Path, str]]] -DEFAULT_PATH: PathType = Path('') - -# This is used as default value for `_env_file` in the `BaseSettings` class and -# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`. -# See the docstring of `BaseSettings` for more details. -ENV_FILE_SENTINEL: DotenvType = Path('') - - -class _CliSubCommand: - pass - - -class _CliPositionalArg: - pass - - -class _CliImplicitFlag: - pass - - -class _CliExplicitFlag: - pass - - -class _CliUnknownArgs: - pass - - -__all__ = [ - 'DEFAULT_PATH', - 'ENV_FILE_SENTINEL', - 'DotenvType', - 'EnvNoneType', - 'ForceDecode', - 'NoDecode', - 'PathType', - 'PydanticModel', - '_CliExplicitFlag', - '_CliImplicitFlag', - '_CliPositionalArg', - '_CliSubCommand', - '_CliUnknownArgs', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/sources/utils.py b/venv/lib/python3.12/site-packages/pydantic_settings/sources/utils.py deleted file mode 100644 index 56bfb3e..0000000 --- a/venv/lib/python3.12/site-packages/pydantic_settings/sources/utils.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Utility functions for pydantic-settings sources.""" - -from __future__ import annotations as _annotations - -from collections import deque -from collections.abc import Mapping, Sequence -from dataclasses import is_dataclass -from enum import Enum -from typing import Any, Optional, cast - -from pydantic import BaseModel, Json, RootModel, Secret -from pydantic._internal._utils import is_model_class -from pydantic.dataclasses import is_pydantic_dataclass -from typing_extensions import get_args, get_origin -from typing_inspection import typing_objects - -from ..exceptions import SettingsError -from ..utils import _lenient_issubclass -from .types import EnvNoneType - - -def _get_env_var_key(key: str, case_sensitive: bool = False) -> str: - return key if case_sensitive else key.lower() - - -def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType: - return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value) - - -def parse_env_vars( - env_vars: Mapping[str, str | None], - case_sensitive: bool = False, - ignore_empty: bool = False, - parse_none_str: str | None = None, -) -> Mapping[str, str | None]: - return { - _get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str) - for k, v in env_vars.items() - if not (ignore_empty and v == '') - } - - -def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool: - # If the model is a root model, the root annotation should be used to - # evaluate the complexity. - if typing_objects.is_typealiastype(annotation) or typing_objects.is_typealiastype(get_origin(annotation)): - annotation = annotation.__value__ - if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel: - annotation = cast('type[RootModel[Any]]', annotation) - root_annotation = annotation.model_fields['root'].annotation - if root_annotation is not None: # pragma: no branch - annotation = root_annotation - - if any(isinstance(md, Json) for md in metadata): # type: ignore[misc] - return False - - origin = get_origin(annotation) - - # Check if annotation is of the form Annotated[type, metadata]. - if typing_objects.is_annotated(origin): - # Return result of recursive call on inner type. - inner, *meta = get_args(annotation) - return _annotation_is_complex(inner, meta) - - if origin is Secret: - return False - - return ( - _annotation_is_complex_inner(annotation) - or _annotation_is_complex_inner(origin) - or hasattr(origin, '__pydantic_core_schema__') - or hasattr(origin, '__get_pydantic_core_schema__') - ) - - -def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: - if _lenient_issubclass(annotation, (str, bytes)): - return False - - return _lenient_issubclass( - annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque) - ) or is_dataclass(annotation) - - -def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: - """Check if a union type contains any complex types.""" - return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) - - -def _annotation_contains_types( - annotation: type[Any] | None, - types: tuple[Any, ...], - is_include_origin: bool = True, - is_strip_annotated: bool = False, -) -> bool: - """Check if a type annotation contains any of the specified types.""" - if is_strip_annotated: - annotation = _strip_annotated(annotation) - if is_include_origin is True and get_origin(annotation) in types: - return True - for type_ in get_args(annotation): - if _annotation_contains_types(type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated): - return True - return annotation in types - - -def _strip_annotated(annotation: Any) -> Any: - if typing_objects.is_annotated(get_origin(annotation)): - return annotation.__origin__ - else: - return annotation - - -def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]: - for type_ in (annotation, get_origin(annotation), *get_args(annotation)): - if _lenient_issubclass(type_, Enum): - if value in tuple(val.value for val in type_): - return type_(value).name - return None - - -def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any: - for type_ in (annotation, get_origin(annotation), *get_args(annotation)): - if _lenient_issubclass(type_, Enum): - if name in tuple(val.name for val in type_): - return type_[name] - return None - - -def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]: - """Get fields from a pydantic model or dataclass.""" - - if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'): - return model_cls.__pydantic_fields__ - if is_model_class(model_cls): - return model_cls.model_fields - raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') - - -def _get_alias_names( - field_name: str, - field_info: Any, - alias_path_args: Optional[dict[str, Optional[int]]] = None, - case_sensitive: bool = True, -) -> tuple[tuple[str, ...], bool]: - """Get alias names for a field, handling alias paths and case sensitivity.""" - from pydantic import AliasChoices, AliasPath - - alias_names: list[str] = [] - is_alias_path_only: bool = True - if not any((field_info.alias, field_info.validation_alias)): - alias_names += [field_name] - is_alias_path_only = False - else: - new_alias_paths: list[AliasPath] = [] - for alias in (field_info.alias, field_info.validation_alias): - if alias is None: - continue - elif isinstance(alias, str): - alias_names.append(alias) - is_alias_path_only = False - elif isinstance(alias, AliasChoices): - for name in alias.choices: - if isinstance(name, str): - alias_names.append(name) - is_alias_path_only = False - else: - new_alias_paths.append(name) - else: - new_alias_paths.append(alias) - for alias_path in new_alias_paths: - name = cast(str, alias_path.path[0]) - name = name.lower() if not case_sensitive else name - if alias_path_args is not None: - alias_path_args[name] = ( - alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None - ) - if not alias_names and is_alias_path_only: - alias_names.append(name) - if not case_sensitive: - alias_names = [alias_name.lower() for alias_name in alias_names] - return tuple(dict.fromkeys(alias_names)), is_alias_path_only - - -def _is_function(obj: Any) -> bool: - """Check if an object is a function.""" - from types import BuiltinFunctionType, FunctionType - - return isinstance(obj, (FunctionType, BuiltinFunctionType)) - - -__all__ = [ - '_annotation_contains_types', - '_annotation_enum_name_to_val', - '_annotation_enum_val_to_name', - '_annotation_is_complex', - '_annotation_is_complex_inner', - '_get_alias_names', - '_get_env_var_key', - '_get_model_fields', - '_is_function', - '_parse_env_none_str', - '_strip_annotated', - '_union_is_complex', - 'parse_env_vars', -] diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/utils.py b/venv/lib/python3.12/site-packages/pydantic_settings/utils.py index 74c99be..73090b8 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings/utils.py +++ b/venv/lib/python3.12/site-packages/pydantic_settings/utils.py @@ -1,19 +1,14 @@ -import sys -import types from pathlib import Path -from typing import Any, _GenericAlias # type: ignore [attr-defined] -from typing_extensions import get_origin - -_PATH_TYPE_LABELS = { - Path.is_dir: 'directory', - Path.is_file: 'file', - Path.is_mount: 'mount point', - Path.is_symlink: 'symlink', - Path.is_block_device: 'block device', - Path.is_char_device: 'char device', - Path.is_fifo: 'FIFO', - Path.is_socket: 'socket', +path_type_labels = { + 'is_dir': 'directory', + 'is_file': 'file', + 'is_mount': 'mount point', + 'is_symlink': 'symlink', + 'is_block_device': 'block device', + 'is_char_device': 'char device', + 'is_fifo': 'FIFO', + 'is_socket': 'socket', } @@ -22,27 +17,8 @@ def path_type_label(p: Path) -> str: Find out what sort of thing a path is. """ assert p.exists(), 'path does not exist' - for method, name in _PATH_TYPE_LABELS.items(): - if method(p): + for method, name in path_type_labels.items(): + if getattr(p, method)(): return name - return 'unknown' # pragma: no cover - - -# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)` -# once we drop support for Python 3.10. -def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover - try: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) - except TypeError: - if get_origin(cls) is not None: - # Up until Python 3.10, isinstance(, type) is True - # (e.g. list[int]) - return False - raise - - -if sys.version_info < (3, 10): - _WithArgsTypes = tuple() -else: - _WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType) + return 'unknown' diff --git a/venv/lib/python3.12/site-packages/pydantic_settings/version.py b/venv/lib/python3.12/site-packages/pydantic_settings/version.py index a13b4dc..2f7444b 100644 --- a/venv/lib/python3.12/site-packages/pydantic_settings/version.py +++ b/venv/lib/python3.12/site-packages/pydantic_settings/version.py @@ -1 +1 @@ -VERSION = '2.11.0' +VERSION = '2.0.3' diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/METADATA b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/METADATA similarity index 86% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/METADATA rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/METADATA index 8dcf31d..a85c4f9 100644 --- a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: python-dotenv -Version: 1.1.1 +Version: 1.0.0 Summary: Read key-value pairs from a .env file and set them as environment variables Home-page: https://github.com/theskumar/python-dotenv Author: Saurabh Kumar @@ -10,11 +10,11 @@ Keywords: environment variables,deployments,settings,env,dotenv,configurations,p Classifier: Development Status :: 5 - Production/Stable Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Intended Audience :: Developers Classifier: Intended Audience :: System Administrators @@ -23,32 +23,20 @@ Classifier: Operating System :: OS Independent Classifier: Topic :: System :: Systems Administration Classifier: Topic :: Utilities Classifier: Environment :: Web Environment -Requires-Python: >=3.9 +Requires-Python: >=3.8 Description-Content-Type: text/markdown License-File: LICENSE Provides-Extra: cli -Requires-Dist: click>=5.0; extra == "cli" -Dynamic: author -Dynamic: author-email -Dynamic: classifier -Dynamic: description -Dynamic: description-content-type -Dynamic: home-page -Dynamic: keywords -Dynamic: license -Dynamic: license-file -Dynamic: provides-extra -Dynamic: requires-python -Dynamic: summary +Requires-Dist: click (>=5.0) ; extra == 'cli' # python-dotenv [![Build Status][build_status_badge]][build_status_link] [![PyPI version][pypi_badge]][pypi_link] -python-dotenv reads key-value pairs from a `.env` file and can set them as environment +Python-dotenv reads key-value pairs from a `.env` file and can set them as environment variables. It helps in the development of applications following the -[12-factor](https://12factor.net/) principles. +[12-factor](http://12factor.net/) principles. - [Getting Started](#getting-started) - [Other Use Cases](#other-use-cases) @@ -72,20 +60,20 @@ If your application takes its configuration from environment variables, like a 1 application, launching it in development is not very practical because you have to set those environment variables yourself. -To help you with that, you can add python-dotenv to your application to make it load the +To help you with that, you can add Python-dotenv to your application to make it load the configuration from a `.env` file when it is present (e.g. in development) while remaining configurable via the environment: ```python from dotenv import load_dotenv -load_dotenv() # take environment variables +load_dotenv() # take environment variables from .env. # Code of your application, which uses environment variables (e.g. from `os.environ` or # `os.getenv`) as if they came from the actual environment. ``` -By default, `load_dotenv` doesn't override existing environment variables and looks for a `.env` file in same directory as python script or searches for it incrementally higher up. +By default, `load_dotenv` doesn't override existing environment variables. To configure the development environment, add a `.env` in the root directory of your project: @@ -244,7 +232,7 @@ empty string. ### Variable expansion -python-dotenv can interpolate variables using POSIX variable expansion. +Python-dotenv can interpolate variables using POSIX variable expansion. With `load_dotenv(override=True)` or `dotenv_values()`, the value of a variable is the first of the values defined in the following list: @@ -286,7 +274,7 @@ people](https://github.com/theskumar/python-dotenv/graphs/contributors). [build_status_badge]: https://github.com/theskumar/python-dotenv/actions/workflows/test.yml/badge.svg [build_status_link]: https://github.com/theskumar/python-dotenv/actions/workflows/test.yml [pypi_badge]: https://badge.fury.io/py/python-dotenv.svg -[pypi_link]: https://badge.fury.io/py/python-dotenv +[pypi_link]: http://badge.fury.io/py/python-dotenv [python_streams]: https://docs.python.org/3/library/io.html # Changelog @@ -296,50 +284,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.1.1] - 2025-06-24 - -## Fixed - -* CLI: Ensure `find_dotenv` work reliably on python 3.13 by [@theskumar] in [#563](https://github.com/theskumar/python-dotenv/pull/563) -* CLI: revert the use of execvpe on Windows by [@wrongontheinternet] in [#566](https://github.com/theskumar/python-dotenv/pull/566) - - -## [1.1.0] - 2025-03-25 - -**Feature** - -- Add support for python 3.13 -- Enhance `dotenv run`, switch to `execvpe` for better resource management and signal handling ([#523]) by [@eekstunt] - -**Fixed** - -- `find_dotenv` and `load_dotenv` now correctly looks up at the current directory when running in debugger or pdb ([#553] by [@randomseed42]) - -**Misc** - -- Drop support for Python 3.8 - -## [1.0.1] - 2024-01-23 - -**Fixed** - -* Gracefully handle code which has been imported from a zipfile ([#456] by [@samwyma]) -* Allow modules using `load_dotenv` to be reloaded when launched in a separate thread ([#497] by [@freddyaboulton]) -* Fix file not closed after deletion, handle error in the rewrite function ([#469] by [@Qwerty-133]) - -**Misc** -* Use pathlib.Path in tests ([#466] by [@eumiro]) -* Fix year in release date in changelog.md ([#454] by [@jankislinger]) -* Use https in README links ([#474] by [@Nicals]) - -## [1.0.0] - 2023-02-24 +## [1.0.0] **Fixed** * Drop support for python 3.7, add python 3.12-dev (#449 by [@theskumar]) * Handle situations where the cwd does not exist. (#446 by [@jctanner]) -## [0.21.1] - 2023-01-21 +## [0.21.1] - 2022-01-21 **Added** @@ -631,7 +583,7 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## 0.5.1 -- Fix `find_dotenv` - it now start search from the file where this +- Fix find\_dotenv - it now start search from the file where this function is called from. ## 0.5.0 @@ -655,13 +607,6 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). [#176]: https://github.com/theskumar/python-dotenv/issues/176 [#183]: https://github.com/theskumar/python-dotenv/issues/183 [#359]: https://github.com/theskumar/python-dotenv/issues/359 -[#469]: https://github.com/theskumar/python-dotenv/issues/469 -[#456]: https://github.com/theskumar/python-dotenv/issues/456 -[#466]: https://github.com/theskumar/python-dotenv/issues/466 -[#454]: https://github.com/theskumar/python-dotenv/issues/454 -[#474]: https://github.com/theskumar/python-dotenv/issues/474 -[#523]: https://github.com/theskumar/python-dotenv/issues/523 -[#553]: https://github.com/theskumar/python-dotenv/issues/553 [@alanjds]: https://github.com/alanjds [@altendky]: https://github.com/altendky @@ -672,31 +617,24 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). [@cjauvin]: https://github.com/cjauvin [@eaf]: https://github.com/eaf [@earlbread]: https://github.com/earlbread -[@eekstunt]: https://github.com/eekstunt [@eggplants]: https://github.com/@eggplants [@ekohl]: https://github.com/ekohl [@elbehery95]: https://github.com/elbehery95 -[@eumiro]: https://github.com/eumiro [@Flimm]: https://github.com/Flimm -[@freddyaboulton]: https://github.com/freddyaboulton [@gergelyk]: https://github.com/gergelyk [@gongqingkui]: https://github.com/gongqingkui [@greyli]: https://github.com/greyli [@harveer07]: https://github.com/@harveer07 [@jadutter]: https://github.com/jadutter -[@jankislinger]: https://github.com/jankislinger [@jctanner]: https://github.com/jctanner [@larsks]: https://github.com/@larsks [@lsmith77]: https://github.com/lsmith77 [@mgorny]: https://github.com/mgorny [@naorlivne]: https://github.com/@naorlivne -[@Nicals]: https://github.com/Nicals [@Nougat-Waffle]: https://github.com/Nougat-Waffle [@qnighy]: https://github.com/qnighy -[@Qwerty-133]: https://github.com/Qwerty-133 [@rabinadk1]: https://github.com/@rabinadk1 [@sammck]: https://github.com/@sammck -[@samwyma]: https://github.com/samwyma [@snobu]: https://github.com/snobu [@techalchemy]: https://github.com/techalchemy [@theGOTOguy]: https://github.com/theGOTOguy @@ -706,13 +644,9 @@ project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). [@x-yuri]: https://github.com/x-yuri [@yannham]: https://github.com/yannham [@zueve]: https://github.com/zueve -[@randomseed42]: https://github.com/zueve -[@wrongontheinternet]: https://github.com/wrongontheinternet -[Unreleased]: https://github.com/theskumar/python-dotenv/compare/v1.1.1...HEAD -[1.1.1]: https://github.com/theskumar/python-dotenv/compare/v1.1.0...1.1.1 -[1.1.0]: https://github.com/theskumar/python-dotenv/compare/v1.0.1...v1.1.0 -[1.0.1]: https://github.com/theskumar/python-dotenv/compare/v1.0.0...v1.0.1 + +[Unreleased]: https://github.com/theskumar/python-dotenv/compare/v1.0.0...HEAD [1.0.0]: https://github.com/theskumar/python-dotenv/compare/v0.21.0...v1.0.0 [0.21.1]: https://github.com/theskumar/python-dotenv/compare/v0.21.0...v0.21.1 [0.21.0]: https://github.com/theskumar/python-dotenv/compare/v0.20.0...v0.21.0 diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/RECORD similarity index 51% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/RECORD rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/RECORD index 941259b..097cbb1 100644 --- a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/RECORD +++ b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/RECORD @@ -1,4 +1,4 @@ -../../../bin/dotenv,sha256=nShk7mUtz-Nvgv-GkAlo04_90ec22Y8jjj1h5J-cp3s,233 +../../../bin/dotenv,sha256=D_1tHcDz_R59CEh0UvAZiuSdFp_MxVYmSyIBtSA9oUc,237 dotenv/__init__.py,sha256=WBU5SfSiKAhS3hzu17ykNuuwbuwyDCX91Szv4vUeOuM,1292 dotenv/__main__.py,sha256=N0RhLG7nHIqtlJHwwepIo-zbJPNx9sewCCRGY528h_4,129 dotenv/__pycache__/__init__.cpython-312.pyc,, @@ -9,18 +9,18 @@ dotenv/__pycache__/main.cpython-312.pyc,, dotenv/__pycache__/parser.cpython-312.pyc,, dotenv/__pycache__/variables.cpython-312.pyc,, dotenv/__pycache__/version.cpython-312.pyc,, -dotenv/cli.py,sha256=ut83SItbWcmEahAkSOzkHqvRKhqhj0tA53vcXpyleOM,6197 +dotenv/cli.py,sha256=_ttQuR9Yl4k1PT53ByISkDjJ3kO_N_LzIDZzZ95uXEk,5809 dotenv/ipython.py,sha256=avI6aez_RxnBptYgchIquF2TSgKI-GOhY3ppiu3VuWE,1303 -dotenv/main.py,sha256=HJgkS0XZcd0f2VZaVGxlUcrOEhqBcmQ6Lz9hQrMfaus,12467 +dotenv/main.py,sha256=6j1GW8kNeZAooqffdajLne_dq_TJLi2Mk63DRNJjXLk,11932 dotenv/parser.py,sha256=QgU5HwMwM2wMqt0vz6dHTJ4nzPmwqRqvi4MSyeVifgU,5186 dotenv/py.typed,sha256=8PjyZ1aVoQpRVvt71muvuq5qE-jTFZkK-GLHkhdebmc,26 dotenv/variables.py,sha256=CD0qXOvvpB3q5RpBQMD9qX6vHX7SyW-SuiwGMFSlt08,2348 -dotenv/version.py,sha256=q8_5C0f-8mHWNb6mMw02zlYPnEGXBqvOmP3z0CEwZKM,22 -python_dotenv-1.1.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -python_dotenv-1.1.1.dist-info/METADATA,sha256=dELvSKXwZ-NbQKAe-k-uJM8khmVN8ZM92B5tyY801yY,24628 -python_dotenv-1.1.1.dist-info/RECORD,, -python_dotenv-1.1.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -python_dotenv-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91 -python_dotenv-1.1.1.dist-info/entry_points.txt,sha256=yRl1rCbswb1nQTQ_gZRlCw5QfabztUGnfGWLhlXFNdI,47 -python_dotenv-1.1.1.dist-info/licenses/LICENSE,sha256=gGGbcEnwjIFoOtDgHwjyV6hAZS3XHugxRtNmWMfSwrk,1556 -python_dotenv-1.1.1.dist-info/top_level.txt,sha256=eyqUH4SHJNr6ahOYlxIunTr4XinE8Z5ajWLdrK3r0D8,7 +dotenv/version.py,sha256=J-j-u0itpEFT6irdmWmixQqYMadNl1X91TxUmoiLHMI,22 +python_dotenv-1.0.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +python_dotenv-1.0.0.dist-info/LICENSE,sha256=gGGbcEnwjIFoOtDgHwjyV6hAZS3XHugxRtNmWMfSwrk,1556 +python_dotenv-1.0.0.dist-info/METADATA,sha256=0oze1EyeRIUTg91jCTJGbnxQR6mz_FkOW73CmeueUak,21991 +python_dotenv-1.0.0.dist-info/RECORD,, +python_dotenv-1.0.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +python_dotenv-1.0.0.dist-info/WHEEL,sha256=2wepM1nk4DS4eFpYrW1TTqPcoGNfHhhO_i5m4cOimbo,92 +python_dotenv-1.0.0.dist-info/entry_points.txt,sha256=yRl1rCbswb1nQTQ_gZRlCw5QfabztUGnfGWLhlXFNdI,47 +python_dotenv-1.0.0.dist-info/top_level.txt,sha256=eyqUH4SHJNr6ahOYlxIunTr4XinE8Z5ajWLdrK3r0D8,7 diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/WHEEL new file mode 100644 index 0000000..57e3d84 --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.38.4) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/entry_points.txt b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/entry_points.txt similarity index 100% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/entry_points.txt rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/entry_points.txt diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/top_level.txt similarity index 100% rename from venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/top_level.txt rename to venv/lib/python3.12/site-packages/python_dotenv-1.0.0.dist-info/top_level.txt diff --git a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/WHEEL deleted file mode 100644 index e7fa31b..0000000 --- a/venv/lib/python3.12/site-packages/python_dotenv-1.1.1.dist-info/WHEEL +++ /dev/null @@ -1,5 +0,0 @@ -Wheel-Version: 1.0 -Generator: setuptools (80.9.0) -Root-Is-Purelib: true -Tag: py3-none-any - diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/METADATA similarity index 85% rename from venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/METADATA index bd04b7c..314d09b 100644 --- a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/METADATA @@ -1,6 +1,6 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: python-jose -Version: 3.5.0 +Version: 3.3.0 Summary: JOSE implementation in Python Home-page: http://github.com/mpdavis/python-jose Author: Michael Davis @@ -11,6 +11,7 @@ Project-URL: Source, https://github.com/mpdavis/python-jose/ Project-URL: Tracker, https://github.com/mpdavis/python-jose/issues/ Project-URL: Changelog, https://github.com/mpdavis/python-jose/blob/master/CHANGELOG.md Keywords: jose jws jwe jwt json web token security signing +Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: Natural Language :: English @@ -18,28 +19,24 @@ Classifier: License :: OSI Approved :: MIT License Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.6 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Utilities -Requires-Python: >=3.9 License-File: LICENSE -Requires-Dist: ecdsa!=0.15 -Requires-Dist: rsa!=4.1.1,!=4.4,<5.0,>=4.0 -Requires-Dist: pyasn1>=0.5.0 -Provides-Extra: test -Requires-Dist: pytest; extra == "test" -Requires-Dist: pytest-cov; extra == "test" +Requires-Dist: ecdsa (!=0.15) +Requires-Dist: rsa +Requires-Dist: pyasn1 Provides-Extra: cryptography -Requires-Dist: cryptography>=3.4.0; extra == "cryptography" +Requires-Dist: cryptography (>=3.4.0) ; extra == 'cryptography' Provides-Extra: pycrypto -Requires-Dist: pycrypto<2.7.0,>=2.6.0; extra == "pycrypto" +Requires-Dist: pycrypto (<2.7.0,>=2.6.0) ; extra == 'pycrypto' +Requires-Dist: pyasn1 ; extra == 'pycrypto' Provides-Extra: pycryptodome -Requires-Dist: pycryptodome<4.0.0,>=3.3.1; extra == "pycryptodome" -Dynamic: license-file +Requires-Dist: pycryptodome (<4.0.0,>=3.3.1) ; extra == 'pycryptodome' +Requires-Dist: pyasn1 ; extra == 'pycryptodome' python-jose =========== @@ -130,8 +127,8 @@ This library was originally based heavily on the work of the folks over at PyJWT .. |pypi| image:: https://img.shields.io/pypi/v/python-jose?style=flat-square :target: https://pypi.org/project/python-jose/ :alt: PyPI -.. |Github Actions CI Status| image:: https://github.com/mpdavis/python-jose/actions/workflows/ci.yml/badge.svg - :target: https://github.com/mpdavis/python-jose/actions/workflows/ci.yml +.. |Github Actions CI Status| image:: https://github.com/mpdavis/python-jose/workflows/main/badge.svg?branch=master + :target: https://github.com/mpdavis/python-jose/actions?workflow=main :alt: Github Actions CI Status .. |Coverage Status| image:: http://codecov.io/github/mpdavis/python-jose/coverage.svg?branch=master :target: http://codecov.io/github/mpdavis/python-jose?branch=master @@ -147,3 +144,5 @@ This library was originally based heavily on the work of the folks over at PyJWT .. |style| image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/psf/black :alt: Code style: black + + diff --git a/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/RECORD new file mode 100644 index 0000000..37d3013 --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/RECORD @@ -0,0 +1,37 @@ +jose/__init__.py,sha256=0XQau8AXQwNwztdDWVr6l7PyWq9w0qN1R0PGcrsMIGM,322 +jose/__pycache__/__init__.cpython-312.pyc,, +jose/__pycache__/constants.cpython-312.pyc,, +jose/__pycache__/exceptions.cpython-312.pyc,, +jose/__pycache__/jwe.cpython-312.pyc,, +jose/__pycache__/jwk.cpython-312.pyc,, +jose/__pycache__/jws.cpython-312.pyc,, +jose/__pycache__/jwt.cpython-312.pyc,, +jose/__pycache__/utils.cpython-312.pyc,, +jose/backends/__init__.py,sha256=yDExDpMlV6U4IBgk2Emov6cpQ2zQftFEh0J3yGaV2Lo,1091 +jose/backends/__pycache__/__init__.cpython-312.pyc,, +jose/backends/__pycache__/_asn1.cpython-312.pyc,, +jose/backends/__pycache__/base.cpython-312.pyc,, +jose/backends/__pycache__/cryptography_backend.cpython-312.pyc,, +jose/backends/__pycache__/ecdsa_backend.cpython-312.pyc,, +jose/backends/__pycache__/native.cpython-312.pyc,, +jose/backends/__pycache__/rsa_backend.cpython-312.pyc,, +jose/backends/_asn1.py,sha256=etzWxBjkt0Et19_IQ92Pj61bAe0nCgPN7bTvSuz8W3s,2655 +jose/backends/base.py,sha256=0kuposKfixAR2W3enKuYdqEZpVG56ODOQDEdgq_pmvs,2224 +jose/backends/cryptography_backend.py,sha256=28-792EKVGjjq2nUoCWdfyPGkoXfWN5vHFO7uolCtog,22763 +jose/backends/ecdsa_backend.py,sha256=ORORepIpIS9D4s6Vtmhli5GZV9kj3CJj2_Mv0ARKGqE,5055 +jose/backends/native.py,sha256=9zyounmjG1ZgVJYkseMcDosJOBILLRyu_UbzhH7ZZ1o,2289 +jose/backends/rsa_backend.py,sha256=RKIC_bphhe52t2D_jEINO_ngj50ty9wXnv7cVO1EmdE,10942 +jose/constants.py,sha256=A0yHNjsby-YVOeKhcoN0rxoM8bai1JlVDvZx82UCZeE,2596 +jose/exceptions.py,sha256=K_ueFBsmTwQySE0CU09iMthOAdPaTQ_HvzRz9lYT1ls,791 +jose/jwe.py,sha256=jSBN3aT2D7xAQ3D-5cVf_9kZebchAI3qoaf-3yMLanY,21976 +jose/jwk.py,sha256=3A1dXXfhGIMQvT43EBAQgiShQZuqLpUZk_xWvW7c9cs,2024 +jose/jws.py,sha256=qgMDRIlyGbGfAGApQfuAL5Qr66Qqa8aYUC3qUO8qM_g,7820 +jose/jwt.py,sha256=7czQxPsfOavLpY6jJTetdPN_FQDcZmmkaZ2QtV3bVPw,17310 +jose/utils.py,sha256=_doSyRne-OygjSI3Iz1kWTSGnwVHHMA6_wYHOS1rhCw,3190 +python_jose-3.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +python_jose-3.3.0.dist-info/LICENSE,sha256=peYY7ubUlvd62K5w_qbt8UgVlVji0ih4fZB2yQCi-SY,1081 +python_jose-3.3.0.dist-info/METADATA,sha256=Sk_zCqxtDfFMG5lAL6EG7Br3KP0yhtw_IsJBwZaDliM,5403 +python_jose-3.3.0.dist-info/RECORD,, +python_jose-3.3.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +python_jose-3.3.0.dist-info/WHEEL,sha256=Z-nyYpwrcSqxfdux5Mbn_DQ525iP7J2DG3JgGvOYyTQ,110 +python_jose-3.3.0.dist-info/top_level.txt,sha256=WIdGzeaROX_xI9hGqyB3h4KKXKGKU2XmV1XphZWIrD8,19 diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/WHEEL similarity index 70% rename from venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/WHEEL rename to venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/WHEEL index 5f133db..01b8fc7 100644 --- a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/WHEEL +++ b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (80.9.0) +Generator: bdist_wheel (0.36.2) Root-Is-Purelib: true Tag: py2-none-any Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/top_level.txt new file mode 100644 index 0000000..3ac440a --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_jose-3.3.0.dist-info/top_level.txt @@ -0,0 +1,2 @@ +jose +jose/backends diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/RECORD deleted file mode 100644 index 6efc192..0000000 --- a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/RECORD +++ /dev/null @@ -1,37 +0,0 @@ -jose/__init__.py,sha256=x8vWB0drBqifxYnt-lognbrRuKb6X3qRxAd5l053gDw,322 -jose/__pycache__/__init__.cpython-312.pyc,, -jose/__pycache__/constants.cpython-312.pyc,, -jose/__pycache__/exceptions.cpython-312.pyc,, -jose/__pycache__/jwe.cpython-312.pyc,, -jose/__pycache__/jwk.cpython-312.pyc,, -jose/__pycache__/jws.cpython-312.pyc,, -jose/__pycache__/jwt.cpython-312.pyc,, -jose/__pycache__/utils.cpython-312.pyc,, -jose/backends/__init__.py,sha256=kaDsN5XktlfA8F_3060PeXdaL4BNdvoUTzPLAjP_v_s,861 -jose/backends/__pycache__/__init__.cpython-312.pyc,, -jose/backends/__pycache__/_asn1.cpython-312.pyc,, -jose/backends/__pycache__/base.cpython-312.pyc,, -jose/backends/__pycache__/cryptography_backend.cpython-312.pyc,, -jose/backends/__pycache__/ecdsa_backend.cpython-312.pyc,, -jose/backends/__pycache__/native.cpython-312.pyc,, -jose/backends/__pycache__/rsa_backend.cpython-312.pyc,, -jose/backends/_asn1.py,sha256=2CqnRB7LojTrNU4d1HC9BA2WkJv5OOM6gyn6B-tVwkk,2656 -jose/backends/base.py,sha256=0kuposKfixAR2W3enKuYdqEZpVG56ODOQDEdgq_pmvs,2224 -jose/backends/cryptography_backend.py,sha256=v1XqO6PIUpYwyAAsMob1FD9D4q6rPwfX7CGV-KxFlAU,22175 -jose/backends/ecdsa_backend.py,sha256=ORORepIpIS9D4s6Vtmhli5GZV9kj3CJj2_Mv0ARKGqE,5055 -jose/backends/native.py,sha256=uZuP8EqihAPsmGdxslMyhh-DGoe1yXLXmB_P-2zXyS8,2096 -jose/backends/rsa_backend.py,sha256=-tiQF_G2v16a5PLCLjEVwSoYaeBy3h-Tj6KKmtYlAuY,10941 -jose/constants.py,sha256=tPZLo6oI8mesxFXOCiulE--GcANW1V37wkO0f1vVvqY,2625 -jose/exceptions.py,sha256=K_ueFBsmTwQySE0CU09iMthOAdPaTQ_HvzRz9lYT1ls,791 -jose/jwe.py,sha256=L7GZsKm6qc2ApDtOnM0YDs2KhP1R3hWMoSoIKi8cQQg,22700 -jose/jwk.py,sha256=TuIrPoKkVFZcwrnp_IcwSdUJL79-pAGCmauAyysmCoQ,1994 -jose/jws.py,sha256=P2SAUhO6ZxjhWk6XHFpulgpREfAHJ_ktAgzPg-OJ_3w,7894 -jose/jwt.py,sha256=OXVuHOP6g05tHyzo9eP4tLn8RzqbdpKrEWU6VwtNOrA,18158 -jose/utils.py,sha256=3R6EViEPwc2NreO1njUsab9rHKnc6fzfRJmWo9f4Y90,4824 -python_jose-3.5.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -python_jose-3.5.0.dist-info/METADATA,sha256=FA4Lhvk8-BZzGNOUbzr4aH84uj0ytjG5SMK9p7oQLwY,5508 -python_jose-3.5.0.dist-info/RECORD,, -python_jose-3.5.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -python_jose-3.5.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109 -python_jose-3.5.0.dist-info/licenses/LICENSE,sha256=peYY7ubUlvd62K5w_qbt8UgVlVji0ih4fZB2yQCi-SY,1081 -python_jose-3.5.0.dist-info/top_level.txt,sha256=tWZmXhRSm0aANjAdRbjirCMnYOQdMwpQqdJUSmANjtk,5 diff --git a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/top_level.txt deleted file mode 100644 index 268baa1..0000000 --- a/venv/lib/python3.12/site-packages/python_jose-3.5.0.dist-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -jose diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/METADATA b/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/METADATA deleted file mode 100644 index 155ce8b..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/METADATA +++ /dev/null @@ -1,40 +0,0 @@ -Metadata-Version: 2.4 -Name: python-multipart -Version: 0.0.20 -Summary: A streaming multipart parser for Python -Project-URL: Homepage, https://github.com/Kludex/python-multipart -Project-URL: Documentation, https://kludex.github.io/python-multipart/ -Project-URL: Changelog, https://github.com/Kludex/python-multipart/blob/master/CHANGELOG.md -Project-URL: Source, https://github.com/Kludex/python-multipart -Author-email: Andrew Dunham , Marcelo Trylesinski -License-Expression: Apache-2.0 -License-File: LICENSE.txt -Classifier: Development Status :: 5 - Production/Stable -Classifier: Environment :: Web Environment -Classifier: Intended Audience :: Developers -Classifier: License :: OSI Approved :: Apache Software License -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3 :: Only -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Programming Language :: Python :: 3.10 -Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Topic :: Software Development :: Libraries :: Python Modules -Requires-Python: >=3.8 -Description-Content-Type: text/markdown - -# [Python-Multipart](https://kludex.github.io/python-multipart/) - -[![Package version](https://badge.fury.io/py/python-multipart.svg)](https://pypi.python.org/pypi/python-multipart) -[![Supported Python Version](https://img.shields.io/pypi/pyversions/python-multipart.svg?color=%2334D058)](https://pypi.org/project/python-multipart) - ---- - -`python-multipart` is an Apache2-licensed streaming multipart parser for Python. -Test coverage is currently 100%. - -## Why? - -Because streaming uploads are awesome for large files. diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/RECORD b/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/RECORD deleted file mode 100644 index f80836d..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/RECORD +++ /dev/null @@ -1,23 +0,0 @@ -multipart/__init__.py,sha256=_ttxOAFnTN4jeac-_8NeXpaXYYo0PPEIp8Ogo4YFNHE,935 -multipart/__pycache__/__init__.cpython-312.pyc,, -multipart/__pycache__/decoders.cpython-312.pyc,, -multipart/__pycache__/exceptions.cpython-312.pyc,, -multipart/__pycache__/multipart.cpython-312.pyc,, -multipart/decoders.py,sha256=XvkAwTU9UFPiXkc0hkvovHf0W6H3vK-2ieWlhav02hQ,40 -multipart/exceptions.py,sha256=6D_X-seiOmMAlIeiGlPGUs8-vpcvIGJeQycFMDb1f7A,42 -multipart/multipart.py,sha256=8fDH14j_VMbrch_58wlzi63XNARGv80kOZAyN72aG7A,41 -python_multipart-0.0.20.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -python_multipart-0.0.20.dist-info/METADATA,sha256=h2GtPOVShbVkpBUrjp5KE3t6eiJJhd0_WCaCXrb5TgU,1817 -python_multipart-0.0.20.dist-info/RECORD,, -python_multipart-0.0.20.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -python_multipart-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 -python_multipart-0.0.20.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556 -python_multipart/__init__.py,sha256=Nlw6Yrc__qXnCZLo17OzbJR2w2mwiSFk69IG4Wl35EU,512 -python_multipart/__pycache__/__init__.cpython-312.pyc,, -python_multipart/__pycache__/decoders.cpython-312.pyc,, -python_multipart/__pycache__/exceptions.cpython-312.pyc,, -python_multipart/__pycache__/multipart.cpython-312.pyc,, -python_multipart/decoders.py,sha256=JM43FMNn_EKP0MI2ZkuZHhNa0MOASoIR0U5TvdG585k,6669 -python_multipart/exceptions.py,sha256=a9buSOv_eiHZoukEJhdWX9LJYSJ6t7XOK3ZEaWoQZlk,992 -python_multipart/multipart.py,sha256=pk3o3eB3KXbNxzOBxbEjCdz-1ESEZIMXVIfl12grG-o,76427 -python_multipart/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/METADATA b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/METADATA new file mode 100644 index 0000000..916367c --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/METADATA @@ -0,0 +1,69 @@ +Metadata-Version: 2.1 +Name: python-multipart +Version: 0.0.6 +Summary: A streaming multipart parser for Python +Project-URL: Homepage, https://github.com/andrew-d/python-multipart +Project-URL: Documentation, https://andrew-d.github.io/python-multipart/ +Project-URL: Changelog, https://github.com/andrew-d/python-multipart/tags +Project-URL: Source, https://github.com/andrew-d/python-multipart +Author-email: Andrew Dunham +License-Expression: Apache-2.0 +License-File: LICENSE.txt +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Topic :: Software Development :: Libraries :: Python Modules +Requires-Python: >=3.7 +Provides-Extra: dev +Requires-Dist: atomicwrites==1.2.1; extra == 'dev' +Requires-Dist: attrs==19.2.0; extra == 'dev' +Requires-Dist: coverage==6.5.0; extra == 'dev' +Requires-Dist: hatch; extra == 'dev' +Requires-Dist: invoke==1.7.3; extra == 'dev' +Requires-Dist: more-itertools==4.3.0; extra == 'dev' +Requires-Dist: pbr==4.3.0; extra == 'dev' +Requires-Dist: pluggy==1.0.0; extra == 'dev' +Requires-Dist: py==1.11.0; extra == 'dev' +Requires-Dist: pytest-cov==4.0.0; extra == 'dev' +Requires-Dist: pytest-timeout==2.1.0; extra == 'dev' +Requires-Dist: pytest==7.2.0; extra == 'dev' +Requires-Dist: pyyaml==5.1; extra == 'dev' +Description-Content-Type: text/x-rst + +================== + Python-Multipart +================== + +.. image:: https://github.com/andrew-d/python-multipart/actions/workflows/test.yaml/badge.svg + :target: https://github.com/andrew-d/python-multipart/actions + + +python-multipart is an Apache2 licensed streaming multipart parser for Python. +Test coverage is currently 100%. +Documentation is available `here`_. + +.. _here: https://andrew-d.github.io/python-multipart/ + +Why? +---- + +Because streaming uploads are awesome for large files. + +How to Test +----------- + +If you want to test: + +.. code-block:: bash + + $ pip install .[dev] + $ inv test diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/RECORD b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/RECORD new file mode 100644 index 0000000..29b2be6 --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/RECORD @@ -0,0 +1,62 @@ +multipart/__init__.py,sha256=EaZd7hXXXNz5RWfzZ4lr-wKWXC4anMNWE7u4tPXtWr0,335 +multipart/__pycache__/__init__.cpython-312.pyc,, +multipart/__pycache__/decoders.cpython-312.pyc,, +multipart/__pycache__/exceptions.cpython-312.pyc,, +multipart/__pycache__/multipart.cpython-312.pyc,, +multipart/decoders.py,sha256=6LeCVARmDrQgmMsaul1WUIf79Q-mLE9swhGxumQe_98,6107 +multipart/exceptions.py,sha256=yDZ9pqq3Y9ZMCvj2TkAvOcNdMjFHjLnHl4luFnzt750,1410 +multipart/multipart.py,sha256=ZRc1beZCgCIXkYe0Xwxh_g4nFdrp3eEid4XODYIfqgQ,71230 +multipart/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +multipart/tests/__pycache__/__init__.cpython-312.pyc,, +multipart/tests/__pycache__/compat.cpython-312.pyc,, +multipart/tests/__pycache__/test_multipart.cpython-312.pyc,, +multipart/tests/compat.py,sha256=3aowcimO1SYU6WqS3GlUJ3jmkgLH63e8AsUPjlta1xU,4266 +multipart/tests/test_data/http/CR_in_header.http,sha256=XEimN_BgEqQXCqK463bMgD9PKIQeLrQhWt2M3vNr9cE,149 +multipart/tests/test_data/http/CR_in_header.yaml,sha256=OEzE2PqK78fi9kjM23YOu4xM0zQ_LRwSiwqFNAmku50,73 +multipart/tests/test_data/http/CR_in_header_value.http,sha256=pf4sP-l4_hzZ8Kr51gUE6CFcCifuWSZ10-vnx6mtXDg,149 +multipart/tests/test_data/http/CR_in_header_value.yaml,sha256=WjqJNYL-cUH2n9k-Xdy1YDvSfDqqXxsiinBDn3HTUu4,73 +multipart/tests/test_data/http/almost_match_boundary.http,sha256=jIsp1M6BHQIHF9o965z3Pt8TFncVvaBj5N43hprRpBM,264 +multipart/tests/test_data/http/almost_match_boundary.yaml,sha256=Hr7WZBwZrbf4vjurjRzGGeY9tFVJLRRmV1rEFXop-6s,300 +multipart/tests/test_data/http/almost_match_boundary_without_CR.http,sha256=KviMqo_FUy1N1-b-YUfyWhs5PmN6_fU7qhMYFTGnUhI,132 +multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml,sha256=HjlUni-nuX3bG2-3FILo4GLBpLD4DImQ48VPlfnfIWY,167 +multipart/tests/test_data/http/almost_match_boundary_without_LF.http,sha256=KylmJ0O-RfnUnXbjVhwJpzHsWqNTPJn29_wfsvrG7AM,133 +multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml,sha256=tkzz_kOFZtkarmMnTen355nm8McPwbmPmWGMxUUBSzU,171 +multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http,sha256=L6bzRistD4X5TTd1zBtfR6gM4EQL77_iBI_Pgaw4ufw,133 +multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml,sha256=cFKxwFMYTo9PKRb04Iai__mY9KG29IPkSm3p80DgEZw,171 +multipart/tests/test_data/http/bad_end_of_headers.http,sha256=ucEDylTCg1_hdEVkIc-1k8ZQ-CBIf5uXfDKbSBsSaF0,149 +multipart/tests/test_data/http/bad_end_of_headers.yaml,sha256=1UHERY2D7tp0HEUl5xD4SiotP2skETmBOF5EjcG2HTw,73 +multipart/tests/test_data/http/bad_header_char.http,sha256=zTqXFNQ9yrbc82vubPg95T4edg1Ueh2xadlVD2lO51A,149 +multipart/tests/test_data/http/bad_header_char.yaml,sha256=9ykVsASnvYvX51qtkCJqhgegeN-hoSU40MsYQvqeVNo,73 +multipart/tests/test_data/http/bad_initial_boundary.http,sha256=IGFSkpmw21XfAXr0xOHwj0vnhxyj-uCWVjcljo68LLo,149 +multipart/tests/test_data/http/bad_initial_boundary.yaml,sha256=eBSbue0BYDYhYtKdBCnm1LGq0O_fOMwV6ZoLpZFDFM4,72 +multipart/tests/test_data/http/base64_encoding.http,sha256=fDbr4BgLdNS8kYiTO7g4HxB81hvmiD2sRUCAoijfRx0,173 +multipart/tests/test_data/http/base64_encoding.yaml,sha256=cz2KxZxoi81MiXRh7DmJQOWcdqQH5ahkrJydGYv4hpU,125 +multipart/tests/test_data/http/empty_header.http,sha256=-wSHHSLu1D2wfdC8Zcaw5TX_USTvWz56CANpsceOZYQ,130 +multipart/tests/test_data/http/empty_header.yaml,sha256=4xdVCYJ-l88HMXkMLNkSQoLNgURoGcKzR1AclPLpkOc,73 +multipart/tests/test_data/http/multiple_fields.http,sha256=6p93ls_B7bk8mXPYhsrFwvktSX8CuRdUH4vn-EZBaRM,242 +multipart/tests/test_data/http/multiple_fields.yaml,sha256=mePM5DVfAzty7QNEEyMu2qrFI28TbG9yWRvWFpWj7Jo,197 +multipart/tests/test_data/http/multiple_files.http,sha256=EtmagVBVpsFGnCqlwfKgswQfU8lGa3QNkP6GVJBa5A0,348 +multipart/tests/test_data/http/multiple_files.yaml,sha256=QO9JMgTvkL2EmIWAl8LcbDrkfNmDk0eA5SOk3gFuFWE,260 +multipart/tests/test_data/http/quoted_printable_encoding.http,sha256=--yYceg17SmqIJsazw-SFChdxeTAq8zV4lzPVM_QMrM,180 +multipart/tests/test_data/http/quoted_printable_encoding.yaml,sha256=G_L6lnP-e4uHfGpYQFopxDdpbd_EbxL2oY8N910BTOI,127 +multipart/tests/test_data/http/single_field.http,sha256=JjdSwFiM0mG07HYzBCcjzeqgqAA9glx-VcRUjkOh8cA,149 +multipart/tests/test_data/http/single_field.yaml,sha256=HMXd14-m9sKBvTsnzWOaG12_3wve5SoXeUISF93wlRc,139 +multipart/tests/test_data/http/single_field_blocks.http,sha256=4laZAIbFmxERZtgPWzuOihvEhLWD1NGTSdqZ6Ra58Ns,115 +multipart/tests/test_data/http/single_field_blocks.yaml,sha256=6mKvHtmiXh6OxoibJsx5pUreIMyQyPb_DWy7GEG9BX8,147 +multipart/tests/test_data/http/single_field_longer.http,sha256=BTBt1MsUaxuHauu-mljb3lU-8Z2dpjRN_lkZW4pkDXA,262 +multipart/tests/test_data/http/single_field_longer.yaml,sha256=aENhQPtHaTPIvgJbdiDHvcOtcthEEUHCQIEfLj0aalY,293 +multipart/tests/test_data/http/single_field_single_file.http,sha256=G4dV0iCSjvEk5DSJ1VXWy6R8Hon3-WOExep41nPWVeQ,192 +multipart/tests/test_data/http/single_field_single_file.yaml,sha256=QO9gqdXQsoizLji9r8kdlPWHJB5vO7wszqP1fHvsNV8,189 +multipart/tests/test_data/http/single_field_with_leading_newlines.http,sha256=YfNEUdZxbi4bBGTU4T4WSQZ6QJDJlcLZUczYzGU5Jaw,153 +multipart/tests/test_data/http/single_field_with_leading_newlines.yaml,sha256=HMXd14-m9sKBvTsnzWOaG12_3wve5SoXeUISF93wlRc,139 +multipart/tests/test_data/http/single_file.http,sha256=axRB0Keb4uhAfHxt7Na1x9-PQHCiiKK8s38a2GG860E,202 +multipart/tests/test_data/http/single_file.yaml,sha256=eUKyGkNTDrXdGni4EyEDbxDBTfAKsstVQ5O5SWghYTc,170 +multipart/tests/test_data/http/utf8_filename.http,sha256=w_Ryf4hC_KJo7v-a18dJFECqm21nzA5Z18dsGyu6zjA,208 +multipart/tests/test_data/http/utf8_filename.yaml,sha256=KpDc4e-yYp_JUXa-S5lp591tzoEybgywtGian0kQFPc,177 +multipart/tests/test_multipart.py,sha256=VrxoOtXO4NWpT1OJqo7FWWIybnxGReumIWCR-FDIHCk,38988 +python_multipart-0.0.6.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +python_multipart-0.0.6.dist-info/METADATA,sha256=J4WQf99XHSSg_EDG7fGgJGotS_Hp7ViCtpY4rQ2OgyM,2459 +python_multipart-0.0.6.dist-info/RECORD,, +python_multipart-0.0.6.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +python_multipart-0.0.6.dist-info/WHEEL,sha256=Fd6mP6ydyRguakwUJ05oBE7fh2IPxgtDN9IwHJ9OqJQ,87 +python_multipart-0.0.6.dist-info/licenses/LICENSE.txt,sha256=qOgzF2zWF9rwC51tOfoVyo7evG0WQwec0vSJPAwom-I,556 diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/WHEEL b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/WHEEL new file mode 100644 index 0000000..9d72767 --- /dev/null +++ b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.13.0 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/licenses/LICENSE.txt b/venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/licenses/LICENSE.txt similarity index 100% rename from venv/lib/python3.12/site-packages/python_multipart-0.0.20.dist-info/licenses/LICENSE.txt rename to venv/lib/python3.12/site-packages/python_multipart-0.0.6.dist-info/licenses/LICENSE.txt diff --git a/venv/lib/python3.12/site-packages/python_multipart/__init__.py b/venv/lib/python3.12/site-packages/python_multipart/__init__.py deleted file mode 100644 index e426526..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# This is the canonical package information. -__author__ = "Andrew Dunham" -__license__ = "Apache" -__copyright__ = "Copyright (c) 2012-2013, Andrew Dunham" -__version__ = "0.0.20" - -from .multipart import ( - BaseParser, - FormParser, - MultipartParser, - OctetStreamParser, - QuerystringParser, - create_form_parser, - parse_form, -) - -__all__ = ( - "BaseParser", - "FormParser", - "MultipartParser", - "OctetStreamParser", - "QuerystringParser", - "create_form_parser", - "parse_form", -) diff --git a/venv/lib/python3.12/site-packages/python_multipart/decoders.py b/venv/lib/python3.12/site-packages/python_multipart/decoders.py deleted file mode 100644 index 82b56a1..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart/decoders.py +++ /dev/null @@ -1,185 +0,0 @@ -import base64 -import binascii -from typing import TYPE_CHECKING - -from .exceptions import DecodeError - -if TYPE_CHECKING: # pragma: no cover - from typing import Protocol, TypeVar - - _T_contra = TypeVar("_T_contra", contravariant=True) - - class SupportsWrite(Protocol[_T_contra]): - def write(self, __b: _T_contra) -> object: ... - - # No way to specify optional methods. See - # https://github.com/python/typing/issues/601 - # close() [Optional] - # finalize() [Optional] - - -class Base64Decoder: - """This object provides an interface to decode a stream of Base64 data. It - is instantiated with an "underlying object", and whenever a write() - operation is performed, it will decode the incoming data as Base64, and - call write() on the underlying object. This is primarily used for decoding - form data encoded as Base64, but can be used for other purposes:: - - from python_multipart.decoders import Base64Decoder - fd = open("notb64.txt", "wb") - decoder = Base64Decoder(fd) - try: - decoder.write("Zm9vYmFy") # "foobar" in Base64 - decoder.finalize() - finally: - decoder.close() - - # The contents of "notb64.txt" should be "foobar". - - This object will also pass all finalize() and close() calls to the - underlying object, if the underlying object supports them. - - Note that this class maintains a cache of base64 chunks, so that a write of - arbitrary size can be performed. You must call :meth:`finalize` on this - object after all writes are completed to ensure that all data is flushed - to the underlying object. - - :param underlying: the underlying object to pass writes to - """ - - def __init__(self, underlying: "SupportsWrite[bytes]") -> None: - self.cache = bytearray() - self.underlying = underlying - - def write(self, data: bytes) -> int: - """Takes any input data provided, decodes it as base64, and passes it - on to the underlying object. If the data provided is invalid base64 - data, then this method will raise - a :class:`python_multipart.exceptions.DecodeError` - - :param data: base64 data to decode - """ - - # Prepend any cache info to our data. - if len(self.cache) > 0: - data = self.cache + data - - # Slice off a string that's a multiple of 4. - decode_len = (len(data) // 4) * 4 - val = data[:decode_len] - - # Decode and write, if we have any. - if len(val) > 0: - try: - decoded = base64.b64decode(val) - except binascii.Error: - raise DecodeError("There was an error raised while decoding base64-encoded data.") - - self.underlying.write(decoded) - - # Get the remaining bytes and save in our cache. - remaining_len = len(data) % 4 - if remaining_len > 0: - self.cache[:] = data[-remaining_len:] - else: - self.cache[:] = b"" - - # Return the length of the data to indicate no error. - return len(data) - - def close(self) -> None: - """Close this decoder. If the underlying object has a `close()` - method, this function will call it. - """ - if hasattr(self.underlying, "close"): - self.underlying.close() - - def finalize(self) -> None: - """Finalize this object. This should be called when no more data - should be written to the stream. This function can raise a - :class:`python_multipart.exceptions.DecodeError` if there is some remaining - data in the cache. - - If the underlying object has a `finalize()` method, this function will - call it. - """ - if len(self.cache) > 0: - raise DecodeError( - "There are %d bytes remaining in the Base64Decoder cache when finalize() is called" % len(self.cache) - ) - - if hasattr(self.underlying, "finalize"): - self.underlying.finalize() - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(underlying={self.underlying!r})" - - -class QuotedPrintableDecoder: - """This object provides an interface to decode a stream of quoted-printable - data. It is instantiated with an "underlying object", in the same manner - as the :class:`python_multipart.decoders.Base64Decoder` class. This class behaves - in exactly the same way, including maintaining a cache of quoted-printable - chunks. - - :param underlying: the underlying object to pass writes to - """ - - def __init__(self, underlying: "SupportsWrite[bytes]") -> None: - self.cache = b"" - self.underlying = underlying - - def write(self, data: bytes) -> int: - """Takes any input data provided, decodes it as quoted-printable, and - passes it on to the underlying object. - - :param data: quoted-printable data to decode - """ - # Prepend any cache info to our data. - if len(self.cache) > 0: - data = self.cache + data - - # If the last 2 characters have an '=' sign in it, then we won't be - # able to decode the encoded value and we'll need to save it for the - # next decoding step. - if data[-2:].find(b"=") != -1: - enc, rest = data[:-2], data[-2:] - else: - enc = data - rest = b"" - - # Encode and write, if we have data. - if len(enc) > 0: - self.underlying.write(binascii.a2b_qp(enc)) - - # Save remaining in cache. - self.cache = rest - return len(data) - - def close(self) -> None: - """Close this decoder. If the underlying object has a `close()` - method, this function will call it. - """ - if hasattr(self.underlying, "close"): - self.underlying.close() - - def finalize(self) -> None: - """Finalize this object. This should be called when no more data - should be written to the stream. This function will not raise any - exceptions, but it may write more data to the underlying object if - there is data remaining in the cache. - - If the underlying object has a `finalize()` method, this function will - call it. - """ - # If we have a cache, write and then remove it. - if len(self.cache) > 0: # pragma: no cover - self.underlying.write(binascii.a2b_qp(self.cache)) - self.cache = b"" - - # Finalize our underlying stream. - if hasattr(self.underlying, "finalize"): - self.underlying.finalize() - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(underlying={self.underlying!r})" diff --git a/venv/lib/python3.12/site-packages/python_multipart/exceptions.py b/venv/lib/python3.12/site-packages/python_multipart/exceptions.py deleted file mode 100644 index cc3671f..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart/exceptions.py +++ /dev/null @@ -1,34 +0,0 @@ -class FormParserError(ValueError): - """Base error class for our form parser.""" - - -class ParseError(FormParserError): - """This exception (or a subclass) is raised when there is an error while - parsing something. - """ - - #: This is the offset in the input data chunk (*NOT* the overall stream) in - #: which the parse error occurred. It will be -1 if not specified. - offset = -1 - - -class MultipartParseError(ParseError): - """This is a specific error that is raised when the MultipartParser detects - an error while parsing. - """ - - -class QuerystringParseError(ParseError): - """This is a specific error that is raised when the QuerystringParser - detects an error while parsing. - """ - - -class DecodeError(ParseError): - """This exception is raised when there is a decoding error - for example - with the Base64Decoder or QuotedPrintableDecoder. - """ - - -class FileError(FormParserError, OSError): - """Exception class for problems with the File class.""" diff --git a/venv/lib/python3.12/site-packages/python_multipart/multipart.py b/venv/lib/python3.12/site-packages/python_multipart/multipart.py deleted file mode 100644 index f26a815..0000000 --- a/venv/lib/python3.12/site-packages/python_multipart/multipart.py +++ /dev/null @@ -1,1873 +0,0 @@ -from __future__ import annotations - -import logging -import os -import shutil -import sys -import tempfile -from email.message import Message -from enum import IntEnum -from io import BufferedRandom, BytesIO -from numbers import Number -from typing import TYPE_CHECKING, cast - -from .decoders import Base64Decoder, QuotedPrintableDecoder -from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError - -if TYPE_CHECKING: # pragma: no cover - from typing import Any, Callable, Literal, Protocol, TypedDict - - from typing_extensions import TypeAlias - - class SupportsRead(Protocol): - def read(self, __n: int) -> bytes: ... - - class QuerystringCallbacks(TypedDict, total=False): - on_field_start: Callable[[], None] - on_field_name: Callable[[bytes, int, int], None] - on_field_data: Callable[[bytes, int, int], None] - on_field_end: Callable[[], None] - on_end: Callable[[], None] - - class OctetStreamCallbacks(TypedDict, total=False): - on_start: Callable[[], None] - on_data: Callable[[bytes, int, int], None] - on_end: Callable[[], None] - - class MultipartCallbacks(TypedDict, total=False): - on_part_begin: Callable[[], None] - on_part_data: Callable[[bytes, int, int], None] - on_part_end: Callable[[], None] - on_header_begin: Callable[[], None] - on_header_field: Callable[[bytes, int, int], None] - on_header_value: Callable[[bytes, int, int], None] - on_header_end: Callable[[], None] - on_headers_finished: Callable[[], None] - on_end: Callable[[], None] - - class FormParserConfig(TypedDict): - UPLOAD_DIR: str | None - UPLOAD_KEEP_FILENAME: bool - UPLOAD_KEEP_EXTENSIONS: bool - UPLOAD_ERROR_ON_BAD_CTE: bool - MAX_MEMORY_FILE_SIZE: int - MAX_BODY_SIZE: float - - class FileConfig(TypedDict, total=False): - UPLOAD_DIR: str | bytes | None - UPLOAD_DELETE_TMP: bool - UPLOAD_KEEP_FILENAME: bool - UPLOAD_KEEP_EXTENSIONS: bool - MAX_MEMORY_FILE_SIZE: int - - class _FormProtocol(Protocol): - def write(self, data: bytes) -> int: ... - - def finalize(self) -> None: ... - - def close(self) -> None: ... - - class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes | None) -> None: ... - - def set_none(self) -> None: ... - - class FileProtocol(_FormProtocol, Protocol): - def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ... - - OnFieldCallback = Callable[[FieldProtocol], None] - OnFileCallback = Callable[[FileProtocol], None] - - CallbackName: TypeAlias = Literal[ - "start", - "data", - "end", - "field_start", - "field_name", - "field_data", - "field_end", - "part_begin", - "part_data", - "part_end", - "header_begin", - "header_field", - "header_value", - "header_end", - "headers_finished", - ] - -# Unique missing object. -_missing = object() - - -class QuerystringState(IntEnum): - """Querystring parser states. - - These are used to keep track of the state of the parser, and are used to determine - what to do when new data is encountered. - """ - - BEFORE_FIELD = 0 - FIELD_NAME = 1 - FIELD_DATA = 2 - - -class MultipartState(IntEnum): - """Multipart parser states. - - These are used to keep track of the state of the parser, and are used to determine - what to do when new data is encountered. - """ - - START = 0 - START_BOUNDARY = 1 - HEADER_FIELD_START = 2 - HEADER_FIELD = 3 - HEADER_VALUE_START = 4 - HEADER_VALUE = 5 - HEADER_VALUE_ALMOST_DONE = 6 - HEADERS_ALMOST_DONE = 7 - PART_DATA_START = 8 - PART_DATA = 9 - PART_DATA_END = 10 - END_BOUNDARY = 11 - END = 12 - - -# Flags for the multipart parser. -FLAG_PART_BOUNDARY = 1 -FLAG_LAST_BOUNDARY = 2 - -# Get constants. Since iterating over a str on Python 2 gives you a 1-length -# string, but iterating over a bytes object on Python 3 gives you an integer, -# we need to save these constants. -CR = b"\r"[0] -LF = b"\n"[0] -COLON = b":"[0] -SPACE = b" "[0] -HYPHEN = b"-"[0] -AMPERSAND = b"&"[0] -SEMICOLON = b";"[0] -LOWER_A = b"a"[0] -LOWER_Z = b"z"[0] -NULL = b"\x00"[0] - -# fmt: off -# Mask for ASCII characters that can be http tokens. -# Per RFC7230 - 3.2.6, this is all alpha-numeric characters -# and these: !#$%&'*+-.^_`|~ -TOKEN_CHARS_SET = frozenset( - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"abcdefghijklmnopqrstuvwxyz" - b"0123456789" - b"!#$%&'*+-.^_`|~") -# fmt: on - - -def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]: - """Parses a Content-Type header into a value in the following format: (content_type, {parameters}).""" - # Uses email.message.Message to parse the header as described in PEP 594. - # Ref: https://peps.python.org/pep-0594/#cgi - if not value: - return (b"", {}) - - # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1. - if isinstance(value, bytes): # pragma: no cover - value = value.decode("latin-1") - - # For types - assert isinstance(value, str), "Value should be a string by now" - - # If we have no options, return the string as-is. - if ";" not in value: - return (value.lower().strip().encode("latin-1"), {}) - - # Split at the first semicolon, to get our value and then options. - # ctype, rest = value.split(b';', 1) - message = Message() - message["content-type"] = value - params = message.get_params() - # If there were no parameters, this would have already returned above - assert params, "At least the content type value should be present" - ctype = params.pop(0)[0].encode("latin-1") - options: dict[bytes, bytes] = {} - for param in params: - key, value = param - # If the value returned from get_params() is a 3-tuple, the last - # element corresponds to the value. - # See: https://docs.python.org/3/library/email.compat32-message.html - if isinstance(value, tuple): - value = value[-1] - # If the value is a filename, we need to fix a bug on IE6 that sends - # the full file path instead of the filename. - if key == "filename": - if value[1:3] == ":\\" or value[:2] == "\\\\": - value = value.split("\\")[-1] - options[key.encode("latin-1")] = value.encode("latin-1") - return ctype, options - - -class Field: - """A Field object represents a (parsed) form field. It represents a single - field with a corresponding name and value. - - The name that a :class:`Field` will be instantiated with is the same name - that would be found in the following HTML:: - - - - This class defines two methods, :meth:`on_data` and :meth:`on_end`, that - will be called when data is written to the Field, and when the Field is - finalized, respectively. - - Args: - name: The name of the form field. - """ - - def __init__(self, name: bytes | None) -> None: - self._name = name - self._value: list[bytes] = [] - - # We cache the joined version of _value for speed. - self._cache = _missing - - @classmethod - def from_value(cls, name: bytes, value: bytes | None) -> Field: - """Create an instance of a :class:`Field`, and set the corresponding - value - either None or an actual value. This method will also - finalize the Field itself. - - Args: - name: the name of the form field. - value: the value of the form field - either a bytestring or None. - - Returns: - A new instance of a [`Field`][python_multipart.Field]. - """ - - f = cls(name) - if value is None: - f.set_none() - else: - f.write(value) - f.finalize() - return f - - def write(self, data: bytes) -> int: - """Write some data into the form field. - - Args: - data: The data to write to the field. - - Returns: - The number of bytes written. - """ - return self.on_data(data) - - def on_data(self, data: bytes) -> int: - """This method is a callback that will be called whenever data is - written to the Field. - - Args: - data: The data to write to the field. - - Returns: - The number of bytes written. - """ - self._value.append(data) - self._cache = _missing - return len(data) - - def on_end(self) -> None: - """This method is called whenever the Field is finalized.""" - if self._cache is _missing: - self._cache = b"".join(self._value) - - def finalize(self) -> None: - """Finalize the form field.""" - self.on_end() - - def close(self) -> None: - """Close the Field object. This will free any underlying cache.""" - # Free our value array. - if self._cache is _missing: - self._cache = b"".join(self._value) - - del self._value - - def set_none(self) -> None: - """Some fields in a querystring can possibly have a value of None - for - example, the string "foo&bar=&baz=asdf" will have a field with the - name "foo" and value None, one with name "bar" and value "", and one - with name "baz" and value "asdf". Since the write() interface doesn't - support writing None, this function will set the field value to None. - """ - self._cache = None - - @property - def field_name(self) -> bytes | None: - """This property returns the name of the field.""" - return self._name - - @property - def value(self) -> bytes | None: - """This property returns the value of the form field.""" - if self._cache is _missing: - self._cache = b"".join(self._value) - - assert isinstance(self._cache, bytes) or self._cache is None - return self._cache - - def __eq__(self, other: object) -> bool: - if isinstance(other, Field): - return self.field_name == other.field_name and self.value == other.value - else: - return NotImplemented - - def __repr__(self) -> str: - if self.value is not None and len(self.value) > 97: - # We get the repr, and then insert three dots before the final - # quote. - v = repr(self.value[:97])[:-1] + "...'" - else: - v = repr(self.value) - - return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v) - - -class File: - """This class represents an uploaded file. It handles writing file data to - either an in-memory file or a temporary file on-disk, if the optional - threshold is passed. - - There are some options that can be passed to the File to change behavior - of the class. Valid options are as follows: - - | Name | Type | Default | Description | - |-----------------------|-------|---------|-------------| - | UPLOAD_DIR | `str` | None | The directory to store uploaded files in. If this is None, a temporary file will be created in the system's standard location. | - | UPLOAD_DELETE_TMP | `bool`| True | Delete automatically created TMP file | - | UPLOAD_KEEP_FILENAME | `bool`| False | Whether or not to keep the filename of the uploaded file. If True, then the filename will be converted to a safe representation (e.g. by removing any invalid path segments), and then saved with the same name). Otherwise, a temporary name will be used. | - | UPLOAD_KEEP_EXTENSIONS| `bool`| False | Whether or not to keep the uploaded file's extension. If False, the file will be saved with the default temporary extension (usually ".tmp"). Otherwise, the file's extension will be maintained. Note that this will properly combine with the UPLOAD_KEEP_FILENAME setting. | - | MAX_MEMORY_FILE_SIZE | `int` | 1 MiB | The maximum number of bytes of a File to keep in memory. By default, the contents of a File are kept into memory until a certain limit is reached, after which the contents of the File are written to a temporary file. This behavior can be disabled by setting this value to an appropriately large value (or, for example, infinity, such as `float('inf')`. | - - Args: - file_name: The name of the file that this [`File`][python_multipart.File] represents. - field_name: The name of the form field that this file was uploaded with. This can be None, if, for example, - the file was uploaded with Content-Type application/octet-stream. - config: The configuration for this File. See above for valid configuration keys and their corresponding values. - """ # noqa: E501 - - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None: - # Save configuration, set other variables default. - self.logger = logging.getLogger(__name__) - self._config = config - self._in_memory = True - self._bytes_written = 0 - self._fileobj: BytesIO | BufferedRandom = BytesIO() - - # Save the provided field/file name. - self._field_name = field_name - self._file_name = file_name - - # Our actual file name is None by default, since, depending on our - # config, we may not actually use the provided name. - self._actual_file_name: bytes | None = None - - # Split the extension from the filename. - if file_name is not None: - base, ext = os.path.splitext(file_name) - self._file_base = base - self._ext = ext - - @property - def field_name(self) -> bytes | None: - """The form field associated with this file. May be None if there isn't - one, for example when we have an application/octet-stream upload. - """ - return self._field_name - - @property - def file_name(self) -> bytes | None: - """The file name given in the upload request.""" - return self._file_name - - @property - def actual_file_name(self) -> bytes | None: - """The file name that this file is saved as. Will be None if it's not - currently saved on disk. - """ - return self._actual_file_name - - @property - def file_object(self) -> BytesIO | BufferedRandom: - """The file object that we're currently writing to. Note that this - will either be an instance of a :class:`io.BytesIO`, or a regular file - object. - """ - return self._fileobj - - @property - def size(self) -> int: - """The total size of this file, counted as the number of bytes that - currently have been written to the file. - """ - return self._bytes_written - - @property - def in_memory(self) -> bool: - """A boolean representing whether or not this file object is currently - stored in-memory or on-disk. - """ - return self._in_memory - - def flush_to_disk(self) -> None: - """If the file is already on-disk, do nothing. Otherwise, copy from - the in-memory buffer to a disk file, and then reassign our internal - file object to this new disk file. - - Note that if you attempt to flush a file that is already on-disk, a - warning will be logged to this module's logger. - """ - if not self._in_memory: - self.logger.warning("Trying to flush to disk when we're not in memory") - return - - # Go back to the start of our file. - self._fileobj.seek(0) - - # Open a new file. - new_file = self._get_disk_file() - - # Copy the file objects. - shutil.copyfileobj(self._fileobj, new_file) - - # Seek to the new position in our new file. - new_file.seek(self._bytes_written) - - # Reassign the fileobject. - old_fileobj = self._fileobj - self._fileobj = new_file - - # We're no longer in memory. - self._in_memory = False - - # Close the old file object. - old_fileobj.close() - - def _get_disk_file(self) -> BufferedRandom: - """This function is responsible for getting a file object on-disk for us.""" - self.logger.info("Opening a file on disk") - - file_dir = self._config.get("UPLOAD_DIR") - keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False) - keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False) - delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True) - tmp_file: None | BufferedRandom = None - - # If we have a directory and are to keep the filename... - if file_dir is not None and keep_filename: - self.logger.info("Saving with filename in: %r", file_dir) - - # Build our filename. - # TODO: what happens if we don't have a filename? - fname = self._file_base + self._ext if keep_extensions else self._file_base - - path = os.path.join(file_dir, fname) # type: ignore[arg-type] - try: - self.logger.info("Opening file: %r", path) - tmp_file = open(path, "w+b") - except OSError: - tmp_file = None - - self.logger.exception("Error opening temporary file") - raise FileError("Error opening temporary file: %r" % path) - else: - # Build options array. - # Note that on Python 3, tempfile doesn't support byte names. We - # encode our paths using the default filesystem encoding. - suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None - - if file_dir is None: - dir = None - elif isinstance(file_dir, bytes): - dir = file_dir.decode(sys.getfilesystemencoding()) - else: - dir = file_dir # pragma: no cover - - # Create a temporary (named) file with the appropriate settings. - self.logger.info( - "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir} - ) - try: - tmp_file = cast(BufferedRandom, tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir)) - except OSError: - self.logger.exception("Error creating named temporary file") - raise FileError("Error creating named temporary file") - - assert tmp_file is not None - # Encode filename as bytes. - if isinstance(tmp_file.name, str): - fname = tmp_file.name.encode(sys.getfilesystemencoding()) - else: - fname = cast(bytes, tmp_file.name) # pragma: no cover - - self._actual_file_name = fname - return tmp_file - - def write(self, data: bytes) -> int: - """Write some data to the File. - - :param data: a bytestring - """ - return self.on_data(data) - - def on_data(self, data: bytes) -> int: - """This method is a callback that will be called whenever data is - written to the File. - - Args: - data: The data to write to the file. - - Returns: - The number of bytes written. - """ - bwritten = self._fileobj.write(data) - - # If the bytes written isn't the same as the length, just return. - if bwritten != len(data): - self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data)) - return bwritten - - # Keep track of how many bytes we've written. - self._bytes_written += bwritten - - # If we're in-memory and are over our limit, we create a file. - max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE") - if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size): - self.logger.info("Flushing to disk") - self.flush_to_disk() - - # Return the number of bytes written. - return bwritten - - def on_end(self) -> None: - """This method is called whenever the Field is finalized.""" - # Flush the underlying file object - self._fileobj.flush() - - def finalize(self) -> None: - """Finalize the form file. This will not close the underlying file, - but simply signal that we are finished writing to the File. - """ - self.on_end() - - def close(self) -> None: - """Close the File object. This will actually close the underlying - file object (whether it's a :class:`io.BytesIO` or an actual file - object). - """ - self._fileobj.close() - - def __repr__(self) -> str: - return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name) - - -class BaseParser: - """This class is the base class for all parsers. It contains the logic for - calling and adding callbacks. - - A callback can be one of two different forms. "Notification callbacks" are - callbacks that are called when something happens - for example, when a new - part of a multipart message is encountered by the parser. "Data callbacks" - are called when we get some sort of data - for example, part of the body of - a multipart chunk. Notification callbacks are called with no parameters, - whereas data callbacks are called with three, as follows:: - - data_callback(data, start, end) - - The "data" parameter is a bytestring (i.e. "foo" on Python 2, or b"foo" on - Python 3). "start" and "end" are integer indexes into the "data" string - that represent the data of interest. Thus, in a data callback, the slice - `data[start:end]` represents the data that the callback is "interested in". - The callback is not passed a copy of the data, since copying severely hurts - performance. - """ - - def __init__(self) -> None: - self.logger = logging.getLogger(__name__) - self.callbacks: QuerystringCallbacks | OctetStreamCallbacks | MultipartCallbacks = {} - - def callback( - self, name: CallbackName, data: bytes | None = None, start: int | None = None, end: int | None = None - ) -> None: - """This function calls a provided callback with some data. If the - callback is not set, will do nothing. - - Args: - name: The name of the callback to call (as a string). - data: Data to pass to the callback. If None, then it is assumed that the callback is a notification - callback, and no parameters are given. - end: An integer that is passed to the data callback. - start: An integer that is passed to the data callback. - """ - on_name = "on_" + name - func = self.callbacks.get(on_name) - if func is None: - return - func = cast("Callable[..., Any]", func) - # Depending on whether we're given a buffer... - if data is not None: - # Don't do anything if we have start == end. - if start is not None and start == end: - return - - self.logger.debug("Calling %s with data[%d:%d]", on_name, start, end) - func(data, start, end) - else: - self.logger.debug("Calling %s with no data", on_name) - func() - - def set_callback(self, name: CallbackName, new_func: Callable[..., Any] | None) -> None: - """Update the function for a callback. Removes from the callbacks dict - if new_func is None. - - :param name: The name of the callback to call (as a string). - - :param new_func: The new function for the callback. If None, then the - callback will be removed (with no error if it does not - exist). - """ - if new_func is None: - self.callbacks.pop("on_" + name, None) # type: ignore[misc] - else: - self.callbacks["on_" + name] = new_func # type: ignore[literal-required] - - def close(self) -> None: - pass # pragma: no cover - - def finalize(self) -> None: - pass # pragma: no cover - - def __repr__(self) -> str: - return "%s()" % self.__class__.__name__ - - -class OctetStreamParser(BaseParser): - """This parser parses an octet-stream request body and calls callbacks when - incoming data is received. Callbacks are as follows: - - | Callback Name | Parameters | Description | - |----------------|-----------------|-----------------------------------------------------| - | on_start | None | Called when the first data is parsed. | - | on_data | data, start, end| Called for each data chunk that is parsed. | - | on_end | None | Called when the parser is finished parsing all data.| - - Args: - callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser]. - max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. - """ - - def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")): - super().__init__() - self.callbacks = callbacks - self._started = False - - if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size: int | float = max_size - self._current_size = 0 - - def write(self, data: bytes) -> int: - """Write some data to the parser, which will perform size verification, - and then pass the data to the underlying callback. - - Args: - data: The data to write to the parser. - - Returns: - The number of bytes written. - """ - if not self._started: - self.callback("start") - self._started = True - - # Truncate data length. - data_len = len(data) - if (self._current_size + data_len) > self.max_size: - # We truncate the length of data that we are to process. - new_size = int(self.max_size - self._current_size) - self.logger.warning( - "Current size is %d (max %d), so truncating data length from %d to %d", - self._current_size, - self.max_size, - data_len, - new_size, - ) - data_len = new_size - - # Increment size, then callback, in case there's an exception. - self._current_size += data_len - self.callback("data", data, 0, data_len) - return data_len - - def finalize(self) -> None: - """Finalize this parser, which signals to that we are finished parsing, - and sends the on_end callback. - """ - self.callback("end") - - def __repr__(self) -> str: - return "%s()" % self.__class__.__name__ - - -class QuerystringParser(BaseParser): - """This is a streaming querystring parser. It will consume data, and call - the callbacks given when it has data. - - | Callback Name | Parameters | Description | - |----------------|-----------------|-----------------------------------------------------| - | on_field_start | None | Called when a new field is encountered. | - | on_field_name | data, start, end| Called when a portion of a field's name is encountered. | - | on_field_data | data, start, end| Called when a portion of a field's data is encountered. | - | on_field_end | None | Called when the end of a field is encountered. | - | on_end | None | Called when the parser is finished parsing all data.| - - Args: - callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser]. - strict_parsing: Whether or not to parse the body strictly. Defaults to False. If this is set to True, then the - behavior of the parser changes as the following: if a field has a value with an equal sign - (e.g. "foo=bar", or "foo="), it is always included. If a field has no equals sign (e.g. "...&name&..."), - it will be treated as an error if 'strict_parsing' is True, otherwise included. If an error is encountered, - then a [`QuerystringParseError`][python_multipart.exceptions.QuerystringParseError] will be raised. - max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. - """ # noqa: E501 - - state: QuerystringState - - def __init__( - self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf") - ) -> None: - super().__init__() - self.state = QuerystringState.BEFORE_FIELD - self._found_sep = False - - self.callbacks = callbacks - - # Max-size stuff - if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size: int | float = max_size - self._current_size = 0 - - # Should parsing be strict? - self.strict_parsing = strict_parsing - - def write(self, data: bytes) -> int: - """Write some data to the parser, which will perform size verification, - parse into either a field name or value, and then pass the - corresponding data to the underlying callback. If an error is - encountered while parsing, a QuerystringParseError will be raised. The - "offset" attribute of the raised exception will be set to the offset in - the input data chunk (NOT the overall stream) that caused the error. - - Args: - data: The data to write to the parser. - - Returns: - The number of bytes written. - """ - # Handle sizing. - data_len = len(data) - if (self._current_size + data_len) > self.max_size: - # We truncate the length of data that we are to process. - new_size = int(self.max_size - self._current_size) - self.logger.warning( - "Current size is %d (max %d), so truncating data length from %d to %d", - self._current_size, - self.max_size, - data_len, - new_size, - ) - data_len = new_size - - l = 0 - try: - l = self._internal_write(data, data_len) - finally: - self._current_size += l - - return l - - def _internal_write(self, data: bytes, length: int) -> int: - state = self.state - strict_parsing = self.strict_parsing - found_sep = self._found_sep - - i = 0 - while i < length: - ch = data[i] - - # Depending on our state... - if state == QuerystringState.BEFORE_FIELD: - # If the 'found_sep' flag is set, we've already encountered - # and skipped a single separator. If so, we check our strict - # parsing flag and decide what to do. Otherwise, we haven't - # yet reached a separator, and thus, if we do, we need to skip - # it as it will be the boundary between fields that's supposed - # to be there. - if ch == AMPERSAND or ch == SEMICOLON: - if found_sep: - # If we're parsing strictly, we disallow blank chunks. - if strict_parsing: - e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i) - e.offset = i - raise e - else: - self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i) - else: - # This case is when we're skipping the (first) - # separator between fields, so we just set our flag - # and continue on. - found_sep = True - else: - # Emit a field-start event, and go to that state. Also, - # reset the "found_sep" flag, for the next time we get to - # this state. - self.callback("field_start") - i -= 1 - state = QuerystringState.FIELD_NAME - found_sep = False - - elif state == QuerystringState.FIELD_NAME: - # Try and find a separator - we ensure that, if we do, we only - # look for the equal sign before it. - sep_pos = data.find(b"&", i) - if sep_pos == -1: - sep_pos = data.find(b";", i) - - # See if we can find an equals sign in the remaining data. If - # so, we can immediately emit the field name and jump to the - # data state. - if sep_pos != -1: - equals_pos = data.find(b"=", i, sep_pos) - else: - equals_pos = data.find(b"=", i) - - if equals_pos != -1: - # Emit this name. - self.callback("field_name", data, i, equals_pos) - - # Jump i to this position. Note that it will then have 1 - # added to it below, which means the next iteration of this - # loop will inspect the character after the equals sign. - i = equals_pos - state = QuerystringState.FIELD_DATA - else: - # No equals sign found. - if not strict_parsing: - # See also comments in the QuerystringState.FIELD_DATA case below. - # If we found the separator, we emit the name and just - # end - there's no data callback at all (not even with - # a blank value). - if sep_pos != -1: - self.callback("field_name", data, i, sep_pos) - self.callback("field_end") - - i = sep_pos - 1 - state = QuerystringState.BEFORE_FIELD - else: - # Otherwise, no separator in this block, so the - # rest of this chunk must be a name. - self.callback("field_name", data, i, length) - i = length - - else: - # We're parsing strictly. If we find a separator, - # this is an error - we require an equals sign. - if sep_pos != -1: - e = QuerystringParseError( - "When strict_parsing is True, we require an " - "equals sign in all field chunks. Did not " - "find one in the chunk that starts at %d" % (i,) - ) - e.offset = i - raise e - - # No separator in the rest of this chunk, so it's just - # a field name. - self.callback("field_name", data, i, length) - i = length - - elif state == QuerystringState.FIELD_DATA: - # Try finding either an ampersand or a semicolon after this - # position. - sep_pos = data.find(b"&", i) - if sep_pos == -1: - sep_pos = data.find(b";", i) - - # If we found it, callback this bit as data and then go back - # to expecting to find a field. - if sep_pos != -1: - self.callback("field_data", data, i, sep_pos) - self.callback("field_end") - - # Note that we go to the separator, which brings us to the - # "before field" state. This allows us to properly emit - # "field_start" events only when we actually have data for - # a field of some sort. - i = sep_pos - 1 - state = QuerystringState.BEFORE_FIELD - - # Otherwise, emit the rest as data and finish. - else: - self.callback("field_data", data, i, length) - i = length - - else: # pragma: no cover (error case) - msg = "Reached an unknown state %d at %d" % (state, i) - self.logger.warning(msg) - e = QuerystringParseError(msg) - e.offset = i - raise e - - i += 1 - - self.state = state - self._found_sep = found_sep - return len(data) - - def finalize(self) -> None: - """Finalize this parser, which signals to that we are finished parsing, - if we're still in the middle of a field, an on_field_end callback, and - then the on_end callback. - """ - # If we're currently in the middle of a field, we finish it. - if self.state == QuerystringState.FIELD_DATA: - self.callback("field_end") - self.callback("end") - - def __repr__(self) -> str: - return "{}(strict_parsing={!r}, max_size={!r})".format( - self.__class__.__name__, self.strict_parsing, self.max_size - ) - - -class MultipartParser(BaseParser): - """This class is a streaming multipart/form-data parser. - - | Callback Name | Parameters | Description | - |--------------------|-----------------|-------------| - | on_part_begin | None | Called when a new part of the multipart message is encountered. | - | on_part_data | data, start, end| Called when a portion of a part's data is encountered. | - | on_part_end | None | Called when the end of a part is reached. | - | on_header_begin | None | Called when we've found a new header in a part of a multipart message | - | on_header_field | data, start, end| Called each time an additional portion of a header is read (i.e. the part of the header that is before the colon; the "Foo" in "Foo: Bar"). | - | on_header_value | data, start, end| Called when we get data for a header. | - | on_header_end | None | Called when the current header is finished - i.e. we've reached the newline at the end of the header. | - | on_headers_finished| None | Called when all headers are finished, and before the part data starts. | - | on_end | None | Called when the parser is finished parsing all data. | - - Args: - boundary: The multipart boundary. This is required, and must match what is given in the HTTP request - usually in the Content-Type header. - callbacks: A dictionary of callbacks. See the documentation for [`BaseParser`][python_multipart.BaseParser]. - max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. - """ # noqa: E501 - - def __init__( - self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf") - ) -> None: - # Initialize parser state. - super().__init__() - self.state = MultipartState.START - self.index = self.flags = 0 - - self.callbacks = callbacks - - if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % max_size) - self.max_size = max_size - self._current_size = 0 - - # Setup marks. These are used to track the state of data received. - self.marks: dict[str, int] = {} - - # Save our boundary. - if isinstance(boundary, str): # pragma: no cover - boundary = boundary.encode("latin-1") - self.boundary = b"\r\n--" + boundary - - def write(self, data: bytes) -> int: - """Write some data to the parser, which will perform size verification, - and then parse the data into the appropriate location (e.g. header, - data, etc.), and pass this on to the underlying callback. If an error - is encountered, a MultipartParseError will be raised. The "offset" - attribute on the raised exception will be set to the offset of the byte - in the input chunk that caused the error. - - Args: - data: The data to write to the parser. - - Returns: - The number of bytes written. - """ - # Handle sizing. - data_len = len(data) - if (self._current_size + data_len) > self.max_size: - # We truncate the length of data that we are to process. - new_size = int(self.max_size - self._current_size) - self.logger.warning( - "Current size is %d (max %d), so truncating data length from %d to %d", - self._current_size, - self.max_size, - data_len, - new_size, - ) - data_len = new_size - - l = 0 - try: - l = self._internal_write(data, data_len) - finally: - self._current_size += l - - return l - - def _internal_write(self, data: bytes, length: int) -> int: - # Get values from locals. - boundary = self.boundary - - # Get our state, flags and index. These are persisted between calls to - # this function. - state = self.state - index = self.index - flags = self.flags - - # Our index defaults to 0. - i = 0 - - # Set a mark. - def set_mark(name: str) -> None: - self.marks[name] = i - - # Remove a mark. - def delete_mark(name: str, reset: bool = False) -> None: - self.marks.pop(name, None) - - # Helper function that makes calling a callback with data easier. The - # 'remaining' parameter will callback from the marked value until the - # end of the buffer, and reset the mark, instead of deleting it. This - # is used at the end of the function to call our callbacks with any - # remaining data in this chunk. - def data_callback(name: CallbackName, end_i: int, remaining: bool = False) -> None: - marked_index = self.marks.get(name) - if marked_index is None: - return - - # Otherwise, we call it from the mark to the current byte we're - # processing. - if end_i <= marked_index: - # There is no additional data to send. - pass - elif marked_index >= 0: - # We are emitting data from the local buffer. - self.callback(name, data, marked_index, end_i) - else: - # Some of the data comes from a partial boundary match. - # and requires look-behind. - # We need to use self.flags (and not flags) because we care about - # the state when we entered the loop. - lookbehind_len = -marked_index - if lookbehind_len <= len(boundary): - self.callback(name, boundary, 0, lookbehind_len) - elif self.flags & FLAG_PART_BOUNDARY: - lookback = boundary + b"\r\n" - self.callback(name, lookback, 0, lookbehind_len) - elif self.flags & FLAG_LAST_BOUNDARY: - lookback = boundary + b"--\r\n" - self.callback(name, lookback, 0, lookbehind_len) - else: # pragma: no cover (error case) - self.logger.warning("Look-back buffer error") - - if end_i > 0: - self.callback(name, data, 0, end_i) - # If we're getting remaining data, we have got all the data we - # can be certain is not a boundary, leaving only a partial boundary match. - if remaining: - self.marks[name] = end_i - length - else: - self.marks.pop(name, None) - - # For each byte... - while i < length: - c = data[i] - - if state == MultipartState.START: - # Skip leading newlines - if c == CR or c == LF: - i += 1 - continue - - # index is used as in index into our boundary. Set to 0. - index = 0 - - # Move to the next state, but decrement i so that we re-process - # this character. - state = MultipartState.START_BOUNDARY - i -= 1 - - elif state == MultipartState.START_BOUNDARY: - # Check to ensure that the last 2 characters in our boundary - # are CRLF. - if index == len(boundary) - 2: - if c == HYPHEN: - # Potential empty message. - state = MultipartState.END_BOUNDARY - elif c != CR: - # Error! - msg = "Did not find CR at end of boundary (%d)" % (i,) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - index += 1 - - elif index == len(boundary) - 2 + 1: - if c != LF: - msg = "Did not find LF at end of boundary (%d)" % (i,) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - # The index is now used for indexing into our boundary. - index = 0 - - # Callback for the start of a part. - self.callback("part_begin") - - # Move to the next character and state. - state = MultipartState.HEADER_FIELD_START - - else: - # Check to ensure our boundary matches - if c != boundary[index + 2]: - msg = "Expected boundary character %r, got %r at index %d" % (boundary[index + 2], c, index + 2) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - # Increment index into boundary and continue. - index += 1 - - elif state == MultipartState.HEADER_FIELD_START: - # Mark the start of a header field here, reset the index, and - # continue parsing our header field. - index = 0 - - # Set a mark of our header field. - set_mark("header_field") - - # Notify that we're starting a header if the next character is - # not a CR; a CR at the beginning of the header will cause us - # to stop parsing headers in the MultipartState.HEADER_FIELD state, - # below. - if c != CR: - self.callback("header_begin") - - # Move to parsing header fields. - state = MultipartState.HEADER_FIELD - i -= 1 - - elif state == MultipartState.HEADER_FIELD: - # If we've reached a CR at the beginning of a header, it means - # that we've reached the second of 2 newlines, and so there are - # no more headers to parse. - if c == CR and index == 0: - delete_mark("header_field") - state = MultipartState.HEADERS_ALMOST_DONE - i += 1 - continue - - # Increment our index in the header. - index += 1 - - # If we've reached a colon, we're done with this header. - if c == COLON: - # A 0-length header is an error. - if index == 1: - msg = "Found 0-length header at %d" % (i,) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - # Call our callback with the header field. - data_callback("header_field", i) - - # Move to parsing the header value. - state = MultipartState.HEADER_VALUE_START - - elif c not in TOKEN_CHARS_SET: - msg = "Found invalid character %r in header at %d" % (c, i) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - elif state == MultipartState.HEADER_VALUE_START: - # Skip leading spaces. - if c == SPACE: - i += 1 - continue - - # Mark the start of the header value. - set_mark("header_value") - - # Move to the header-value state, reprocessing this character. - state = MultipartState.HEADER_VALUE - i -= 1 - - elif state == MultipartState.HEADER_VALUE: - # If we've got a CR, we're nearly done our headers. Otherwise, - # we do nothing and just move past this character. - if c == CR: - data_callback("header_value", i) - self.callback("header_end") - state = MultipartState.HEADER_VALUE_ALMOST_DONE - - elif state == MultipartState.HEADER_VALUE_ALMOST_DONE: - # The last character should be a LF. If not, it's an error. - if c != LF: - msg = "Did not find LF character at end of header " "(found %r)" % (c,) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - # Move back to the start of another header. Note that if that - # state detects ANOTHER newline, it'll trigger the end of our - # headers. - state = MultipartState.HEADER_FIELD_START - - elif state == MultipartState.HEADERS_ALMOST_DONE: - # We're almost done our headers. This is reached when we parse - # a CR at the beginning of a header, so our next character - # should be a LF, or it's an error. - if c != LF: - msg = f"Did not find LF at end of headers (found {c!r})" - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - self.callback("headers_finished") - state = MultipartState.PART_DATA_START - - elif state == MultipartState.PART_DATA_START: - # Mark the start of our part data. - set_mark("part_data") - - # Start processing part data, including this character. - state = MultipartState.PART_DATA - i -= 1 - - elif state == MultipartState.PART_DATA: - # We're processing our part data right now. During this, we - # need to efficiently search for our boundary, since any data - # on any number of lines can be a part of the current data. - - # Save the current value of our index. We use this in case we - # find part of a boundary, but it doesn't match fully. - prev_index = index - - # Set up variables. - boundary_length = len(boundary) - data_length = length - - # If our index is 0, we're starting a new part, so start our - # search. - if index == 0: - # The most common case is likely to be that the whole - # boundary is present in the buffer. - # Calling `find` is much faster than iterating here. - i0 = data.find(boundary, i, data_length) - if i0 >= 0: - # We matched the whole boundary string. - index = boundary_length - 1 - i = i0 + boundary_length - 1 - else: - # No match found for whole string. - # There may be a partial boundary at the end of the - # data, which the find will not match. - # Since the length should to be searched is limited to - # the boundary length, just perform a naive search. - i = max(i, data_length - boundary_length) - - # Search forward until we either hit the end of our buffer, - # or reach a potential start of the boundary. - while i < data_length - 1 and data[i] != boundary[0]: - i += 1 - - c = data[i] - - # Now, we have a couple of cases here. If our index is before - # the end of the boundary... - if index < boundary_length: - # If the character matches... - if boundary[index] == c: - # The current character matches, so continue! - index += 1 - else: - index = 0 - - # Our index is equal to the length of our boundary! - elif index == boundary_length: - # First we increment it. - index += 1 - - # Now, if we've reached a newline, we need to set this as - # the potential end of our boundary. - if c == CR: - flags |= FLAG_PART_BOUNDARY - - # Otherwise, if this is a hyphen, we might be at the last - # of all boundaries. - elif c == HYPHEN: - flags |= FLAG_LAST_BOUNDARY - - # Otherwise, we reset our index, since this isn't either a - # newline or a hyphen. - else: - index = 0 - - # Our index is right after the part boundary, which should be - # a LF. - elif index == boundary_length + 1: - # If we're at a part boundary (i.e. we've seen a CR - # character already)... - if flags & FLAG_PART_BOUNDARY: - # We need a LF character next. - if c == LF: - # Unset the part boundary flag. - flags &= ~FLAG_PART_BOUNDARY - - # We have identified a boundary, callback for any data before it. - data_callback("part_data", i - index) - # Callback indicating that we've reached the end of - # a part, and are starting a new one. - self.callback("part_end") - self.callback("part_begin") - - # Move to parsing new headers. - index = 0 - state = MultipartState.HEADER_FIELD_START - i += 1 - continue - - # We didn't find an LF character, so no match. Reset - # our index and clear our flag. - index = 0 - flags &= ~FLAG_PART_BOUNDARY - - # Otherwise, if we're at the last boundary (i.e. we've - # seen a hyphen already)... - elif flags & FLAG_LAST_BOUNDARY: - # We need a second hyphen here. - if c == HYPHEN: - # We have identified a boundary, callback for any data before it. - data_callback("part_data", i - index) - # Callback to end the current part, and then the - # message. - self.callback("part_end") - self.callback("end") - state = MultipartState.END - else: - # No match, so reset index. - index = 0 - - # Otherwise, our index is 0. If the previous index is not, it - # means we reset something, and we need to take the data we - # thought was part of our boundary and send it along as actual - # data. - if index == 0 and prev_index > 0: - # Overwrite our previous index. - prev_index = 0 - - # Re-consider the current character, since this could be - # the start of the boundary itself. - i -= 1 - - elif state == MultipartState.END_BOUNDARY: - if index == len(boundary) - 2 + 1: - if c != HYPHEN: - msg = "Did not find - at end of boundary (%d)" % (i,) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - index += 1 - self.callback("end") - state = MultipartState.END - - elif state == MultipartState.END: - # Don't do anything if chunk ends with CRLF. - if c == CR and i + 1 < length and data[i + 1] == LF: - i += 2 - continue - # Skip data after the last boundary. - self.logger.warning("Skipping data after last boundary") - i = length - break - - else: # pragma: no cover (error case) - # We got into a strange state somehow! Just stop processing. - msg = "Reached an unknown state %d at %d" % (state, i) - self.logger.warning(msg) - e = MultipartParseError(msg) - e.offset = i - raise e - - # Move to the next byte. - i += 1 - - # We call our callbacks with any remaining data. Note that we pass - # the 'remaining' flag, which sets the mark back to 0 instead of - # deleting it, if it's found. This is because, if the mark is found - # at this point, we assume that there's data for one of these things - # that has been parsed, but not yet emitted. And, as such, it implies - # that we haven't yet reached the end of this 'thing'. So, by setting - # the mark to 0, we cause any data callbacks that take place in future - # calls to this function to start from the beginning of that buffer. - data_callback("header_field", length, True) - data_callback("header_value", length, True) - data_callback("part_data", length - index, True) - - # Save values to locals. - self.state = state - self.index = index - self.flags = flags - - # Return our data length to indicate no errors, and that we processed - # all of it. - return length - - def finalize(self) -> None: - """Finalize this parser, which signals to that we are finished parsing. - - Note: It does not currently, but in the future, it will verify that we - are in the final state of the parser (i.e. the end of the multipart - message is well-formed), and, if not, throw an error. - """ - # TODO: verify that we're in the state MultipartState.END, otherwise throw an - # error or otherwise state that we're not finished parsing. - pass - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(boundary={self.boundary!r})" - - -class FormParser: - """This class is the all-in-one form parser. Given all the information - necessary to parse a form, it will instantiate the correct parser, create - the proper :class:`Field` and :class:`File` classes to store the data that - is parsed, and call the two given callbacks with each field and file as - they become available. - - Args: - content_type: The Content-Type of the incoming request. This is used to select the appropriate parser. - on_field: The callback to call when a field has been parsed and is ready for usage. See above for parameters. - on_file: The callback to call when a file has been parsed and is ready for usage. See above for parameters. - on_end: An optional callback to call when all fields and files in a request has been parsed. Can be None. - boundary: If the request is a multipart/form-data request, this should be the boundary of the request, as given - in the Content-Type header, as a bytestring. - file_name: If the request is of type application/octet-stream, then the body of the request will not contain any - information about the uploaded file. In such cases, you can provide the file name of the uploaded file - manually. - FileClass: The class to use for uploaded files. Defaults to :class:`File`, but you can provide your own class - if you wish to customize behaviour. The class will be instantiated as FileClass(file_name, field_name), and - it must provide the following functions:: - - file_instance.write(data) - - file_instance.finalize() - - file_instance.close() - FieldClass: The class to use for uploaded fields. Defaults to :class:`Field`, but you can provide your own - class if you wish to customize behaviour. The class will be instantiated as FieldClass(field_name), and it - must provide the following functions:: - - field_instance.write(data) - - field_instance.finalize() - - field_instance.close() - - field_instance.set_none() - config: Configuration to use for this FormParser. The default values are taken from the DEFAULT_CONFIG value, - and then any keys present in this dictionary will overwrite the default values. - """ - - #: This is the default configuration for our form parser. - #: Note: all file sizes should be in bytes. - DEFAULT_CONFIG: FormParserConfig = { - "MAX_BODY_SIZE": float("inf"), - "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024, - "UPLOAD_DIR": None, - "UPLOAD_KEEP_FILENAME": False, - "UPLOAD_KEEP_EXTENSIONS": False, - # Error on invalid Content-Transfer-Encoding? - "UPLOAD_ERROR_ON_BAD_CTE": False, - } - - def __init__( - self, - content_type: str, - on_field: OnFieldCallback | None, - on_file: OnFileCallback | None, - on_end: Callable[[], None] | None = None, - boundary: bytes | str | None = None, - file_name: bytes | None = None, - FileClass: type[FileProtocol] = File, - FieldClass: type[FieldProtocol] = Field, - config: dict[Any, Any] = {}, - ) -> None: - self.logger = logging.getLogger(__name__) - - # Save variables. - self.content_type = content_type - self.boundary = boundary - self.bytes_received = 0 - self.parser = None - - # Save callbacks. - self.on_field = on_field - self.on_file = on_file - self.on_end = on_end - - # Save classes. - self.FileClass = File - self.FieldClass = Field - - # Set configuration options. - self.config: FormParserConfig = self.DEFAULT_CONFIG.copy() - self.config.update(config) # type: ignore[typeddict-item] - - parser: OctetStreamParser | MultipartParser | QuerystringParser | None = None - - # Depending on the Content-Type, we instantiate the correct parser. - if content_type == "application/octet-stream": - file: FileProtocol = None # type: ignore - - def on_start() -> None: - nonlocal file - file = FileClass(file_name, None, config=cast("FileConfig", self.config)) - - def on_data(data: bytes, start: int, end: int) -> None: - nonlocal file - file.write(data[start:end]) - - def _on_end() -> None: - nonlocal file - # Finalize the file itself. - file.finalize() - - # Call our callback. - if on_file: - on_file(file) - - # Call the on-end callback. - if self.on_end is not None: - self.on_end() - - # Instantiate an octet-stream parser - parser = OctetStreamParser( - callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end}, - max_size=self.config["MAX_BODY_SIZE"], - ) - - elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": - name_buffer: list[bytes] = [] - - f: FieldProtocol | None = None - - def on_field_start() -> None: - pass - - def on_field_name(data: bytes, start: int, end: int) -> None: - name_buffer.append(data[start:end]) - - def on_field_data(data: bytes, start: int, end: int) -> None: - nonlocal f - if f is None: - f = FieldClass(b"".join(name_buffer)) - del name_buffer[:] - f.write(data[start:end]) - - def on_field_end() -> None: - nonlocal f - # Finalize and call callback. - if f is None: - # If we get here, it's because there was no field data. - # We create a field, set it to None, and then continue. - f = FieldClass(b"".join(name_buffer)) - del name_buffer[:] - f.set_none() - - f.finalize() - if on_field: - on_field(f) - f = None - - def _on_end() -> None: - if self.on_end is not None: - self.on_end() - - # Instantiate parser. - parser = QuerystringParser( - callbacks={ - "on_field_start": on_field_start, - "on_field_name": on_field_name, - "on_field_data": on_field_data, - "on_field_end": on_field_end, - "on_end": _on_end, - }, - max_size=self.config["MAX_BODY_SIZE"], - ) - - elif content_type == "multipart/form-data": - if boundary is None: - self.logger.error("No boundary given") - raise FormParserError("No boundary given") - - header_name: list[bytes] = [] - header_value: list[bytes] = [] - headers: dict[bytes, bytes] = {} - - f_multi: FileProtocol | FieldProtocol | None = None - writer = None - is_file = False - - def on_part_begin() -> None: - # Reset headers in case this isn't the first part. - nonlocal headers - headers = {} - - def on_part_data(data: bytes, start: int, end: int) -> None: - nonlocal writer - assert writer is not None - writer.write(data[start:end]) - # TODO: check for error here. - - def on_part_end() -> None: - nonlocal f_multi, is_file - assert f_multi is not None - f_multi.finalize() - if is_file: - if on_file: - on_file(f_multi) - else: - if on_field: - on_field(cast("FieldProtocol", f_multi)) - - def on_header_field(data: bytes, start: int, end: int) -> None: - header_name.append(data[start:end]) - - def on_header_value(data: bytes, start: int, end: int) -> None: - header_value.append(data[start:end]) - - def on_header_end() -> None: - headers[b"".join(header_name)] = b"".join(header_value) - del header_name[:] - del header_value[:] - - def on_headers_finished() -> None: - nonlocal is_file, f_multi, writer - # Reset the 'is file' flag. - is_file = False - - # Parse the content-disposition header. - # TODO: handle mixed case - content_disp = headers.get(b"Content-Disposition") - disp, options = parse_options_header(content_disp) - - # Get the field and filename. - field_name = options.get(b"name") - file_name = options.get(b"filename") - # TODO: check for errors - - # Create the proper class. - if file_name is None: - f_multi = FieldClass(field_name) - else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config)) - is_file = True - - # Parse the given Content-Transfer-Encoding to determine what - # we need to do with the incoming data. - # TODO: check that we properly handle 8bit / 7bit encoding. - transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") - - if transfer_encoding in (b"binary", b"8bit", b"7bit"): - writer = f_multi - - elif transfer_encoding == b"base64": - writer = Base64Decoder(f_multi) - - elif transfer_encoding == b"quoted-printable": - writer = QuotedPrintableDecoder(f_multi) - - else: - self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding) - if self.config["UPLOAD_ERROR_ON_BAD_CTE"]: - raise FormParserError('Unknown Content-Transfer-Encoding "{!r}"'.format(transfer_encoding)) - else: - # If we aren't erroring, then we just treat this as an - # unencoded Content-Transfer-Encoding. - writer = f_multi - - def _on_end() -> None: - nonlocal writer - if writer is not None: - writer.finalize() - if self.on_end is not None: - self.on_end() - - # Instantiate a multipart parser. - parser = MultipartParser( - boundary, - callbacks={ - "on_part_begin": on_part_begin, - "on_part_data": on_part_data, - "on_part_end": on_part_end, - "on_header_field": on_header_field, - "on_header_value": on_header_value, - "on_header_end": on_header_end, - "on_headers_finished": on_headers_finished, - "on_end": _on_end, - }, - max_size=self.config["MAX_BODY_SIZE"], - ) - - else: - self.logger.warning("Unknown Content-Type: %r", content_type) - raise FormParserError("Unknown Content-Type: {}".format(content_type)) - - self.parser = parser - - def write(self, data: bytes) -> int: - """Write some data. The parser will forward this to the appropriate - underlying parser. - - Args: - data: The data to write. - - Returns: - The number of bytes processed. - """ - self.bytes_received += len(data) - # TODO: check the parser's return value for errors? - assert self.parser is not None - return self.parser.write(data) - - def finalize(self) -> None: - """Finalize the parser.""" - if self.parser is not None and hasattr(self.parser, "finalize"): - self.parser.finalize() - - def close(self) -> None: - """Close the parser.""" - if self.parser is not None and hasattr(self.parser, "close"): - self.parser.close() - - def __repr__(self) -> str: - return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser) - - -def create_form_parser( - headers: dict[str, bytes], - on_field: OnFieldCallback | None, - on_file: OnFileCallback | None, - trust_x_headers: bool = False, - config: dict[Any, Any] = {}, -) -> FormParser: - """This function is a helper function to aid in creating a FormParser - instances. Given a dictionary-like headers object, it will determine - the correct information needed, instantiate a FormParser with the - appropriate values and given callbacks, and then return the corresponding - parser. - - Args: - headers: A dictionary-like object of HTTP headers. The only required header is Content-Type. - on_field: Callback to call with each parsed field. - on_file: Callback to call with each parsed file. - trust_x_headers: Whether or not to trust information received from certain X-Headers - for example, the file - name from X-File-Name. - config: Configuration variables to pass to the FormParser. - """ - content_type: str | bytes | None = headers.get("Content-Type") - if content_type is None: - logging.getLogger(__name__).warning("No Content-Type header given") - raise ValueError("No Content-Type header given!") - - # Boundaries are optional (the FormParser will raise if one is needed - # but not given). - content_type, params = parse_options_header(content_type) - boundary = params.get(b"boundary") - - # We need content_type to be a string, not a bytes object. - content_type = content_type.decode("latin-1") - - # File names are optional. - file_name = headers.get("X-File-Name") - - # Instantiate a form parser. - form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config) - - # Return our parser. - return form_parser - - -def parse_form( - headers: dict[str, bytes], - input_stream: SupportsRead, - on_field: OnFieldCallback | None, - on_file: OnFileCallback | None, - chunk_size: int = 1048576, -) -> None: - """This function is useful if you just want to parse a request body, - without too much work. Pass it a dictionary-like object of the request's - headers, and a file-like object for the input stream, along with two - callbacks that will get called whenever a field or file is parsed. - - Args: - headers: A dictionary-like object of HTTP headers. The only required header is Content-Type. - input_stream: A file-like object that represents the request body. The read() method must return bytestrings. - on_field: Callback to call with each parsed field. - on_file: Callback to call with each parsed file. - chunk_size: The maximum size to read from the input stream and write to the parser at one time. - Defaults to 1 MiB. - """ - # Create our form parser. - parser = create_form_parser(headers, on_field, on_file) - - # Read chunks of 1MiB and write to the parser, but never read more than - # the given Content-Length, if any. - content_length: int | float | bytes | None = headers.get("Content-Length") - if content_length is not None: - content_length = int(content_length) - else: - content_length = float("inf") - bytes_read = 0 - - while True: - # Read only up to the Content-Length given. - max_readable = int(min(content_length - bytes_read, chunk_size)) - buff = input_stream.read(max_readable) - - # Write to the parser and update our length. - parser.write(buff) - bytes_read += len(buff) - - # If we get a buffer that's smaller than the size requested, or if we - # have read up to our content length, we're done. - if len(buff) != max_readable or bytes_read == content_length: - break - - # Tell our parser that we're done writing data. - parser.finalize() diff --git a/venv/lib/python3.12/site-packages/python_multipart/py.typed b/venv/lib/python3.12/site-packages/python_multipart/py.typed deleted file mode 100644 index e69de29..0000000 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/INSTALLER b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/INSTALLER similarity index 100% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/INSTALLER rename to venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/INSTALLER diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/licenses/LICENSE b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/LICENSE similarity index 100% rename from venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/licenses/LICENSE rename to venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/LICENSE diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/METADATA b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/METADATA similarity index 68% rename from venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/METADATA rename to venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/METADATA index 5262348..98242bc 100644 --- a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/METADATA +++ b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/METADATA @@ -1,16 +1,17 @@ -Metadata-Version: 2.4 +Metadata-Version: 2.1 Name: redis -Version: 6.4.0 +Version: 5.0.1 Summary: Python client for Redis database and key-value store +Home-page: https://github.com/redis/redis-py +Author: Redis Inc. +Author-email: oss@redis.com +License: MIT +Project-URL: Documentation, https://redis.readthedocs.io/en/latest/ Project-URL: Changes, https://github.com/redis/redis-py/releases Project-URL: Code, https://github.com/redis/redis-py -Project-URL: Documentation, https://redis.readthedocs.io/en/latest/ -Project-URL: Homepage, https://github.com/redis/redis-py Project-URL: Issue tracker, https://github.com/redis/redis-py/issues -Author-email: "Redis Inc." -License-Expression: MIT -License-File: LICENSE -Keywords: Redis,database,key-value-store +Keywords: Redis,key-value store,database +Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Environment :: Console Classifier: Intended Audience :: Developers @@ -19,31 +20,32 @@ Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 -Classifier: Programming Language :: Python :: 3.12 -Classifier: Programming Language :: Python :: 3.13 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy -Requires-Python: >=3.9 -Requires-Dist: async-timeout>=4.0.3; python_full_version < '3.11.3' -Provides-Extra: hiredis -Requires-Dist: hiredis>=3.2.0; extra == 'hiredis' -Provides-Extra: jwt -Requires-Dist: pyjwt>=2.9.0; extra == 'jwt' -Provides-Extra: ocsp -Requires-Dist: cryptography>=36.0.1; extra == 'ocsp' -Requires-Dist: pyopenssl>=20.0.1; extra == 'ocsp' -Requires-Dist: requests>=2.31.0; extra == 'ocsp' +Requires-Python: >=3.7 Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: async-timeout >=4.0.2 ; python_full_version <= "3.11.2" +Requires-Dist: importlib-metadata >=1.0 ; python_version < "3.8" +Requires-Dist: typing-extensions ; python_version < "3.8" +Provides-Extra: hiredis +Requires-Dist: hiredis >=1.0.0 ; extra == 'hiredis' +Provides-Extra: ocsp +Requires-Dist: cryptography >=36.0.1 ; extra == 'ocsp' +Requires-Dist: pyopenssl ==20.0.1 ; extra == 'ocsp' +Requires-Dist: requests >=2.26.0 ; extra == 'ocsp' # redis-py The Python interface to the Redis key-value store. [![CI](https://github.com/redis/redis-py/workflows/CI/badge.svg?branch=master)](https://github.com/redis/redis-py/actions?query=workflow%3ACI+branch%3Amaster) -[![docs](https://readthedocs.org/projects/redis/badge/?version=stable&style=flat)](https://redis.readthedocs.io/en/stable/) +[![docs](https://readthedocs.org/projects/redis/badge/?version=stable&style=flat)](https://redis-py.readthedocs.io/en/stable/) [![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![pypi](https://badge.fury.io/py/redis.svg)](https://pypi.org/project/redis/) [![pre-release](https://img.shields.io/github/v/release/redis/redis-py?include_prereleases&label=latest-prerelease)](https://github.com/redis/redis-py/releases) @@ -53,35 +55,18 @@ The Python interface to the Redis key-value store. --------------------------------------------- -**Note:** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. -**Note:** redis-py 6.1.0 will be the last version of redis-py to support Python 3.8, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 6.2.0 will support Python 3.9+. +**Note: ** redis-py 5.0 will be the last version of redis-py to support Python 3.7, as it has reached [end of life](https://devguide.python.org/versions/). redis-py 5.1 will support Python 3.8+. + --------------------------------------------- -## How do I Redis? - -[Learn for free at Redis University](https://redis.io/learn/university) - -[Try the Redis Cloud](https://redis.io/try-free/) - -[Dive in developer tutorials](https://redis.io/learn) - -[Join the Redis community](https://redis.io/community/) - -[Work at Redis](https://redis.io/careers/) - ## Installation -Start a redis via docker (for Redis versions >= 8.0): - -``` bash -docker run -p 6379:6379 -it redis:latest -``` - -Start a redis via docker (for Redis versions < 8.0): +Start a redis via docker: ``` bash docker run -p 6379:6379 -it redis/redis-stack:latest ``` + To install redis-py, simply: ``` bash @@ -99,7 +84,7 @@ Looking for a high-level library to handle object mapping? See [redis-om-python] ## Supported Redis Versions -The most recent version of this library supports Redis version [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES), [7.4](https://github.com/redis/redis/blob/7.4/00-RELEASENOTES) and [8.0](https://github.com/redis/redis/blob/8.0/00-RELEASENOTES). +The most recent version of this library supports redis version [5.0](https://github.com/redis/redis/blob/5.0/00-RELEASENOTES), [6.0](https://github.com/redis/redis/blob/6.0/00-RELEASENOTES), [6.2](https://github.com/redis/redis/blob/6.2/00-RELEASENOTES), [7.0](https://github.com/redis/redis/blob/7.0/00-RELEASENOTES) and [7.2](https://github.com/redis/redis/blob/7.2/00-RELEASENOTES). The table below highlights version compatibility of the most-recent library versions and redis versions. @@ -107,8 +92,7 @@ The table below highlights version compatibility of the most-recent library vers |-----------------|-------------------| | 3.5.3 | <= 6.2 Family of releases | | >= 4.5.0 | Version 5.0 to 7.0 | -| >= 5.0.0 | Version 5.0 to 7.4 | -| >= 6.0.0 | Version 7.2 to current | +| >= 5.0.0 | Version 5.0 to current | ## Usage @@ -198,46 +182,12 @@ The following example shows how to utilize [Redis Pub/Sub](https://redis.io/docs {'pattern': None, 'type': 'subscribe', 'channel': b'my-second-channel', 'data': 1} ``` -### Redis’ search and query capabilities default dialect -Release 6.0.0 introduces a client-side default dialect for Redis’ search and query capabilities. -By default, the client now overrides the server-side dialect with version 2, automatically appending *DIALECT 2* to commands like *FT.AGGREGATE* and *FT.SEARCH*. - -**Important**: Be aware that the query dialect may impact the results returned. If needed, you can revert to a different dialect version by configuring the client accordingly. - -``` python ->>> from redis.commands.search.field import TextField ->>> from redis.commands.search.query import Query ->>> from redis.commands.search.index_definition import IndexDefinition ->>> import redis - ->>> r = redis.Redis(host='localhost', port=6379, db=0) ->>> r.ft().create_index( ->>> (TextField("name"), TextField("lastname")), ->>> definition=IndexDefinition(prefix=["test:"]), ->>> ) - ->>> r.hset("test:1", "name", "James") ->>> r.hset("test:1", "lastname", "Brown") - ->>> # Query with default DIALECT 2 ->>> query = "@name: James Brown" ->>> q = Query(query) ->>> res = r.ft().search(q) - ->>> # Query with explicit DIALECT 1 ->>> query = "@name: James Brown" ->>> q = Query(query).dialect(1) ->>> res = r.ft().search(q) -``` - -You can find further details in the [query dialect documentation](https://redis.io/docs/latest/develop/interact/search-and-query/advanced-concepts/dialects/). - ---------------------------------------------- +-------------------------- ### Author -redis-py is developed and maintained by [Redis Inc](https://redis.io). It can be found [here]( +redis-py is developed and maintained by [Redis Inc](https://redis.com). It can be found [here]( https://github.com/redis/redis-py), or downloaded from [pypi](https://pypi.org/project/redis/). Special thanks to: @@ -249,4 +199,5 @@ Special thanks to: system. - Paul Hubbard for initial packaging support. -[![Redis](./docs/_static/logo-redis.svg)](https://redis.io) +[![Redis](./docs/logo-redis.png)](https://www.redis.com) + diff --git a/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/RECORD b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/RECORD new file mode 100644 index 0000000..537058f --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/RECORD @@ -0,0 +1,148 @@ +redis-5.0.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +redis-5.0.1.dist-info/LICENSE,sha256=pXslClvwPXr-VbdAYzE_Ktt7ANVGwKsUmok5gzP-PMg,1074 +redis-5.0.1.dist-info/METADATA,sha256=xLwWid1Pns_mCEX6qn3qtFxtf7pphgPFPWOwEg5LWrQ,8910 +redis-5.0.1.dist-info/RECORD,, +redis-5.0.1.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +redis-5.0.1.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92 +redis-5.0.1.dist-info/top_level.txt,sha256=OMAefszlde6ZoOtlM35AWzpRIrwtcqAMHGlRit-w2-4,6 +redis/__init__.py,sha256=PthSOEfXKlYV9xBgroOnO2tJD7uu0BWwvztgsKUvK48,2110 +redis/__pycache__/__init__.cpython-312.pyc,, +redis/__pycache__/backoff.cpython-312.pyc,, +redis/__pycache__/client.cpython-312.pyc,, +redis/__pycache__/cluster.cpython-312.pyc,, +redis/__pycache__/compat.cpython-312.pyc,, +redis/__pycache__/connection.cpython-312.pyc,, +redis/__pycache__/crc.cpython-312.pyc,, +redis/__pycache__/credentials.cpython-312.pyc,, +redis/__pycache__/exceptions.cpython-312.pyc,, +redis/__pycache__/lock.cpython-312.pyc,, +redis/__pycache__/ocsp.cpython-312.pyc,, +redis/__pycache__/retry.cpython-312.pyc,, +redis/__pycache__/sentinel.cpython-312.pyc,, +redis/__pycache__/typing.cpython-312.pyc,, +redis/__pycache__/utils.cpython-312.pyc,, +redis/_parsers/__init__.py,sha256=qkfgV2X9iyvQAvbLdSelwgz0dCk9SGAosCvuZC9-qDc,550 +redis/_parsers/__pycache__/__init__.cpython-312.pyc,, +redis/_parsers/__pycache__/base.cpython-312.pyc,, +redis/_parsers/__pycache__/commands.cpython-312.pyc,, +redis/_parsers/__pycache__/encoders.cpython-312.pyc,, +redis/_parsers/__pycache__/helpers.cpython-312.pyc,, +redis/_parsers/__pycache__/hiredis.cpython-312.pyc,, +redis/_parsers/__pycache__/resp2.cpython-312.pyc,, +redis/_parsers/__pycache__/resp3.cpython-312.pyc,, +redis/_parsers/__pycache__/socket.cpython-312.pyc,, +redis/_parsers/base.py,sha256=95SoPNwt4xJQB-ONIjxsR46n4EHnxnmkv9f0ReZSIR0,7480 +redis/_parsers/commands.py,sha256=pmR4hl4u93UvCmeDgePHFc6pWDr4slrKEvCsdMmtj_M,11052 +redis/_parsers/encoders.py,sha256=X0jvTp-E4TZUlZxV5LJJ88TuVrF1vly5tuC0xjxGaSc,1734 +redis/_parsers/helpers.py,sha256=xcRjjns6uQPb2pp0AOlOK9LhMJL4ofyEMFqVA7CwzsE,27947 +redis/_parsers/hiredis.py,sha256=X8yk0ElEEjHlhUgjs9fdHSOijlxYtunTrTJSLzkGrvQ,7581 +redis/_parsers/resp2.py,sha256=f22kH-_ZP2iNtOn6xOe65MSy_fJpu8OEn1u_hgeeojI,4813 +redis/_parsers/resp3.py,sha256=rXDA0R-wjCj2vyGaaWEf50NXN7UFBzefRnK3NGzWz2E,9657 +redis/_parsers/socket.py,sha256=CKD8QW_wFSNlIZzxlbNduaGpiv0I8wBcsGuAIojDfJg,5403 +redis/asyncio/__init__.py,sha256=uoDD8XYVi0Kj6mcufYwLDUTQXmBRx7a0bhKF9stZr7I,1489 +redis/asyncio/__pycache__/__init__.cpython-312.pyc,, +redis/asyncio/__pycache__/client.cpython-312.pyc,, +redis/asyncio/__pycache__/cluster.cpython-312.pyc,, +redis/asyncio/__pycache__/connection.cpython-312.pyc,, +redis/asyncio/__pycache__/lock.cpython-312.pyc,, +redis/asyncio/__pycache__/retry.cpython-312.pyc,, +redis/asyncio/__pycache__/sentinel.cpython-312.pyc,, +redis/asyncio/__pycache__/utils.cpython-312.pyc,, +redis/asyncio/client.py,sha256=BYurDT13lsw0N3a8sLqQFl00tFFolpET7_EujLw2Nbc,58826 +redis/asyncio/cluster.py,sha256=a0Za2icr03ytjF_WVohDMvEZejixUdVMhpsKWeMxYHY,63076 +redis/asyncio/connection.py,sha256=ZwClasZ2x0SQY90gDZvraFIx2lhGPnDm-xUUPPsb424,43426 +redis/asyncio/lock.py,sha256=lLasXEO2E1CskhX5ZZoaSGpmwZP1Q782R3HAUNG3wD4,11967 +redis/asyncio/retry.py,sha256=SnPPOlo5gcyIFtkC4DY7HFvmDgUaILsJ3DeHioogdB8,2219 +redis/asyncio/sentinel.py,sha256=sTVJCbi1KtIbHJc3fkHRZb_LGav_UtCAq-ipxltkGsE,14198 +redis/asyncio/utils.py,sha256=Yxc5YQumhLjtDDwCS4mgxI6yy2Z21AzLlFxVbxCohic,704 +redis/backoff.py,sha256=x-sAjV7u4MmdOjFZSZ8RnUnCaQtPhCBbGNBgICvCW3I,2966 +redis/client.py,sha256=IkqYEPg2WA35jBjPCpEgcKcVW3Hx8lm89j_IQ2dnoOw,57514 +redis/cluster.py,sha256=HcH2YM057xpWMQhGYBLWv5l9yrb7hzcSuPXXbqJl_DY,92754 +redis/commands/__init__.py,sha256=cTUH-MGvaLYS0WuoytyqtN1wniw2A1KbkUXcpvOSY3I,576 +redis/commands/__pycache__/__init__.cpython-312.pyc,, +redis/commands/__pycache__/cluster.cpython-312.pyc,, +redis/commands/__pycache__/core.cpython-312.pyc,, +redis/commands/__pycache__/helpers.cpython-312.pyc,, +redis/commands/__pycache__/redismodules.cpython-312.pyc,, +redis/commands/__pycache__/sentinel.cpython-312.pyc,, +redis/commands/bf/__init__.py,sha256=ESmQXH4p9Dp37tNCwQGDiF_BHDEaKnXSF7ZfASEqkFY,8027 +redis/commands/bf/__pycache__/__init__.cpython-312.pyc,, +redis/commands/bf/__pycache__/commands.cpython-312.pyc,, +redis/commands/bf/__pycache__/info.cpython-312.pyc,, +redis/commands/bf/commands.py,sha256=kVWUatdS0zLcu8-fVIqLLQBU5u8fJWIOCVUD3fqYVp0,21462 +redis/commands/bf/info.py,sha256=tpE4hv1zApxoOgyV9_8BEDZcl4Wf6tS1dSvtlxV7uTE,3395 +redis/commands/cluster.py,sha256=5BDwdeUnWVWOalF5fHD12HPQeDq_rc2vhuCI3sChrYE,31562 +redis/commands/core.py,sha256=2WM9nZ3f0Xqny8o5yucORe0fLRItJO4SWU68W5Wr1mw,223552 +redis/commands/graph/__init__.py,sha256=NmklyOuzIa20yEWrhnKQxgQlaXKYkcwBkGHpvQyo5J8,7237 +redis/commands/graph/__pycache__/__init__.cpython-312.pyc,, +redis/commands/graph/__pycache__/commands.cpython-312.pyc,, +redis/commands/graph/__pycache__/edge.cpython-312.pyc,, +redis/commands/graph/__pycache__/exceptions.cpython-312.pyc,, +redis/commands/graph/__pycache__/execution_plan.cpython-312.pyc,, +redis/commands/graph/__pycache__/node.cpython-312.pyc,, +redis/commands/graph/__pycache__/path.cpython-312.pyc,, +redis/commands/graph/__pycache__/query_result.cpython-312.pyc,, +redis/commands/graph/commands.py,sha256=rLGV58ZJKEf6yxzk1oD3IwiS03lP6bpbo0249pFI0OY,10379 +redis/commands/graph/edge.py,sha256=_TljVB4a1pPS9pb8_Cvw8rclbBOOI__-fY9fybU4djQ,2460 +redis/commands/graph/exceptions.py,sha256=kRDBsYLgwIaM4vqioO_Bp_ugWvjfqCH7DIv4Gpc9HCM,107 +redis/commands/graph/execution_plan.py,sha256=Pxr8_zhPWT_EdZSgGrbiWw8wFL6q5JF7O-Z6Xzm55iw,6742 +redis/commands/graph/node.py,sha256=Pasfsl5dF6WqT9KCNFAKKwGubyK_2ORCoAQE4VtnXkQ,2400 +redis/commands/graph/path.py,sha256=m6Gz4DYfMIQ8VReDLHlnQw_KI2rVdepWYk_AU0_x_GM,2080 +redis/commands/graph/query_result.py,sha256=GTEnBE0rAiUk4JquaxcVKdL1kzSMDWW5ky-iFTvRN84,17040 +redis/commands/helpers.py,sha256=WgfhdH3NCBW2Vqg-9PcP2EIKwzBkzb5CeqfdnPm2tTQ,4531 +redis/commands/json/__init__.py,sha256=llpDQz2kBNnJyfQfuh0-2oY-knMb6gAS0ADtPmaTKsM,4854 +redis/commands/json/__pycache__/__init__.cpython-312.pyc,, +redis/commands/json/__pycache__/_util.cpython-312.pyc,, +redis/commands/json/__pycache__/commands.cpython-312.pyc,, +redis/commands/json/__pycache__/decoders.cpython-312.pyc,, +redis/commands/json/__pycache__/path.cpython-312.pyc,, +redis/commands/json/_util.py,sha256=b_VQTh10FyLl8BtREfJfDagOJCyd6wTQQs8g63pi5GI,116 +redis/commands/json/commands.py,sha256=9P3NBFyWuRxWer5i__NtJx7oJZNnTOisfrHGhwaRfoA,15603 +redis/commands/json/decoders.py,sha256=a_IoMV_wgeJyUifD4P6HTcM9s6FhricwmzQcZRmc-Gw,1411 +redis/commands/json/path.py,sha256=0zaO6_q_FVMk1Bkhkb7Wcr8AF2Tfr69VhkKy1IBVhpA,393 +redis/commands/redismodules.py,sha256=7TfVzLj319mhsA6WEybsOdIPk4pC-1hScJg3H5hv3T4,2454 +redis/commands/search/__init__.py,sha256=happQFVF0j7P87p7LQsUK5AK0kuem9cA-xvVRdQWpos,5744 +redis/commands/search/__pycache__/__init__.cpython-312.pyc,, +redis/commands/search/__pycache__/_util.cpython-312.pyc,, +redis/commands/search/__pycache__/aggregation.cpython-312.pyc,, +redis/commands/search/__pycache__/commands.cpython-312.pyc,, +redis/commands/search/__pycache__/document.cpython-312.pyc,, +redis/commands/search/__pycache__/field.cpython-312.pyc,, +redis/commands/search/__pycache__/indexDefinition.cpython-312.pyc,, +redis/commands/search/__pycache__/query.cpython-312.pyc,, +redis/commands/search/__pycache__/querystring.cpython-312.pyc,, +redis/commands/search/__pycache__/reducers.cpython-312.pyc,, +redis/commands/search/__pycache__/result.cpython-312.pyc,, +redis/commands/search/__pycache__/suggestion.cpython-312.pyc,, +redis/commands/search/_util.py,sha256=VAguSwh_3dNtJwNU6Vle2CNdPE10_NUkPffD7GWFX48,193 +redis/commands/search/aggregation.py,sha256=8yQ1P31Qiy29xehlmN2ToCh73e-MHmOg_y0_UXfQDS8,10772 +redis/commands/search/commands.py,sha256=dpSMZ7hXjbAlrUL4h5GX6BtP4WibQZCO6Ylfo8qkAF0,36751 +redis/commands/search/document.py,sha256=g2R-PRgq-jN33_GLXzavvse4cpIHBMfjPfPK7tnE9Gc,413 +redis/commands/search/field.py,sha256=WxtOHgtm9S82_C0nzeT7fHRrWPkGflJnSXQRIiaVJmU,4518 +redis/commands/search/indexDefinition.py,sha256=VL2CMzjxN0HEIaTn88evnHX1fCEmytbik4vAmiiYSC8,2489 +redis/commands/search/query.py,sha256=blBcgFnurT9rkg4gI6j14EekWU_J9e_aDlryVCCWDjM,11564 +redis/commands/search/querystring.py,sha256=dE577kOqkCErNgO-IXI4xFVHI8kQE-JiH5ZRI_CKjHE,7597 +redis/commands/search/reducers.py,sha256=Scceylx8BjyqS-TJOdhNW63n6tecL9ojt4U5Sqho5UY,4220 +redis/commands/search/result.py,sha256=4H7LnOVWScti7WO2XYxjhiTu3QNIt2pZHO1eptXZDBk,2149 +redis/commands/search/suggestion.py,sha256=V_re6suDCoNc0ETn_P1t51FeK4pCamPwxZRxCY8jscE,1612 +redis/commands/sentinel.py,sha256=hRcIQ9x9nEkdcCsJzo6Ves6vk-3tsfQqfJTT_v3oLY0,4110 +redis/commands/timeseries/__init__.py,sha256=gkz6wshEzzQQryBOnrAqqQzttS-AHfXmuN_H1J38EbM,3459 +redis/commands/timeseries/__pycache__/__init__.cpython-312.pyc,, +redis/commands/timeseries/__pycache__/commands.cpython-312.pyc,, +redis/commands/timeseries/__pycache__/info.cpython-312.pyc,, +redis/commands/timeseries/__pycache__/utils.cpython-312.pyc,, +redis/commands/timeseries/commands.py,sha256=bFdk-609CnL-dTqMU5yQEiY-UCjVpLknHGDENQ2t-1U,33438 +redis/commands/timeseries/info.py,sha256=5deBInBtLPb3ZrVoSB4EhWkRPkSIW5Qd_98rMDnutnk,3207 +redis/commands/timeseries/utils.py,sha256=o7q7Fe1wgpdTLKyGY8Qi2VV6XKEBprhzmPdrFz3OIvo,1309 +redis/compat.py,sha256=tr-t9oHdeosrK3TvZySaLvP3ZlGqTZQaXtlTqiqp_8I,242 +redis/connection.py,sha256=fxHl5icHS3Mk2AhHeSGxcpMcY5aeHmq5589g2XyI_xg,50524 +redis/crc.py,sha256=Z3kXFtkY2LdgefnQMud1xr4vG5UYvA9LCMqNMX1ywu4,729 +redis/credentials.py,sha256=6VvFeReFp6vernGIWlIVOm8OmbNgoFYdd1wgsjZTnlk,738 +redis/exceptions.py,sha256=AzWeYEpVR1koUddMgvz0WZxmPX_jyksagoRf8FSSWKA,5103 +redis/lock.py,sha256=CwB_qo7ADDGSt_JqjQKSL1nKDCwdb-ASJsAlv0JO6mA,11564 +redis/ocsp.py,sha256=WwiGby6yZYR0D3lgnnQYmPKy-UAgYqGXi6A4jDBZGL4,11450 +redis/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +redis/retry.py,sha256=Ssp9s2hhDfyRs0rCRCaTgRtLR7NAYO5QMw4QflourGo,1817 +redis/sentinel.py,sha256=CErsD-c3mYFnXDttCY1OvpyUdfKcyD5F9Jv9Fd3iHuU,14175 +redis/typing.py,sha256=wjyihEjyGiJrigcs0-zhy7K-MzVy7uLidjszNdPHMug,2212 +redis/utils.py,sha256=87p7ImnihyIhiaqalVYh9Qq9JeaVwi_Y4GBzNaHAXJg,3381 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/REQUESTED b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/REQUESTED similarity index 100% rename from venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/REQUESTED rename to venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/REQUESTED diff --git a/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/WHEEL b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/WHEEL new file mode 100644 index 0000000..7e68873 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.41.2) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/top_level.txt b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/top_level.txt new file mode 100644 index 0000000..7800f0f --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis-5.0.1.dist-info/top_level.txt @@ -0,0 +1 @@ +redis diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/RECORD b/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/RECORD deleted file mode 100644 index 8c32ae0..0000000 --- a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/RECORD +++ /dev/null @@ -1,153 +0,0 @@ -redis-6.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -redis-6.4.0.dist-info/METADATA,sha256=bNX_u48QF0Co6COOwBo5eycG2FlBbBG8OeWnz2pO9jQ,10784 -redis-6.4.0.dist-info/RECORD,, -redis-6.4.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -redis-6.4.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87 -redis-6.4.0.dist-info/licenses/LICENSE,sha256=pXslClvwPXr-VbdAYzE_Ktt7ANVGwKsUmok5gzP-PMg,1074 -redis/__init__.py,sha256=fD_AFZRhHReFMbpmRFqSPaltxmtapfIPWyFVziJd0eI,2048 -redis/__pycache__/__init__.cpython-312.pyc,, -redis/__pycache__/backoff.cpython-312.pyc,, -redis/__pycache__/cache.cpython-312.pyc,, -redis/__pycache__/client.cpython-312.pyc,, -redis/__pycache__/cluster.cpython-312.pyc,, -redis/__pycache__/connection.cpython-312.pyc,, -redis/__pycache__/crc.cpython-312.pyc,, -redis/__pycache__/credentials.cpython-312.pyc,, -redis/__pycache__/event.cpython-312.pyc,, -redis/__pycache__/exceptions.cpython-312.pyc,, -redis/__pycache__/lock.cpython-312.pyc,, -redis/__pycache__/ocsp.cpython-312.pyc,, -redis/__pycache__/retry.cpython-312.pyc,, -redis/__pycache__/sentinel.cpython-312.pyc,, -redis/__pycache__/typing.cpython-312.pyc,, -redis/__pycache__/utils.cpython-312.pyc,, -redis/_parsers/__init__.py,sha256=gyf5dp918NuJAkWFl8sX1Z-qAvbX_40-_7YCTM6Rvjc,693 -redis/_parsers/__pycache__/__init__.cpython-312.pyc,, -redis/_parsers/__pycache__/base.cpython-312.pyc,, -redis/_parsers/__pycache__/commands.cpython-312.pyc,, -redis/_parsers/__pycache__/encoders.cpython-312.pyc,, -redis/_parsers/__pycache__/helpers.cpython-312.pyc,, -redis/_parsers/__pycache__/hiredis.cpython-312.pyc,, -redis/_parsers/__pycache__/resp2.cpython-312.pyc,, -redis/_parsers/__pycache__/resp3.cpython-312.pyc,, -redis/_parsers/__pycache__/socket.cpython-312.pyc,, -redis/_parsers/base.py,sha256=k6n7-oTmmzAUiiZpaB6Vfjzlj_torwBsaPBEYdOTDak,9908 -redis/_parsers/commands.py,sha256=pmR4hl4u93UvCmeDgePHFc6pWDr4slrKEvCsdMmtj_M,11052 -redis/_parsers/encoders.py,sha256=X0jvTp-E4TZUlZxV5LJJ88TuVrF1vly5tuC0xjxGaSc,1734 -redis/_parsers/helpers.py,sha256=Y6n14fE0eCYbF3TBuJxhycnJ1yHKiYoAJrOCUaiWolg,29223 -redis/_parsers/hiredis.py,sha256=iUjLT5OEgD4zqF_tg3Szmg1c_73RozXyjjAFsVYKCWM,10893 -redis/_parsers/resp2.py,sha256=f22kH-_ZP2iNtOn6xOe65MSy_fJpu8OEn1u_hgeeojI,4813 -redis/_parsers/resp3.py,sha256=tiZRbyJAnObqll2LQJ57Br-3jxwQcMocV4GQE_LpC6g,9883 -redis/_parsers/socket.py,sha256=CKD8QW_wFSNlIZzxlbNduaGpiv0I8wBcsGuAIojDfJg,5403 -redis/asyncio/__init__.py,sha256=uoDD8XYVi0Kj6mcufYwLDUTQXmBRx7a0bhKF9stZr7I,1489 -redis/asyncio/__pycache__/__init__.cpython-312.pyc,, -redis/asyncio/__pycache__/client.cpython-312.pyc,, -redis/asyncio/__pycache__/cluster.cpython-312.pyc,, -redis/asyncio/__pycache__/connection.cpython-312.pyc,, -redis/asyncio/__pycache__/lock.cpython-312.pyc,, -redis/asyncio/__pycache__/retry.cpython-312.pyc,, -redis/asyncio/__pycache__/sentinel.cpython-312.pyc,, -redis/asyncio/__pycache__/utils.cpython-312.pyc,, -redis/asyncio/client.py,sha256=6a5-txYcRMtObkb7Bfi08MKQQY01oy5NKpHAlfhIFNM,61905 -redis/asyncio/cluster.py,sha256=0nilDMyz_obavxJetO3S8fgBob8X7w4KIdfxdKftsZw,90146 -redis/asyncio/connection.py,sha256=D28OecfufSf6c2gJ8UhJhorhWMpHeFHxxIaWxvvQHoc,49197 -redis/asyncio/lock.py,sha256=GxgV6EsyKpMjh74KtaOPxh4fNPuwApz6Th46qhvrAws,12801 -redis/asyncio/retry.py,sha256=Ikm0rsvnFItracA89DdPcejLqb_Sr4QBz73Ow_LUmwU,1880 -redis/asyncio/sentinel.py,sha256=Ppk-jlTubcHpa0lvinZ1pPTtQ5rFHXZkkaCZ7G_TCQs,14868 -redis/asyncio/utils.py,sha256=31xFzXczDgSRyf6hSjiwue1eDQ_XlP_OJdp5dKxW_aE,718 -redis/auth/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -redis/auth/__pycache__/__init__.cpython-312.pyc,, -redis/auth/__pycache__/err.cpython-312.pyc,, -redis/auth/__pycache__/idp.cpython-312.pyc,, -redis/auth/__pycache__/token.cpython-312.pyc,, -redis/auth/__pycache__/token_manager.cpython-312.pyc,, -redis/auth/err.py,sha256=WYkbuDIzwp1S-eAvsya6QMlO6g9QIXbzMITOsTWX0xk,694 -redis/auth/idp.py,sha256=IMDIIb9q72vbIwtFN8vPdaAKZVTdh0HuC5uj5ufqmw4,631 -redis/auth/token.py,sha256=qYwAgxFW3S93QDUqp1BTsj7Pj9ZohnixGeOX0s7AsjY,3317 -redis/auth/token_manager.py,sha256=ShBsYXiBZBJBOMB_Y-pXfLwEOAmc9s1okaCECinNZ7g,12018 -redis/backoff.py,sha256=tQM6Lh2g2FjMH8iXg94br2sU9eri4mEW9FbOrMt0azs,5285 -redis/cache.py,sha256=68rJDNogvNwgdgBel6zSX9QziL11qsKIMhmvQvHvznM,9549 -redis/client.py,sha256=Xmo6va8oKg7ksD8tv5-EErCFq3OhpfeISuR-nWBIRSA,62463 -redis/cluster.py,sha256=CgKGFnprziYjsr--qWbhY--2oaaWQRbuKofi1Qr9m5c,124120 -redis/commands/__init__.py,sha256=cTUH-MGvaLYS0WuoytyqtN1wniw2A1KbkUXcpvOSY3I,576 -redis/commands/__pycache__/__init__.cpython-312.pyc,, -redis/commands/__pycache__/cluster.cpython-312.pyc,, -redis/commands/__pycache__/core.cpython-312.pyc,, -redis/commands/__pycache__/helpers.cpython-312.pyc,, -redis/commands/__pycache__/redismodules.cpython-312.pyc,, -redis/commands/__pycache__/sentinel.cpython-312.pyc,, -redis/commands/bf/__init__.py,sha256=qk4DA9KsMiP4WYqYeP1T5ScBwctsVtlLyMhrYIyq1Zc,8019 -redis/commands/bf/__pycache__/__init__.cpython-312.pyc,, -redis/commands/bf/__pycache__/commands.cpython-312.pyc,, -redis/commands/bf/__pycache__/info.cpython-312.pyc,, -redis/commands/bf/commands.py,sha256=xeKt8E7G8HB-l922J0DLg07CEIZTVNGx_2Lfyw1gIck,21283 -redis/commands/bf/info.py,sha256=_OB2v_hAPI9mdVNiBx8jUtH2MhMoct9ZRm-e8In6wQo,3355 -redis/commands/cluster.py,sha256=vdWdpl4mP51oqfYBZHg5CUXt6jPaNp7aCLHyTieDrt8,31248 -redis/commands/core.py,sha256=RjVbTxe_vfnraVOqREH6ofNU2LMX8-ZGSAzd5g3ypvE,241132 -redis/commands/helpers.py,sha256=VCoPdBMCr4wxdWBw1EB9R7ZBbQM0exAG1kws4XwsCII,3318 -redis/commands/json/__init__.py,sha256=bznXhLYR652rfLfLp8cz0ZN0Yr8IRx4FgON_tq9_2Io,4845 -redis/commands/json/__pycache__/__init__.cpython-312.pyc,, -redis/commands/json/__pycache__/_util.cpython-312.pyc,, -redis/commands/json/__pycache__/commands.cpython-312.pyc,, -redis/commands/json/__pycache__/decoders.cpython-312.pyc,, -redis/commands/json/__pycache__/path.cpython-312.pyc,, -redis/commands/json/_util.py,sha256=hIBQ1TLCTgUifcLsg0x8kJlecxmXhA9I0zMnHlQk0Ho,137 -redis/commands/json/commands.py,sha256=ih8upnxeOpjPZXNfqeFBYxiCN2Cmyv8UGu3AlQnT6JQ,15723 -redis/commands/json/decoders.py,sha256=a_IoMV_wgeJyUifD4P6HTcM9s6FhricwmzQcZRmc-Gw,1411 -redis/commands/json/path.py,sha256=0zaO6_q_FVMk1Bkhkb7Wcr8AF2Tfr69VhkKy1IBVhpA,393 -redis/commands/redismodules.py,sha256=-kLM4RBklDhNh-MXCra81ZTSstIQ-ulRab6v0dYUTdA,2573 -redis/commands/search/__init__.py,sha256=happQFVF0j7P87p7LQsUK5AK0kuem9cA-xvVRdQWpos,5744 -redis/commands/search/__pycache__/__init__.cpython-312.pyc,, -redis/commands/search/__pycache__/_util.cpython-312.pyc,, -redis/commands/search/__pycache__/aggregation.cpython-312.pyc,, -redis/commands/search/__pycache__/commands.cpython-312.pyc,, -redis/commands/search/__pycache__/dialect.cpython-312.pyc,, -redis/commands/search/__pycache__/document.cpython-312.pyc,, -redis/commands/search/__pycache__/field.cpython-312.pyc,, -redis/commands/search/__pycache__/index_definition.cpython-312.pyc,, -redis/commands/search/__pycache__/profile_information.cpython-312.pyc,, -redis/commands/search/__pycache__/query.cpython-312.pyc,, -redis/commands/search/__pycache__/querystring.cpython-312.pyc,, -redis/commands/search/__pycache__/reducers.cpython-312.pyc,, -redis/commands/search/__pycache__/result.cpython-312.pyc,, -redis/commands/search/__pycache__/suggestion.cpython-312.pyc,, -redis/commands/search/_util.py,sha256=9Mp72OO5Ib5UbfN7uXb-iB7hQCm1jQLV90ms2P9XSGU,219 -redis/commands/search/aggregation.py,sha256=R2ul26mH10dQxUdQNKqH-Os1thOz88m4taTK08khiZc,11564 -redis/commands/search/commands.py,sha256=4lnL7MXsp9XqMyUgPxJ9S6p8BRnsIrjXuwvSTL9qo3E,38436 -redis/commands/search/dialect.py,sha256=-7M6kkr33x0FkMtKmUsbeRAE6qxLUbqdJCqIo0UKIXo,105 -redis/commands/search/document.py,sha256=g2R-PRgq-jN33_GLXzavvse4cpIHBMfjPfPK7tnE9Gc,413 -redis/commands/search/field.py,sha256=g9I1LHrVJKO1KtiUwotxrQvpg89e-sx26oClHuaKTn8,5935 -redis/commands/search/index_definition.py,sha256=VL2CMzjxN0HEIaTn88evnHX1fCEmytbik4vAmiiYSC8,2489 -redis/commands/search/profile_information.py,sha256=w9SbMiHbcZ1TpsZMe8cMIyO1hGkm5GhnZ_Gqg1feLtc,249 -redis/commands/search/query.py,sha256=MbSs-cY7hG1OEkO-i6LJ_Ui1D3d2VyDTXPrmb-rty7w,12199 -redis/commands/search/querystring.py,sha256=dE577kOqkCErNgO-IXI4xFVHI8kQE-JiH5ZRI_CKjHE,7597 -redis/commands/search/reducers.py,sha256=Scceylx8BjyqS-TJOdhNW63n6tecL9ojt4U5Sqho5UY,4220 -redis/commands/search/result.py,sha256=iuqmwOeCNo_7N4a_YxxDzVdOTpbwfF1T2uuq5sTqzMo,2624 -redis/commands/search/suggestion.py,sha256=V_re6suDCoNc0ETn_P1t51FeK4pCamPwxZRxCY8jscE,1612 -redis/commands/sentinel.py,sha256=Q1Xuw7qXA0YRZXGlIKsuOtah8UfF0QnkLywOTRvjiMY,5299 -redis/commands/timeseries/__init__.py,sha256=k492_xE_lBD0cVSX82TWBiNxOWuDDrrVZUjINi3LZSc,3450 -redis/commands/timeseries/__pycache__/__init__.cpython-312.pyc,, -redis/commands/timeseries/__pycache__/commands.cpython-312.pyc,, -redis/commands/timeseries/__pycache__/info.cpython-312.pyc,, -redis/commands/timeseries/__pycache__/utils.cpython-312.pyc,, -redis/commands/timeseries/commands.py,sha256=8Z2BEyP23qTYCJR_e9zdG11yWmIDwGBMO2PJNLtK2BA,47147 -redis/commands/timeseries/info.py,sha256=meZYdu7IV9KaUWMKZs9qW4vo3Q9MwhdY-EBtKQzls5o,3223 -redis/commands/timeseries/utils.py,sha256=NLwSOS5Dz9N8dYQSzEyBIvrItOWwfQ0xgDj8un6x3dU,1319 -redis/commands/vectorset/__init__.py,sha256=_fM0UdYjuzs8YWIUjQGH9QX5FwI0So8_D-5ALWWrWFc,1322 -redis/commands/vectorset/__pycache__/__init__.cpython-312.pyc,, -redis/commands/vectorset/__pycache__/commands.cpython-312.pyc,, -redis/commands/vectorset/__pycache__/utils.cpython-312.pyc,, -redis/commands/vectorset/commands.py,sha256=xXfQqI7_VWbUsyBwUa5FoZLF10alJDMtZoa_H5VbGFQ,12763 -redis/commands/vectorset/utils.py,sha256=N-x0URyg76XC39CNfBym6FkFCVgm5NthzWKBnc2H0Xc,2981 -redis/connection.py,sha256=eT4Mbj5pjBm_R5SSQrrDkljJ-qCxnsgVRBDlbwrGDsU,67042 -redis/crc.py,sha256=Z3kXFtkY2LdgefnQMud1xr4vG5UYvA9LCMqNMX1ywu4,729 -redis/credentials.py,sha256=GOnO3-LSW34efHaIrUbS742Mw8l70mRzF6UrKiKZsMY,1828 -redis/event.py,sha256=ddsIm3uP1PagsN9oYyblE7vE6n9VDCe5cZVxdUogbCQ,12133 -redis/exceptions.py,sha256=b3OO87gncNCRUnx1d7O57N2kkjP-feXn70fPkXHaLmQ,5789 -redis/lock.py,sha256=GrvPSxaOqKo7iAL2oi5ZUEPsOkxAXHVE_Tp1ejgO2fY,12760 -redis/ocsp.py,sha256=teYSmKnCtk6B3jJLdNYbZN4OE0mxgspt2zUPbkIQzio,11452 -redis/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -redis/retry.py,sha256=oS0nc0nYxEQaD4t95HEr1GhvhpOmnTKMnNtHn8Fqzxo,3405 -redis/sentinel.py,sha256=DP1XtO1HRemZMamC1TFHg_hBJRv9eoQgTMlZfPYRUo8,15013 -redis/typing.py,sha256=z5JQjGkNzejEzb2y7TXct7tS5yzAfLQod9o37Mh1_Ug,1953 -redis/utils.py,sha256=vO-njeF4ntROo1OReUiKtcY72I2JcEZYA62-_ssQW50,8495 diff --git a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/WHEEL b/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/WHEEL deleted file mode 100644 index 12228d4..0000000 --- a/venv/lib/python3.12/site-packages/redis-6.4.0.dist-info/WHEEL +++ /dev/null @@ -1,4 +0,0 @@ -Wheel-Version: 1.0 -Generator: hatchling 1.27.0 -Root-Is-Purelib: true -Tag: py3-none-any diff --git a/venv/lib/python3.12/site-packages/redis/__init__.py b/venv/lib/python3.12/site-packages/redis/__init__.py index 795662d..495d2d9 100644 --- a/venv/lib/python3.12/site-packages/redis/__init__.py +++ b/venv/lib/python3.12/site-packages/redis/__init__.py @@ -1,3 +1,5 @@ +import sys + from redis import asyncio # noqa from redis.backoff import default_backoff from redis.client import Redis, StrictRedis @@ -16,15 +18,11 @@ from redis.exceptions import ( BusyLoadingError, ChildDeadlockedError, ConnectionError, - CrossSlotTransactionError, DataError, - InvalidPipelineStack, InvalidResponse, - MaxConnectionsError, OutOfMemoryError, PubSubError, ReadOnlyError, - RedisClusterException, RedisError, ResponseError, TimeoutError, @@ -38,6 +36,11 @@ from redis.sentinel import ( ) from redis.utils import from_url +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata + def int_or_str(value): try: @@ -46,10 +49,17 @@ def int_or_str(value): return value -__version__ = "6.4.0" -VERSION = tuple(map(int_or_str, __version__.split("."))) +try: + __version__ = metadata.version("redis") +except metadata.PackageNotFoundError: + __version__ = "99.99.99" +try: + VERSION = tuple(map(int_or_str, __version__.split("."))) +except AttributeError: + VERSION = tuple([99, 99, 99]) + __all__ = [ "AuthenticationError", "AuthenticationWrongNumberOfArgsError", @@ -60,19 +70,15 @@ __all__ = [ "ConnectionError", "ConnectionPool", "CredentialProvider", - "CrossSlotTransactionError", "DataError", "from_url", "default_backoff", - "InvalidPipelineStack", "InvalidResponse", - "MaxConnectionsError", "OutOfMemoryError", "PubSubError", "ReadOnlyError", "Redis", "RedisCluster", - "RedisClusterException", "RedisError", "ResponseError", "Sentinel", diff --git a/venv/lib/python3.12/site-packages/redis/_parsers/__init__.py b/venv/lib/python3.12/site-packages/redis/_parsers/__init__.py index 30cb1cd..6cc32e3 100644 --- a/venv/lib/python3.12/site-packages/redis/_parsers/__init__.py +++ b/venv/lib/python3.12/site-packages/redis/_parsers/__init__.py @@ -1,9 +1,4 @@ -from .base import ( - AsyncPushNotificationsParser, - BaseParser, - PushNotificationsParser, - _AsyncRESPBase, -) +from .base import BaseParser, _AsyncRESPBase from .commands import AsyncCommandsParser, CommandsParser from .encoders import Encoder from .hiredis import _AsyncHiredisParser, _HiredisParser @@ -16,12 +11,10 @@ __all__ = [ "_AsyncRESPBase", "_AsyncRESP2Parser", "_AsyncRESP3Parser", - "AsyncPushNotificationsParser", "CommandsParser", "Encoder", "BaseParser", "_HiredisParser", "_RESP2Parser", "_RESP3Parser", - "PushNotificationsParser", ] diff --git a/venv/lib/python3.12/site-packages/redis/_parsers/base.py b/venv/lib/python3.12/site-packages/redis/_parsers/base.py index 69d7b58..8e59249 100644 --- a/venv/lib/python3.12/site-packages/redis/_parsers/base.py +++ b/venv/lib/python3.12/site-packages/redis/_parsers/base.py @@ -1,7 +1,7 @@ import sys from abc import ABC from asyncio import IncompleteReadError, StreamReader, TimeoutError -from typing import Callable, List, Optional, Protocol, Union +from typing import List, Optional, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout @@ -9,32 +9,26 @@ else: from async_timeout import timeout as async_timeout from ..exceptions import ( - AskError, AuthenticationError, AuthenticationWrongNumberOfArgsError, BusyLoadingError, - ClusterCrossSlotError, - ClusterDownError, ConnectionError, ExecAbortError, - MasterDownError, ModuleError, - MovedError, NoPermissionError, NoScriptError, OutOfMemoryError, ReadOnlyError, RedisError, ResponseError, - TryAgainError, ) from ..typing import EncodableT from .encoders import Encoder from .socket import SERVER_CLOSED_CONNECTION_ERROR, SocketBuffer -MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." +MODULE_LOAD_ERROR = "Error loading the extension. " "Please check the server logs." NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" -MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not " "possible." MODULE_EXPORTS_DATA_TYPES_ERROR = ( "Error unloading module: the module " "exports one or more module-side data " @@ -78,12 +72,6 @@ class BaseParser(ABC): "READONLY": ReadOnlyError, "NOAUTH": AuthenticationError, "NOPERM": NoPermissionError, - "ASK": AskError, - "TRYAGAIN": TryAgainError, - "MOVED": MovedError, - "CLUSTERDOWN": ClusterDownError, - "CROSSSLOT": ClusterCrossSlotError, - "MASTERDOWN": MasterDownError, } @classmethod @@ -158,58 +146,6 @@ class AsyncBaseParser(BaseParser): raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] - - -class PushNotificationsParser(Protocol): - """Protocol defining RESP3-specific parsing functionality""" - - pubsub_push_handler_func: Callable - invalidation_push_handler_func: Optional[Callable] = None - - def handle_pubsub_push_response(self, response): - """Handle pubsub push responses""" - raise NotImplementedError() - - def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: - return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: - return self.invalidation_push_handler_func(response) - - def set_pubsub_push_handler(self, pubsub_push_handler_func): - self.pubsub_push_handler_func = pubsub_push_handler_func - - def set_invalidation_push_handler(self, invalidation_push_handler_func): - self.invalidation_push_handler_func = invalidation_push_handler_func - - -class AsyncPushNotificationsParser(Protocol): - """Protocol defining async RESP3-specific parsing functionality""" - - pubsub_push_handler_func: Callable - invalidation_push_handler_func: Optional[Callable] = None - - async def handle_pubsub_push_response(self, response): - """Handle pubsub push responses asynchronously""" - raise NotImplementedError() - - async def handle_push_response(self, response, **kwargs): - """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: - return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: - return await self.invalidation_push_handler_func(response) - - def set_pubsub_push_handler(self, pubsub_push_handler_func): - """Set the pubsub push handler function""" - self.pubsub_push_handler_func = pubsub_push_handler_func - - def set_invalidation_push_handler(self, invalidation_push_handler_func): - """Set the invalidation push handler function""" - self.invalidation_push_handler_func = invalidation_push_handler_func - - class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" @@ -246,7 +182,7 @@ class _AsyncRESPBase(AsyncBaseParser): return True try: async with async_timeout(0): - return self._stream.at_eof() + return await self._stream.read(1) except TimeoutError: return False diff --git a/venv/lib/python3.12/site-packages/redis/_parsers/helpers.py b/venv/lib/python3.12/site-packages/redis/_parsers/helpers.py index 154dc66..fb5da83 100644 --- a/venv/lib/python3.12/site-packages/redis/_parsers/helpers.py +++ b/venv/lib/python3.12/site-packages/redis/_parsers/helpers.py @@ -38,7 +38,7 @@ def parse_info(response): response = str_if_bytes(response) def get_value(value): - if "," not in value and "=" not in value: + if "," not in value or "=" not in value: try: if "." in value: return float(value) @@ -46,18 +46,11 @@ def parse_info(response): return int(value) except ValueError: return value - elif "=" not in value: - return [get_value(v) for v in value.split(",") if v] else: sub_dict = {} for item in value.split(","): - if not item: - continue - if "=" in item: - k, v = item.rsplit("=", 1) - sub_dict[k] = get_value(v) - else: - sub_dict[item] = True + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) return sub_dict for line in response.splitlines(): @@ -87,7 +80,7 @@ def parse_memory_stats(response, **kwargs): """Parse the results of MEMORY STATS""" stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) for key, value in stats.items(): - if key.startswith("db.") and isinstance(value, list): + if key.startswith("db."): stats[key] = pairs_to_dict( value, decode_keys=True, decode_string_values=True ) @@ -275,22 +268,17 @@ def parse_xinfo_stream(response, **options): data = {str_if_bytes(k): v for k, v in response.items()} if not options.get("full", False): first = data.get("first-entry") - if first is not None and first[0] is not None: + if first is not None: data["first-entry"] = (first[0], pairs_to_dict(first[1])) last = data["last-entry"] - if last is not None and last[0] is not None: + if last is not None: data["last-entry"] = (last[0], pairs_to_dict(last[1])) else: data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - if len(data["groups"]) > 0 and isinstance(data["groups"][0], list): + if isinstance(data["groups"][0], list): data["groups"] = [ pairs_to_dict(group, decode_keys=True) for group in data["groups"] ] - for g in data["groups"]: - if g["consumers"] and g["consumers"][0] is not None: - g["consumers"] = [ - pairs_to_dict(c, decode_keys=True) for c in g["consumers"] - ] else: data["groups"] = [ {str_if_bytes(k): v for k, v in group.items()} @@ -334,7 +322,7 @@ def float_or_none(response): return float(response) -def bool_ok(response, **options): +def bool_ok(response): return str_if_bytes(response) == "OK" @@ -366,12 +354,7 @@ def parse_scan(response, **options): def parse_hscan(response, **options): cursor, r = response - no_values = options.get("no_values", False) - if no_values: - payload = r or [] - else: - payload = r and pairs_to_dict(r) or {} - return int(cursor), payload + return int(cursor), r and pairs_to_dict(r) or {} def parse_zscan(response, **options): @@ -396,20 +379,13 @@ def parse_slowlog_get(response, **options): # an O(N) complexity) instead of the command. if isinstance(item[3], list): result["command"] = space.join(item[3]) - - # These fields are optional, depends on environment. - if len(item) >= 6: - result["client_address"] = item[4] - result["client_name"] = item[5] + result["client_address"] = item[4] + result["client_name"] = item[5] else: result["complexity"] = item[3] result["command"] = space.join(item[4]) - - # These fields are optional, depends on environment. - if len(item) >= 7: - result["client_address"] = item[5] - result["client_name"] = item[6] - + result["client_address"] = item[5] + result["client_name"] = item[6] return result return [parse_item(item) for item in response] @@ -452,11 +428,9 @@ def parse_cluster_info(response, **options): def _parse_node_line(line): line_items = line.split(" ") node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] - ip = addr.split("@")[0] - hostname = addr.split("@")[1].split(",")[1] if "@" in addr and "," in addr else "" + addr = addr.split("@")[0] node_dict = { "node_id": node_id, - "hostname": hostname, "flags": flags, "master_id": master_id, "last_ping_sent": ping, @@ -469,7 +443,7 @@ def _parse_node_line(line): if len(line_items) >= 9: slots, migrations = _parse_slots(line_items[8:]) node_dict["slots"], node_dict["migrations"] = slots, migrations - return ip, node_dict + return addr, node_dict def _parse_slots(slot_ranges): @@ -516,7 +490,7 @@ def parse_geosearch_generic(response, **options): except KeyError: # it means the command was sent via execute_command return response - if not isinstance(response, list): + if type(response) != list: response_list = [response] else: response_list = response @@ -676,8 +650,7 @@ def parse_client_info(value): "omem", "tot-mem", }: - if int_key in client_info: - client_info[int_key] = int(client_info[int_key]) + client_info[int_key] = int(client_info[int_key]) return client_info @@ -840,28 +813,24 @@ _RedisCallbacksRESP2 = { _RedisCallbacksRESP3 = { - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), **string_keys_to_dict( "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " "ZUNION HGETALL XREADGROUP", lambda r, **kwargs: r, ), **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), - "ACL LOG": lambda r: ( - [ - {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} - for x in r - ] - if isinstance(r, list) - else bool_ok(r) - ), + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r + ] + if isinstance(r, list) + else bool_ok(r), "COMMAND": parse_command_resp3, "CONFIG GET": lambda r: { - str_if_bytes(key) if key is not None else None: ( - str_if_bytes(value) if value is not None else None - ) + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None for key, value in r.items() }, "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, @@ -869,11 +838,11 @@ _RedisCallbacksRESP3 = { "SENTINEL MASTERS": parse_sentinel_masters_resp3, "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, - "STRALGO": lambda r, **options: ( - {str_if_bytes(key): str_if_bytes(value) for key, value in r.items()} - if isinstance(r, dict) - else str_if_bytes(r) - ), + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), "XINFO CONSUMERS": lambda r: [ {str_if_bytes(key): value for key, value in x.items()} for x in r ], diff --git a/venv/lib/python3.12/site-packages/redis/_parsers/hiredis.py b/venv/lib/python3.12/site-packages/redis/_parsers/hiredis.py index 521a58b..b3247b7 100644 --- a/venv/lib/python3.12/site-packages/redis/_parsers/hiredis.py +++ b/venv/lib/python3.12/site-packages/redis/_parsers/hiredis.py @@ -1,23 +1,19 @@ import asyncio import socket import sys -from logging import getLogger -from typing import Callable, List, Optional, TypedDict, Union +from typing import Callable, List, Optional, Union if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: from async_timeout import timeout as async_timeout +from redis.compat import TypedDict + from ..exceptions import ConnectionError, InvalidResponse, RedisError from ..typing import EncodableT from ..utils import HIREDIS_AVAILABLE -from .base import ( - AsyncBaseParser, - AsyncPushNotificationsParser, - BaseParser, - PushNotificationsParser, -) +from .base import AsyncBaseParser, BaseParser from .socket import ( NONBLOCKING_EXCEPTION_ERROR_NUMBERS, NONBLOCKING_EXCEPTIONS, @@ -25,11 +21,6 @@ from .socket import ( SERVER_CLOSED_CONNECTION_ERROR, ) -# Used to signal that hiredis-py does not have enough data to parse. -# Using `False` or `None` is not reliable, given that the parser can -# return `False` or `None` for legitimate reasons from RESP payloads. -NOT_ENOUGH_DATA = object() - class _HiredisReaderArgs(TypedDict, total=False): protocolError: Callable[[str], Exception] @@ -38,7 +29,7 @@ class _HiredisReaderArgs(TypedDict, total=False): errors: Optional[str] -class _HiredisParser(BaseParser, PushNotificationsParser): +class _HiredisParser(BaseParser): "Parser class for connections using Hiredis" def __init__(self, socket_read_size): @@ -46,9 +37,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser): raise RedisError("Hiredis is not installed") self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) - self.pubsub_push_handler_func = self.handle_pubsub_push_response - self.invalidation_push_handler_func = None - self._hiredis_PushNotificationType = None def __del__(self): try: @@ -56,11 +44,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser): except Exception: pass - def handle_pubsub_push_response(self, response): - logger = getLogger("push_response") - logger.debug("Push response: " + str(response)) - return response - def on_connect(self, connection, **kwargs): import hiredis @@ -70,32 +53,25 @@ class _HiredisParser(BaseParser, PushNotificationsParser): "protocolError": InvalidResponse, "replyError": self.parse_error, "errors": connection.encoder.encoding_errors, - "notEnoughData": NOT_ENOUGH_DATA, } if connection.encoder.decode_responses: kwargs["encoding"] = connection.encoder.encoding self._reader = hiredis.Reader(**kwargs) - self._next_response = NOT_ENOUGH_DATA - - try: - self._hiredis_PushNotificationType = hiredis.PushNotification - except AttributeError: - # hiredis < 3.2 - self._hiredis_PushNotificationType = None + self._next_response = False def on_disconnect(self): self._sock = None self._reader = None - self._next_response = NOT_ENOUGH_DATA + self._next_response = False def can_read(self, timeout): if not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._next_response is NOT_ENOUGH_DATA: + if self._next_response is False: self._next_response = self._reader.gets() - if self._next_response is NOT_ENOUGH_DATA: + if self._next_response is False: return self.read_from_socket(timeout=timeout, raise_on_timeout=False) return True @@ -129,24 +105,14 @@ class _HiredisParser(BaseParser, PushNotificationsParser): if custom_timeout: sock.settimeout(self._socket_timeout) - def read_response(self, disable_decoding=False, push_request=False): + def read_response(self, disable_decoding=False): if not self._reader: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) # _next_response might be cached from a can_read() call - if self._next_response is not NOT_ENOUGH_DATA: + if self._next_response is not False: response = self._next_response - self._next_response = NOT_ENOUGH_DATA - if self._hiredis_PushNotificationType is not None and isinstance( - response, self._hiredis_PushNotificationType - ): - response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return response + self._next_response = False return response if disable_decoding: @@ -154,7 +120,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser): else: response = self._reader.gets() - while response is NOT_ENOUGH_DATA: + while response is False: self.read_from_socket() if disable_decoding: response = self._reader.gets(False) @@ -165,16 +131,6 @@ class _HiredisParser(BaseParser, PushNotificationsParser): # happened if isinstance(response, ConnectionError): raise response - elif self._hiredis_PushNotificationType is not None and isinstance( - response, self._hiredis_PushNotificationType - ): - response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return response elif ( isinstance(response, list) and response @@ -184,7 +140,7 @@ class _HiredisParser(BaseParser, PushNotificationsParser): return response -class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): +class _AsyncHiredisParser(AsyncBaseParser): """Async implementation of parser class for connections using Hiredis""" __slots__ = ("_reader",) @@ -194,14 +150,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): raise RedisError("Hiredis is not available.") super().__init__(socket_read_size=socket_read_size) self._reader = None - self.pubsub_push_handler_func = self.handle_pubsub_push_response - self.invalidation_push_handler_func = None - self._hiredis_PushNotificationType = None - - async def handle_pubsub_push_response(self, response): - logger = getLogger("push_response") - logger.debug("Push response: " + str(response)) - return response def on_connect(self, connection): import hiredis @@ -210,7 +158,6 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): kwargs: _HiredisReaderArgs = { "protocolError": InvalidResponse, "replyError": self.parse_error, - "notEnoughData": NOT_ENOUGH_DATA, } if connection.encoder.decode_responses: kwargs["encoding"] = connection.encoder.encoding @@ -219,21 +166,13 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): self._reader = hiredis.Reader(**kwargs) self._connected = True - try: - self._hiredis_PushNotificationType = getattr( - hiredis, "PushNotification", None - ) - except AttributeError: - # hiredis < 3.2 - self._hiredis_PushNotificationType = None - def on_disconnect(self): self._connected = False async def can_read_destructive(self): if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) - if self._reader.gets() is not NOT_ENOUGH_DATA: + if self._reader.gets(): return True try: async with async_timeout(0): @@ -251,7 +190,7 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): return True async def read_response( - self, disable_decoding: bool = False, push_request: bool = False + self, disable_decoding: bool = False ) -> Union[EncodableT, List[EncodableT]]: # If `on_disconnect()` has been called, prohibit any more reads # even if they could happen because data might be present. @@ -259,33 +198,16 @@ class _AsyncHiredisParser(AsyncBaseParser, AsyncPushNotificationsParser): if not self._connected: raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() - - while response is NOT_ENOUGH_DATA: + response = self._reader.gets() + while response is False: await self.read_from_socket() - if disable_decoding: - response = self._reader.gets(False) - else: - response = self._reader.gets() + response = self._reader.gets() # if the response is a ConnectionError or the response is a list and # the first item is a ConnectionError, raise it as something bad # happened if isinstance(response, ConnectionError): raise response - elif self._hiredis_PushNotificationType is not None and isinstance( - response, self._hiredis_PushNotificationType - ): - response = await self.handle_push_response(response) - if not push_request: - return await self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: - return response elif ( isinstance(response, list) and response diff --git a/venv/lib/python3.12/site-packages/redis/_parsers/resp3.py b/venv/lib/python3.12/site-packages/redis/_parsers/resp3.py index 42c6652..ad766a8 100644 --- a/venv/lib/python3.12/site-packages/redis/_parsers/resp3.py +++ b/venv/lib/python3.12/site-packages/redis/_parsers/resp3.py @@ -3,26 +3,20 @@ from typing import Any, Union from ..exceptions import ConnectionError, InvalidResponse, ResponseError from ..typing import EncodableT -from .base import ( - AsyncPushNotificationsParser, - PushNotificationsParser, - _AsyncRESPBase, - _RESPBase, -) +from .base import _AsyncRESPBase, _RESPBase from .socket import SERVER_CLOSED_CONNECTION_ERROR -class _RESP3Parser(_RESPBase, PushNotificationsParser): +class _RESP3Parser(_RESPBase): """RESP3 protocol implementation""" def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.pubsub_push_handler_func = self.handle_pubsub_push_response - self.invalidation_push_handler_func = None + self.push_handler_func = self.handle_push_response - def handle_pubsub_push_response(self, response): + def handle_push_response(self, response): logger = getLogger("push_response") - logger.debug("Push response: " + str(response)) + logger.info("Push response: " + str(response)) return response def read_response(self, disable_decoding=False, push_request=False): @@ -91,16 +85,19 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): # set response elif byte == b"~": # redis can return unhashable types (like dict) in a set, - # so we return sets as list, all the time, for predictability + # so we need to first convert to a list, and then try to convert it to a set response = [ self._read_response(disable_decoding=disable_decoding) for _ in range(int(response)) ] + try: + response = set(response) + except TypeError: + pass # map response elif byte == b"%": - # We cannot use a dict-comprehension to parse stream. - # Evaluation order of key:val expression in dict comprehension only - # became defined to be left-right in version 3.8 + # we use this approach and not dict comprehension here + # because this dict comprehension fails in python 3.7 resp_dict = {} for _ in range(int(response)): key = self._read_response(disable_decoding=disable_decoding) @@ -116,13 +113,13 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): ) for _ in range(int(response)) ] - response = self.handle_push_response(response) + res = self.push_handler_func(response) if not push_request: return self._read_response( disable_decoding=disable_decoding, push_request=push_request ) else: - return response + return res else: raise InvalidResponse(f"Protocol Error: {raw!r}") @@ -130,16 +127,18 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): response = self.encoder.decode(response) return response + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func -class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser): + +class _AsyncRESP3Parser(_AsyncRESPBase): def __init__(self, socket_read_size): super().__init__(socket_read_size) - self.pubsub_push_handler_func = self.handle_pubsub_push_response - self.invalidation_push_handler_func = None + self.push_handler_func = self.handle_push_response - async def handle_pubsub_push_response(self, response): + def handle_push_response(self, response): logger = getLogger("push_response") - logger.debug("Push response: " + str(response)) + logger.info("Push response: " + str(response)) return response async def read_response( @@ -215,23 +214,23 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser): # set response elif byte == b"~": # redis can return unhashable types (like dict) in a set, - # so we always convert to a list, to have predictable return types + # so we need to first convert to a list, and then try to convert it to a set response = [ (await self._read_response(disable_decoding=disable_decoding)) for _ in range(int(response)) ] + try: + response = set(response) + except TypeError: + pass # map response elif byte == b"%": - # We cannot use a dict-comprehension to parse stream. - # Evaluation order of key:val expression in dict comprehension only - # became defined to be left-right in version 3.8 - resp_dict = {} - for _ in range(int(response)): - key = await self._read_response(disable_decoding=disable_decoding) - resp_dict[key] = await self._read_response( - disable_decoding=disable_decoding, push_request=push_request + response = { + (await self._read_response(disable_decoding=disable_decoding)): ( + await self._read_response(disable_decoding=disable_decoding) ) - response = resp_dict + for _ in range(int(response)) + } # push response elif byte == b">": response = [ @@ -242,16 +241,19 @@ class _AsyncRESP3Parser(_AsyncRESPBase, AsyncPushNotificationsParser): ) for _ in range(int(response)) ] - response = await self.handle_push_response(response) + res = self.push_handler_func(response) if not push_request: return await self._read_response( disable_decoding=disable_decoding, push_request=push_request ) else: - return response + return res else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) return response + + def set_push_handler(self, push_handler_func): + self.push_handler_func = push_handler_func diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/client.py b/venv/lib/python3.12/site-packages/redis/asyncio/client.py index aac4090..e4d2e77 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/client.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/client.py @@ -15,11 +15,9 @@ from typing import ( Mapping, MutableMapping, Optional, - Protocol, Set, Tuple, Type, - TypedDict, TypeVar, Union, cast, @@ -39,7 +37,6 @@ from redis.asyncio.connection import ( ) from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry -from redis.backoff import ExponentialWithJitterBackoff from redis.client import ( EMPTY_RESPONSE, NEVER_DECODE, @@ -52,40 +49,27 @@ from redis.commands import ( AsyncSentinelCommands, list_or_args, ) +from redis.compat import Protocol, TypedDict from redis.credentials import CredentialProvider -from redis.event import ( - AfterPooledConnectionsInstantiationEvent, - AfterPubSubConnectionInstantiationEvent, - AfterSingleConnectionInstantiationEvent, - ClientType, - EventDispatcher, -) from redis.exceptions import ( ConnectionError, ExecAbortError, PubSubError, RedisError, ResponseError, + TimeoutError, WatchError, ) from redis.typing import ChannelT, EncodableT, KeyT from redis.utils import ( - SSL_AVAILABLE, + HIREDIS_AVAILABLE, _set_info_logger, - deprecated_args, deprecated_function, get_lib_version, safe_str, str_if_bytes, - truncate_text, ) -if TYPE_CHECKING and SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode -else: - TLSVersion = None - VerifyMode = None - PubSubHandler = Callable[[Dict[str, str]], Awaitable[None]] _KeyT = TypeVar("_KeyT", bound=KeyT) _ArgT = TypeVar("_ArgT", KeyT, EncodableT) @@ -96,11 +80,13 @@ if TYPE_CHECKING: class ResponseCallbackProtocol(Protocol): - def __call__(self, response: Any, **kwargs): ... + def __call__(self, response: Any, **kwargs): + ... class AsyncResponseCallbackProtocol(Protocol): - async def __call__(self, response: Any, **kwargs): ... + async def __call__(self, response: Any, **kwargs): + ... ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] @@ -182,7 +168,7 @@ class Redis( warnings.warn( DeprecationWarning( '"auto_close_connection_pool" is deprecated ' - "since version 5.0.1. " + "since version 5.0.0. " "Please create a ConnectionPool explicitly and " "provide to the Redis() constructor instead." ) @@ -208,11 +194,6 @@ class Redis( client.auto_close_connection_pool = True return client - @deprecated_args( - args_to_warn=["retry_on_timeout"], - reason="TimeoutError is included by default.", - version="6.0.0", - ) def __init__( self, *, @@ -230,19 +211,14 @@ class Redis( encoding_errors: str = "strict", decode_responses: bool = False, retry_on_timeout: bool = False, - retry: Retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 - ), retry_on_error: Optional[list] = None, ssl: bool = False, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, - ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_cert_reqs: str = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = True, - ssl_min_version: Optional[TLSVersion] = None, - ssl_ciphers: Optional[str] = None, + ssl_check_hostname: bool = False, max_connections: Optional[int] = None, single_connection_client: bool = False, health_check_interval: int = 0, @@ -250,38 +226,20 @@ class Redis( lib_name: Optional[str] = "redis-py", lib_version: Optional[str] = get_lib_version(), username: Optional[str] = None, + retry: Optional[Retry] = None, auto_close_connection_pool: Optional[bool] = None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - event_dispatcher: Optional[EventDispatcher] = None, ): """ Initialize a new Redis client. - - To specify a retry policy for specific errors, you have two options: - - 1. Set the `retry_on_error` to a list of the error/s to retry on, and - you can also set `retry` to a valid `Retry` object(in case the default - one is not appropriate) - with this approach the retries will be triggered - on the default errors specified in the Retry object enriched with the - errors specified in `retry_on_error`. - - 2. Define a `Retry` object with configured 'supported_errors' and set - it to the `retry` parameter - with this approach you completely redefine - the errors on which retries will happen. - - `retry_on_timeout` is deprecated - please include the TimeoutError - either in the Retry object or in the `retry_on_error` list. - - When 'connection_pool' is provided - the retry configuration of the - provided pool will be used. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ kwargs: Dict[str, Any] - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher # auto_close_connection_pool only has an effect if connection_pool is # None. It is assumed that if connection_pool is not None, the user # wants to manage the connection pool themselves. @@ -289,7 +247,7 @@ class Redis( warnings.warn( DeprecationWarning( '"auto_close_connection_pool" is deprecated ' - "since version 5.0.1. " + "since version 5.0.0. " "Please create a ConnectionPool explicitly and " "provide to the Redis() constructor instead." ) @@ -301,6 +259,8 @@ class Redis( # Create internal connection pool, expected to be closed by Redis instance if not retry_on_error: retry_on_error = [] + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -310,6 +270,7 @@ class Redis( "encoding": encoding, "encoding_errors": encoding_errors, "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, "retry_on_error": retry_on_error, "retry": copy.deepcopy(retry), "max_connections": max_connections, @@ -350,26 +311,14 @@ class Redis( "ssl_ca_certs": ssl_ca_certs, "ssl_ca_data": ssl_ca_data, "ssl_check_hostname": ssl_check_hostname, - "ssl_min_version": ssl_min_version, - "ssl_ciphers": ssl_ciphers, } ) # This arg only used if no pool is passed in self.auto_close_connection_pool = auto_close_connection_pool connection_pool = ConnectionPool(**kwargs) - self._event_dispatcher.dispatch( - AfterPooledConnectionsInstantiationEvent( - [connection_pool], ClientType.ASYNC, credential_provider - ) - ) else: # If a pool is passed in, do not close it self.auto_close_connection_pool = False - self._event_dispatcher.dispatch( - AfterPooledConnectionsInstantiationEvent( - [connection_pool], ClientType.ASYNC, credential_provider - ) - ) self.connection_pool = connection_pool self.single_connection_client = single_connection_client @@ -388,10 +337,7 @@ class Redis( self._single_conn_lock = asyncio.Lock() def __repr__(self): - return ( - f"<{self.__class__.__module__}.{self.__class__.__name__}" - f"({self.connection_pool!r})>" - ) + return f"{self.__class__.__name__}<{self.connection_pool!r}>" def __await__(self): return self.initialize().__await__() @@ -400,13 +346,7 @@ class Redis( if self.single_connection_client: async with self._single_conn_lock: if self.connection is None: - self.connection = await self.connection_pool.get_connection() - - self._event_dispatcher.dispatch( - AfterSingleConnectionInstantiationEvent( - self.connection, ClientType.ASYNC, self._single_conn_lock - ) - ) + self.connection = await self.connection_pool.get_connection("_") return self def set_response_callback(self, command: str, callback: ResponseCallbackT): @@ -421,10 +361,10 @@ class Redis( """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def get_retry(self) -> Optional[Retry]: + def get_retry(self) -> Optional["Retry"]: return self.get_connection_kwargs().get("retry") - def set_retry(self, retry: Retry) -> None: + def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) @@ -503,7 +443,6 @@ class Redis( blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, - raise_on_release_error: bool = True, ) -> Lock: """ Return a new Lock object using key ``name`` that mimics @@ -550,11 +489,6 @@ class Redis( thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -572,7 +506,6 @@ class Redis( blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, - raise_on_release_error=raise_on_release_error, ) def pubsub(self, **kwargs) -> "PubSub": @@ -581,9 +514,7 @@ class Redis( subscribe to channels and listen for messages that get published to them. """ - return PubSub( - self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs - ) + return PubSub(self.connection_pool, **kwargs) def monitor(self) -> "Monitor": return Monitor(self.connection_pool) @@ -615,18 +546,15 @@ class Redis( _grl().call_exception_handler(context) except RuntimeError: pass - self.connection._close() async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: """ Closes Redis client connection - Args: - close_connection_pool: - decides whether to close the connection pool used by this Redis client, - overriding Redis.auto_close_connection_pool. - By default, let Redis.auto_close_connection_pool decide - whether to close the connection pool. + :param close_connection_pool: decides whether to close the connection pool used + by this Redis client, overriding Redis.auto_close_connection_pool. By default, + let Redis.auto_close_connection_pool decide whether to close the connection + pool. """ conn = self.connection if conn: @@ -637,7 +565,7 @@ class Redis( ): await self.connection_pool.disconnect() - @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") async def close(self, close_connection_pool: Optional[bool] = None) -> None: """ Alias for aclose(), for backwards compatibility @@ -651,17 +579,18 @@ class Redis( await conn.send_command(*args) return await self.parse_response(conn, command_name, **options) - async def _close_connection(self, conn: Connection): + async def _disconnect_raise(self, conn: Connection, error: Exception): """ - Close the connection before retrying. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection and raise an exception + if retry_on_error is not set or the error + is not one of the specified error types """ await conn.disconnect() + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): + raise error # COMMAND EXECUTION AND PROTOCOL PARSING async def execute_command(self, *args, **options): @@ -669,7 +598,7 @@ class Redis( await self.initialize() pool = self.connection_pool command_name = args[0] - conn = self.connection or await pool.get_connection() + conn = self.connection or await pool.get_connection(command_name, **options) if self.single_connection_client: await self._single_conn_lock.acquire() @@ -678,7 +607,7 @@ class Redis( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda _: self._close_connection(conn), + lambda error: self._disconnect_raise(conn, error), ) finally: if self.single_connection_client: @@ -704,9 +633,6 @@ class Redis( if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) - # Remove keys entry, it needs only for cache. - options.pop("keys", None) - if command_name in self.response_callbacks: # Mypy bug: https://github.com/python/mypy/issues/10977 command_name = cast(str, command_name) @@ -743,7 +669,7 @@ class Monitor: async def connect(self): if self.connection is None: - self.connection = await self.connection_pool.get_connection() + self.connection = await self.connection_pool.get_connection("MONITOR") async def __aenter__(self): await self.connect() @@ -820,12 +746,7 @@ class PubSub: ignore_subscribe_messages: bool = False, encoder=None, push_handler_func: Optional[Callable] = None, - event_dispatcher: Optional["EventDispatcher"] = None, ): - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages @@ -862,7 +783,7 @@ class PubSub: def __del__(self): if self.connection: - self.connection.deregister_connect_callback(self.on_connect) + self.connection._deregister_connect_callback(self.on_connect) async def aclose(self): # In case a connection property does not yet exist @@ -873,7 +794,7 @@ class PubSub: async with self._lock: if self.connection: await self.connection.disconnect() - self.connection.deregister_connect_callback(self.on_connect) + self.connection._deregister_connect_callback(self.on_connect) await self.connection_pool.release(self.connection) self.connection = None self.channels = {} @@ -881,12 +802,12 @@ class PubSub: self.patterns = {} self.pending_unsubscribe_patterns = set() - @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close") + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") async def close(self) -> None: """Alias for aclose(), for backwards compatibility""" await self.aclose() - @deprecated_function(version="5.0.1", reason="Use aclose() instead", name="reset") + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset") async def reset(self) -> None: """Alias for aclose(), for backwards compatibility""" await self.aclose() @@ -931,26 +852,26 @@ class PubSub: Ensure that the PubSub is connected """ if self.connection is None: - self.connection = await self.connection_pool.get_connection() + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) + self.connection._register_connect_callback(self.on_connect) else: await self.connection.connect() - if self.push_handler_func is not None: - self.connection._parser.set_pubsub_push_handler(self.push_handler_func) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) - self._event_dispatcher.dispatch( - AfterPubSubConnectionInstantiationEvent( - self.connection, self.connection_pool, ClientType.ASYNC, self._lock - ) - ) - - async def _reconnect(self, conn): + async def _disconnect_raise_connect(self, conn, error): """ - Try to reconnect + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect """ await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error await conn.connect() async def _execute(self, conn, command, *args, **kwargs): @@ -963,7 +884,7 @@ class PubSub: """ return await conn.retry.call_with_retry( lambda: command(*args, **kwargs), - lambda _: self._reconnect(conn), + lambda error: self._disconnect_raise_connect(conn, error), ) async def parse_response(self, block: bool = True, timeout: float = 0): @@ -1232,11 +1153,13 @@ class PubSub: class PubsubWorkerExceptionHandler(Protocol): - def __call__(self, e: BaseException, pubsub: PubSub): ... + def __call__(self, e: BaseException, pubsub: PubSub): + ... class AsyncPubsubWorkerExceptionHandler(Protocol): - async def __call__(self, e: BaseException, pubsub: PubSub): ... + async def __call__(self, e: BaseException, pubsub: PubSub): + ... PSWorkerThreadExcHandlerT = Union[ @@ -1254,8 +1177,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] in one transmission. This is convenient for batch processing, such as saving all the values in a list to Redis. - All commands executed within a pipeline(when running in transactional mode, - which is the default behavior) are wrapped with MULTI and EXEC + All commands executed within a pipeline are wrapped with MULTI and EXEC calls. This guarantees all commands executed in the pipeline will be executed atomically. @@ -1284,7 +1206,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] self.shard_hint = shard_hint self.watching = False self.command_stack: CommandStackT = [] - self.scripts: Set[Script] = set() + self.scripts: Set["Script"] = set() self.explicit_transaction = False async def __aenter__(self: _RedisT) -> _RedisT: @@ -1356,50 +1278,49 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - async def _disconnect_reset_raise_on_watching( - self, - conn: Connection, - error: Exception, - ): + async def _disconnect_reset_raise(self, conn, error): """ - Close the connection reset watching state and - raise an exception if we were watching. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError """ await conn.disconnect() # if we were already watching a variable, the watch is no longer # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: - await self.reset() + await self.aclose() raise WatchError( - f"A {type(error).__name__} occurred while watching one or more keys" + "A ConnectionError occurred on while watching one or more keys" ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.aclose() + raise async def immediate_execute_command(self, *args, **options): """ - Execute a command immediately, but don't auto-retry on the supported - errors for retry if we're already WATCHing a variable. - Used when issuing WATCH or subsequent commands retrieving their values but before + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before MULTI is called. """ command_name = args[0] conn = self.connection # if this is the first call, we need a connection if not conn: - conn = await self.connection_pool.get_connection() + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) self.connection = conn return await conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_reset_raise_on_watching(conn, error), + lambda error: self._disconnect_reset_raise(conn, error), ) def pipeline_execute_command(self, *args, **options): @@ -1484,10 +1405,6 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] if not isinstance(r, Exception): args, options = cmd command_name = args[0] - - # Remove keys entry, it needs only for cache. - options.pop("keys", None) - if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) if inspect.isawaitable(r): @@ -1525,10 +1442,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] self, exception: Exception, number: int, command: Iterable[object] ) -> None: cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({truncate_text(cmd)}) " - "of pipeline caused error: {exception.args}" - ) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" exception.args = (msg,) + exception.args[1:] async def parse_response( @@ -1554,15 +1468,11 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] if not exist: s.sha = await immediate("SCRIPT LOAD", s.script) - async def _disconnect_raise_on_watching(self, conn: Connection, error: Exception): + async def _disconnect_raise_reset(self, conn: Connection, error: Exception): """ - Close the connection, raise an exception if we were watching. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection, raise an exception if we were watching, + and raise an exception if retry_on_timeout is not set, + or the error is not a TimeoutError """ await conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1570,10 +1480,15 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] # indicates the user should retry this transaction. if self.watching: raise WatchError( - f"A {type(error).__name__} occurred while watching one or more keys" + "A ConnectionError occurred on while watching one or more keys" ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + await self.reset() + raise - async def execute(self, raise_on_error: bool = True) -> List[Any]: + async def execute(self, raise_on_error: bool = True): """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: @@ -1587,7 +1502,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] conn = self.connection if not conn: - conn = await self.connection_pool.get_connection() + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn @@ -1596,7 +1511,7 @@ class Pipeline(Redis): # lgtm [py/init-calls-subclass] try: return await conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_on_watching(conn, error), + lambda error: self._disconnect_raise_reset(conn, error), ) finally: await self.reset() diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/cluster.py b/venv/lib/python3.12/site-packages/redis/asyncio/cluster.py index e8434d0..636144a 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/cluster.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/cluster.py @@ -2,23 +2,16 @@ import asyncio import collections import random import socket -import threading -import time import warnings -from abc import ABC, abstractmethod -from copy import copy -from itertools import chain from typing import ( Any, Callable, - Coroutine, Deque, Dict, Generator, List, Mapping, Optional, - Set, Tuple, Type, TypeVar, @@ -32,11 +25,10 @@ from redis._parsers.helpers import ( _RedisCallbacksRESP3, ) from redis.asyncio.client import ResponseCallbackT -from redis.asyncio.connection import Connection, SSLConnection, parse_url +from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock from redis.asyncio.retry import Retry -from redis.auth.token import TokenInterface -from redis.backoff import ExponentialWithJitterBackoff, NoBackoff +from redis.backoff import default_backoff from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis from redis.cluster import ( PIPELINE_BLOCKED_COMMANDS, @@ -45,7 +37,6 @@ from redis.cluster import ( SLOT_ID, AbstractRedisCluster, LoadBalancer, - LoadBalancingStrategy, block_pipeline_command, get_node_name, parse_cluster_slots, @@ -53,49 +44,51 @@ from redis.cluster import ( from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot from redis.credentials import CredentialProvider -from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher from redis.exceptions import ( AskError, BusyLoadingError, + ClusterCrossSlotError, ClusterDownError, ClusterError, ConnectionError, - CrossSlotTransactionError, DataError, - ExecAbortError, - InvalidPipelineStack, + MasterDownError, MaxConnectionsError, MovedError, RedisClusterException, - RedisError, ResponseError, SlotNotCoveredError, TimeoutError, TryAgainError, - WatchError, ) from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import ( - SSL_AVAILABLE, - deprecated_args, deprecated_function, + dict_merge, get_lib_version, safe_str, str_if_bytes, - truncate_text, ) -if SSL_AVAILABLE: - from ssl import TLSVersion, VerifyMode -else: - TLSVersion = None - VerifyMode = None - TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] ) +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + "MOVED": MovedError, + "TRYAGAIN": TryAgainError, + }, + ) + + class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands): """ Create a new RedisCluster client. @@ -136,23 +129,9 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand | See: https://redis.io/docs/manual/scaling/#redis-cluster-configuration-parameters :param read_from_replicas: - | @deprecated - please use load_balancing_strategy instead - | Enable read from replicas in READONLY mode. + | Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - The data read from replicas is eventually consistent with the data in primary nodes. - :param load_balancing_strategy: - | Enable read from replicas in READONLY mode and defines the load balancing - strategy that will be used for cluster node selection. - The data read from replicas is eventually consistent with the data in primary nodes. - :param dynamic_startup_nodes: - | Set the RedisCluster's startup nodes to all the discovered nodes. - If true (default value), the cluster's discovered nodes will be used to - determine the cluster nodes-slots mapping in the next topology refresh. - It will remove the initial passed startup nodes if their endpoints aren't - listed in the CLUSTER SLOTS output. - If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists - specific IP addresses, it is best to set it to false. :param reinitialize_steps: | Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs and the cluster does not @@ -162,23 +141,19 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand To avoid reinitializing the cluster on moved errors, set reinitialize_steps to 0. :param cluster_error_retry_attempts: - | @deprecated - Please configure the 'retry' object instead - In case 'retry' object is set - this argument is ignored! - - Number of times to retry before raising an error when :class:`~.TimeoutError`, - :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` - or :class:`~.ClusterDownError` are encountered - :param retry: - | A retry object that defines the retry strategy and the number of - retries for the cluster client. - In current implementation for the cluster client (starting form redis-py version 6.0.0) - the retry object is not yet fully utilized, instead it is used just to determine - the number of retries for the cluster client. - In the future releases the retry object will be used to handle the cluster client retries! + | Number of times to retry before raising an error when :class:`~.TimeoutError` + or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered + :param connection_error_retry_attempts: + | Number of times to retry before reinitializing when :class:`~.TimeoutError` + or :class:`~.ConnectionError` are encountered. + The default backoff strategy will be set if Retry object is not passed (see + default_backoff in backoff.py). To change it, pass a custom Retry object + using the "retry" keyword. :param max_connections: | Maximum number of connections per node. If there are no free connections & the maximum number of connections are already created, a - :class:`~.MaxConnectionsError` is raised. + :class:`~.MaxConnectionsError` is raised. This error may be retried as defined + by :attr:`connection_error_retry_attempts` :param address_remap: | An optional callable which, when provided with an internal network address of a node, e.g. a `(host, port)` tuple, will return the address @@ -234,9 +209,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand __slots__ = ( "_initialize", "_lock", - "retry", + "cluster_error_retry_attempts", "command_flags", "commands_parser", + "connection_error_retry_attempts", "connection_kwargs", "encoder", "node_flags", @@ -248,18 +224,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "result_callbacks", ) - @deprecated_args( - args_to_warn=["read_from_replicas"], - reason="Please configure the 'load_balancing_strategy' instead", - version="5.3.0", - ) - @deprecated_args( - args_to_warn=[ - "cluster_error_retry_attempts", - ], - reason="Please configure the 'retry' object instead", - version="6.0.0", - ) def __init__( self, host: Optional[str] = None, @@ -268,13 +232,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand startup_nodes: Optional[List["ClusterNode"]] = None, require_full_coverage: bool = True, read_from_replicas: bool = False, - load_balancing_strategy: Optional[LoadBalancingStrategy] = None, - dynamic_startup_nodes: bool = True, reinitialize_steps: int = 5, cluster_error_retry_attempts: int = 3, + connection_error_retry_attempts: int = 3, max_connections: int = 2**31, - retry: Optional["Retry"] = None, - retry_on_error: Optional[List[Type[Exception]]] = None, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, @@ -294,19 +255,18 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand socket_keepalive: bool = False, socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, socket_timeout: Optional[float] = None, + retry: Optional["Retry"] = None, + retry_on_error: Optional[List[Type[Exception]]] = None, # SSL related kwargs ssl: bool = False, ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_cert_reqs: Union[str, VerifyMode] = "required", + ssl_cert_reqs: str = "required", ssl_certfile: Optional[str] = None, - ssl_check_hostname: bool = True, + ssl_check_hostname: bool = False, ssl_keyfile: Optional[str] = None, - ssl_min_version: Optional[TLSVersion] = None, - ssl_ciphers: Optional[str] = None, protocol: Optional[int] = 2, - address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - event_dispatcher: Optional[EventDispatcher] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: if db: raise RedisClusterException( @@ -330,6 +290,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand kwargs: Dict[str, Any] = { "max_connections": max_connections, "connection_class": Connection, + "parser_class": ClusterParser, # Client related kwargs "credential_provider": credential_provider, "username": username, @@ -347,6 +308,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "socket_keepalive": socket_keepalive, "socket_keepalive_options": socket_keepalive_options, "socket_timeout": socket_timeout, + "retry": retry, "protocol": protocol, } @@ -361,24 +323,24 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand "ssl_certfile": ssl_certfile, "ssl_check_hostname": ssl_check_hostname, "ssl_keyfile": ssl_keyfile, - "ssl_min_version": ssl_min_version, - "ssl_ciphers": ssl_ciphers, } ) - if read_from_replicas or load_balancing_strategy: + if read_from_replicas: # Call our on_connect function to configure READONLY mode kwargs["redis_connect_func"] = self.on_connect - if retry: - self.retry = retry - else: - self.retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), - retries=cluster_error_retry_attempts, + self.retry = retry + if retry or retry_on_error or connection_error_retry_attempts > 0: + # Set a retry object for all cluster nodes + self.retry = retry or Retry( + default_backoff(), connection_error_retry_attempts ) - if retry_on_error: + if not retry_on_error: + # Default errors for retrying + retry_on_error = [ConnectionError, TimeoutError] self.retry.update_supported_errors(retry_on_error) + kwargs.update({"retry": self.retry}) kwargs["response_callbacks"] = _RedisCallbacks.copy() if kwargs.get("protocol") in ["3", 3]: @@ -399,33 +361,27 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand if host and port: startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs)) - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher - self.nodes_manager = NodesManager( startup_nodes, require_full_coverage, kwargs, - dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, - event_dispatcher=self._event_dispatcher, ) self.encoder = Encoder(encoding, encoding_errors, decode_responses) self.read_from_replicas = read_from_replicas - self.load_balancing_strategy = load_balancing_strategy self.reinitialize_steps = reinitialize_steps + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.connection_error_retry_attempts = connection_error_retry_attempts self.reinitialize_counter = 0 self.commands_parser = AsyncCommandsParser() self.node_flags = self.__class__.NODE_FLAGS.copy() self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.response_callbacks = kwargs["response_callbacks"] self.result_callbacks = self.__class__.RESULT_CALLBACKS.copy() - self.result_callbacks["CLUSTER SLOTS"] = ( - lambda cmd, res, **kwargs: parse_cluster_slots( - list(res.values())[0], **kwargs - ) + self.result_callbacks[ + "CLUSTER SLOTS" + ] = lambda cmd, res, **kwargs: parse_cluster_slots( + list(res.values())[0], **kwargs ) self._initialize = True @@ -586,8 +542,15 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand """Get the kwargs passed to :class:`~redis.asyncio.connection.Connection`.""" return self.connection_kwargs - def set_retry(self, retry: Retry) -> None: + def get_retry(self) -> Optional["Retry"]: + return self.retry + + def set_retry(self, retry: "Retry") -> None: self.retry = retry + for node in self.get_nodes(): + node.connection_kwargs.update({"retry": retry}) + for conn in node._connections: + conn.retry = retry def set_response_callback(self, command: str, callback: ResponseCallbackT) -> None: """Set a custom response callback.""" @@ -624,7 +587,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand self.nodes_manager.get_node_from_slot( await self._determine_slot(command, *args), self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, ) ] @@ -706,8 +668,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand """ Execute a raw command on the appropriate cluster node or target_nodes. - It will retry the command as specified by the retries property of - the :attr:`retry` & then raise an exception. + It will retry the command as specified by :attr:`cluster_error_retry_attempts` & + then raise an exception. :param args: | Raw command args @@ -723,7 +685,7 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand command = args[0] target_nodes = [] target_nodes_specified = False - retry_attempts = self.retry.get_retries() + retry_attempts = self.cluster_error_retry_attempts passed_targets = kwargs.pop("target_nodes", None) if passed_targets and not self._is_node_flag(passed_targets): @@ -805,22 +767,12 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand # refresh the target node slot = await self._determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, - self.read_from_replicas and args[0] in READ_COMMANDS, - self.load_balancing_strategy - if args[0] in READ_COMMANDS - else None, + slot, self.read_from_replicas and args[0] in READ_COMMANDS ) moved = False return await target_node.execute_command(*args, **kwargs) - except BusyLoadingError: - raise - except MaxConnectionsError: - # MaxConnectionsError indicates client-side resource exhaustion - # (too many connections in the pool), not a node failure. - # Don't treat this as a node failure - just re-raise the error - # without reinitializing the cluster. + except (BusyLoadingError, MaxConnectionsError): raise except (ConnectionError, TimeoutError): # Connection retries are being handled in the node's @@ -832,16 +784,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand # and try again with the new setup await self.aclose() raise - except (ClusterDownError, SlotNotCoveredError): + except ClusterDownError: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command - - # SlotNotCoveredError can occur when the cluster is not fully - # initialized or can be temporary issue. - # We will try to reinitialize the cluster topology - # and retry executing the command - await self.aclose() await asyncio.sleep(0.25) raise @@ -887,7 +833,10 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand if shard_hint: raise RedisClusterException("shard_hint is deprecated in cluster mode") - return ClusterPipeline(self, transaction) + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + + return ClusterPipeline(self) def lock( self, @@ -898,7 +847,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand blocking_timeout: Optional[float] = None, lock_class: Optional[Type[Lock]] = None, thread_local: bool = True, - raise_on_release_error: bool = True, ) -> Lock: """ Return a new Lock object using key ``name`` that mimics @@ -945,11 +893,6 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -967,33 +910,8 @@ class RedisCluster(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommand blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, - raise_on_release_error=raise_on_release_error, ) - async def transaction( - self, func: Coroutine[None, "ClusterPipeline", Any], *watches, **kwargs - ): - """ - Convenience method for executing the callable `func` as a transaction - while watching all keys specified in `watches`. The 'func' callable - should expect a single argument which is a Pipeline object. - """ - shard_hint = kwargs.pop("shard_hint", None) - value_from_callable = kwargs.pop("value_from_callable", False) - watch_delay = kwargs.pop("watch_delay", None) - async with self.pipeline(True, shard_hint) as pipe: - while True: - try: - if watches: - await pipe.watch(*watches) - func_value = await func(pipe) - exec_value = await pipe.execute() - return func_value if value_from_callable else exec_value - except WatchError: - if watch_delay is not None and watch_delay > 0: - time.sleep(watch_delay) - continue - class ClusterNode: """ @@ -1006,8 +924,6 @@ class ClusterNode: __slots__ = ( "_connections", "_free", - "_lock", - "_event_dispatcher", "connection_class", "connection_kwargs", "host", @@ -1045,9 +961,6 @@ class ClusterNode: self._connections: List[Connection] = [] self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections) - self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) - if self._event_dispatcher is None: - self._event_dispatcher = EventDispatcher() def __repr__(self) -> str: return ( @@ -1093,34 +1006,12 @@ class ClusterNode: return self._free.popleft() except IndexError: if len(self._connections) < self.max_connections: - # We are configuring the connection pool not to retry - # connections on lower level clients to avoid retrying - # connections to nodes that are not reachable - # and to avoid blocking the connection pool. - # The only error that will have some handling in the lower - # level clients is ConnectionError which will trigger disconnection - # of the socket. - # The retries will be handled on cluster client level - # where we will have proper handling of the cluster topology - retry = Retry( - backoff=NoBackoff(), - retries=0, - supported_errors=(ConnectionError,), - ) - connection_kwargs = self.connection_kwargs.copy() - connection_kwargs["retry"] = retry - connection = self.connection_class(**connection_kwargs) + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection raise MaxConnectionsError() - def release(self, connection: Connection) -> None: - """ - Release connection back to free queue. - """ - self._free.append(connection) - async def parse_response( self, connection: Connection, command: str, **kwargs: Any ) -> Any: @@ -1138,9 +1029,6 @@ class ClusterNode: if EMPTY_RESPONSE in kwargs: kwargs.pop(EMPTY_RESPONSE) - # Remove keys entry, it needs only for cache. - kwargs.pop("keys", None) - # Return response if command in self.response_callbacks: return self.response_callbacks[command](response, **kwargs) @@ -1186,39 +1074,10 @@ class ClusterNode: return ret - async def re_auth_callback(self, token: TokenInterface): - tmp_queue = collections.deque() - while self._free: - conn = self._free.popleft() - await conn.retry.call_with_retry( - lambda: conn.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ), - lambda error: self._mock(error), - ) - await conn.retry.call_with_retry( - lambda: conn.read_response(), lambda error: self._mock(error) - ) - tmp_queue.append(conn) - - while tmp_queue: - conn = tmp_queue.popleft() - self._free.append(conn) - - async def _mock(self, error: RedisError): - """ - Dummy functions, needs to be passed as error callback to retry object. - :param error: - :return: - """ - pass - class NodesManager: __slots__ = ( - "_dynamic_startup_nodes", "_moved_exception", - "_event_dispatcher", "connection_kwargs", "default_node", "nodes_cache", @@ -1234,9 +1093,7 @@ class NodesManager: startup_nodes: List["ClusterNode"], require_full_coverage: bool, connection_kwargs: Dict[str, Any], - dynamic_startup_nodes: bool = True, - address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - event_dispatcher: Optional[EventDispatcher] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, ) -> None: self.startup_nodes = {node.name: node for node in startup_nodes} self.require_full_coverage = require_full_coverage @@ -1247,13 +1104,7 @@ class NodesManager: self.nodes_cache: Dict[str, "ClusterNode"] = {} self.slots_cache: Dict[int, List["ClusterNode"]] = {} self.read_load_balancer = LoadBalancer() - - self._dynamic_startup_nodes: bool = dynamic_startup_nodes self._moved_exception: MovedError = None - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher def get_node( self, @@ -1270,7 +1121,9 @@ class NodesManager: return self.nodes_cache.get(node_name) else: raise DataError( - "get_node requires one of the following: 1. node name 2. host and port" + "get_node requires one of the following: " + "1. node name " + "2. host and port" ) def set_nodes( @@ -1291,9 +1144,6 @@ class NodesManager: task = asyncio.create_task(old[name].disconnect()) # noqa old[name] = node - def update_moved_exception(self, exception): - self._moved_exception = exception - def _update_moved_slots(self) -> None: e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) @@ -1333,23 +1183,17 @@ class NodesManager: self._moved_exception = None def get_node_from_slot( - self, - slot: int, - read_from_replicas: bool = False, - load_balancing_strategy=None, + self, slot: int, read_from_replicas: bool = False ) -> "ClusterNode": if self._moved_exception: self._update_moved_slots() - if read_from_replicas is True and load_balancing_strategy is None: - load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN - try: - if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: - # get the server index using the strategy defined in load_balancing_strategy + if read_from_replicas: + # get the server index in a Round-Robin manner primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]), load_balancing_strategy + primary_name, len(self.slots_cache[slot]) ) return self.slots_cache[slot][node_idx] return self.slots_cache[slot][0] @@ -1374,23 +1218,16 @@ class NodesManager: startup_nodes_reachable = False fully_covered = False exception = None - # Convert to tuple to prevent RuntimeError if self.startup_nodes - # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): + for startup_node in self.startup_nodes.values(): try: # Make sure cluster mode is enabled on this node - try: - self._event_dispatcher.dispatch( - AfterAsyncClusterInstantiationEvent( - self.nodes_cache, - self.connection_kwargs.get("credential_provider", None), - ) - ) - cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") - except ResponseError: + if not (await startup_node.execute_command("INFO")).get( + "cluster_enabled" + ): raise RedisClusterException( "Cluster mode is not enabled on this node" ) + cluster_slots = await startup_node.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True except Exception as e: # Try the next startup node. @@ -1422,8 +1259,6 @@ class NodesManager: port = int(primary_node[1]) host, port = self.remap_host_port(host, port) - nodes_for_slot = [] - target_node = tmp_nodes_cache.get(get_node_name(host, port)) if not target_node: target_node = ClusterNode( @@ -1431,26 +1266,30 @@ class NodesManager: ) # add this node to the nodes cache tmp_nodes_cache[target_node.name] = target_node - nodes_for_slot.append(target_node) - - replica_nodes = slot[3:] - for replica_node in replica_nodes: - host = replica_node[0] - port = replica_node[1] - host, port = self.remap_host_port(host, port) - - target_replica_node = tmp_nodes_cache.get(get_node_name(host, port)) - if not target_replica_node: - target_replica_node = ClusterNode( - host, port, REPLICA, **self.connection_kwargs - ) - # add this node to the nodes cache - tmp_nodes_cache[target_replica_node.name] = target_replica_node - nodes_for_slot.append(target_replica_node) for i in range(int(slot[0]), int(slot[1]) + 1): if i not in tmp_slots: - tmp_slots[i] = nodes_for_slot + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + host, port = self.remap_host_port(host, port) + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port) + ) + if not target_replica_node: + target_replica_node = ClusterNode( + host, port, REPLICA, **self.connection_kwargs + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node else: # Validate that 2 nodes want to use the same slot cache # setup @@ -1463,7 +1302,7 @@ class NodesManager: if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f"slots cache: {', '.join(disagreements)}" + f'slots cache: {", ".join(disagreements)}' ) # Validate if all slots are covered or if we should try next startup node @@ -1494,10 +1333,8 @@ class NodesManager: # Set the tmp variables to the real variables self.slots_cache = tmp_slots self.set_nodes(self.nodes_cache, tmp_nodes_cache, remove_old=True) - - if self._dynamic_startup_nodes: - # Populate the startup nodes with all discovered nodes - self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) + # Populate the startup nodes with all discovered nodes + self.set_nodes(self.startup_nodes, self.nodes_cache, remove_old=True) # Set the default node self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] @@ -1561,38 +1398,40 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm | Existing :class:`~.RedisCluster` client """ - __slots__ = ("cluster_client", "_transaction", "_execution_strategy") + __slots__ = ("_command_stack", "_client") - def __init__( - self, client: RedisCluster, transaction: Optional[bool] = None - ) -> None: - self.cluster_client = client - self._transaction = transaction - self._execution_strategy: ExecutionStrategy = ( - PipelineStrategy(self) - if not self._transaction - else TransactionStrategy(self) - ) + def __init__(self, client: RedisCluster) -> None: + self._client = client + + self._command_stack: List["PipelineCommand"] = [] async def initialize(self) -> "ClusterPipeline": - await self._execution_strategy.initialize() + if self._client._initialize: + await self._client.initialize() + self._command_stack = [] return self async def __aenter__(self) -> "ClusterPipeline": return await self.initialize() async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - await self.reset() + self._command_stack = [] def __await__(self) -> Generator[Any, None, "ClusterPipeline"]: return self.initialize().__await__() + def __enter__(self) -> "ClusterPipeline": + self._command_stack = [] + return self + + def __exit__(self, exc_type: None, exc_value: None, traceback: None) -> None: + self._command_stack = [] + def __bool__(self) -> bool: - "Pipeline instances should always evaluate to True on Python 3+" - return True + return bool(self._command_stack) def __len__(self) -> int: - return len(self._execution_strategy) + return len(self._command_stack) def execute_command( self, *args: Union[KeyT, EncodableT], **kwargs: Any @@ -1608,7 +1447,10 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`] - Rest of the kwargs are passed to the Redis connection """ - return self._execution_strategy.execute_command(*args, **kwargs) + self._command_stack.append( + PipelineCommand(len(self._command_stack), *args, **kwargs) + ) + return self async def execute( self, raise_on_error: bool = True, allow_redirections: bool = True @@ -1616,7 +1458,7 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm """ Execute the pipeline. - It will retry the commands as specified by retries specified in :attr:`retry` + It will retry the commands as specified by :attr:`cluster_error_retry_attempts` & then raise an exception. :param raise_on_error: @@ -1628,294 +1470,35 @@ class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterComm :raises RedisClusterException: if target_nodes is not provided & the command can't be mapped to a slot """ - try: - return await self._execution_strategy.execute( - raise_on_error, allow_redirections - ) - finally: - await self.reset() - - def _split_command_across_slots( - self, command: str, *keys: KeyT - ) -> "ClusterPipeline": - for slot_keys in self.cluster_client._partition_keys_by_slot(keys).values(): - self.execute_command(command, *slot_keys) - - return self - - async def reset(self): - """ - Reset back to empty pipeline. - """ - await self._execution_strategy.reset() - - def multi(self): - """ - Start a transactional block of the pipeline after WATCH commands - are issued. End the transactional block with `execute`. - """ - self._execution_strategy.multi() - - async def discard(self): - """ """ - await self._execution_strategy.discard() - - async def watch(self, *names): - """Watches the values at keys ``names``""" - await self._execution_strategy.watch(*names) - - async def unwatch(self): - """Unwatches all previously specified keys""" - await self._execution_strategy.unwatch() - - async def unlink(self, *names): - await self._execution_strategy.unlink(*names) - - def mset_nonatomic( - self, mapping: Mapping[AnyKeyT, EncodableT] - ) -> "ClusterPipeline": - return self._execution_strategy.mset_nonatomic(mapping) - - -for command in PIPELINE_BLOCKED_COMMANDS: - command = command.replace(" ", "_").lower() - if command == "mset_nonatomic": - continue - - setattr(ClusterPipeline, command, block_pipeline_command(command)) - - -class PipelineCommand: - def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: - self.args = args - self.kwargs = kwargs - self.position = position - self.result: Union[Any, Exception] = None - - def __repr__(self) -> str: - return f"[{self.position}] {self.args} ({self.kwargs})" - - -class ExecutionStrategy(ABC): - @abstractmethod - async def initialize(self) -> "ClusterPipeline": - """ - Initialize the execution strategy. - - See ClusterPipeline.initialize() - """ - pass - - @abstractmethod - def execute_command( - self, *args: Union[KeyT, EncodableT], **kwargs: Any - ) -> "ClusterPipeline": - """ - Append a raw command to the pipeline. - - See ClusterPipeline.execute_command() - """ - pass - - @abstractmethod - async def execute( - self, raise_on_error: bool = True, allow_redirections: bool = True - ) -> List[Any]: - """ - Execute the pipeline. - - It will retry the commands as specified by retries specified in :attr:`retry` - & then raise an exception. - - See ClusterPipeline.execute() - """ - pass - - @abstractmethod - def mset_nonatomic( - self, mapping: Mapping[AnyKeyT, EncodableT] - ) -> "ClusterPipeline": - """ - Executes multiple MSET commands according to the provided slot/pairs mapping. - - See ClusterPipeline.mset_nonatomic() - """ - pass - - @abstractmethod - async def reset(self): - """ - Resets current execution strategy. - - See: ClusterPipeline.reset() - """ - pass - - @abstractmethod - def multi(self): - """ - Starts transactional context. - - See: ClusterPipeline.multi() - """ - pass - - @abstractmethod - async def watch(self, *names): - """ - Watch given keys. - - See: ClusterPipeline.watch() - """ - pass - - @abstractmethod - async def unwatch(self): - """ - Unwatches all previously specified keys - - See: ClusterPipeline.unwatch() - """ - pass - - @abstractmethod - async def discard(self): - pass - - @abstractmethod - async def unlink(self, *names): - """ - "Unlink a key specified by ``names``" - - See: ClusterPipeline.unlink() - """ - pass - - @abstractmethod - def __len__(self) -> int: - pass - - -class AbstractStrategy(ExecutionStrategy): - def __init__(self, pipe: ClusterPipeline) -> None: - self._pipe: ClusterPipeline = pipe - self._command_queue: List["PipelineCommand"] = [] - - async def initialize(self) -> "ClusterPipeline": - if self._pipe.cluster_client._initialize: - await self._pipe.cluster_client.initialize() - self._command_queue = [] - return self._pipe - - def execute_command( - self, *args: Union[KeyT, EncodableT], **kwargs: Any - ) -> "ClusterPipeline": - self._command_queue.append( - PipelineCommand(len(self._command_queue), *args, **kwargs) - ) - return self._pipe - - def _annotate_exception(self, exception, number, command): - """ - Provides extra context to the exception prior to it being handled - """ - cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({truncate_text(cmd)}) of pipeline " - f"caused error: {exception.args[0]}" - ) - exception.args = (msg,) + exception.args[1:] - - @abstractmethod - def mset_nonatomic( - self, mapping: Mapping[AnyKeyT, EncodableT] - ) -> "ClusterPipeline": - pass - - @abstractmethod - async def execute( - self, raise_on_error: bool = True, allow_redirections: bool = True - ) -> List[Any]: - pass - - @abstractmethod - async def reset(self): - pass - - @abstractmethod - def multi(self): - pass - - @abstractmethod - async def watch(self, *names): - pass - - @abstractmethod - async def unwatch(self): - pass - - @abstractmethod - async def discard(self): - pass - - @abstractmethod - async def unlink(self, *names): - pass - - def __len__(self) -> int: - return len(self._command_queue) - - -class PipelineStrategy(AbstractStrategy): - def __init__(self, pipe: ClusterPipeline) -> None: - super().__init__(pipe) - - def mset_nonatomic( - self, mapping: Mapping[AnyKeyT, EncodableT] - ) -> "ClusterPipeline": - encoder = self._pipe.cluster_client.encoder - - slots_pairs = {} - for pair in mapping.items(): - slot = key_slot(encoder.encode(pair[0])) - slots_pairs.setdefault(slot, []).extend(pair) - - for pairs in slots_pairs.values(): - self.execute_command("MSET", *pairs) - - return self._pipe - - async def execute( - self, raise_on_error: bool = True, allow_redirections: bool = True - ) -> List[Any]: - if not self._command_queue: + if not self._command_stack: return [] try: - retry_attempts = self._pipe.cluster_client.retry.get_retries() - while True: + for _ in range(self._client.cluster_error_retry_attempts): + if self._client._initialize: + await self._client.initialize() + try: - if self._pipe.cluster_client._initialize: - await self._pipe.cluster_client.initialize() return await self._execute( - self._pipe.cluster_client, - self._command_queue, + self._client, + self._command_stack, raise_on_error=raise_on_error, allow_redirections=allow_redirections, ) - - except RedisCluster.ERRORS_ALLOW_RETRY as e: - if retry_attempts > 0: - # Try again with the new cluster setup. All other errors - # should be raised. - retry_attempts -= 1 - await self._pipe.cluster_client.aclose() + except BaseException as e: + if type(e) in self.__class__.ERRORS_ALLOW_RETRY: + # Try again with the new cluster setup. + exception = e + await self._client.aclose() await asyncio.sleep(0.25) else: # All other errors should be raised. - raise e + raise + + # If it fails the configured number of times then raise an exception + raise exception finally: - await self.reset() + self._command_stack = [] async def _execute( self, @@ -1973,424 +1556,65 @@ class PipelineStrategy(AbstractStrategy): if isinstance(result, Exception): command = " ".join(map(safe_str, cmd.args)) msg = ( - f"Command # {cmd.position + 1} " - f"({truncate_text(command)}) " - f"of pipeline caused error: {result.args}" + f"Command # {cmd.position + 1} ({command}) of pipeline " + f"caused error: {result.args}" ) result.args = (msg,) + result.args[1:] raise result - default_cluster_node = client.get_default_node() - - # Check whether the default node was used. In some cases, - # 'client.get_default_node()' may return None. The check below - # prevents a potential AttributeError. - if default_cluster_node is not None: - default_node = nodes.get(default_cluster_node.name) - if default_node is not None: - # This pipeline execution used the default node, check if we need - # to replace it. - # Note: when the error is raised we'll reset the default node in the - # caller function. - for cmd in default_node[1]: - # Check if it has a command that failed with a relevant - # exception - if type(cmd.result) in RedisCluster.ERRORS_ALLOW_RETRY: - client.replace_default_node() - break + default_node = nodes.get(client.get_default_node().name) + if default_node is not None: + # This pipeline execution used the default node, check if we need + # to replace it. + # Note: when the error is raised we'll reset the default node in the + # caller function. + for cmd in default_node[1]: + # Check if it has a command that failed with a relevant + # exception + if type(cmd.result) in self.__class__.ERRORS_ALLOW_RETRY: + client.replace_default_node() + break return [cmd.result for cmd in stack] - async def reset(self): - """ - Reset back to empty pipeline. - """ - self._command_queue = [] + def _split_command_across_slots( + self, command: str, *keys: KeyT + ) -> "ClusterPipeline": + for slot_keys in self._client._partition_keys_by_slot(keys).values(): + self.execute_command(command, *slot_keys) - def multi(self): - raise RedisClusterException( - "method multi() is not supported outside of transactional context" - ) - - async def watch(self, *names): - raise RedisClusterException( - "method watch() is not supported outside of transactional context" - ) - - async def unwatch(self): - raise RedisClusterException( - "method unwatch() is not supported outside of transactional context" - ) - - async def discard(self): - raise RedisClusterException( - "method discard() is not supported outside of transactional context" - ) - - async def unlink(self, *names): - if len(names) != 1: - raise RedisClusterException( - "unlinking multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("UNLINK", names[0]) - - -class TransactionStrategy(AbstractStrategy): - NO_SLOTS_COMMANDS = {"UNWATCH"} - IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} - UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - SLOT_REDIRECT_ERRORS = (AskError, MovedError) - CONNECTION_ERRORS = ( - ConnectionError, - OSError, - ClusterDownError, - SlotNotCoveredError, - ) - - def __init__(self, pipe: ClusterPipeline) -> None: - super().__init__(pipe) - self._explicit_transaction = False - self._watching = False - self._pipeline_slots: Set[int] = set() - self._transaction_node: Optional[ClusterNode] = None - self._transaction_connection: Optional[Connection] = None - self._executing = False - self._retry = copy(self._pipe.cluster_client.retry) - self._retry.update_supported_errors( - RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS - ) - - def _get_client_and_connection_for_transaction( - self, - ) -> Tuple[ClusterNode, Connection]: - """ - Find a connection for a pipeline transaction. - - For running an atomic transaction, watch keys ensure that contents have not been - altered as long as the watch commands for those keys were sent over the same - connection. So once we start watching a key, we fetch a connection to the - node that owns that slot and reuse it. - """ - if not self._pipeline_slots: - raise RedisClusterException( - "At least a command with a key is needed to identify a node" - ) - - node: ClusterNode = self._pipe.cluster_client.nodes_manager.get_node_from_slot( - list(self._pipeline_slots)[0], False - ) - self._transaction_node = node - - if not self._transaction_connection: - connection: Connection = self._transaction_node.acquire_connection() - self._transaction_connection = connection - - return self._transaction_node, self._transaction_connection - - def execute_command(self, *args: Union[KeyT, EncodableT], **kwargs: Any) -> "Any": - # Given the limitation of ClusterPipeline sync API, we have to run it in thread. - response = None - error = None - - def runner(): - nonlocal response - nonlocal error - try: - response = asyncio.run(self._execute_command(*args, **kwargs)) - except Exception as e: - error = e - - thread = threading.Thread(target=runner) - thread.start() - thread.join() - - if error: - raise error - - return response - - async def _execute_command( - self, *args: Union[KeyT, EncodableT], **kwargs: Any - ) -> Any: - if self._pipe.cluster_client._initialize: - await self._pipe.cluster_client.initialize() - - slot_number: Optional[int] = None - if args[0] not in self.NO_SLOTS_COMMANDS: - slot_number = await self._pipe.cluster_client._determine_slot(*args) - - if ( - self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS - ) and not self._explicit_transaction: - if args[0] == "WATCH": - self._validate_watch() - - if slot_number is not None: - if self._pipeline_slots and slot_number not in self._pipeline_slots: - raise CrossSlotTransactionError( - "Cannot watch or send commands on different slots" - ) - - self._pipeline_slots.add(slot_number) - elif args[0] not in self.NO_SLOTS_COMMANDS: - raise RedisClusterException( - f"Cannot identify slot number for command: {args[0]}," - "it cannot be triggered in a transaction" - ) - - return self._immediate_execute_command(*args, **kwargs) - else: - if slot_number is not None: - self._pipeline_slots.add(slot_number) - - return super().execute_command(*args, **kwargs) - - def _validate_watch(self): - if self._explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") - - self._watching = True - - async def _immediate_execute_command(self, *args, **options): - return await self._retry.call_with_retry( - lambda: self._get_connection_and_send_command(*args, **options), - self._reinitialize_on_error, - ) - - async def _get_connection_and_send_command(self, *args, **options): - redis_node, connection = self._get_client_and_connection_for_transaction() - return await self._send_command_parse_response( - connection, redis_node, args[0], *args, **options - ) - - async def _send_command_parse_response( - self, - connection: Connection, - redis_node: ClusterNode, - command_name, - *args, - **options, - ): - """ - Send a command and parse the response - """ - - await connection.send_command(*args) - output = await redis_node.parse_response(connection, command_name, **options) - - if command_name in self.UNWATCH_COMMANDS: - self._watching = False - return output - - async def _reinitialize_on_error(self, error): - if self._watching: - if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: - raise WatchError("Slot rebalancing occurred while watching keys") - - if ( - type(error) in self.SLOT_REDIRECT_ERRORS - or type(error) in self.CONNECTION_ERRORS - ): - if self._transaction_connection: - self._transaction_connection = None - - self._pipe.cluster_client.reinitialize_counter += 1 - if ( - self._pipe.cluster_client.reinitialize_steps - and self._pipe.cluster_client.reinitialize_counter - % self._pipe.cluster_client.reinitialize_steps - == 0 - ): - await self._pipe.cluster_client.nodes_manager.initialize() - self.reinitialize_counter = 0 - else: - self._pipe.cluster_client.nodes_manager.update_moved_exception(error) - - self._executing = False - - def _raise_first_error(self, responses, stack): - """ - Raise the first exception on the stack - """ - for r, cmd in zip(responses, stack): - if isinstance(r, Exception): - self._annotate_exception(r, cmd.position + 1, cmd.args) - raise r + return self def mset_nonatomic( self, mapping: Mapping[AnyKeyT, EncodableT] ) -> "ClusterPipeline": - raise NotImplementedError("Method is not supported in transactional context.") + encoder = self._client.encoder - async def execute( - self, raise_on_error: bool = True, allow_redirections: bool = True - ) -> List[Any]: - stack = self._command_queue - if not stack and (not self._watching or not self._pipeline_slots): - return [] + slots_pairs = {} + for pair in mapping.items(): + slot = key_slot(encoder.encode(pair[0])) + slots_pairs.setdefault(slot, []).extend(pair) - return await self._execute_transaction_with_retries(stack, raise_on_error) + for pairs in slots_pairs.values(): + self.execute_command("MSET", *pairs) - async def _execute_transaction_with_retries( - self, stack: List["PipelineCommand"], raise_on_error: bool - ): - return await self._retry.call_with_retry( - lambda: self._execute_transaction(stack, raise_on_error), - self._reinitialize_on_error, - ) + return self - async def _execute_transaction( - self, stack: List["PipelineCommand"], raise_on_error: bool - ): - if len(self._pipeline_slots) > 1: - raise CrossSlotTransactionError( - "All keys involved in a cluster transaction must map to the same slot" - ) - self._executing = True +for command in PIPELINE_BLOCKED_COMMANDS: + command = command.replace(" ", "_").lower() + if command == "mset_nonatomic": + continue - redis_node, connection = self._get_client_and_connection_for_transaction() + setattr(ClusterPipeline, command, block_pipeline_command(command)) - stack = chain( - [PipelineCommand(0, "MULTI")], - stack, - [PipelineCommand(0, "EXEC")], - ) - commands = [c.args for c in stack if EMPTY_RESPONSE not in c.kwargs] - packed_commands = connection.pack_commands(commands) - await connection.send_packed_command(packed_commands) - errors = [] - # parse off the response for MULTI - # NOTE: we need to handle ResponseErrors here and continue - # so that we read all the additional command messages from - # the socket - try: - await redis_node.parse_response(connection, "MULTI") - except ResponseError as e: - self._annotate_exception(e, 0, "MULTI") - errors.append(e) - except self.CONNECTION_ERRORS as cluster_error: - self._annotate_exception(cluster_error, 0, "MULTI") - raise +class PipelineCommand: + def __init__(self, position: int, *args: Any, **kwargs: Any) -> None: + self.args = args + self.kwargs = kwargs + self.position = position + self.result: Union[Any, Exception] = None - # and all the other commands - for i, command in enumerate(self._command_queue): - if EMPTY_RESPONSE in command.kwargs: - errors.append((i, command.kwargs[EMPTY_RESPONSE])) - else: - try: - _ = await redis_node.parse_response(connection, "_") - except self.SLOT_REDIRECT_ERRORS as slot_error: - self._annotate_exception(slot_error, i + 1, command.args) - errors.append(slot_error) - except self.CONNECTION_ERRORS as cluster_error: - self._annotate_exception(cluster_error, i + 1, command.args) - raise - except ResponseError as e: - self._annotate_exception(e, i + 1, command.args) - errors.append(e) - - response = None - # parse the EXEC. - try: - response = await redis_node.parse_response(connection, "EXEC") - except ExecAbortError: - if errors: - raise errors[0] - raise - - self._executing = False - - # EXEC clears any watched keys - self._watching = False - - if response is None: - raise WatchError("Watched variable changed.") - - # put any parse errors into the response - for i, e in errors: - response.insert(i, e) - - if len(response) != len(self._command_queue): - raise InvalidPipelineStack( - "Unexpected response length for cluster pipeline EXEC." - " Command stack was {} but response had length {}".format( - [c.args[0] for c in self._command_queue], len(response) - ) - ) - - # find any errors in the response and raise if necessary - if raise_on_error or len(errors) > 0: - self._raise_first_error( - response, - self._command_queue, - ) - - # We have to run response callbacks manually - data = [] - for r, cmd in zip(response, self._command_queue): - if not isinstance(r, Exception): - command_name = cmd.args[0] - if command_name in self._pipe.cluster_client.response_callbacks: - r = self._pipe.cluster_client.response_callbacks[command_name]( - r, **cmd.kwargs - ) - data.append(r) - return data - - async def reset(self): - self._command_queue = [] - - # make sure to reset the connection state in the event that we were - # watching something - if self._transaction_connection: - try: - if self._watching: - # call this manually since our unwatch or - # immediate_execute_command methods can call reset() - await self._transaction_connection.send_command("UNWATCH") - await self._transaction_connection.read_response() - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - self._transaction_node.release(self._transaction_connection) - self._transaction_connection = None - except self.CONNECTION_ERRORS: - # disconnect will also remove any previous WATCHes - if self._transaction_connection: - await self._transaction_connection.disconnect() - - # clean up the other instance attributes - self._transaction_node = None - self._watching = False - self._explicit_transaction = False - self._pipeline_slots = set() - self._executing = False - - def multi(self): - if self._explicit_transaction: - raise RedisError("Cannot issue nested calls to MULTI") - if self._command_queue: - raise RedisError( - "Commands without an initial WATCH have already been issued" - ) - self._explicit_transaction = True - - async def watch(self, *names): - if self._explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") - - return await self.execute_command("WATCH", *names) - - async def unwatch(self): - if self._watching: - return await self.execute_command("UNWATCH") - - return True - - async def discard(self): - await self.reset() - - async def unlink(self, *names): - return self.execute_command("UNLINK", *names) + def __repr__(self) -> str: + return f"[{self.position}] {self.args} ({self.kwargs})" diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/connection.py b/venv/lib/python3.12/site-packages/redis/asyncio/connection.py index 4efd868..65fa586 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/connection.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/connection.py @@ -3,8 +3,8 @@ import copy import enum import inspect import socket +import ssl import sys -import warnings import weakref from abc import abstractmethod from itertools import chain @@ -16,30 +16,14 @@ from typing import ( List, Mapping, Optional, - Protocol, Set, Tuple, Type, - TypedDict, TypeVar, Union, ) from urllib.parse import ParseResult, parse_qs, unquote, urlparse -from ..utils import SSL_AVAILABLE - -if SSL_AVAILABLE: - import ssl - from ssl import SSLContext, TLSVersion -else: - ssl = None - TLSVersion = None - SSLContext = None - -from ..auth.token import TokenInterface -from ..event import AsyncAfterConnectionReleasedEvent, EventDispatcher -from ..utils import deprecated_args, format_error_message - # the functionality is available in 3.11.x but has a major issue before # 3.11.3. See https://github.com/redis/redis-py/issues/2633 if sys.version_info >= (3, 11, 3): @@ -49,6 +33,7 @@ else: from redis.asyncio.retry import Retry from redis.backoff import NoBackoff +from redis.compat import Protocol, TypedDict from redis.connection import DEFAULT_RESP_VERSION from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( @@ -93,11 +78,13 @@ else: class ConnectCallbackProtocol(Protocol): - def __call__(self, connection: "AbstractConnection"): ... + def __call__(self, connection: "AbstractConnection"): + ... class AsyncConnectCallbackProtocol(Protocol): - async def __call__(self, connection: "AbstractConnection"): ... + async def __call__(self, connection: "AbstractConnection"): + ... ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] @@ -159,7 +146,6 @@ class AbstractConnection: encoder_class: Type[Encoder] = Encoder, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - event_dispatcher: Optional[EventDispatcher] = None, ): if (username or password) and credential_provider is not None: raise DataError( @@ -168,10 +154,6 @@ class AbstractConnection: "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher self.db = db self.client_name = client_name self.lib_name = lib_name @@ -211,8 +193,6 @@ class AbstractConnection: self.set_parser(parser_class) self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = [] self._buffer_cutoff = 6000 - self._re_auth_token: Optional[TokenInterface] = None - try: p = int(protocol) except TypeError: @@ -224,33 +204,9 @@ class AbstractConnection: raise ConnectionError("protocol must be either 2 or 3") self.protocol = protocol - def __del__(self, _warnings: Any = warnings): - # For some reason, the individual streams don't get properly garbage - # collected and therefore produce no resource warnings. We add one - # here, in the same style as those from the stdlib. - if getattr(self, "_writer", None): - _warnings.warn( - f"unclosed Connection {self!r}", ResourceWarning, source=self - ) - - try: - asyncio.get_running_loop() - self._close() - except RuntimeError: - # No actions been taken if pool already closed. - pass - - def _close(self): - """ - Internal method to silently close the connection without waiting - """ - if self._writer: - self._writer.close() - self._writer = self._reader = None - def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) - return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" + return f"{self.__class__.__name__}<{repr_args}>" @abstractmethod def repr_pieces(self): @@ -260,24 +216,12 @@ class AbstractConnection: def is_connected(self): return self._reader is not None and self._writer is not None - def register_connect_callback(self, callback): - """ - Register a callback to be called when the connection is established either - initially or reconnected. This allows listeners to issue commands that - are ephemeral to the connection, for example pub/sub subscription or - key tracking. The callback must be a _method_ and will be kept as - a weak reference. - """ + def _register_connect_callback(self, callback): wm = weakref.WeakMethod(callback) if wm not in self._connect_callbacks: self._connect_callbacks.append(wm) - def deregister_connect_callback(self, callback): - """ - De-register a previously registered callback. It will no-longer receive - notifications on connection events. Calling this is not required when the - listener goes away, since the callbacks are kept as weak methods. - """ + def _deregister_connect_callback(self, callback): try: self._connect_callbacks.remove(weakref.WeakMethod(callback)) except ValueError: @@ -293,20 +237,12 @@ class AbstractConnection: async def connect(self): """Connects to the Redis server if not already connected""" - await self.connect_check_health(check_health=True) - - async def connect_check_health( - self, check_health: bool = True, retry_socket_connect: bool = True - ): if self.is_connected: return try: - if retry_socket_connect: - await self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect() - ) - else: - await self._connect() + await self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect() + ) except asyncio.CancelledError: raise # in 3.7 and earlier, this is an Exception, not BaseException except (socket.timeout, asyncio.TimeoutError): @@ -319,14 +255,12 @@ class AbstractConnection: try: if not self.redis_connect_func: # Use the default on_connect function - await self.on_connect_check_health(check_health=check_health) + await self.on_connect() else: # Use the passed function redis_connect_func - ( - await self.redis_connect_func(self) - if asyncio.iscoroutinefunction(self.redis_connect_func) - else self.redis_connect_func(self) - ) + await self.redis_connect_func(self) if asyncio.iscoroutinefunction( + self.redis_connect_func + ) else self.redis_connect_func(self) except RedisError: # clean up after any error in on_connect await self.disconnect() @@ -350,17 +284,12 @@ class AbstractConnection: def _host_error(self) -> str: pass + @abstractmethod def _error_message(self, exception: BaseException) -> str: - return format_error_message(self._host_error(), exception) - - def get_protocol(self): - return self.protocol + pass async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" - await self.on_connect_check_health(check_health=True) - - async def on_connect_check_health(self, check_health: bool = True) -> None: self._parser.on_connect(self) parser = self._parser @@ -371,8 +300,7 @@ class AbstractConnection: self.credential_provider or UsernamePasswordCredentialProvider(self.username, self.password) ) - auth_args = await cred_provider.get_credentials_async() - + auth_args = cred_provider.get_credentials() # if resp version is specified and we have auth args, # we need to send them via HELLO if auth_args and self.protocol not in [2, "2"]: @@ -383,11 +311,7 @@ class AbstractConnection: self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH - await self.send_command( - "HELLO", self.protocol, "AUTH", *auth_args, check_health=False - ) + await self.send_command("HELLO", self.protocol, "AUTH", *auth_args) response = await self.read_response() if response.get(b"proto") != int(self.protocol) and response.get( "proto" @@ -418,7 +342,7 @@ class AbstractConnection: # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - await self.send_command("HELLO", self.protocol, check_health=check_health) + await self.send_command("HELLO", self.protocol) response = await self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" @@ -427,35 +351,18 @@ class AbstractConnection: # if a client_name is given, set it if self.client_name: - await self.send_command( - "CLIENT", - "SETNAME", - self.client_name, - check_health=check_health, - ) + await self.send_command("CLIENT", "SETNAME", self.client_name) if str_if_bytes(await self.read_response()) != "OK": raise ConnectionError("Error setting client name") # set the library name and version, pipeline for lower startup latency if self.lib_name: - await self.send_command( - "CLIENT", - "SETINFO", - "LIB-NAME", - self.lib_name, - check_health=check_health, - ) + await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) if self.lib_version: - await self.send_command( - "CLIENT", - "SETINFO", - "LIB-VER", - self.lib_version, - check_health=check_health, - ) + await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) # if a database is specified, switch to it. Also pipeline this if self.db: - await self.send_command("SELECT", self.db, check_health=check_health) + await self.send_command("SELECT", self.db) # read responses from pipeline for _ in (sent for sent in (self.lib_name, self.lib_version) if sent): @@ -517,8 +424,8 @@ class AbstractConnection: self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True ) -> None: if not self.is_connected: - await self.connect_check_health(check_health=False) - if check_health: + await self.connect() + elif check_health: await self.check_health() try: @@ -581,7 +488,11 @@ class AbstractConnection: read_timeout = timeout if timeout is not None else self.socket_timeout host_error = self._host_error() try: - if read_timeout is not None and self.protocol in ["3", 3]: + if ( + read_timeout is not None + and self.protocol in ["3", 3] + and not HIREDIS_AVAILABLE + ): async with async_timeout(read_timeout): response = await self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request @@ -591,7 +502,7 @@ class AbstractConnection: response = await self._parser.read_response( disable_decoding=disable_decoding ) - elif self.protocol in ["3", 3]: + elif self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: response = await self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request ) @@ -703,27 +614,6 @@ class AbstractConnection: output.append(SYM_EMPTY.join(pieces)) return output - def _socket_is_empty(self): - """Check if the socket is empty""" - return len(self._reader._buffer) == 0 - - async def process_invalidation_messages(self): - while not self._socket_is_empty(): - await self.read_response(push_request=True) - - def set_re_auth_token(self, token: TokenInterface): - self._re_auth_token = token - - async def re_auth(self): - if self._re_auth_token is not None: - await self.send_command( - "AUTH", - self._re_auth_token.try_get("oid"), - self._re_auth_token.get_value(), - ) - await self.read_response() - self._re_auth_token = None - class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -781,6 +671,27 @@ class Connection(AbstractConnection): def _host_error(self) -> str: return f"{self.host}:{self.port}" + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return f"Error connecting to {host_error}. Connection reset by peer" + elif len(exception.args) == 1: + return f"Error connecting to {host_error}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[0]}." + ) + class SSLConnection(Connection): """Manages SSL connections to and from the Redis server(s). @@ -792,17 +703,12 @@ class SSLConnection(Connection): self, ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None, - ssl_cert_reqs: Union[str, ssl.VerifyMode] = "required", + ssl_cert_reqs: str = "required", ssl_ca_certs: Optional[str] = None, ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = True, - ssl_min_version: Optional[TLSVersion] = None, - ssl_ciphers: Optional[str] = None, + ssl_check_hostname: bool = False, **kwargs, ): - if not SSL_AVAILABLE: - raise RedisError("Python wasn't built with SSL support") - self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, @@ -810,8 +716,6 @@ class SSLConnection(Connection): ca_certs=ssl_ca_certs, ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, - min_version=ssl_min_version, - ciphers=ssl_ciphers, ) super().__init__(**kwargs) @@ -844,10 +748,6 @@ class SSLConnection(Connection): def check_hostname(self): return self.ssl_context.check_hostname - @property - def min_version(self): - return self.ssl_context.min_version - class RedisSSLContext: __slots__ = ( @@ -858,30 +758,23 @@ class RedisSSLContext: "ca_data", "context", "check_hostname", - "min_version", - "ciphers", ) def __init__( self, keyfile: Optional[str] = None, certfile: Optional[str] = None, - cert_reqs: Optional[Union[str, ssl.VerifyMode]] = None, + cert_reqs: Optional[str] = None, ca_certs: Optional[str] = None, ca_data: Optional[str] = None, check_hostname: bool = False, - min_version: Optional[TLSVersion] = None, - ciphers: Optional[str] = None, ): - if not SSL_AVAILABLE: - raise RedisError("Python wasn't built with SSL support") - self.keyfile = keyfile self.certfile = certfile if cert_reqs is None: - cert_reqs = ssl.CERT_NONE + self.cert_reqs = ssl.CERT_NONE elif isinstance(cert_reqs, str): - CERT_REQS = { # noqa: N806 + CERT_REQS = { "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, @@ -890,18 +783,13 @@ class RedisSSLContext: raise RedisError( f"Invalid SSL Certificate Requirements Flag: {cert_reqs}" ) - cert_reqs = CERT_REQS[cert_reqs] - self.cert_reqs = cert_reqs + self.cert_reqs = CERT_REQS[cert_reqs] self.ca_certs = ca_certs self.ca_data = ca_data - self.check_hostname = ( - check_hostname if self.cert_reqs != ssl.CERT_NONE else False - ) - self.min_version = min_version - self.ciphers = ciphers - self.context: Optional[SSLContext] = None + self.check_hostname = check_hostname + self.context: Optional[ssl.SSLContext] = None - def get(self) -> SSLContext: + def get(self) -> ssl.SSLContext: if not self.context: context = ssl.create_default_context() context.check_hostname = self.check_hostname @@ -910,10 +798,6 @@ class RedisSSLContext: context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) if self.ca_certs or self.ca_data: context.load_verify_locations(cafile=self.ca_certs, cadata=self.ca_data) - if self.min_version is not None: - context.minimum_version = self.min_version - if self.ciphers is not None: - context.set_ciphers(self.ciphers) self.context = context return self.context @@ -941,6 +825,20 @@ class UnixDomainSocketConnection(AbstractConnection): def _host_error(self) -> str: return self.path + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + host_error = self._host_error() + if len(exception.args) == 1: + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{host_error}. {exception.args[1]}." + ) + FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") @@ -963,7 +861,6 @@ URL_QUERY_ARGUMENT_PARSERS: Mapping[str, Callable[..., object]] = MappingProxyTy "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, - "timeout": float, } ) @@ -990,7 +887,7 @@ def parse_url(url: str) -> ConnectKwargs: try: kwargs[name] = parser(value) except (TypeError, ValueError): - raise ValueError(f"Invalid value for '{name}' in connection URL.") + raise ValueError(f"Invalid value for `{name}` in connection URL.") else: kwargs[name] = value @@ -1042,7 +939,6 @@ class ConnectionPool: By default, TCP connections are created unless ``connection_class`` is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for unix sockets. - :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. Any additional keyword arguments are passed to the constructor of ``connection_class``. @@ -1112,22 +1008,16 @@ class ConnectionPool: self._available_connections: List[AbstractConnection] = [] self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) - self._lock = asyncio.Lock() - self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) - if self._event_dispatcher is None: - self._event_dispatcher = EventDispatcher() def __repr__(self): - conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) return ( - f"<{self.__class__.__module__}.{self.__class__.__name__}" - f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" - f"({conn_kwargs})>)>" + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" ) def reset(self): self._available_connections = [] - self._in_use_connections = weakref.WeakSet() + self._in_use_connections = set() def can_get_connection(self) -> bool: """Return True if a connection can be retrieved from the pool.""" @@ -1136,25 +1026,8 @@ class ConnectionPool: or len(self._in_use_connections) < self.max_connections ) - @deprecated_args( - args_to_warn=["*"], - reason="Use get_connection() without args instead", - version="5.3.0", - ) - async def get_connection(self, command_name=None, *keys, **options): - async with self._lock: - """Get a connected connection from the pool""" - connection = self.get_available_connection() - try: - await self.ensure_connection(connection) - except BaseException: - await self.release(connection) - raise - - return connection - - def get_available_connection(self): - """Get a connection from the pool, without making sure it is connected""" + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" try: connection = self._available_connections.pop() except IndexError: @@ -1162,6 +1035,13 @@ class ConnectionPool: raise ConnectionError("Too many connections") from None connection = self.make_connection() self._in_use_connections.add(connection) + + try: + await self.ensure_connection(connection) + except BaseException: + await self.release(connection) + raise + return connection def get_encoder(self): @@ -1187,7 +1067,7 @@ class ConnectionPool: try: if await connection.can_read_destructive(): raise ConnectionError("Connection has data") from None - except (ConnectionError, TimeoutError, OSError): + except (ConnectionError, OSError): await connection.disconnect() await connection.connect() if await connection.can_read_destructive(): @@ -1199,9 +1079,6 @@ class ConnectionPool: # not doing so is an error that will cause an exception here. self._in_use_connections.remove(connection) self._available_connections.append(connection) - await self._event_dispatcher.dispatch_async( - AsyncAfterConnectionReleasedEvent(connection) - ) async def disconnect(self, inuse_connections: bool = True): """ @@ -1235,29 +1112,6 @@ class ConnectionPool: for conn in self._in_use_connections: conn.retry = retry - async def re_auth_callback(self, token: TokenInterface): - async with self._lock: - for conn in self._available_connections: - await conn.retry.call_with_retry( - lambda: conn.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ), - lambda error: self._mock(error), - ) - await conn.retry.call_with_retry( - lambda: conn.read_response(), lambda error: self._mock(error) - ) - for conn in self._in_use_connections: - conn.set_re_auth_token(token) - - async def _mock(self, error: RedisError): - """ - Dummy functions, needs to be passed as error callback to retry object. - :param error: - :return: - """ - pass - class BlockingConnectionPool(ConnectionPool): """ @@ -1275,7 +1129,7 @@ class BlockingConnectionPool(ConnectionPool): connection from the pool when all of connections are in use, rather than raising a :py:class:`~redis.ConnectionError` (as the default :py:class:`~redis.asyncio.ConnectionPool` implementation does), it - blocks the current `Task` for a specified number of seconds until + makes blocks the current `Task` for a specified number of seconds until a connection becomes available. Use ``max_connections`` to increase / decrease the pool size:: @@ -1309,29 +1163,16 @@ class BlockingConnectionPool(ConnectionPool): self._condition = asyncio.Condition() self.timeout = timeout - @deprecated_args( - args_to_warn=["*"], - reason="Use get_connection() without args instead", - version="5.3.0", - ) - async def get_connection(self, command_name=None, *keys, **options): + async def get_connection(self, command_name, *keys, **options): """Gets a connection from the pool, blocking until one is available""" try: - async with self._condition: - async with async_timeout(self.timeout): + async with async_timeout(self.timeout): + async with self._condition: await self._condition.wait_for(self.can_get_connection) - connection = super().get_available_connection() + return await super().get_connection(command_name, *keys, **options) except asyncio.TimeoutError as err: raise ConnectionError("No connection available.") from err - # We now perform the connection check outside of the lock. - try: - await self.ensure_connection(connection) - return connection - except BaseException: - await self.release(connection) - raise - async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" async with self._condition: diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/lock.py b/venv/lib/python3.12/site-packages/redis/asyncio/lock.py index 16d7fb6..e1d11a8 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/lock.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/lock.py @@ -1,18 +1,14 @@ import asyncio -import logging import threading import uuid from types import SimpleNamespace from typing import TYPE_CHECKING, Awaitable, Optional, Union from redis.exceptions import LockError, LockNotOwnedError -from redis.typing import Number if TYPE_CHECKING: from redis.asyncio import Redis, RedisCluster -logger = logging.getLogger(__name__) - class Lock: """ @@ -86,9 +82,8 @@ class Lock: timeout: Optional[float] = None, sleep: float = 0.1, blocking: bool = True, - blocking_timeout: Optional[Number] = None, + blocking_timeout: Optional[float] = None, thread_local: bool = True, - raise_on_release_error: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -132,11 +127,6 @@ class Lock: thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -153,7 +143,6 @@ class Lock: self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) self.local = threading.local() if self.thread_local else SimpleNamespace() - self.raise_on_release_error = raise_on_release_error self.local.token = None self.register_scripts() @@ -173,19 +162,12 @@ class Lock: raise LockError("Unable to acquire lock within the time specified") async def __aexit__(self, exc_type, exc_value, traceback): - try: - await self.release() - except LockError: - if self.raise_on_release_error: - raise - logger.warning( - "Lock was unlocked or no longer owned when exiting context manager." - ) + await self.release() async def acquire( self, blocking: Optional[bool] = None, - blocking_timeout: Optional[Number] = None, + blocking_timeout: Optional[float] = None, token: Optional[Union[str, bytes]] = None, ): """ @@ -267,10 +249,7 @@ class Lock: """Releases the already acquired lock""" expected_token = self.local.token if expected_token is None: - raise LockError( - "Cannot release a lock that's not owned or is already unlocked.", - lock_name=self.name, - ) + raise LockError("Cannot release an unlocked lock") self.local.token = None return self.do_release(expected_token) @@ -283,7 +262,7 @@ class Lock: raise LockNotOwnedError("Cannot release a lock that's no longer owned") def extend( - self, additional_time: Number, replace_ttl: bool = False + self, additional_time: float, replace_ttl: bool = False ) -> Awaitable[bool]: """ Adds more time to an already acquired lock. diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/retry.py b/venv/lib/python3.12/site-packages/redis/asyncio/retry.py index 98b2d9c..7c5e3b0 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/retry.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/retry.py @@ -2,16 +2,18 @@ from asyncio import sleep from typing import TYPE_CHECKING, Any, Awaitable, Callable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, RedisError, TimeoutError -from redis.retry import AbstractRetry - -T = TypeVar("T") if TYPE_CHECKING: from redis.backoff import AbstractBackoff -class Retry(AbstractRetry[RedisError]): - __hash__ = AbstractRetry.__hash__ +T = TypeVar("T") + + +class Retry: + """Retry a specific number of times after a failure""" + + __slots__ = "_backoff", "_retries", "_supported_errors" def __init__( self, @@ -22,16 +24,23 @@ class Retry(AbstractRetry[RedisError]): TimeoutError, ), ): - super().__init__(backoff, retries, supported_errors) + """ + Initialize a `Retry` object with a `Backoff` object + that retries a maximum of `retries` times. + `retries` can be negative to retry forever. + You can specify the types of supported errors which trigger + a retry with the `supported_errors` parameter. + """ + self._backoff = backoff + self._retries = retries + self._supported_errors = supported_errors - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Retry): - return NotImplemented - - return ( - self._backoff == other._backoff - and self._retries == other._retries - and set(self._supported_errors) == set(other._supported_errors) + def update_supported_errors(self, specified_errors: list): + """ + Updates the supported errors with the specified error types + """ + self._supported_errors = tuple( + set(self._supported_errors + tuple(specified_errors)) ) async def call_with_retry( diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/sentinel.py b/venv/lib/python3.12/site-packages/redis/asyncio/sentinel.py index d0455ab..6834fb1 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/sentinel.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/sentinel.py @@ -11,12 +11,8 @@ from redis.asyncio.connection import ( SSLConnection, ) from redis.commands import AsyncSentinelCommands -from redis.exceptions import ( - ConnectionError, - ReadOnlyError, - ResponseError, - TimeoutError, -) +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes class MasterNotFoundError(ConnectionError): @@ -33,18 +29,20 @@ class SentinelManagedConnection(Connection): super().__init__(**kwargs) def __repr__(self): - s = f"<{self.__class__.__module__}.{self.__class__.__name__}" + pool = self.connection_pool + s = f"{self.__class__.__name__}" + return s + ">" async def connect_to(self, address): self.host, self.port = address - await self.connect_check_health( - check_health=self.connection_pool.check_connection, - retry_socket_connect=False, - ) + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") async def _connect_retry(self): if self._reader: @@ -107,11 +105,9 @@ class SentinelConnectionPool(ConnectionPool): def __init__(self, service_name, sentinel_manager, **kwargs): kwargs["connection_class"] = kwargs.get( "connection_class", - ( - SentinelManagedSSLConnection - if kwargs.pop("ssl", False) - else SentinelManagedConnection - ), + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) @@ -124,8 +120,8 @@ class SentinelConnectionPool(ConnectionPool): def __repr__(self): return ( - f"<{self.__class__.__module__}.{self.__class__.__name__}" - f"(service={self.service_name}({self.is_master and 'master' or 'slave'}))>" + f"{self.__class__.__name__}" + f"" ) def reset(self): @@ -201,7 +197,6 @@ class Sentinel(AsyncSentinelCommands): sentinels, min_other_sentinels=0, sentinel_kwargs=None, - force_master_ip=None, **connection_kwargs, ): # if sentinel_kwargs isn't defined, use the socket_* options from @@ -218,7 +213,6 @@ class Sentinel(AsyncSentinelCommands): ] self.min_other_sentinels = min_other_sentinels self.connection_kwargs = connection_kwargs - self._force_master_ip = force_master_ip async def execute_command(self, *args, **kwargs): """ @@ -226,31 +220,19 @@ class Sentinel(AsyncSentinelCommands): once - If set to True, then execute the resulting command on a single node at random, rather than across the entire sentinel cluster. """ - once = bool(kwargs.pop("once", False)) - - # Check if command is supposed to return the original - # responses instead of boolean value. - return_responses = bool(kwargs.pop("return_responses", False)) + once = bool(kwargs.get("once", False)) + if "once" in kwargs.keys(): + kwargs.pop("once") if once: - response = await random.choice(self.sentinels).execute_command( - *args, **kwargs - ) - if return_responses: - return [response] - else: - return True if response else False - - tasks = [ - asyncio.Task(sentinel.execute_command(*args, **kwargs)) - for sentinel in self.sentinels - ] - responses = await asyncio.gather(*tasks) - - if return_responses: - return responses - - return all(responses) + await random.choice(self.sentinels).execute_command(*args, **kwargs) + else: + tasks = [ + asyncio.Task(sentinel.execute_command(*args, **kwargs)) + for sentinel in self.sentinels + ] + await asyncio.gather(*tasks) + return True def __repr__(self): sentinel_addresses = [] @@ -259,10 +241,7 @@ class Sentinel(AsyncSentinelCommands): f"{sentinel.connection_pool.connection_kwargs['host']}:" f"{sentinel.connection_pool.connection_kwargs['port']}" ) - return ( - f"<{self.__class__}.{self.__class__.__name__}" - f"(sentinels=[{','.join(sentinel_addresses)}])>" - ) + return f"{self.__class__.__name__}" def check_master_state(self, state: dict, service_name: str) -> bool: if not state["is_master"] or state["is_sdown"] or state["is_odown"]: @@ -294,13 +273,7 @@ class Sentinel(AsyncSentinelCommands): sentinel, self.sentinels[0], ) - - ip = ( - self._force_master_ip - if self._force_master_ip is not None - else state["ip"] - ) - return ip, state["port"] + return state["ip"], state["port"] error_info = "" if len(collected_errors) > 0: @@ -341,8 +314,6 @@ class Sentinel(AsyncSentinelCommands): ): """ Returns a redis client instance for the ``service_name`` master. - Sentinel client will detect failover and reconnect Redis clients - automatically. A :py:class:`~redis.sentinel.SentinelConnectionPool` class is used to retrieve the master's address before establishing a new diff --git a/venv/lib/python3.12/site-packages/redis/asyncio/utils.py b/venv/lib/python3.12/site-packages/redis/asyncio/utils.py index fa01451..5a55b36 100644 --- a/venv/lib/python3.12/site-packages/redis/asyncio/utils.py +++ b/venv/lib/python3.12/site-packages/redis/asyncio/utils.py @@ -16,7 +16,7 @@ def from_url(url, **kwargs): return Redis.from_url(url, **kwargs) -class pipeline: # noqa: N801 +class pipeline: def __init__(self, redis_obj: "Redis"): self.p: "Pipeline" = redis_obj.pipeline() diff --git a/venv/lib/python3.12/site-packages/redis/auth/__init__.py b/venv/lib/python3.12/site-packages/redis/auth/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/venv/lib/python3.12/site-packages/redis/auth/err.py b/venv/lib/python3.12/site-packages/redis/auth/err.py deleted file mode 100644 index 743dab1..0000000 --- a/venv/lib/python3.12/site-packages/redis/auth/err.py +++ /dev/null @@ -1,31 +0,0 @@ -from typing import Iterable - - -class RequestTokenErr(Exception): - """ - Represents an exception during token request. - """ - - def __init__(self, *args): - super().__init__(*args) - - -class InvalidTokenSchemaErr(Exception): - """ - Represents an exception related to invalid token schema. - """ - - def __init__(self, missing_fields: Iterable[str] = []): - super().__init__( - "Unexpected token schema. Following fields are missing: " - + ", ".join(missing_fields) - ) - - -class TokenRenewalErr(Exception): - """ - Represents an exception during token renewal process. - """ - - def __init__(self, *args): - super().__init__(*args) diff --git a/venv/lib/python3.12/site-packages/redis/auth/idp.py b/venv/lib/python3.12/site-packages/redis/auth/idp.py deleted file mode 100644 index 0951d95..0000000 --- a/venv/lib/python3.12/site-packages/redis/auth/idp.py +++ /dev/null @@ -1,28 +0,0 @@ -from abc import ABC, abstractmethod - -from redis.auth.token import TokenInterface - -""" -This interface is the facade of an identity provider -""" - - -class IdentityProviderInterface(ABC): - """ - Receive a token from the identity provider. - Receiving a token only works when being authenticated. - """ - - @abstractmethod - def request_token(self, force_refresh=False) -> TokenInterface: - pass - - -class IdentityProviderConfigInterface(ABC): - """ - Configuration class that provides a configured identity provider. - """ - - @abstractmethod - def get_provider(self) -> IdentityProviderInterface: - pass diff --git a/venv/lib/python3.12/site-packages/redis/auth/token.py b/venv/lib/python3.12/site-packages/redis/auth/token.py deleted file mode 100644 index 1f613af..0000000 --- a/venv/lib/python3.12/site-packages/redis/auth/token.py +++ /dev/null @@ -1,130 +0,0 @@ -from abc import ABC, abstractmethod -from datetime import datetime, timezone - -from redis.auth.err import InvalidTokenSchemaErr - - -class TokenInterface(ABC): - @abstractmethod - def is_expired(self) -> bool: - pass - - @abstractmethod - def ttl(self) -> float: - pass - - @abstractmethod - def try_get(self, key: str) -> str: - pass - - @abstractmethod - def get_value(self) -> str: - pass - - @abstractmethod - def get_expires_at_ms(self) -> float: - pass - - @abstractmethod - def get_received_at_ms(self) -> float: - pass - - -class TokenResponse: - def __init__(self, token: TokenInterface): - self._token = token - - def get_token(self) -> TokenInterface: - return self._token - - def get_ttl_ms(self) -> float: - return self._token.get_expires_at_ms() - self._token.get_received_at_ms() - - -class SimpleToken(TokenInterface): - def __init__( - self, value: str, expires_at_ms: float, received_at_ms: float, claims: dict - ) -> None: - self.value = value - self.expires_at = expires_at_ms - self.received_at = received_at_ms - self.claims = claims - - def ttl(self) -> float: - if self.expires_at == -1: - return -1 - - return self.expires_at - (datetime.now(timezone.utc).timestamp() * 1000) - - def is_expired(self) -> bool: - if self.expires_at == -1: - return False - - return self.ttl() <= 0 - - def try_get(self, key: str) -> str: - return self.claims.get(key) - - def get_value(self) -> str: - return self.value - - def get_expires_at_ms(self) -> float: - return self.expires_at - - def get_received_at_ms(self) -> float: - return self.received_at - - -class JWToken(TokenInterface): - REQUIRED_FIELDS = {"exp"} - - def __init__(self, token: str): - try: - import jwt - except ImportError as ie: - raise ImportError( - f"The PyJWT library is required for {self.__class__.__name__}.", - ) from ie - self._value = token - self._decoded = jwt.decode( - self._value, - options={"verify_signature": False}, - algorithms=[jwt.get_unverified_header(self._value).get("alg")], - ) - self._validate_token() - - def is_expired(self) -> bool: - exp = self._decoded["exp"] - if exp == -1: - return False - - return ( - self._decoded["exp"] * 1000 <= datetime.now(timezone.utc).timestamp() * 1000 - ) - - def ttl(self) -> float: - exp = self._decoded["exp"] - if exp == -1: - return -1 - - return ( - self._decoded["exp"] * 1000 - datetime.now(timezone.utc).timestamp() * 1000 - ) - - def try_get(self, key: str) -> str: - return self._decoded.get(key) - - def get_value(self) -> str: - return self._value - - def get_expires_at_ms(self) -> float: - return float(self._decoded["exp"] * 1000) - - def get_received_at_ms(self) -> float: - return datetime.now(timezone.utc).timestamp() * 1000 - - def _validate_token(self): - actual_fields = {x for x in self._decoded.keys()} - - if len(self.REQUIRED_FIELDS - actual_fields) != 0: - raise InvalidTokenSchemaErr(self.REQUIRED_FIELDS - actual_fields) diff --git a/venv/lib/python3.12/site-packages/redis/auth/token_manager.py b/venv/lib/python3.12/site-packages/redis/auth/token_manager.py deleted file mode 100644 index dd8d162..0000000 --- a/venv/lib/python3.12/site-packages/redis/auth/token_manager.py +++ /dev/null @@ -1,370 +0,0 @@ -import asyncio -import logging -import threading -from datetime import datetime, timezone -from time import sleep -from typing import Any, Awaitable, Callable, Union - -from redis.auth.err import RequestTokenErr, TokenRenewalErr -from redis.auth.idp import IdentityProviderInterface -from redis.auth.token import TokenResponse - -logger = logging.getLogger(__name__) - - -class CredentialsListener: - """ - Listeners that will be notified on events related to credentials. - Accepts callbacks and awaitable callbacks. - """ - - def __init__(self): - self._on_next = None - self._on_error = None - - @property - def on_next(self) -> Union[Callable[[Any], None], Awaitable]: - return self._on_next - - @on_next.setter - def on_next(self, callback: Union[Callable[[Any], None], Awaitable]) -> None: - self._on_next = callback - - @property - def on_error(self) -> Union[Callable[[Exception], None], Awaitable]: - return self._on_error - - @on_error.setter - def on_error(self, callback: Union[Callable[[Exception], None], Awaitable]) -> None: - self._on_error = callback - - -class RetryPolicy: - def __init__(self, max_attempts: int, delay_in_ms: float): - self.max_attempts = max_attempts - self.delay_in_ms = delay_in_ms - - def get_max_attempts(self) -> int: - """ - Retry attempts before exception will be thrown. - - :return: int - """ - return self.max_attempts - - def get_delay_in_ms(self) -> float: - """ - Delay between retries in seconds. - - :return: int - """ - return self.delay_in_ms - - -class TokenManagerConfig: - def __init__( - self, - expiration_refresh_ratio: float, - lower_refresh_bound_millis: int, - token_request_execution_timeout_in_ms: int, - retry_policy: RetryPolicy, - ): - self._expiration_refresh_ratio = expiration_refresh_ratio - self._lower_refresh_bound_millis = lower_refresh_bound_millis - self._token_request_execution_timeout_in_ms = ( - token_request_execution_timeout_in_ms - ) - self._retry_policy = retry_policy - - def get_expiration_refresh_ratio(self) -> float: - """ - Represents the ratio of a token's lifetime at which a refresh should be triggered. # noqa: E501 - For example, a value of 0.75 means the token should be refreshed - when 75% of its lifetime has elapsed (or when 25% of its lifetime remains). - - :return: float - """ - - return self._expiration_refresh_ratio - - def get_lower_refresh_bound_millis(self) -> int: - """ - Represents the minimum time in milliseconds before token expiration - to trigger a refresh, in milliseconds. - This value sets a fixed lower bound for when a token refresh should occur, - regardless of the token's total lifetime. - If set to 0 there will be no lower bound and the refresh will be triggered - based on the expirationRefreshRatio only. - - :return: int - """ - return self._lower_refresh_bound_millis - - def get_token_request_execution_timeout_in_ms(self) -> int: - """ - Represents the maximum time in milliseconds to wait - for a token request to complete. - - :return: int - """ - return self._token_request_execution_timeout_in_ms - - def get_retry_policy(self) -> RetryPolicy: - """ - Represents the retry policy for token requests. - - :return: RetryPolicy - """ - return self._retry_policy - - -class TokenManager: - def __init__( - self, identity_provider: IdentityProviderInterface, config: TokenManagerConfig - ): - self._idp = identity_provider - self._config = config - self._next_timer = None - self._listener = None - self._init_timer = None - self._retries = 0 - - def __del__(self): - logger.info("Token manager are disposed") - self.stop() - - def start( - self, - listener: CredentialsListener, - skip_initial: bool = False, - ) -> Callable[[], None]: - self._listener = listener - - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # Run loop in a separate thread to unblock main thread. - loop = asyncio.new_event_loop() - thread = threading.Thread( - target=_start_event_loop_in_thread, args=(loop,), daemon=True - ) - thread.start() - - # Event to block for initial execution. - init_event = asyncio.Event() - self._init_timer = loop.call_later( - 0, self._renew_token, skip_initial, init_event - ) - logger.info("Token manager started") - - # Blocks in thread-safe manner. - asyncio.run_coroutine_threadsafe(init_event.wait(), loop).result() - return self.stop - - async def start_async( - self, - listener: CredentialsListener, - block_for_initial: bool = False, - initial_delay_in_ms: float = 0, - skip_initial: bool = False, - ) -> Callable[[], None]: - self._listener = listener - - loop = asyncio.get_running_loop() - init_event = asyncio.Event() - - # Wraps the async callback with async wrapper to schedule with loop.call_later() - wrapped = _async_to_sync_wrapper( - loop, self._renew_token_async, skip_initial, init_event - ) - self._init_timer = loop.call_later(initial_delay_in_ms / 1000, wrapped) - logger.info("Token manager started") - - if block_for_initial: - await init_event.wait() - - return self.stop - - def stop(self): - if self._init_timer is not None: - self._init_timer.cancel() - if self._next_timer is not None: - self._next_timer.cancel() - - def acquire_token(self, force_refresh=False) -> TokenResponse: - try: - token = self._idp.request_token(force_refresh) - except RequestTokenErr as e: - if self._retries < self._config.get_retry_policy().get_max_attempts(): - self._retries += 1 - sleep(self._config.get_retry_policy().get_delay_in_ms() / 1000) - return self.acquire_token(force_refresh) - else: - raise e - - self._retries = 0 - return TokenResponse(token) - - async def acquire_token_async(self, force_refresh=False) -> TokenResponse: - try: - token = self._idp.request_token(force_refresh) - except RequestTokenErr as e: - if self._retries < self._config.get_retry_policy().get_max_attempts(): - self._retries += 1 - await asyncio.sleep( - self._config.get_retry_policy().get_delay_in_ms() / 1000 - ) - return await self.acquire_token_async(force_refresh) - else: - raise e - - self._retries = 0 - return TokenResponse(token) - - def _calculate_renewal_delay(self, expire_date: float, issue_date: float) -> float: - delay_for_lower_refresh = self._delay_for_lower_refresh(expire_date) - delay_for_ratio_refresh = self._delay_for_ratio_refresh(expire_date, issue_date) - delay = min(delay_for_ratio_refresh, delay_for_lower_refresh) - - return 0 if delay < 0 else delay / 1000 - - def _delay_for_lower_refresh(self, expire_date: float): - return ( - expire_date - - self._config.get_lower_refresh_bound_millis() - - (datetime.now(timezone.utc).timestamp() * 1000) - ) - - def _delay_for_ratio_refresh(self, expire_date: float, issue_date: float): - token_ttl = expire_date - issue_date - refresh_before = token_ttl - ( - token_ttl * self._config.get_expiration_refresh_ratio() - ) - - return ( - expire_date - - refresh_before - - (datetime.now(timezone.utc).timestamp() * 1000) - ) - - def _renew_token( - self, skip_initial: bool = False, init_event: asyncio.Event = None - ): - """ - Task to renew token from identity provider. - Schedules renewal tasks based on token TTL. - """ - - try: - token_res = self.acquire_token(force_refresh=True) - delay = self._calculate_renewal_delay( - token_res.get_token().get_expires_at_ms(), - token_res.get_token().get_received_at_ms(), - ) - - if token_res.get_token().is_expired(): - raise TokenRenewalErr("Requested token is expired") - - if self._listener.on_next is None: - logger.warning( - "No registered callback for token renewal task. Renewal cancelled" - ) - return - - if not skip_initial: - try: - self._listener.on_next(token_res.get_token()) - except Exception as e: - raise TokenRenewalErr(e) - - if delay <= 0: - return - - loop = asyncio.get_running_loop() - self._next_timer = loop.call_later(delay, self._renew_token) - logger.info(f"Next token renewal scheduled in {delay} seconds") - return token_res - except Exception as e: - if self._listener.on_error is None: - raise e - - self._listener.on_error(e) - finally: - if init_event: - init_event.set() - - async def _renew_token_async( - self, skip_initial: bool = False, init_event: asyncio.Event = None - ): - """ - Async task to renew tokens from identity provider. - Schedules renewal tasks based on token TTL. - """ - - try: - token_res = await self.acquire_token_async(force_refresh=True) - delay = self._calculate_renewal_delay( - token_res.get_token().get_expires_at_ms(), - token_res.get_token().get_received_at_ms(), - ) - - if token_res.get_token().is_expired(): - raise TokenRenewalErr("Requested token is expired") - - if self._listener.on_next is None: - logger.warning( - "No registered callback for token renewal task. Renewal cancelled" - ) - return - - if not skip_initial: - try: - await self._listener.on_next(token_res.get_token()) - except Exception as e: - raise TokenRenewalErr(e) - - if delay <= 0: - return - - loop = asyncio.get_running_loop() - wrapped = _async_to_sync_wrapper(loop, self._renew_token_async) - logger.info(f"Next token renewal scheduled in {delay} seconds") - loop.call_later(delay, wrapped) - except Exception as e: - if self._listener.on_error is None: - raise e - - await self._listener.on_error(e) - finally: - if init_event: - init_event.set() - - -def _async_to_sync_wrapper(loop, coro_func, *args, **kwargs): - """ - Wraps an asynchronous function so it can be used with loop.call_later. - - :param loop: The event loop in which the coroutine will be executed. - :param coro_func: The coroutine function to wrap. - :param args: Positional arguments to pass to the coroutine function. - :param kwargs: Keyword arguments to pass to the coroutine function. - :return: A regular function suitable for loop.call_later. - """ - - def wrapped(): - # Schedule the coroutine in the event loop - asyncio.ensure_future(coro_func(*args, **kwargs), loop=loop) - - return wrapped - - -def _start_event_loop_in_thread(event_loop: asyncio.AbstractEventLoop): - """ - Starts event loop in a thread. - Used to be able to schedule tasks using loop.call_later. - - :param event_loop: - :return: - """ - asyncio.set_event_loop(event_loop) - event_loop.run_forever() diff --git a/venv/lib/python3.12/site-packages/redis/backoff.py b/venv/lib/python3.12/site-packages/redis/backoff.py index 6e1f68a..c62e760 100644 --- a/venv/lib/python3.12/site-packages/redis/backoff.py +++ b/venv/lib/python3.12/site-packages/redis/backoff.py @@ -19,7 +19,7 @@ class AbstractBackoff(ABC): pass @abstractmethod - def compute(self, failures: int) -> float: + def compute(self, failures): """Compute backoff in seconds upon failure""" pass @@ -27,34 +27,25 @@ class AbstractBackoff(ABC): class ConstantBackoff(AbstractBackoff): """Constant backoff upon failure""" - def __init__(self, backoff: float) -> None: + def __init__(self, backoff): """`backoff`: backoff time in seconds""" self._backoff = backoff - def __hash__(self) -> int: - return hash((self._backoff,)) - - def __eq__(self, other) -> bool: - if not isinstance(other, ConstantBackoff): - return NotImplemented - - return self._backoff == other._backoff - - def compute(self, failures: int) -> float: + def compute(self, failures): return self._backoff class NoBackoff(ConstantBackoff): """No backoff upon failure""" - def __init__(self) -> None: + def __init__(self): super().__init__(0) class ExponentialBackoff(AbstractBackoff): """Exponential backoff upon failure""" - def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE): + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): """ `cap`: maximum backoff time in seconds `base`: base backoff time in seconds @@ -62,23 +53,14 @@ class ExponentialBackoff(AbstractBackoff): self._cap = cap self._base = base - def __hash__(self) -> int: - return hash((self._base, self._cap)) - - def __eq__(self, other) -> bool: - if not isinstance(other, ExponentialBackoff): - return NotImplemented - - return self._base == other._base and self._cap == other._cap - - def compute(self, failures: int) -> float: + def compute(self, failures): return min(self._cap, self._base * 2**failures) class FullJitterBackoff(AbstractBackoff): """Full jitter backoff upon failure""" - def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): """ `cap`: maximum backoff time in seconds `base`: base backoff time in seconds @@ -86,23 +68,14 @@ class FullJitterBackoff(AbstractBackoff): self._cap = cap self._base = base - def __hash__(self) -> int: - return hash((self._base, self._cap)) - - def __eq__(self, other) -> bool: - if not isinstance(other, FullJitterBackoff): - return NotImplemented - - return self._base == other._base and self._cap == other._cap - - def compute(self, failures: int) -> float: + def compute(self, failures): return random.uniform(0, min(self._cap, self._base * 2**failures)) class EqualJitterBackoff(AbstractBackoff): """Equal jitter backoff upon failure""" - def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): """ `cap`: maximum backoff time in seconds `base`: base backoff time in seconds @@ -110,16 +83,7 @@ class EqualJitterBackoff(AbstractBackoff): self._cap = cap self._base = base - def __hash__(self) -> int: - return hash((self._base, self._cap)) - - def __eq__(self, other) -> bool: - if not isinstance(other, EqualJitterBackoff): - return NotImplemented - - return self._base == other._base and self._cap == other._cap - - def compute(self, failures: int) -> float: + def compute(self, failures): temp = min(self._cap, self._base * 2**failures) / 2 return temp + random.uniform(0, temp) @@ -127,7 +91,7 @@ class EqualJitterBackoff(AbstractBackoff): class DecorrelatedJitterBackoff(AbstractBackoff): """Decorrelated jitter backoff upon failure""" - def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: + def __init__(self, cap=DEFAULT_CAP, base=DEFAULT_BASE): """ `cap`: maximum backoff time in seconds `base`: base backoff time in seconds @@ -136,48 +100,15 @@ class DecorrelatedJitterBackoff(AbstractBackoff): self._base = base self._previous_backoff = 0 - def __hash__(self) -> int: - return hash((self._base, self._cap)) - - def __eq__(self, other) -> bool: - if not isinstance(other, DecorrelatedJitterBackoff): - return NotImplemented - - return self._base == other._base and self._cap == other._cap - - def reset(self) -> None: + def reset(self): self._previous_backoff = 0 - def compute(self, failures: int) -> float: + def compute(self, failures): max_backoff = max(self._base, self._previous_backoff * 3) temp = random.uniform(self._base, max_backoff) self._previous_backoff = min(self._cap, temp) return self._previous_backoff -class ExponentialWithJitterBackoff(AbstractBackoff): - """Exponential backoff upon failure, with jitter""" - - def __init__(self, cap: float = DEFAULT_CAP, base: float = DEFAULT_BASE) -> None: - """ - `cap`: maximum backoff time in seconds - `base`: base backoff time in seconds - """ - self._cap = cap - self._base = base - - def __hash__(self) -> int: - return hash((self._base, self._cap)) - - def __eq__(self, other) -> bool: - if not isinstance(other, ExponentialWithJitterBackoff): - return NotImplemented - - return self._base == other._base and self._cap == other._cap - - def compute(self, failures: int) -> float: - return min(self._cap, random.random() * self._base * 2**failures) - - def default_backoff(): return EqualJitterBackoff() diff --git a/venv/lib/python3.12/site-packages/redis/cache.py b/venv/lib/python3.12/site-packages/redis/cache.py deleted file mode 100644 index 9971edd..0000000 --- a/venv/lib/python3.12/site-packages/redis/cache.py +++ /dev/null @@ -1,401 +0,0 @@ -from abc import ABC, abstractmethod -from collections import OrderedDict -from dataclasses import dataclass -from enum import Enum -from typing import Any, List, Optional, Union - - -class CacheEntryStatus(Enum): - VALID = "VALID" - IN_PROGRESS = "IN_PROGRESS" - - -class EvictionPolicyType(Enum): - time_based = "time_based" - frequency_based = "frequency_based" - - -@dataclass(frozen=True) -class CacheKey: - command: str - redis_keys: tuple - - -class CacheEntry: - def __init__( - self, - cache_key: CacheKey, - cache_value: bytes, - status: CacheEntryStatus, - connection_ref, - ): - self.cache_key = cache_key - self.cache_value = cache_value - self.status = status - self.connection_ref = connection_ref - - def __hash__(self): - return hash( - (self.cache_key, self.cache_value, self.status, self.connection_ref) - ) - - def __eq__(self, other): - return hash(self) == hash(other) - - -class EvictionPolicyInterface(ABC): - @property - @abstractmethod - def cache(self): - pass - - @cache.setter - def cache(self, value): - pass - - @property - @abstractmethod - def type(self) -> EvictionPolicyType: - pass - - @abstractmethod - def evict_next(self) -> CacheKey: - pass - - @abstractmethod - def evict_many(self, count: int) -> List[CacheKey]: - pass - - @abstractmethod - def touch(self, cache_key: CacheKey) -> None: - pass - - -class CacheConfigurationInterface(ABC): - @abstractmethod - def get_cache_class(self): - pass - - @abstractmethod - def get_max_size(self) -> int: - pass - - @abstractmethod - def get_eviction_policy(self): - pass - - @abstractmethod - def is_exceeds_max_size(self, count: int) -> bool: - pass - - @abstractmethod - def is_allowed_to_cache(self, command: str) -> bool: - pass - - -class CacheInterface(ABC): - @property - @abstractmethod - def collection(self) -> OrderedDict: - pass - - @property - @abstractmethod - def config(self) -> CacheConfigurationInterface: - pass - - @property - @abstractmethod - def eviction_policy(self) -> EvictionPolicyInterface: - pass - - @property - @abstractmethod - def size(self) -> int: - pass - - @abstractmethod - def get(self, key: CacheKey) -> Union[CacheEntry, None]: - pass - - @abstractmethod - def set(self, entry: CacheEntry) -> bool: - pass - - @abstractmethod - def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: - pass - - @abstractmethod - def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: - pass - - @abstractmethod - def flush(self) -> int: - pass - - @abstractmethod - def is_cachable(self, key: CacheKey) -> bool: - pass - - -class DefaultCache(CacheInterface): - def __init__( - self, - cache_config: CacheConfigurationInterface, - ) -> None: - self._cache = OrderedDict() - self._cache_config = cache_config - self._eviction_policy = self._cache_config.get_eviction_policy().value() - self._eviction_policy.cache = self - - @property - def collection(self) -> OrderedDict: - return self._cache - - @property - def config(self) -> CacheConfigurationInterface: - return self._cache_config - - @property - def eviction_policy(self) -> EvictionPolicyInterface: - return self._eviction_policy - - @property - def size(self) -> int: - return len(self._cache) - - def set(self, entry: CacheEntry) -> bool: - if not self.is_cachable(entry.cache_key): - return False - - self._cache[entry.cache_key] = entry - self._eviction_policy.touch(entry.cache_key) - - if self._cache_config.is_exceeds_max_size(len(self._cache)): - self._eviction_policy.evict_next() - - return True - - def get(self, key: CacheKey) -> Union[CacheEntry, None]: - entry = self._cache.get(key, None) - - if entry is None: - return None - - self._eviction_policy.touch(key) - return entry - - def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]: - response = [] - - for key in cache_keys: - if self.get(key) is not None: - self._cache.pop(key) - response.append(True) - else: - response.append(False) - - return response - - def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]: - response = [] - keys_to_delete = [] - - for redis_key in redis_keys: - if isinstance(redis_key, bytes): - redis_key = redis_key.decode() - for cache_key in self._cache: - if redis_key in cache_key.redis_keys: - keys_to_delete.append(cache_key) - response.append(True) - - for key in keys_to_delete: - self._cache.pop(key) - - return response - - def flush(self) -> int: - elem_count = len(self._cache) - self._cache.clear() - return elem_count - - def is_cachable(self, key: CacheKey) -> bool: - return self._cache_config.is_allowed_to_cache(key.command) - - -class LRUPolicy(EvictionPolicyInterface): - def __init__(self): - self.cache = None - - @property - def cache(self): - return self._cache - - @cache.setter - def cache(self, cache: CacheInterface): - self._cache = cache - - @property - def type(self) -> EvictionPolicyType: - return EvictionPolicyType.time_based - - def evict_next(self) -> CacheKey: - self._assert_cache() - popped_entry = self._cache.collection.popitem(last=False) - return popped_entry[0] - - def evict_many(self, count: int) -> List[CacheKey]: - self._assert_cache() - if count > len(self._cache.collection): - raise ValueError("Evictions count is above cache size") - - popped_keys = [] - - for _ in range(count): - popped_entry = self._cache.collection.popitem(last=False) - popped_keys.append(popped_entry[0]) - - return popped_keys - - def touch(self, cache_key: CacheKey) -> None: - self._assert_cache() - - if self._cache.collection.get(cache_key) is None: - raise ValueError("Given entry does not belong to the cache") - - self._cache.collection.move_to_end(cache_key) - - def _assert_cache(self): - if self.cache is None or not isinstance(self.cache, CacheInterface): - raise ValueError("Eviction policy should be associated with valid cache.") - - -class EvictionPolicy(Enum): - LRU = LRUPolicy - - -class CacheConfig(CacheConfigurationInterface): - DEFAULT_CACHE_CLASS = DefaultCache - DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU - DEFAULT_MAX_SIZE = 10000 - - DEFAULT_ALLOW_LIST = [ - "BITCOUNT", - "BITFIELD_RO", - "BITPOS", - "EXISTS", - "GEODIST", - "GEOHASH", - "GEOPOS", - "GEORADIUSBYMEMBER_RO", - "GEORADIUS_RO", - "GEOSEARCH", - "GET", - "GETBIT", - "GETRANGE", - "HEXISTS", - "HGET", - "HGETALL", - "HKEYS", - "HLEN", - "HMGET", - "HSTRLEN", - "HVALS", - "JSON.ARRINDEX", - "JSON.ARRLEN", - "JSON.GET", - "JSON.MGET", - "JSON.OBJKEYS", - "JSON.OBJLEN", - "JSON.RESP", - "JSON.STRLEN", - "JSON.TYPE", - "LCS", - "LINDEX", - "LLEN", - "LPOS", - "LRANGE", - "MGET", - "SCARD", - "SDIFF", - "SINTER", - "SINTERCARD", - "SISMEMBER", - "SMEMBERS", - "SMISMEMBER", - "SORT_RO", - "STRLEN", - "SUBSTR", - "SUNION", - "TS.GET", - "TS.INFO", - "TS.RANGE", - "TS.REVRANGE", - "TYPE", - "XLEN", - "XPENDING", - "XRANGE", - "XREAD", - "XREVRANGE", - "ZCARD", - "ZCOUNT", - "ZDIFF", - "ZINTER", - "ZINTERCARD", - "ZLEXCOUNT", - "ZMSCORE", - "ZRANGE", - "ZRANGEBYLEX", - "ZRANGEBYSCORE", - "ZRANK", - "ZREVRANGE", - "ZREVRANGEBYLEX", - "ZREVRANGEBYSCORE", - "ZREVRANK", - "ZSCORE", - "ZUNION", - ] - - def __init__( - self, - max_size: int = DEFAULT_MAX_SIZE, - cache_class: Any = DEFAULT_CACHE_CLASS, - eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY, - ): - self._cache_class = cache_class - self._max_size = max_size - self._eviction_policy = eviction_policy - - def get_cache_class(self): - return self._cache_class - - def get_max_size(self) -> int: - return self._max_size - - def get_eviction_policy(self) -> EvictionPolicy: - return self._eviction_policy - - def is_exceeds_max_size(self, count: int) -> bool: - return count > self._max_size - - def is_allowed_to_cache(self, command: str) -> bool: - return command in self.DEFAULT_ALLOW_LIST - - -class CacheFactoryInterface(ABC): - @abstractmethod - def get_cache(self) -> CacheInterface: - pass - - -class CacheFactory(CacheFactoryInterface): - def __init__(self, cache_config: Optional[CacheConfig] = None): - self._config = cache_config - - if self._config is None: - self._config = CacheConfig() - - def get_cache(self) -> CacheInterface: - cache_class = self._config.get_cache_class() - return cache_class(cache_config=self._config) diff --git a/venv/lib/python3.12/site-packages/redis/client.py b/venv/lib/python3.12/site-packages/redis/client.py old mode 100755 new mode 100644 index 0e05b6f..4923143 --- a/venv/lib/python3.12/site-packages/redis/client.py +++ b/venv/lib/python3.12/site-packages/redis/client.py @@ -2,19 +2,9 @@ import copy import re import threading import time +import warnings from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Mapping, - Optional, - Set, - Type, - Union, -) +from typing import Any, Callable, Dict, List, Optional, Type, Union from redis._parsers.encoders import Encoder from redis._parsers.helpers import ( @@ -23,54 +13,33 @@ from redis._parsers.helpers import ( _RedisCallbacksRESP3, bool_ok, ) -from redis.backoff import ExponentialWithJitterBackoff -from redis.cache import CacheConfig, CacheInterface from redis.commands import ( CoreCommands, RedisModuleCommands, SentinelCommands, list_or_args, ) -from redis.commands.core import Script -from redis.connection import ( - AbstractConnection, - Connection, - ConnectionPool, - SSLConnection, - UnixDomainSocketConnection, -) +from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection from redis.credentials import CredentialProvider -from redis.event import ( - AfterPooledConnectionsInstantiationEvent, - AfterPubSubConnectionInstantiationEvent, - AfterSingleConnectionInstantiationEvent, - ClientType, - EventDispatcher, -) from redis.exceptions import ( ConnectionError, ExecAbortError, PubSubError, RedisError, ResponseError, + TimeoutError, WatchError, ) from redis.lock import Lock from redis.retry import Retry from redis.utils import ( + HIREDIS_AVAILABLE, _set_info_logger, - deprecated_args, get_lib_version, safe_str, str_if_bytes, - truncate_text, ) -if TYPE_CHECKING: - import ssl - - import OpenSSL - SYM_EMPTY = b"" EMPTY_RESPONSE = "EMPTY_RESPONSE" @@ -125,7 +94,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ @classmethod - def from_url(cls, url: str, **kwargs) -> "Redis": + def from_url(cls, url: str, **kwargs) -> None: """ Return a Redis client object configured from the given URL @@ -191,80 +160,56 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): client.auto_close_connection_pool = True return client - @deprecated_args( - args_to_warn=["retry_on_timeout"], - reason="TimeoutError is included by default.", - version="6.0.0", - ) def __init__( self, - host: str = "localhost", - port: int = 6379, - db: int = 0, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - socket_keepalive: Optional[bool] = None, - socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, - connection_pool: Optional[ConnectionPool] = None, - unix_socket_path: Optional[str] = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - retry: Retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 - ), - retry_on_error: Optional[List[Type[Exception]]] = None, - ssl: bool = False, - ssl_keyfile: Optional[str] = None, - ssl_certfile: Optional[str] = None, - ssl_cert_reqs: Union[str, "ssl.VerifyMode"] = "required", - ssl_ca_certs: Optional[str] = None, - ssl_ca_path: Optional[str] = None, - ssl_ca_data: Optional[str] = None, - ssl_check_hostname: bool = True, - ssl_password: Optional[str] = None, - ssl_validate_ocsp: bool = False, - ssl_validate_ocsp_stapled: bool = False, - ssl_ocsp_context: Optional["OpenSSL.SSL.Context"] = None, - ssl_ocsp_expected_cert: Optional[str] = None, - ssl_min_version: Optional["ssl.TLSVersion"] = None, - ssl_ciphers: Optional[str] = None, - max_connections: Optional[int] = None, - single_connection_client: bool = False, - health_check_interval: int = 0, - client_name: Optional[str] = None, - lib_name: Optional[str] = "redis-py", - lib_version: Optional[str] = get_lib_version(), - username: Optional[str] = None, - redis_connect_func: Optional[Callable[[], None]] = None, + host="localhost", + port=6379, + db=0, + password=None, + socket_timeout=None, + socket_connect_timeout=None, + socket_keepalive=None, + socket_keepalive_options=None, + connection_pool=None, + unix_socket_path=None, + encoding="utf-8", + encoding_errors="strict", + charset=None, + errors=None, + decode_responses=False, + retry_on_timeout=False, + retry_on_error=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs="required", + ssl_ca_certs=None, + ssl_ca_path=None, + ssl_ca_data=None, + ssl_check_hostname=False, + ssl_password=None, + ssl_validate_ocsp=False, + ssl_validate_ocsp_stapled=False, + ssl_ocsp_context=None, + ssl_ocsp_expected_cert=None, + max_connections=None, + single_connection_client=False, + health_check_interval=0, + client_name=None, + lib_name="redis-py", + lib_version=get_lib_version(), + username=None, + retry=None, + redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, - cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfig] = None, - event_dispatcher: Optional[EventDispatcher] = None, ) -> None: """ Initialize a new Redis client. - - To specify a retry policy for specific errors, you have two options: - - 1. Set the `retry_on_error` to a list of the error/s to retry on, and - you can also set `retry` to a valid `Retry` object(in case the default - one is not appropriate) - with this approach the retries will be triggered - on the default errors specified in the Retry object enriched with the - errors specified in `retry_on_error`. - - 2. Define a `Retry` object with configured 'supported_errors' and set - it to the `retry` parameter - with this approach you completely redefine - the errors on which retries will happen. - - `retry_on_timeout` is deprecated - please include the TimeoutError - either in the Retry object or in the `retry_on_error` list. - - When 'connection_pool' is provided - the retry configuration of the - provided pool will be used. + To specify a retry policy for specific errors, first set + `retry_on_error` to a list of the error/s to retry on, then set + `retry` to a valid `Retry` object. + To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. Args: @@ -272,13 +217,25 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): if `True`, connection pool is not used. In that case `Redis` instance use is not thread safe. """ - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher if not connection_pool: + if charset is not None: + warnings.warn( + DeprecationWarning( + '"charset" is deprecated. Use "encoding" instead' + ) + ) + encoding = charset + if errors is not None: + warnings.warn( + DeprecationWarning( + '"errors" is deprecated. Use "encoding_errors" instead' + ) + ) + encoding_errors = errors if not retry_on_error: retry_on_error = [] + if retry_on_timeout is True: + retry_on_error.append(TimeoutError) kwargs = { "db": db, "username": username, @@ -334,50 +291,17 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): "ssl_validate_ocsp": ssl_validate_ocsp, "ssl_ocsp_context": ssl_ocsp_context, "ssl_ocsp_expected_cert": ssl_ocsp_expected_cert, - "ssl_min_version": ssl_min_version, - "ssl_ciphers": ssl_ciphers, - } - ) - if (cache_config or cache) and protocol in [3, "3"]: - kwargs.update( - { - "cache": cache, - "cache_config": cache_config, } ) connection_pool = ConnectionPool(**kwargs) - self._event_dispatcher.dispatch( - AfterPooledConnectionsInstantiationEvent( - [connection_pool], ClientType.SYNC, credential_provider - ) - ) self.auto_close_connection_pool = True else: self.auto_close_connection_pool = False - self._event_dispatcher.dispatch( - AfterPooledConnectionsInstantiationEvent( - [connection_pool], ClientType.SYNC, credential_provider - ) - ) self.connection_pool = connection_pool - - if (cache_config or cache) and self.connection_pool.get_protocol() not in [ - 3, - "3", - ]: - raise RedisError("Client caching is only supported with RESP version 3") - - self.single_connection_lock = threading.RLock() self.connection = None - self._single_connection_client = single_connection_client - if self._single_connection_client: - self.connection = self.connection_pool.get_connection() - self._event_dispatcher.dispatch( - AfterSingleConnectionInstantiationEvent( - self.connection, ClientType.SYNC, self.single_connection_lock - ) - ) + if single_connection_client: + self.connection = self.connection_pool.get_connection("_") self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) @@ -387,10 +311,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): self.response_callbacks.update(_RedisCallbacksRESP2) def __repr__(self) -> str: - return ( - f"<{type(self).__module__}.{type(self).__name__}" - f"({repr(self.connection_pool)})>" - ) + return f"{type(self).__name__}<{repr(self.connection_pool)}>" def get_encoder(self) -> "Encoder": """Get the connection pool's encoder""" @@ -400,10 +321,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """Get the connection's key-word arguments""" return self.connection_pool.connection_kwargs - def get_retry(self) -> Optional[Retry]: + def get_retry(self) -> Optional["Retry"]: return self.get_connection_kwargs().get("retry") - def set_retry(self, retry: Retry) -> None: + def set_retry(self, retry: "Retry") -> None: self.get_connection_kwargs().update({"retry": retry}) self.connection_pool.set_retry(retry) @@ -448,7 +369,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): def transaction( self, func: Callable[["Pipeline"], None], *watches, **kwargs - ) -> Union[List[Any], Any, None]: + ) -> None: """ Convenience method for executing the callable `func` as a transaction while watching all keys specified in `watches`. The 'func' callable @@ -479,7 +400,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): blocking_timeout: Optional[float] = None, lock_class: Union[None, Any] = None, thread_local: bool = True, - raise_on_release_error: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -526,11 +446,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -548,7 +463,6 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, - raise_on_release_error=raise_on_release_error, ) def pubsub(self, **kwargs): @@ -557,9 +471,7 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): subscribe to channels and listen for messages that get published to them. """ - return PubSub( - self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs - ) + return PubSub(self.connection_pool, **kwargs) def monitor(self): return Monitor(self.connection_pool) @@ -576,12 +488,9 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): self.close() def __del__(self): - try: - self.close() - except Exception: - pass + self.close() - def close(self) -> None: + def close(self): # In case a connection property does not yet exist # (due to a crash earlier in the Redis() constructor), return # immediately as there is nothing to clean-up. @@ -600,44 +509,37 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ Send a command and parse the response """ - conn.send_command(*args, **options) + conn.send_command(*args) return self.parse_response(conn, command_name, **options) - def _close_connection(self, conn) -> None: + def _disconnect_raise(self, conn, error): """ - Close the connection before retrying. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection and raise an exception + if retry_on_error is not set or the error + is not one of the specified error types """ - conn.disconnect() + if ( + conn.retry_on_error is None + or isinstance(error, tuple(conn.retry_on_error)) is False + ): + raise error # COMMAND EXECUTION AND PROTOCOL PARSING def execute_command(self, *args, **options): - return self._execute_command(*args, **options) - - def _execute_command(self, *args, **options): """Execute a command and return a parsed response""" pool = self.connection_pool command_name = args[0] - conn = self.connection or pool.get_connection() + conn = self.connection or pool.get_connection(command_name, **options) - if self._single_connection_client: - self.single_connection_lock.acquire() try: return conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda _: self._close_connection(conn), + lambda error: self._disconnect_raise(conn, error), ) finally: - if self._single_connection_client: - self.single_connection_lock.release() if not self.connection: pool.release(conn) @@ -657,16 +559,10 @@ class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): if EMPTY_RESPONSE in options: options.pop(EMPTY_RESPONSE) - # Remove keys entry, it needs only for cache. - options.pop("keys", None) - if command_name in self.response_callbacks: return self.response_callbacks[command_name](response, **options) return response - def get_cache(self) -> Optional[CacheInterface]: - return self.connection_pool.cache - StrictRedis = Redis @@ -683,7 +579,7 @@ class Monitor: def __init__(self, connection_pool): self.connection_pool = connection_pool - self.connection = self.connection_pool.get_connection() + self.connection = self.connection_pool.get_connection("MONITOR") def __enter__(self): self.connection.send_command("MONITOR") @@ -758,7 +654,6 @@ class PubSub: ignore_subscribe_messages: bool = False, encoder: Optional["Encoder"] = None, push_handler_func: Union[None, Callable[[str], None]] = None, - event_dispatcher: Optional["EventDispatcher"] = None, ): self.connection_pool = connection_pool self.shard_hint = shard_hint @@ -769,12 +664,6 @@ class PubSub: # to lookup channel and pattern names for callback handlers. self.encoder = encoder self.push_handler_func = push_handler_func - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher - - self._lock = threading.RLock() if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) @@ -804,7 +693,7 @@ class PubSub: def reset(self) -> None: if self.connection: self.connection.disconnect() - self.connection.deregister_connect_callback(self.on_connect) + self.connection._deregister_connect_callback(self.on_connect) self.connection_pool.release(self.connection) self.connection = None self.health_check_response_counter = 0 @@ -857,23 +746,19 @@ class PubSub: # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection() + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None: - self.connection._parser.set_pubsub_push_handler(self.push_handler_func) - self._event_dispatcher.dispatch( - AfterPubSubConnectionInstantiationEvent( - self.connection, self.connection_pool, ClientType.SYNC, self._lock - ) - ) + self.connection._register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: self.clean_health_check_responses() - with self._lock: - self._execute(connection, connection.send_command, *args, **kwargs) + self._execute(connection, connection.send_command, *args, **kwargs) def clean_health_check_responses(self) -> None: """ @@ -889,18 +774,19 @@ class PubSub: else: raise PubSubError( "A non health check response was cleaned by " - "execute_command: {}".format(response) + "execute_command: {0}".format(response) ) ttl -= 1 - def _reconnect(self, conn) -> None: + def _disconnect_raise_connect(self, conn, error) -> None: """ - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - In this error handler we are trying to reconnect to the server. + Close the connection and raise an exception + if retry_on_timeout is not set or the error + is not a TimeoutError. Otherwise, try to reconnect """ conn.disconnect() + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + raise error conn.connect() def _execute(self, conn, command, *args, **kwargs): @@ -913,7 +799,7 @@ class PubSub: """ return conn.retry.call_with_retry( lambda: command(*args, **kwargs), - lambda _: self._reconnect(conn), + lambda error: self._disconnect_raise_connect(conn, error), ) def parse_response(self, block=True, timeout=0): @@ -962,7 +848,7 @@ class PubSub: "did you forget to call subscribe() or psubscribe()?" ) - if conn.health_check_interval and time.monotonic() > conn.next_health_check: + if conn.health_check_interval and time.time() > conn.next_health_check: conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False) self.health_check_response_counter += 1 @@ -1112,12 +998,12 @@ class PubSub: """ if not self.subscribed: # Wait for subscription - start_time = time.monotonic() + start_time = time.time() if self.subscribed_event.wait(timeout) is True: # The connection was subscribed during the timeout time frame. # The timeout should be adjusted based on the time spent # waiting for the subscription - time_spent = time.monotonic() - start_time + time_spent = time.time() - start_time timeout = max(0.0, timeout - time_spent) else: # The connection isn't subscribed to any channels or patterns, @@ -1214,7 +1100,7 @@ class PubSub: def run_in_thread( self, - sleep_time: float = 0.0, + sleep_time: int = 0, daemon: bool = False, exception_handler: Optional[Callable] = None, ) -> "PubSubWorkerThread": @@ -1282,8 +1168,7 @@ class Pipeline(Redis): in one transmission. This is convenient for batch processing, such as saving all the values in a list to Redis. - All commands executed within a pipeline(when running in transactional mode, - which is the default behavior) are wrapped with MULTI and EXEC + All commands executed within a pipeline are wrapped with MULTI and EXEC calls. This guarantees all commands executed in the pipeline will be executed atomically. @@ -1298,22 +1183,15 @@ class Pipeline(Redis): UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - def __init__( - self, - connection_pool: ConnectionPool, - response_callbacks, - transaction, - shard_hint, - ): + def __init__(self, connection_pool, response_callbacks, transaction, shard_hint): self.connection_pool = connection_pool - self.connection: Optional[Connection] = None + self.connection = None self.response_callbacks = response_callbacks self.transaction = transaction self.shard_hint = shard_hint + self.watching = False - self.command_stack = [] - self.scripts: Set[Script] = set() - self.explicit_transaction = False + self.reset() def __enter__(self) -> "Pipeline": return self @@ -1379,51 +1257,47 @@ class Pipeline(Redis): return self.immediate_execute_command(*args, **kwargs) return self.pipeline_execute_command(*args, **kwargs) - def _disconnect_reset_raise_on_watching( - self, - conn: AbstractConnection, - error: Exception, - ) -> None: + def _disconnect_reset_raise(self, conn, error) -> None: """ - Close the connection reset watching state and - raise an exception if we were watching. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection, reset watching state and + raise an exception if we were watching, + retry_on_timeout is not set, + or the error is not a TimeoutError """ conn.disconnect() - # if we were already watching a variable, the watch is no longer # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: self.reset() raise WatchError( - f"A {type(error).__name__} occurred while watching one or more keys" + "A ConnectionError occurred on while watching one or more keys" ) + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): + self.reset() + raise def immediate_execute_command(self, *args, **options): """ - Execute a command immediately, but don't auto-retry on the supported - errors for retry if we're already WATCHing a variable. - Used when issuing WATCH or subsequent commands retrieving their values but before + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before MULTI is called. """ command_name = args[0] conn = self.connection # if this is the first call, we need a connection if not conn: - conn = self.connection_pool.get_connection() + conn = self.connection_pool.get_connection(command_name, self.shard_hint) self.connection = conn return conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), - lambda error: self._disconnect_reset_raise_on_watching(conn, error), + lambda error: self._disconnect_reset_raise(conn, error), ) def pipeline_execute_command(self, *args, **options) -> "Pipeline": @@ -1441,9 +1315,7 @@ class Pipeline(Redis): self.command_stack.append((args, options)) return self - def _execute_transaction( - self, connection: Connection, commands, raise_on_error - ) -> List: + def _execute_transaction(self, connection, commands, raise_on_error) -> List: cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) all_cmds = connection.pack_commands( [args for args, options in cmds if EMPTY_RESPONSE not in options] @@ -1504,8 +1376,6 @@ class Pipeline(Redis): for r, cmd in zip(response, commands): if not isinstance(r, Exception): args, options = cmd - # Remove keys entry, it needs only for cache. - options.pop("keys", None) command_name = args[0] if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) @@ -1537,7 +1407,7 @@ class Pipeline(Redis): def annotate_exception(self, exception, number, command): cmd = " ".join(map(safe_str, command)) msg = ( - f"Command # {number} ({truncate_text(cmd)}) of pipeline " + f"Command # {number} ({cmd}) of pipeline " f"caused error: {exception.args[0]}" ) exception.args = (msg,) + exception.args[1:] @@ -1563,19 +1433,11 @@ class Pipeline(Redis): if not exist: s.sha = immediate("SCRIPT LOAD", s.script) - def _disconnect_raise_on_watching( - self, - conn: AbstractConnection, - error: Exception, - ) -> None: + def _disconnect_raise_reset(self, conn: Redis, error: Exception) -> None: """ - Close the connection, raise an exception if we were watching. - - The supported exceptions are already checked in the - retry object so we don't need to do it here. - - After we disconnect the connection, it will try to reconnect and - do a health check as part of the send_command logic(on connection level). + Close the connection, raise an exception if we were watching, + and raise an exception if TimeoutError is not part of retry_on_error, + or the error is not a TimeoutError """ conn.disconnect() # if we were watching a variable, the watch is no longer valid @@ -1583,10 +1445,17 @@ class Pipeline(Redis): # indicates the user should retry this transaction. if self.watching: raise WatchError( - f"A {type(error).__name__} occurred while watching one or more keys" + "A ConnectionError occurred on while watching one or more keys" ) + # if TimeoutError is not part of retry_on_error, or the error + # is not a TimeoutError, raise it + if not ( + TimeoutError in conn.retry_on_error and isinstance(error, TimeoutError) + ): + self.reset() + raise error - def execute(self, raise_on_error: bool = True) -> List[Any]: + def execute(self, raise_on_error=True): """Execute all the commands in the current pipeline""" stack = self.command_stack if not stack and not self.watching: @@ -1600,7 +1469,7 @@ class Pipeline(Redis): conn = self.connection if not conn: - conn = self.connection_pool.get_connection() + conn = self.connection_pool.get_connection("MULTI", self.shard_hint) # assign to self.connection so reset() releases the connection # back to the pool after we're done self.connection = conn @@ -1608,7 +1477,7 @@ class Pipeline(Redis): try: return conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), - lambda error: self._disconnect_raise_on_watching(conn, error), + lambda error: self._disconnect_raise_reset(conn, error), ) finally: self.reset() diff --git a/venv/lib/python3.12/site-packages/redis/cluster.py b/venv/lib/python3.12/site-packages/redis/cluster.py index 4b971cf..873d586 100644 --- a/venv/lib/python3.12/site-packages/redis/cluster.py +++ b/venv/lib/python3.12/site-packages/redis/cluster.py @@ -3,43 +3,26 @@ import socket import sys import threading import time -from abc import ABC, abstractmethod from collections import OrderedDict -from copy import copy -from enum import Enum -from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from redis._parsers import CommandsParser, Encoder from redis._parsers.helpers import parse_scan -from redis.backoff import ExponentialWithJitterBackoff, NoBackoff -from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface -from redis.client import EMPTY_RESPONSE, CaseInsensitiveDict, PubSub, Redis +from redis.backoff import default_backoff +from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args -from redis.connection import ( - Connection, - ConnectionPool, - parse_url, -) +from redis.connection import ConnectionPool, DefaultParser, parse_url from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot -from redis.event import ( - AfterPooledConnectionsInstantiationEvent, - AfterPubSubConnectionInstantiationEvent, - ClientType, - EventDispatcher, -) from redis.exceptions import ( AskError, AuthenticationError, + ClusterCrossSlotError, ClusterDownError, ClusterError, ConnectionError, - CrossSlotTransactionError, DataError, - ExecAbortError, - InvalidPipelineStack, - MaxConnectionsError, + MasterDownError, MovedError, RedisClusterException, RedisError, @@ -47,18 +30,16 @@ from redis.exceptions import ( SlotNotCoveredError, TimeoutError, TryAgainError, - WatchError, ) from redis.lock import Lock from redis.retry import Retry from redis.utils import ( - deprecated_args, + HIREDIS_AVAILABLE, dict_merge, list_keys_to_dict, merge_result, safe_str, str_if_bytes, - truncate_text, ) @@ -66,13 +47,10 @@ def get_node_name(host: str, port: Union[str, int]) -> str: return f"{host}:{port}" -@deprecated_args( - allowed_args=["redis_node"], - reason="Use get_connection(redis_node) instead", - version="5.3.0", -) -def get_connection(redis_node: Redis, *args, **options) -> Connection: - return redis_node.connection or redis_node.connection_pool.get_connection() +def get_connection(redis_node, *args, **options): + return redis_node.connection or redis_node.connection_pool.get_connection( + args[0], **options + ) def parse_scan_result(command, res, **options): @@ -153,6 +131,7 @@ REPLICA = "replica" SLOT_ID = "slot-id" REDIS_ALLOWED_KEYS = ( + "charset", "connection_class", "connection_pool", "connection_pool_class", @@ -162,6 +141,7 @@ REDIS_ALLOWED_KEYS = ( "decode_responses", "encoding", "encoding_errors", + "errors", "host", "lib_name", "lib_version", @@ -185,13 +165,10 @@ REDIS_ALLOWED_KEYS = ( "ssl_cert_reqs", "ssl_keyfile", "ssl_password", - "ssl_check_hostname", "unix_socket_path", "username", - "cache", - "cache_config", ) -KWARGS_DISABLED_KEYS = ("host", "port", "retry") +KWARGS_DISABLED_KEYS = ("host", "port") def cleanup_kwargs(**kwargs): @@ -207,6 +184,20 @@ def cleanup_kwargs(**kwargs): return connection_kwargs +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, + { + "ASK": AskError, + "TRYAGAIN": TryAgainError, + "MOVED": MovedError, + "CLUSTERDOWN": ClusterDownError, + "CROSSSLOT": ClusterCrossSlotError, + "MASTERDOWN": MasterDownError, + }, + ) + + class AbstractRedisCluster: RedisClusterRequestTTL = 16 @@ -300,6 +291,7 @@ class AbstractRedisCluster: "TFUNCTION LIST", "TFCALL", "TFCALLASYNC", + "GRAPH.CONFIG", "LATENCY HISTORY", "LATENCY LATEST", "LATENCY RESET", @@ -319,6 +311,7 @@ class AbstractRedisCluster: "FUNCTION LIST", "FUNCTION LOAD", "FUNCTION RESTORE", + "REDISGEARS_2.REFRESHCLUSTER", "SCAN", "SCRIPT EXISTS", "SCRIPT FLUSH", @@ -422,12 +415,7 @@ class AbstractRedisCluster: list_keys_to_dict(["SCRIPT FLUSH"], lambda command, res: all(res.values())), ) - ERRORS_ALLOW_RETRY = ( - ConnectionError, - TimeoutError, - ClusterDownError, - SlotNotCoveredError, - ) + ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError) def replace_default_node(self, target_node: "ClusterNode" = None) -> None: """Replace the default cluster node. @@ -448,7 +436,7 @@ class AbstractRedisCluster: # Choose a primary if the cluster contains different primaries self.nodes_manager.default_node = random.choice(primaries) else: - # Otherwise, choose a primary if the cluster contains different primaries + # Otherwise, hoose a primary if the cluster contains different primaries replicas = [node for node in self.get_replicas() if node != curr_node] if replicas: self.nodes_manager.default_node = random.choice(replicas) @@ -499,18 +487,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): """ return cls(url=url, **kwargs) - @deprecated_args( - args_to_warn=["read_from_replicas"], - reason="Please configure the 'load_balancing_strategy' instead", - version="5.3.0", - ) - @deprecated_args( - args_to_warn=[ - "cluster_error_retry_attempts", - ], - reason="Please configure the 'retry' object instead", - version="6.0.0", - ) def __init__( self, host: Optional[str] = None, @@ -518,16 +494,12 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): startup_nodes: Optional[List["ClusterNode"]] = None, cluster_error_retry_attempts: int = 3, retry: Optional["Retry"] = None, - require_full_coverage: bool = True, + require_full_coverage: bool = False, reinitialize_steps: int = 5, read_from_replicas: bool = False, - load_balancing_strategy: Optional["LoadBalancingStrategy"] = None, dynamic_startup_nodes: bool = True, url: Optional[str] = None, - address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfig] = None, - event_dispatcher: Optional[EventDispatcher] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): """ @@ -550,16 +522,11 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): cluster client. If not all slots are covered, RedisClusterException will be thrown. :param read_from_replicas: - @deprecated - please use load_balancing_strategy instead Enable read from replicas in READONLY mode. You can read possibly stale data. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. - :param load_balancing_strategy: - Enable read from replicas in READONLY mode and defines the load balancing - strategy that will be used for cluster node selection. - The data read from replicas is eventually consistent with the data in primary nodes. - :param dynamic_startup_nodes: + :param dynamic_startup_nodes: Set the RedisCluster's startup nodes to all of the discovered nodes. If true (default value), the cluster's discovered nodes will be used to determine the cluster nodes-slots mapping in the next topology refresh. @@ -568,19 +535,9 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): If you use dynamic DNS endpoints for startup nodes but CLUSTER SLOTS lists specific IP addresses, it is best to set it to false. :param cluster_error_retry_attempts: - @deprecated - Please configure the 'retry' object instead - In case 'retry' object is set - this argument is ignored! - Number of times to retry before raising an error when - :class:`~.TimeoutError` or :class:`~.ConnectionError`, :class:`~.SlotNotCoveredError` or + :class:`~.TimeoutError` or :class:`~.ConnectionError` or :class:`~.ClusterDownError` are encountered - :param retry: - A retry object that defines the retry strategy and the number of - retries for the cluster client. - In current implementation for the cluster client (starting form redis-py version 6.0.0) - the retry object is not yet fully utilized, instead it is used just to determine - the number of retries for the cluster client. - In the future releases the retry object will be used to handle the cluster client retries! :param reinitialize_steps: Specifies the number of MOVED errors that need to occur before reinitializing the whole cluster topology. If a MOVED error occurs @@ -600,8 +557,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): :**kwargs: Extra arguments that will be sent into Redis instance when created - (See Official redis-py doc for supported kwargs - the only limitation - is that you can't provide 'retry' object as part of kwargs. + (See Official redis-py doc for supported kwargs [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) Some kwargs are not supported and will raise a RedisClusterException: @@ -616,15 +572,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): "Argument 'db' is not possible to use in cluster mode" ) - if "retry" in kwargs: - # Argument 'retry' is not possible to be used in kwargs when in cluster mode - # the kwargs are set to the lower level connections to the cluster nodes - # and there we provide retry configuration without retries allowed. - # The retries should be handled on cluster client level. - raise RedisClusterException( - "The 'retry' argument cannot be used in kwargs when running in cluster mode." - ) - # Get the startup node/s from_url = False if url is not None: @@ -667,40 +614,27 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): kwargs = cleanup_kwargs(**kwargs) if retry: self.retry = retry + kwargs.update({"retry": self.retry}) else: - self.retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), - retries=cluster_error_retry_attempts, - ) + kwargs.update({"retry": Retry(default_backoff(), 0)}) self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) - protocol = kwargs.get("protocol", None) - if (cache_config or cache) and protocol not in [3, "3"]: - raise RedisError("Client caching is only supported with RESP version 3") - + self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() self.read_from_replicas = read_from_replicas - self.load_balancing_strategy = load_balancing_strategy self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher self.nodes_manager = NodesManager( startup_nodes=startup_nodes, from_url=from_url, require_full_coverage=require_full_coverage, dynamic_startup_nodes=dynamic_startup_nodes, address_remap=address_remap, - cache=cache, - cache_config=cache_config, - event_dispatcher=self._event_dispatcher, **kwargs, ) @@ -708,9 +642,8 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS ) self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS) - self.commands_parser = CommandsParser(self) - self._lock = threading.RLock() + self._lock = threading.Lock() def __enter__(self): return self @@ -719,10 +652,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): self.close() def __del__(self): - try: - self.close() - except Exception: - pass + self.close() def disconnect_connection_pools(self): for node in self.get_nodes(): @@ -738,9 +668,10 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): Initialize the connection, authenticate and select a database and send READONLY if it is set during object initialization. """ + connection.set_parser(ClusterParser) connection.on_connect() - if self.read_from_replicas or self.load_balancing_strategy: + if self.read_from_replicas: # Sending READONLY command to server to configure connection as # readonly. Since each cluster node may change its server type due # to a failover, we should establish a READONLY connection @@ -753,7 +684,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): if self.user_on_connect_func is not None: self.user_on_connect_func(connection) - def get_redis_connection(self, node: "ClusterNode") -> Redis: + def get_redis_connection(self, node): if not node.redis_connection: with self._lock: if not node.redis_connection: @@ -812,8 +743,13 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): self.nodes_manager.default_node = node return True - def set_retry(self, retry: Retry) -> None: + def get_retry(self) -> Optional["Retry"]: + return self.retry + + def set_retry(self, retry: "Retry") -> None: self.retry = retry + for node in self.get_nodes(): + node.redis_connection.set_retry(retry) def monitor(self, target_node=None): """ @@ -851,18 +787,19 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): if shard_hint: raise RedisClusterException("shard_hint is deprecated in cluster mode") + if transaction: + raise RedisClusterException("transaction is deprecated in cluster mode") + return ClusterPipeline( nodes_manager=self.nodes_manager, commands_parser=self.commands_parser, startup_nodes=self.nodes_manager.startup_nodes, result_callbacks=self.result_callbacks, cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, read_from_replicas=self.read_from_replicas, - load_balancing_strategy=self.load_balancing_strategy, reinitialize_steps=self.reinitialize_steps, - retry=self.retry, lock=self._lock, - transaction=transaction, ) def lock( @@ -874,7 +811,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): blocking_timeout=None, lock_class=None, thread_local=True, - raise_on_release_error: bool = True, ): """ Return a new Lock object using key ``name`` that mimics @@ -921,11 +857,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -943,7 +874,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): blocking=blocking, blocking_timeout=blocking_timeout, thread_local=thread_local, - raise_on_release_error=raise_on_release_error, ) def set_response_callback(self, command, callback): @@ -985,9 +915,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): # get the node that holds the key's slot slot = self.determine_slot(*args) node = self.nodes_manager.get_node_from_slot( - slot, - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy if command in READ_COMMANDS else None, + slot, self.read_from_replicas and command in READ_COMMANDS ) return [node] @@ -1024,7 +952,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): redis_conn = self.get_default_node().redis_connection return self.commands_parser.get_keys(redis_conn, *args) - def determine_slot(self, *args) -> int: + def determine_slot(self, *args): """ Figure out what slot to use based on args. @@ -1117,14 +1045,11 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): return nodes def execute_command(self, *args, **kwargs): - return self._internal_execute_command(*args, **kwargs) - - def _internal_execute_command(self, *args, **kwargs): """ Wrapper for ERRORS_ALLOW_RETRY error handling. - It will try the number of times specified by the retries property from - config option "self.retry" which defaults to 3 unless manually + It will try the number of times specified by the config option + "self.cluster_error_retry_attempts" which defaults to 3 unless manually configured. If it reaches the number of times, the command will raise the exception @@ -1150,7 +1075,9 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): # execution since the nodes may not be valid anymore after the tables # were reinitialized. So in case of passed target nodes, # retry_attempts will be set to 0. - retry_attempts = 0 if target_nodes_specified else self.retry.get_retries() + retry_attempts = ( + 0 if target_nodes_specified else self.cluster_error_retry_attempts + ) # Add one for the first execution execute_attempts = 1 + retry_attempts for _ in range(execute_attempts): @@ -1209,26 +1136,19 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): # refresh the target node slot = self.determine_slot(*args) target_node = self.nodes_manager.get_node_from_slot( - slot, - self.read_from_replicas and command in READ_COMMANDS, - self.load_balancing_strategy - if command in READ_COMMANDS - else None, + slot, self.read_from_replicas and command in READ_COMMANDS ) moved = False redis_node = self.get_redis_connection(target_node) - connection = get_connection(redis_node) + connection = get_connection(redis_node, *args, **kwargs) if asking: connection.send_command("ASKING") redis_node.parse_response(connection, "ASKING", **kwargs) asking = False - connection.send_command(*args, **kwargs) + + connection.send_command(*args) response = redis_node.parse_response(connection, command, **kwargs) - - # Remove keys entry, it needs only for cache. - kwargs.pop("keys", None) - if command in self.cluster_response_callbacks: response = self.cluster_response_callbacks[command]( response, **kwargs @@ -1236,13 +1156,9 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): return response except AuthenticationError: raise - except MaxConnectionsError: - # MaxConnectionsError indicates client-side resource exhaustion - # (too many connections in the pool), not a node failure. - # Don't treat this as a node failure - just re-raise the error - # without reinitializing the cluster. - raise except (ConnectionError, TimeoutError) as e: + # Connection retries are being handled in the node's + # Retry object. # ConnectionError can also be raised if we couldn't get a # connection from the pool before timing out, so check that # this is an actual connection before attempting to disconnect. @@ -1279,19 +1195,13 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except (ClusterDownError, SlotNotCoveredError): + except ClusterDownError as e: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command - - # SlotNotCoveredError can occur when the cluster is not fully - # initialized or can be temporary issue. - # We will try to reinitialize the cluster topology - # and retry executing the command - time.sleep(0.25) self.nodes_manager.initialize() - raise + raise e except ResponseError: raise except Exception as e: @@ -1304,7 +1214,7 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): raise ClusterError("TTL exhausted.") - def close(self) -> None: + def close(self): try: with self._lock: if self.nodes_manager: @@ -1343,28 +1253,6 @@ class RedisCluster(AbstractRedisCluster, RedisClusterCommands): """ setattr(self, funcname, func) - def transaction(self, func, *watches, **kwargs): - """ - Convenience method for executing the callable `func` as a transaction - while watching all keys specified in `watches`. The 'func' callable - should expect a single argument which is a Pipeline object. - """ - shard_hint = kwargs.pop("shard_hint", None) - value_from_callable = kwargs.pop("value_from_callable", False) - watch_delay = kwargs.pop("watch_delay", None) - with self.pipeline(True, shard_hint) as pipe: - while True: - try: - if watches: - pipe.watch(*watches) - func_value = func(pipe) - exec_value = pipe.execute() - return func_value if value_from_callable else exec_value - except WatchError: - if watch_delay is not None and watch_delay > 0: - time.sleep(watch_delay) - continue - class ClusterNode: def __init__(self, host, port, server_type=None, redis_connection=None): @@ -1390,18 +1278,8 @@ class ClusterNode: return isinstance(obj, ClusterNode) and obj.name == self.name def __del__(self): - try: - if self.redis_connection is not None: - self.redis_connection.close() - except Exception: - # Ignore errors when closing the connection - pass - - -class LoadBalancingStrategy(Enum): - ROUND_ROBIN = "round_robin" - ROUND_ROBIN_REPLICAS = "round_robin_replicas" - RANDOM_REPLICA = "random_replica" + if self.redis_connection is not None: + self.redis_connection.close() class LoadBalancer: @@ -1413,38 +1291,15 @@ class LoadBalancer: self.primary_to_idx = {} self.start_index = start_index - def get_server_index( - self, - primary: str, - list_size: int, - load_balancing_strategy: LoadBalancingStrategy = LoadBalancingStrategy.ROUND_ROBIN, - ) -> int: - if load_balancing_strategy == LoadBalancingStrategy.RANDOM_REPLICA: - return self._get_random_replica_index(list_size) - else: - return self._get_round_robin_index( - primary, - list_size, - load_balancing_strategy == LoadBalancingStrategy.ROUND_ROBIN_REPLICAS, - ) + def get_server_index(self, primary: str, list_size: int) -> int: + server_index = self.primary_to_idx.setdefault(primary, self.start_index) + # Update the index + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index def reset(self) -> None: self.primary_to_idx.clear() - def _get_random_replica_index(self, list_size: int) -> int: - return random.randint(1, list_size - 1) - - def _get_round_robin_index( - self, primary: str, list_size: int, replicas_only: bool - ) -> int: - server_index = self.primary_to_idx.setdefault(primary, self.start_index) - if replicas_only and server_index == 0: - # skip the primary node index - server_index = 1 - # Update the index for the next round - self.primary_to_idx[primary] = (server_index + 1) % list_size - return server_index - class NodesManager: def __init__( @@ -1455,14 +1310,10 @@ class NodesManager: lock=None, dynamic_startup_nodes=True, connection_pool_class=ConnectionPool, - address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None, - cache: Optional[CacheInterface] = None, - cache_config: Optional[CacheConfig] = None, - cache_factory: Optional[CacheFactoryInterface] = None, - event_dispatcher: Optional[EventDispatcher] = None, + address_remap: Optional[Callable[[str, int], Tuple[str, int]]] = None, **kwargs, ): - self.nodes_cache: Dict[str, Redis] = {} + self.nodes_cache = {} self.slots_cache = {} self.startup_nodes = {} self.default_node = None @@ -1472,22 +1323,12 @@ class NodesManager: self._dynamic_startup_nodes = dynamic_startup_nodes self.connection_pool_class = connection_pool_class self.address_remap = address_remap - self._cache = cache - self._cache_config = cache_config - self._cache_factory = cache_factory self._moved_exception = None self.connection_kwargs = kwargs self.read_load_balancer = LoadBalancer() if lock is None: - lock = threading.RLock() + lock = threading.Lock() self._lock = lock - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher - self._credential_provider = self.connection_kwargs.get( - "credential_provider", None - ) self.initialize() def get_node(self, host=None, port=None, node_name=None): @@ -1548,21 +1389,7 @@ class NodesManager: # Reset moved_exception self._moved_exception = None - @deprecated_args( - args_to_warn=["server_type"], - reason=( - "In case you need select some load balancing strategy " - "that will use replicas, please set it through 'load_balancing_strategy'" - ), - version="5.3.0", - ) - def get_node_from_slot( - self, - slot, - read_from_replicas=False, - load_balancing_strategy=None, - server_type=None, - ) -> ClusterNode: + def get_node_from_slot(self, slot, read_from_replicas=False, server_type=None): """ Gets a node that servers this hash slot """ @@ -1577,14 +1404,11 @@ class NodesManager: f'"require_full_coverage={self._require_full_coverage}"' ) - if read_from_replicas is True and load_balancing_strategy is None: - load_balancing_strategy = LoadBalancingStrategy.ROUND_ROBIN - - if len(self.slots_cache[slot]) > 1 and load_balancing_strategy: - # get the server index using the strategy defined in load_balancing_strategy + if read_from_replicas is True: + # get the server index in a Round-Robin manner primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( - primary_name, len(self.slots_cache[slot]), load_balancing_strategy + primary_name, len(self.slots_cache[slot]) ) elif ( server_type is None @@ -1631,49 +1455,20 @@ class NodesManager: """ This function will create a redis connection to all nodes in :nodes: """ - connection_pools = [] for node in nodes: if node.redis_connection is None: node.redis_connection = self.create_redis_node( host=node.host, port=node.port, **self.connection_kwargs ) - connection_pools.append(node.redis_connection.connection_pool) - - self._event_dispatcher.dispatch( - AfterPooledConnectionsInstantiationEvent( - connection_pools, ClientType.SYNC, self._credential_provider - ) - ) def create_redis_node(self, host, port, **kwargs): - # We are configuring the connection pool not to retry - # connections on lower level clients to avoid retrying - # connections to nodes that are not reachable - # and to avoid blocking the connection pool. - # The only error that will have some handling in the lower - # level clients is ConnectionError which will trigger disconnection - # of the socket. - # The retries will be handled on cluster client level - # where we will have proper handling of the cluster topology - node_retry_config = Retry( - backoff=NoBackoff(), retries=0, supported_errors=(ConnectionError,) - ) - if self.from_url: # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - kwargs.update({"cache": self._cache}) - kwargs.update({"retry": node_retry_config}) r = Redis(connection_pool=self.connection_pool_class(**kwargs)) else: - r = Redis( - host=host, - port=port, - cache=self._cache, - retry=node_retry_config, - **kwargs, - ) + r = Redis(host=host, port=port, **kwargs) return r def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache): @@ -1690,8 +1485,6 @@ class NodesManager: target_node = ClusterNode(host, port, role) if target_node.server_type != role: target_node.server_type = role - # add this node to the nodes cache - tmp_nodes_cache[target_node.name] = target_node return target_node @@ -1709,9 +1502,7 @@ class NodesManager: fully_covered = False kwargs = self.connection_kwargs exception = None - # Convert to tuple to prevent RuntimeError if self.startup_nodes - # is modified during iteration - for startup_node in tuple(self.startup_nodes.values()): + for startup_node in self.startup_nodes.values(): try: if startup_node.redis_connection: r = startup_node.redis_connection @@ -1722,13 +1513,11 @@ class NodesManager: ) self.startup_nodes[startup_node.name].redis_connection = r # Make sure cluster mode is enabled on this node - try: - cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) - r.connection_pool.disconnect() - except ResponseError: + if bool(r.info().get("cluster_enabled")) is False: raise RedisClusterException( "Cluster mode is not enabled on this node" ) + cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS")) startup_nodes_reachable = True except Exception as e: # Try the next startup node. @@ -1758,26 +1547,31 @@ class NodesManager: port = int(primary_node[1]) host, port = self.remap_host_port(host, port) - nodes_for_slot = [] - target_node = self._get_or_create_cluster_node( host, port, PRIMARY, tmp_nodes_cache ) - nodes_for_slot.append(target_node) - - replica_nodes = slot[3:] - for replica_node in replica_nodes: - host = str_if_bytes(replica_node[0]) - port = int(replica_node[1]) - host, port = self.remap_host_port(host, port) - target_replica_node = self._get_or_create_cluster_node( - host, port, REPLICA, tmp_nodes_cache - ) - nodes_for_slot.append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node for i in range(int(slot[0]), int(slot[1]) + 1): if i not in tmp_slots: - tmp_slots[i] = nodes_for_slot + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = str_if_bytes(replica_node[0]) + port = replica_node[1] + host, port = self.remap_host_port(host, port) + + target_replica_node = self._get_or_create_cluster_node( + host, port, REPLICA, tmp_nodes_cache + ) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node else: # Validate that 2 nodes want to use the same slot cache # setup @@ -1790,7 +1584,7 @@ class NodesManager: if len(disagreements) > 5: raise RedisClusterException( f"startup_nodes could not agree on a valid " - f"slots cache: {', '.join(disagreements)}" + f'slots cache: {", ".join(disagreements)}' ) fully_covered = self.check_slots_coverage(tmp_slots) @@ -1805,12 +1599,6 @@ class NodesManager: f"one reachable node: {str(exception)}" ) from exception - if self._cache is None and self._cache_config is not None: - if self._cache_factory is None: - self._cache = CacheFactory(self._cache_config).get_cache() - else: - self._cache = self._cache_factory.get_cache() - # Create Redis connections to all nodes self.create_redis_connections(list(tmp_nodes_cache.values())) @@ -1835,7 +1623,7 @@ class NodesManager: # If initialize was called after a MovedError, clear it self._moved_exception = None - def close(self) -> None: + def close(self): self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: @@ -1858,16 +1646,6 @@ class NodesManager: return self.address_remap((host, port)) return host, port - def find_connection_owner(self, connection: Connection) -> Optional[Redis]: - node_name = get_node_name(connection.host, connection.port) - for node in tuple(self.nodes_cache.values()): - if node.redis_connection: - conn_args = node.redis_connection.connection_pool.connection_kwargs - if node_name == get_node_name( - conn_args.get("host"), conn_args.get("port") - ): - return node - class ClusterPubSub(PubSub): """ @@ -1885,7 +1663,6 @@ class ClusterPubSub(PubSub): host=None, port=None, push_handler_func=None, - event_dispatcher: Optional["EventDispatcher"] = None, **kwargs, ): """ @@ -1894,7 +1671,7 @@ class ClusterPubSub(PubSub): first command execution. The node will be determined by: 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node that handles the keyslot: If read_from_replicas is - set to true or load_balancing_strategy is set, a replica can be selected. + set to true, a replica can be selected. :type redis_cluster: RedisCluster :type node: ClusterNode @@ -1911,15 +1688,10 @@ class ClusterPubSub(PubSub): self.cluster = redis_cluster self.node_pubsub_mapping = {} self._pubsubs_generator = self._pubsubs_generator() - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher super().__init__( connection_pool=connection_pool, encoder=redis_cluster.encoder, push_handler_func=push_handler_func, - event_dispatcher=self._event_dispatcher, **kwargs, ) @@ -1990,9 +1762,7 @@ class ClusterPubSub(PubSub): channel = args[1] slot = self.cluster.keyslot(channel) node = self.cluster.nodes_manager.get_node_from_slot( - slot, - self.cluster.read_from_replicas, - self.cluster.load_balancing_strategy, + slot, self.cluster.read_from_replicas ) else: # Get a random node @@ -2000,17 +1770,14 @@ class ClusterPubSub(PubSub): self.node = node redis_connection = self.cluster.get_redis_connection(node) self.connection_pool = redis_connection.connection_pool - self.connection = self.connection_pool.get_connection() + self.connection = self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) # register a callback that re-subscribes to any channels we # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None: - self.connection._parser.set_pubsub_push_handler(self.push_handler_func) - self._event_dispatcher.dispatch( - AfterPubSubConnectionInstantiationEvent( - self.connection, self.connection_pool, ClientType.SYNC, self._lock - ) - ) + self.connection._register_connect_callback(self.on_connect) + if self.push_handler_func is not None and not HIREDIS_AVAILABLE: + self.connection._parser.set_push_handler(self.push_handler_func) connection = self.connection self._execute(connection, connection.send_command, *args) @@ -2034,7 +1801,8 @@ class ClusterPubSub(PubSub): def _pubsubs_generator(self): while True: - yield from self.node_pubsub_mapping.values() + for pubsub in self.node_pubsub_mapping.values(): + yield pubsub def get_sharded_message( self, ignore_subscribe_messages=False, timeout=0.0, target_node=None @@ -2127,17 +1895,6 @@ class ClusterPipeline(RedisCluster): TryAgainError, ) - NO_SLOTS_COMMANDS = {"UNWATCH"} - IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} - UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - - @deprecated_args( - args_to_warn=[ - "cluster_error_retry_attempts", - ], - reason="Please configure the 'retry' object instead", - version="6.0.0", - ) def __init__( self, nodes_manager: "NodesManager", @@ -2146,12 +1903,9 @@ class ClusterPipeline(RedisCluster): cluster_response_callbacks: Optional[Dict[str, Callable]] = None, startup_nodes: Optional[List["ClusterNode"]] = None, read_from_replicas: bool = False, - load_balancing_strategy: Optional[LoadBalancingStrategy] = None, cluster_error_retry_attempts: int = 3, reinitialize_steps: int = 5, - retry: Optional[Retry] = None, lock=None, - transaction=False, **kwargs, ): """ """ @@ -2164,31 +1918,19 @@ class ClusterPipeline(RedisCluster): ) self.startup_nodes = startup_nodes if startup_nodes else [] self.read_from_replicas = read_from_replicas - self.load_balancing_strategy = load_balancing_strategy self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks + self.cluster_error_retry_attempts = cluster_error_retry_attempts self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps - if retry is not None: - self.retry = retry - else: - self.retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), - retries=cluster_error_retry_attempts, - ) - self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), kwargs.get("decode_responses", False), ) if lock is None: - lock = threading.RLock() + lock = threading.Lock() self._lock = lock - self.parent_execute_command = super().execute_command - self._execution_strategy: ExecutionStrategy = ( - PipelineStrategy(self) if not transaction else TransactionStrategy(self) - ) def __repr__(self): """ """ @@ -2210,7 +1952,7 @@ class ClusterPipeline(RedisCluster): def __len__(self): """ """ - return len(self._execution_strategy.command_queue) + return len(self.command_stack) def __bool__(self): "Pipeline instances should always evaluate to True on Python 3+" @@ -2220,35 +1962,45 @@ class ClusterPipeline(RedisCluster): """ Wrapper function for pipeline_execute_command """ - return self._execution_strategy.execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ - Stage a command to be executed when execute() is next called - - Returns the current Pipeline object back so commands can be - chained together, such as: - - pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') - - At some other point, you can then run: pipe.execute(), - which will execute all commands queued in the pipe. + Appends the executed command to the pipeline's command stack """ - return self._execution_strategy.execute_command(*args, **options) + self.command_stack.append( + PipelineCommand(args, options, len(self.command_stack)) + ) + return self + + def raise_first_error(self, stack): + """ + Raise the first exception on the stack + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r def annotate_exception(self, exception, number, command): """ Provides extra context to the exception prior to it being handled """ - self._execution_strategy.annotate_exception(exception, number, command) + cmd = " ".join(map(safe_str, command)) + msg = ( + f"Command # {number} ({cmd}) of pipeline " + f"caused error: {exception.args[0]}" + ) + exception.args = (msg,) + exception.args[1:] - def execute(self, raise_on_error: bool = True) -> List[Any]: + def execute(self, raise_on_error=True): """ Execute all the commands in the current pipeline """ - + stack = self.command_stack try: - return self._execution_strategy.execute(raise_on_error) + return self.send_cluster_commands(stack, raise_on_error) finally: self.reset() @@ -2256,53 +2008,306 @@ class ClusterPipeline(RedisCluster): """ Reset back to empty pipeline. """ - self._execution_strategy.reset() + self.command_stack = [] + + self.scripts = set() + + # TODO: Implement + # make sure to reset the connection state in the event that we were + # watching something + # if self.watching and self.connection: + # try: + # # call this manually since our unwatch or + # # immediate_execute_command methods can call reset() + # self.connection.send_command('UNWATCH') + # self.connection.read_response() + # except ConnectionError: + # # disconnect will also remove any previous WATCHes + # self.connection.disconnect() + + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + + # TODO: Implement + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + # if self.connection: + # self.connection_pool.release(self.connection) + # self.connection = None def send_cluster_commands( self, stack, raise_on_error=True, allow_redirections=True ): - return self._execution_strategy.send_cluster_commands( - stack, raise_on_error=raise_on_error, allow_redirections=allow_redirections + """ + Wrapper for CLUSTERDOWN error handling. + + If the cluster reports it is down it is assumed that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the config option "self.cluster_error_retry_attempts" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + if not stack: + return [] + retry_attempts = self.cluster_error_retry_attempts + while True: + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except (ClusterDownError, ConnectionError) as e: + if retry_attempts > 0: + # Try again with the new cluster setup. All other errors + # should be raised. + retry_attempts -= 1 + pass + else: + raise e + + def _send_cluster_commands( + self, stack, raise_on_error=True, allow_redirections=True + ): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + is_default_node = False + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + while True: + # refer to our internal node -> slot table that + # tells us where a given command should route to. + # (it might be possible we have a cached node that no longer + # exists in the cluster, which is why we do this in a loop) + passed_targets = c.options.pop("target_nodes", None) + if passed_targets and not self._is_nodes_flag(passed_targets): + target_nodes = self._parse_target_nodes(passed_targets) + else: + target_nodes = self._determine_nodes( + *c.args, node_flag=passed_targets + ) + if not target_nodes: + raise RedisClusterException( + f"No targets were found to execute {c.args} command on" + ) + if len(target_nodes) > 1: + raise RedisClusterException( + f"Too many targets for command {c.args}" + ) + + node = target_nodes[0] + if node == self.get_default_node(): + is_default_node = True + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node.name + if node_name not in nodes: + redis_node = self.get_redis_connection(node) + try: + connection = get_connection(redis_node, c.args) + except ConnectionError: + # Connection retries are being handled in the node's + # Retry object. Reinitialize the node -> slot table. + self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() + raise + nodes[node_name] = NodeCommands( + redis_node.parse_response, + redis_node.connection_pool, + connection, + ) + nodes[node_name].append(c) + break + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network essentially in parallel + # so that we can read them all in parallel as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted( + ( + c + for c in attempt + if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) + ), + key=lambda x: x.position, ) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaing commands one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + if is_default_node: + self.replace_default_node() + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = super().execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [] + for c in sorted(stack, key=lambda x: x.position): + if c.args[0] in self.cluster_response_callbacks: + c.result = self.cluster_response_callbacks[c.args[0]]( + c.result, **c.options + ) + response.append(c.result) + + if raise_on_error: + self.raise_first_error(stack) + + return response + + def _fail_on_redirect(self, allow_redirections): + """ """ + if not allow_redirections: + raise RedisClusterException( + "ASK & MOVED redirection not allowed in this pipeline" + ) def exists(self, *keys): - return self._execution_strategy.exists(*keys) + return self.execute_command("EXISTS", *keys) def eval(self): """ """ - return self._execution_strategy.eval() + raise RedisClusterException("method eval() is not implemented") def multi(self): - """ - Start a transactional block of the pipeline after WATCH commands - are issued. End the transactional block with `execute`. - """ - self._execution_strategy.multi() + """ """ + raise RedisClusterException("method multi() is not implemented") + + def immediate_execute_command(self, *args, **options): + """ """ + raise RedisClusterException( + "method immediate_execute_command() is not implemented" + ) + + def _execute_transaction(self, *args, **kwargs): + """ """ + raise RedisClusterException("method _execute_transaction() is not implemented") def load_scripts(self): """ """ - self._execution_strategy.load_scripts() - - def discard(self): - """ """ - self._execution_strategy.discard() + raise RedisClusterException("method load_scripts() is not implemented") def watch(self, *names): - """Watches the values at keys ``names``""" - self._execution_strategy.watch(*names) + """ """ + raise RedisClusterException("method watch() is not implemented") def unwatch(self): - """Unwatches all previously specified keys""" - self._execution_strategy.unwatch() + """ """ + raise RedisClusterException("method unwatch() is not implemented") def script_load_for_pipeline(self, *args, **kwargs): - self._execution_strategy.script_load_for_pipeline(*args, **kwargs) + """ """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented" + ) def delete(self, *names): - self._execution_strategy.delete(*names) + """ + "Delete a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("DEL", names[0]) def unlink(self, *names): - self._execution_strategy.unlink(*names) + """ + "Unlink a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "unlinking multiple keys is not implemented in pipeline command" + ) + + return self.execute_command("UNLINK", names[0]) def block_pipeline_command(name: str) -> Callable[..., Any]: @@ -2479,881 +2484,3 @@ class NodeCommands: return except RedisError: c.result = sys.exc_info()[1] - - -class ExecutionStrategy(ABC): - @property - @abstractmethod - def command_queue(self): - pass - - @abstractmethod - def execute_command(self, *args, **kwargs): - """ - Execution flow for current execution strategy. - - See: ClusterPipeline.execute_command() - """ - pass - - @abstractmethod - def annotate_exception(self, exception, number, command): - """ - Annotate exception according to current execution strategy. - - See: ClusterPipeline.annotate_exception() - """ - pass - - @abstractmethod - def pipeline_execute_command(self, *args, **options): - """ - Pipeline execution flow for current execution strategy. - - See: ClusterPipeline.pipeline_execute_command() - """ - pass - - @abstractmethod - def execute(self, raise_on_error: bool = True) -> List[Any]: - """ - Executes current execution strategy. - - See: ClusterPipeline.execute() - """ - pass - - @abstractmethod - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Sends commands according to current execution strategy. - - See: ClusterPipeline.send_cluster_commands() - """ - pass - - @abstractmethod - def reset(self): - """ - Resets current execution strategy. - - See: ClusterPipeline.reset() - """ - pass - - @abstractmethod - def exists(self, *keys): - pass - - @abstractmethod - def eval(self): - pass - - @abstractmethod - def multi(self): - """ - Starts transactional context. - - See: ClusterPipeline.multi() - """ - pass - - @abstractmethod - def load_scripts(self): - pass - - @abstractmethod - def watch(self, *names): - pass - - @abstractmethod - def unwatch(self): - """ - Unwatches all previously specified keys - - See: ClusterPipeline.unwatch() - """ - pass - - @abstractmethod - def script_load_for_pipeline(self, *args, **kwargs): - pass - - @abstractmethod - def delete(self, *names): - """ - "Delete a key specified by ``names``" - - See: ClusterPipeline.delete() - """ - pass - - @abstractmethod - def unlink(self, *names): - """ - "Unlink a key specified by ``names``" - - See: ClusterPipeline.unlink() - """ - pass - - @abstractmethod - def discard(self): - pass - - -class AbstractStrategy(ExecutionStrategy): - def __init__( - self, - pipe: ClusterPipeline, - ): - self._command_queue: List[PipelineCommand] = [] - self._pipe = pipe - self._nodes_manager = self._pipe.nodes_manager - - @property - def command_queue(self): - return self._command_queue - - @command_queue.setter - def command_queue(self, queue: List[PipelineCommand]): - self._command_queue = queue - - @abstractmethod - def execute_command(self, *args, **kwargs): - pass - - def pipeline_execute_command(self, *args, **options): - self._command_queue.append( - PipelineCommand(args, options, len(self._command_queue)) - ) - return self._pipe - - @abstractmethod - def execute(self, raise_on_error: bool = True) -> List[Any]: - pass - - @abstractmethod - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - pass - - @abstractmethod - def reset(self): - pass - - def exists(self, *keys): - return self.execute_command("EXISTS", *keys) - - def eval(self): - """ """ - raise RedisClusterException("method eval() is not implemented") - - def load_scripts(self): - """ """ - raise RedisClusterException("method load_scripts() is not implemented") - - def script_load_for_pipeline(self, *args, **kwargs): - """ """ - raise RedisClusterException( - "method script_load_for_pipeline() is not implemented" - ) - - def annotate_exception(self, exception, number, command): - """ - Provides extra context to the exception prior to it being handled - """ - cmd = " ".join(map(safe_str, command)) - msg = ( - f"Command # {number} ({truncate_text(cmd)}) of pipeline " - f"caused error: {exception.args[0]}" - ) - exception.args = (msg,) + exception.args[1:] - - -class PipelineStrategy(AbstractStrategy): - def __init__(self, pipe: ClusterPipeline): - super().__init__(pipe) - self.command_flags = pipe.command_flags - - def execute_command(self, *args, **kwargs): - return self.pipeline_execute_command(*args, **kwargs) - - def _raise_first_error(self, stack): - """ - Raise the first exception on the stack - """ - for c in stack: - r = c.result - if isinstance(r, Exception): - self.annotate_exception(r, c.position + 1, c.args) - raise r - - def execute(self, raise_on_error: bool = True) -> List[Any]: - stack = self._command_queue - if not stack: - return [] - - try: - return self.send_cluster_commands(stack, raise_on_error) - finally: - self.reset() - - def reset(self): - """ - Reset back to empty pipeline. - """ - self._command_queue = [] - - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Wrapper for RedisCluster.ERRORS_ALLOW_RETRY errors handling. - - If one of the retryable exceptions has been thrown we assume that: - - connection_pool was disconnected - - connection_pool was reseted - - refereh_table_asap set to True - - It will try the number of times specified by - the retries in config option "self.retry" - which defaults to 3 unless manually configured. - - If it reaches the number of times, the command will - raises ClusterDownException. - """ - if not stack: - return [] - retry_attempts = self._pipe.retry.get_retries() - while True: - try: - return self._send_cluster_commands( - stack, - raise_on_error=raise_on_error, - allow_redirections=allow_redirections, - ) - except RedisCluster.ERRORS_ALLOW_RETRY as e: - if retry_attempts > 0: - # Try again with the new cluster setup. All other errors - # should be raised. - retry_attempts -= 1 - pass - else: - raise e - - def _send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - """ - Send a bunch of cluster commands to the redis cluster. - - `allow_redirections` If the pipeline should follow - `ASK` & `MOVED` responses automatically. If set - to false it will raise RedisClusterException. - """ - # the first time sending the commands we send all of - # the commands that were queued up. - # if we have to run through it again, we only retry - # the commands that failed. - attempt = sorted(stack, key=lambda x: x.position) - is_default_node = False - # build a list of node objects based on node names we need to - nodes = {} - - # as we move through each command that still needs to be processed, - # we figure out the slot number that command maps to, then from - # the slot determine the node. - for c in attempt: - while True: - # refer to our internal node -> slot table that - # tells us where a given command should route to. - # (it might be possible we have a cached node that no longer - # exists in the cluster, which is why we do this in a loop) - passed_targets = c.options.pop("target_nodes", None) - if passed_targets and not self._is_nodes_flag(passed_targets): - target_nodes = self._parse_target_nodes(passed_targets) - else: - target_nodes = self._determine_nodes( - *c.args, node_flag=passed_targets - ) - if not target_nodes: - raise RedisClusterException( - f"No targets were found to execute {c.args} command on" - ) - if len(target_nodes) > 1: - raise RedisClusterException( - f"Too many targets for command {c.args}" - ) - - node = target_nodes[0] - if node == self._pipe.get_default_node(): - is_default_node = True - - # now that we know the name of the node - # ( it's just a string in the form of host:port ) - # we can build a list of commands for each node. - node_name = node.name - if node_name not in nodes: - redis_node = self._pipe.get_redis_connection(node) - try: - connection = get_connection(redis_node) - except (ConnectionError, TimeoutError): - for n in nodes.values(): - n.connection_pool.release(n.connection) - # Connection retries are being handled in the node's - # Retry object. Reinitialize the node -> slot table. - self._nodes_manager.initialize() - if is_default_node: - self._pipe.replace_default_node() - raise - nodes[node_name] = NodeCommands( - redis_node.parse_response, - redis_node.connection_pool, - connection, - ) - nodes[node_name].append(c) - break - - # send the commands in sequence. - # we write to all the open sockets for each node first, - # before reading anything - # this allows us to flush all the requests out across the - # network - # so that we can read them from different sockets as they come back. - # we dont' multiplex on the sockets as they come available, - # but that shouldn't make too much difference. - try: - node_commands = nodes.values() - for n in node_commands: - n.write() - - for n in node_commands: - n.read() - finally: - # release all of the redis connections we allocated earlier - # back into the connection pool. - # we used to do this step as part of a try/finally block, - # but it is really dangerous to - # release connections back into the pool if for some - # reason the socket has data still left in it - # from a previous operation. The write and - # read operations already have try/catch around them for - # all known types of errors including connection - # and socket level errors. - # So if we hit an exception, something really bad - # happened and putting any oF - # these connections back into the pool is a very bad idea. - # the socket might have unread buffer still sitting in it, - # and then the next time we read from it we pass the - # buffered result back from a previous command and - # every single request after to that connection will always get - # a mismatched result. - for n in nodes.values(): - n.connection_pool.release(n.connection) - - # if the response isn't an exception it is a - # valid response from the node - # we're all done with that command, YAY! - # if we have more commands to attempt, we've run into problems. - # collect all the commands we are allowed to retry. - # (MOVED, ASK, or connection errors or timeout errors) - attempt = sorted( - ( - c - for c in attempt - if isinstance(c.result, ClusterPipeline.ERRORS_ALLOW_RETRY) - ), - key=lambda x: x.position, - ) - if attempt and allow_redirections: - # RETRY MAGIC HAPPENS HERE! - # send these remaining commands one at a time using `execute_command` - # in the main client. This keeps our retry logic - # in one place mostly, - # and allows us to be more confident in correctness of behavior. - # at this point any speed gains from pipelining have been lost - # anyway, so we might as well make the best - # attempt to get the correct behavior. - # - # The client command will handle retries for each - # individual command sequentially as we pass each - # one into `execute_command`. Any exceptions - # that bubble out should only appear once all - # retries have been exhausted. - # - # If a lot of commands have failed, we'll be setting the - # flag to rebuild the slots table from scratch. - # So MOVED errors should correct themselves fairly quickly. - self._pipe.reinitialize_counter += 1 - if self._pipe._should_reinitialized(): - self._nodes_manager.initialize() - if is_default_node: - self._pipe.replace_default_node() - for c in attempt: - try: - # send each command individually like we - # do in the main client. - c.result = self._pipe.parent_execute_command(*c.args, **c.options) - except RedisError as e: - c.result = e - - # turn the response back into a simple flat array that corresponds - # to the sequence of commands issued in the stack in pipeline.execute() - response = [] - for c in sorted(stack, key=lambda x: x.position): - if c.args[0] in self._pipe.cluster_response_callbacks: - # Remove keys entry, it needs only for cache. - c.options.pop("keys", None) - c.result = self._pipe.cluster_response_callbacks[c.args[0]]( - c.result, **c.options - ) - response.append(c.result) - - if raise_on_error: - self._raise_first_error(stack) - - return response - - def _is_nodes_flag(self, target_nodes): - return isinstance(target_nodes, str) and target_nodes in self._pipe.node_flags - - def _parse_target_nodes(self, target_nodes): - if isinstance(target_nodes, list): - nodes = target_nodes - elif isinstance(target_nodes, ClusterNode): - # Supports passing a single ClusterNode as a variable - nodes = [target_nodes] - elif isinstance(target_nodes, dict): - # Supports dictionaries of the format {node_name: node}. - # It enables to execute commands with multi nodes as follows: - # rc.cluster_save_config(rc.get_primaries()) - nodes = target_nodes.values() - else: - raise TypeError( - "target_nodes type can be one of the following: " - "node_flag (PRIMARIES, REPLICAS, RANDOM, ALL_NODES)," - "ClusterNode, list, or dict. " - f"The passed type is {type(target_nodes)}" - ) - return nodes - - def _determine_nodes(self, *args, **kwargs) -> List["ClusterNode"]: - # Determine which nodes should be executed the command on. - # Returns a list of target nodes. - command = args[0].upper() - if ( - len(args) >= 2 - and f"{args[0]} {args[1]}".upper() in self._pipe.command_flags - ): - command = f"{args[0]} {args[1]}".upper() - - nodes_flag = kwargs.pop("nodes_flag", None) - if nodes_flag is not None: - # nodes flag passed by the user - command_flag = nodes_flag - else: - # get the nodes group for this command if it was predefined - command_flag = self._pipe.command_flags.get(command) - if command_flag == self._pipe.RANDOM: - # return a random node - return [self._pipe.get_random_node()] - elif command_flag == self._pipe.PRIMARIES: - # return all primaries - return self._pipe.get_primaries() - elif command_flag == self._pipe.REPLICAS: - # return all replicas - return self._pipe.get_replicas() - elif command_flag == self._pipe.ALL_NODES: - # return all nodes - return self._pipe.get_nodes() - elif command_flag == self._pipe.DEFAULT_NODE: - # return the cluster's default node - return [self._nodes_manager.default_node] - elif command in self._pipe.SEARCH_COMMANDS[0]: - return [self._nodes_manager.default_node] - else: - # get the node that holds the key's slot - slot = self._pipe.determine_slot(*args) - node = self._nodes_manager.get_node_from_slot( - slot, - self._pipe.read_from_replicas and command in READ_COMMANDS, - self._pipe.load_balancing_strategy - if command in READ_COMMANDS - else None, - ) - return [node] - - def multi(self): - raise RedisClusterException( - "method multi() is not supported outside of transactional context" - ) - - def discard(self): - raise RedisClusterException( - "method discard() is not supported outside of transactional context" - ) - - def watch(self, *names): - raise RedisClusterException( - "method watch() is not supported outside of transactional context" - ) - - def unwatch(self, *names): - raise RedisClusterException( - "method unwatch() is not supported outside of transactional context" - ) - - def delete(self, *names): - if len(names) != 1: - raise RedisClusterException( - "deleting multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("DEL", names[0]) - - def unlink(self, *names): - if len(names) != 1: - raise RedisClusterException( - "unlinking multiple keys is not implemented in pipeline command" - ) - - return self.execute_command("UNLINK", names[0]) - - -class TransactionStrategy(AbstractStrategy): - NO_SLOTS_COMMANDS = {"UNWATCH"} - IMMEDIATE_EXECUTE_COMMANDS = {"WATCH", "UNWATCH"} - UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} - SLOT_REDIRECT_ERRORS = (AskError, MovedError) - CONNECTION_ERRORS = ( - ConnectionError, - OSError, - ClusterDownError, - SlotNotCoveredError, - ) - - def __init__(self, pipe: ClusterPipeline): - super().__init__(pipe) - self._explicit_transaction = False - self._watching = False - self._pipeline_slots: Set[int] = set() - self._transaction_connection: Optional[Connection] = None - self._executing = False - self._retry = copy(self._pipe.retry) - self._retry.update_supported_errors( - RedisCluster.ERRORS_ALLOW_RETRY + self.SLOT_REDIRECT_ERRORS - ) - - def _get_client_and_connection_for_transaction(self) -> Tuple[Redis, Connection]: - """ - Find a connection for a pipeline transaction. - - For running an atomic transaction, watch keys ensure that contents have not been - altered as long as the watch commands for those keys were sent over the same - connection. So once we start watching a key, we fetch a connection to the - node that owns that slot and reuse it. - """ - if not self._pipeline_slots: - raise RedisClusterException( - "At least a command with a key is needed to identify a node" - ) - - node: ClusterNode = self._nodes_manager.get_node_from_slot( - list(self._pipeline_slots)[0], False - ) - redis_node: Redis = self._pipe.get_redis_connection(node) - if self._transaction_connection: - if not redis_node.connection_pool.owns_connection( - self._transaction_connection - ): - previous_node = self._nodes_manager.find_connection_owner( - self._transaction_connection - ) - previous_node.connection_pool.release(self._transaction_connection) - self._transaction_connection = None - - if not self._transaction_connection: - self._transaction_connection = get_connection(redis_node) - - return redis_node, self._transaction_connection - - def execute_command(self, *args, **kwargs): - slot_number: Optional[int] = None - if args[0] not in ClusterPipeline.NO_SLOTS_COMMANDS: - slot_number = self._pipe.determine_slot(*args) - - if ( - self._watching or args[0] in self.IMMEDIATE_EXECUTE_COMMANDS - ) and not self._explicit_transaction: - if args[0] == "WATCH": - self._validate_watch() - - if slot_number is not None: - if self._pipeline_slots and slot_number not in self._pipeline_slots: - raise CrossSlotTransactionError( - "Cannot watch or send commands on different slots" - ) - - self._pipeline_slots.add(slot_number) - elif args[0] not in self.NO_SLOTS_COMMANDS: - raise RedisClusterException( - f"Cannot identify slot number for command: {args[0]}," - "it cannot be triggered in a transaction" - ) - - return self._immediate_execute_command(*args, **kwargs) - else: - if slot_number is not None: - self._pipeline_slots.add(slot_number) - - return self.pipeline_execute_command(*args, **kwargs) - - def _validate_watch(self): - if self._explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") - - self._watching = True - - def _immediate_execute_command(self, *args, **options): - return self._retry.call_with_retry( - lambda: self._get_connection_and_send_command(*args, **options), - self._reinitialize_on_error, - ) - - def _get_connection_and_send_command(self, *args, **options): - redis_node, connection = self._get_client_and_connection_for_transaction() - return self._send_command_parse_response( - connection, redis_node, args[0], *args, **options - ) - - def _send_command_parse_response( - self, conn, redis_node: Redis, command_name, *args, **options - ): - """ - Send a command and parse the response - """ - - conn.send_command(*args) - output = redis_node.parse_response(conn, command_name, **options) - - if command_name in self.UNWATCH_COMMANDS: - self._watching = False - return output - - def _reinitialize_on_error(self, error): - if self._watching: - if type(error) in self.SLOT_REDIRECT_ERRORS and self._executing: - raise WatchError("Slot rebalancing occurred while watching keys") - - if ( - type(error) in self.SLOT_REDIRECT_ERRORS - or type(error) in self.CONNECTION_ERRORS - ): - if self._transaction_connection: - self._transaction_connection = None - - self._pipe.reinitialize_counter += 1 - if self._pipe._should_reinitialized(): - self._nodes_manager.initialize() - self.reinitialize_counter = 0 - else: - self._nodes_manager.update_moved_exception(error) - - self._executing = False - - def _raise_first_error(self, responses, stack): - """ - Raise the first exception on the stack - """ - for r, cmd in zip(responses, stack): - if isinstance(r, Exception): - self.annotate_exception(r, cmd.position + 1, cmd.args) - raise r - - def execute(self, raise_on_error: bool = True) -> List[Any]: - stack = self._command_queue - if not stack and (not self._watching or not self._pipeline_slots): - return [] - - return self._execute_transaction_with_retries(stack, raise_on_error) - - def _execute_transaction_with_retries( - self, stack: List["PipelineCommand"], raise_on_error: bool - ): - return self._retry.call_with_retry( - lambda: self._execute_transaction(stack, raise_on_error), - self._reinitialize_on_error, - ) - - def _execute_transaction( - self, stack: List["PipelineCommand"], raise_on_error: bool - ): - if len(self._pipeline_slots) > 1: - raise CrossSlotTransactionError( - "All keys involved in a cluster transaction must map to the same slot" - ) - - self._executing = True - - redis_node, connection = self._get_client_and_connection_for_transaction() - - stack = chain( - [PipelineCommand(("MULTI",))], - stack, - [PipelineCommand(("EXEC",))], - ) - commands = [c.args for c in stack if EMPTY_RESPONSE not in c.options] - packed_commands = connection.pack_commands(commands) - connection.send_packed_command(packed_commands) - errors = [] - - # parse off the response for MULTI - # NOTE: we need to handle ResponseErrors here and continue - # so that we read all the additional command messages from - # the socket - try: - redis_node.parse_response(connection, "MULTI") - except ResponseError as e: - self.annotate_exception(e, 0, "MULTI") - errors.append(e) - except self.CONNECTION_ERRORS as cluster_error: - self.annotate_exception(cluster_error, 0, "MULTI") - raise - - # and all the other commands - for i, command in enumerate(self._command_queue): - if EMPTY_RESPONSE in command.options: - errors.append((i, command.options[EMPTY_RESPONSE])) - else: - try: - _ = redis_node.parse_response(connection, "_") - except self.SLOT_REDIRECT_ERRORS as slot_error: - self.annotate_exception(slot_error, i + 1, command.args) - errors.append(slot_error) - except self.CONNECTION_ERRORS as cluster_error: - self.annotate_exception(cluster_error, i + 1, command.args) - raise - except ResponseError as e: - self.annotate_exception(e, i + 1, command.args) - errors.append(e) - - response = None - # parse the EXEC. - try: - response = redis_node.parse_response(connection, "EXEC") - except ExecAbortError: - if errors: - raise errors[0] - raise - - self._executing = False - - # EXEC clears any watched keys - self._watching = False - - if response is None: - raise WatchError("Watched variable changed.") - - # put any parse errors into the response - for i, e in errors: - response.insert(i, e) - - if len(response) != len(self._command_queue): - raise InvalidPipelineStack( - "Unexpected response length for cluster pipeline EXEC." - " Command stack was {} but response had length {}".format( - [c.args[0] for c in self._command_queue], len(response) - ) - ) - - # find any errors in the response and raise if necessary - if raise_on_error or len(errors) > 0: - self._raise_first_error( - response, - self._command_queue, - ) - - # We have to run response callbacks manually - data = [] - for r, cmd in zip(response, self._command_queue): - if not isinstance(r, Exception): - command_name = cmd.args[0] - if command_name in self._pipe.cluster_response_callbacks: - r = self._pipe.cluster_response_callbacks[command_name]( - r, **cmd.options - ) - data.append(r) - return data - - def reset(self): - self._command_queue = [] - - # make sure to reset the connection state in the event that we were - # watching something - if self._transaction_connection: - try: - if self._watching: - # call this manually since our unwatch or - # immediate_execute_command methods can call reset() - self._transaction_connection.send_command("UNWATCH") - self._transaction_connection.read_response() - # we can safely return the connection to the pool here since we're - # sure we're no longer WATCHing anything - node = self._nodes_manager.find_connection_owner( - self._transaction_connection - ) - node.redis_connection.connection_pool.release( - self._transaction_connection - ) - self._transaction_connection = None - except self.CONNECTION_ERRORS: - # disconnect will also remove any previous WATCHes - if self._transaction_connection: - self._transaction_connection.disconnect() - - # clean up the other instance attributes - self._watching = False - self._explicit_transaction = False - self._pipeline_slots = set() - self._executing = False - - def send_cluster_commands( - self, stack, raise_on_error=True, allow_redirections=True - ): - raise NotImplementedError( - "send_cluster_commands cannot be executed in transactional context." - ) - - def multi(self): - if self._explicit_transaction: - raise RedisError("Cannot issue nested calls to MULTI") - if self._command_queue: - raise RedisError( - "Commands without an initial WATCH have already been issued" - ) - self._explicit_transaction = True - - def watch(self, *names): - if self._explicit_transaction: - raise RedisError("Cannot issue a WATCH after a MULTI") - - return self.execute_command("WATCH", *names) - - def unwatch(self): - if self._watching: - return self.execute_command("UNWATCH") - - return True - - def discard(self): - self.reset() - - def delete(self, *names): - return self.execute_command("DEL", *names) - - def unlink(self, *names): - return self.execute_command("UNLINK", *names) diff --git a/venv/lib/python3.12/site-packages/redis/commands/bf/__init__.py b/venv/lib/python3.12/site-packages/redis/commands/bf/__init__.py index 29c5c18..959358f 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/bf/__init__.py +++ b/venv/lib/python3.12/site-packages/redis/commands/bf/__init__.py @@ -5,7 +5,7 @@ from .commands import * # noqa from .info import BFInfo, CFInfo, CMSInfo, TDigestInfo, TopKInfo -class AbstractBloom: +class AbstractBloom(object): """ The client allows to interact with RedisBloom and use all of it's functionality. diff --git a/venv/lib/python3.12/site-packages/redis/commands/bf/commands.py b/venv/lib/python3.12/site-packages/redis/commands/bf/commands.py index 0a88505..447f844 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/bf/commands.py +++ b/venv/lib/python3.12/site-packages/redis/commands/bf/commands.py @@ -1,5 +1,6 @@ from redis.client import NEVER_DECODE -from redis.utils import deprecated_function +from redis.exceptions import ModuleError +from redis.utils import HIREDIS_AVAILABLE, deprecated_function BF_RESERVE = "BF.RESERVE" BF_ADD = "BF.ADD" @@ -138,6 +139,9 @@ class BFCommands: This command will return successive (iter, data) pairs until (0, NULL) to indicate completion. For more information see `BF.SCANDUMP `_. """ # noqa + if HIREDIS_AVAILABLE: + raise ModuleError("This command cannot be used when hiredis is available.") + params = [key, iter] options = {} options[NEVER_DECODE] = [] diff --git a/venv/lib/python3.12/site-packages/redis/commands/bf/info.py b/venv/lib/python3.12/site-packages/redis/commands/bf/info.py index 1a876c1..e1f0208 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/bf/info.py +++ b/venv/lib/python3.12/site-packages/redis/commands/bf/info.py @@ -1,7 +1,7 @@ from ..helpers import nativestr -class BFInfo: +class BFInfo(object): capacity = None size = None filterNum = None @@ -26,7 +26,7 @@ class BFInfo: return getattr(self, item) -class CFInfo: +class CFInfo(object): size = None bucketNum = None filterNum = None @@ -57,7 +57,7 @@ class CFInfo: return getattr(self, item) -class CMSInfo: +class CMSInfo(object): width = None depth = None count = None @@ -72,7 +72,7 @@ class CMSInfo: return getattr(self, item) -class TopKInfo: +class TopKInfo(object): k = None width = None depth = None @@ -89,7 +89,7 @@ class TopKInfo: return getattr(self, item) -class TDigestInfo: +class TDigestInfo(object): compression = None capacity = None merged_nodes = None diff --git a/venv/lib/python3.12/site-packages/redis/commands/cluster.py b/venv/lib/python3.12/site-packages/redis/commands/cluster.py index 13f2035..14b8741 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/cluster.py +++ b/venv/lib/python3.12/site-packages/redis/commands/cluster.py @@ -7,13 +7,13 @@ from typing import ( Iterable, Iterator, List, - Literal, Mapping, NoReturn, Optional, Union, ) +from redis.compat import Literal from redis.crc import key_slot from redis.exceptions import RedisClusterException, RedisError from redis.typing import ( @@ -23,7 +23,6 @@ from redis.typing import ( KeysT, KeyT, PatternT, - ResponseT, ) from .core import ( @@ -31,18 +30,21 @@ from .core import ( AsyncACLCommands, AsyncDataAccessCommands, AsyncFunctionCommands, + AsyncGearsCommands, AsyncManagementCommands, AsyncModuleCommands, AsyncScriptCommands, DataAccessCommands, FunctionCommands, + GearsCommands, ManagementCommands, ModuleCommands, PubSubCommands, + ResponseT, ScriptCommands, ) from .helpers import list_or_args -from .redismodules import AsyncRedisModuleCommands, RedisModuleCommands +from .redismodules import RedisModuleCommands if TYPE_CHECKING: from redis.asyncio.cluster import TargetNodesT @@ -223,7 +225,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol): The keys are first split up into slots and then an DEL command is sent for every slot - Non-existent keys are ignored. + Non-existant keys are ignored. Returns the number of keys that were deleted. For more information see https://redis.io/commands/del @@ -238,7 +240,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol): The keys are first split up into slots and then an TOUCH command is sent for every slot - Non-existent keys are ignored. + Non-existant keys are ignored. Returns the number of keys that were touched. For more information see https://redis.io/commands/touch @@ -252,7 +254,7 @@ class ClusterMultiKeyCommands(ClusterCommandsProtocol): The keys are first split up into slots and then an TOUCH command is sent for every slot - Non-existent keys are ignored. + Non-existant keys are ignored. Returns the number of keys that were unlinked. For more information see https://redis.io/commands/unlink @@ -593,7 +595,7 @@ class ClusterManagementCommands(ManagementCommands): "CLUSTER SETSLOT", slot_id, state, node_id, target_nodes=target_node ) elif state.upper() == "STABLE": - raise RedisError('For "stable" state please use cluster_setslot_stable') + raise RedisError('For "stable" state please use ' "cluster_setslot_stable") else: raise RedisError(f"Invalid slot state: {state}") @@ -691,6 +693,12 @@ class ClusterManagementCommands(ManagementCommands): self.read_from_replicas = False return self.execute_command("READWRITE", target_nodes=target_nodes) + def gears_refresh_cluster(self, **kwargs) -> ResponseT: + """ + On an OSS cluster, before executing any gears function, you must call this command. # noqa + """ + return self.execute_command("REDISGEARS_2.REFRESHCLUSTER", **kwargs) + class AsyncClusterManagementCommands( ClusterManagementCommands, AsyncManagementCommands @@ -866,6 +874,7 @@ class RedisClusterCommands( ClusterDataAccessCommands, ScriptCommands, FunctionCommands, + GearsCommands, ModuleCommands, RedisModuleCommands, ): @@ -896,8 +905,8 @@ class AsyncRedisClusterCommands( AsyncClusterDataAccessCommands, AsyncScriptCommands, AsyncFunctionCommands, + AsyncGearsCommands, AsyncModuleCommands, - AsyncRedisModuleCommands, ): """ A class for all Redis Cluster commands diff --git a/venv/lib/python3.12/site-packages/redis/commands/core.py b/venv/lib/python3.12/site-packages/redis/commands/core.py index d6fb550..e73553e 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/core.py +++ b/venv/lib/python3.12/site-packages/redis/commands/core.py @@ -3,7 +3,6 @@ import datetime import hashlib import warnings -from enum import Enum from typing import ( TYPE_CHECKING, Any, @@ -14,7 +13,6 @@ from typing import ( Iterable, Iterator, List, - Literal, Mapping, Optional, Sequence, @@ -23,6 +21,7 @@ from typing import ( Union, ) +from redis.compat import Literal from redis.exceptions import ConnectionError, DataError, NoScriptError, RedisError from redis.typing import ( AbsExpiryT, @@ -37,24 +36,20 @@ from redis.typing import ( GroupT, KeysT, KeyT, - Number, PatternT, - ResponseT, ScriptTextT, StreamIdT, TimeoutSecT, ZScoreBoundT, ) -from redis.utils import ( - deprecated_function, - extract_expire_flags, -) from .helpers import list_or_args if TYPE_CHECKING: - import redis.asyncio.client - import redis.client + from redis.asyncio.client import Redis as AsyncRedis + from redis.client import Redis + +ResponseT = Union[Awaitable, Any] class ACLCommands(CommandsProtocol): @@ -63,7 +58,7 @@ class ACLCommands(CommandsProtocol): see: https://redis.io/topics/acl """ - def acl_cat(self, category: Optional[str] = None, **kwargs) -> ResponseT: + def acl_cat(self, category: Union[str, None] = None, **kwargs) -> ResponseT: """ Returns a list of categories or commands within a category. @@ -86,13 +81,13 @@ class ACLCommands(CommandsProtocol): def acl_deluser(self, *username: str, **kwargs) -> ResponseT: """ - Delete the ACL for the specified ``username``\\s + Delete the ACL for the specified ``username``s For more information see https://redis.io/commands/acl-deluser """ return self.execute_command("ACL DELUSER", *username, **kwargs) - def acl_genpass(self, bits: Optional[int] = None, **kwargs) -> ResponseT: + def acl_genpass(self, bits: Union[int, None] = None, **kwargs) -> ResponseT: """Generate a random password value. If ``bits`` is supplied then use this number of bits, rounded to the next multiple of 4. @@ -104,7 +99,6 @@ class ACLCommands(CommandsProtocol): b = int(bits) if b < 0 or b > 4096: raise ValueError - pieces.append(b) except ValueError: raise DataError( "genpass optionally accepts a bits argument, between 0 and 4096." @@ -137,7 +131,7 @@ class ACLCommands(CommandsProtocol): """ return self.execute_command("ACL LIST", **kwargs) - def acl_log(self, count: Optional[int] = None, **kwargs) -> ResponseT: + def acl_log(self, count: Union[int, None] = None, **kwargs) -> ResponseT: """ Get ACL logs as a list. :param int count: Get logs[0:count]. @@ -190,8 +184,8 @@ class ACLCommands(CommandsProtocol): username: str, enabled: bool = False, nopass: bool = False, - passwords: Optional[Union[str, Iterable[str]]] = None, - hashed_passwords: Optional[Union[str, Iterable[str]]] = None, + passwords: Union[str, Iterable[str], None] = None, + hashed_passwords: Union[str, Iterable[str], None] = None, categories: Optional[Iterable[str]] = None, commands: Optional[Iterable[str]] = None, keys: Optional[Iterable[KeyT]] = None, @@ -206,59 +200,69 @@ class ACLCommands(CommandsProtocol): """ Create or update an ACL user. - Create or update the ACL for `username`. If the user already exists, + Create or update the ACL for ``username``. If the user already exists, the existing ACL is completely overwritten and replaced with the specified values. - For more information, see https://redis.io/commands/acl-setuser + ``enabled`` is a boolean indicating whether the user should be allowed + to authenticate or not. Defaults to ``False``. - Args: - username: The name of the user whose ACL is to be created or updated. - enabled: Indicates whether the user should be allowed to authenticate. - Defaults to `False`. - nopass: Indicates whether the user can authenticate without a password. - This cannot be `True` if `passwords` are also specified. - passwords: A list of plain text passwords to add to or remove from the user. - Each password must be prefixed with a '+' to add or a '-' to - remove. For convenience, a single prefixed string can be used - when adding or removing a single password. - hashed_passwords: A list of SHA-256 hashed passwords to add to or remove - from the user. Each hashed password must be prefixed with - a '+' to add or a '-' to remove. For convenience, a single - prefixed string can be used when adding or removing a - single password. - categories: A list of strings representing category permissions. Each string - must be prefixed with either a '+' to add the category - permission or a '-' to remove the category permission. - commands: A list of strings representing command permissions. Each string - must be prefixed with either a '+' to add the command permission - or a '-' to remove the command permission. - keys: A list of key patterns to grant the user access to. Key patterns allow - ``'*'`` to support wildcard matching. For example, ``'*'`` grants - access to all keys while ``'cache:*'`` grants access to all keys that - are prefixed with ``cache:``. - `keys` should not be prefixed with a ``'~'``. - reset: Indicates whether the user should be fully reset prior to applying - the new ACL. Setting this to `True` will remove all existing - passwords, flags, and privileges from the user and then apply the - specified rules. If `False`, the user's existing passwords, flags, - and privileges will be kept and any new specified rules will be - applied on top. - reset_keys: Indicates whether the user's key permissions should be reset - prior to applying any new key permissions specified in `keys`. - If `False`, the user's existing key permissions will be kept and - any new specified key permissions will be applied on top. - reset_channels: Indicates whether the user's channel permissions should be - reset prior to applying any new channel permissions - specified in `channels`. If `False`, the user's existing - channel permissions will be kept and any new specified - channel permissions will be applied on top. - reset_passwords: Indicates whether to remove all existing passwords and the - `nopass` flag from the user prior to applying any new - passwords specified in `passwords` or `hashed_passwords`. - If `False`, the user's existing passwords and `nopass` - status will be kept and any new specified passwords or - hashed passwords will be applied on top. + ``nopass`` is a boolean indicating whether the can authenticate without + a password. This cannot be True if ``passwords`` are also specified. + + ``passwords`` if specified is a list of plain text passwords + to add to or remove from the user. Each password must be prefixed with + a '+' to add or a '-' to remove. For convenience, the value of + ``passwords`` can be a simple prefixed string when adding or + removing a single password. + + ``hashed_passwords`` if specified is a list of SHA-256 hashed passwords + to add to or remove from the user. Each hashed password must be + prefixed with a '+' to add or a '-' to remove. For convenience, + the value of ``hashed_passwords`` can be a simple prefixed string when + adding or removing a single password. + + ``categories`` if specified is a list of strings representing category + permissions. Each string must be prefixed with either a '+' to add the + category permission or a '-' to remove the category permission. + + ``commands`` if specified is a list of strings representing command + permissions. Each string must be prefixed with either a '+' to add the + command permission or a '-' to remove the command permission. + + ``keys`` if specified is a list of key patterns to grant the user + access to. Keys patterns allow '*' to support wildcard matching. For + example, '*' grants access to all keys while 'cache:*' grants access + to all keys that are prefixed with 'cache:'. ``keys`` should not be + prefixed with a '~'. + + ``reset`` is a boolean indicating whether the user should be fully + reset prior to applying the new ACL. Setting this to True will + remove all existing passwords, flags and privileges from the user and + then apply the specified rules. If this is False, the user's existing + passwords, flags and privileges will be kept and any new specified + rules will be applied on top. + + ``reset_keys`` is a boolean indicating whether the user's key + permissions should be reset prior to applying any new key permissions + specified in ``keys``. If this is False, the user's existing + key permissions will be kept and any new specified key permissions + will be applied on top. + + ``reset_channels`` is a boolean indicating whether the user's channel + permissions should be reset prior to applying any new channel permissions + specified in ``channels``.If this is False, the user's existing + channel permissions will be kept and any new specified channel permissions + will be applied on top. + + ``reset_passwords`` is a boolean indicating whether to remove all + existing passwords and the 'nopass' flag from the user prior to + applying any new passwords specified in 'passwords' or + 'hashed_passwords'. If this is False, the user's existing passwords + and 'nopass' status will be kept and any new specified passwords + or hashed_passwords will be applied on top. + + For more information see https://redis.io/commands/acl-setuser """ encoder = self.get_encoder() pieces: List[EncodableT] = [username] @@ -450,13 +454,12 @@ class ManagementCommands(CommandsProtocol): def client_kill_filter( self, - _id: Optional[str] = None, - _type: Optional[str] = None, - addr: Optional[str] = None, - skipme: Optional[bool] = None, - laddr: Optional[bool] = None, - user: Optional[str] = None, - maxage: Optional[int] = None, + _id: Union[str, None] = None, + _type: Union[str, None] = None, + addr: Union[str, None] = None, + skipme: Union[bool, None] = None, + laddr: Union[bool, None] = None, + user: str = None, **kwargs, ) -> ResponseT: """ @@ -470,7 +473,6 @@ class ManagementCommands(CommandsProtocol): options. If skipme is not provided, the server defaults to skipme=True :param laddr: Kills a client by its 'local (bind) address:port' :param user: Kills a client for a specific user name - :param maxage: Kills clients that are older than the specified age in seconds """ args = [] if _type is not None: @@ -493,8 +495,6 @@ class ManagementCommands(CommandsProtocol): args.extend((b"LADDR", laddr)) if user is not None: args.extend((b"USER", user)) - if maxage is not None: - args.extend((b"MAXAGE", maxage)) if not args: raise DataError( "CLIENT KILL ... ... " @@ -512,7 +512,7 @@ class ManagementCommands(CommandsProtocol): return self.execute_command("CLIENT INFO", **kwargs) def client_list( - self, _type: Optional[str] = None, client_id: List[EncodableT] = [], **kwargs + self, _type: Union[str, None] = None, client_id: List[EncodableT] = [], **kwargs ) -> ResponseT: """ Returns a list of currently connected clients. @@ -535,7 +535,7 @@ class ManagementCommands(CommandsProtocol): raise DataError("client_id must be a list") if client_id: args.append(b"ID") - args += client_id + args.append(" ".join(client_id)) return self.execute_command("CLIENT LIST", *args, **kwargs) def client_getname(self, **kwargs) -> ResponseT: @@ -589,7 +589,7 @@ class ManagementCommands(CommandsProtocol): def client_tracking_on( self, - clientid: Optional[int] = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -608,7 +608,7 @@ class ManagementCommands(CommandsProtocol): def client_tracking_off( self, - clientid: Optional[int] = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -628,7 +628,7 @@ class ManagementCommands(CommandsProtocol): def client_tracking( self, on: bool = True, - clientid: Optional[int] = None, + clientid: Union[int, None] = None, prefix: Sequence[KeyT] = [], bcast: bool = False, optin: bool = False, @@ -738,19 +738,16 @@ class ManagementCommands(CommandsProtocol): For more information see https://redis.io/commands/client-pause - Args: - timeout: milliseconds to pause clients - all: If true (default) all client commands are blocked. - otherwise, clients are only blocked if they attempt to execute - a write command. - + :param timeout: milliseconds to pause clients + :param all: If true (default) all client commands are blocked. + otherwise, clients are only blocked if they attempt to execute + a write command. For the WRITE mode, some commands have special behavior: - - * EVAL/EVALSHA: Will block client for all scripts. - * PUBLISH: Will block client. - * PFCOUNT: Will block client. - * WAIT: Acknowledgments will be delayed, so this command will - appear blocked. + EVAL/EVALSHA: Will block client for all scripts. + PUBLISH: Will block client. + PFCOUNT: Will block client. + WAIT: Acknowledgments will be delayed, so this command will + appear blocked. """ args = ["CLIENT PAUSE", str(timeout)] if not isinstance(timeout, int): @@ -988,7 +985,7 @@ class ManagementCommands(CommandsProtocol): return self.execute_command("SELECT", index, **kwargs) def info( - self, section: Optional[str] = None, *args: List[str], **kwargs + self, section: Union[str, None] = None, *args: List[str], **kwargs ) -> ResponseT: """ Returns a dictionary containing information about the Redis server @@ -1070,7 +1067,7 @@ class ManagementCommands(CommandsProtocol): timeout: int, copy: bool = False, replace: bool = False, - auth: Optional[str] = None, + auth: Union[str, None] = None, **kwargs, ) -> ResponseT: """ @@ -1152,7 +1149,7 @@ class ManagementCommands(CommandsProtocol): return self.execute_command("MEMORY MALLOC-STATS", **kwargs) def memory_usage( - self, key: KeyT, samples: Optional[int] = None, **kwargs + self, key: KeyT, samples: Union[int, None] = None, **kwargs ) -> ResponseT: """ Return the total memory usage for key, its value and associated @@ -1291,7 +1288,7 @@ class ManagementCommands(CommandsProtocol): raise RedisError("SHUTDOWN seems to have failed.") def slaveof( - self, host: Optional[str] = None, port: Optional[int] = None, **kwargs + self, host: Union[str, None] = None, port: Union[int, None] = None, **kwargs ) -> ResponseT: """ Set the server to be a replicated slave of the instance identified @@ -1304,7 +1301,7 @@ class ManagementCommands(CommandsProtocol): return self.execute_command("SLAVEOF", b"NO", b"ONE", **kwargs) return self.execute_command("SLAVEOF", host, port, **kwargs) - def slowlog_get(self, num: Optional[int] = None, **kwargs) -> ResponseT: + def slowlog_get(self, num: Union[int, None] = None, **kwargs) -> ResponseT: """ Get the entries from the slowlog. If ``num`` is specified, get the most recent ``num`` items. @@ -1391,6 +1388,9 @@ class ManagementCommands(CommandsProtocol): ) +AsyncManagementCommands = ManagementCommands + + class AsyncManagementCommands(ManagementCommands): async def command_info(self, **kwargs) -> None: return super().command_info(**kwargs) @@ -1449,9 +1449,9 @@ class BitFieldOperation: def __init__( self, - client: Union["redis.client.Redis", "redis.asyncio.client.Redis"], + client: Union["Redis", "AsyncRedis"], key: str, - default_overflow: Optional[str] = None, + default_overflow: Union[str, None] = None, ): self.client = client self.key = key @@ -1487,7 +1487,7 @@ class BitFieldOperation: fmt: str, offset: BitfieldOffsetT, increment: int, - overflow: Optional[str] = None, + overflow: Union[str, None] = None, ): """ Increment a bitfield by a given amount. @@ -1572,8 +1572,8 @@ class BasicKeyCommands(CommandsProtocol): def bitcount( self, key: KeyT, - start: Optional[int] = None, - end: Optional[int] = None, + start: Union[int, None] = None, + end: Union[int, None] = None, mode: Optional[str] = None, ) -> ResponseT: """ @@ -1590,12 +1590,12 @@ class BasicKeyCommands(CommandsProtocol): raise DataError("Both start and end must be specified") if mode is not None: params.append(mode) - return self.execute_command("BITCOUNT", *params, keys=[key]) + return self.execute_command("BITCOUNT", *params) def bitfield( - self: Union["redis.client.Redis", "redis.asyncio.client.Redis"], + self: Union["Redis", "AsyncRedis"], key: KeyT, - default_overflow: Optional[str] = None, + default_overflow: Union[str, None] = None, ) -> BitFieldOperation: """ Return a BitFieldOperation instance to conveniently construct one or @@ -1606,7 +1606,7 @@ class BasicKeyCommands(CommandsProtocol): return BitFieldOperation(self, key, default_overflow=default_overflow) def bitfield_ro( - self: Union["redis.client.Redis", "redis.asyncio.client.Redis"], + self: Union["Redis", "AsyncRedis"], key: KeyT, encoding: str, offset: BitfieldOffsetT, @@ -1626,7 +1626,7 @@ class BasicKeyCommands(CommandsProtocol): items = items or [] for encoding, offset in items: params.extend(["GET", encoding, offset]) - return self.execute_command("BITFIELD_RO", *params, keys=[key]) + return self.execute_command("BITFIELD_RO", *params) def bitop(self, operation: str, dest: KeyT, *keys: KeyT) -> ResponseT: """ @@ -1641,8 +1641,8 @@ class BasicKeyCommands(CommandsProtocol): self, key: KeyT, bit: int, - start: Optional[int] = None, - end: Optional[int] = None, + start: Union[int, None] = None, + end: Union[int, None] = None, mode: Optional[str] = None, ) -> ResponseT: """ @@ -1666,13 +1666,13 @@ class BasicKeyCommands(CommandsProtocol): if mode is not None: params.append(mode) - return self.execute_command("BITPOS", *params, keys=[key]) + return self.execute_command("BITPOS", *params) def copy( self, source: str, destination: str, - destination_db: Optional[str] = None, + destination_db: Union[str, None] = None, replace: bool = False, ) -> ResponseT: """ @@ -1733,7 +1733,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/exists """ - return self.execute_command("EXISTS", *names, keys=names) + return self.execute_command("EXISTS", *names) __contains__ = exists @@ -1826,7 +1826,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/get """ - return self.execute_command("GET", name, keys=[name]) + return self.execute_command("GET", name) def getdel(self, name: KeyT) -> ResponseT: """ @@ -1842,10 +1842,10 @@ class BasicKeyCommands(CommandsProtocol): def getex( self, name: KeyT, - ex: Optional[ExpiryT] = None, - px: Optional[ExpiryT] = None, - exat: Optional[AbsExpiryT] = None, - pxat: Optional[AbsExpiryT] = None, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, persist: bool = False, ) -> ResponseT: """ @@ -1868,6 +1868,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/getex """ + opset = {ex, px, exat, pxat} if len(opset) > 2 or len(opset) > 1 and persist: raise DataError( @@ -1875,12 +1876,33 @@ class BasicKeyCommands(CommandsProtocol): "and ``persist`` are mutually exclusive." ) - exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) - + pieces: list[EncodableT] = [] + # similar to set command + if ex is not None: + pieces.append("EX") + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append("PX") + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + # similar to pexpireat command + if exat is not None: + pieces.append("EXAT") + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + pieces.append(exat) + if pxat is not None: + pieces.append("PXAT") + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + pieces.append(pxat) if persist: - exp_options.append("PERSIST") + pieces.append("PERSIST") - return self.execute_command("GETEX", name, *exp_options) + return self.execute_command("GETEX", name, *pieces) def __getitem__(self, name: KeyT): """ @@ -1898,7 +1920,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/getbit """ - return self.execute_command("GETBIT", name, offset, keys=[name]) + return self.execute_command("GETBIT", name, offset) def getrange(self, key: KeyT, start: int, end: int) -> ResponseT: """ @@ -1907,7 +1929,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/getrange """ - return self.execute_command("GETRANGE", key, start, end, keys=[key]) + return self.execute_command("GETRANGE", key, start, end) def getset(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -1990,7 +2012,6 @@ class BasicKeyCommands(CommandsProtocol): options = {} if not args: options[EMPTY_RESPONSE] = [] - options["keys"] = args return self.execute_command("MGET", *args, **options) def mset(self, mapping: Mapping[AnyKeyT, EncodableT]) -> ResponseT: @@ -2137,7 +2158,7 @@ class BasicKeyCommands(CommandsProtocol): return self.execute_command("PTTL", name) def hrandfield( - self, key: str, count: Optional[int] = None, withvalues: bool = False + self, key: str, count: int = None, withvalues: bool = False ) -> ResponseT: """ Return a random field from the hash value stored at key. @@ -2191,8 +2212,8 @@ class BasicKeyCommands(CommandsProtocol): value: EncodableT, replace: bool = False, absttl: bool = False, - idletime: Optional[int] = None, - frequency: Optional[int] = None, + idletime: Union[int, None] = None, + frequency: Union[int, None] = None, ) -> ResponseT: """ Create a key using the provided serialized value, previously obtained @@ -2238,14 +2259,14 @@ class BasicKeyCommands(CommandsProtocol): self, name: KeyT, value: EncodableT, - ex: Optional[ExpiryT] = None, - px: Optional[ExpiryT] = None, + ex: Union[ExpiryT, None] = None, + px: Union[ExpiryT, None] = None, nx: bool = False, xx: bool = False, keepttl: bool = False, get: bool = False, - exat: Optional[AbsExpiryT] = None, - pxat: Optional[AbsExpiryT] = None, + exat: Union[AbsExpiryT, None] = None, + pxat: Union[AbsExpiryT, None] = None, ) -> ResponseT: """ Set the value at key ``name`` to ``value`` @@ -2275,21 +2296,36 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/set """ - opset = {ex, px, exat, pxat} - if len(opset) > 2 or len(opset) > 1 and keepttl: - raise DataError( - "``ex``, ``px``, ``exat``, ``pxat``, " - "and ``keepttl`` are mutually exclusive." - ) - - if nx and xx: - raise DataError("``nx`` and ``xx`` are mutually exclusive.") - pieces: list[EncodableT] = [name, value] options = {} - - pieces.extend(extract_expire_flags(ex, px, exat, pxat)) - + if ex is not None: + pieces.append("EX") + if isinstance(ex, datetime.timedelta): + pieces.append(int(ex.total_seconds())) + elif isinstance(ex, int): + pieces.append(ex) + elif isinstance(ex, str) and ex.isdigit(): + pieces.append(int(ex)) + else: + raise DataError("ex must be datetime.timedelta or int") + if px is not None: + pieces.append("PX") + if isinstance(px, datetime.timedelta): + pieces.append(int(px.total_seconds() * 1000)) + elif isinstance(px, int): + pieces.append(px) + else: + raise DataError("px must be datetime.timedelta or int") + if exat is not None: + pieces.append("EXAT") + if isinstance(exat, datetime.datetime): + exat = int(exat.timestamp()) + pieces.append(exat) + if pxat is not None: + pieces.append("PXAT") + if isinstance(pxat, datetime.datetime): + pxat = int(pxat.timestamp() * 1000) + pieces.append(pxat) if keepttl: pieces.append("KEEPTTL") @@ -2360,7 +2396,7 @@ class BasicKeyCommands(CommandsProtocol): specific_argument: Union[Literal["strings"], Literal["keys"]] = "strings", len: bool = False, idx: bool = False, - minmatchlen: Optional[int] = None, + minmatchlen: Union[int, None] = None, withmatchlen: bool = False, **kwargs, ) -> ResponseT: @@ -2422,14 +2458,14 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/strlen """ - return self.execute_command("STRLEN", name, keys=[name]) + return self.execute_command("STRLEN", name) def substr(self, name: KeyT, start: int, end: int = -1) -> ResponseT: """ Return a substring of the string at key ``name``. ``start`` and ``end`` are 0-based integers specifying the portion of the string to return. """ - return self.execute_command("SUBSTR", name, start, end, keys=[name]) + return self.execute_command("SUBSTR", name, start, end) def touch(self, *args: KeyT) -> ResponseT: """ @@ -2454,7 +2490,7 @@ class BasicKeyCommands(CommandsProtocol): For more information see https://redis.io/commands/type """ - return self.execute_command("TYPE", name, keys=[name]) + return self.execute_command("TYPE", name) def watch(self, *names: KeyT) -> None: """ @@ -2466,7 +2502,7 @@ class BasicKeyCommands(CommandsProtocol): def unwatch(self) -> None: """ - Unwatches all previously watched keys for a transaction + Unwatches the value at key ``name``, or None of the key doesn't exist For more information see https://redis.io/commands/unwatch """ @@ -2507,7 +2543,7 @@ class BasicKeyCommands(CommandsProtocol): pieces.extend(["MINMATCHLEN", minmatchlen]) if withmatchlen: pieces.append("WITHMATCHLEN") - return self.execute_command("LCS", *pieces, keys=[key1, key2]) + return self.execute_command("LCS", *pieces) class AsyncBasicKeyCommands(BasicKeyCommands): @@ -2537,7 +2573,7 @@ class ListCommands(CommandsProtocol): """ def blpop( - self, keys: List, timeout: Optional[Number] = 0 + self, keys: List, timeout: Optional[int] = 0 ) -> Union[Awaitable[list], list]: """ LPOP a value off of the first non-empty list @@ -2558,7 +2594,7 @@ class ListCommands(CommandsProtocol): return self.execute_command("BLPOP", *keys) def brpop( - self, keys: List, timeout: Optional[Number] = 0 + self, keys: List, timeout: Optional[int] = 0 ) -> Union[Awaitable[list], list]: """ RPOP a value off of the first non-empty list @@ -2579,7 +2615,7 @@ class ListCommands(CommandsProtocol): return self.execute_command("BRPOP", *keys) def brpoplpush( - self, src: str, dst: str, timeout: Optional[Number] = 0 + self, src: str, dst: str, timeout: Optional[int] = 0 ) -> Union[Awaitable[Optional[str]], Optional[str]]: """ Pop a value off the tail of ``src``, push it on the head of ``dst`` @@ -2646,7 +2682,7 @@ class ListCommands(CommandsProtocol): For more information see https://redis.io/commands/lindex """ - return self.execute_command("LINDEX", name, index, keys=[name]) + return self.execute_command("LINDEX", name, index) def linsert( self, name: str, where: str, refvalue: str, value: str @@ -2668,7 +2704,7 @@ class ListCommands(CommandsProtocol): For more information see https://redis.io/commands/llen """ - return self.execute_command("LLEN", name, keys=[name]) + return self.execute_command("LLEN", name) def lpop( self, @@ -2715,7 +2751,7 @@ class ListCommands(CommandsProtocol): For more information see https://redis.io/commands/lrange """ - return self.execute_command("LRANGE", name, start, end, keys=[name]) + return self.execute_command("LRANGE", name, start, end) def lrem(self, name: str, count: int, value: str) -> Union[Awaitable[int], int]: """ @@ -2838,7 +2874,7 @@ class ListCommands(CommandsProtocol): if maxlen is not None: pieces.extend(["MAXLEN", maxlen]) - return self.execute_command("LPOS", *pieces, keys=[name]) + return self.execute_command("LPOS", *pieces) def sort( self, @@ -2910,7 +2946,6 @@ class ListCommands(CommandsProtocol): ) options = {"groups": len(get) if groups else None} - options["keys"] = [name] return self.execute_command("SORT", *pieces, **options) def sort_ro( @@ -2960,8 +2995,8 @@ class ScanCommands(CommandsProtocol): self, cursor: int = 0, match: Union[PatternT, None] = None, - count: Optional[int] = None, - _type: Optional[str] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> ResponseT: """ @@ -2992,8 +3027,8 @@ class ScanCommands(CommandsProtocol): def scan_iter( self, match: Union[PatternT, None] = None, - count: Optional[int] = None, - _type: Optional[str] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> Iterator: """ @@ -3022,7 +3057,7 @@ class ScanCommands(CommandsProtocol): name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Incrementally return lists of elements in a set. Also return a cursor @@ -3045,7 +3080,7 @@ class ScanCommands(CommandsProtocol): self, name: KeyT, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, ) -> Iterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -3065,8 +3100,7 @@ class ScanCommands(CommandsProtocol): name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Optional[int] = None, - no_values: Union[bool, None] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Incrementally return key/value slices in a hash. Also return a cursor @@ -3076,8 +3110,6 @@ class ScanCommands(CommandsProtocol): ``count`` allows for hint the minimum number of returns - ``no_values`` indicates to return only the keys, without values. - For more information see https://redis.io/commands/hscan """ pieces: list[EncodableT] = [name, cursor] @@ -3085,16 +3117,13 @@ class ScanCommands(CommandsProtocol): pieces.extend([b"MATCH", match]) if count is not None: pieces.extend([b"COUNT", count]) - if no_values is not None: - pieces.extend([b"NOVALUES"]) - return self.execute_command("HSCAN", *pieces, no_values=no_values) + return self.execute_command("HSCAN", *pieces) def hscan_iter( self, name: str, match: Union[PatternT, None] = None, - count: Optional[int] = None, - no_values: Union[bool, None] = None, + count: Union[int, None] = None, ) -> Iterator: """ Make an iterator using the HSCAN command so that the client doesn't @@ -3103,25 +3132,18 @@ class ScanCommands(CommandsProtocol): ``match`` allows for filtering the keys by pattern ``count`` allows for hint the minimum number of returns - - ``no_values`` indicates to return only the keys, without values """ cursor = "0" while cursor != 0: - cursor, data = self.hscan( - name, cursor=cursor, match=match, count=count, no_values=no_values - ) - if no_values: - yield from data - else: - yield from data.items() + cursor, data = self.hscan(name, cursor=cursor, match=match, count=count) + yield from data.items() def zscan( self, name: KeyT, cursor: int = 0, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, ) -> ResponseT: """ @@ -3148,7 +3170,7 @@ class ScanCommands(CommandsProtocol): self, name: KeyT, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, ) -> Iterator: """ @@ -3177,8 +3199,8 @@ class AsyncScanCommands(ScanCommands): async def scan_iter( self, match: Union[PatternT, None] = None, - count: Optional[int] = None, - _type: Optional[str] = None, + count: Union[int, None] = None, + _type: Union[str, None] = None, **kwargs, ) -> AsyncIterator: """ @@ -3207,7 +3229,7 @@ class AsyncScanCommands(ScanCommands): self, name: KeyT, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, ) -> AsyncIterator: """ Make an iterator using the SSCAN command so that the client doesn't @@ -3229,8 +3251,7 @@ class AsyncScanCommands(ScanCommands): self, name: str, match: Union[PatternT, None] = None, - count: Optional[int] = None, - no_values: Union[bool, None] = None, + count: Union[int, None] = None, ) -> AsyncIterator: """ Make an iterator using the HSCAN command so that the client doesn't @@ -3239,26 +3260,20 @@ class AsyncScanCommands(ScanCommands): ``match`` allows for filtering the keys by pattern ``count`` allows for hint the minimum number of returns - - ``no_values`` indicates to return only the keys, without values """ cursor = "0" while cursor != 0: cursor, data = await self.hscan( - name, cursor=cursor, match=match, count=count, no_values=no_values + name, cursor=cursor, match=match, count=count ) - if no_values: - for it in data: - yield it - else: - for it in data.items(): - yield it + for it in data.items(): + yield it async def zscan_iter( self, name: KeyT, match: Union[PatternT, None] = None, - count: Optional[int] = None, + count: Union[int, None] = None, score_cast_func: Union[type, Callable] = float, ) -> AsyncIterator: """ @@ -3290,7 +3305,7 @@ class SetCommands(CommandsProtocol): see: https://redis.io/topics/data-types#sets """ - def sadd(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: + def sadd(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ Add ``value(s)`` to set ``name`` @@ -3298,13 +3313,13 @@ class SetCommands(CommandsProtocol): """ return self.execute_command("SADD", name, *values) - def scard(self, name: KeyT) -> Union[Awaitable[int], int]: + def scard(self, name: str) -> Union[Awaitable[int], int]: """ Return the number of elements in set ``name`` For more information see https://redis.io/commands/scard """ - return self.execute_command("SCARD", name, keys=[name]) + return self.execute_command("SCARD", name) def sdiff(self, keys: List, *args: List) -> Union[Awaitable[list], list]: """ @@ -3313,7 +3328,7 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/sdiff """ args = list_or_args(keys, args) - return self.execute_command("SDIFF", *args, keys=args) + return self.execute_command("SDIFF", *args) def sdiffstore( self, dest: str, keys: List, *args: List @@ -3334,13 +3349,13 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/sinter """ args = list_or_args(keys, args) - return self.execute_command("SINTER", *args, keys=args) + return self.execute_command("SINTER", *args) def sintercard( - self, numkeys: int, keys: List[KeyT], limit: int = 0 + self, numkeys: int, keys: List[str], limit: int = 0 ) -> Union[Awaitable[int], int]: """ - Return the cardinality of the intersect of multiple sets specified by ``keys``. + Return the cardinality of the intersect of multiple sets specified by ``keys`. When LIMIT provided (defaults to 0 and means unlimited), if the intersection cardinality reaches limit partway through the computation, the algorithm will @@ -3349,10 +3364,10 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/sintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("SINTERCARD", *args, keys=keys) + return self.execute_command("SINTERCARD", *args) def sinterstore( - self, dest: KeyT, keys: List, *args: List + self, dest: str, keys: List, *args: List ) -> Union[Awaitable[int], int]: """ Store the intersection of sets specified by ``keys`` into a new @@ -3364,7 +3379,7 @@ class SetCommands(CommandsProtocol): return self.execute_command("SINTERSTORE", dest, *args) def sismember( - self, name: KeyT, value: str + self, name: str, value: str ) -> Union[Awaitable[Union[Literal[0], Literal[1]]], Union[Literal[0], Literal[1]]]: """ Return whether ``value`` is a member of set ``name``: @@ -3373,18 +3388,18 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/sismember """ - return self.execute_command("SISMEMBER", name, value, keys=[name]) + return self.execute_command("SISMEMBER", name, value) - def smembers(self, name: KeyT) -> Union[Awaitable[Set], Set]: + def smembers(self, name: str) -> Union[Awaitable[Set], Set]: """ Return all members of the set ``name`` For more information see https://redis.io/commands/smembers """ - return self.execute_command("SMEMBERS", name, keys=[name]) + return self.execute_command("SMEMBERS", name) def smismember( - self, name: KeyT, values: List, *args: List + self, name: str, values: List, *args: List ) -> Union[ Awaitable[List[Union[Literal[0], Literal[1]]]], List[Union[Literal[0], Literal[1]]], @@ -3398,9 +3413,9 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/smismember """ args = list_or_args(values, args) - return self.execute_command("SMISMEMBER", name, *args, keys=[name]) + return self.execute_command("SMISMEMBER", name, *args) - def smove(self, src: KeyT, dst: KeyT, value: str) -> Union[Awaitable[bool], bool]: + def smove(self, src: str, dst: str, value: str) -> Union[Awaitable[bool], bool]: """ Move ``value`` from set ``src`` to set ``dst`` atomically @@ -3408,7 +3423,7 @@ class SetCommands(CommandsProtocol): """ return self.execute_command("SMOVE", src, dst, value) - def spop(self, name: KeyT, count: Optional[int] = None) -> Union[str, List, None]: + def spop(self, name: str, count: Optional[int] = None) -> Union[str, List, None]: """ Remove and return a random member of set ``name`` @@ -3418,7 +3433,7 @@ class SetCommands(CommandsProtocol): return self.execute_command("SPOP", name, *args) def srandmember( - self, name: KeyT, number: Optional[int] = None + self, name: str, number: Optional[int] = None ) -> Union[str, List, None]: """ If ``number`` is None, returns a random member of set ``name``. @@ -3432,7 +3447,7 @@ class SetCommands(CommandsProtocol): args = (number is not None) and [number] or [] return self.execute_command("SRANDMEMBER", name, *args) - def srem(self, name: KeyT, *values: FieldT) -> Union[Awaitable[int], int]: + def srem(self, name: str, *values: FieldT) -> Union[Awaitable[int], int]: """ Remove ``values`` from set ``name`` @@ -3447,10 +3462,10 @@ class SetCommands(CommandsProtocol): For more information see https://redis.io/commands/sunion """ args = list_or_args(keys, args) - return self.execute_command("SUNION", *args, keys=args) + return self.execute_command("SUNION", *args) def sunionstore( - self, dest: KeyT, keys: List, *args: List + self, dest: str, keys: List, *args: List ) -> Union[Awaitable[int], int]: """ Store the union of sets specified by ``keys`` into a new @@ -3474,49 +3489,24 @@ class StreamCommands(CommandsProtocol): def xack(self, name: KeyT, groupname: GroupT, *ids: StreamIdT) -> ResponseT: """ Acknowledges the successful processing of one or more messages. - - Args: - name: name of the stream. - groupname: name of the consumer group. - *ids: message ids to acknowledge. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowledge. For more information see https://redis.io/commands/xack """ return self.execute_command("XACK", name, groupname, *ids) - def xackdel( - self, - name: KeyT, - groupname: GroupT, - *ids: StreamIdT, - ref_policy: Literal["KEEPREF", "DELREF", "ACKED"] = "KEEPREF", - ) -> ResponseT: - """ - Combines the functionality of XACK and XDEL. Acknowledges the specified - message IDs in the given consumer group and simultaneously attempts to - delete the corresponding entries from the stream. - """ - if not ids: - raise DataError("XACKDEL requires at least one message ID") - - if ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: - raise DataError("XACKDEL ref_policy must be one of: KEEPREF, DELREF, ACKED") - - pieces = [name, groupname, ref_policy, "IDS", len(ids)] - pieces.extend(ids) - return self.execute_command("XACKDEL", *pieces) - def xadd( self, name: KeyT, fields: Dict[FieldT, EncodableT], id: StreamIdT = "*", - maxlen: Optional[int] = None, + maxlen: Union[int, None] = None, approximate: bool = True, nomkstream: bool = False, minid: Union[StreamIdT, None] = None, - limit: Optional[int] = None, - ref_policy: Optional[Literal["KEEPREF", "DELREF", "ACKED"]] = None, + limit: Union[int, None] = None, ) -> ResponseT: """ Add to a stream. @@ -3530,10 +3520,6 @@ class StreamCommands(CommandsProtocol): minid: the minimum id in the stream to query. Can't be specified with maxlen. limit: specifies the maximum number of entries to retrieve - ref_policy: optional reference policy for consumer groups when trimming: - - KEEPREF (default): When trimming, preserves references in consumer groups' PEL - - DELREF: When trimming, removes all references from consumer groups' PEL - - ACKED: When trimming, only removes entries acknowledged by all consumer groups For more information see https://redis.io/commands/xadd """ @@ -3541,9 +3527,6 @@ class StreamCommands(CommandsProtocol): if maxlen is not None and minid is not None: raise DataError("Only one of ```maxlen``` or ```minid``` may be specified") - if ref_policy is not None and ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: - raise DataError("XADD ref_policy must be one of: KEEPREF, DELREF, ACKED") - if maxlen is not None: if not isinstance(maxlen, int) or maxlen < 0: raise DataError("XADD maxlen must be non-negative integer") @@ -3560,8 +3543,6 @@ class StreamCommands(CommandsProtocol): pieces.extend([b"LIMIT", limit]) if nomkstream: pieces.append(b"NOMKSTREAM") - if ref_policy is not None: - pieces.append(ref_policy) pieces.append(id) if not isinstance(fields, dict) or len(fields) == 0: raise DataError("XADD fields must be a non-empty dict") @@ -3576,7 +3557,7 @@ class StreamCommands(CommandsProtocol): consumername: ConsumerT, min_idle_time: int, start_id: StreamIdT = "0-0", - count: Optional[int] = None, + count: Union[int, None] = None, justid: bool = False, ) -> ResponseT: """ @@ -3627,9 +3608,9 @@ class StreamCommands(CommandsProtocol): consumername: ConsumerT, min_idle_time: int, message_ids: Union[List[StreamIdT], Tuple[StreamIdT]], - idle: Optional[int] = None, - time: Optional[int] = None, - retrycount: Optional[int] = None, + idle: Union[int, None] = None, + time: Union[int, None] = None, + retrycount: Union[int, None] = None, force: bool = False, justid: bool = False, ) -> ResponseT: @@ -3706,35 +3687,13 @@ class StreamCommands(CommandsProtocol): def xdel(self, name: KeyT, *ids: StreamIdT) -> ResponseT: """ Deletes one or more messages from a stream. - - Args: - name: name of the stream. - *ids: message ids to delete. + name: name of the stream. + *ids: message ids to delete. For more information see https://redis.io/commands/xdel """ return self.execute_command("XDEL", name, *ids) - def xdelex( - self, - name: KeyT, - *ids: StreamIdT, - ref_policy: Literal["KEEPREF", "DELREF", "ACKED"] = "KEEPREF", - ) -> ResponseT: - """ - Extended version of XDEL that provides more control over how message entries - are deleted concerning consumer groups. - """ - if not ids: - raise DataError("XDELEX requires at least one message ID") - - if ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: - raise DataError("XDELEX ref_policy must be one of: KEEPREF, DELREF, ACKED") - - pieces = [name, ref_policy, "IDS", len(ids)] - pieces.extend(ids) - return self.execute_command("XDELEX", *pieces) - def xgroup_create( self, name: KeyT, @@ -3861,7 +3820,7 @@ class StreamCommands(CommandsProtocol): For more information see https://redis.io/commands/xlen """ - return self.execute_command("XLEN", name, keys=[name]) + return self.execute_command("XLEN", name) def xpending(self, name: KeyT, groupname: GroupT) -> ResponseT: """ @@ -3871,7 +3830,7 @@ class StreamCommands(CommandsProtocol): For more information see https://redis.io/commands/xpending """ - return self.execute_command("XPENDING", name, groupname, keys=[name]) + return self.execute_command("XPENDING", name, groupname) def xpending_range( self, @@ -3881,7 +3840,7 @@ class StreamCommands(CommandsProtocol): max: StreamIdT, count: int, consumername: Union[ConsumerT, None] = None, - idle: Optional[int] = None, + idle: Union[int, None] = None, ) -> ResponseT: """ Returns information about pending messages, in a range. @@ -3935,7 +3894,7 @@ class StreamCommands(CommandsProtocol): name: KeyT, min: StreamIdT = "-", max: StreamIdT = "+", - count: Optional[int] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Read stream values within an interval. @@ -3960,13 +3919,13 @@ class StreamCommands(CommandsProtocol): pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XRANGE", name, *pieces, keys=[name]) + return self.execute_command("XRANGE", name, *pieces) def xread( self, streams: Dict[KeyT, StreamIdT], - count: Optional[int] = None, - block: Optional[int] = None, + count: Union[int, None] = None, + block: Union[int, None] = None, ) -> ResponseT: """ Block and monitor multiple streams for new data. @@ -3998,15 +3957,15 @@ class StreamCommands(CommandsProtocol): keys, values = zip(*streams.items()) pieces.extend(keys) pieces.extend(values) - return self.execute_command("XREAD", *pieces, keys=keys) + return self.execute_command("XREAD", *pieces) def xreadgroup( self, groupname: str, consumername: str, streams: Dict[KeyT, StreamIdT], - count: Optional[int] = None, - block: Optional[int] = None, + count: Union[int, None] = None, + block: Union[int, None] = None, noack: bool = False, ) -> ResponseT: """ @@ -4052,7 +4011,7 @@ class StreamCommands(CommandsProtocol): name: KeyT, max: StreamIdT = "+", min: StreamIdT = "-", - count: Optional[int] = None, + count: Union[int, None] = None, ) -> ResponseT: """ Read stream values within an interval, in reverse order. @@ -4077,16 +4036,15 @@ class StreamCommands(CommandsProtocol): pieces.append(b"COUNT") pieces.append(str(count)) - return self.execute_command("XREVRANGE", name, *pieces, keys=[name]) + return self.execute_command("XREVRANGE", name, *pieces) def xtrim( self, name: KeyT, - maxlen: Optional[int] = None, + maxlen: Union[int, None] = None, approximate: bool = True, minid: Union[StreamIdT, None] = None, - limit: Optional[int] = None, - ref_policy: Optional[Literal["KEEPREF", "DELREF", "ACKED"]] = None, + limit: Union[int, None] = None, ) -> ResponseT: """ Trims old messages from a stream. @@ -4097,10 +4055,6 @@ class StreamCommands(CommandsProtocol): minid: the minimum id in the stream to query Can't be specified with maxlen. limit: specifies the maximum number of entries to retrieve - ref_policy: optional reference policy for consumer groups: - - KEEPREF (default): Trims entries but preserves references in consumer groups' PEL - - DELREF: Trims entries and removes all references from consumer groups' PEL - - ACKED: Only trims entries that were read and acknowledged by all consumer groups For more information see https://redis.io/commands/xtrim """ @@ -4111,9 +4065,6 @@ class StreamCommands(CommandsProtocol): if maxlen is None and minid is None: raise DataError("One of ``maxlen`` or ``minid`` must be specified") - if ref_policy is not None and ref_policy not in {"KEEPREF", "DELREF", "ACKED"}: - raise DataError("XTRIM ref_policy must be one of: KEEPREF, DELREF, ACKED") - if maxlen is not None: pieces.append(b"MAXLEN") if minid is not None: @@ -4127,8 +4078,6 @@ class StreamCommands(CommandsProtocol): if limit is not None: pieces.append(b"LIMIT") pieces.append(limit) - if ref_policy is not None: - pieces.append(ref_policy) return self.execute_command("XTRIM", name, *pieces) @@ -4194,7 +4143,8 @@ class SortedSetCommands(CommandsProtocol): raise DataError("ZADD allows either 'gt' or 'lt', not both") if incr and len(mapping) != 1: raise DataError( - "ZADD option 'incr' only works when passing a single element/score pair" + "ZADD option 'incr' only works when passing a " + "single element/score pair" ) if nx and (gt or lt): raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") @@ -4225,7 +4175,7 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zcard """ - return self.execute_command("ZCARD", name, keys=[name]) + return self.execute_command("ZCARD", name) def zcount(self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT) -> ResponseT: """ @@ -4234,7 +4184,7 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zcount """ - return self.execute_command("ZCOUNT", name, min, max, keys=[name]) + return self.execute_command("ZCOUNT", name, min, max) def zdiff(self, keys: KeysT, withscores: bool = False) -> ResponseT: """ @@ -4246,7 +4196,7 @@ class SortedSetCommands(CommandsProtocol): pieces = [len(keys), *keys] if withscores: pieces.append("WITHSCORES") - return self.execute_command("ZDIFF", *pieces, keys=keys) + return self.execute_command("ZDIFF", *pieces) def zdiffstore(self, dest: KeyT, keys: KeysT) -> ResponseT: """ @@ -4267,7 +4217,7 @@ class SortedSetCommands(CommandsProtocol): return self.execute_command("ZINCRBY", name, amount, value) def zinter( - self, keys: KeysT, aggregate: Optional[str] = None, withscores: bool = False + self, keys: KeysT, aggregate: Union[str, None] = None, withscores: bool = False ) -> ResponseT: """ Return the intersect of multiple sorted sets specified by ``keys``. @@ -4286,7 +4236,7 @@ class SortedSetCommands(CommandsProtocol): self, dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Optional[str] = None, + aggregate: Union[str, None] = None, ) -> ResponseT: """ Intersect multiple sorted sets specified by ``keys`` into a new @@ -4306,7 +4256,7 @@ class SortedSetCommands(CommandsProtocol): ) -> Union[Awaitable[int], int]: """ Return the cardinality of the intersect of multiple sorted sets - specified by ``keys``. + specified by ``keys`. When LIMIT provided (defaults to 0 and means unlimited), if the intersection cardinality reaches limit partway through the computation, the algorithm will exit and yield limit as the cardinality @@ -4314,7 +4264,7 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zintercard """ args = [numkeys, *keys, "LIMIT", limit] - return self.execute_command("ZINTERCARD", *args, keys=keys) + return self.execute_command("ZINTERCARD", *args) def zlexcount(self, name, min, max): """ @@ -4323,9 +4273,9 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zlexcount """ - return self.execute_command("ZLEXCOUNT", name, min, max, keys=[name]) + return self.execute_command("ZLEXCOUNT", name, min, max) - def zpopmax(self, name: KeyT, count: Optional[int] = None) -> ResponseT: + def zpopmax(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: """ Remove and return up to ``count`` members with the highest scores from the sorted set ``name``. @@ -4336,7 +4286,7 @@ class SortedSetCommands(CommandsProtocol): options = {"withscores": True} return self.execute_command("ZPOPMAX", name, *args, **options) - def zpopmin(self, name: KeyT, count: Optional[int] = None) -> ResponseT: + def zpopmin(self, name: KeyT, count: Union[int, None] = None) -> ResponseT: """ Remove and return up to ``count`` members with the lowest scores from the sorted set ``name``. @@ -4348,7 +4298,7 @@ class SortedSetCommands(CommandsProtocol): return self.execute_command("ZPOPMIN", name, *args, **options) def zrandmember( - self, key: KeyT, count: Optional[int] = None, withscores: bool = False + self, key: KeyT, count: int = None, withscores: bool = False ) -> ResponseT: """ Return a random element from the sorted set value stored at key. @@ -4480,8 +4430,8 @@ class SortedSetCommands(CommandsProtocol): bylex: bool = False, withscores: bool = False, score_cast_func: Union[type, Callable, None] = float, - offset: Optional[int] = None, - num: Optional[int] = None, + offset: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: if byscore and bylex: raise DataError("``byscore`` and ``bylex`` can not be specified together.") @@ -4506,7 +4456,6 @@ class SortedSetCommands(CommandsProtocol): if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} - options["keys"] = [name] return self.execute_command(*pieces, **options) def zrange( @@ -4519,8 +4468,8 @@ class SortedSetCommands(CommandsProtocol): score_cast_func: Union[type, Callable] = float, byscore: bool = False, bylex: bool = False, - offset: Optional[int] = None, - num: Optional[int] = None, + offset: int = None, + num: int = None, ) -> ResponseT: """ Return a range of values from sorted set ``name`` between @@ -4595,7 +4544,6 @@ class SortedSetCommands(CommandsProtocol): if withscores: pieces.append(b"WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} - options["keys"] = name return self.execute_command(*pieces, **options) def zrangestore( @@ -4607,8 +4555,8 @@ class SortedSetCommands(CommandsProtocol): byscore: bool = False, bylex: bool = False, desc: bool = False, - offset: Optional[int] = None, - num: Optional[int] = None, + offset: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Stores in ``dest`` the result of a range of values from sorted set @@ -4653,8 +4601,8 @@ class SortedSetCommands(CommandsProtocol): name: KeyT, min: EncodableT, max: EncodableT, - start: Optional[int] = None, - num: Optional[int] = None, + start: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Return the lexicographical range of values from sorted set ``name`` @@ -4670,15 +4618,15 @@ class SortedSetCommands(CommandsProtocol): pieces = ["ZRANGEBYLEX", name, min, max] if start is not None and num is not None: pieces.extend([b"LIMIT", start, num]) - return self.execute_command(*pieces, keys=[name]) + return self.execute_command(*pieces) def zrevrangebylex( self, name: KeyT, max: EncodableT, min: EncodableT, - start: Optional[int] = None, - num: Optional[int] = None, + start: Union[int, None] = None, + num: Union[int, None] = None, ) -> ResponseT: """ Return the reversed lexicographical range of values from sorted set @@ -4694,15 +4642,15 @@ class SortedSetCommands(CommandsProtocol): pieces = ["ZREVRANGEBYLEX", name, max, min] if start is not None and num is not None: pieces.extend(["LIMIT", start, num]) - return self.execute_command(*pieces, keys=[name]) + return self.execute_command(*pieces) def zrangebyscore( self, name: KeyT, min: ZScoreBoundT, max: ZScoreBoundT, - start: Optional[int] = None, - num: Optional[int] = None, + start: Union[int, None] = None, + num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, ) -> ResponseT: @@ -4728,7 +4676,6 @@ class SortedSetCommands(CommandsProtocol): if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} - options["keys"] = [name] return self.execute_command(*pieces, **options) def zrevrangebyscore( @@ -4736,8 +4683,8 @@ class SortedSetCommands(CommandsProtocol): name: KeyT, max: ZScoreBoundT, min: ZScoreBoundT, - start: Optional[int] = None, - num: Optional[int] = None, + start: Union[int, None] = None, + num: Union[int, None] = None, withscores: bool = False, score_cast_func: Union[type, Callable] = float, ): @@ -4763,7 +4710,6 @@ class SortedSetCommands(CommandsProtocol): if withscores: pieces.append("WITHSCORES") options = {"withscores": withscores, "score_cast_func": score_cast_func} - options["keys"] = [name] return self.execute_command(*pieces, **options) def zrank( @@ -4781,8 +4727,8 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zrank """ if withscore: - return self.execute_command("ZRANK", name, value, "WITHSCORE", keys=[name]) - return self.execute_command("ZRANK", name, value, keys=[name]) + return self.execute_command("ZRANK", name, value, "WITHSCORE") + return self.execute_command("ZRANK", name, value) def zrem(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -4840,10 +4786,8 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zrevrank """ if withscore: - return self.execute_command( - "ZREVRANK", name, value, "WITHSCORE", keys=[name] - ) - return self.execute_command("ZREVRANK", name, value, keys=[name]) + return self.execute_command("ZREVRANK", name, value, "WITHSCORE") + return self.execute_command("ZREVRANK", name, value) def zscore(self, name: KeyT, value: EncodableT) -> ResponseT: """ @@ -4851,12 +4795,12 @@ class SortedSetCommands(CommandsProtocol): For more information see https://redis.io/commands/zscore """ - return self.execute_command("ZSCORE", name, value, keys=[name]) + return self.execute_command("ZSCORE", name, value) def zunion( self, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Optional[str] = None, + aggregate: Union[str, None] = None, withscores: bool = False, ) -> ResponseT: """ @@ -4873,7 +4817,7 @@ class SortedSetCommands(CommandsProtocol): self, dest: KeyT, keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Optional[str] = None, + aggregate: Union[str, None] = None, ) -> ResponseT: """ Union multiple sorted sets specified by ``keys`` into @@ -4898,14 +4842,14 @@ class SortedSetCommands(CommandsProtocol): if not members: raise DataError("ZMSCORE members must be a non-empty list") pieces = [key] + members - return self.execute_command("ZMSCORE", *pieces, keys=[key]) + return self.execute_command("ZMSCORE", *pieces) def _zaggregate( self, command: str, dest: Union[KeyT, None], keys: Union[Sequence[KeyT], Mapping[AnyKeyT, float]], - aggregate: Optional[str] = None, + aggregate: Union[str, None] = None, **options, ) -> ResponseT: pieces: list[EncodableT] = [command] @@ -4928,7 +4872,6 @@ class SortedSetCommands(CommandsProtocol): raise DataError("aggregate can be sum, min or max.") if options.get("withscores", False): pieces.append(b"WITHSCORES") - options["keys"] = keys return self.execute_command(*pieces, **options) @@ -4970,23 +4913,13 @@ class HyperlogCommands(CommandsProtocol): AsyncHyperlogCommands = HyperlogCommands -class HashDataPersistOptions(Enum): - # set the value for each provided key to each - # provided value only if all do not already exist. - FNX = "FNX" - - # set the value for each provided key to each - # provided value only if all already exist. - FXX = "FXX" - - class HashCommands(CommandsProtocol): """ Redis commands for Hash data type. see: https://redis.io/topics/data-types-intro#redis-hashes """ - def hdel(self, name: str, *keys: str) -> Union[Awaitable[int], int]: + def hdel(self, name: str, *keys: List) -> Union[Awaitable[int], int]: """ Delete ``keys`` from hash ``name`` @@ -5000,7 +4933,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hexists """ - return self.execute_command("HEXISTS", name, key, keys=[name]) + return self.execute_command("HEXISTS", name, key) def hget( self, name: str, key: str @@ -5010,7 +4943,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hget """ - return self.execute_command("HGET", name, key, keys=[name]) + return self.execute_command("HGET", name, key) def hgetall(self, name: str) -> Union[Awaitable[dict], dict]: """ @@ -5018,81 +4951,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hgetall """ - return self.execute_command("HGETALL", name, keys=[name]) - - def hgetdel( - self, name: str, *keys: str - ) -> Union[ - Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] - ]: - """ - Return the value of ``key`` within the hash ``name`` and - delete the field in the hash. - This command is similar to HGET, except for the fact that it also deletes - the key on success from the hash with the provided ```name```. - - Available since Redis 8.0 - For more information see https://redis.io/commands/hgetdel - """ - if len(keys) == 0: - raise DataError("'hgetdel' should have at least one key provided") - - return self.execute_command("HGETDEL", name, "FIELDS", len(keys), *keys) - - def hgetex( - self, - name: KeyT, - *keys: str, - ex: Optional[ExpiryT] = None, - px: Optional[ExpiryT] = None, - exat: Optional[AbsExpiryT] = None, - pxat: Optional[AbsExpiryT] = None, - persist: bool = False, - ) -> Union[ - Awaitable[Optional[List[Union[str, bytes]]]], Optional[List[Union[str, bytes]]] - ]: - """ - Return the values of ``key`` and ``keys`` within the hash ``name`` - and optionally set their expiration. - - ``ex`` sets an expire flag on ``kyes`` for ``ex`` seconds. - - ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. - - ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, - specified in unix time. - - ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, - specified in unix time. - - ``persist`` remove the time to live associated with the ``keys``. - - Available since Redis 8.0 - For more information see https://redis.io/commands/hgetex - """ - if not keys: - raise DataError("'hgetex' should have at least one key provided") - - opset = {ex, px, exat, pxat} - if len(opset) > 2 or len(opset) > 1 and persist: - raise DataError( - "``ex``, ``px``, ``exat``, ``pxat``, " - "and ``persist`` are mutually exclusive." - ) - - exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) - - if persist: - exp_options.append("PERSIST") - - return self.execute_command( - "HGETEX", - name, - *exp_options, - "FIELDS", - len(keys), - *keys, - ) + return self.execute_command("HGETALL", name) def hincrby( self, name: str, key: str, amount: int = 1 @@ -5120,7 +4979,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hkeys """ - return self.execute_command("HKEYS", name, keys=[name]) + return self.execute_command("HKEYS", name) def hlen(self, name: str) -> Union[Awaitable[int], int]: """ @@ -5128,7 +4987,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hlen """ - return self.execute_command("HLEN", name, keys=[name]) + return self.execute_command("HLEN", name) def hset( self, @@ -5148,103 +5007,16 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hset """ - if key is None and not mapping and not items: raise DataError("'hset' with no key value pairs") - - pieces = [] - if items: - pieces.extend(items) + items = items or [] if key is not None: - pieces.extend((key, value)) + items.extend((key, value)) if mapping: for pair in mapping.items(): - pieces.extend(pair) + items.extend(pair) - return self.execute_command("HSET", name, *pieces) - - def hsetex( - self, - name: str, - key: Optional[str] = None, - value: Optional[str] = None, - mapping: Optional[dict] = None, - items: Optional[list] = None, - ex: Optional[ExpiryT] = None, - px: Optional[ExpiryT] = None, - exat: Optional[AbsExpiryT] = None, - pxat: Optional[AbsExpiryT] = None, - data_persist_option: Optional[HashDataPersistOptions] = None, - keepttl: bool = False, - ) -> Union[Awaitable[int], int]: - """ - Set ``key`` to ``value`` within hash ``name`` - - ``mapping`` accepts a dict of key/value pairs that will be - added to hash ``name``. - - ``items`` accepts a list of key/value pairs that will be - added to hash ``name``. - - ``ex`` sets an expire flag on ``keys`` for ``ex`` seconds. - - ``px`` sets an expire flag on ``keys`` for ``px`` milliseconds. - - ``exat`` sets an expire flag on ``keys`` for ``ex`` seconds, - specified in unix time. - - ``pxat`` sets an expire flag on ``keys`` for ``ex`` milliseconds, - specified in unix time. - - ``data_persist_option`` can be set to ``FNX`` or ``FXX`` to control the - behavior of the command. - ``FNX`` will set the value for each provided key to each - provided value only if all do not already exist. - ``FXX`` will set the value for each provided key to each - provided value only if all already exist. - - ``keepttl`` if True, retain the time to live associated with the keys. - - Returns the number of fields that were added. - - Available since Redis 8.0 - For more information see https://redis.io/commands/hsetex - """ - if key is None and not mapping and not items: - raise DataError("'hsetex' with no key value pairs") - - if items and len(items) % 2 != 0: - raise DataError( - "'hsetex' with odd number of items. " - "'items' must contain a list of key/value pairs." - ) - - opset = {ex, px, exat, pxat} - if len(opset) > 2 or len(opset) > 1 and keepttl: - raise DataError( - "``ex``, ``px``, ``exat``, ``pxat``, " - "and ``keepttl`` are mutually exclusive." - ) - - exp_options: list[EncodableT] = extract_expire_flags(ex, px, exat, pxat) - if data_persist_option: - exp_options.append(data_persist_option.value) - - if keepttl: - exp_options.append("KEEPTTL") - - pieces = [] - if items: - pieces.extend(items) - if key is not None: - pieces.extend((key, value)) - if mapping: - for pair in mapping.items(): - pieces.extend(pair) - - return self.execute_command( - "HSETEX", name, *exp_options, "FIELDS", int(len(pieces) / 2), *pieces - ) + return self.execute_command("HSET", name, *items) def hsetnx(self, name: str, key: str, value: str) -> Union[Awaitable[bool], bool]: """ @@ -5255,11 +5027,6 @@ class HashCommands(CommandsProtocol): """ return self.execute_command("HSETNX", name, key, value) - @deprecated_function( - version="4.0.0", - reason="Use 'hset' instead.", - name="hmset", - ) def hmset(self, name: str, mapping: dict) -> Union[Awaitable[str], str]: """ Set key to value within hash ``name`` for each corresponding @@ -5267,6 +5034,12 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hmset """ + warnings.warn( + f"{self.__class__.__name__}.hmset() is deprecated. " + f"Use {self.__class__.__name__}.hset() instead.", + DeprecationWarning, + stacklevel=2, + ) if not mapping: raise DataError("'hmset' with 'mapping' of length 0") items = [] @@ -5281,7 +5054,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hmget """ args = list_or_args(keys, args) - return self.execute_command("HMGET", name, *args, keys=[name]) + return self.execute_command("HMGET", name, *args) def hvals(self, name: str) -> Union[Awaitable[List], List]: """ @@ -5289,7 +5062,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hvals """ - return self.execute_command("HVALS", name, keys=[name]) + return self.execute_command("HVALS", name) def hstrlen(self, name: str, key: str) -> Union[Awaitable[int], int]: """ @@ -5298,366 +5071,7 @@ class HashCommands(CommandsProtocol): For more information see https://redis.io/commands/hstrlen """ - return self.execute_command("HSTRLEN", name, key, keys=[name]) - - def hexpire( - self, - name: KeyT, - seconds: ExpiryT, - *fields: str, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> ResponseT: - """ - Sets or updates the expiration time for fields within a hash key, using relative - time in seconds. - - If a field already has an expiration time, the behavior of the update can be - controlled using the `nx`, `xx`, `gt`, and `lt` parameters. - - The return value provides detailed information about the outcome for each field. - - For more information, see https://redis.io/commands/hexpire - - Args: - name: The name of the hash key. - seconds: Expiration time in seconds, relative. Can be an integer, or a - Python `timedelta` object. - fields: List of fields within the hash to apply the expiration time to. - nx: Set expiry only when the field has no expiry. - xx: Set expiry only when the field has an existing expiry. - gt: Set expiry only when the new expiry is greater than the current one. - lt: Set expiry only when the new expiry is less than the current one. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `0` if the specified NX | XX | GT | LT condition was not met. - - `1` if the expiration time was set or updated. - - `2` if the field was deleted because the specified expiration time is - in the past. - """ - conditions = [nx, xx, gt, lt] - if sum(conditions) > 1: - raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - - if isinstance(seconds, datetime.timedelta): - seconds = int(seconds.total_seconds()) - - options = [] - if nx: - options.append("NX") - if xx: - options.append("XX") - if gt: - options.append("GT") - if lt: - options.append("LT") - - return self.execute_command( - "HEXPIRE", name, seconds, *options, "FIELDS", len(fields), *fields - ) - - def hpexpire( - self, - name: KeyT, - milliseconds: ExpiryT, - *fields: str, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> ResponseT: - """ - Sets or updates the expiration time for fields within a hash key, using relative - time in milliseconds. - - If a field already has an expiration time, the behavior of the update can be - controlled using the `nx`, `xx`, `gt`, and `lt` parameters. - - The return value provides detailed information about the outcome for each field. - - For more information, see https://redis.io/commands/hpexpire - - Args: - name: The name of the hash key. - milliseconds: Expiration time in milliseconds, relative. Can be an integer, - or a Python `timedelta` object. - fields: List of fields within the hash to apply the expiration time to. - nx: Set expiry only when the field has no expiry. - xx: Set expiry only when the field has an existing expiry. - gt: Set expiry only when the new expiry is greater than the current one. - lt: Set expiry only when the new expiry is less than the current one. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `0` if the specified NX | XX | GT | LT condition was not met. - - `1` if the expiration time was set or updated. - - `2` if the field was deleted because the specified expiration time is - in the past. - """ - conditions = [nx, xx, gt, lt] - if sum(conditions) > 1: - raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - - if isinstance(milliseconds, datetime.timedelta): - milliseconds = int(milliseconds.total_seconds() * 1000) - - options = [] - if nx: - options.append("NX") - if xx: - options.append("XX") - if gt: - options.append("GT") - if lt: - options.append("LT") - - return self.execute_command( - "HPEXPIRE", name, milliseconds, *options, "FIELDS", len(fields), *fields - ) - - def hexpireat( - self, - name: KeyT, - unix_time_seconds: AbsExpiryT, - *fields: str, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> ResponseT: - """ - Sets or updates the expiration time for fields within a hash key, using an - absolute Unix timestamp in seconds. - - If a field already has an expiration time, the behavior of the update can be - controlled using the `nx`, `xx`, `gt`, and `lt` parameters. - - The return value provides detailed information about the outcome for each field. - - For more information, see https://redis.io/commands/hexpireat - - Args: - name: The name of the hash key. - unix_time_seconds: Expiration time as Unix timestamp in seconds. Can be an - integer or a Python `datetime` object. - fields: List of fields within the hash to apply the expiration time to. - nx: Set expiry only when the field has no expiry. - xx: Set expiry only when the field has an existing expiration time. - gt: Set expiry only when the new expiry is greater than the current one. - lt: Set expiry only when the new expiry is less than the current one. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `0` if the specified NX | XX | GT | LT condition was not met. - - `1` if the expiration time was set or updated. - - `2` if the field was deleted because the specified expiration time is - in the past. - """ - conditions = [nx, xx, gt, lt] - if sum(conditions) > 1: - raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - - if isinstance(unix_time_seconds, datetime.datetime): - unix_time_seconds = int(unix_time_seconds.timestamp()) - - options = [] - if nx: - options.append("NX") - if xx: - options.append("XX") - if gt: - options.append("GT") - if lt: - options.append("LT") - - return self.execute_command( - "HEXPIREAT", - name, - unix_time_seconds, - *options, - "FIELDS", - len(fields), - *fields, - ) - - def hpexpireat( - self, - name: KeyT, - unix_time_milliseconds: AbsExpiryT, - *fields: str, - nx: bool = False, - xx: bool = False, - gt: bool = False, - lt: bool = False, - ) -> ResponseT: - """ - Sets or updates the expiration time for fields within a hash key, using an - absolute Unix timestamp in milliseconds. - - If a field already has an expiration time, the behavior of the update can be - controlled using the `nx`, `xx`, `gt`, and `lt` parameters. - - The return value provides detailed information about the outcome for each field. - - For more information, see https://redis.io/commands/hpexpireat - - Args: - name: The name of the hash key. - unix_time_milliseconds: Expiration time as Unix timestamp in milliseconds. - Can be an integer or a Python `datetime` object. - fields: List of fields within the hash to apply the expiry. - nx: Set expiry only when the field has no expiry. - xx: Set expiry only when the field has an existing expiry. - gt: Set expiry only when the new expiry is greater than the current one. - lt: Set expiry only when the new expiry is less than the current one. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `0` if the specified NX | XX | GT | LT condition was not met. - - `1` if the expiration time was set or updated. - - `2` if the field was deleted because the specified expiration time is - in the past. - """ - conditions = [nx, xx, gt, lt] - if sum(conditions) > 1: - raise ValueError("Only one of 'nx', 'xx', 'gt', 'lt' can be specified.") - - if isinstance(unix_time_milliseconds, datetime.datetime): - unix_time_milliseconds = int(unix_time_milliseconds.timestamp() * 1000) - - options = [] - if nx: - options.append("NX") - if xx: - options.append("XX") - if gt: - options.append("GT") - if lt: - options.append("LT") - - return self.execute_command( - "HPEXPIREAT", - name, - unix_time_milliseconds, - *options, - "FIELDS", - len(fields), - *fields, - ) - - def hpersist(self, name: KeyT, *fields: str) -> ResponseT: - """ - Removes the expiration time for each specified field in a hash. - - For more information, see https://redis.io/commands/hpersist - - Args: - name: The name of the hash key. - fields: A list of fields within the hash from which to remove the - expiration time. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `-1` if the field exists but has no associated expiration time. - - `1` if the expiration time was successfully removed from the field. - """ - return self.execute_command("HPERSIST", name, "FIELDS", len(fields), *fields) - - def hexpiretime(self, key: KeyT, *fields: str) -> ResponseT: - """ - Returns the expiration times of hash fields as Unix timestamps in seconds. - - For more information, see https://redis.io/commands/hexpiretime - - Args: - key: The hash key. - fields: A list of fields within the hash for which to get the expiration - time. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `-1` if the field exists but has no associated expire time. - - A positive integer representing the expiration Unix timestamp in - seconds, if the field has an associated expiration time. - """ - return self.execute_command( - "HEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] - ) - - def hpexpiretime(self, key: KeyT, *fields: str) -> ResponseT: - """ - Returns the expiration times of hash fields as Unix timestamps in milliseconds. - - For more information, see https://redis.io/commands/hpexpiretime - - Args: - key: The hash key. - fields: A list of fields within the hash for which to get the expiration - time. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `-1` if the field exists but has no associated expire time. - - A positive integer representing the expiration Unix timestamp in - milliseconds, if the field has an associated expiration time. - """ - return self.execute_command( - "HPEXPIRETIME", key, "FIELDS", len(fields), *fields, keys=[key] - ) - - def httl(self, key: KeyT, *fields: str) -> ResponseT: - """ - Returns the TTL (Time To Live) in seconds for each specified field within a hash - key. - - For more information, see https://redis.io/commands/httl - - Args: - key: The hash key. - fields: A list of fields within the hash for which to get the TTL. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `-1` if the field exists but has no associated expire time. - - A positive integer representing the TTL in seconds if the field has - an associated expiration time. - """ - return self.execute_command( - "HTTL", key, "FIELDS", len(fields), *fields, keys=[key] - ) - - def hpttl(self, key: KeyT, *fields: str) -> ResponseT: - """ - Returns the TTL (Time To Live) in milliseconds for each specified field within a - hash key. - - For more information, see https://redis.io/commands/hpttl - - Args: - key: The hash key. - fields: A list of fields within the hash for which to get the TTL. - - Returns: - Returns a list which contains for each field in the request: - - `-2` if the field does not exist, or if the key does not exist. - - `-1` if the field exists but has no associated expire time. - - A positive integer representing the TTL in milliseconds if the field - has an associated expiration time. - """ - return self.execute_command( - "HPTTL", key, "FIELDS", len(fields), *fields, keys=[key] - ) + return self.execute_command("HSTRLEN", name, key) AsyncHashCommands = HashCommands @@ -5668,7 +5082,7 @@ class Script: An executable Lua script object returned by ``register_script`` """ - def __init__(self, registered_client: "redis.client.Redis", script: ScriptTextT): + def __init__(self, registered_client: "Redis", script: ScriptTextT): self.registered_client = registered_client self.script = script # Precalculate and store the SHA1 hex digest of the script. @@ -5676,7 +5090,11 @@ class Script: if isinstance(script, str): # We need the encoding from the client in order to generate an # accurate byte representation of the script - encoder = self.get_encoder() + try: + encoder = registered_client.connection_pool.get_encoder() + except AttributeError: + # Cluster + encoder = registered_client.get_encoder() script = encoder.encode(script) self.sha = hashlib.sha1(script).hexdigest() @@ -5684,7 +5102,7 @@ class Script: self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.client.Redis", None] = None, + client: Union["Redis", None] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -5707,35 +5125,13 @@ class Script: self.sha = client.script_load(self.script) return client.evalsha(self.sha, len(keys), *args) - def get_encoder(self): - """Get the encoder to encode string scripts into bytes.""" - try: - return self.registered_client.get_encoder() - except AttributeError: - # DEPRECATED - # In version <=4.1.2, this was the code we used to get the encoder. - # However, after 4.1.2 we added support for scripting in clustered - # redis. ClusteredRedis doesn't have a `.connection_pool` attribute - # so we changed the Script class to use - # `self.registered_client.get_encoder` (see above). - # However, that is technically a breaking change, as consumers who - # use Scripts directly might inject a `registered_client` that - # doesn't have a `.get_encoder` field. This try/except prevents us - # from breaking backward-compatibility. Ideally, it would be - # removed in the next major release. - return self.registered_client.connection_pool.get_encoder() - class AsyncScript: """ An executable Lua script object returned by ``register_script`` """ - def __init__( - self, - registered_client: "redis.asyncio.client.Redis", - script: ScriptTextT, - ): + def __init__(self, registered_client: "AsyncRedis", script: ScriptTextT): self.registered_client = registered_client self.script = script # Precalculate and store the SHA1 hex digest of the script. @@ -5755,7 +5151,7 @@ class AsyncScript: self, keys: Union[Sequence[KeyT], None] = None, args: Union[Iterable[EncodableT], None] = None, - client: Union["redis.asyncio.client.Redis", None] = None, + client: Union["AsyncRedis", None] = None, ): """Execute the script, passing any required ``args``""" keys = keys or [] @@ -5852,16 +5248,16 @@ AsyncPubSubCommands = PubSubCommands class ScriptCommands(CommandsProtocol): """ Redis Lua script commands. see: - https://redis.io/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ + https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ """ def _eval( - self, command: str, script: str, numkeys: int, *keys_and_args: str + self, command: str, script: str, numkeys: int, *keys_and_args: list ) -> Union[Awaitable[str], str]: return self.execute_command(command, script, numkeys, *keys_and_args) def eval( - self, script: str, numkeys: int, *keys_and_args: str + self, script: str, numkeys: int, *keys_and_args: list ) -> Union[Awaitable[str], str]: """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -5876,7 +5272,7 @@ class ScriptCommands(CommandsProtocol): return self._eval("EVAL", script, numkeys, *keys_and_args) def eval_ro( - self, script: str, numkeys: int, *keys_and_args: str + self, script: str, numkeys: int, *keys_and_args: list ) -> Union[Awaitable[str], str]: """ The read-only variant of the EVAL command @@ -5895,7 +5291,7 @@ class ScriptCommands(CommandsProtocol): return self.execute_command(command, sha, numkeys, *keys_and_args) def evalsha( - self, sha: str, numkeys: int, *keys_and_args: str + self, sha: str, numkeys: int, *keys_and_args: list ) -> Union[Awaitable[str], str]: """ Use the ``sha`` to execute a Lua script already registered via EVAL @@ -5911,7 +5307,7 @@ class ScriptCommands(CommandsProtocol): return self._evalsha("EVALSHA", sha, numkeys, *keys_and_args) def evalsha_ro( - self, sha: str, numkeys: int, *keys_and_args: str + self, sha: str, numkeys: int, *keys_and_args: list ) -> Union[Awaitable[str], str]: """ The read-only variant of the EVALSHA command @@ -5929,7 +5325,7 @@ class ScriptCommands(CommandsProtocol): """ Check if a script exists in the script cache by specifying the SHAs of each script as ``args``. Returns a list of boolean values indicating if - if each already script exists in the cache_data. + if each already script exists in the cache. For more information see https://redis.io/commands/script-exists """ @@ -5943,7 +5339,7 @@ class ScriptCommands(CommandsProtocol): def script_flush( self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None ) -> ResponseT: - """Flush all scripts from the script cache_data. + """Flush all scripts from the script cache. ``sync_type`` is by default SYNC (synchronous) but it can also be ASYNC. @@ -5974,13 +5370,13 @@ class ScriptCommands(CommandsProtocol): def script_load(self, script: ScriptTextT) -> ResponseT: """ - Load a Lua ``script`` into the script cache_data. Returns the SHA. + Load a Lua ``script`` into the script cache. Returns the SHA. For more information see https://redis.io/commands/script-load """ return self.execute_command("SCRIPT LOAD", script) - def register_script(self: "redis.client.Redis", script: ScriptTextT) -> Script: + def register_script(self: "Redis", script: ScriptTextT) -> Script: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -5994,10 +5390,7 @@ class AsyncScriptCommands(ScriptCommands): async def script_debug(self, *args) -> None: return super().script_debug() - def register_script( - self: "redis.asyncio.client.Redis", - script: ScriptTextT, - ) -> AsyncScript: + def register_script(self: "AsyncRedis", script: ScriptTextT) -> AsyncScript: """ Register a Lua ``script`` specifying the ``keys`` it will touch. Returns a Script object that is callable and hides the complexity of @@ -6056,7 +5449,7 @@ class GeoCommands(CommandsProtocol): return self.execute_command("GEOADD", *pieces) def geodist( - self, name: KeyT, place1: FieldT, place2: FieldT, unit: Optional[str] = None + self, name: KeyT, place1: FieldT, place2: FieldT, unit: Union[str, None] = None ) -> ResponseT: """ Return the distance between ``place1`` and ``place2`` members of the @@ -6071,7 +5464,7 @@ class GeoCommands(CommandsProtocol): raise DataError("GEODIST invalid unit") elif unit: pieces.append(unit) - return self.execute_command("GEODIST", *pieces, keys=[name]) + return self.execute_command("GEODIST", *pieces) def geohash(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -6080,7 +5473,7 @@ class GeoCommands(CommandsProtocol): For more information see https://redis.io/commands/geohash """ - return self.execute_command("GEOHASH", name, *values, keys=[name]) + return self.execute_command("GEOHASH", name, *values) def geopos(self, name: KeyT, *values: FieldT) -> ResponseT: """ @@ -6090,7 +5483,7 @@ class GeoCommands(CommandsProtocol): For more information see https://redis.io/commands/geopos """ - return self.execute_command("GEOPOS", name, *values, keys=[name]) + return self.execute_command("GEOPOS", name, *values) def georadius( self, @@ -6098,14 +5491,14 @@ class GeoCommands(CommandsProtocol): longitude: float, latitude: float, radius: float, - unit: Optional[str] = None, + unit: Union[str, None] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: Optional[int] = None, - sort: Optional[str] = None, - store: Optional[KeyT] = None, - store_dist: Optional[KeyT] = None, + count: Union[int, None] = None, + sort: Union[str, None] = None, + store: Union[KeyT, None] = None, + store_dist: Union[KeyT, None] = None, any: bool = False, ) -> ResponseT: """ @@ -6160,12 +5553,12 @@ class GeoCommands(CommandsProtocol): name: KeyT, member: FieldT, radius: float, - unit: Optional[str] = None, + unit: Union[str, None] = None, withdist: bool = False, withcoord: bool = False, withhash: bool = False, - count: Optional[int] = None, - sort: Optional[str] = None, + count: Union[int, None] = None, + sort: Union[str, None] = None, store: Union[KeyT, None] = None, store_dist: Union[KeyT, None] = None, any: bool = False, @@ -6250,8 +5643,8 @@ class GeoCommands(CommandsProtocol): radius: Union[float, None] = None, width: Union[float, None] = None, height: Union[float, None] = None, - sort: Optional[str] = None, - count: Optional[int] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, any: bool = False, withcoord: bool = False, withdist: bool = False, @@ -6325,15 +5718,15 @@ class GeoCommands(CommandsProtocol): self, dest: KeyT, name: KeyT, - member: Optional[FieldT] = None, - longitude: Optional[float] = None, - latitude: Optional[float] = None, + member: Union[FieldT, None] = None, + longitude: Union[float, None] = None, + latitude: Union[float, None] = None, unit: str = "m", - radius: Optional[float] = None, - width: Optional[float] = None, - height: Optional[float] = None, - sort: Optional[str] = None, - count: Optional[int] = None, + radius: Union[float, None] = None, + width: Union[float, None] = None, + height: Union[float, None] = None, + sort: Union[str, None] = None, + count: Union[int, None] = None, any: bool = False, storedist: bool = False, ) -> ResponseT: @@ -6430,8 +5823,6 @@ class GeoCommands(CommandsProtocol): if kwargs[arg_name]: pieces.append(byte_repr) - kwargs["keys"] = [args[0] if command == "GEOSEARCH" else args[1]] - return self.execute_command(command, *pieces, **kwargs) @@ -6508,6 +5899,62 @@ class ModuleCommands(CommandsProtocol): return self.execute_command("COMMAND") +class Script: + """ + An executable Lua script object returned by ``register_script`` + """ + + def __init__(self, registered_client, script): + self.registered_client = registered_client + self.script = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = self.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + def __call__(self, keys=[], args=[], client=None): + "Execute the script, passing any required ``args``" + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + from redis.client import Pipeline + + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + try: + return client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a different server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = client.script_load(self.script) + return client.evalsha(self.sha, len(keys), *args) + + def get_encoder(self): + """Get the encoder to encode string scripts into bytes.""" + try: + return self.registered_client.get_encoder() + except AttributeError: + # DEPRECATED + # In version <=4.1.2, this was the code we used to get the encoder. + # However, after 4.1.2 we added support for scripting in clustered + # redis. ClusteredRedis doesn't have a `.connection_pool` attribute + # so we changed the Script class to use + # `self.registered_client.get_encoder` (see above). + # However, that is technically a breaking change, as consumers who + # use Scripts directly might inject a `registered_client` that + # doesn't have a `.get_encoder` field. This try/except prevents us + # from breaking backward-compatibility. Ideally, it would be + # removed in the next major release. + return self.registered_client.connection_pool.get_encoder() + + class AsyncModuleCommands(ModuleCommands): async def command_info(self) -> None: return super().command_info() @@ -6584,12 +6031,9 @@ class FunctionCommands: ) -> Union[Awaitable[List], List]: """ Return information about the functions and libraries. - - Args: - - library: specify a pattern for matching library names - withcode: cause the server to include the libraries source implementation - in the reply + :param library: pecify a pattern for matching library names + :param withcode: cause the server to include the libraries source + implementation in the reply """ args = ["LIBRARYNAME", library] if withcode: @@ -6597,12 +6041,12 @@ class FunctionCommands: return self.execute_command("FUNCTION LIST", *args) def _fcall( - self, command: str, function, numkeys: int, *keys_and_args: Any + self, command: str, function, numkeys: int, *keys_and_args: Optional[List] ) -> Union[Awaitable[str], str]: return self.execute_command(command, function, numkeys, *keys_and_args) def fcall( - self, function, numkeys: int, *keys_and_args: Any + self, function, numkeys: int, *keys_and_args: Optional[List] ) -> Union[Awaitable[str], str]: """ Invoke a function. @@ -6612,13 +6056,13 @@ class FunctionCommands: return self._fcall("FCALL", function, numkeys, *keys_and_args) def fcall_ro( - self, function, numkeys: int, *keys_and_args: Any + self, function, numkeys: int, *keys_and_args: Optional[List] ) -> Union[Awaitable[str], str]: """ This is a read-only variant of the FCALL command that cannot execute commands that modify data. - For more information see https://redis.io/commands/fcall_ro + For more information see https://redis.io/commands/fcal_ro """ return self._fcall("FCALL_RO", function, numkeys, *keys_and_args) @@ -6668,6 +6112,131 @@ class FunctionCommands: AsyncFunctionCommands = FunctionCommands +class GearsCommands: + def tfunction_load( + self, lib_code: str, replace: bool = False, config: Union[str, None] = None + ) -> ResponseT: + """ + Load a new library to RedisGears. + + ``lib_code`` - the library code. + ``config`` - a string representation of a JSON object + that will be provided to the library on load time, + for more information refer to + https://github.com/RedisGears/RedisGears/blob/master/docs/function_advance_topics.md#library-configuration + ``replace`` - an optional argument, instructs RedisGears to replace the + function if its already exists + + For more information see https://redis.io/commands/tfunction-load/ + """ + pieces = [] + if replace: + pieces.append("REPLACE") + if config is not None: + pieces.extend(["CONFIG", config]) + pieces.append(lib_code) + return self.execute_command("TFUNCTION LOAD", *pieces) + + def tfunction_delete(self, lib_name: str) -> ResponseT: + """ + Delete a library from RedisGears. + + ``lib_name`` the library name to delete. + + For more information see https://redis.io/commands/tfunction-delete/ + """ + return self.execute_command("TFUNCTION DELETE", lib_name) + + def tfunction_list( + self, + with_code: bool = False, + verbose: int = 0, + lib_name: Union[str, None] = None, + ) -> ResponseT: + """ + List the functions with additional information about each function. + + ``with_code`` Show libraries code. + ``verbose`` output verbosity level, higher number will increase verbosity level + ``lib_name`` specifying a library name (can be used multiple times to show multiple libraries in a single command) # noqa + + For more information see https://redis.io/commands/tfunction-list/ + """ + pieces = [] + if with_code: + pieces.append("WITHCODE") + if verbose >= 1 and verbose <= 3: + pieces.append("v" * verbose) + else: + raise DataError("verbose can be 1, 2 or 3") + if lib_name is not None: + pieces.append("LIBRARY") + pieces.append(lib_name) + + return self.execute_command("TFUNCTION LIST", *pieces) + + def _tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + _async: bool = False, + *args: List, + ) -> ResponseT: + pieces = [f"{lib_name}.{func_name}"] + if keys is not None: + pieces.append(len(keys)) + pieces.extend(keys) + else: + pieces.append(0) + if args is not None: + pieces.extend(args) + if _async: + return self.execute_command("TFCALLASYNC", *pieces) + return self.execute_command("TFCALL", *pieces) + + def tfcall( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke a function. + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, False, *args) + + def tfcall_async( + self, + lib_name: str, + func_name: str, + keys: KeysT = None, + *args: List, + ) -> ResponseT: + """ + Invoke an async function (coroutine). + + ``lib_name`` - the library name contains the function. + ``func_name`` - the function name to run. + ``keys`` - the keys that will be touched by the function. + ``args`` - Additional argument to pass to the function. + + For more information see https://redis.io/commands/tfcall/ + """ + return self._tfcall(lib_name, func_name, keys, True, *args) + + +AsyncGearsCommands = GearsCommands + + class DataAccessCommands( BasicKeyCommands, HyperlogCommands, @@ -6711,6 +6280,7 @@ class CoreCommands( PubSubCommands, ScriptCommands, FunctionCommands, + GearsCommands, ): """ A class containing all of the implemented redis commands. This class is @@ -6727,6 +6297,7 @@ class AsyncCoreCommands( AsyncPubSubCommands, AsyncScriptCommands, AsyncFunctionCommands, + AsyncGearsCommands, ): """ A class containing all of the implemented redis commands. This class is diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/__init__.py b/venv/lib/python3.12/site-packages/redis/commands/graph/__init__.py new file mode 100644 index 0000000..ffaf1fb --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/__init__.py @@ -0,0 +1,263 @@ +import warnings + +from ..helpers import quote_string, random_string, stringify_param_value +from .commands import AsyncGraphCommands, GraphCommands +from .edge import Edge # noqa +from .node import Node # noqa +from .path import Path # noqa + +DB_LABELS = "DB.LABELS" +DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" +DB_PROPERTYKEYS = "DB.PROPERTYKEYS" + + +class Graph(GraphCommands): + """ + Graph, collection of nodes and edges. + """ + + def __init__(self, client, name=random_string()): + """ + Create a new graph. + """ + warnings.warn( + DeprecationWarning( + "RedisGraph support is deprecated as of Redis Stack 7.2 \ + (https://redis.com/blog/redisgraph-eol/)" + ) + ) + self.NAME = name # Graph key + self.client = client + self.execute_command = client.execute_command + + self.nodes = {} + self.edges = [] + self._labels = [] # List of node labels. + self._properties = [] # List of properties. + self._relationship_types = [] # List of relation types. + self.version = 0 # Graph version + + @property + def name(self): + return self.NAME + + def _clear_schema(self): + self._labels = [] + self._properties = [] + self._relationship_types = [] + + def _refresh_schema(self): + self._clear_schema() + self._refresh_labels() + self._refresh_relations() + self._refresh_attributes() + + def _refresh_labels(self): + lbls = self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + def _refresh_relations(self): + rels = self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + def _refresh_attributes(self): + props = self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + self._refresh_labels() + label = self._labels[idx] + return label + + def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + self._refresh_attributes() + p = self._properties[idx] + return p + + def add_node(self, node): + """ + Adds a node to the graph. + """ + if node.alias is None: + node.alias = random_string() + self.nodes[node.alias] = node + + def add_edge(self, edge): + """ + Adds an edge to the graph. + """ + if not (self.nodes[edge.src_node.alias] and self.nodes[edge.dest_node.alias]): + raise AssertionError("Both edge's end must be in the graph") + + self.edges.append(edge) + + def _build_params_header(self, params): + if params is None: + return "" + if not isinstance(params, dict): + raise TypeError("'params' must be a dict") + # Header starts with "CYPHER" + params_header = "CYPHER " + for key, value in params.items(): + params_header += str(key) + "=" + stringify_param_value(value) + " " + return params_header + + # Procedures. + def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + q += f"YIELD {','.join(y)}" + + return self.query(q, read_only=read_only) + + def labels(self): + return self.call_procedure(DB_LABELS, read_only=True).result_set + + def relationship_types(self): + return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set + + def property_keys(self): + return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set + + +class AsyncGraph(Graph, AsyncGraphCommands): + """Async version for Graph""" + + async def _refresh_labels(self): + lbls = await self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + async def _refresh_attributes(self): + props = await self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + async def _refresh_relations(self): + rels = await self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + async def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + await self._refresh_labels() + label = self._labels[idx] + return label + + async def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + await self._refresh_attributes() + p = self._properties[idx] + return p + + async def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + await self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + f"YIELD {','.join(y)}" + return await self.query(q, read_only=read_only) + + async def labels(self): + return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set + + async def property_keys(self): + return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set + + async def relationship_types(self): + return ( + await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) + ).result_set diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/commands.py b/venv/lib/python3.12/site-packages/redis/commands/graph/commands.py new file mode 100644 index 0000000..762ab42 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/commands.py @@ -0,0 +1,313 @@ +from redis import DataError +from redis.exceptions import ResponseError + +from .exceptions import VersionMismatchException +from .execution_plan import ExecutionPlan +from .query_result import AsyncQueryResult, QueryResult + +PROFILE_CMD = "GRAPH.PROFILE" +RO_QUERY_CMD = "GRAPH.RO_QUERY" +QUERY_CMD = "GRAPH.QUERY" +DELETE_CMD = "GRAPH.DELETE" +SLOWLOG_CMD = "GRAPH.SLOWLOG" +CONFIG_CMD = "GRAPH.CONFIG" +LIST_CMD = "GRAPH.LIST" +EXPLAIN_CMD = "GRAPH.EXPLAIN" + + +class GraphCommands: + """RedisGraph Commands""" + + def commit(self): + """ + Create entire graph. + """ + if len(self.nodes) == 0 and len(self.edges) == 0: + return None + + query = "CREATE " + for _, node in self.nodes.items(): + query += str(node) + "," + + query += ",".join([str(edge) for edge in self.edges]) + + # Discard leading comma. + if query[-1] == ",": + query = query[:-1] + + return self.query(query) + + def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = self.execute_command(*command) + return QueryResult(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return self.query(q, params, timeout, read_only) + + def merge(self, pattern): + """ + Merge pattern. + """ + query = "MERGE " + query += str(pattern) + + return self.query(query) + + def delete(self): + """ + Deletes graph. + For more information see `DELETE `_. # noqa + """ + self._clear_schema() + return self.execute_command(DELETE_CMD, self.name) + + # declared here, to override the built in redis.db.flush() + def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + self.commit() + self.nodes = {} + self.edges = [] + + def bulk(self, **kwargs): + """Internal only. Not supported.""" + raise NotImplementedError( + "GRAPH.BULK is internal only. " + "Use https://github.com/redisgraph/redisgraph-bulk-loader." + ) + + def profile(self, query): + """ + Execute a query and produce an execution plan augmented with metrics + for each operation's execution. Return a string representation of a + query execution plan, with details on results produced by and time + spent in each operation. + For more information see `GRAPH.PROFILE `_. # noqa + """ + return self.query(query, profile=True) + + def slowlog(self): + """ + Get a list containing up to 10 of the slowest queries issued + against the given graph ID. + For more information see `GRAPH.SLOWLOG `_. # noqa + + Each item in the list has the following structure: + 1. A unix timestamp at which the log entry was processed. + 2. The issued command. + 3. The issued query. + 4. The amount of time needed for its execution, in milliseconds. + """ + return self.execute_command(SLOWLOG_CMD, self.name) + + def config(self, name, value=None, set=False): + """ + Retrieve or update a RedisGraph configuration. + For more information see `https://redis.io/commands/graph.config-get/>`_. # noqa + + Args: + + name : str + The name of the configuration + value : + The value we want to set (can be used only when `set` is on) + set : bool + Turn on to set a configuration. Default behavior is get. + """ + params = ["SET" if set else "GET", name] + if value is not None: + if set: + params.append(value) + else: + raise DataError( + "``value`` can be provided only when ``set`` is True" + ) # noqa + return self.execute_command(CONFIG_CMD, *params) + + def list_keys(self): + """ + Lists all graph keys in the keyspace. + For more information see `GRAPH.LIST `_. # noqa + """ + return self.execute_command(LIST_CMD) + + def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + For more information see `GRAPH.EXPLAIN `_. # noqa + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + +class AsyncGraphCommands(GraphCommands): + async def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = await self.execute_command(*command) + return await AsyncQueryResult().initialize(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return await self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return await self.query(q, params, timeout, read_only) + + async def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + async def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + async def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + await self.commit() + self.nodes = {} + self.edges = [] diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/edge.py b/venv/lib/python3.12/site-packages/redis/commands/graph/edge.py new file mode 100644 index 0000000..6ee195f --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/edge.py @@ -0,0 +1,91 @@ +from ..helpers import quote_string +from .node import Node + + +class Edge: + """ + An edge connecting two nodes. + """ + + def __init__(self, src_node, relation, dest_node, edge_id=None, properties=None): + """ + Create a new edge. + """ + if src_node is None or dest_node is None: + # NOTE(bors-42): It makes sense to change AssertionError to + # ValueError here + raise AssertionError("Both src_node & dest_node must be provided") + + self.id = edge_id + self.relation = relation or "" + self.properties = properties or {} + self.src_node = src_node + self.dest_node = dest_node + + def to_string(self): + res = "" + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + + return res + + def __str__(self): + # Source node. + if isinstance(self.src_node, Node): + res = str(self.src_node) + else: + res = "()" + + # Edge + res += "-[" + if self.relation: + res += ":" + self.relation + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + res += "]->" + + # Dest node. + if isinstance(self.dest_node, Node): + res += str(self.dest_node) + else: + res += "()" + + return res + + def __eq__(self, rhs): + # Type checking + if not isinstance(rhs, Edge): + return False + + # Quick positive check, if both IDs are set. + if self.id is not None and rhs.id is not None and self.id == rhs.id: + return True + + # Source and destination nodes should match. + if self.src_node != rhs.src_node: + return False + + if self.dest_node != rhs.dest_node: + return False + + # Relation should match. + if self.relation != rhs.relation: + return False + + # Quick check for number of properties. + if len(self.properties) != len(rhs.properties): + return False + + # Compare properties. + if self.properties != rhs.properties: + return False + + return True diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/exceptions.py b/venv/lib/python3.12/site-packages/redis/commands/graph/exceptions.py new file mode 100644 index 0000000..4bbac10 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/exceptions.py @@ -0,0 +1,3 @@ +class VersionMismatchException(Exception): + def __init__(self, version): + self.version = version diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/execution_plan.py b/venv/lib/python3.12/site-packages/redis/commands/graph/execution_plan.py new file mode 100644 index 0000000..179a80c --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/execution_plan.py @@ -0,0 +1,211 @@ +import re + + +class ProfileStats: + """ + ProfileStats, runtime execution statistics of operation. + """ + + def __init__(self, records_produced, execution_time): + self.records_produced = records_produced + self.execution_time = execution_time + + +class Operation: + """ + Operation, single operation within execution plan. + """ + + def __init__(self, name, args=None, profile_stats=None): + """ + Create a new operation. + + Args: + name: string that represents the name of the operation + args: operation arguments + profile_stats: profile statistics + """ + self.name = name + self.args = args + self.profile_stats = profile_stats + self.children = [] + + def append_child(self, child): + if not isinstance(child, Operation) or self is child: + raise Exception("child must be Operation") + + self.children.append(child) + return self + + def child_count(self): + return len(self.children) + + def __eq__(self, o: object) -> bool: + if not isinstance(o, Operation): + return False + + return self.name == o.name and self.args == o.args + + def __str__(self) -> str: + args_str = "" if self.args is None else " | " + self.args + return f"{self.name}{args_str}" + + +class ExecutionPlan: + """ + ExecutionPlan, collection of operations. + """ + + def __init__(self, plan): + """ + Create a new execution plan. + + Args: + plan: array of strings that represents the collection operations + the output from GRAPH.EXPLAIN + """ + if not isinstance(plan, list): + raise Exception("plan must be an array") + + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + + self.plan = plan + self.structured_plan = self._operation_tree() + + def _compare_operations(self, root_a, root_b): + """ + Compare execution plan operation tree + + Return: True if operation trees are equal, False otherwise + """ + + # compare current root + if root_a != root_b: + return False + + # make sure root have the same number of children + if root_a.child_count() != root_b.child_count(): + return False + + # recursively compare children + for i in range(root_a.child_count()): + if not self._compare_operations(root_a.children[i], root_b.children[i]): + return False + + return True + + def __str__(self) -> str: + def aggraget_str(str_children): + return "\n".join( + [ + " " + line + for str_child in str_children + for line in str_child.splitlines() + ] + ) + + def combine_str(x, y): + return f"{x}\n{y}" + + return self._operation_traverse( + self.structured_plan, str, aggraget_str, combine_str + ) + + def __eq__(self, o: object) -> bool: + """Compares two execution plans + + Return: True if the two plans are equal False otherwise + """ + # make sure 'o' is an execution-plan + if not isinstance(o, ExecutionPlan): + return False + + # get root for both plans + root_a = self.structured_plan + root_b = o.structured_plan + + # compare execution trees + return self._compare_operations(root_a, root_b) + + def _operation_traverse(self, op, op_f, aggregate_f, combine_f): + """ + Traverse operation tree recursively applying functions + + Args: + op: operation to traverse + op_f: function applied for each operation + aggregate_f: aggregation function applied for all children of a single operation + combine_f: combine function applied for the operation result and the children result + """ # noqa + # apply op_f for each operation + op_res = op_f(op) + if len(op.children) == 0: + return op_res # no children return + else: + # apply _operation_traverse recursively + children = [ + self._operation_traverse(child, op_f, aggregate_f, combine_f) + for child in op.children + ] + # combine the operation result with the children aggregated result + return combine_f(op_res, aggregate_f(children)) + + def _operation_tree(self): + """Build the operation tree from the string representation""" + + # initial state + i = 0 + level = 0 + stack = [] + current = None + + def _create_operation(args): + profile_stats = None + name = args[0].strip() + args.pop(0) + if len(args) > 0 and "Records produced" in args[-1]: + records_produced = int( + re.search("Records produced: (\\d+)", args[-1]).group(1) + ) + execution_time = float( + re.search("Execution time: (\\d+.\\d+) ms", args[-1]).group(1) + ) + profile_stats = ProfileStats(records_produced, execution_time) + args.pop(-1) + return Operation( + name, None if len(args) == 0 else args[0].strip(), profile_stats + ) + + # iterate plan operations + while i < len(self.plan): + current_op = self.plan[i] + op_level = current_op.count(" ") + if op_level == level: + # if the operation level equal to the current level + # set the current operation and move next + child = _create_operation(current_op.split("|")) + if current: + current = stack.pop() + current.append_child(child) + current = child + i += 1 + elif op_level == level + 1: + # if the operation is child of the current operation + # add it as child and set as current operation + child = _create_operation(current_op.split("|")) + current.append_child(child) + stack.append(current) + current = child + level += 1 + i += 1 + elif op_level < level: + # if the operation is not child of current operation + # go back to it's parent operation + levels_back = level - op_level + 1 + for _ in range(levels_back): + current = stack.pop() + level -= levels_back + else: + raise Exception("corrupted plan") + return stack[0] diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/node.py b/venv/lib/python3.12/site-packages/redis/commands/graph/node.py new file mode 100644 index 0000000..4546a39 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/node.py @@ -0,0 +1,88 @@ +from ..helpers import quote_string + + +class Node: + """ + A node within the graph. + """ + + def __init__(self, node_id=None, alias=None, label=None, properties=None): + """ + Create a new node. + """ + self.id = node_id + self.alias = alias + if isinstance(label, list): + label = [inner_label for inner_label in label if inner_label != ""] + + if ( + label is None + or label == "" + or (isinstance(label, list) and len(label) == 0) + ): + self.label = None + self.labels = None + elif isinstance(label, str): + self.label = label + self.labels = [label] + elif isinstance(label, list) and all( + [isinstance(inner_label, str) for inner_label in label] + ): + self.label = label[0] + self.labels = label + else: + raise AssertionError( + "label should be either None, string or a list of strings" + ) + + self.properties = properties or {} + + def to_string(self): + res = "" + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + + return res + + def __str__(self): + res = "(" + if self.alias: + res += self.alias + if self.labels: + res += ":" + ":".join(self.labels) + if self.properties: + props = ",".join( + key + ":" + str(quote_string(val)) + for key, val in sorted(self.properties.items()) + ) + res += "{" + props + "}" + res += ")" + + return res + + def __eq__(self, rhs): + # Type checking + if not isinstance(rhs, Node): + return False + + # Quick positive check, if both IDs are set. + if self.id is not None and rhs.id is not None and self.id != rhs.id: + return False + + # Label should match. + if self.label != rhs.label: + return False + + # Quick check for number of properties. + if len(self.properties) != len(rhs.properties): + return False + + # Compare properties. + if self.properties != rhs.properties: + return False + + return True diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/path.py b/venv/lib/python3.12/site-packages/redis/commands/graph/path.py new file mode 100644 index 0000000..ee22dc8 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/path.py @@ -0,0 +1,78 @@ +from .edge import Edge +from .node import Node + + +class Path: + def __init__(self, nodes, edges): + if not (isinstance(nodes, list) and isinstance(edges, list)): + raise TypeError("nodes and edges must be list") + + self._nodes = nodes + self._edges = edges + self.append_type = Node + + @classmethod + def new_empty_path(cls): + return cls([], []) + + def nodes(self): + return self._nodes + + def edges(self): + return self._edges + + def get_node(self, index): + return self._nodes[index] + + def get_relationship(self, index): + return self._edges[index] + + def first_node(self): + return self._nodes[0] + + def last_node(self): + return self._nodes[-1] + + def edge_count(self): + return len(self._edges) + + def nodes_count(self): + return len(self._nodes) + + def add_node(self, node): + if not isinstance(node, self.append_type): + raise AssertionError("Add Edge before adding Node") + self._nodes.append(node) + self.append_type = Edge + return self + + def add_edge(self, edge): + if not isinstance(edge, self.append_type): + raise AssertionError("Add Node before adding Edge") + self._edges.append(edge) + self.append_type = Node + return self + + def __eq__(self, other): + # Type checking + if not isinstance(other, Path): + return False + + return self.nodes() == other.nodes() and self.edges() == other.edges() + + def __str__(self): + res = "<" + edge_count = self.edge_count() + for i in range(0, edge_count): + node_id = self.get_node(i).id + res += "(" + str(node_id) + ")" + edge = self.get_relationship(i) + res += ( + "-[" + str(int(edge.id)) + "]->" + if edge.src_node == node_id + else "<-[" + str(int(edge.id)) + "]-" + ) + node_id = self.get_node(edge_count).id + res += "(" + str(node_id) + ")" + res += ">" + return res diff --git a/venv/lib/python3.12/site-packages/redis/commands/graph/query_result.py b/venv/lib/python3.12/site-packages/redis/commands/graph/query_result.py new file mode 100644 index 0000000..7c7f58b --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/commands/graph/query_result.py @@ -0,0 +1,573 @@ +import sys +from collections import OrderedDict +from distutils.util import strtobool + +# from prettytable import PrettyTable +from redis import ResponseError + +from .edge import Edge +from .exceptions import VersionMismatchException +from .node import Node +from .path import Path + +LABELS_ADDED = "Labels added" +LABELS_REMOVED = "Labels removed" +NODES_CREATED = "Nodes created" +NODES_DELETED = "Nodes deleted" +RELATIONSHIPS_DELETED = "Relationships deleted" +PROPERTIES_SET = "Properties set" +PROPERTIES_REMOVED = "Properties removed" +RELATIONSHIPS_CREATED = "Relationships created" +INDICES_CREATED = "Indices created" +INDICES_DELETED = "Indices deleted" +CACHED_EXECUTION = "Cached execution" +INTERNAL_EXECUTION_TIME = "internal execution time" + +STATS = [ + LABELS_ADDED, + LABELS_REMOVED, + NODES_CREATED, + PROPERTIES_SET, + PROPERTIES_REMOVED, + RELATIONSHIPS_CREATED, + NODES_DELETED, + RELATIONSHIPS_DELETED, + INDICES_CREATED, + INDICES_DELETED, + CACHED_EXECUTION, + INTERNAL_EXECUTION_TIME, +] + + +class ResultSetColumnTypes: + COLUMN_UNKNOWN = 0 + COLUMN_SCALAR = 1 + COLUMN_NODE = 2 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa + COLUMN_RELATION = 3 # Unused as of RedisGraph v2.1.0, retained for backwards compatibility. # noqa + + +class ResultSetScalarTypes: + VALUE_UNKNOWN = 0 + VALUE_NULL = 1 + VALUE_STRING = 2 + VALUE_INTEGER = 3 + VALUE_BOOLEAN = 4 + VALUE_DOUBLE = 5 + VALUE_ARRAY = 6 + VALUE_EDGE = 7 + VALUE_NODE = 8 + VALUE_PATH = 9 + VALUE_MAP = 10 + VALUE_POINT = 11 + + +class QueryResult: + def __init__(self, graph, response, profile=False): + """ + A class that represents a result of the query operation. + + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + self.parse_results(response) + + def _check_for_errors(self, response): + """ + Check if the response contains an error. + """ + if isinstance(response[0], ResponseError): + error = response[0] + if str(error) == "version mismatch": + version = response[1] + error = VersionMismatchException(version) + raise error + + # If we encountered a run-time error, the last response + # element will be an exception + if isinstance(response[-1], ResponseError): + raise response[-1] + + def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = self.parse_records(raw_result_set) + + def parse_statistics(self, raw_statistics): + """ + Parse the statistics returned in the response. + """ + self.statistics = {} + + # decode statistics + for idx, stat in enumerate(raw_statistics): + if isinstance(stat, bytes): + raw_statistics[idx] = stat.decode() + + for s in STATS: + v = self._get_value(s, raw_statistics) + if v is not None: + self.statistics[s] = v + + def parse_header(self, raw_result_set): + """ + Parse the header of the result. + """ + # An array of column name/column type pairs. + header = raw_result_set[0] + return header + + def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [ + [ + self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + for row in raw_result_set[1] + ] + + return records + + def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = self.graph.get_property(prop[0]) + prop_value = self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + def parse_string(self, cell): + """ + Parse the cell as a string. + """ + if isinstance(cell, bytes): + return cell.decode() + elif not isinstance(cell, str): + return str(cell) + else: + return cell + + def parse_node(self, cell): + """ + Parse the cell to a node. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + node_id = int(cell[0]) + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(self.graph.get_label(inner_label)) + properties = self.parse_entity_properties(cell[2]) + return Node(node_id=node_id, label=labels, properties=properties) + + def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = self.parse_scalar(cell[0]) + edges = self.parse_scalar(cell[1]) + return Path(nodes, edges) + + def parse_map(self, cell): + """ + Parse the cell as a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = self.parse_scalar(cell[i + 1]) + + return m + + def parse_point(self, cell): + """ + Parse the cell to point. + """ + p = {} + # A point is received an array of the form: [latitude, longitude] + # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa + p["latitude"] = float(cell[0]) + p["longitude"] = float(cell[1]) + return p + + def parse_null(self, cell): + """ + Parse a null value. + """ + return None + + def parse_integer(self, cell): + """ + Parse the integer value from the cell. + """ + return int(cell) + + def parse_boolean(self, value): + """ + Parse the cell value as a boolean. + """ + value = value.decode() if isinstance(value, bytes) else value + try: + scalar = True if strtobool(value) else False + except ValueError: + sys.stderr.write("unknown boolean type\n") + scalar = None + return scalar + + def parse_double(self, cell): + """ + Parse the cell as a double. + """ + return float(cell) + + def parse_array(self, value): + """ + Parse an array of values. + """ + scalar = [self.parse_scalar(value[i]) for i in range(len(value))] + return scalar + + def parse_unknown(self, cell): + """ + Parse a cell of unknown type. + """ + sys.stderr.write("Unknown type\n") + return None + + def parse_scalar(self, cell): + """ + Parse a scalar value from a cell in the result set. + """ + scalar_type = int(cell[0]) + value = cell[1] + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + def parse_profile(self, response): + self.result_set = [x[0 : x.index(",")].strip() for x in response] + + def is_empty(self): + return len(self.result_set) == 0 + + @staticmethod + def _get_value(prop, statistics): + for stat in statistics: + if prop in stat: + return float(stat.split(": ")[1].split(" ")[0]) + + return None + + def _get_stat(self, stat): + return self.statistics[stat] if stat in self.statistics else 0 + + @property + def labels_added(self): + """Returns the number of labels added in the query""" + return self._get_stat(LABELS_ADDED) + + @property + def labels_removed(self): + """Returns the number of labels removed in the query""" + return self._get_stat(LABELS_REMOVED) + + @property + def nodes_created(self): + """Returns the number of nodes created in the query""" + return self._get_stat(NODES_CREATED) + + @property + def nodes_deleted(self): + """Returns the number of nodes deleted in the query""" + return self._get_stat(NODES_DELETED) + + @property + def properties_set(self): + """Returns the number of properties set in the query""" + return self._get_stat(PROPERTIES_SET) + + @property + def properties_removed(self): + """Returns the number of properties removed in the query""" + return self._get_stat(PROPERTIES_REMOVED) + + @property + def relationships_created(self): + """Returns the number of relationships created in the query""" + return self._get_stat(RELATIONSHIPS_CREATED) + + @property + def relationships_deleted(self): + """Returns the number of relationships deleted in the query""" + return self._get_stat(RELATIONSHIPS_DELETED) + + @property + def indices_created(self): + """Returns the number of indices created in the query""" + return self._get_stat(INDICES_CREATED) + + @property + def indices_deleted(self): + """Returns the number of indices deleted in the query""" + return self._get_stat(INDICES_DELETED) + + @property + def cached_execution(self): + """Returns whether or not the query execution plan was cached""" + return self._get_stat(CACHED_EXECUTION) == 1 + + @property + def run_time_ms(self): + """Returns the server execution time of the query""" + return self._get_stat(INTERNAL_EXECUTION_TIME) + + @property + def parse_scalar_types(self): + return { + ResultSetScalarTypes.VALUE_NULL: self.parse_null, + ResultSetScalarTypes.VALUE_STRING: self.parse_string, + ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, + ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, + ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, + ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, + ResultSetScalarTypes.VALUE_NODE: self.parse_node, + ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, + ResultSetScalarTypes.VALUE_PATH: self.parse_path, + ResultSetScalarTypes.VALUE_MAP: self.parse_map, + ResultSetScalarTypes.VALUE_POINT: self.parse_point, + ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, + } + + @property + def parse_record_types(self): + return { + ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, + ResultSetColumnTypes.COLUMN_NODE: self.parse_node, + ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, + ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, + } + + +class AsyncQueryResult(QueryResult): + """ + Async version for the QueryResult class - a class that + represents a result of the query operation. + """ + + def __init__(self): + """ + To init the class you must call self.initialize() + """ + pass + + async def initialize(self, graph, response, profile=False): + """ + Initializes the class. + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + await self.parse_results(response) + + return self + + async def parse_node(self, cell): + """ + Parses a node from the cell. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(await self.graph.get_label(inner_label)) + properties = await self.parse_entity_properties(cell[2]) + node_id = int(cell[0]) + return Node(node_id=node_id, label=labels, properties=properties) + + async def parse_scalar(self, cell): + """ + Parses a scalar value from the server response. + """ + scalar_type = int(cell[0]) + value = cell[1] + try: + scalar = await self.parse_scalar_types[scalar_type](value) + except TypeError: + # Not all of the functions are async + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + async def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [] + for row in raw_result_set[1]: + record = [ + await self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + records.append(record) + + return records + + async def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = await self.parse_records(raw_result_set) + + async def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = await self.graph.get_property(prop[0]) + prop_value = await self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + async def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = await self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = await self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + async def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = await self.parse_scalar(cell[0]) + edges = await self.parse_scalar(cell[1]) + return Path(nodes, edges) + + async def parse_map(self, cell): + """ + Parse the cell to a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = await self.parse_scalar(cell[i + 1]) + + return m + + async def parse_array(self, value): + """ + Parse array value. + """ + scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] + return scalar diff --git a/venv/lib/python3.12/site-packages/redis/commands/helpers.py b/venv/lib/python3.12/site-packages/redis/commands/helpers.py index 859a43a..324d981 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/helpers.py +++ b/venv/lib/python3.12/site-packages/redis/commands/helpers.py @@ -43,32 +43,19 @@ def parse_to_list(response): """Optimistically parse the response to a list.""" res = [] - special_values = {"infinity", "nan", "-infinity"} - if response is None: return res for item in response: - if item is None: - res.append(None) - continue try: - item_str = nativestr(item) + res.append(int(item)) + except ValueError: + try: + res.append(float(item)) + except ValueError: + res.append(nativestr(item)) except TypeError: res.append(None) - continue - - if isinstance(item_str, str) and item_str.lower() in special_values: - res.append(item_str) # Keep as string - else: - try: - res.append(int(item)) - except ValueError: - try: - res.append(float(item)) - except ValueError: - res.append(item_str) - return res @@ -77,11 +64,6 @@ def parse_list_to_dict(response): for i in range(0, len(response), 2): if isinstance(response[i], list): res["Child iterators"].append(parse_list_to_dict(response[i])) - try: - if isinstance(response[i + 1], list): - res["Child iterators"].append(parse_list_to_dict(response[i + 1])) - except IndexError: - pass elif isinstance(response[i + 1], list): res["Child iterators"] = [parse_list_to_dict(response[i + 1])] else: @@ -92,6 +74,25 @@ def parse_list_to_dict(response): return res +def parse_to_dict(response): + if response is None: + return {} + + res = {} + for det in response: + if isinstance(det[1], list): + res[det[0]] = parse_list_to_dict(det[1]) + else: + try: # try to set the attribute. may be provided without value + try: # try to convert the value to float + res[det[0]] = float(det[1]) + except (TypeError, ValueError): + res[det[0]] = det[1] + except IndexError: + pass + return res + + def random_string(length=10): """ Returns a random N character long string. @@ -101,6 +102,26 @@ def random_string(length=10): ) +def quote_string(v): + """ + RedisGraph strings must be quoted, + quote_string wraps given v with quotes incase + v is a string. + """ + + if isinstance(v, bytes): + v = v.decode() + elif not isinstance(v, str): + return v + if len(v) == 0: + return '""' + + v = v.replace("\\", "\\\\") + v = v.replace('"', '\\"') + + return f'"{v}"' + + def decode_dict_keys(obj): """Decode the keys of the given dictionary with utf-8.""" newobj = copy.copy(obj) @@ -111,6 +132,33 @@ def decode_dict_keys(obj): return newobj +def stringify_param_value(value): + """ + Turn a parameter value into a string suitable for the params header of + a Cypher command. + You may pass any value that would be accepted by `json.dumps()`. + + Ways in which output differs from that of `str()`: + * Strings are quoted. + * None --> "null". + * In dictionaries, keys are _not_ quoted. + + :param value: The parameter value to be turned into a string. + :return: string + """ + + if isinstance(value, str): + return quote_string(value) + elif value is None: + return "null" + elif isinstance(value, (list, tuple)): + return f'[{",".join(map(stringify_param_value, value))}]' + elif isinstance(value, dict): + return f'{{{",".join(f"{k}:{stringify_param_value(v)}" for k, v in value.items())}}}' # noqa + else: + return str(value) + + def get_protocol_version(client): if isinstance(client, redis.Redis) or isinstance(client, redis.asyncio.Redis): return client.connection_pool.connection_kwargs.get("protocol") diff --git a/venv/lib/python3.12/site-packages/redis/commands/json/__init__.py b/venv/lib/python3.12/site-packages/redis/commands/json/__init__.py index 0e717b3..01077e6 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/json/__init__.py +++ b/venv/lib/python3.12/site-packages/redis/commands/json/__init__.py @@ -120,7 +120,7 @@ class JSON(JSONCommands): startup_nodes=self.client.nodes_manager.startup_nodes, result_callbacks=self.client.result_callbacks, cluster_response_callbacks=self.client.cluster_response_callbacks, - cluster_error_retry_attempts=self.client.retry.get_retries(), + cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, read_from_replicas=self.client.read_from_replicas, reinitialize_steps=self.client.reinitialize_steps, lock=self.client._lock, diff --git a/venv/lib/python3.12/site-packages/redis/commands/json/_util.py b/venv/lib/python3.12/site-packages/redis/commands/json/_util.py index 5ef2edc..3400bcd 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/json/_util.py +++ b/venv/lib/python3.12/site-packages/redis/commands/json/_util.py @@ -1,5 +1,3 @@ -from typing import List, Mapping, Union +from typing import Any, Dict, List, Union -JsonType = Union[ - str, int, float, bool, None, Mapping[str, "JsonType"], List["JsonType"] -] +JsonType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]] diff --git a/venv/lib/python3.12/site-packages/redis/commands/json/commands.py b/venv/lib/python3.12/site-packages/redis/commands/json/commands.py index 48849e1..0f92e0d 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/json/commands.py +++ b/venv/lib/python3.12/site-packages/redis/commands/json/commands.py @@ -15,7 +15,7 @@ class JSONCommands: def arrappend( self, name: str, path: Optional[str] = Path.root_path(), *args: List[JsonType] - ) -> List[Optional[int]]: + ) -> List[Union[int, None]]: """Append the objects ``args`` to the array under the ``path` in key ``name``. @@ -33,7 +33,7 @@ class JSONCommands: scalar: int, start: Optional[int] = None, stop: Optional[int] = None, - ) -> List[Optional[int]]: + ) -> List[Union[int, None]]: """ Return the index of ``scalar`` in the JSON array under ``path`` at key ``name``. @@ -49,11 +49,11 @@ class JSONCommands: if stop is not None: pieces.append(stop) - return self.execute_command("JSON.ARRINDEX", *pieces, keys=[name]) + return self.execute_command("JSON.ARRINDEX", *pieces) def arrinsert( self, name: str, path: str, index: int, *args: List[JsonType] - ) -> List[Optional[int]]: + ) -> List[Union[int, None]]: """Insert the objects ``args`` to the array at index ``index`` under the ``path` in key ``name``. @@ -66,20 +66,20 @@ class JSONCommands: def arrlen( self, name: str, path: Optional[str] = Path.root_path() - ) -> List[Optional[int]]: + ) -> List[Union[int, None]]: """Return the length of the array JSON value under ``path`` at key``name``. For more information see `JSON.ARRLEN `_. """ # noqa - return self.execute_command("JSON.ARRLEN", name, str(path), keys=[name]) + return self.execute_command("JSON.ARRLEN", name, str(path)) def arrpop( self, name: str, path: Optional[str] = Path.root_path(), index: Optional[int] = -1, - ) -> List[Optional[str]]: + ) -> List[Union[str, None]]: """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. @@ -89,7 +89,7 @@ class JSONCommands: def arrtrim( self, name: str, path: str, start: int, stop: int - ) -> List[Optional[int]]: + ) -> List[Union[int, None]]: """Trim the array JSON value under ``path`` at key ``name`` to the inclusive range given by ``start`` and ``stop``. @@ -102,34 +102,32 @@ class JSONCommands: For more information see `JSON.TYPE `_. """ # noqa - return self.execute_command("JSON.TYPE", name, str(path), keys=[name]) + return self.execute_command("JSON.TYPE", name, str(path)) def resp(self, name: str, path: Optional[str] = Path.root_path()) -> List: """Return the JSON value under ``path`` at key ``name``. For more information see `JSON.RESP `_. """ # noqa - return self.execute_command("JSON.RESP", name, str(path), keys=[name]) + return self.execute_command("JSON.RESP", name, str(path)) def objkeys( self, name: str, path: Optional[str] = Path.root_path() - ) -> List[Optional[List[str]]]: + ) -> List[Union[List[str], None]]: """Return the key names in the dictionary JSON value under ``path`` at key ``name``. For more information see `JSON.OBJKEYS `_. """ # noqa - return self.execute_command("JSON.OBJKEYS", name, str(path), keys=[name]) + return self.execute_command("JSON.OBJKEYS", name, str(path)) - def objlen( - self, name: str, path: Optional[str] = Path.root_path() - ) -> List[Optional[int]]: + def objlen(self, name: str, path: Optional[str] = Path.root_path()) -> int: """Return the length of the dictionary JSON value under ``path`` at key ``name``. For more information see `JSON.OBJLEN `_. """ # noqa - return self.execute_command("JSON.OBJLEN", name, str(path), keys=[name]) + return self.execute_command("JSON.OBJLEN", name, str(path)) def numincrby(self, name: str, path: str, number: int) -> str: """Increment the numeric (integer or floating point) JSON value under @@ -175,7 +173,7 @@ class JSONCommands: def get( self, name: str, *args, no_escape: Optional[bool] = False - ) -> Optional[List[JsonType]]: + ) -> List[JsonType]: """ Get the object stored as a JSON value at key ``name``. @@ -199,7 +197,7 @@ class JSONCommands: # Handle case where key doesn't exist. The JSONDecoder would raise a # TypeError exception since it can't decode None try: - return self.execute_command("JSON.GET", *pieces, keys=[name]) + return self.execute_command("JSON.GET", *pieces) except TypeError: return None @@ -213,7 +211,7 @@ class JSONCommands: pieces = [] pieces += keys pieces.append(str(path)) - return self.execute_command("JSON.MGET", *pieces, keys=keys) + return self.execute_command("JSON.MGET", *pieces) def set( self, @@ -314,7 +312,7 @@ class JSONCommands: """ - with open(file_name) as fp: + with open(file_name, "r") as fp: file_content = loads(fp.read()) return self.set(name, path, file_content, nx=nx, xx=xx, decode_keys=decode_keys) @@ -326,7 +324,7 @@ class JSONCommands: nx: Optional[bool] = False, xx: Optional[bool] = False, decode_keys: Optional[bool] = False, - ) -> Dict[str, bool]: + ) -> List[Dict[str, bool]]: """ Iterate over ``root_folder`` and set each JSON file to a value under ``json_path`` with the file name as the key. @@ -357,7 +355,7 @@ class JSONCommands: return set_files_result - def strlen(self, name: str, path: Optional[str] = None) -> List[Optional[int]]: + def strlen(self, name: str, path: Optional[str] = None) -> List[Union[int, None]]: """Return the length of the string JSON value under ``path`` at key ``name``. @@ -366,7 +364,7 @@ class JSONCommands: pieces = [name] if path is not None: pieces.append(str(path)) - return self.execute_command("JSON.STRLEN", *pieces, keys=[name]) + return self.execute_command("JSON.STRLEN", *pieces) def toggle( self, name: str, path: Optional[str] = Path.root_path() @@ -379,7 +377,7 @@ class JSONCommands: return self.execute_command("JSON.TOGGLE", name, str(path)) def strappend( - self, name: str, value: str, path: Optional[str] = Path.root_path() + self, name: str, value: str, path: Optional[int] = Path.root_path() ) -> Union[int, List[Optional[int]]]: """Append to the string JSON value. If two options are specified after the key name, the path is determined to be the first. If a single diff --git a/venv/lib/python3.12/site-packages/redis/commands/redismodules.py b/venv/lib/python3.12/site-packages/redis/commands/redismodules.py index 078844f..7e2045a 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/redismodules.py +++ b/venv/lib/python3.12/site-packages/redis/commands/redismodules.py @@ -1,14 +1,4 @@ -from __future__ import annotations - from json import JSONDecoder, JSONEncoder -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from .bf import BFBloom, CFBloom, CMSBloom, TDigestBloom, TOPKBloom - from .json import JSON - from .search import AsyncSearch, Search - from .timeseries import TimeSeries - from .vectorset import VectorSet class RedisModuleCommands: @@ -16,7 +6,7 @@ class RedisModuleCommands: modules into the command namespace. """ - def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()) -> JSON: + def json(self, encoder=JSONEncoder(), decoder=JSONDecoder()): """Access the json namespace, providing support for redis json.""" from .json import JSON @@ -24,7 +14,7 @@ class RedisModuleCommands: jj = JSON(client=self, encoder=encoder, decoder=decoder) return jj - def ft(self, index_name="idx") -> Search: + def ft(self, index_name="idx"): """Access the search namespace, providing support for redis search.""" from .search import Search @@ -32,7 +22,7 @@ class RedisModuleCommands: s = Search(client=self, index_name=index_name) return s - def ts(self) -> TimeSeries: + def ts(self): """Access the timeseries namespace, providing support for redis timeseries data. """ @@ -42,7 +32,7 @@ class RedisModuleCommands: s = TimeSeries(client=self) return s - def bf(self) -> BFBloom: + def bf(self): """Access the bloom namespace.""" from .bf import BFBloom @@ -50,7 +40,7 @@ class RedisModuleCommands: bf = BFBloom(client=self) return bf - def cf(self) -> CFBloom: + def cf(self): """Access the bloom namespace.""" from .bf import CFBloom @@ -58,7 +48,7 @@ class RedisModuleCommands: cf = CFBloom(client=self) return cf - def cms(self) -> CMSBloom: + def cms(self): """Access the bloom namespace.""" from .bf import CMSBloom @@ -66,7 +56,7 @@ class RedisModuleCommands: cms = CMSBloom(client=self) return cms - def topk(self) -> TOPKBloom: + def topk(self): """Access the bloom namespace.""" from .bf import TOPKBloom @@ -74,7 +64,7 @@ class RedisModuleCommands: topk = TOPKBloom(client=self) return topk - def tdigest(self) -> TDigestBloom: + def tdigest(self): """Access the bloom namespace.""" from .bf import TDigestBloom @@ -82,20 +72,32 @@ class RedisModuleCommands: tdigest = TDigestBloom(client=self) return tdigest - def vset(self) -> VectorSet: - """Access the VectorSet commands namespace.""" + def graph(self, index_name="idx"): + """Access the graph namespace, providing support for + redis graph data. + """ - from .vectorset import VectorSet + from .graph import Graph - vset = VectorSet(client=self) - return vset + g = Graph(client=self, name=index_name) + return g class AsyncRedisModuleCommands(RedisModuleCommands): - def ft(self, index_name="idx") -> AsyncSearch: + def ft(self, index_name="idx"): """Access the search namespace, providing support for redis search.""" from .search import AsyncSearch s = AsyncSearch(client=self, index_name=index_name) return s + + def graph(self, index_name="idx"): + """Access the graph namespace, providing support for + redis graph data. + """ + + from .graph import AsyncGraph + + g = AsyncGraph(client=self, name=index_name) + return g diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/_util.py b/venv/lib/python3.12/site-packages/redis/commands/search/_util.py index 191600d..dd1dff3 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/_util.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/_util.py @@ -1,7 +1,7 @@ -def to_string(s, encoding: str = "utf-8"): +def to_string(s): if isinstance(s, str): return s elif isinstance(s, bytes): - return s.decode(encoding, "ignore") + return s.decode("utf-8", "ignore") else: return s # Not a string we care about diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/aggregation.py b/venv/lib/python3.12/site-packages/redis/commands/search/aggregation.py index 00435f6..50d18f4 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/aggregation.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/aggregation.py @@ -1,7 +1,5 @@ from typing import List, Union -from redis.commands.search.dialect import DEFAULT_DIALECT - FIELDNAME = object() @@ -26,7 +24,7 @@ class Reducer: NAME = None - def __init__(self, *args: str) -> None: + def __init__(self, *args: List[str]) -> None: self._args = args self._field = None self._alias = None @@ -112,11 +110,9 @@ class AggregateRequest: self._with_schema = False self._verbatim = False self._cursor = [] - self._dialect = DEFAULT_DIALECT - self._add_scores = False - self._scorer = "TFIDF" + self._dialect = None - def load(self, *fields: str) -> "AggregateRequest": + def load(self, *fields: List[str]) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -223,7 +219,7 @@ class AggregateRequest: self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest": + def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -296,24 +292,6 @@ class AggregateRequest: self._with_schema = True return self - def add_scores(self) -> "AggregateRequest": - """ - If set, includes the score as an ordinary field of the row. - """ - self._add_scores = True - return self - - def scorer(self, scorer: str) -> "AggregateRequest": - """ - Use a different scoring function to evaluate document relevance. - Default is `TFIDF`. - - :param scorer: The scoring function to use - (e.g. `TFIDF.DOCNORM` or `BM25`) - """ - self._scorer = scorer - return self - def verbatim(self) -> "AggregateRequest": self._verbatim = True return self @@ -337,19 +315,12 @@ class AggregateRequest: if self._verbatim: ret.append("VERBATIM") - if self._scorer: - ret.extend(["SCORER", self._scorer]) - - if self._add_scores: - ret.append("ADDSCORES") - if self._cursor: ret += self._cursor if self._loadall: ret.append("LOAD") ret.append("*") - elif self._loadfields: ret.append("LOAD") ret.append(str(len(self._loadfields))) diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/commands.py b/venv/lib/python3.12/site-packages/redis/commands/search/commands.py index 80d9b35..2df2b5a 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/commands.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/commands.py @@ -2,16 +2,13 @@ import itertools import time from typing import Dict, List, Optional, Union -from redis.client import NEVER_DECODE, Pipeline +from redis.client import Pipeline from redis.utils import deprecated_function -from ..helpers import get_protocol_version +from ..helpers import get_protocol_version, parse_to_dict from ._util import to_string from .aggregation import AggregateRequest, AggregateResult, Cursor from .document import Document -from .field import Field -from .index_definition import IndexDefinition -from .profile_information import ProfileInformation from .query import Query from .result import Result from .suggestion import SuggestionParser @@ -23,6 +20,7 @@ ALTER_CMD = "FT.ALTER" SEARCH_CMD = "FT.SEARCH" ADD_CMD = "FT.ADD" ADDHASH_CMD = "FT.ADDHASH" +DROP_CMD = "FT.DROP" DROPINDEX_CMD = "FT.DROPINDEX" EXPLAIN_CMD = "FT.EXPLAIN" EXPLAINCLI_CMD = "FT.EXPLAINCLI" @@ -34,6 +32,7 @@ SPELLCHECK_CMD = "FT.SPELLCHECK" DICT_ADD_CMD = "FT.DICTADD" DICT_DEL_CMD = "FT.DICTDEL" DICT_DUMP_CMD = "FT.DICTDUMP" +GET_CMD = "FT.GET" MGET_CMD = "FT.MGET" CONFIG_CMD = "FT.CONFIG" TAGVALS_CMD = "FT.TAGVALS" @@ -66,7 +65,7 @@ class SearchCommands: def _parse_results(self, cmd, res, **kwargs): if get_protocol_version(self.client) in ["3", 3]: - return ProfileInformation(res) if cmd == "FT.PROFILE" else res + return res else: return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) @@ -81,7 +80,6 @@ class SearchCommands: duration=kwargs["duration"], has_payload=kwargs["query"]._with_payloads, with_scores=kwargs["query"]._with_scores, - field_encodings=kwargs["query"]._return_fields_decode_as, ) def _parse_aggregate(self, res, **kwargs): @@ -100,7 +98,7 @@ class SearchCommands: with_scores=query._with_scores, ) - return result, ProfileInformation(res[1]) + return result, parse_to_dict(res[1]) def _parse_spellcheck(self, res, **kwargs): corrections = {} @@ -153,43 +151,44 @@ class SearchCommands: def create_index( self, - fields: List[Field], - no_term_offsets: bool = False, - no_field_flags: bool = False, - stopwords: Optional[List[str]] = None, - definition: Optional[IndexDefinition] = None, + fields, + no_term_offsets=False, + no_field_flags=False, + stopwords=None, + definition=None, max_text_fields=False, temporary=None, - no_highlight: bool = False, - no_term_frequencies: bool = False, - skip_initial_scan: bool = False, + no_highlight=False, + no_term_frequencies=False, + skip_initial_scan=False, ): """ - Creates the search index. The index must not already exist. + Create the search index. The index must not already exist. - For more information, see https://redis.io/commands/ft.create/ + ### Parameters: - Args: - fields: A list of Field objects. - no_term_offsets: If `true`, term offsets will not be saved in the index. - no_field_flags: If true, field flags that allow searching in specific fields - will not be saved. - stopwords: If provided, the index will be created with this custom stopword - list. The list can be empty. - definition: If provided, the index will be created with this custom index - definition. - max_text_fields: If true, indexes will be encoded as if there were more than - 32 text fields, allowing for additional fields beyond 32. - temporary: Creates a lightweight temporary index which will expire after the - specified period of inactivity. The internal idle timer is reset - whenever the index is searched or added to. - no_highlight: If true, disables highlighting support. Also implied by - `no_term_offsets`. - no_term_frequencies: If true, term frequencies will not be saved in the - index. - skip_initial_scan: If true, the initial scan and indexing will be skipped. + - **fields**: a list of TextField or NumericField objects + - **no_term_offsets**: If true, we will not save term offsets in + the index + - **no_field_flags**: If true, we will not save field flags that + allow searching in specific fields + - **stopwords**: If not None, we create the index with this custom + stopword list. The list can be empty + - **max_text_fields**: If true, we will encode indexes as if there + were more than 32 text fields which allows you to add additional + fields (beyond 32). + - **temporary**: Create a lightweight temporary index which will + expire after the specified period of inactivity (in seconds). The + internal idle timer is reset whenever the index is searched or added to. + - **no_highlight**: If true, disabling highlighting support. + Also implied by no_term_offsets. + - **no_term_frequencies**: If true, we avoid saving the term frequencies + in the index. + - **skip_initial_scan**: If true, we do not scan and index. + + For more information see `FT.CREATE `_. + """ # noqa - """ args = [CREATE_CMD, self.index_name] if definition is not None: args += definition.args @@ -253,18 +252,8 @@ class SearchCommands: For more information see `FT.DROPINDEX `_. """ # noqa - args = [DROPINDEX_CMD, self.index_name] - - delete_str = ( - "DD" - if isinstance(delete_documents, bool) and delete_documents is True - else "" - ) - - if delete_str: - args.append(delete_str) - - return self.execute_command(*args) + delete_str = "DD" if delete_documents else "" + return self.execute_command(DROPINDEX_CMD, self.index_name, delete_str) def _add_document( self, @@ -346,30 +335,30 @@ class SearchCommands: """ Add a single document to the index. - Args: + ### Parameters - doc_id: the id of the saved document. - nosave: if set to true, we just index the document, and don't + - **doc_id**: the id of the saved document. + - **nosave**: if set to true, we just index the document, and don't save a copy of it. This means that searches will just return ids. - score: the document ranking, between 0.0 and 1.0 - payload: optional inner-index payload we can save for fast - access in scoring functions - replace: if True, and the document already is in the index, - we perform an update and reindex the document - partial: if True, the fields specified will be added to the + - **score**: the document ranking, between 0.0 and 1.0 + - **payload**: optional inner-index payload we can save for fast + i access in scoring functions + - **replace**: if True, and the document already is in the index, + we perform an update and reindex the document + - **partial**: if True, the fields specified will be added to the existing document. This has the added benefit that any fields specified with `no_index` will not be reindexed again. Implies `replace` - language: Specify the language used for document tokenization. - no_create: if True, the document is only updated and reindexed + - **language**: Specify the language used for document tokenization. + - **no_create**: if True, the document is only updated and reindexed if it already exists. If the document does not exist, an error will be returned. Implies `replace` - fields: kwargs dictionary of the document fields to be saved - and/or indexed. - NOTE: Geo points shoule be encoded as strings of "lon,lat" + - **fields** kwargs dictionary of the document fields to be saved + and/or indexed. + NOTE: Geo points shoule be encoded as strings of "lon,lat" """ # noqa return self._add_document( doc_id, @@ -404,7 +393,6 @@ class SearchCommands: doc_id, conn=None, score=score, language=language, replace=replace ) - @deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0") def delete_document(self, doc_id, conn=None, delete_actual_document=False): """ Delete a document from index @@ -439,7 +427,6 @@ class SearchCommands: return Document(id=id, **fields) - @deprecated_function(version="2.0.0", reason="deprecated since redisearch 2.0") def get(self, *ids): """ Returns the full contents of multiple documents. @@ -510,19 +497,14 @@ class SearchCommands: For more information see `FT.SEARCH `_. """ # noqa args, query = self._mk_query_args(query, query_params=query_params) - st = time.monotonic() - - options = {} - if get_protocol_version(self.client) not in ["3", 3]: - options[NEVER_DECODE] = True - - res = self.execute_command(SEARCH_CMD, *args, **options) + st = time.time() + res = self.execute_command(SEARCH_CMD, *args) if isinstance(res, Pipeline): return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) def explain( @@ -542,7 +524,7 @@ class SearchCommands: def aggregate( self, - query: Union[AggregateRequest, Cursor], + query: Union[str, Query], query_params: Dict[str, Union[str, int, float]] = None, ): """ @@ -573,7 +555,7 @@ class SearchCommands: ) def _get_aggregate_result( - self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool + self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool ): if has_cursor: if isinstance(query, Cursor): @@ -596,7 +578,7 @@ class SearchCommands: def profile( self, - query: Union[Query, AggregateRequest], + query: Union[str, Query, AggregateRequest], limited: bool = False, query_params: Optional[Dict[str, Union[str, int, float]]] = None, ): @@ -606,13 +588,13 @@ class SearchCommands: ### Parameters - **query**: This can be either an `AggregateRequest` or `Query`. + **query**: This can be either an `AggregateRequest`, `Query` or string. **limited**: If set to True, removes details of reader iterator. **query_params**: Define one or more value parameters. Each parameter has a name and a value. """ - st = time.monotonic() + st = time.time() cmd = [PROFILE_CMD, self.index_name, ""] if limited: cmd.append("LIMITED") @@ -631,20 +613,20 @@ class SearchCommands: res = self.execute_command(*cmd) return self._parse_results( - PROFILE_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 + PROFILE_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) def spellcheck(self, query, distance=None, include=None, exclude=None): """ Issue a spellcheck query - Args: + ### Parameters - query: search query. - distance: the maximal Levenshtein distance for spelling + **query**: search query. + **distance***: the maximal Levenshtein distance for spelling suggestions (default: 1, max: 4). - include: specifies an inclusion custom dictionary. - exclude: specifies an exclusion custom dictionary. + **include**: specifies an inclusion custom dictionary. + **exclude**: specifies an exclusion custom dictionary. For more information see `FT.SPELLCHECK `_. """ # noqa @@ -702,10 +684,6 @@ class SearchCommands: cmd = [DICT_DUMP_CMD, name] return self.execute_command(*cmd) - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call config_set from core module instead", - ) def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. @@ -720,10 +698,6 @@ class SearchCommands: raw = self.execute_command(*cmd) return raw == "OK" - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call config_get from core module instead", - ) def config_get(self, option: str) -> str: """Get runtime configuration option value. @@ -950,24 +924,19 @@ class AsyncSearchCommands(SearchCommands): For more information see `FT.SEARCH `_. """ # noqa args, query = self._mk_query_args(query, query_params=query_params) - st = time.monotonic() - - options = {} - if get_protocol_version(self.client) not in ["3", 3]: - options[NEVER_DECODE] = True - - res = await self.execute_command(SEARCH_CMD, *args, **options) + st = time.time() + res = await self.execute_command(SEARCH_CMD, *args) if isinstance(res, Pipeline): return res return self._parse_results( - SEARCH_CMD, res, query=query, duration=(time.monotonic() - st) * 1000.0 + SEARCH_CMD, res, query=query, duration=(time.time() - st) * 1000.0 ) async def aggregate( self, - query: Union[AggregateResult, Cursor], + query: Union[str, Query], query_params: Dict[str, Union[str, int, float]] = None, ): """ @@ -1025,10 +994,6 @@ class AsyncSearchCommands(SearchCommands): return self._parse_results(SPELLCHECK_CMD, res) - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call config_set from core module instead", - ) async def config_set(self, option: str, value: str) -> bool: """Set runtime configuration option. @@ -1043,10 +1008,6 @@ class AsyncSearchCommands(SearchCommands): raw = await self.execute_command(*cmd) return raw == "OK" - @deprecated_function( - version="8.0.0", - reason="deprecated since Redis 8.0, call config_get from core module instead", - ) async def config_get(self, option: str) -> str: """Get runtime configuration option value. diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/dialect.py b/venv/lib/python3.12/site-packages/redis/commands/search/dialect.py deleted file mode 100644 index 828b3f2..0000000 --- a/venv/lib/python3.12/site-packages/redis/commands/search/dialect.py +++ /dev/null @@ -1,3 +0,0 @@ -# Value for the default dialect to be used as a part of -# Search or Aggregate query. -DEFAULT_DIALECT = 2 diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/field.py b/venv/lib/python3.12/site-packages/redis/commands/search/field.py index 45cd403..76eb58c 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/field.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/field.py @@ -4,10 +4,6 @@ from redis import DataError class Field: - """ - A class representing a field in a document. - """ - NUMERIC = "NUMERIC" TEXT = "TEXT" WEIGHT = "WEIGHT" @@ -17,9 +13,6 @@ class Field: SORTABLE = "SORTABLE" NOINDEX = "NOINDEX" AS = "AS" - GEOSHAPE = "GEOSHAPE" - INDEX_MISSING = "INDEXMISSING" - INDEX_EMPTY = "INDEXEMPTY" def __init__( self, @@ -27,24 +20,8 @@ class Field: args: List[str] = None, sortable: bool = False, no_index: bool = False, - index_missing: bool = False, - index_empty: bool = False, as_name: str = None, ): - """ - Create a new field object. - - Args: - name: The name of the field. - args: - sortable: If `True`, the field will be sortable. - no_index: If `True`, the field will not be indexed. - index_missing: If `True`, it will be possible to search for documents that - have this field missing. - index_empty: If `True`, it will be possible to search for documents that - have this field empty. - as_name: If provided, this alias will be used for the field. - """ if args is None: args = [] self.name = name @@ -56,10 +33,6 @@ class Field: self.args_suffix.append(Field.SORTABLE) if no_index: self.args_suffix.append(Field.NOINDEX) - if index_missing: - self.args_suffix.append(Field.INDEX_MISSING) - if index_empty: - self.args_suffix.append(Field.INDEX_EMPTY) if no_index and not sortable: raise ValueError("Non-Sortable non-Indexable fields are ignored") @@ -118,21 +91,6 @@ class NumericField(Field): Field.__init__(self, name, args=[Field.NUMERIC], **kwargs) -class GeoShapeField(Field): - """ - GeoShapeField is used to enable within/contain indexing/searching - """ - - SPHERICAL = "SPHERICAL" - FLAT = "FLAT" - - def __init__(self, name: str, coord_system=None, **kwargs): - args = [Field.GEOSHAPE] - if coord_system: - args.append(coord_system) - Field.__init__(self, name, args=args, **kwargs) - - class GeoField(Field): """ GeoField is used to define a geo-indexing field in a schema definition @@ -181,7 +139,7 @@ class VectorField(Field): ``name`` is the name of the field. - ``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA". + ``algorithm`` can be "FLAT" or "HNSW". ``attributes`` each algorithm can have specific attributes. Some of them are mandatory and some of them are optional. See @@ -194,10 +152,10 @@ class VectorField(Field): if sort or noindex: raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.") - if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]: + if algorithm.upper() not in ["FLAT", "HNSW"]: raise DataError( - "Realtime vector indexing supporting 3 Indexing Methods:" - "'FLAT', 'HNSW', and 'SVS-VAMANA'." + "Realtime vector indexing supporting 2 Indexing Methods:" + "'FLAT' and 'HNSW'." ) attr_li = [] diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/index_definition.py b/venv/lib/python3.12/site-packages/redis/commands/search/indexDefinition.py similarity index 100% rename from venv/lib/python3.12/site-packages/redis/commands/search/index_definition.py rename to venv/lib/python3.12/site-packages/redis/commands/search/indexDefinition.py diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/profile_information.py b/venv/lib/python3.12/site-packages/redis/commands/search/profile_information.py deleted file mode 100644 index 23551be..0000000 --- a/venv/lib/python3.12/site-packages/redis/commands/search/profile_information.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Any - - -class ProfileInformation: - """ - Wrapper around FT.PROFILE response - """ - - def __init__(self, info: Any) -> None: - self._info: Any = info - - @property - def info(self) -> Any: - return self._info diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/query.py b/venv/lib/python3.12/site-packages/redis/commands/search/query.py index a8312a2..113ddf9 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/query.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/query.py @@ -1,7 +1,5 @@ from typing import List, Optional, Union -from redis.commands.search.dialect import DEFAULT_DIALECT - class Query: """ @@ -37,12 +35,11 @@ class Query: self._in_order: bool = False self._sortby: Optional[SortbyField] = None self._return_fields: List = [] - self._return_fields_decode_as: dict = {} self._summarize_fields: List = [] self._highlight_fields: List = [] self._language: Optional[str] = None self._expander: Optional[str] = None - self._dialect: int = DEFAULT_DIALECT + self._dialect: Optional[int] = None def query_string(self) -> str: """Return the query string of this query only.""" @@ -56,27 +53,13 @@ class Query: def return_fields(self, *fields) -> "Query": """Add fields to return fields.""" - for field in fields: - self.return_field(field) + self._return_fields += fields return self - def return_field( - self, - field: str, - as_field: Optional[str] = None, - decode_field: Optional[bool] = True, - encoding: Optional[str] = "utf8", - ) -> "Query": - """ - Add a field to the list of fields to return. - - - **field**: The field to include in query results - - **as_field**: The alias for the field - - **decode_field**: Whether to decode the field from bytes to string - - **encoding**: The encoding to use when decoding the field - """ + def return_field(self, field: str, as_field: Optional[str] = None) -> "Query": + """Add field to return fields (Optional: add 'AS' name + to the field).""" self._return_fields.append(field) - self._return_fields_decode_as[field] = encoding if decode_field else None if as_field is not None: self._return_fields += ("AS", as_field) return self @@ -179,8 +162,6 @@ class Query: Use a different scoring function to evaluate document relevance. Default is `TFIDF`. - Since Redis 8.0 default was changed to BM25STD. - :param scorer: The scoring function to use (e.g. `TFIDF.DOCNORM` or `BM25`) """ diff --git a/venv/lib/python3.12/site-packages/redis/commands/search/result.py b/venv/lib/python3.12/site-packages/redis/commands/search/result.py index e2c7efb..5b19e6f 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/search/result.py +++ b/venv/lib/python3.12/site-packages/redis/commands/search/result.py @@ -1,5 +1,3 @@ -from typing import Optional - from ._util import to_string from .document import Document @@ -11,19 +9,11 @@ class Result: """ def __init__( - self, - res, - hascontent, - duration=0, - has_payload=False, - with_scores=False, - field_encodings: Optional[dict] = None, + self, res, hascontent, duration=0, has_payload=False, with_scores=False ): """ - - duration: the execution time of the query - - has_payload: whether the query has payloads - - with_scores: whether the query has scores - - field_encodings: a dictionary of field encodings if any is provided + - **snippets**: An optional dictionary of the form + {field: snippet_size} for snippet formatting """ self.total = res[0] @@ -49,22 +39,18 @@ class Result: fields = {} if hascontent and res[i + fields_offset] is not None: - keys = map(to_string, res[i + fields_offset][::2]) - values = res[i + fields_offset][1::2] - - for key, value in zip(keys, values): - if field_encodings is None or key not in field_encodings: - fields[key] = to_string(value) - continue - - encoding = field_encodings[key] - - # If the encoding is None, we don't need to decode the value - if encoding is None: - fields[key] = value - else: - fields[key] = to_string(value, encoding=encoding) - + fields = ( + dict( + dict( + zip( + map(to_string, res[i + fields_offset][::2]), + map(to_string, res[i + fields_offset][1::2]), + ) + ) + ) + if hascontent + else {} + ) try: del fields["id"] except KeyError: diff --git a/venv/lib/python3.12/site-packages/redis/commands/sentinel.py b/venv/lib/python3.12/site-packages/redis/commands/sentinel.py index b2879b2..f745757 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/sentinel.py +++ b/venv/lib/python3.12/site-packages/redis/commands/sentinel.py @@ -11,35 +11,16 @@ class SentinelCommands: """Redis Sentinel's SENTINEL command.""" warnings.warn(DeprecationWarning("Use the individual sentinel_* methods")) - def sentinel_get_master_addr_by_name(self, service_name, return_responses=False): - """ - Returns a (host, port) pair for the given ``service_name`` when return_responses is True, - otherwise returns a boolean value that indicates if the command was successful. - """ - return self.execute_command( - "SENTINEL GET-MASTER-ADDR-BY-NAME", - service_name, - once=True, - return_responses=return_responses, - ) + def sentinel_get_master_addr_by_name(self, service_name): + """Returns a (host, port) pair for the given ``service_name``""" + return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) - def sentinel_master(self, service_name, return_responses=False): - """ - Returns a dictionary containing the specified masters state, when return_responses is True, - otherwise returns a boolean value that indicates if the command was successful. - """ - return self.execute_command( - "SENTINEL MASTER", service_name, return_responses=return_responses - ) + def sentinel_master(self, service_name): + """Returns a dictionary containing the specified masters state.""" + return self.execute_command("SENTINEL MASTER", service_name) def sentinel_masters(self): - """ - Returns a list of dictionaries containing each master's state. - - Important: This function is called by the Sentinel implementation and is - called directly on the Redis standalone client for sentinels, - so it doesn't support the "once" and "return_responses" options. - """ + """Returns a list of dictionaries containing each master's state.""" return self.execute_command("SENTINEL MASTERS") def sentinel_monitor(self, name, ip, port, quorum): @@ -50,27 +31,16 @@ class SentinelCommands: """Remove a master from Sentinel's monitoring""" return self.execute_command("SENTINEL REMOVE", name) - def sentinel_sentinels(self, service_name, return_responses=False): - """ - Returns a list of sentinels for ``service_name``, when return_responses is True, - otherwise returns a boolean value that indicates if the command was successful. - """ - return self.execute_command( - "SENTINEL SENTINELS", service_name, return_responses=return_responses - ) + def sentinel_sentinels(self, service_name): + """Returns a list of sentinels for ``service_name``""" + return self.execute_command("SENTINEL SENTINELS", service_name) def sentinel_set(self, name, option, value): """Set Sentinel monitoring parameters for a given master""" return self.execute_command("SENTINEL SET", name, option, value) def sentinel_slaves(self, service_name): - """ - Returns a list of slaves for ``service_name`` - - Important: This function is called by the Sentinel implementation and is - called directly on the Redis standalone client for sentinels, - so it doesn't support the "once" and "return_responses" options. - """ + """Returns a list of slaves for ``service_name``""" return self.execute_command("SENTINEL SLAVES", service_name) def sentinel_reset(self, pattern): diff --git a/venv/lib/python3.12/site-packages/redis/commands/timeseries/__init__.py b/venv/lib/python3.12/site-packages/redis/commands/timeseries/__init__.py index 3fbf821..4188b93 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/timeseries/__init__.py +++ b/venv/lib/python3.12/site-packages/redis/commands/timeseries/__init__.py @@ -84,7 +84,7 @@ class TimeSeries(TimeSeriesCommands): startup_nodes=self.client.nodes_manager.startup_nodes, result_callbacks=self.client.result_callbacks, cluster_response_callbacks=self.client.cluster_response_callbacks, - cluster_error_retry_attempts=self.client.retry.get_retries(), + cluster_error_retry_attempts=self.client.cluster_error_retry_attempts, read_from_replicas=self.client.read_from_replicas, reinitialize_steps=self.client.reinitialize_steps, lock=self.client._lock, diff --git a/venv/lib/python3.12/site-packages/redis/commands/timeseries/commands.py b/venv/lib/python3.12/site-packages/redis/commands/timeseries/commands.py index b0cb864..13e3cdf 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/timeseries/commands.py +++ b/venv/lib/python3.12/site-packages/redis/commands/timeseries/commands.py @@ -33,67 +33,41 @@ class TimeSeriesCommands: labels: Optional[Dict[str, str]] = None, chunk_size: Optional[int] = None, duplicate_policy: Optional[str] = None, - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, ): """ Create a new time-series. - For more information see https://redis.io/commands/ts.create/ - Args: - key: - The time-series key. - retention_msecs: - Maximum age for samples, compared to the highest reported timestamp in - milliseconds. If `None` or `0` is passed, the series is not trimmed at - all. - uncompressed: - Changes data storage from compressed (default) to uncompressed. - labels: - A dictionary of label-value pairs that represent metadata labels of the - key. - chunk_size: - Memory size, in bytes, allocated for each data chunk. Must be a multiple - of 8 in the range `[48..1048576]`. In earlier versions of the module the - minimum value was different. - duplicate_policy: - Policy for handling multiple samples with identical timestamps. Can be - one of: - - 'block': An error will occur and the new value will be ignored. - - 'first': Ignore the new value. - - 'last': Override with the latest value. - - 'min': Only override if the value is lower than the existing value. - - 'max': Only override if the value is higher than the existing value. - - 'sum': If a previous sample exists, add the new sample to it so - that the updated value is equal to (previous + new). If no - previous sample exists, set the updated value equal to the new - value. + key: + time-series key + retention_msecs: + Maximum age for samples compared to highest reported timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. - ignore_max_time_diff: - A non-negative integer value, in milliseconds, that sets an ignore - threshold for added timestamps. If the difference between the last - timestamp and the new timestamp is lower than this threshold, the new - entry is ignored. Only applicable if `duplicate_policy` is set to - `last`, and if `ignore_max_val_diff` is also set. Available since - RedisTimeSeries version 1.12.0. - ignore_max_val_diff: - A non-negative floating point value, that sets an ignore threshold for - added values. If the difference between the last value and the new value - is lower than this threshold, the new entry is ignored. Only applicable - if `duplicate_policy` is set to `last`, and if `ignore_max_time_diff` is - also set. Available since RedisTimeSeries version 1.12.0. - """ + For more information: https://redis.io/commands/ts.create/ + """ # noqa params = [key] self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) self._append_chunk_size(params, chunk_size) - self._append_duplicate_policy(params, duplicate_policy) + self._append_duplicate_policy(params, CREATE_CMD, duplicate_policy) self._append_labels(params, labels) - self._append_insertion_filters( - params, ignore_max_time_diff, ignore_max_val_diff - ) return self.execute_command(CREATE_CMD, *params) @@ -104,65 +78,39 @@ class TimeSeriesCommands: labels: Optional[Dict[str, str]] = None, chunk_size: Optional[int] = None, duplicate_policy: Optional[str] = None, - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, ): """ - Update an existing time series. - - For more information see https://redis.io/commands/ts.alter/ + Update the retention, chunk size, duplicate policy, and labels of an existing + time series. Args: - key: - The time-series key. - retention_msecs: - Maximum age for samples, compared to the highest reported timestamp in - milliseconds. If `None` or `0` is passed, the series is not trimmed at - all. - labels: - A dictionary of label-value pairs that represent metadata labels of the - key. - chunk_size: - Memory size, in bytes, allocated for each data chunk. Must be a multiple - of 8 in the range `[48..1048576]`. In earlier versions of the module the - minimum value was different. Changing this value does not affect - existing chunks. - duplicate_policy: - Policy for handling multiple samples with identical timestamps. Can be - one of: - - 'block': An error will occur and the new value will be ignored. - - 'first': Ignore the new value. - - 'last': Override with the latest value. - - 'min': Only override if the value is lower than the existing value. - - 'max': Only override if the value is higher than the existing value. - - 'sum': If a previous sample exists, add the new sample to it so - that the updated value is equal to (previous + new). If no - previous sample exists, set the updated value equal to the new - value. + key: + time-series key + retention_msecs: + Maximum retention period, compared to maximal existing timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. - ignore_max_time_diff: - A non-negative integer value, in milliseconds, that sets an ignore - threshold for added timestamps. If the difference between the last - timestamp and the new timestamp is lower than this threshold, the new - entry is ignored. Only applicable if `duplicate_policy` is set to - `last`, and if `ignore_max_val_diff` is also set. Available since - RedisTimeSeries version 1.12.0. - ignore_max_val_diff: - A non-negative floating point value, that sets an ignore threshold for - added values. If the difference between the last value and the new value - is lower than this threshold, the new entry is ignored. Only applicable - if `duplicate_policy` is set to `last`, and if `ignore_max_time_diff` is - also set. Available since RedisTimeSeries version 1.12.0. - """ + For more information: https://redis.io/commands/ts.alter/ + """ # noqa params = [key] self._append_retention(params, retention_msecs) self._append_chunk_size(params, chunk_size) - self._append_duplicate_policy(params, duplicate_policy) + self._append_duplicate_policy(params, ALTER_CMD, duplicate_policy) self._append_labels(params, labels) - self._append_insertion_filters( - params, ignore_max_time_diff, ignore_max_val_diff - ) return self.execute_command(ALTER_CMD, *params) @@ -176,104 +124,57 @@ class TimeSeriesCommands: labels: Optional[Dict[str, str]] = None, chunk_size: Optional[int] = None, duplicate_policy: Optional[str] = None, - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, - on_duplicate: Optional[str] = None, ): """ - Append a sample to a time series. When the specified key does not exist, a new - time series is created. - - For more information see https://redis.io/commands/ts.add/ + Append (or create and append) a new sample to a time series. Args: - key: - The time-series key. - timestamp: - Timestamp of the sample. `*` can be used for automatic timestamp (using - the system clock). - value: - Numeric data value of the sample. - retention_msecs: - Maximum age for samples, compared to the highest reported timestamp in - milliseconds. If `None` or `0` is passed, the series is not trimmed at - all. - uncompressed: - Changes data storage from compressed (default) to uncompressed. - labels: - A dictionary of label-value pairs that represent metadata labels of the - key. - chunk_size: - Memory size, in bytes, allocated for each data chunk. Must be a multiple - of 8 in the range `[48..1048576]`. In earlier versions of the module the - minimum value was different. - duplicate_policy: - Policy for handling multiple samples with identical timestamps. Can be - one of: - - 'block': An error will occur and the new value will be ignored. - - 'first': Ignore the new value. - - 'last': Override with the latest value. - - 'min': Only override if the value is lower than the existing value. - - 'max': Only override if the value is higher than the existing value. - - 'sum': If a previous sample exists, add the new sample to it so - that the updated value is equal to (previous + new). If no - previous sample exists, set the updated value equal to the new - value. + key: + time-series key + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + value: + Numeric data value of the sample + retention_msecs: + Maximum retention period, compared to maximal existing timestamp (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. + Must be a multiple of 8 in the range [128 .. 1048576]. + duplicate_policy: + Policy for handling multiple samples with identical timestamps. + Can be one of: + - 'block': an error will occur for any out of order sample. + - 'first': ignore the new value. + - 'last': override with latest value. + - 'min': only override if the value is lower than the existing value. + - 'max': only override if the value is higher than the existing value. - ignore_max_time_diff: - A non-negative integer value, in milliseconds, that sets an ignore - threshold for added timestamps. If the difference between the last - timestamp and the new timestamp is lower than this threshold, the new - entry is ignored. Only applicable if `duplicate_policy` is set to - `last`, and if `ignore_max_val_diff` is also set. Available since - RedisTimeSeries version 1.12.0. - ignore_max_val_diff: - A non-negative floating point value, that sets an ignore threshold for - added values. If the difference between the last value and the new value - is lower than this threshold, the new entry is ignored. Only applicable - if `duplicate_policy` is set to `last`, and if `ignore_max_time_diff` is - also set. Available since RedisTimeSeries version 1.12.0. - on_duplicate: - Use a specific duplicate policy for the specified timestamp. Overrides - the duplicate policy set by `duplicate_policy`. - """ + For more information: https://redis.io/commands/ts.add/ + """ # noqa params = [key, timestamp, value] self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) self._append_chunk_size(params, chunk_size) - self._append_duplicate_policy(params, duplicate_policy) + self._append_duplicate_policy(params, ADD_CMD, duplicate_policy) self._append_labels(params, labels) - self._append_insertion_filters( - params, ignore_max_time_diff, ignore_max_val_diff - ) - self._append_on_duplicate(params, on_duplicate) return self.execute_command(ADD_CMD, *params) def madd(self, ktv_tuples: List[Tuple[KeyT, Union[int, str], Number]]): """ - Append new samples to one or more time series. + Append (or create and append) a new `value` to series + `key` with `timestamp`. + Expects a list of `tuples` as (`key`,`timestamp`, `value`). + Return value is an array with timestamps of insertions. - Each time series must already exist. - - The method expects a list of tuples. Each tuple should contain three elements: - (`key`, `timestamp`, `value`). The `value` will be appended to the time series - identified by 'key', at the given 'timestamp'. - - For more information see https://redis.io/commands/ts.madd/ - - Args: - ktv_tuples: - A list of tuples, where each tuple contains: - - `key`: The key of the time series. - - `timestamp`: The timestamp at which the value should be appended. - - `value`: The value to append to the time series. - - Returns: - A list that contains, for each sample, either the timestamp that was used, - or an error, if the sample could not be added. - """ + For more information: https://redis.io/commands/ts.madd/ + """ # noqa params = [] for ktv in ktv_tuples: params.extend(ktv) @@ -289,86 +190,37 @@ class TimeSeriesCommands: uncompressed: Optional[bool] = False, labels: Optional[Dict[str, str]] = None, chunk_size: Optional[int] = None, - duplicate_policy: Optional[str] = None, - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, ): """ - Increment the latest sample's of a series. When the specified key does not - exist, a new time series is created. - - This command can be used as a counter or gauge that automatically gets history - as a time series. - - For more information see https://redis.io/commands/ts.incrby/ + Increment (or create an time-series and increment) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. Args: - key: - The time-series key. - value: - Numeric value to be added (addend). - timestamp: - Timestamp of the sample. `*` can be used for automatic timestamp (using - the system clock). `timestamp` must be equal to or higher than the - maximum existing timestamp in the series. When equal, the value of the - sample with the maximum existing timestamp is increased. If it is - higher, a new sample with a timestamp set to `timestamp` is created, and - its value is set to the value of the sample with the maximum existing - timestamp plus the addend. - retention_msecs: - Maximum age for samples, compared to the highest reported timestamp in - milliseconds. If `None` or `0` is passed, the series is not trimmed at - all. - uncompressed: - Changes data storage from compressed (default) to uncompressed. - labels: - A dictionary of label-value pairs that represent metadata labels of the - key. - chunk_size: - Memory size, in bytes, allocated for each data chunk. Must be a multiple - of 8 in the range `[48..1048576]`. In earlier versions of the module the - minimum value was different. - duplicate_policy: - Policy for handling multiple samples with identical timestamps. Can be - one of: - - 'block': An error will occur and the new value will be ignored. - - 'first': Ignore the new value. - - 'last': Override with the latest value. - - 'min': Only override if the value is lower than the existing value. - - 'max': Only override if the value is higher than the existing value. - - 'sum': If a previous sample exists, add the new sample to it so - that the updated value is equal to (previous + new). If no - previous sample exists, set the updated value equal to the new - value. + key: + time-series key + value: + Numeric data value of the sample + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + retention_msecs: + Maximum age for samples compared to last event time (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. - ignore_max_time_diff: - A non-negative integer value, in milliseconds, that sets an ignore - threshold for added timestamps. If the difference between the last - timestamp and the new timestamp is lower than this threshold, the new - entry is ignored. Only applicable if `duplicate_policy` is set to - `last`, and if `ignore_max_val_diff` is also set. Available since - RedisTimeSeries version 1.12.0. - ignore_max_val_diff: - A non-negative floating point value, that sets an ignore threshold for - added values. If the difference between the last value and the new value - is lower than this threshold, the new entry is ignored. Only applicable - if `duplicate_policy` is set to `last`, and if `ignore_max_time_diff` is - also set. Available since RedisTimeSeries version 1.12.0. - - Returns: - The timestamp of the sample that was modified or added. - """ + For more information: https://redis.io/commands/ts.incrby/ + """ # noqa params = [key, value] self._append_timestamp(params, timestamp) self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) self._append_chunk_size(params, chunk_size) - self._append_duplicate_policy(params, duplicate_policy) self._append_labels(params, labels) - self._append_insertion_filters( - params, ignore_max_time_diff, ignore_max_val_diff - ) return self.execute_command(INCRBY_CMD, *params) @@ -381,86 +233,37 @@ class TimeSeriesCommands: uncompressed: Optional[bool] = False, labels: Optional[Dict[str, str]] = None, chunk_size: Optional[int] = None, - duplicate_policy: Optional[str] = None, - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, ): """ - Decrement the latest sample's of a series. When the specified key does not - exist, a new time series is created. - - This command can be used as a counter or gauge that automatically gets history - as a time series. - - For more information see https://redis.io/commands/ts.decrby/ + Decrement (or create an time-series and decrement) the latest sample's of a series. + This command can be used as a counter or gauge that automatically gets history as a time series. Args: - key: - The time-series key. - value: - Numeric value to subtract (subtrahend). - timestamp: - Timestamp of the sample. `*` can be used for automatic timestamp (using - the system clock). `timestamp` must be equal to or higher than the - maximum existing timestamp in the series. When equal, the value of the - sample with the maximum existing timestamp is decreased. If it is - higher, a new sample with a timestamp set to `timestamp` is created, and - its value is set to the value of the sample with the maximum existing - timestamp minus subtrahend. - retention_msecs: - Maximum age for samples, compared to the highest reported timestamp in - milliseconds. If `None` or `0` is passed, the series is not trimmed at - all. - uncompressed: - Changes data storage from compressed (default) to uncompressed. - labels: - A dictionary of label-value pairs that represent metadata labels of the - key. - chunk_size: - Memory size, in bytes, allocated for each data chunk. Must be a multiple - of 8 in the range `[48..1048576]`. In earlier versions of the module the - minimum value was different. - duplicate_policy: - Policy for handling multiple samples with identical timestamps. Can be - one of: - - 'block': An error will occur and the new value will be ignored. - - 'first': Ignore the new value. - - 'last': Override with the latest value. - - 'min': Only override if the value is lower than the existing value. - - 'max': Only override if the value is higher than the existing value. - - 'sum': If a previous sample exists, add the new sample to it so - that the updated value is equal to (previous + new). If no - previous sample exists, set the updated value equal to the new - value. + key: + time-series key + value: + Numeric data value of the sample + timestamp: + Timestamp of the sample. * can be used for automatic timestamp (using the system clock). + retention_msecs: + Maximum age for samples compared to last event time (in milliseconds). + If None or 0 is passed then the series is not trimmed at all. + uncompressed: + Changes data storage from compressed (by default) to uncompressed + labels: + Set of label-value pairs that represent metadata labels of the key. + chunk_size: + Memory size, in bytes, allocated for each data chunk. - ignore_max_time_diff: - A non-negative integer value, in milliseconds, that sets an ignore - threshold for added timestamps. If the difference between the last - timestamp and the new timestamp is lower than this threshold, the new - entry is ignored. Only applicable if `duplicate_policy` is set to - `last`, and if `ignore_max_val_diff` is also set. Available since - RedisTimeSeries version 1.12.0. - ignore_max_val_diff: - A non-negative floating point value, that sets an ignore threshold for - added values. If the difference between the last value and the new value - is lower than this threshold, the new entry is ignored. Only applicable - if `duplicate_policy` is set to `last`, and if `ignore_max_time_diff` is - also set. Available since RedisTimeSeries version 1.12.0. - - Returns: - The timestamp of the sample that was modified or added. - """ + For more information: https://redis.io/commands/ts.decrby/ + """ # noqa params = [key, value] self._append_timestamp(params, timestamp) self._append_retention(params, retention_msecs) self._append_uncompressed(params, uncompressed) self._append_chunk_size(params, chunk_size) - self._append_duplicate_policy(params, duplicate_policy) self._append_labels(params, labels) - self._append_insertion_filters( - params, ignore_max_time_diff, ignore_max_val_diff - ) return self.execute_command(DECRBY_CMD, *params) @@ -468,22 +271,17 @@ class TimeSeriesCommands: """ Delete all samples between two timestamps for a given time series. - The given timestamp interval is closed (inclusive), meaning that samples whose - timestamp equals `from_time` or `to_time` are also deleted. - - For more information see https://redis.io/commands/ts.del/ - Args: - key: - The time-series key. - from_time: - Start timestamp for the range deletion. - to_time: - End timestamp for the range deletion. - Returns: - The number of samples deleted. - """ + key: + time-series key. + from_time: + Start timestamp for the range deletion. + to_time: + End timestamp for the range deletion. + + For more information: https://redis.io/commands/ts.del/ + """ # noqa return self.execute_command(DEL_CMD, key, from_time, to_time) def createrule( @@ -497,23 +295,24 @@ class TimeSeriesCommands: """ Create a compaction rule from values added to `source_key` into `dest_key`. - For more information see https://redis.io/commands/ts.createrule/ - Args: - source_key: - Key name for source time series. - dest_key: - Key name for destination (compacted) time series. - aggregation_type: - Aggregation type: One of the following: - [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, `std.p`, - `std.s`, `var.p`, `var.s`, `twa`] - bucket_size_msec: - Duration of each bucket, in milliseconds. - align_timestamp: - Assure that there is a bucket that starts at exactly align_timestamp and - align all other buckets accordingly. - """ + + source_key: + Key name for source time series + dest_key: + Key name for destination (compacted) time series + aggregation_type: + Aggregation type: One of the following: + [`avg`, `sum`, `min`, `max`, `range`, `count`, `first`, `last`, `std.p`, + `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Duration of each bucket, in milliseconds + align_timestamp: + Assure that there is a bucket that starts at exactly align_timestamp and + align all other buckets accordingly. + + For more information: https://redis.io/commands/ts.createrule/ + """ # noqa params = [source_key, dest_key] self._append_aggregation(params, aggregation_type, bucket_size_msec) if align_timestamp is not None: @@ -523,10 +322,10 @@ class TimeSeriesCommands: def deleterule(self, source_key: KeyT, dest_key: KeyT): """ - Delete a compaction rule from `source_key` to `dest_key`. + Delete a compaction rule from `source_key` to `dest_key`.. - For more information see https://redis.io/commands/ts.deleterule/ - """ + For more information: https://redis.io/commands/ts.deleterule/ + """ # noqa return self.execute_command(DELETERULE_CMD, source_key, dest_key) def __range_params( @@ -575,46 +374,42 @@ class TimeSeriesCommands: empty: Optional[bool] = False, ): """ - Query a range in forward direction for a specific time-series. - - For more information see https://redis.io/commands/ts.range/ + Query a range in forward direction for a specific time-serie. Args: - key: - Key name for timeseries. - from_time: - Start timestamp for the range query. `-` can be used to express the - minimum possible timestamp (0). - to_time: - End timestamp for range query, `+` can be used to express the maximum - possible timestamp. - count: - Limits the number of returned samples. - aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, - `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, - `twa`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also - `filter by_max_value`). - filter_by_max_value: - Filter result by maximum value (must mention also - `filter by_min_value`). - align: - Timestamp for alignment control for aggregation. - latest: - Used when a time series is a compaction, reports the compacted value of - the latest possibly partial bucket. - bucket_timestamp: - Controls how bucket timestamps are reported. Can be one of [`-`, `low`, - `+`, `high`, `~`, `mid`]. - empty: - Reports aggregations for empty buckets. - """ + + key: + Key name for timeseries. + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter by_min_value). + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.range/ + """ # noqa params = self.__range_params( key, from_time, @@ -630,7 +425,7 @@ class TimeSeriesCommands: bucket_timestamp, empty, ) - return self.execute_command(RANGE_CMD, *params, keys=[key]) + return self.execute_command(RANGE_CMD, *params) def revrange( self, @@ -653,44 +448,40 @@ class TimeSeriesCommands: **Note**: This command is only available since RedisTimeSeries >= v1.4 - For more information see https://redis.io/commands/ts.revrange/ - Args: - key: - Key name for timeseries. - from_time: - Start timestamp for the range query. `-` can be used to express the - minimum possible timestamp (0). - to_time: - End timestamp for range query, `+` can be used to express the maximum - possible timestamp. - count: - Limits the number of returned samples. - aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, - `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, - `twa`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also - `filter_by_max_value`). - filter_by_max_value: - Filter result by maximum value (must mention also - `filter_by_min_value`). - align: - Timestamp for alignment control for aggregation. - latest: - Used when a time series is a compaction, reports the compacted value of - the latest possibly partial bucket. - bucket_timestamp: - Controls how bucket timestamps are reported. Can be one of [`-`, `low`, - `+`, `high`, `~`, `mid`]. - empty: - Reports aggregations for empty buckets. - """ + + key: + Key name for timeseries. + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted value of the + latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.revrange/ + """ # noqa params = self.__range_params( key, from_time, @@ -706,7 +497,7 @@ class TimeSeriesCommands: bucket_timestamp, empty, ) - return self.execute_command(REVRANGE_CMD, *params, keys=[key]) + return self.execute_command(REVRANGE_CMD, *params) def __mrange_params( self, @@ -767,55 +558,49 @@ class TimeSeriesCommands: """ Query a range across multiple time-series by filters in forward direction. - For more information see https://redis.io/commands/ts.mrange/ - Args: - from_time: - Start timestamp for the range query. `-` can be used to express the - minimum possible timestamp (0). - to_time: - End timestamp for range query, `+` can be used to express the maximum - possible timestamp. - filters: - Filter to match the time-series labels. - count: - Limits the number of returned samples. - aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, - `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, - `twa`] - bucket_size_msec: - Time bucket for aggregation in milliseconds. - with_labels: - Include in the reply all label-value pairs representing metadata labels - of the time series. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also - `filter_by_max_value`). - filter_by_max_value: - Filter result by maximum value (must mention also - `filter_by_min_value`). - groupby: - Grouping by fields the results (must mention also `reduce`). - reduce: - Applying reducer functions on each group. Can be one of [`avg` `sum`, - `min`, `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. - select_labels: - Include in the reply only a subset of the key-value pair labels of a - series. - align: - Timestamp for alignment control for aggregation. - latest: - Used when a time series is a compaction, reports the compacted value of - the latest possibly partial bucket. - bucket_timestamp: - Controls how bucket timestamps are reported. Can be one of [`-`, `low`, - `+`, `high`, `~`, `mid`]. - empty: - Reports aggregations for empty buckets. - """ + + from_time: + Start timestamp for the range query. `-` can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, `+` can be used to express the maximum possible timestamp. + filters: + filter to match the time-series labels. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + with_labels: + Include in the reply all label-value pairs representing metadata labels of the time series. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + groupby: + Grouping by fields the results (must mention also reduce). + reduce: + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrange/ + """ # noqa params = self.__mrange_params( aggregation_type, bucket_size_msec, @@ -861,55 +646,49 @@ class TimeSeriesCommands: """ Query a range across multiple time-series by filters in reverse direction. - For more information see https://redis.io/commands/ts.mrevrange/ - Args: - from_time: - Start timestamp for the range query. '-' can be used to express the - minimum possible timestamp (0). - to_time: - End timestamp for range query, '+' can be used to express the maximum - possible timestamp. - filters: - Filter to match the time-series labels. - count: - Limits the number of returned samples. - aggregation_type: - Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, - `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, - `twa`]. - bucket_size_msec: - Time bucket for aggregation in milliseconds. - with_labels: - Include in the reply all label-value pairs representing metadata labels - of the time series. - filter_by_ts: - List of timestamps to filter the result by specific timestamps. - filter_by_min_value: - Filter result by minimum value (must mention also - `filter_by_max_value`). - filter_by_max_value: - Filter result by maximum value (must mention also - `filter_by_min_value`). - groupby: - Grouping by fields the results (must mention also `reduce`). - reduce: - Applying reducer functions on each group. Can be one of [`avg` `sum`, - `min`, `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. - select_labels: - Include in the reply only a subset of the key-value pair labels of a - series. - align: - Timestamp for alignment control for aggregation. - latest: - Used when a time series is a compaction, reports the compacted value of - the latest possibly partial bucket. - bucket_timestamp: - Controls how bucket timestamps are reported. Can be one of [`-`, `low`, - `+`, `high`, `~`, `mid`]. - empty: - Reports aggregations for empty buckets. - """ + + from_time: + Start timestamp for the range query. - can be used to express the minimum possible timestamp (0). + to_time: + End timestamp for range query, + can be used to express the maximum possible timestamp. + filters: + Filter to match the time-series labels. + count: + Limits the number of returned samples. + aggregation_type: + Optional aggregation type. Can be one of [`avg`, `sum`, `min`, `max`, + `range`, `count`, `first`, `last`, `std.p`, `std.s`, `var.p`, `var.s`, `twa`] + bucket_size_msec: + Time bucket for aggregation in milliseconds. + with_labels: + Include in the reply all label-value pairs representing metadata labels of the time series. + filter_by_ts: + List of timestamps to filter the result by specific timestamps. + filter_by_min_value: + Filter result by minimum value (must mention also filter_by_max_value). + filter_by_max_value: + Filter result by maximum value (must mention also filter_by_min_value). + groupby: + Grouping by fields the results (must mention also reduce). + reduce: + Applying reducer functions on each group. Can be one of [`avg` `sum`, `min`, + `max`, `range`, `count`, `std.p`, `std.s`, `var.p`, `var.s`]. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + align: + Timestamp for alignment control for aggregation. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + bucket_timestamp: + Controls how bucket timestamps are reported. Can be one of [`-`, `low`, `+`, + `high`, `~`, `mid`]. + empty: + Reports aggregations for empty buckets. + + For more information: https://redis.io/commands/ts.mrevrange/ + """ # noqa params = self.__mrange_params( aggregation_type, bucket_size_msec, @@ -933,19 +712,16 @@ class TimeSeriesCommands: return self.execute_command(MREVRANGE_CMD, *params) def get(self, key: KeyT, latest: Optional[bool] = False): - """ + """# noqa Get the last sample of `key`. + `latest` used when a time series is a compaction, reports the compacted + value of the latest (possibly partial) bucket - For more information see https://redis.io/commands/ts.get/ - - Args: - latest: - Used when a time series is a compaction, reports the compacted value of - the latest (possibly partial) bucket. - """ + For more information: https://redis.io/commands/ts.get/ + """ # noqa params = [key] self._append_latest(params, latest) - return self.execute_command(GET_CMD, *params, keys=[key]) + return self.execute_command(GET_CMD, *params) def mget( self, @@ -954,24 +730,24 @@ class TimeSeriesCommands: select_labels: Optional[List[str]] = None, latest: Optional[bool] = False, ): - """ + """# noqa Get the last samples matching the specific `filter`. - For more information see https://redis.io/commands/ts.mget/ - Args: - filters: - Filter to match the time-series labels. - with_labels: - Include in the reply all label-value pairs representing metadata labels - of the time series. - select_labels: - Include in the reply only a subset of the key-value pair labels o the - time series. - latest: - Used when a time series is a compaction, reports the compacted value of - the latest possibly partial bucket. - """ + + filters: + Filter to match the time-series labels. + with_labels: + Include in the reply all label-value pairs representing metadata + labels of the time series. + select_labels: + Include in the reply only a subset of the key-value pair labels of a series. + latest: + Used when a time series is a compaction, reports the compacted + value of the latest possibly partial bucket + + For more information: https://redis.io/commands/ts.mget/ + """ # noqa params = [] self._append_latest(params, latest) self._append_with_labels(params, with_labels, select_labels) @@ -980,26 +756,26 @@ class TimeSeriesCommands: return self.execute_command(MGET_CMD, *params) def info(self, key: KeyT): - """ + """# noqa Get information of `key`. - For more information see https://redis.io/commands/ts.info/ - """ - return self.execute_command(INFO_CMD, key, keys=[key]) + For more information: https://redis.io/commands/ts.info/ + """ # noqa + return self.execute_command(INFO_CMD, key) def queryindex(self, filters: List[str]): - """ + """# noqa Get all time series keys matching the `filter` list. - For more information see https://redis.io/commands/ts.queryindex/ - """ + For more information: https://redis.io/commands/ts.queryindex/ + """ # noq return self.execute_command(QUERYINDEX_CMD, *filters) @staticmethod def _append_uncompressed(params: List[str], uncompressed: Optional[bool]): """Append UNCOMPRESSED tag to params.""" if uncompressed: - params.extend(["ENCODING", "UNCOMPRESSED"]) + params.extend(["UNCOMPRESSED"]) @staticmethod def _append_with_labels( @@ -1075,16 +851,17 @@ class TimeSeriesCommands: params.extend(["CHUNK_SIZE", chunk_size]) @staticmethod - def _append_duplicate_policy(params: List[str], duplicate_policy: Optional[str]): - """Append DUPLICATE_POLICY property to params.""" + def _append_duplicate_policy( + params: List[str], command: Optional[str], duplicate_policy: Optional[str] + ): + """Append DUPLICATE_POLICY property to params on CREATE + and ON_DUPLICATE on ADD. + """ if duplicate_policy is not None: - params.extend(["DUPLICATE_POLICY", duplicate_policy]) - - @staticmethod - def _append_on_duplicate(params: List[str], on_duplicate: Optional[str]): - """Append ON_DUPLICATE property to params.""" - if on_duplicate is not None: - params.extend(["ON_DUPLICATE", on_duplicate]) + if command == "TS.ADD": + params.extend(["ON_DUPLICATE", duplicate_policy]) + else: + params.extend(["DUPLICATE_POLICY", duplicate_policy]) @staticmethod def _append_filer_by_ts(params: List[str], ts_list: Optional[List[int]]): @@ -1117,20 +894,3 @@ class TimeSeriesCommands: """Append EMPTY property to params.""" if empty: params.append("EMPTY") - - @staticmethod - def _append_insertion_filters( - params: List[str], - ignore_max_time_diff: Optional[int] = None, - ignore_max_val_diff: Optional[Number] = None, - ): - """Append insertion filters to params.""" - if (ignore_max_time_diff is None) != (ignore_max_val_diff is None): - raise ValueError( - "Both ignore_max_time_diff and ignore_max_val_diff must be set." - ) - - if ignore_max_time_diff is not None and ignore_max_val_diff is not None: - params.extend( - ["IGNORE", str(ignore_max_time_diff), str(ignore_max_val_diff)] - ) diff --git a/venv/lib/python3.12/site-packages/redis/commands/timeseries/info.py b/venv/lib/python3.12/site-packages/redis/commands/timeseries/info.py index 861e3ef..3a384dc 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/timeseries/info.py +++ b/venv/lib/python3.12/site-packages/redis/commands/timeseries/info.py @@ -6,7 +6,7 @@ class TSInfo: """ Hold information and statistics on the time-series. Can be created using ``tsinfo`` command - https://redis.io/docs/latest/commands/ts.info/ + https://oss.redis.com/redistimeseries/commands/#tsinfo. """ rules = [] @@ -57,7 +57,7 @@ class TSInfo: Policy that will define handling of duplicate samples. Can read more about on - https://redis.io/docs/latest/develop/data-types/timeseries/configuration/#duplicate_policy + https://oss.redis.com/redistimeseries/configuration/#duplicate_policy """ response = dict(zip(map(nativestr, args[::2]), args[1::2])) self.rules = response.get("rules") @@ -78,7 +78,7 @@ class TSInfo: self.chunk_size = response["chunkSize"] if "duplicatePolicy" in response: self.duplicate_policy = response["duplicatePolicy"] - if isinstance(self.duplicate_policy, bytes): + if type(self.duplicate_policy) == bytes: self.duplicate_policy = self.duplicate_policy.decode() def get(self, item): diff --git a/venv/lib/python3.12/site-packages/redis/commands/timeseries/utils.py b/venv/lib/python3.12/site-packages/redis/commands/timeseries/utils.py index 12ed656..c49b040 100644 --- a/venv/lib/python3.12/site-packages/redis/commands/timeseries/utils.py +++ b/venv/lib/python3.12/site-packages/redis/commands/timeseries/utils.py @@ -5,7 +5,7 @@ def list_to_dict(aList): return {nativestr(aList[i][0]): nativestr(aList[i][1]) for i in range(len(aList))} -def parse_range(response, **kwargs): +def parse_range(response): """Parse range response. Used by TS.RANGE and TS.REVRANGE.""" return [tuple((r[0], float(r[1]))) for r in response] diff --git a/venv/lib/python3.12/site-packages/redis/commands/vectorset/__init__.py b/venv/lib/python3.12/site-packages/redis/commands/vectorset/__init__.py deleted file mode 100644 index d78580a..0000000 --- a/venv/lib/python3.12/site-packages/redis/commands/vectorset/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -import json - -from redis._parsers.helpers import pairs_to_dict -from redis.commands.vectorset.utils import ( - parse_vemb_result, - parse_vlinks_result, - parse_vsim_result, -) - -from ..helpers import get_protocol_version -from .commands import ( - VEMB_CMD, - VGETATTR_CMD, - VINFO_CMD, - VLINKS_CMD, - VSIM_CMD, - VectorSetCommands, -) - - -class VectorSet(VectorSetCommands): - def __init__(self, client, **kwargs): - """Create a new VectorSet client.""" - # Set the module commands' callbacks - self._MODULE_CALLBACKS = { - VEMB_CMD: parse_vemb_result, - VGETATTR_CMD: lambda r: r and json.loads(r) or None, - } - - self._RESP2_MODULE_CALLBACKS = { - VINFO_CMD: lambda r: r and pairs_to_dict(r) or None, - VSIM_CMD: parse_vsim_result, - VLINKS_CMD: parse_vlinks_result, - } - self._RESP3_MODULE_CALLBACKS = {} - - self.client = client - self.execute_command = client.execute_command - - if get_protocol_version(self.client) in ["3", 3]: - self._MODULE_CALLBACKS.update(self._RESP3_MODULE_CALLBACKS) - else: - self._MODULE_CALLBACKS.update(self._RESP2_MODULE_CALLBACKS) - - for k, v in self._MODULE_CALLBACKS.items(): - self.client.set_response_callback(k, v) diff --git a/venv/lib/python3.12/site-packages/redis/commands/vectorset/commands.py b/venv/lib/python3.12/site-packages/redis/commands/vectorset/commands.py deleted file mode 100644 index 0f23dba..0000000 --- a/venv/lib/python3.12/site-packages/redis/commands/vectorset/commands.py +++ /dev/null @@ -1,374 +0,0 @@ -import json -from enum import Enum -from typing import Awaitable, Dict, List, Optional, Union - -from redis.client import NEVER_DECODE -from redis.commands.helpers import get_protocol_version -from redis.exceptions import DataError -from redis.typing import CommandsProtocol, EncodableT, KeyT, Number - -VADD_CMD = "VADD" -VSIM_CMD = "VSIM" -VREM_CMD = "VREM" -VDIM_CMD = "VDIM" -VCARD_CMD = "VCARD" -VEMB_CMD = "VEMB" -VLINKS_CMD = "VLINKS" -VINFO_CMD = "VINFO" -VSETATTR_CMD = "VSETATTR" -VGETATTR_CMD = "VGETATTR" -VRANDMEMBER_CMD = "VRANDMEMBER" - - -class QuantizationOptions(Enum): - """Quantization options for the VADD command.""" - - NOQUANT = "NOQUANT" - BIN = "BIN" - Q8 = "Q8" - - -class CallbacksOptions(Enum): - """Options that can be set for the commands callbacks""" - - RAW = "RAW" - WITHSCORES = "WITHSCORES" - ALLOW_DECODING = "ALLOW_DECODING" - RESP3 = "RESP3" - - -class VectorSetCommands(CommandsProtocol): - """Redis VectorSet commands""" - - def vadd( - self, - key: KeyT, - vector: Union[List[float], bytes], - element: str, - reduce_dim: Optional[int] = None, - cas: Optional[bool] = False, - quantization: Optional[QuantizationOptions] = None, - ef: Optional[Number] = None, - attributes: Optional[Union[dict, str]] = None, - numlinks: Optional[int] = None, - ) -> Union[Awaitable[int], int]: - """ - Add vector ``vector`` for element ``element`` to a vector set ``key``. - - ``reduce_dim`` sets the dimensions to reduce the vector to. - If not provided, the vector is not reduced. - - ``cas`` is a boolean flag that indicates whether to use CAS (check-and-set style) - when adding the vector. If not provided, CAS is not used. - - ``quantization`` sets the quantization type to use. - If not provided, int8 quantization is used. - The options are: - - NOQUANT: No quantization - - BIN: Binary quantization - - Q8: Signed 8-bit quantization - - ``ef`` sets the exploration factor to use. - If not provided, the default exploration factor is used. - - ``attributes`` is a dictionary or json string that contains the attributes to set for the vector. - If not provided, no attributes are set. - - ``numlinks`` sets the number of links to create for the vector. - If not provided, the default number of links is used. - - For more information see https://redis.io/commands/vadd - """ - if not vector or not element: - raise DataError("Both vector and element must be provided") - - pieces = [] - if reduce_dim: - pieces.extend(["REDUCE", reduce_dim]) - - values_pieces = [] - if isinstance(vector, bytes): - values_pieces.extend(["FP32", vector]) - else: - values_pieces.extend(["VALUES", len(vector)]) - values_pieces.extend(vector) - pieces.extend(values_pieces) - - pieces.append(element) - - if cas: - pieces.append("CAS") - - if quantization: - pieces.append(quantization.value) - - if ef: - pieces.extend(["EF", ef]) - - if attributes: - if isinstance(attributes, dict): - # transform attributes to json string - attributes_json = json.dumps(attributes) - else: - attributes_json = attributes - pieces.extend(["SETATTR", attributes_json]) - - if numlinks: - pieces.extend(["M", numlinks]) - - return self.execute_command(VADD_CMD, key, *pieces) - - def vsim( - self, - key: KeyT, - input: Union[List[float], bytes, str], - with_scores: Optional[bool] = False, - count: Optional[int] = None, - ef: Optional[Number] = None, - filter: Optional[str] = None, - filter_ef: Optional[str] = None, - truth: Optional[bool] = False, - no_thread: Optional[bool] = False, - epsilon: Optional[Number] = None, - ) -> Union[ - Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]], - Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]], - ]: - """ - Compare a vector or element ``input`` with the other vectors in a vector set ``key``. - - ``with_scores`` sets if the results should be returned with the - similarity scores of the elements in the result. - - ``count`` sets the number of results to return. - - ``ef`` sets the exploration factor. - - ``filter`` sets filter that should be applied for the search. - - ``filter_ef`` sets the max filtering effort. - - ``truth`` when enabled forces the command to perform linear scan. - - ``no_thread`` when enabled forces the command to execute the search - on the data structure in the main thread. - - ``epsilon`` floating point between 0 and 1, if specified will return - only elements with distance no further than the specified one. - - For more information see https://redis.io/commands/vsim - """ - - if not input: - raise DataError("'input' should be provided") - - pieces = [] - options = {} - - if isinstance(input, bytes): - pieces.extend(["FP32", input]) - elif isinstance(input, list): - pieces.extend(["VALUES", len(input)]) - pieces.extend(input) - else: - pieces.extend(["ELE", input]) - - if with_scores: - pieces.append("WITHSCORES") - options[CallbacksOptions.WITHSCORES.value] = True - - if count: - pieces.extend(["COUNT", count]) - - if epsilon: - pieces.extend(["EPSILON", epsilon]) - - if ef: - pieces.extend(["EF", ef]) - - if filter: - pieces.extend(["FILTER", filter]) - - if filter_ef: - pieces.extend(["FILTER-EF", filter_ef]) - - if truth: - pieces.append("TRUTH") - - if no_thread: - pieces.append("NOTHREAD") - - return self.execute_command(VSIM_CMD, key, *pieces, **options) - - def vdim(self, key: KeyT) -> Union[Awaitable[int], int]: - """ - Get the dimension of a vector set. - - In the case of vectors that were populated using the `REDUCE` - option, for random projection, the vector set will report the size of - the projected (reduced) dimension. - - Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. - - For more information see https://redis.io/commands/vdim - """ - return self.execute_command(VDIM_CMD, key) - - def vcard(self, key: KeyT) -> Union[Awaitable[int], int]: - """ - Get the cardinality(the number of elements) of a vector set with key ``key``. - - Raises `redis.exceptions.ResponseError` if the vector set doesn't exist. - - For more information see https://redis.io/commands/vcard - """ - return self.execute_command(VCARD_CMD, key) - - def vrem(self, key: KeyT, element: str) -> Union[Awaitable[int], int]: - """ - Remove an element from a vector set. - - For more information see https://redis.io/commands/vrem - """ - return self.execute_command(VREM_CMD, key, element) - - def vemb( - self, key: KeyT, element: str, raw: Optional[bool] = False - ) -> Union[ - Awaitable[Optional[Union[List[EncodableT], Dict[str, EncodableT]]]], - Optional[Union[List[EncodableT], Dict[str, EncodableT]]], - ]: - """ - Get the approximated vector of an element ``element`` from vector set ``key``. - - ``raw`` is a boolean flag that indicates whether to return the - interal representation used by the vector. - - - For more information see https://redis.io/commands/vembed - """ - options = {} - pieces = [] - pieces.extend([key, element]) - - if get_protocol_version(self.client) in ["3", 3]: - options[CallbacksOptions.RESP3.value] = True - - if raw: - pieces.append("RAW") - - options[NEVER_DECODE] = True - if ( - hasattr(self.client, "connection_pool") - and self.client.connection_pool.connection_kwargs["decode_responses"] - ) or ( - hasattr(self.client, "nodes_manager") - and self.client.nodes_manager.connection_kwargs["decode_responses"] - ): - # allow decoding in the postprocessing callback - # if the user set decode_responses=True - # in the connection pool - options[CallbacksOptions.ALLOW_DECODING.value] = True - - options[CallbacksOptions.RAW.value] = True - - return self.execute_command(VEMB_CMD, *pieces, **options) - - def vlinks( - self, key: KeyT, element: str, with_scores: Optional[bool] = False - ) -> Union[ - Awaitable[ - Optional[ - List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]] - ] - ], - Optional[List[Union[List[Union[str, bytes]], Dict[Union[str, bytes], Number]]]], - ]: - """ - Returns the neighbors for each level the element ``element`` exists in the vector set ``key``. - - The result is a list of lists, where each list contains the neighbors for one level. - If the element does not exist, or if the vector set does not exist, None is returned. - - If the ``WITHSCORES`` option is provided, the result is a list of dicts, - where each dict contains the neighbors for one level, with the scores as values. - - For more information see https://redis.io/commands/vlinks - """ - options = {} - pieces = [] - pieces.extend([key, element]) - - if with_scores: - pieces.append("WITHSCORES") - options[CallbacksOptions.WITHSCORES.value] = True - - return self.execute_command(VLINKS_CMD, *pieces, **options) - - def vinfo(self, key: KeyT) -> Union[Awaitable[dict], dict]: - """ - Get information about a vector set. - - For more information see https://redis.io/commands/vinfo - """ - return self.execute_command(VINFO_CMD, key) - - def vsetattr( - self, key: KeyT, element: str, attributes: Optional[Union[dict, str]] = None - ) -> Union[Awaitable[int], int]: - """ - Associate or remove JSON attributes ``attributes`` of element ``element`` - for vector set ``key``. - - For more information see https://redis.io/commands/vsetattr - """ - if attributes is None: - attributes_json = "{}" - elif isinstance(attributes, dict): - # transform attributes to json string - attributes_json = json.dumps(attributes) - else: - attributes_json = attributes - - return self.execute_command(VSETATTR_CMD, key, element, attributes_json) - - def vgetattr( - self, key: KeyT, element: str - ) -> Union[Optional[Awaitable[dict]], Optional[dict]]: - """ - Retrieve the JSON attributes of an element ``elemet`` for vector set ``key``. - - If the element does not exist, or if the vector set does not exist, None is - returned. - - For more information see https://redis.io/commands/vgetattr - """ - return self.execute_command(VGETATTR_CMD, key, element) - - def vrandmember( - self, key: KeyT, count: Optional[int] = None - ) -> Union[ - Awaitable[Optional[Union[List[str], str]]], Optional[Union[List[str], str]] - ]: - """ - Returns random elements from a vector set ``key``. - - ``count`` is the number of elements to return. - If ``count`` is not provided, a single element is returned as a single string. - If ``count`` is positive(smaller than the number of elements - in the vector set), the command returns a list with up to ``count`` - distinct elements from the vector set - If ``count`` is negative, the command returns a list with ``count`` random elements, - potentially with duplicates. - If ``count`` is greater than the number of elements in the vector set, - only the entire set is returned as a list. - - If the vector set does not exist, ``None`` is returned. - - For more information see https://redis.io/commands/vrandmember - """ - pieces = [] - pieces.append(key) - if count is not None: - pieces.append(count) - return self.execute_command(VRANDMEMBER_CMD, *pieces) diff --git a/venv/lib/python3.12/site-packages/redis/commands/vectorset/utils.py b/venv/lib/python3.12/site-packages/redis/commands/vectorset/utils.py deleted file mode 100644 index ed6d194..0000000 --- a/venv/lib/python3.12/site-packages/redis/commands/vectorset/utils.py +++ /dev/null @@ -1,94 +0,0 @@ -from redis._parsers.helpers import pairs_to_dict -from redis.commands.vectorset.commands import CallbacksOptions - - -def parse_vemb_result(response, **options): - """ - Handle VEMB result since the command can returning different result - structures depending on input options and on quantization type of the vector set. - - Parsing VEMB result into: - - List[Union[bytes, Union[int, float]]] - - Dict[str, Union[bytes, str, float]] - """ - if response is None: - return response - - if options.get(CallbacksOptions.RAW.value): - result = {} - result["quantization"] = ( - response[0].decode("utf-8") - if options.get(CallbacksOptions.ALLOW_DECODING.value) - else response[0] - ) - result["raw"] = response[1] - result["l2"] = float(response[2]) - if len(response) > 3: - result["range"] = float(response[3]) - return result - else: - if options.get(CallbacksOptions.RESP3.value): - return response - - result = [] - for i in range(len(response)): - try: - result.append(int(response[i])) - except ValueError: - # if the value is not an integer, it should be a float - result.append(float(response[i])) - - return result - - -def parse_vlinks_result(response, **options): - """ - Handle VLINKS result since the command can be returning different result - structures depending on input options. - Parsing VLINKS result into: - - List[List[str]] - - List[Dict[str, Number]] - """ - if response is None: - return response - - if options.get(CallbacksOptions.WITHSCORES.value): - result = [] - # Redis will return a list of list of strings. - # This list have to be transformed to list of dicts - for level_item in response: - level_data_dict = {} - for key, value in pairs_to_dict(level_item).items(): - value = float(value) - level_data_dict[key] = value - result.append(level_data_dict) - return result - else: - # return the list of elements for each level - # list of lists - return response - - -def parse_vsim_result(response, **options): - """ - Handle VSIM result since the command can be returning different result - structures depending on input options. - Parsing VSIM result into: - - List[List[str]] - - List[Dict[str, Number]] - """ - if response is None: - return response - - if options.get(CallbacksOptions.WITHSCORES.value): - # Redis will return a list of list of pairs. - # This list have to be transformed to dict - result_dict = {} - for key, value in pairs_to_dict(response).items(): - value = float(value) - result_dict[key] = value - return result_dict - else: - # return the list of elements for each level - # list of lists - return response diff --git a/venv/lib/python3.12/site-packages/redis/compat.py b/venv/lib/python3.12/site-packages/redis/compat.py new file mode 100644 index 0000000..e478493 --- /dev/null +++ b/venv/lib/python3.12/site-packages/redis/compat.py @@ -0,0 +1,6 @@ +# flake8: noqa +try: + from typing import Literal, Protocol, TypedDict # lgtm [py/unused-import] +except ImportError: + from typing_extensions import Literal # lgtm [py/unused-import] + from typing_extensions import Protocol, TypedDict diff --git a/venv/lib/python3.12/site-packages/redis/connection.py b/venv/lib/python3.12/site-packages/redis/connection.py index 47cb589..b39ba28 100644 --- a/venv/lib/python3.12/site-packages/redis/connection.py +++ b/venv/lib/python3.12/site-packages/redis/connection.py @@ -1,37 +1,26 @@ import copy import os import socket +import ssl import sys import threading -import time import weakref from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from time import time +from typing import Any, Callable, List, Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse -from redis.cache import ( - CacheEntry, - CacheEntryStatus, - CacheFactory, - CacheFactoryInterface, - CacheInterface, - CacheKey, -) - from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser -from .auth.token import TokenInterface from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider -from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, ChildDeadlockedError, ConnectionError, DataError, - MaxConnectionsError, RedisError, ResponseError, TimeoutError, @@ -40,20 +29,12 @@ from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, + HIREDIS_PACK_AVAILABLE, SSL_AVAILABLE, - compare_versions, - deprecated_args, - ensure_string, - format_error_message, get_lib_version, str_if_bytes, ) -if SSL_AVAILABLE: - import ssl -else: - ssl = None - if HIREDIS_AVAILABLE: import hiredis @@ -142,88 +123,7 @@ class PythonRespSerializer: return output -class ConnectionInterface: - @abstractmethod - def repr_pieces(self): - pass - - @abstractmethod - def register_connect_callback(self, callback): - pass - - @abstractmethod - def deregister_connect_callback(self, callback): - pass - - @abstractmethod - def set_parser(self, parser_class): - pass - - @abstractmethod - def get_protocol(self): - pass - - @abstractmethod - def connect(self): - pass - - @abstractmethod - def on_connect(self): - pass - - @abstractmethod - def disconnect(self, *args): - pass - - @abstractmethod - def check_health(self): - pass - - @abstractmethod - def send_packed_command(self, command, check_health=True): - pass - - @abstractmethod - def send_command(self, *args, **kwargs): - pass - - @abstractmethod - def can_read(self, timeout=0): - pass - - @abstractmethod - def read_response( - self, - disable_decoding=False, - *, - disconnect_on_error=True, - push_request=False, - ): - pass - - @abstractmethod - def pack_command(self, *args): - pass - - @abstractmethod - def pack_commands(self, commands): - pass - - @property - @abstractmethod - def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: - pass - - @abstractmethod - def set_re_auth_token(self, token: TokenInterface): - pass - - @abstractmethod - def re_auth(self): - pass - - -class AbstractConnection(ConnectionInterface): +class AbstractConnection: "Manages communication to and from a Redis server" def __init__( @@ -249,7 +149,6 @@ class AbstractConnection(ConnectionInterface): credential_provider: Optional[CredentialProvider] = None, protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, - event_dispatcher: Optional[EventDispatcher] = None, ): """ Initialize a new Connection. @@ -265,10 +164,6 @@ class AbstractConnection(ConnectionInterface): "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) - if event_dispatcher is None: - self._event_dispatcher = EventDispatcher() - else: - self._event_dispatcher = event_dispatcher self.pid = os.getpid() self.db = db self.client_name = client_name @@ -302,13 +197,11 @@ class AbstractConnection(ConnectionInterface): self.next_health_check = 0 self.redis_connect_func = redis_connect_func self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 - self._re_auth_token: Optional[TokenInterface] = None try: p = int(protocol) except TypeError: @@ -324,7 +217,7 @@ class AbstractConnection(ConnectionInterface): def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) - return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" + return f"{self.__class__.__name__}<{repr_args}>" @abstractmethod def repr_pieces(self): @@ -339,29 +232,17 @@ class AbstractConnection(ConnectionInterface): def _construct_command_packer(self, packer): if packer is not None: return packer - elif HIREDIS_AVAILABLE: + elif HIREDIS_PACK_AVAILABLE: return HiredisRespSerializer() else: return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) - def register_connect_callback(self, callback): - """ - Register a callback to be called when the connection is established either - initially or reconnected. This allows listeners to issue commands that - are ephemeral to the connection, for example pub/sub subscription or - key tracking. The callback must be a _method_ and will be kept as - a weak reference. - """ + def _register_connect_callback(self, callback): wm = weakref.WeakMethod(callback) if wm not in self._connect_callbacks: self._connect_callbacks.append(wm) - def deregister_connect_callback(self, callback): - """ - De-register a previously registered callback. It will no-longer receive - notifications on connection events. Calling this is not required when the - listener goes away, since the callbacks are kept as weak methods. - """ + def _deregister_connect_callback(self, callback): try: self._connect_callbacks.remove(weakref.WeakMethod(callback)) except ValueError: @@ -377,20 +258,12 @@ class AbstractConnection(ConnectionInterface): def connect(self): "Connects to the Redis server if not already connected" - self.connect_check_health(check_health=True) - - def connect_check_health( - self, check_health: bool = True, retry_socket_connect: bool = True - ): if self._sock: return try: - if retry_socket_connect: - sock = self.retry.call_with_retry( - lambda: self._connect(), lambda error: self.disconnect(error) - ) - else: - sock = self._connect() + sock = self.retry.call_with_retry( + lambda: self._connect(), lambda error: self.disconnect(error) + ) except socket.timeout: raise TimeoutError("Timeout connecting to server") except OSError as e: @@ -400,7 +273,7 @@ class AbstractConnection(ConnectionInterface): try: if self.redis_connect_func is None: # Use the default on_connect function - self.on_connect_check_health(check_health=check_health) + self.on_connect() else: # Use the passed function redis_connect_func self.redis_connect_func(self) @@ -426,13 +299,11 @@ class AbstractConnection(ConnectionInterface): def _host_error(self): pass + @abstractmethod def _error_message(self, exception): - return format_error_message(self._host_error(), exception) + pass def on_connect(self): - self.on_connect_check_health(check_health=True) - - def on_connect_check_health(self, check_health: bool = True): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) parser = self._parser @@ -456,12 +327,8 @@ class AbstractConnection(ConnectionInterface): self._parser.on_connect(self) if len(auth_args) == 1: auth_args = ["default", auth_args[0]] - # avoid checking health here -- PING will fail if we try - # to check the health prior to the AUTH - self.send_command( - "HELLO", self.protocol, "AUTH", *auth_args, check_health=False - ) - self.handshake_metadata = self.read_response() + self.send_command("HELLO", self.protocol, "AUTH", *auth_args) + response = self.read_response() # if response.get(b"proto") != self.protocol and response.get( # "proto" # ) != self.protocol: @@ -491,55 +358,38 @@ class AbstractConnection(ConnectionInterface): # update cluster exception classes self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES self._parser.on_connect(self) - self.send_command("HELLO", self.protocol, check_health=check_health) - self.handshake_metadata = self.read_response() + self.send_command("HELLO", self.protocol) + response = self.read_response() if ( - self.handshake_metadata.get(b"proto") != self.protocol - and self.handshake_metadata.get("proto") != self.protocol + response.get(b"proto") != self.protocol + and response.get("proto") != self.protocol ): raise ConnectionError("Invalid RESP version") # if a client_name is given, set it if self.client_name: - self.send_command( - "CLIENT", - "SETNAME", - self.client_name, - check_health=check_health, - ) + self.send_command("CLIENT", "SETNAME", self.client_name) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Error setting client name") try: # set the library name and version if self.lib_name: - self.send_command( - "CLIENT", - "SETINFO", - "LIB-NAME", - self.lib_name, - check_health=check_health, - ) + self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name) self.read_response() except ResponseError: pass try: if self.lib_version: - self.send_command( - "CLIENT", - "SETINFO", - "LIB-VER", - self.lib_version, - check_health=check_health, - ) + self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version) self.read_response() except ResponseError: pass # if a database is specified, switch to it if self.db: - self.send_command("SELECT", self.db, check_health=check_health) + self.send_command("SELECT", self.db) if str_if_bytes(self.read_response()) != "OK": raise ConnectionError("Invalid Database") @@ -555,7 +405,7 @@ class AbstractConnection(ConnectionInterface): if os.getpid() == self.pid: try: conn_sock.shutdown(socket.SHUT_RDWR) - except (OSError, TypeError): + except OSError: pass try: @@ -575,13 +425,13 @@ class AbstractConnection(ConnectionInterface): def check_health(self): """Check the health of the connection with a PING/PONG""" - if self.health_check_interval and time.monotonic() > self.next_health_check: + if self.health_check_interval and time() > self.next_health_check: self.retry.call_with_retry(self._send_ping, self._ping_failed) def send_packed_command(self, command, check_health=True): """Send an already packed command to the Redis server""" if not self._sock: - self.connect_check_health(check_health=False) + self.connect() # guard against health check recursion if check_health: self.check_health() @@ -642,7 +492,7 @@ class AbstractConnection(ConnectionInterface): host_error = self._host_error() try: - if self.protocol in ["3", 3]: + if self.protocol in ["3", 3] and not HIREDIS_AVAILABLE: response = self._parser.read_response( disable_decoding=disable_decoding, push_request=push_request ) @@ -655,7 +505,9 @@ class AbstractConnection(ConnectionInterface): except OSError as e: if disconnect_on_error: self.disconnect() - raise ConnectionError(f"Error while reading from {host_error} : {e.args}") + raise ConnectionError( + f"Error while reading from {host_error}" f" : {e.args}" + ) except BaseException: # Also by default close in case of BaseException. A lot of code # relies on this behaviour when doing Command/Response pairs. @@ -665,7 +517,7 @@ class AbstractConnection(ConnectionInterface): raise if self.health_check_interval: - self.next_health_check = time.monotonic() + self.health_check_interval + self.next_health_check = time() + self.health_check_interval if isinstance(response, ResponseError): try: @@ -708,30 +560,6 @@ class AbstractConnection(ConnectionInterface): output.append(SYM_EMPTY.join(pieces)) return output - def get_protocol(self) -> Union[int, str]: - return self.protocol - - @property - def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: - return self._handshake_metadata - - @handshake_metadata.setter - def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]): - self._handshake_metadata = value - - def set_re_auth_token(self, token: TokenInterface): - self._re_auth_token = token - - def re_auth(self): - if self._re_auth_token is not None: - self.send_command( - "AUTH", - self._re_auth_token.try_get("oid"), - self._re_auth_token.get_value(), - ) - self.read_response() - self._re_auth_token = None - class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -793,10 +621,6 @@ class Connection(AbstractConnection): except OSError as _: err = _ if sock is not None: - try: - sock.shutdown(socket.SHUT_RDWR) # ensure a clean close - except OSError: - pass sock.close() if err is not None: @@ -806,219 +630,26 @@ class Connection(AbstractConnection): def _host_error(self): return f"{self.host}:{self.port}" + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" -class CacheProxyConnection(ConnectionInterface): - DUMMY_CACHE_VALUE = b"foo" - MIN_ALLOWED_VERSION = "7.4.0" - DEFAULT_SERVER_NAME = "redis" + host_error = self._host_error() - def __init__( - self, - conn: ConnectionInterface, - cache: CacheInterface, - pool_lock: threading.RLock, - ): - self.pid = os.getpid() - self._conn = conn - self.retry = self._conn.retry - self.host = self._conn.host - self.port = self._conn.port - self.credential_provider = conn.credential_provider - self._pool_lock = pool_lock - self._cache = cache - self._cache_lock = threading.RLock() - self._current_command_cache_key = None - self._current_options = None - self.register_connect_callback(self._enable_tracking_callback) - - def repr_pieces(self): - return self._conn.repr_pieces() - - def register_connect_callback(self, callback): - self._conn.register_connect_callback(callback) - - def deregister_connect_callback(self, callback): - self._conn.deregister_connect_callback(callback) - - def set_parser(self, parser_class): - self._conn.set_parser(parser_class) - - def connect(self): - self._conn.connect() - - server_name = self._conn.handshake_metadata.get(b"server", None) - if server_name is None: - server_name = self._conn.handshake_metadata.get("server", None) - server_ver = self._conn.handshake_metadata.get(b"version", None) - if server_ver is None: - server_ver = self._conn.handshake_metadata.get("version", None) - if server_ver is None or server_ver is None: - raise ConnectionError("Cannot retrieve information about server version") - - server_ver = ensure_string(server_ver) - server_name = ensure_string(server_name) - - if ( - server_name != self.DEFAULT_SERVER_NAME - or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1 - ): - raise ConnectionError( - "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501 - ) - - def on_connect(self): - self._conn.on_connect() - - def disconnect(self, *args): - with self._cache_lock: - self._cache.flush() - self._conn.disconnect(*args) - - def check_health(self): - self._conn.check_health() - - def send_packed_command(self, command, check_health=True): - # TODO: Investigate if it's possible to unpack command - # or extract keys from packed command - self._conn.send_packed_command(command) - - def send_command(self, *args, **kwargs): - self._process_pending_invalidations() - - with self._cache_lock: - # Command is write command or not allowed - # to be cached. - if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())): - self._current_command_cache_key = None - self._conn.send_command(*args, **kwargs) - return - - if kwargs.get("keys") is None: - raise ValueError("Cannot create cache key.") - - # Creates cache key. - self._current_command_cache_key = CacheKey( - command=args[0], redis_keys=tuple(kwargs.get("keys")) - ) - - with self._cache_lock: - # We have to trigger invalidation processing in case if - # it was cached by another connection to avoid - # queueing invalidations in stale connections. - if self._cache.get(self._current_command_cache_key): - entry = self._cache.get(self._current_command_cache_key) - - if entry.connection_ref != self._conn: - with self._pool_lock: - while entry.connection_ref.can_read(): - entry.connection_ref.read_response(push_request=True) - - return - - # Set temporary entry value to prevent - # race condition from another connection. - self._cache.set( - CacheEntry( - cache_key=self._current_command_cache_key, - cache_value=self.DUMMY_CACHE_VALUE, - status=CacheEntryStatus.IN_PROGRESS, - connection_ref=self._conn, + if len(exception.args) == 1: + try: + return f"Error connecting to {host_error}. \ + {exception.args[0]}." + except AttributeError: + return f"Connection Error: {exception.args[0]}" + else: + try: + return ( + f"Error {exception.args[0]} connecting to " + f"{host_error}. {exception.args[1]}." ) - ) - - # Send command over socket only if it's allowed - # read-only command that not yet cached. - self._conn.send_command(*args, **kwargs) - - def can_read(self, timeout=0): - return self._conn.can_read(timeout) - - def read_response( - self, disable_decoding=False, *, disconnect_on_error=True, push_request=False - ): - with self._cache_lock: - # Check if command response exists in a cache and it's not in progress. - if ( - self._current_command_cache_key is not None - and self._cache.get(self._current_command_cache_key) is not None - and self._cache.get(self._current_command_cache_key).status - != CacheEntryStatus.IN_PROGRESS - ): - res = copy.deepcopy( - self._cache.get(self._current_command_cache_key).cache_value - ) - self._current_command_cache_key = None - return res - - response = self._conn.read_response( - disable_decoding=disable_decoding, - disconnect_on_error=disconnect_on_error, - push_request=push_request, - ) - - with self._cache_lock: - # Prevent not-allowed command from caching. - if self._current_command_cache_key is None: - return response - # If response is None prevent from caching. - if response is None: - self._cache.delete_by_cache_keys([self._current_command_cache_key]) - return response - - cache_entry = self._cache.get(self._current_command_cache_key) - - # Cache only responses that still valid - # and wasn't invalidated by another connection in meantime. - if cache_entry is not None: - cache_entry.status = CacheEntryStatus.VALID - cache_entry.cache_value = response - self._cache.set(cache_entry) - - self._current_command_cache_key = None - - return response - - def pack_command(self, *args): - return self._conn.pack_command(*args) - - def pack_commands(self, commands): - return self._conn.pack_commands(commands) - - @property - def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]: - return self._conn.handshake_metadata - - def _connect(self): - self._conn._connect() - - def _host_error(self): - self._conn._host_error() - - def _enable_tracking_callback(self, conn: ConnectionInterface) -> None: - conn.send_command("CLIENT", "TRACKING", "ON") - conn.read_response() - conn._parser.set_invalidation_push_handler(self._on_invalidation_callback) - - def _process_pending_invalidations(self): - while self.can_read(): - self._conn.read_response(push_request=True) - - def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]): - with self._cache_lock: - # Flush cache when DB flushed on server-side - if data[1] is None: - self._cache.flush() - else: - self._cache.delete_by_redis_keys(data[1]) - - def get_protocol(self): - return self._conn.get_protocol() - - def set_re_auth_token(self, token: TokenInterface): - self._conn.set_re_auth_token(token) - - def re_auth(self): - self._conn.re_auth() + except AttributeError: + return f"Connection Error: {exception.args[0]}" class SSLConnection(Connection): @@ -1034,15 +665,13 @@ class SSLConnection(Connection): ssl_cert_reqs="required", ssl_ca_certs=None, ssl_ca_data=None, - ssl_check_hostname=True, + ssl_check_hostname=False, ssl_ca_path=None, ssl_password=None, ssl_validate_ocsp=False, ssl_validate_ocsp_stapled=False, ssl_ocsp_context=None, ssl_ocsp_expected_cert=None, - ssl_min_version=None, - ssl_ciphers=None, **kwargs, ): """Constructor @@ -1050,7 +679,7 @@ class SSLConnection(Connection): Args: ssl_keyfile: Path to an ssl private key. Defaults to None. ssl_certfile: Path to an ssl certificate. Defaults to None. - ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required), or an ssl.VerifyMode. Defaults to "required". + ssl_cert_reqs: The string value for the SSLContext.verify_mode (none, optional, required). Defaults to "required". ssl_ca_certs: The path to a file of concatenated CA certificates in PEM format. Defaults to None. ssl_ca_data: Either an ASCII string of one or more PEM-encoded certificates or a bytes-like object of DER-encoded certificates. ssl_check_hostname: If set, match the hostname during the SSL handshake. Defaults to False. @@ -1061,8 +690,6 @@ class SSLConnection(Connection): ssl_validate_ocsp_stapled: If set, perform a validation on a stapled ocsp response ssl_ocsp_context: A fully initialized OpenSSL.SSL.Context object to be used in verifying the ssl_ocsp_expected_cert ssl_ocsp_expected_cert: A PEM armoured string containing the expected certificate to be returned from the ocsp verification service. - ssl_min_version: The lowest supported SSL version. It affects the supported SSL versions of the SSLContext. None leaves the default provided by ssl module. - ssl_ciphers: A string listing the ciphers that are allowed to be used. Defaults to None, which means that the default ciphers are used. See https://docs.python.org/3/library/ssl.html#ssl.SSLContext.set_ciphers for more information. Raises: RedisError @@ -1075,7 +702,7 @@ class SSLConnection(Connection): if ssl_cert_reqs is None: ssl_cert_reqs = ssl.CERT_NONE elif isinstance(ssl_cert_reqs, str): - CERT_REQS = { # noqa: N806 + CERT_REQS = { "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, @@ -1089,39 +716,17 @@ class SSLConnection(Connection): self.ca_certs = ssl_ca_certs self.ca_data = ssl_ca_data self.ca_path = ssl_ca_path - self.check_hostname = ( - ssl_check_hostname if self.cert_reqs != ssl.CERT_NONE else False - ) + self.check_hostname = ssl_check_hostname self.certificate_password = ssl_password self.ssl_validate_ocsp = ssl_validate_ocsp self.ssl_validate_ocsp_stapled = ssl_validate_ocsp_stapled self.ssl_ocsp_context = ssl_ocsp_context self.ssl_ocsp_expected_cert = ssl_ocsp_expected_cert - self.ssl_min_version = ssl_min_version - self.ssl_ciphers = ssl_ciphers super().__init__(**kwargs) def _connect(self): - """ - Wrap the socket with SSL support, handling potential errors. - """ + "Wrap the socket with SSL support" sock = super()._connect() - try: - return self._wrap_socket_with_ssl(sock) - except (OSError, RedisError): - sock.close() - raise - - def _wrap_socket_with_ssl(self, sock): - """ - Wraps the socket with SSL support. - - Args: - sock: The plain socket to wrap with SSL. - - Returns: - An SSL wrapped socket. - """ context = ssl.create_default_context() context.check_hostname = self.check_hostname context.verify_mode = self.cert_reqs @@ -1139,10 +744,7 @@ class SSLConnection(Connection): context.load_verify_locations( cafile=self.ca_certs, capath=self.ca_path, cadata=self.ca_data ) - if self.ssl_min_version is not None: - context.minimum_version = self.ssl_min_version - if self.ssl_ciphers: - context.set_ciphers(self.ssl_ciphers) + sslsock = context.wrap_socket(sock, server_hostname=self.host) if self.ssl_validate_ocsp is True and CRYPTOGRAPHY_AVAILABLE is False: raise RedisError("cryptography is not installed.") @@ -1152,8 +754,6 @@ class SSLConnection(Connection): "- not both." ) - sslsock = context.wrap_socket(sock, server_hostname=self.host) - # validation for the stapled case if self.ssl_validate_ocsp_stapled: import OpenSSL @@ -1196,9 +796,9 @@ class UnixDomainSocketConnection(AbstractConnection): "Manages UDS communication to and from a Redis server" def __init__(self, path="", socket_timeout=None, **kwargs): - super().__init__(**kwargs) self.path = path self.socket_timeout = socket_timeout + super().__init__(**kwargs) def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] @@ -1210,22 +810,27 @@ class UnixDomainSocketConnection(AbstractConnection): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.settimeout(self.socket_connect_timeout) - try: - sock.connect(self.path) - except OSError: - # Prevent ResourceWarnings for unclosed sockets. - try: - sock.shutdown(socket.SHUT_RDWR) # ensure a clean close - except OSError: - pass - sock.close() - raise + sock.connect(self.path) sock.settimeout(self.socket_timeout) return sock def _host_error(self): return self.path + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + host_error = self._host_error() + if len(exception.args) == 1: + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{host_error}. {exception.args[1]}." + ) + FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") @@ -1248,7 +853,6 @@ URL_QUERY_ARGUMENT_PARSERS = { "max_connections": int, "health_check_interval": int, "ssl_check_hostname": to_bool, - "timeout": float, } @@ -1274,7 +878,7 @@ def parse_url(url): try: kwargs[name] = parser(value) except (TypeError, ValueError): - raise ValueError(f"Invalid value for '{name}' in connection URL.") + raise ValueError(f"Invalid value for `{name}` in connection URL.") else: kwargs[name] = value @@ -1309,9 +913,6 @@ def parse_url(url): return kwargs -_CP = TypeVar("_CP", bound="ConnectionPool") - - class ConnectionPool: """ Create a connection pool. ``If max_connections`` is set, then this @@ -1321,14 +922,13 @@ class ConnectionPool: By default, TCP connections are created unless ``connection_class`` is specified. Use class:`.UnixDomainSocketConnection` for unix sockets. - :py:class:`~redis.SSLConnection` can be used for SSL enabled connections. Any additional keyword arguments are passed to the constructor of ``connection_class``. """ @classmethod - def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + def from_url(cls, url, **kwargs): """ Return a connection pool configured from the given URL. @@ -1380,7 +980,6 @@ class ConnectionPool: self, connection_class=Connection, max_connections: Optional[int] = None, - cache_factory: Optional[CacheFactoryInterface] = None, **connection_kwargs, ): max_connections = max_connections or 2**31 @@ -1390,34 +989,6 @@ class ConnectionPool: self.connection_class = connection_class self.connection_kwargs = connection_kwargs self.max_connections = max_connections - self.cache = None - self._cache_factory = cache_factory - - if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"): - if connection_kwargs.get("protocol") not in [3, "3"]: - raise RedisError("Client caching is only supported with RESP version 3") - - cache = self.connection_kwargs.get("cache") - - if cache is not None: - if not isinstance(cache, CacheInterface): - raise ValueError("Cache must implement CacheInterface") - - self.cache = cache - else: - if self._cache_factory is not None: - self.cache = self._cache_factory.get_cache() - else: - self.cache = CacheFactory( - self.connection_kwargs.get("cache_config") - ).get_cache() - - connection_kwargs.pop("cache", None) - connection_kwargs.pop("cache_config", None) - - self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) - if self._event_dispatcher is None: - self._event_dispatcher = EventDispatcher() # a lock to protect the critical section in _checkpid(). # this lock is acquired when the process id changes, such as @@ -1427,29 +998,17 @@ class ConnectionPool: # object of this pool. subsequent threads acquiring this lock # will notice the first thread already did the work and simply # release the lock. - - self._fork_lock = threading.RLock() - self._lock = threading.RLock() - + self._fork_lock = threading.Lock() self.reset() - def __repr__(self) -> str: - conn_kwargs = ",".join([f"{k}={v}" for k, v in self.connection_kwargs.items()]) + def __repr__(self) -> (str, str): return ( - f"<{self.__class__.__module__}.{self.__class__.__name__}" - f"(<{self.connection_class.__module__}.{self.connection_class.__name__}" - f"({conn_kwargs})>)>" + f"{type(self).__name__}" + f"<{repr(self.connection_class(**self.connection_kwargs))}>" ) - def get_protocol(self): - """ - Returns: - The RESP protocol version, or ``None`` if the protocol is not specified, - in which case the server default will be used. - """ - return self.connection_kwargs.get("protocol", None) - def reset(self) -> None: + self._lock = threading.Lock() self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() @@ -1512,14 +1071,8 @@ class ConnectionPool: finally: self._fork_lock.release() - @deprecated_args( - args_to_warn=["*"], - reason="Use get_connection() without args instead", - version="5.3.0", - ) - def get_connection(self, command_name=None, *keys, **options) -> "Connection": + def get_connection(self, command_name: str, *keys, **options) -> "Connection": "Get a connection from the pool" - self._checkpid() with self._lock: try: @@ -1536,9 +1089,9 @@ class ConnectionPool: # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if connection.can_read(): raise ConnectionError("Connection has data") - except (ConnectionError, TimeoutError, OSError): + except (ConnectionError, OSError): connection.disconnect() connection.connect() if connection.can_read(): @@ -1560,17 +1113,11 @@ class ConnectionPool: decode_responses=kwargs.get("decode_responses", False), ) - def make_connection(self) -> "ConnectionInterface": + def make_connection(self) -> "Connection": "Create a new connection" if self._created_connections >= self.max_connections: - raise MaxConnectionsError("Too many connections") + raise ConnectionError("Too many connections") self._created_connections += 1 - - if self.cache is not None: - return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1582,18 +1129,15 @@ class ConnectionPool: except KeyError: # Gracefully fail when a connection is returned to this pool # that the pool doesn't actually own - return + pass if self.owns_connection(connection): self._available_connections.append(connection) - self._event_dispatcher.dispatch( - AfterConnectionReleasedEvent(connection) - ) else: - # Pool doesn't own this connection, do not add it back - # to the pool. - # The created connections count should not be changed, - # because the connection was not created by the pool. + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 connection.disconnect() return @@ -1624,36 +1168,13 @@ class ConnectionPool: """Close the pool, disconnecting all connections""" self.disconnect() - def set_retry(self, retry: Retry) -> None: + def set_retry(self, retry: "Retry") -> None: self.connection_kwargs.update({"retry": retry}) for conn in self._available_connections: conn.retry = retry for conn in self._in_use_connections: conn.retry = retry - def re_auth_callback(self, token: TokenInterface): - with self._lock: - for conn in self._available_connections: - conn.retry.call_with_retry( - lambda: conn.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ), - lambda error: self._mock(error), - ) - conn.retry.call_with_retry( - lambda: conn.read_response(), lambda error: self._mock(error) - ) - for conn in self._in_use_connections: - conn.set_re_auth_token(token) - - async def _mock(self, error: RedisError): - """ - Dummy functions, needs to be passed as error callback to retry object. - :param error: - :return: - """ - pass - class BlockingConnectionPool(ConnectionPool): """ @@ -1731,21 +1252,11 @@ class BlockingConnectionPool(ConnectionPool): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - else: - connection = self.connection_class(**self.connection_kwargs) + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection - @deprecated_args( - args_to_warn=["*"], - reason="Use get_connection() without args instead", - version="5.3.0", - ) - def get_connection(self, command_name=None, *keys, **options): + def get_connection(self, command_name, *keys, **options): """ Get a connection, blocking for ``self.timeout`` until a connection is available from the pool. @@ -1785,7 +1296,7 @@ class BlockingConnectionPool(ConnectionPool): try: if connection.can_read(): raise ConnectionError("Connection has data") - except (ConnectionError, TimeoutError, OSError): + except (ConnectionError, OSError): connection.disconnect() connection.connect() if connection.can_read(): diff --git a/venv/lib/python3.12/site-packages/redis/credentials.py b/venv/lib/python3.12/site-packages/redis/credentials.py index 6e59454..7ba26dc 100644 --- a/venv/lib/python3.12/site-packages/redis/credentials.py +++ b/venv/lib/python3.12/site-packages/redis/credentials.py @@ -1,8 +1,4 @@ -import logging -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional, Tuple, Union - -logger = logging.getLogger(__name__) +from typing import Optional, Tuple, Union class CredentialProvider: @@ -13,38 +9,6 @@ class CredentialProvider: def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: raise NotImplementedError("get_credentials must be implemented") - async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]: - logger.warning( - "This method is added for backward compatability. " - "Please override it in your implementation." - ) - return self.get_credentials() - - -class StreamingCredentialProvider(CredentialProvider, ABC): - """ - Credential provider that streams credentials in the background. - """ - - @abstractmethod - def on_next(self, callback: Callable[[Any], None]): - """ - Specifies the callback that should be invoked - when the next credentials will be retrieved. - - :param callback: Callback with - :return: - """ - pass - - @abstractmethod - def on_error(self, callback: Callable[[Exception], None]): - pass - - @abstractmethod - def is_streaming(self) -> bool: - pass - class UsernamePasswordCredentialProvider(CredentialProvider): """ @@ -60,6 +24,3 @@ class UsernamePasswordCredentialProvider(CredentialProvider): if self.username: return self.username, self.password return (self.password,) - - async def get_credentials_async(self) -> Union[Tuple[str], Tuple[str, str]]: - return self.get_credentials() diff --git a/venv/lib/python3.12/site-packages/redis/event.py b/venv/lib/python3.12/site-packages/redis/event.py deleted file mode 100644 index b86c66b..0000000 --- a/venv/lib/python3.12/site-packages/redis/event.py +++ /dev/null @@ -1,394 +0,0 @@ -import asyncio -import threading -from abc import ABC, abstractmethod -from enum import Enum -from typing import List, Optional, Union - -from redis.auth.token import TokenInterface -from redis.credentials import CredentialProvider, StreamingCredentialProvider - - -class EventListenerInterface(ABC): - """ - Represents a listener for given event object. - """ - - @abstractmethod - def listen(self, event: object): - pass - - -class AsyncEventListenerInterface(ABC): - """ - Represents an async listener for given event object. - """ - - @abstractmethod - async def listen(self, event: object): - pass - - -class EventDispatcherInterface(ABC): - """ - Represents a dispatcher that dispatches events to listeners - associated with given event. - """ - - @abstractmethod - def dispatch(self, event: object): - pass - - @abstractmethod - async def dispatch_async(self, event: object): - pass - - -class EventException(Exception): - """ - Exception wrapper that adds an event object into exception context. - """ - - def __init__(self, exception: Exception, event: object): - self.exception = exception - self.event = event - super().__init__(exception) - - -class EventDispatcher(EventDispatcherInterface): - # TODO: Make dispatcher to accept external mappings. - def __init__(self): - """ - Mapping should be extended for any new events or listeners to be added. - """ - self._event_listeners_mapping = { - AfterConnectionReleasedEvent: [ - ReAuthConnectionListener(), - ], - AfterPooledConnectionsInstantiationEvent: [ - RegisterReAuthForPooledConnections() - ], - AfterSingleConnectionInstantiationEvent: [ - RegisterReAuthForSingleConnection() - ], - AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], - AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], - AsyncAfterConnectionReleasedEvent: [ - AsyncReAuthConnectionListener(), - ], - } - - def dispatch(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) - - for listener in listeners: - listener.listen(event) - - async def dispatch_async(self, event: object): - listeners = self._event_listeners_mapping.get(type(event)) - - for listener in listeners: - await listener.listen(event) - - -class AfterConnectionReleasedEvent: - """ - Event that will be fired before each command execution. - """ - - def __init__(self, connection): - self._connection = connection - - @property - def connection(self): - return self._connection - - -class AsyncAfterConnectionReleasedEvent(AfterConnectionReleasedEvent): - pass - - -class ClientType(Enum): - SYNC = ("sync",) - ASYNC = ("async",) - - -class AfterPooledConnectionsInstantiationEvent: - """ - Event that will be fired after pooled connection instances was created. - """ - - def __init__( - self, - connection_pools: List, - client_type: ClientType, - credential_provider: Optional[CredentialProvider] = None, - ): - self._connection_pools = connection_pools - self._client_type = client_type - self._credential_provider = credential_provider - - @property - def connection_pools(self): - return self._connection_pools - - @property - def client_type(self) -> ClientType: - return self._client_type - - @property - def credential_provider(self) -> Union[CredentialProvider, None]: - return self._credential_provider - - -class AfterSingleConnectionInstantiationEvent: - """ - Event that will be fired after single connection instances was created. - - :param connection_lock: For sync client thread-lock should be provided, - for async asyncio.Lock - """ - - def __init__( - self, - connection, - client_type: ClientType, - connection_lock: Union[threading.RLock, asyncio.Lock], - ): - self._connection = connection - self._client_type = client_type - self._connection_lock = connection_lock - - @property - def connection(self): - return self._connection - - @property - def client_type(self) -> ClientType: - return self._client_type - - @property - def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: - return self._connection_lock - - -class AfterPubSubConnectionInstantiationEvent: - def __init__( - self, - pubsub_connection, - connection_pool, - client_type: ClientType, - connection_lock: Union[threading.RLock, asyncio.Lock], - ): - self._pubsub_connection = pubsub_connection - self._connection_pool = connection_pool - self._client_type = client_type - self._connection_lock = connection_lock - - @property - def pubsub_connection(self): - return self._pubsub_connection - - @property - def connection_pool(self): - return self._connection_pool - - @property - def client_type(self) -> ClientType: - return self._client_type - - @property - def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: - return self._connection_lock - - -class AfterAsyncClusterInstantiationEvent: - """ - Event that will be fired after async cluster instance was created. - - Async cluster doesn't use connection pools, - instead ClusterNode object manages connections. - """ - - def __init__( - self, - nodes: dict, - credential_provider: Optional[CredentialProvider] = None, - ): - self._nodes = nodes - self._credential_provider = credential_provider - - @property - def nodes(self) -> dict: - return self._nodes - - @property - def credential_provider(self) -> Union[CredentialProvider, None]: - return self._credential_provider - - -class ReAuthConnectionListener(EventListenerInterface): - """ - Listener that performs re-authentication of given connection. - """ - - def listen(self, event: AfterConnectionReleasedEvent): - event.connection.re_auth() - - -class AsyncReAuthConnectionListener(AsyncEventListenerInterface): - """ - Async listener that performs re-authentication of given connection. - """ - - async def listen(self, event: AsyncAfterConnectionReleasedEvent): - await event.connection.re_auth() - - -class RegisterReAuthForPooledConnections(EventListenerInterface): - """ - Listener that registers a re-authentication callback for pooled connections. - Required by :class:`StreamingCredentialProvider`. - """ - - def __init__(self): - self._event = None - - def listen(self, event: AfterPooledConnectionsInstantiationEvent): - if isinstance(event.credential_provider, StreamingCredentialProvider): - self._event = event - - if event.client_type == ClientType.SYNC: - event.credential_provider.on_next(self._re_auth) - event.credential_provider.on_error(self._raise_on_error) - else: - event.credential_provider.on_next(self._re_auth_async) - event.credential_provider.on_error(self._raise_on_error_async) - - def _re_auth(self, token): - for pool in self._event.connection_pools: - pool.re_auth_callback(token) - - async def _re_auth_async(self, token): - for pool in self._event.connection_pools: - await pool.re_auth_callback(token) - - def _raise_on_error(self, error: Exception): - raise EventException(error, self._event) - - async def _raise_on_error_async(self, error: Exception): - raise EventException(error, self._event) - - -class RegisterReAuthForSingleConnection(EventListenerInterface): - """ - Listener that registers a re-authentication callback for single connection. - Required by :class:`StreamingCredentialProvider`. - """ - - def __init__(self): - self._event = None - - def listen(self, event: AfterSingleConnectionInstantiationEvent): - if isinstance( - event.connection.credential_provider, StreamingCredentialProvider - ): - self._event = event - - if event.client_type == ClientType.SYNC: - event.connection.credential_provider.on_next(self._re_auth) - event.connection.credential_provider.on_error(self._raise_on_error) - else: - event.connection.credential_provider.on_next(self._re_auth_async) - event.connection.credential_provider.on_error( - self._raise_on_error_async - ) - - def _re_auth(self, token): - with self._event.connection_lock: - self._event.connection.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ) - self._event.connection.read_response() - - async def _re_auth_async(self, token): - async with self._event.connection_lock: - await self._event.connection.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ) - await self._event.connection.read_response() - - def _raise_on_error(self, error: Exception): - raise EventException(error, self._event) - - async def _raise_on_error_async(self, error: Exception): - raise EventException(error, self._event) - - -class RegisterReAuthForAsyncClusterNodes(EventListenerInterface): - def __init__(self): - self._event = None - - def listen(self, event: AfterAsyncClusterInstantiationEvent): - if isinstance(event.credential_provider, StreamingCredentialProvider): - self._event = event - event.credential_provider.on_next(self._re_auth) - event.credential_provider.on_error(self._raise_on_error) - - async def _re_auth(self, token: TokenInterface): - for key in self._event.nodes: - await self._event.nodes[key].re_auth_callback(token) - - async def _raise_on_error(self, error: Exception): - raise EventException(error, self._event) - - -class RegisterReAuthForPubSub(EventListenerInterface): - def __init__(self): - self._connection = None - self._connection_pool = None - self._client_type = None - self._connection_lock = None - self._event = None - - def listen(self, event: AfterPubSubConnectionInstantiationEvent): - if isinstance( - event.pubsub_connection.credential_provider, StreamingCredentialProvider - ) and event.pubsub_connection.get_protocol() in [3, "3"]: - self._event = event - self._connection = event.pubsub_connection - self._connection_pool = event.connection_pool - self._client_type = event.client_type - self._connection_lock = event.connection_lock - - if self._client_type == ClientType.SYNC: - self._connection.credential_provider.on_next(self._re_auth) - self._connection.credential_provider.on_error(self._raise_on_error) - else: - self._connection.credential_provider.on_next(self._re_auth_async) - self._connection.credential_provider.on_error( - self._raise_on_error_async - ) - - def _re_auth(self, token: TokenInterface): - with self._connection_lock: - self._connection.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ) - self._connection.read_response() - - self._connection_pool.re_auth_callback(token) - - async def _re_auth_async(self, token: TokenInterface): - async with self._connection_lock: - await self._connection.send_command( - "AUTH", token.try_get("oid"), token.get_value() - ) - await self._connection.read_response() - - await self._connection_pool.re_auth_callback(token) - - def _raise_on_error(self, error: Exception): - raise EventException(error, self._event) - - async def _raise_on_error_async(self, error: Exception): - raise EventException(error, self._event) diff --git a/venv/lib/python3.12/site-packages/redis/exceptions.py b/venv/lib/python3.12/site-packages/redis/exceptions.py index 6434449..7cf15a7 100644 --- a/venv/lib/python3.12/site-packages/redis/exceptions.py +++ b/venv/lib/python3.12/site-packages/redis/exceptions.py @@ -79,24 +79,18 @@ class ModuleError(ResponseError): class LockError(RedisError, ValueError): "Errors acquiring or releasing a lock" - # NOTE: For backwards compatibility, this class derives from ValueError. # This was originally chosen to behave like threading.Lock. - - def __init__(self, message=None, lock_name=None): - self.message = message - self.lock_name = lock_name + pass class LockNotOwnedError(LockError): - "Error trying to extend or release a lock that is not owned (anymore)" - + "Error trying to extend or release a lock that is (no longer) owned" pass class ChildDeadlockedError(Exception): "Error indicating that a child process is deadlocked after a fork()" - pass @@ -221,27 +215,4 @@ class SlotNotCoveredError(RedisClusterException): class MaxConnectionsError(ConnectionError): - """ - Raised when a connection pool has reached its max_connections limit. - This indicates pool exhaustion rather than an actual connection failure. - """ - - pass - - -class CrossSlotTransactionError(RedisClusterException): - """ - Raised when a transaction or watch is triggered in a pipeline - and not all keys or all commands belong to the same slot. - """ - - pass - - -class InvalidPipelineStack(RedisClusterException): - """ - Raised on unexpected response length on pipelines. This is - most likely a handling error on the stack. - """ - - pass + ... diff --git a/venv/lib/python3.12/site-packages/redis/lock.py b/venv/lib/python3.12/site-packages/redis/lock.py index 0288496..4cca102 100644 --- a/venv/lib/python3.12/site-packages/redis/lock.py +++ b/venv/lib/python3.12/site-packages/redis/lock.py @@ -1,4 +1,3 @@ -import logging import threading import time as mod_time import uuid @@ -8,8 +7,6 @@ from typing import Optional, Type from redis.exceptions import LockError, LockNotOwnedError from redis.typing import Number -logger = logging.getLogger(__name__) - class Lock: """ @@ -85,7 +82,6 @@ class Lock: blocking: bool = True, blocking_timeout: Optional[Number] = None, thread_local: bool = True, - raise_on_release_error: bool = True, ): """ Create a new Lock instance named ``name`` using the Redis client @@ -129,11 +125,6 @@ class Lock: thread-1 would see the token value as "xyz" and would be able to successfully release the thread-2's lock. - ``raise_on_release_error`` indicates whether to raise an exception when - the lock is no longer owned when exiting the context manager. By default, - this is True, meaning an exception will be raised. If False, the warning - will be logged and the exception will be suppressed. - In some use cases it's necessary to disable thread local storage. For example, if you have code where one thread acquires a lock and passes that lock instance to a worker thread to release later. If thread @@ -149,7 +140,6 @@ class Lock: self.blocking = blocking self.blocking_timeout = blocking_timeout self.thread_local = bool(thread_local) - self.raise_on_release_error = raise_on_release_error self.local = threading.local() if self.thread_local else SimpleNamespace() self.local.token = None self.register_scripts() @@ -167,10 +157,7 @@ class Lock: def __enter__(self) -> "Lock": if self.acquire(): return self - raise LockError( - "Unable to acquire lock within the time specified", - lock_name=self.name, - ) + raise LockError("Unable to acquire lock within the time specified") def __exit__( self, @@ -178,14 +165,7 @@ class Lock: exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: - try: - self.release() - except LockError: - if self.raise_on_release_error: - raise - logger.warning( - "Lock was unlocked or no longer owned when exiting context manager." - ) + self.release() def acquire( self, @@ -268,10 +248,7 @@ class Lock: """ expected_token = self.local.token if expected_token is None: - raise LockError( - "Cannot release a lock that's not owned or is already unlocked.", - lock_name=self.name, - ) + raise LockError("Cannot release an unlocked lock") self.local.token = None self.do_release(expected_token) @@ -279,12 +256,9 @@ class Lock: if not bool( self.lua_release(keys=[self.name], args=[expected_token], client=self.redis) ): - raise LockNotOwnedError( - "Cannot release a lock that's no longer owned", - lock_name=self.name, - ) + raise LockNotOwnedError("Cannot release a lock that's no longer owned") - def extend(self, additional_time: Number, replace_ttl: bool = False) -> bool: + def extend(self, additional_time: int, replace_ttl: bool = False) -> bool: """ Adds more time to an already acquired lock. @@ -296,12 +270,12 @@ class Lock: `additional_time`. """ if self.local.token is None: - raise LockError("Cannot extend an unlocked lock", lock_name=self.name) + raise LockError("Cannot extend an unlocked lock") if self.timeout is None: - raise LockError("Cannot extend a lock with no timeout", lock_name=self.name) + raise LockError("Cannot extend a lock with no timeout") return self.do_extend(additional_time, replace_ttl) - def do_extend(self, additional_time: Number, replace_ttl: bool) -> bool: + def do_extend(self, additional_time: int, replace_ttl: bool) -> bool: additional_time = int(additional_time * 1000) if not bool( self.lua_extend( @@ -310,10 +284,7 @@ class Lock: client=self.redis, ) ): - raise LockNotOwnedError( - "Cannot extend a lock that's no longer owned", - lock_name=self.name, - ) + raise LockNotOwnedError("Cannot extend a lock that's no longer owned") return True def reacquire(self) -> bool: @@ -321,12 +292,9 @@ class Lock: Resets a TTL of an already acquired lock back to a timeout value. """ if self.local.token is None: - raise LockError("Cannot reacquire an unlocked lock", lock_name=self.name) + raise LockError("Cannot reacquire an unlocked lock") if self.timeout is None: - raise LockError( - "Cannot reacquire a lock with no timeout", - lock_name=self.name, - ) + raise LockError("Cannot reacquire a lock with no timeout") return self.do_reacquire() def do_reacquire(self) -> bool: @@ -336,8 +304,5 @@ class Lock: keys=[self.name], args=[self.local.token, timeout], client=self.redis ) ): - raise LockNotOwnedError( - "Cannot reacquire a lock that's no longer owned", - lock_name=self.name, - ) + raise LockNotOwnedError("Cannot reacquire a lock that's no longer owned") return True diff --git a/venv/lib/python3.12/site-packages/redis/ocsp.py b/venv/lib/python3.12/site-packages/redis/ocsp.py index d69c914..b0420b4 100644 --- a/venv/lib/python3.12/site-packages/redis/ocsp.py +++ b/venv/lib/python3.12/site-packages/redis/ocsp.py @@ -15,7 +15,6 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from cryptography.hazmat.primitives.hashes import SHA1, Hash from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat from cryptography.x509 import ocsp - from redis.exceptions import AuthorizationError, ConnectionError @@ -57,12 +56,12 @@ def _check_certificate(issuer_cert, ocsp_bytes, validate=True): if ocsp_response.response_status == ocsp.OCSPResponseStatus.SUCCESSFUL: if ocsp_response.certificate_status != ocsp.OCSPCertStatus.GOOD: raise ConnectionError( - f"Received an {str(ocsp_response.certificate_status).split('.')[1]} " + f'Received an {str(ocsp_response.certificate_status).split(".")[1]} ' "ocsp certificate status" ) else: raise ConnectionError( - "failed to retrieve a successful response from the ocsp responder" + "failed to retrieve a sucessful response from the ocsp responder" ) if ocsp_response.this_update >= datetime.datetime.now(): @@ -140,7 +139,7 @@ def _get_pubkey_hash(certificate): def ocsp_staple_verifier(con, ocsp_bytes, expected=None): - """An implementation of a function for set_ocsp_client_callback in PyOpenSSL. + """An implemention of a function for set_ocsp_client_callback in PyOpenSSL. This function validates that the provide ocsp_bytes response is valid, and matches the expected, stapled responses. @@ -267,7 +266,7 @@ class OCSPVerifier: return url def check_certificate(self, server, cert, issuer_url): - """Checks the validity of an ocsp server for an issuer""" + """Checks the validitity of an ocsp server for an issuer""" r = requests.get(issuer_url) if not r.ok: diff --git a/venv/lib/python3.12/site-packages/redis/retry.py b/venv/lib/python3.12/site-packages/redis/retry.py index 7577863..6064430 100644 --- a/venv/lib/python3.12/site-packages/redis/retry.py +++ b/venv/lib/python3.12/site-packages/redis/retry.py @@ -1,27 +1,17 @@ -import abc import socket from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable, Tuple, Type, TypeVar from redis.exceptions import ConnectionError, TimeoutError -T = TypeVar("T") -E = TypeVar("E", bound=Exception, covariant=True) -if TYPE_CHECKING: - from redis.backoff import AbstractBackoff - - -class AbstractRetry(Generic[E], abc.ABC): +class Retry: """Retry a specific number of times after a failure""" - _supported_errors: Tuple[Type[E], ...] - def __init__( self, - backoff: "AbstractBackoff", - retries: int, - supported_errors: Tuple[Type[E], ...], + backoff, + retries, + supported_errors=(ConnectionError, TimeoutError, socket.timeout), ): """ Initialize a `Retry` object with a `Backoff` object @@ -34,14 +24,7 @@ class AbstractRetry(Generic[E], abc.ABC): self._retries = retries self._supported_errors = supported_errors - @abc.abstractmethod - def __eq__(self, other: Any) -> bool: - return NotImplemented - - def __hash__(self) -> int: - return hash((self._backoff, self._retries, frozenset(self._supported_errors))) - - def update_supported_errors(self, specified_errors: Iterable[Type[E]]) -> None: + def update_supported_errors(self, specified_errors: list): """ Updates the supported errors with the specified error types """ @@ -49,49 +32,7 @@ class AbstractRetry(Generic[E], abc.ABC): set(self._supported_errors + tuple(specified_errors)) ) - def get_retries(self) -> int: - """ - Get the number of retries. - """ - return self._retries - - def update_retries(self, value: int) -> None: - """ - Set the number of retries. - """ - self._retries = value - - -class Retry(AbstractRetry[Exception]): - __hash__ = AbstractRetry.__hash__ - - def __init__( - self, - backoff: "AbstractBackoff", - retries: int, - supported_errors: Tuple[Type[Exception], ...] = ( - ConnectionError, - TimeoutError, - socket.timeout, - ), - ): - super().__init__(backoff, retries, supported_errors) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, Retry): - return NotImplemented - - return ( - self._backoff == other._backoff - and self._retries == other._retries - and set(self._supported_errors) == set(other._supported_errors) - ) - - def call_with_retry( - self, - do: Callable[[], T], - fail: Callable[[Exception], Any], - ) -> T: + def call_with_retry(self, do, fail): """ Execute an operation that might fail and returns its result, or raise the exception that was thrown depending on the `Backoff` object. diff --git a/venv/lib/python3.12/site-packages/redis/sentinel.py b/venv/lib/python3.12/site-packages/redis/sentinel.py index f12bd8d..41f308d 100644 --- a/venv/lib/python3.12/site-packages/redis/sentinel.py +++ b/venv/lib/python3.12/site-packages/redis/sentinel.py @@ -5,12 +5,8 @@ from typing import Optional from redis.client import Redis from redis.commands import SentinelCommands from redis.connection import Connection, ConnectionPool, SSLConnection -from redis.exceptions import ( - ConnectionError, - ReadOnlyError, - ResponseError, - TimeoutError, -) +from redis.exceptions import ConnectionError, ReadOnlyError, ResponseError, TimeoutError +from redis.utils import str_if_bytes class MasterNotFoundError(ConnectionError): @@ -28,10 +24,7 @@ class SentinelManagedConnection(Connection): def __repr__(self): pool = self.connection_pool - s = ( - f"<{type(self).__module__}.{type(self).__name__}" - f"(service={pool.service_name}%s)>" - ) + s = f"{type(self).__name__}" if self.host: host_info = f",host={self.host},port={self.port}" s = s % host_info @@ -39,11 +32,11 @@ class SentinelManagedConnection(Connection): def connect_to(self, address): self.host, self.port = address - - self.connect_check_health( - check_health=self.connection_pool.check_connection, - retry_socket_connect=False, - ) + super().connect() + if self.connection_pool.check_connection: + self.send_command("PING") + if str_if_bytes(self.read_response()) != "PONG": + raise ConnectionError("PING failed") def _connect_retry(self): if self._sock: @@ -149,11 +142,9 @@ class SentinelConnectionPool(ConnectionPool): def __init__(self, service_name, sentinel_manager, **kwargs): kwargs["connection_class"] = kwargs.get( "connection_class", - ( - SentinelManagedSSLConnection - if kwargs.pop("ssl", False) - else SentinelManagedConnection - ), + SentinelManagedSSLConnection + if kwargs.pop("ssl", False) + else SentinelManagedConnection, ) self.is_master = kwargs.pop("is_master", True) self.check_connection = kwargs.pop("check_connection", False) @@ -171,10 +162,7 @@ class SentinelConnectionPool(ConnectionPool): def __repr__(self): role = "master" if self.is_master else "slave" - return ( - f"<{type(self).__module__}.{type(self).__name__}" - f"(service={self.service_name}({role}))>" - ) + return f"{type(self).__name__}" - ) + return f'{type(self).__name__}' def check_master_state(self, state, service_name): if not state["is_master"] or state["is_sdown"] or state["is_odown"]: @@ -321,13 +293,7 @@ class Sentinel(SentinelCommands): sentinel, self.sentinels[0], ) - - ip = ( - self._force_master_ip - if self._force_master_ip is not None - else state["ip"] - ) - return ip, state["port"] + return state["ip"], state["port"] error_info = "" if len(collected_errors) > 0: @@ -364,8 +330,6 @@ class Sentinel(SentinelCommands): ): """ Returns a redis client instance for the ``service_name`` master. - Sentinel client will detect failover and reconnect Redis clients - automatically. A :py:class:`~redis.sentinel.SentinelConnectionPool` class is used to retrieve the master's address before establishing a new diff --git a/venv/lib/python3.12/site-packages/redis/typing.py b/venv/lib/python3.12/site-packages/redis/typing.py index ede5385..56a1e99 100644 --- a/venv/lib/python3.12/site-packages/redis/typing.py +++ b/venv/lib/python3.12/site-packages/redis/typing.py @@ -7,18 +7,21 @@ from typing import ( Awaitable, Iterable, Mapping, - Protocol, Type, TypeVar, Union, ) +from redis.compat import Protocol + if TYPE_CHECKING: from redis._parsers import Encoder + from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool + from redis.connection import ConnectionPool Number = Union[int, float] -EncodedT = Union[bytes, bytearray, memoryview] +EncodedT = Union[bytes, memoryview] DecodedT = Union[str, int, float] EncodableT = Union[EncodedT, DecodedT] AbsExpiryT = Union[int, datetime] @@ -30,7 +33,6 @@ KeyT = _StringLikeT # Main redis key space PatternT = _StringLikeT # Patterns matched against keys, fields etc FieldT = EncodableT # Fields within hash tables, streams and geo commands KeysT = Union[KeyT, Iterable[KeyT]] -ResponseT = Union[Awaitable[Any], Any] ChannelT = _StringLikeT GroupT = _StringLikeT # Consumer group ConsumerT = _StringLikeT # Consumer name @@ -50,8 +52,14 @@ ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Except class CommandsProtocol(Protocol): - def execute_command(self, *args, **options) -> ResponseT: ... + connection_pool: Union["AsyncConnectionPool", "ConnectionPool"] + + def execute_command(self, *args, **options): + ... -class ClusterCommandsProtocol(CommandsProtocol): +class ClusterCommandsProtocol(CommandsProtocol, Protocol): encoder: "Encoder" + + def execute_command(self, *args, **options) -> Union[Any, Awaitable]: + ... diff --git a/venv/lib/python3.12/site-packages/redis/utils.py b/venv/lib/python3.12/site-packages/redis/utils.py index 79c23c8..01fdfed 100644 --- a/venv/lib/python3.12/site-packages/redis/utils.py +++ b/venv/lib/python3.12/site-packages/redis/utils.py @@ -1,26 +1,18 @@ -import datetime import logging -import textwrap -from collections.abc import Callable +import sys from contextlib import contextmanager from functools import wraps -from typing import Any, Dict, List, Mapping, Optional, TypeVar, Union - -from redis.exceptions import DataError -from redis.typing import AbsExpiryT, EncodableT, ExpiryT +from typing import Any, Dict, Mapping, Union try: import hiredis # noqa - # Only support Hiredis >= 3.0: - hiredis_version = hiredis.__version__.split(".") - HIREDIS_AVAILABLE = int(hiredis_version[0]) > 3 or ( - int(hiredis_version[0]) == 3 and int(hiredis_version[1]) >= 2 - ) - if not HIREDIS_AVAILABLE: - raise ImportError("hiredis package should be >= 3.2.0") + # Only support Hiredis >= 1.0: + HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") + HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") except ImportError: HIREDIS_AVAILABLE = False + HIREDIS_PACK_AVAILABLE = False try: import ssl # noqa @@ -36,7 +28,10 @@ try: except ImportError: CRYPTOGRAPHY_AVAILABLE = False -from importlib import metadata +if sys.version_info >= (3, 8): + from importlib import metadata +else: + import importlib_metadata as metadata def from_url(url, **kwargs): @@ -131,74 +126,6 @@ def deprecated_function(reason="", version="", name=None): return decorator -def warn_deprecated_arg_usage( - arg_name: Union[list, str], - function_name: str, - reason: str = "", - version: str = "", - stacklevel: int = 2, -): - import warnings - - msg = ( - f"Call to '{function_name}' function with deprecated" - f" usage of input argument/s '{arg_name}'." - ) - if reason: - msg += f" ({reason})" - if version: - msg += f" -- Deprecated since version {version}." - warnings.warn(msg, category=DeprecationWarning, stacklevel=stacklevel) - - -C = TypeVar("C", bound=Callable) - - -def deprecated_args( - args_to_warn: list = ["*"], - allowed_args: list = [], - reason: str = "", - version: str = "", -) -> Callable[[C], C]: - """ - Decorator to mark specified args of a function as deprecated. - If '*' is in args_to_warn, all arguments will be marked as deprecated. - """ - - def decorator(func: C) -> C: - @wraps(func) - def wrapper(*args, **kwargs): - # Get function argument names - arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] - - provided_args = dict(zip(arg_names, args)) - provided_args.update(kwargs) - - provided_args.pop("self", None) - for allowed_arg in allowed_args: - provided_args.pop(allowed_arg, None) - - for arg in args_to_warn: - if arg == "*" and len(provided_args) > 0: - warn_deprecated_arg_usage( - list(provided_args.keys()), - func.__name__, - reason, - version, - stacklevel=3, - ) - elif arg in provided_args: - warn_deprecated_arg_usage( - arg, func.__name__, reason, version, stacklevel=3 - ) - - return func(*args, **kwargs) - - return wrapper - - return decorator - - def _set_info_logger(): """ Set up a logger that log info logs to stdout. @@ -218,97 +145,3 @@ def get_lib_version(): except metadata.PackageNotFoundError: libver = "99.99.99" return libver - - -def format_error_message(host_error: str, exception: BaseException) -> str: - if not exception.args: - return f"Error connecting to {host_error}." - elif len(exception.args) == 1: - return f"Error {exception.args[0]} connecting to {host_error}." - else: - return ( - f"Error {exception.args[0]} connecting to {host_error}. " - f"{exception.args[1]}." - ) - - -def compare_versions(version1: str, version2: str) -> int: - """ - Compare two versions. - - :return: -1 if version1 > version2 - 0 if both versions are equal - 1 if version1 < version2 - """ - - num_versions1 = list(map(int, version1.split("."))) - num_versions2 = list(map(int, version2.split("."))) - - if len(num_versions1) > len(num_versions2): - diff = len(num_versions1) - len(num_versions2) - for _ in range(diff): - num_versions2.append(0) - elif len(num_versions1) < len(num_versions2): - diff = len(num_versions2) - len(num_versions1) - for _ in range(diff): - num_versions1.append(0) - - for i, ver in enumerate(num_versions1): - if num_versions1[i] > num_versions2[i]: - return -1 - elif num_versions1[i] < num_versions2[i]: - return 1 - - return 0 - - -def ensure_string(key): - if isinstance(key, bytes): - return key.decode("utf-8") - elif isinstance(key, str): - return key - else: - raise TypeError("Key must be either a string or bytes") - - -def extract_expire_flags( - ex: Optional[ExpiryT] = None, - px: Optional[ExpiryT] = None, - exat: Optional[AbsExpiryT] = None, - pxat: Optional[AbsExpiryT] = None, -) -> List[EncodableT]: - exp_options: list[EncodableT] = [] - if ex is not None: - exp_options.append("EX") - if isinstance(ex, datetime.timedelta): - exp_options.append(int(ex.total_seconds())) - elif isinstance(ex, int): - exp_options.append(ex) - elif isinstance(ex, str) and ex.isdigit(): - exp_options.append(int(ex)) - else: - raise DataError("ex must be datetime.timedelta or int") - elif px is not None: - exp_options.append("PX") - if isinstance(px, datetime.timedelta): - exp_options.append(int(px.total_seconds() * 1000)) - elif isinstance(px, int): - exp_options.append(px) - else: - raise DataError("px must be datetime.timedelta or int") - elif exat is not None: - if isinstance(exat, datetime.datetime): - exat = int(exat.timestamp()) - exp_options.extend(["EXAT", exat]) - elif pxat is not None: - if isinstance(pxat, datetime.datetime): - pxat = int(pxat.timestamp() * 1000) - exp_options.extend(["PXAT", pxat]) - - return exp_options - - -def truncate_text(txt, max_length=100): - return textwrap.shorten( - text=txt, width=max_length, placeholder="...", break_long_words=True - ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/RECORD b/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/RECORD deleted file mode 100644 index b2d951f..0000000 --- a/venv/lib/python3.12/site-packages/sqlalchemy-2.0.43.dist-info/RECORD +++ /dev/null @@ -1,532 +0,0 @@ -sqlalchemy-2.0.43.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -sqlalchemy-2.0.43.dist-info/METADATA,sha256=6StIsiY_vKcG9DPqObgaUSVms9cc12bYmm3KbXl9yMw,9577 -sqlalchemy-2.0.43.dist-info/RECORD,, -sqlalchemy-2.0.43.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -sqlalchemy-2.0.43.dist-info/WHEEL,sha256=aSgG0F4rGPZtV0iTEIfy6dtHq6g67Lze3uLfk0vWn88,151 -sqlalchemy-2.0.43.dist-info/licenses/LICENSE,sha256=mCFyC1jUpWW2EyEAeorUOraZGjlZ5mzV203Z6uacffw,1100 -sqlalchemy-2.0.43.dist-info/top_level.txt,sha256=rp-ZgB7D8G11ivXON5VGPjupT1voYmWqkciDt5Uaw_Q,11 -sqlalchemy/__init__.py,sha256=Oi26seKKS4YLZt2VPHQvkIIoTKAcEWD4BCjHQZRG8BE,12659 -sqlalchemy/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/__pycache__/events.cpython-312.pyc,, -sqlalchemy/__pycache__/exc.cpython-312.pyc,, -sqlalchemy/__pycache__/inspection.cpython-312.pyc,, -sqlalchemy/__pycache__/log.cpython-312.pyc,, -sqlalchemy/__pycache__/schema.cpython-312.pyc,, -sqlalchemy/__pycache__/types.cpython-312.pyc,, -sqlalchemy/connectors/__init__.py,sha256=YeSHsOB0YhdM6jZUvHFQFwKqNXO02MlklmGW0yCywjI,476 -sqlalchemy/connectors/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/connectors/__pycache__/aioodbc.cpython-312.pyc,, -sqlalchemy/connectors/__pycache__/asyncio.cpython-312.pyc,, -sqlalchemy/connectors/__pycache__/pyodbc.cpython-312.pyc,, -sqlalchemy/connectors/aioodbc.py,sha256=-OKbnvR-kLCKHyrOIBkAZwTASAbQZ5qmrozm0dwbtNE,5577 -sqlalchemy/connectors/asyncio.py,sha256=OPhwvKQo7l3CUSY7YsL3W8oBqc_zQIAytIvqLjZLwTA,10122 -sqlalchemy/connectors/pyodbc.py,sha256=ZGWBmYYYVgqUHjex3d_lYHZyAhQJGowp9cWGYnj1200,8618 -sqlalchemy/cyextension/__init__.py,sha256=4npVIjitKfUs0NQ6f3UdQBDq4ipJ0_ZNB2mpKqtc5ik,244 -sqlalchemy/cyextension/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so,sha256=lYDyNVPL2jy2Rb_gGvfhb06-wudWxwj9iQzDTaN3wTg,2011024 -sqlalchemy/cyextension/collections.pyx,sha256=L7DZ3DGKpgw2MT2ZZRRxCnrcyE5pU1NAFowWgAzQPEc,12571 -sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so,sha256=OlABGDt9_2Qlbh57tcrsiJOKAx8uh5oWarXKkeSwKy8,813560 -sqlalchemy/cyextension/immutabledict.pxd,sha256=3x3-rXG5eRQ7bBnktZ-OJ9-6ft8zToPmTDOd92iXpB0,291 -sqlalchemy/cyextension/immutabledict.pyx,sha256=KfDTYbTfebstE8xuqAtuXsHNAK0_b5q_ymUiinUe_xs,3535 -sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so,sha256=zkD9YxSNVRP3CLJdpTH8BAAOA8R5Ga_5vYw5HLmO72o,613448 -sqlalchemy/cyextension/processors.pyx,sha256=R1rHsGLEaGeBq5VeCydjClzYlivERIJ9B-XLOJlf2MQ,1792 -sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so,sha256=L83Uxcx_wzzbrYE4xDBVrYxjBa99lQVtd8JfSwayiOA,631984 -sqlalchemy/cyextension/resultproxy.pyx,sha256=eWLdyBXiBy_CLQrF5ScfWJm7X0NeelscSXedtj1zv9Q,2725 -sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so,sha256=XS0EThpDqN2vddpA8aamUzaqOAgutMk8Gz0m8lq2fKM,990328 -sqlalchemy/cyextension/util.pyx,sha256=Tt5VwTUtO3YKQK2PHfYOLhV2Jr5GMRJcp2DzH4fjGOs,2569 -sqlalchemy/dialects/__init__.py,sha256=oOkVOr98g-6jxaUXld8szIgxkXMBae5IPfAzBrcpLaw,1798 -sqlalchemy/dialects/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/__pycache__/_typing.cpython-312.pyc,, -sqlalchemy/dialects/_typing.py,sha256=8YwrkOa8IvmBojwwegbL5mL_0UAuzdqYiKHKANpvHMw,971 -sqlalchemy/dialects/mssql/__init__.py,sha256=6t_aNpgbMLdPE9gpHYTf9o6QfVavncztRLbr21l2NaY,1880 -sqlalchemy/dialects/mssql/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/aioodbc.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/base.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/information_schema.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/json.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/pymssql.cpython-312.pyc,, -sqlalchemy/dialects/mssql/__pycache__/pyodbc.cpython-312.pyc,, -sqlalchemy/dialects/mssql/aioodbc.py,sha256=4CmhwIkZrabpG-r7_ogRVajD-nhRZSFJ0Swz2d0jIHM,2021 -sqlalchemy/dialects/mssql/base.py,sha256=bsDGdlI9UJ3o_K_FQm-lryn28Gjcss8jpiUwV-rduwo,133927 -sqlalchemy/dialects/mssql/information_schema.py,sha256=CDNPC1ZDjj-DumMgzZdm1oNY6FiO-_Fn2DWJuPVnni0,8963 -sqlalchemy/dialects/mssql/json.py,sha256=F53pibuOVRzgDtjoclOI7LnkKXNVsaVfJyBH1XAhyDo,4756 -sqlalchemy/dialects/mssql/provision.py,sha256=P1tqxZ4f6Oeqn2gNi7dXl82LRLCg1-OB4eWiZc6CHek,5593 -sqlalchemy/dialects/mssql/pymssql.py,sha256=C7yAs3Pw81W1KTVNc6_0sHQuYlJ5iH82vKByY4TkB1g,4097 -sqlalchemy/dialects/mssql/pyodbc.py,sha256=CnO7KDWxbxb7AoZhp_PMDBvVSMuzwq1h4Cav2IWFWDo,27173 -sqlalchemy/dialects/mysql/__init__.py,sha256=ropOMUWrAcL-Q7h-9jQ_tb3ISAFIsNRQ8YVXvn0URl0,2206 -sqlalchemy/dialects/mysql/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/aiomysql.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/asyncmy.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/base.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/cymysql.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/dml.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/enumerated.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/expression.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/json.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/mariadb.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/mariadbconnector.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/mysqlconnector.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/mysqldb.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/pymysql.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/pyodbc.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/reflection.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/reserved_words.cpython-312.pyc,, -sqlalchemy/dialects/mysql/__pycache__/types.cpython-312.pyc,, -sqlalchemy/dialects/mysql/aiomysql.py,sha256=XpHS7KvZF_XQFlghvqyZfPuLD890M7GTgMLCaeXA67E,7728 -sqlalchemy/dialects/mysql/asyncmy.py,sha256=kuX02tRZ-0kKbwRgs3dL5T-mRyc5oBSFoIzQDgaHgYk,7093 -sqlalchemy/dialects/mysql/base.py,sha256=V2CE2XB6eiFG3doNdzH3NZPhgXgt3OL7QN8F3dg_9Pg,137763 -sqlalchemy/dialects/mysql/cymysql.py,sha256=ihH4kZ273nvf0R0p8keD71ZIaTXRHyZePXMlobwgbpI,3215 -sqlalchemy/dialects/mysql/dml.py,sha256=VjnTobe_SBNF2RN6tvqa5LOn-9x4teVUyzUedZkOmdc,7768 -sqlalchemy/dialects/mysql/enumerated.py,sha256=si2hGv5jMNGS78n_JDgswIhbBZuTqjwbxjiWg5ZUdy4,10292 -sqlalchemy/dialects/mysql/expression.py,sha256=C8LhU-CM6agqKCS1tl1_ChSqwZbqt3zP_dSGBqgBgLg,4241 -sqlalchemy/dialects/mysql/json.py,sha256=ckYT_lihvqr28iHJTUUwvPPUIoYVLL_wUXWFDTCna_M,2806 -sqlalchemy/dialects/mysql/mariadb.py,sha256=yaiZnnbjfrBqHm1ykaRSFYKrrYUqu-GBYvt97EGYSzs,1886 -sqlalchemy/dialects/mysql/mariadbconnector.py,sha256=lJuS3euMlVBbJDJ10ntqe3TnrjzneLEUlE8sLZl6Qoc,10385 -sqlalchemy/dialects/mysql/mysqlconnector.py,sha256=aaAiF32rQVoLNVIdgGKHMsnMei--0ig3OqmhWq45MrA,10097 -sqlalchemy/dialects/mysql/mysqldb.py,sha256=8wIxcxQxT-X6nywLJkjg9_JdIKGYOhlrtVL8lP_WFcM,9943 -sqlalchemy/dialects/mysql/provision.py,sha256=MaQ9eeHnRL4EXAebIInwarCIiDbYcz_sMCss3wyV12Q,3717 -sqlalchemy/dialects/mysql/pymysql.py,sha256=Qlc9XToIqAfHz0c_ODs97uk1TlV1ZrEl_TidTjoeByU,4886 -sqlalchemy/dialects/mysql/pyodbc.py,sha256=v-Zo4M7blxdff--KJiIantCwbPO6H-GBkNCTN4nBgU4,5111 -sqlalchemy/dialects/mysql/reflection.py,sha256=CBxBiv1mCLLNHz-I8hgJKACTF3K0eYEpWd0ndCBCq5I,24690 -sqlalchemy/dialects/mysql/reserved_words.py,sha256=iG6zb78sn-RdqWQRk2F_Tuufk5tUodkcoHbxTdgZYkw,9236 -sqlalchemy/dialects/mysql/types.py,sha256=lAkkNRVPBHP8H7AQQ7NykfJ8YxgdUDAHkfd7qD-Lwvo,26459 -sqlalchemy/dialects/oracle/__init__.py,sha256=5qrJcFTF3vgB9B4PkwBJj3iXE7P57LdaHNkxMa1NXug,1898 -sqlalchemy/dialects/oracle/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/base.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/cx_oracle.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/dictionary.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/oracledb.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/types.cpython-312.pyc,, -sqlalchemy/dialects/oracle/__pycache__/vector.cpython-312.pyc,, -sqlalchemy/dialects/oracle/base.py,sha256=zEl885-lRs07FGdWFuSzBfa1FqrUPT7l2wpcBr9joIs,139156 -sqlalchemy/dialects/oracle/cx_oracle.py,sha256=mYrXD0nJzuTY1h878b50fNXIUBgjc9Q1LJjjY1VHx3w,56717 -sqlalchemy/dialects/oracle/dictionary.py,sha256=J7tGVE0KyUPZKpPLOary3HdDq1DWd29arF5udLgv8_o,19519 -sqlalchemy/dialects/oracle/oracledb.py,sha256=veqto1AUIbSxRmpUQin0ysMV8Y6sWAkzXt7W8IIl118,33771 -sqlalchemy/dialects/oracle/provision.py,sha256=ga1gNQZlXZKk7DYuYegllUejJxZXRKDGa7dbi_S_poc,8313 -sqlalchemy/dialects/oracle/types.py,sha256=axN6Yidx9tGRIUAbDpBrhMWXE-C8jSllFpTghpGOOzU,9058 -sqlalchemy/dialects/oracle/vector.py,sha256=YtN7E5TbDIQR2FCICaSeeaOnvzHP_O0mXNq1gk02S4Q,10874 -sqlalchemy/dialects/postgresql/__init__.py,sha256=kD8W-SV5e2CesvWg2MQAtncXuZFwGPfR_UODvmRXE08,3892 -sqlalchemy/dialects/postgresql/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/_psycopg_common.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/array.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/asyncpg.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/base.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/dml.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/ext.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/hstore.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/json.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/named_types.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/operators.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/pg8000.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/pg_catalog.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/psycopg.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/psycopg2.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/psycopg2cffi.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/ranges.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/__pycache__/types.cpython-312.pyc,, -sqlalchemy/dialects/postgresql/_psycopg_common.py,sha256=h4JmkHWxy_Nspn6Bi9YKpa9l0OkwInwQzYKue-fJnVA,5783 -sqlalchemy/dialects/postgresql/array.py,sha256=FyyJ1f3RSAhHtgxKydfMkUAGEh-LJyLOZ31jiAdDo74,16956 -sqlalchemy/dialects/postgresql/asyncpg.py,sha256=QPvyV6YYZ9--ULoMYC5pl7axct79H8DbYrKAUQASqzg,41548 -sqlalchemy/dialects/postgresql/base.py,sha256=RDuehOZL3hLPhq4_7G-91BgAM9LeToHiiIU-RjFGVmU,186421 -sqlalchemy/dialects/postgresql/dml.py,sha256=2SmyMeYveAgm7OnT_CJvwad2nh8BP37yT6gFs8dBYN8,12126 -sqlalchemy/dialects/postgresql/ext.py,sha256=voxpAz-zoCOO-fjpCzrw7UASzNIvdz2u4kFSuGcshlI,17347 -sqlalchemy/dialects/postgresql/hstore.py,sha256=wR4gmvfQWPssHwYTXEsPJTb4LkBS6x4e4XXE6smtDH4,11934 -sqlalchemy/dialects/postgresql/json.py,sha256=YO6yuDnUKh-mHNtc7DavFMpYNUrJ_dNb24gw333uH0M,12842 -sqlalchemy/dialects/postgresql/named_types.py,sha256=D1WFTcxE-PKYRaB75gWvnAvpgGJRTcFkW9nSGpC4WCo,17812 -sqlalchemy/dialects/postgresql/operators.py,sha256=ay3ckNsWtqDjxDseTdKMGGqYVzST6lmfhbbYHG_bxCw,2808 -sqlalchemy/dialects/postgresql/pg8000.py,sha256=r6Lg5tgwuf4FE_RA_kHcfHPW5GXUdNWWr3E846Z4aI0,18743 -sqlalchemy/dialects/postgresql/pg_catalog.py,sha256=wnzFm9S0JFag1TBdySDJH3VOFSkJWmwAjVcIAQ25jHg,9999 -sqlalchemy/dialects/postgresql/provision.py,sha256=7pg9-nOnaK5XBzqByXNPuvi3rxtnRa3dJxdSPVq4eeA,5770 -sqlalchemy/dialects/postgresql/psycopg.py,sha256=k7zXsJj35aOXCrhsbMxwTQX5JWegrqirFJ1Hgbq-GjQ,23326 -sqlalchemy/dialects/postgresql/psycopg2.py,sha256=1KXw9RzsQEAXJazCBywdP5CwLu-HsCSDAD_Khc_rPTM,32032 -sqlalchemy/dialects/postgresql/psycopg2cffi.py,sha256=nKilJfvO9mJwk5NRw5iZDekKY5vi379tvdUJ2vn5eyQ,1756 -sqlalchemy/dialects/postgresql/ranges.py,sha256=rsvhfZ63OVtHHeBDXb_6hULg0HkVx18hkChfoznlhcg,32946 -sqlalchemy/dialects/postgresql/types.py,sha256=oKhDsFiITKbZcCP66L3dhif54pmsFvVfv-MZQWA3sYo,7629 -sqlalchemy/dialects/sqlite/__init__.py,sha256=6Xcz3nPsl8lqCcZ4-VzPRmkMrkKgAp2buKsClZelU7c,1182 -sqlalchemy/dialects/sqlite/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/aiosqlite.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/base.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/dml.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/json.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/pysqlcipher.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/__pycache__/pysqlite.cpython-312.pyc,, -sqlalchemy/dialects/sqlite/aiosqlite.py,sha256=eZW4NFpLS6z02keIHeJLI5tFUkzhn0MpS8r2kkl0G0I,14619 -sqlalchemy/dialects/sqlite/base.py,sha256=rRYahtQDySw-4v6ljEomUdvjigGTNXqaqPuiQ5eOpa4,102859 -sqlalchemy/dialects/sqlite/dml.py,sha256=4N8qh06RuMphLoQgWw7wv5nXIrka57jIFvK2x9xTZqg,9138 -sqlalchemy/dialects/sqlite/json.py,sha256=A62xPyLRZxl2hvgTMM92jd_7jlw9UE_4Y6Udqt-8g04,2777 -sqlalchemy/dialects/sqlite/provision.py,sha256=VhqDjDALqxKQY_3Z3hjzkmPQJ-vtk2Dkk1A4qLTs-G8,5596 -sqlalchemy/dialects/sqlite/pysqlcipher.py,sha256=di8rYryfL0KAn3pRGepmunHyIRGy-4Hhr-2q_ehPzss,5371 -sqlalchemy/dialects/sqlite/pysqlite.py,sha256=42jPDi1nZ_9YVKKWaKnkurL8NOFUX_8Rbn7baqRw0J8,25999 -sqlalchemy/dialects/type_migration_guidelines.txt,sha256=-uHNdmYFGB7bzUNT6i8M5nb4j6j9YUKAtW4lcBZqsMg,8239 -sqlalchemy/engine/__init__.py,sha256=EF4haWCPu95WtWx1GzcHRJ_bBmtJMznno3I2TQ-ZIHE,2818 -sqlalchemy/engine/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/_py_processors.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/_py_row.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/_py_util.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/base.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/characteristics.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/create.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/cursor.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/default.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/events.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/interfaces.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/mock.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/processors.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/reflection.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/result.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/row.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/strategies.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/url.cpython-312.pyc,, -sqlalchemy/engine/__pycache__/util.cpython-312.pyc,, -sqlalchemy/engine/_py_processors.py,sha256=7QxgkVOd5h1Qd22qFh-pPZdM7RBRzNjj8lWAMWrilcI,3744 -sqlalchemy/engine/_py_row.py,sha256=yNdrZe36yw6mO7x0OEbG0dGojH7CQkNReIwn9LMUPUs,3787 -sqlalchemy/engine/_py_util.py,sha256=Nvd4pVdXRs89khRevK-Ux4Y9p2f2vnALboNrSwhqS1U,2465 -sqlalchemy/engine/base.py,sha256=aNp2tGNBWlBz2pHiOveJ3PeaJRDJlLknekUQ50MJDjU,123090 -sqlalchemy/engine/characteristics.py,sha256=PepmGApo1sL01dS1qtSbmHplu9ZCdtuSegiGI7L7NZY,4765 -sqlalchemy/engine/create.py,sha256=uIAiU-ANj7fk_6A3dbJw_SEU8Qfd0_YF8yEHGxD0r1g,33847 -sqlalchemy/engine/cursor.py,sha256=63KLS-IKKAYh2uADJytpT1i9-qpG9E0iVBIcKTtKkwI,76567 -sqlalchemy/engine/default.py,sha256=PpySUqbAliGjw80ZxhDdZwyiFEMCpNPcC1XmyJynyEE,85721 -sqlalchemy/engine/events.py,sha256=4_e6Ip32ar2Eb27R4ipamiKC-7Tpg4lVz3txabhT5Rc,37400 -sqlalchemy/engine/interfaces.py,sha256=fNGMov1byIOkPxh7dJervp-UUNyHHm3jpIB0HrCMucc,115119 -sqlalchemy/engine/mock.py,sha256=L07bSIkgEbIkih-pYvFWh7k7adHVp5tBFBekKlD7GHs,4156 -sqlalchemy/engine/processors.py,sha256=XK32bULBkuVVRa703u4-SrTCDi_a18Dxq1M09QFBEPw,2379 -sqlalchemy/engine/reflection.py,sha256=QNOAXvKtdzVddpbkMOyM380y3olKdJKQkmF0Bfwia-Q,75565 -sqlalchemy/engine/result.py,sha256=46J3rP0ZwDwsqU-4CAaEHXTpx8OqCEP9Dy4LQwtHUEg,77805 -sqlalchemy/engine/row.py,sha256=BPtAwsceiRxB9ANpDNM24uQ1M_Zs0xFkSXoKR_I8xyY,12031 -sqlalchemy/engine/strategies.py,sha256=3DixBdeTa824XjuID2o7UxIyg7GyNwdBI8hOOT0SQnc,439 -sqlalchemy/engine/url.py,sha256=GJfZo0KtbMtkOIHBPI_KcKASsyrI5UYkX-UoN62FQxc,31067 -sqlalchemy/engine/util.py,sha256=4OmXwFlmnq6_vBlfUBHnz5LrI_8bT3TwgynX4wcJfnw,5682 -sqlalchemy/event/__init__.py,sha256=ZjVxFGbt9neH5AC4GFiUN5IG2O4j6Z9v2LdmyagJi9w,997 -sqlalchemy/event/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/event/__pycache__/api.cpython-312.pyc,, -sqlalchemy/event/__pycache__/attr.cpython-312.pyc,, -sqlalchemy/event/__pycache__/base.cpython-312.pyc,, -sqlalchemy/event/__pycache__/legacy.cpython-312.pyc,, -sqlalchemy/event/__pycache__/registry.cpython-312.pyc,, -sqlalchemy/event/api.py,sha256=x-VlMFJXzubD6fuB4VRTTeAJeeQNUZ5jHZXD1aL0Qkg,8109 -sqlalchemy/event/attr.py,sha256=YhPXVBPj63Cfyn0nS6h8Ljq0SEbD3mtAZn9HYlzGbtw,20751 -sqlalchemy/event/base.py,sha256=g5eRGX4e949srBK2gUxLYM0RrDUdtUEPS2FT_9IKZeI,15254 -sqlalchemy/event/legacy.py,sha256=lGafKAOF6PY8Bz0AqhN9Q6n-lpXqFLwdv-0T6-UBpow,8227 -sqlalchemy/event/registry.py,sha256=MNEMyR8HZhzQFgxk4Jk_Em6nXTihmGXiSIwPdUnalPM,11144 -sqlalchemy/events.py,sha256=VBRvtckn9JS3tfUfi6UstqUrvQ15J2xamcDByFysIrI,525 -sqlalchemy/exc.py,sha256=AjFBCrOl_V4vQdGegn72Y951RSRMPL6T5qjxnFTGFbM,23978 -sqlalchemy/ext/__init__.py,sha256=BkTNuOg454MpCY9QA3FLK8td7KQhD1W74fOEXxnWibE,322 -sqlalchemy/ext/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/associationproxy.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/automap.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/baked.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/compiler.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/horizontal_shard.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/hybrid.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/indexable.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/instrumentation.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/mutable.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/orderinglist.cpython-312.pyc,, -sqlalchemy/ext/__pycache__/serializer.cpython-312.pyc,, -sqlalchemy/ext/associationproxy.py,sha256=QAo0GssILBua9wRNT3gajwZMEct3KCCu-gWVtAG-MA0,66442 -sqlalchemy/ext/asyncio/__init__.py,sha256=kTIfpwsHWhqZ-VMOBZFBq66kt1XeF0hNuwOToEDe4_Y,1317 -sqlalchemy/ext/asyncio/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/base.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/engine.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/exc.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/result.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/scoping.cpython-312.pyc,, -sqlalchemy/ext/asyncio/__pycache__/session.cpython-312.pyc,, -sqlalchemy/ext/asyncio/base.py,sha256=40VvRDZqVW_WQ1o-CRaB4c8Zx37rmiLGfQm4PNXWwdQ,9033 -sqlalchemy/ext/asyncio/engine.py,sha256=mMuD_Yq-BdVR5gUchSQzR1TI6mkov9bhtlqFnhvntdI,48321 -sqlalchemy/ext/asyncio/exc.py,sha256=npijuILDXH2p4Q5RzhHzutKwZ5CjtqTcP-U0h9TZUmk,639 -sqlalchemy/ext/asyncio/result.py,sha256=SqG9K9ar9AhzDQDIzt6tu60SoBu63uY1Hlzc7k1GtKQ,30548 -sqlalchemy/ext/asyncio/scoping.py,sha256=5DDH3Ne54yYLHIGaWVxS390JlHn0h3OvH5pj-dGrW_s,52570 -sqlalchemy/ext/asyncio/session.py,sha256=BzwqmXGEdT4K9WMxM6SO_d_xq9eCIatD4yl30nUSybk,63743 -sqlalchemy/ext/automap.py,sha256=n88mktqvExwjqfsDu3yLIA4wbOIWUpQ1S35Uw3X6ffQ,61675 -sqlalchemy/ext/baked.py,sha256=w3SeRoqnPkIhPL2nRAxfVhyir2ypsiW4kmtmUGKs8qo,17753 -sqlalchemy/ext/compiler.py,sha256=f7o4qhUUldpsx4F1sQoUvdVaT2BhiemqNBCF4r_uQUo,20889 -sqlalchemy/ext/declarative/__init__.py,sha256=SuVflXOGDxx2sB2QSTqNEvqS0fyhOkh3-sy2lRsSOLA,1818 -sqlalchemy/ext/declarative/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/ext/declarative/__pycache__/extensions.cpython-312.pyc,, -sqlalchemy/ext/declarative/extensions.py,sha256=yHUPcztU-5E1JrNyELDFWKchAnaYK6Y9-dLcqyc1nUI,19531 -sqlalchemy/ext/horizontal_shard.py,sha256=vouIehpQAuwT0HXyWyynTL3m_gcBuLcB-X8lDB0uQ8U,16691 -sqlalchemy/ext/hybrid.py,sha256=DkvNGtiQYzlEBvs1rYEDXhM8vJEXXh_6DMigsHH9w4k,52531 -sqlalchemy/ext/indexable.py,sha256=AfRoQgBWUKfTxx4jnRaQ97ex8k2FsJLQqc2eKK3ps-k,11066 -sqlalchemy/ext/instrumentation.py,sha256=iCp89rvfK7buW0jJyzKTBDKyMsd06oTRJDItOk4OVSw,15707 -sqlalchemy/ext/mutable.py,sha256=J8ix6T51DkVfr9XDe93Md_92Zf6tzXmdEjMiyORX90E,37603 -sqlalchemy/ext/mypy/__init__.py,sha256=yVNtoBDNeTl1sqRoA_fSY3o1g6M8NxqUVvAHPRLmFTw,241 -sqlalchemy/ext/mypy/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/apply.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/decl_class.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/infer.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/names.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/plugin.cpython-312.pyc,, -sqlalchemy/ext/mypy/__pycache__/util.cpython-312.pyc,, -sqlalchemy/ext/mypy/apply.py,sha256=v_Svc1WiBz9yBXqBVBKoCuPGN286TfVmuuCVZPlbyzo,10591 -sqlalchemy/ext/mypy/decl_class.py,sha256=Nuca4ofHkASAkdqEQlULYB7iLm_KID7Mp384seDhVGg,17384 -sqlalchemy/ext/mypy/infer.py,sha256=29vgn22Hi8E8oIZL6UJCBl6oipiPSAQjxccCEkVb410,19367 -sqlalchemy/ext/mypy/names.py,sha256=_Q7J_F8KBSMHcVRw746fsosSJ3RAdDL6RpGAuGa-XJA,10480 -sqlalchemy/ext/mypy/plugin.py,sha256=9YHBp0Bwo92DbDZIUWwIr0hwXPcE4XvHs0-xshvSwUw,9750 -sqlalchemy/ext/mypy/util.py,sha256=CuW2fJ-g9YtkjcypzmrPRaFc-rAvQTzW5A2-w5VTANg,9960 -sqlalchemy/ext/orderinglist.py,sha256=LDHIRpMbl8w0mjDuz6phjnWhApmLRU0PrqouVUDTu-I,15163 -sqlalchemy/ext/serializer.py,sha256=_z95wZMTn3G3sCGN52gwzD4CuKjrhGMr5Eu8g9MxQNg,6169 -sqlalchemy/future/__init__.py,sha256=R1h8VBwMiIUdP3QHv_tFNby557425FJOAGhUoXGvCmc,512 -sqlalchemy/future/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/future/__pycache__/engine.cpython-312.pyc,, -sqlalchemy/future/engine.py,sha256=2nJFBQAXAE8pqe1cs-D3JjC6wUX2ya2h2e_tniuaBq0,495 -sqlalchemy/inspection.py,sha256=qKEKG37N1OjxpQeVzob1q9VwWjBbjI1x0movJG7fYJ4,5063 -sqlalchemy/log.py,sha256=e_ztNUfZM08FmTWeXN9-doD5YKW44nXxgKCUxxNs6Ow,8607 -sqlalchemy/orm/__init__.py,sha256=BICvTXpLaTNe2AiUaxnZHWzjL5miT9fd_IU-ip3OFNk,8463 -sqlalchemy/orm/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/_orm_constructors.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/_typing.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/attributes.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/base.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/bulk_persistence.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/clsregistry.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/collections.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/context.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/decl_api.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/decl_base.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/dependency.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/descriptor_props.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/dynamic.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/evaluator.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/events.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/exc.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/identity.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/instrumentation.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/interfaces.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/loading.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/mapped_collection.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/mapper.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/path_registry.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/persistence.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/properties.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/query.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/relationships.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/scoping.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/session.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/state.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/state_changes.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/strategies.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/strategy_options.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/sync.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/unitofwork.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/util.cpython-312.pyc,, -sqlalchemy/orm/__pycache__/writeonly.cpython-312.pyc,, -sqlalchemy/orm/_orm_constructors.py,sha256=0pVhF06N8RHm3P418xpkZOBwKtrUsY7sQI2xz0f8zT4,105600 -sqlalchemy/orm/_typing.py,sha256=vaYRl4_K3n-sjc9u0Rb4eWWpBOoOi92--OHqaGogRvA,4973 -sqlalchemy/orm/attributes.py,sha256=oh9lKob8z-wChCQuAnW6MokQcaah6x9mNQI9_jbAX7Q,93117 -sqlalchemy/orm/base.py,sha256=J8rTiYm2xTyjTCJdSzaZRh8zasOiIK9FVXtFUits8AU,27501 -sqlalchemy/orm/bulk_persistence.py,sha256=evxOQKnfLRaByNXkudFyH8uFPmtVlCjP80CiIT4Lyb8,72984 -sqlalchemy/orm/clsregistry.py,sha256=-ZD3iO6qXropVH3gSf1nouKWG_xwMl_z5SE6sqOaYOA,17952 -sqlalchemy/orm/collections.py,sha256=cIoXIagPBv4B-TQN7BJssGwQcU0SgEhnKa6wLWsitys,52281 -sqlalchemy/orm/context.py,sha256=9OOJxvXJ_01Sd5-wny-WqVGtak4IA78TyLG_zMOHYmA,115082 -sqlalchemy/orm/decl_api.py,sha256=ViRNRYA1jXcxJCX2UPW7ugymozqbV55WbIj1c96XPpQ,65038 -sqlalchemy/orm/decl_base.py,sha256=N13zJJ0Yejcwu0yOWz8WI38ab56WTeHioYr2PlRCal0,83486 -sqlalchemy/orm/dependency.py,sha256=eiYTsSnW94uGXEFQWj6-KFn25ivz_a2dPN3P6_nMou4,47619 -sqlalchemy/orm/descriptor_props.py,sha256=dh97zKu5-OHDNEhHA3H2YHwdpT8wVT06faeHDzED4pk,37795 -sqlalchemy/orm/dynamic.py,sha256=Z4GpcVL8rM8gi0bytQOZXw-_kKi-sExbRWGjU30dK3g,9816 -sqlalchemy/orm/evaluator.py,sha256=PKrUW1zEOvmv1XEgc_hBdYqNcyk4zjWr_rJhCEQBFIc,12353 -sqlalchemy/orm/events.py,sha256=rdqxmaiaZ7MZ5LQwY5cz6irLkGpJzr1C66zkTsW-QgA,127780 -sqlalchemy/orm/exc.py,sha256=V7cUPl9Kw4qZHLyjOvU1C5WMJ-0MKpNN10qM0C0YG5Y,7636 -sqlalchemy/orm/identity.py,sha256=5NFtF9ZPZWAOmtOqCPyVX2-_pQq9A5XeN2ns3Wirpv8,9249 -sqlalchemy/orm/instrumentation.py,sha256=WhElvvOWOn3Fuc-Asc5HmcKDX6EzFtBleLJKPZEc5A0,24321 -sqlalchemy/orm/interfaces.py,sha256=C0RL0aOVB7E14EVp7MD9C55F2yrOfuOMZ0X-oZg3FCg,49072 -sqlalchemy/orm/loading.py,sha256=SMv9Q5bC-kdvsBpOqBNGqNWlL3I75fxByUeEpLC3qtg,58488 -sqlalchemy/orm/mapped_collection.py,sha256=FAqaTlOUCYqdws2KR_fW0T8mMWIrLuAxJGU5f4W1aGs,19682 -sqlalchemy/orm/mapper.py,sha256=-7q3rHqj3x_acv6prq3sDEXZmHx7kGSV9G-gW_JwaX4,171834 -sqlalchemy/orm/path_registry.py,sha256=tRk3osC5BmU7kkcKJCeeibpg2witjyVzO0rX0pu8vmc,25914 -sqlalchemy/orm/persistence.py,sha256=laKaHW7XsVDYhXfDLnxqAJ5lPB8vhUZ0lEhLvtx-fb4,61812 -sqlalchemy/orm/properties.py,sha256=yXxd40V25FIF9vSEev-AxH58yZie8mZMCGQtgFmoUe8,30127 -sqlalchemy/orm/query.py,sha256=hPLslLL50lThw--5G8l3GtPgEdIY07hqIDOEO-0-wT8,118724 -sqlalchemy/orm/relationships.py,sha256=t3yqixZ41chMVOnmelNaps7jwj5vwN9dZFSB0gKK9Pw,128763 -sqlalchemy/orm/scoping.py,sha256=I_-BL8xAFQsZraFtA1wf5wgZ1WywBwBk-9OwiSAjPTM,78600 -sqlalchemy/orm/session.py,sha256=tNdUDRhTx0qFB6cCbnORatW4aWoNfJKuxNwch4KTd3E,195877 -sqlalchemy/orm/state.py,sha256=1vtlz674sGFmwZ8Ih9TdrslA-0nhU2G52WgV-FoG2j0,37670 -sqlalchemy/orm/state_changes.py,sha256=al74Ymt3vqqtWfzZUHQhIKmBZXbT1ovLxgfDurW6XRc,6813 -sqlalchemy/orm/strategies.py,sha256=zk2sg-5D05dBJlzEzpLD5Sfnd5WcCH6dDm4-bxZdMKI,119803 -sqlalchemy/orm/strategy_options.py,sha256=6QFEsOoOsyP2yNJHiJ4j9urfwQxfHFuSVJpoD9TxHcA,85627 -sqlalchemy/orm/sync.py,sha256=RdoxnhvgNjn3Lhtoq4QjvXpj8qfOz__wyibh0FMON0A,5779 -sqlalchemy/orm/unitofwork.py,sha256=hkSIcVonoSt0WWHk019bCDEw0g2o2fg4m4yqoTGyAoo,27033 -sqlalchemy/orm/util.py,sha256=t7lHq0-2FdSpPT558v674-6j9j4DTCmWTOI9xbDy3nY,80889 -sqlalchemy/orm/writeonly.py,sha256=OmFqL9SaJxgZkuvISHwa5WZlipMf3X6t5UJPDwxv_pA,22225 -sqlalchemy/pool/__init__.py,sha256=niqzCv2uOZT07DOiV2inlmjrW3lZyqDXGCjnOl1IqJ4,1804 -sqlalchemy/pool/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/pool/__pycache__/base.cpython-312.pyc,, -sqlalchemy/pool/__pycache__/events.cpython-312.pyc,, -sqlalchemy/pool/__pycache__/impl.cpython-312.pyc,, -sqlalchemy/pool/base.py,sha256=_UnrUVppwH0gBkiqPWPcxh1FgU4rjEsCDuCBBw73uAg,52383 -sqlalchemy/pool/events.py,sha256=wdFfvat0fSrVF84Zzsz5E3HnVY0bhL7MPsGME-b2qa8,13149 -sqlalchemy/pool/impl.py,sha256=2cg6RVfaXHOH-JPvJx0ITN-xDvjNP-eokhmqpDjsBgE,18899 -sqlalchemy/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -sqlalchemy/schema.py,sha256=huwl6-8J9j8ZkMiV3ISminNA7BPa8GrYmdX-q4Lvy9M,3251 -sqlalchemy/sql/__init__.py,sha256=Y-bZ25Zf-bxqsF2zUkpRGTjFuozNNVQHxUJV3Qmaq2M,5820 -sqlalchemy/sql/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_dml_constructors.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_elements_constructors.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_orm_types.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_py_util.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_selectable_constructors.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/_typing.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/annotation.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/base.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/cache_key.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/coercions.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/compiler.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/crud.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/ddl.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/default_comparator.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/dml.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/elements.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/events.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/expression.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/functions.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/lambdas.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/naming.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/operators.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/roles.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/schema.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/selectable.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/sqltypes.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/traversals.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/type_api.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/util.cpython-312.pyc,, -sqlalchemy/sql/__pycache__/visitors.cpython-312.pyc,, -sqlalchemy/sql/_dml_constructors.py,sha256=JF_XucNTfAk6Vz9fYiPWOgpIGtUkDj6VPILysLcrVhk,3795 -sqlalchemy/sql/_elements_constructors.py,sha256=0fOsjr_UVUnpJJyP7FL0dd1-tqcqIU5uc0vsNfPNApo,63096 -sqlalchemy/sql/_orm_types.py,sha256=0zeMit-V4rYZe-bB9X3xugnjFnPXH0gmeqkJou9Fows,625 -sqlalchemy/sql/_py_util.py,sha256=4KFXNvBq3hhfrr-A1J1uBml3b3CGguIf1dat9gsEHqE,2173 -sqlalchemy/sql/_selectable_constructors.py,sha256=2xSSQEkjhsOim8nvuzQgSN_jpfKdJM9_jVNR91n-wuM,22171 -sqlalchemy/sql/_typing.py,sha256=lV12dX4kWMC1IIEyD3fgOJo_plMq0-qfE5h_oiQzTuQ,13029 -sqlalchemy/sql/annotation.py,sha256=qHUEwbdmMD3Ybr0ez-Dyiw9l9UB_RUMHWAUIeO_r3gE,18245 -sqlalchemy/sql/base.py,sha256=lwxhzQumtS7GA0Hb7v3TgUT9pbwELEkGoyj9XqRcS2Y,75859 -sqlalchemy/sql/cache_key.py,sha256=hnOYFbU_vmtpqorW-dE1Z9h_CK_Yi_3YXZpOAp30ZbM,33653 -sqlalchemy/sql/coercions.py,sha256=8jZUTu7NqukXTVvz9jqJ7Pr3u762qrP2AUVgmOgoUTc,40705 -sqlalchemy/sql/compiler.py,sha256=63-a8RYtgbU-UKDLerrMidaZvRUqmsT7H_4fS0PZ4qc,283319 -sqlalchemy/sql/crud.py,sha256=zfJdQsRZgAwxcxmo4-WjhgxJKpJ7FRoAAuZ7NgNNUx0,59455 -sqlalchemy/sql/ddl.py,sha256=6Za5sdcpC2D0rJ7_tPSnyp6XR-B0zaDR6MCn032g0eE,47993 -sqlalchemy/sql/default_comparator.py,sha256=YL0lb3TGlmfoUfcMWEo5FkvBQVPa1ZnDcYxoUq97f_4,16706 -sqlalchemy/sql/dml.py,sha256=hUubKQK2dT91uMXyWuK1OpdJ6L4R_VyBw_rKH82lt7U,66232 -sqlalchemy/sql/elements.py,sha256=E0lCoqQJbWwQ34xdhdxGXqGcFgrvla_xrnSpWgs4Uwo,178317 -sqlalchemy/sql/events.py,sha256=iWjc_nm1vClDBLg4ZhDnY75CkBdnlDPSPe0MGBSmbiM,18312 -sqlalchemy/sql/expression.py,sha256=CsOkmAQgaB-Rnwe7eK60FdBC5R9kY5pczCGrVw2BwGs,7583 -sqlalchemy/sql/functions.py,sha256=DQkV7asOlWaBtFTqRIC663oNkloy5EUhHexjo87GtUY,64826 -sqlalchemy/sql/lambdas.py,sha256=W5b75ojie3EOm7poR27qsnQHQYdz-NxfSrgb5ATT2H0,49401 -sqlalchemy/sql/naming.py,sha256=5Tk6nm4xqy8d9gzXzDvdiqqS7IptUaf1d7IuVdslplU,6855 -sqlalchemy/sql/operators.py,sha256=h5bgu31gukGdsYsN_0-1C7IGAdSCFpBxuRjOUnu1Two,76792 -sqlalchemy/sql/roles.py,sha256=drAeWbevjgFAKNcMrH_EuJ-9sSvcq4aeXwAqMXXZGYw,7662 -sqlalchemy/sql/schema.py,sha256=UW3cJhz8YhdGNp5VuUcFy0qVkGpbwmgj7ejdyklSr4s,230401 -sqlalchemy/sql/selectable.py,sha256=5L3itqHaRCyd7isvo3VE32jyajdV8VZQ7ybnzWgmu14,242155 -sqlalchemy/sql/sqltypes.py,sha256=kMNNxP0z3xfK8OeZCI4wMsexAN07O31O1Wj6uaFNzdk,132156 -sqlalchemy/sql/traversals.py,sha256=7GALHt5mFceUv2SMUikIdAb9SUcSbACqhwoei5rPkxc,33664 -sqlalchemy/sql/type_api.py,sha256=ZaRtirCvkY2-LOv2TeRFX8r8aVOl5fZhplLWBqexctE,85425 -sqlalchemy/sql/util.py,sha256=NSyop8VMFspSPhnUeTc6-ffWEnBgS12FasZKSo-e1-w,48110 -sqlalchemy/sql/visitors.py,sha256=nMK_ddPg4NvEhEgKorD0rGoy-jqs-dT-uou-S8HAEyY,36316 -sqlalchemy/testing/__init__.py,sha256=GgUEqxUNCxg-92_GgBDnljUHsdCxaGPMG1TWy5tjwgk,3160 -sqlalchemy/testing/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/assertions.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/assertsql.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/asyncio.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/config.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/engines.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/entities.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/exclusions.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/pickleable.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/profiling.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/provision.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/requirements.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/schema.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/util.cpython-312.pyc,, -sqlalchemy/testing/__pycache__/warnings.cpython-312.pyc,, -sqlalchemy/testing/assertions.py,sha256=9FLeP4Q5nPCP-NAVutOse9ej0SD1uEGtW5YKIy8s5dA,31564 -sqlalchemy/testing/assertsql.py,sha256=cmhtZrgPBjrqIfzFz3VBWxVNvxWoRllvmoWcUCoqsio,16817 -sqlalchemy/testing/asyncio.py,sha256=QsMzDWARFRrpLoWhuYqzYQPTUZ80fymlKrqOoDkmCmQ,3830 -sqlalchemy/testing/config.py,sha256=HySdB5_FgCW1iHAJVxYo-4wq5gUAEi0N8E93IC6M86Q,12058 -sqlalchemy/testing/engines.py,sha256=c1gFXfpo5S1dvNjGIL03mbW2eVYtUD_9M_ZEfQO2ArM,13414 -sqlalchemy/testing/entities.py,sha256=KdgTVPSALhi9KkAXj2giOYl62ld-1yZziIDBSV8E3vw,3354 -sqlalchemy/testing/exclusions.py,sha256=0Byf3DIMQXN0-HOS6M2MPJ-fOm_n5MzE1yIfHgE0nLs,12473 -sqlalchemy/testing/fixtures/__init__.py,sha256=e5YtfSlkKDRuyIZhEKBCycMX5BOO4MZ-0d97l1JDhJE,1198 -sqlalchemy/testing/fixtures/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/testing/fixtures/__pycache__/base.cpython-312.pyc,, -sqlalchemy/testing/fixtures/__pycache__/mypy.cpython-312.pyc,, -sqlalchemy/testing/fixtures/__pycache__/orm.cpython-312.pyc,, -sqlalchemy/testing/fixtures/__pycache__/sql.cpython-312.pyc,, -sqlalchemy/testing/fixtures/base.py,sha256=n1wws2ziMfP5CcmKx1R-1bFitUDvIAjJH0atWKMI5Oc,12385 -sqlalchemy/testing/fixtures/mypy.py,sha256=tzCaKeO6SX_6uhdBFrKo6iBB7abdZxhyj7SFUlRQINc,12755 -sqlalchemy/testing/fixtures/orm.py,sha256=3JJoYdI2tj5-LL7AN8bVa79NV3Guo4d9p6IgheHkWGc,6095 -sqlalchemy/testing/fixtures/sql.py,sha256=ht-OD6fMZ0inxucRzRZG4kEMNicqY8oJdlKbZzHhAJc,15900 -sqlalchemy/testing/pickleable.py,sha256=G3L0xL9OtbX7wThfreRjWd0GW7q0kUKcTUuCN5ETGno,2833 -sqlalchemy/testing/plugin/__init__.py,sha256=vRfF7M763cGm9tLQDWK6TyBNHc80J1nX2fmGGxN14wY,247 -sqlalchemy/testing/plugin/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/testing/plugin/__pycache__/bootstrap.cpython-312.pyc,, -sqlalchemy/testing/plugin/__pycache__/plugin_base.cpython-312.pyc,, -sqlalchemy/testing/plugin/__pycache__/pytestplugin.cpython-312.pyc,, -sqlalchemy/testing/plugin/bootstrap.py,sha256=VYnVSMb-u30hGY6xGn6iG-LqiF0CubT90AJPFY_6UiY,1685 -sqlalchemy/testing/plugin/plugin_base.py,sha256=TBWdg2XgXB6QgUUFdKLv1O9-SXMitjHLm2rNNIzXZhQ,21578 -sqlalchemy/testing/plugin/pytestplugin.py,sha256=X49CojfNqAPSqBjzYZb6lLxj_Qxz37-onCYBI6-xOCk,27624 -sqlalchemy/testing/profiling.py,sha256=SWhWiZImJvDsNn0rQyNki70xdNxZL53ZI98ihxiykbQ,10148 -sqlalchemy/testing/provision.py,sha256=6r2FTnm-t7u8MMbWo7eMhAH3qkL0w0WlmE29MUSEIu4,14702 -sqlalchemy/testing/requirements.py,sha256=3u8lfzSOLE-_QUD6iHkhzRRbXDyEucmz2T8VRO8QG08,55757 -sqlalchemy/testing/schema.py,sha256=IImFumAdpzOyoKAs0WnaGakq8D3sSU4snD9W4LVOV3s,6513 -sqlalchemy/testing/suite/__init__.py,sha256=S8TLwTiif8xX67qlZUo5I9fl9UjZAFGSzvlptp2WoWc,722 -sqlalchemy/testing/suite/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_cte.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_ddl.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_deprecations.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_dialect.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_insert.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_reflection.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_results.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_rowcount.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_select.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_sequence.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_types.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_unicode_ddl.cpython-312.pyc,, -sqlalchemy/testing/suite/__pycache__/test_update_delete.cpython-312.pyc,, -sqlalchemy/testing/suite/test_cte.py,sha256=_GnADXRnhm37RdSRBR5SthQenTeb5VVo3HoCuO0Vifw,7262 -sqlalchemy/testing/suite/test_ddl.py,sha256=MItp-votCzvahlRqHRagte2Omyq9XUOFdFsgzCb6_-g,12031 -sqlalchemy/testing/suite/test_deprecations.py,sha256=7C6IbxRmq7wg_DLq56f1V5RCS9iVrAv3epJZQTB-dOo,5337 -sqlalchemy/testing/suite/test_dialect.py,sha256=j3srr7k2aUd_kPtJPgqI1g1aYD6ko4MvuGu1a1HQgS8,24215 -sqlalchemy/testing/suite/test_insert.py,sha256=pR0VWMQ9JJPbnANE6634PzR0VFmWMF8im6OTahc4vsQ,18824 -sqlalchemy/testing/suite/test_reflection.py,sha256=oRqwm8ZUjDdXcE3mooIg5513FpNiwEl76IoJaa_aK-Q,114101 -sqlalchemy/testing/suite/test_results.py,sha256=S7Vqqh_Wuqf7uhM8h0cBVeV1GS5GJRO_ZTVYmT7kwuc,17042 -sqlalchemy/testing/suite/test_rowcount.py,sha256=UVyHHQsU0TxkzV_dqCOKR1aROvIq7frKYMVjwUqLWfE,7900 -sqlalchemy/testing/suite/test_select.py,sha256=U6WHUBzko_x6dK32PCXY7-5xN9j0VuAS5z3C-zjDE8I,62041 -sqlalchemy/testing/suite/test_sequence.py,sha256=DMqyJkL1o4GClrNjzoy7GDn_jPNPTZNvk9t5e-MVXeo,9923 -sqlalchemy/testing/suite/test_types.py,sha256=C3wJn3DGlGf58eNr02SoYR3iFAl-vnnHPJS_SSWIu80,68013 -sqlalchemy/testing/suite/test_unicode_ddl.py,sha256=0zVc2e3zbCQag_xL4b0i7F062HblHwV46JHLMweYtcE,6141 -sqlalchemy/testing/suite/test_update_delete.py,sha256=_OxH0wggHUqPImalGEPI48RiRx6mO985Om1PtRYOCzA,3994 -sqlalchemy/testing/util.py,sha256=BuA4q-8cmNhrUVqPP35Rr15MnYGSjmW0hmUdS1SI0_I,14526 -sqlalchemy/testing/warnings.py,sha256=sj4vfTtjodcfoX6FPH_Zykb4fomjmgqIYj81QPpSwH8,1546 -sqlalchemy/types.py,sha256=Iq_rKisaj_zhHtzD2R2cxvg3jkug5frikbkcKG0S4Lg,3166 -sqlalchemy/util/__init__.py,sha256=fAnlZil8ImzO2ZQghrQ-S2H1PO1ViKPaJcI3LD8bMUk,8314 -sqlalchemy/util/__pycache__/__init__.cpython-312.pyc,, -sqlalchemy/util/__pycache__/_collections.cpython-312.pyc,, -sqlalchemy/util/__pycache__/_concurrency_py3k.cpython-312.pyc,, -sqlalchemy/util/__pycache__/_has_cy.cpython-312.pyc,, -sqlalchemy/util/__pycache__/_py_collections.cpython-312.pyc,, -sqlalchemy/util/__pycache__/compat.cpython-312.pyc,, -sqlalchemy/util/__pycache__/concurrency.cpython-312.pyc,, -sqlalchemy/util/__pycache__/deprecations.cpython-312.pyc,, -sqlalchemy/util/__pycache__/langhelpers.cpython-312.pyc,, -sqlalchemy/util/__pycache__/preloaded.cpython-312.pyc,, -sqlalchemy/util/__pycache__/queue.cpython-312.pyc,, -sqlalchemy/util/__pycache__/tool_support.cpython-312.pyc,, -sqlalchemy/util/__pycache__/topological.cpython-312.pyc,, -sqlalchemy/util/__pycache__/typing.cpython-312.pyc,, -sqlalchemy/util/_collections.py,sha256=JQkGm3MBq3RWr5WKG1-SwocPK3PwQHNslW8QqT7CAq0,20151 -sqlalchemy/util/_concurrency_py3k.py,sha256=UtPDkb67OOVWYvBqYaQgENg0k_jOA2mQOE04XmrbYq0,9170 -sqlalchemy/util/_has_cy.py,sha256=3oh7s5iQtW9qcI8zYunCfGAKG6fzo2DIpzP5p1BnE8Q,1247 -sqlalchemy/util/_py_collections.py,sha256=nxdOFQkO05ijXw-0u_InaH19pPj4VsFcat7tZNoIjt8,16650 -sqlalchemy/util/compat.py,sha256=ahh0y6bVwOTkT6CdRvxXFGXJSsDQL_RTPyT3AQjw9xo,8848 -sqlalchemy/util/concurrency.py,sha256=eQVS3YDH3GwB3Uw5pbzmqEBSYTK90EbnE5mQ05fHERg,3304 -sqlalchemy/util/deprecations.py,sha256=L7D4GqeIozpjO8iVybf7jL9dDlgfTbAaQH4TQAX74qE,12012 -sqlalchemy/util/langhelpers.py,sha256=veH0KW61Pz8hooiM9xMmTEzQqnjZ0KxBGdxW5Z_Rbtc,68371 -sqlalchemy/util/preloaded.py,sha256=RMarsuhtMW8ZuvqLSuR0kwbp45VRlzKpJMLUe7p__qY,5904 -sqlalchemy/util/queue.py,sha256=w1ufhuiC7lzyiZDhciRtRz1uyxU72jRI7SWhhL-p600,10185 -sqlalchemy/util/tool_support.py,sha256=e7lWu6o1QlKq4e6c9PyDsuyFyiWe79vO72UQ_YX2pUA,6135 -sqlalchemy/util/topological.py,sha256=tbkMRY0TTgNiq44NUJpnazXR4xb9v4Q4mQ8BygMp0vY,3451 -sqlalchemy/util/typing.py,sha256=iwyZIgOJUN2o9cRz8YTH093iY5iNvpXiDQG3pce0cc4,22466 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/__init__.py index 0ff2665..fbb8d66 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/__init__.py @@ -1,5 +1,5 @@ -# __init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -55,7 +55,7 @@ from .pool import Pool as Pool from .pool import PoolProxiedConnection as PoolProxiedConnection from .pool import PoolResetState as PoolResetState from .pool import QueuePool as QueuePool -from .pool import SingletonThreadPool as SingletonThreadPool +from .pool import SingletonThreadPool as SingleonThreadPool from .pool import StaticPool as StaticPool from .schema import BaseDDLElement as BaseDDLElement from .schema import BLANK_SCHEMA as BLANK_SCHEMA @@ -269,11 +269,13 @@ from .types import Uuid as Uuid from .types import VARBINARY as VARBINARY from .types import VARCHAR as VARCHAR -__version__ = "2.0.43" +__version__ = "2.0.23" def __go(lcls: Any) -> None: - _util.preloaded.import_prefix("sqlalchemy") + from . import util as _sa_util + + _sa_util.preloaded.import_prefix("sqlalchemy") from . import exc diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/__init__.py index 43cd103..1969d72 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/__init__.py @@ -1,5 +1,5 @@ # connectors/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/aioodbc.py b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/aioodbc.py index 6e4b864..c698636 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/aioodbc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/aioodbc.py @@ -1,5 +1,5 @@ # connectors/aioodbc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -20,7 +20,6 @@ from .. import util from ..util.concurrency import await_fallback from ..util.concurrency import await_only - if TYPE_CHECKING: from ..engine.interfaces import ConnectArgsType from ..engine.url import URL @@ -59,15 +58,6 @@ class AsyncAdapt_aioodbc_connection(AsyncAdapt_dbapi_connection): self._connection._conn.autocommit = value - def ping(self, reconnect): - return self.await_(self._connection.ping(reconnect)) - - def add_output_converter(self, *arg, **kw): - self._connection.add_output_converter(*arg, **kw) - - def character_set_name(self): - return self._connection.character_set_name() - def cursor(self, server_side=False): # aioodbc sets connection=None when closed and just fails with # AttributeError here. Here we use the same ProgrammingError + @@ -180,5 +170,18 @@ class aiodbcConnector(PyODBCConnector): else: return pool.AsyncAdaptedQueuePool + def _do_isolation_level(self, connection, autocommit, isolation_level): + connection.set_autocommit(autocommit) + connection.set_isolation_level(isolation_level) + + def _do_autocommit(self, connection, value): + connection.set_autocommit(value) + + def set_readonly(self, connection, value): + connection.set_read_only(value) + + def set_deferrable(self, connection, value): + connection.set_deferrable(value) + def get_driver_connection(self, connection): return connection._connection diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/asyncio.py b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/asyncio.py index fda21b6..997407c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/asyncio.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/asyncio.py @@ -1,124 +1,22 @@ # connectors/asyncio.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors """generic asyncio-adapted versions of DBAPI connection and cursor""" from __future__ import annotations -import asyncio import collections -import sys -from typing import Any -from typing import AsyncIterator -from typing import Deque -from typing import Iterator -from typing import NoReturn -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING +import itertools from ..engine import AdaptedConnection +from ..util.concurrency import asyncio from ..util.concurrency import await_fallback from ..util.concurrency import await_only -from ..util.typing import Protocol - -if TYPE_CHECKING: - from ..engine.interfaces import _DBAPICursorDescription - from ..engine.interfaces import _DBAPIMultiExecuteParams - from ..engine.interfaces import _DBAPISingleExecuteParams - from ..engine.interfaces import DBAPIModule - from ..util.typing import Self - - -class AsyncIODBAPIConnection(Protocol): - """protocol representing an async adapted version of a - :pep:`249` database connection. - - - """ - - # note that async DBAPIs dont agree if close() should be awaitable, - # so it is omitted here and picked up by the __getattr__ hook below - - async def commit(self) -> None: ... - - def cursor(self, *args: Any, **kwargs: Any) -> AsyncIODBAPICursor: ... - - async def rollback(self) -> None: ... - - def __getattr__(self, key: str) -> Any: ... - - def __setattr__(self, key: str, value: Any) -> None: ... - - -class AsyncIODBAPICursor(Protocol): - """protocol representing an async adapted version - of a :pep:`249` database cursor. - - - """ - - def __aenter__(self) -> Any: ... - - @property - def description( - self, - ) -> _DBAPICursorDescription: - """The description attribute of the Cursor.""" - ... - - @property - def rowcount(self) -> int: ... - - arraysize: int - - lastrowid: int - - async def close(self) -> None: ... - - async def execute( - self, - operation: Any, - parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: ... - - async def executemany( - self, - operation: Any, - parameters: _DBAPIMultiExecuteParams, - ) -> Any: ... - - async def fetchone(self) -> Optional[Any]: ... - - async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ... - - async def fetchall(self) -> Sequence[Any]: ... - - async def setinputsizes(self, sizes: Sequence[Any]) -> None: ... - - def setoutputsize(self, size: Any, column: Any) -> None: ... - - async def callproc( - self, procname: str, parameters: Sequence[Any] = ... - ) -> Any: ... - - async def nextset(self) -> Optional[bool]: ... - - def __aiter__(self) -> AsyncIterator[Any]: ... - - -class AsyncAdapt_dbapi_module: - if TYPE_CHECKING: - Error = DBAPIModule.Error - OperationalError = DBAPIModule.OperationalError - InterfaceError = DBAPIModule.InterfaceError - IntegrityError = DBAPIModule.IntegrityError - - def __getattr__(self, key: str) -> Any: ... class AsyncAdapt_dbapi_cursor: @@ -131,136 +29,99 @@ class AsyncAdapt_dbapi_cursor: "_rows", ) - _cursor: AsyncIODBAPICursor - _adapt_connection: AsyncAdapt_dbapi_connection - _connection: AsyncIODBAPIConnection - _rows: Deque[Any] - - def __init__(self, adapt_connection: AsyncAdapt_dbapi_connection): + def __init__(self, adapt_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self.await_ = adapt_connection.await_ - cursor = self._make_new_cursor(self._connection) - self._cursor = self._aenter_cursor(cursor) + cursor = self._connection.cursor() - if not self.server_side: - self._rows = collections.deque() - - def _aenter_cursor(self, cursor: AsyncIODBAPICursor) -> AsyncIODBAPICursor: - return self.await_(cursor.__aenter__()) # type: ignore[no-any-return] - - def _make_new_cursor( - self, connection: AsyncIODBAPIConnection - ) -> AsyncIODBAPICursor: - return connection.cursor() + self._cursor = self.await_(cursor.__aenter__()) + self._rows = collections.deque() @property - def description(self) -> Optional[_DBAPICursorDescription]: + def description(self): return self._cursor.description @property - def rowcount(self) -> int: + def rowcount(self): return self._cursor.rowcount @property - def arraysize(self) -> int: + def arraysize(self): return self._cursor.arraysize @arraysize.setter - def arraysize(self, value: int) -> None: + def arraysize(self, value): self._cursor.arraysize = value @property - def lastrowid(self) -> int: + def lastrowid(self): return self._cursor.lastrowid - def close(self) -> None: + def close(self): # note we aren't actually closing the cursor here, # we are just letting GC do it. see notes in aiomysql dialect self._rows.clear() - def execute( - self, - operation: Any, - parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: - try: - return self.await_(self._execute_async(operation, parameters)) - except Exception as error: - self._adapt_connection._handle_exception(error) + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) - def executemany( - self, - operation: Any, - seq_of_parameters: _DBAPIMultiExecuteParams, - ) -> Any: - try: - return self.await_( - self._executemany_async(operation, seq_of_parameters) - ) - except Exception as error: - self._adapt_connection._handle_exception(error) + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) + ) - async def _execute_async( - self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] - ) -> Any: + async def _execute_async(self, operation, parameters): async with self._adapt_connection._execute_mutex: - if parameters is None: - result = await self._cursor.execute(operation) - else: - result = await self._cursor.execute(operation, parameters) + result = await self._cursor.execute(operation, parameters or ()) if self._cursor.description and not self.server_side: + # aioodbc has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. self._rows = collections.deque(await self._cursor.fetchall()) return result - async def _executemany_async( - self, - operation: Any, - seq_of_parameters: _DBAPIMultiExecuteParams, - ) -> Any: + async def _executemany_async(self, operation, seq_of_parameters): async with self._adapt_connection._execute_mutex: return await self._cursor.executemany(operation, seq_of_parameters) - def nextset(self) -> None: + def nextset(self): self.await_(self._cursor.nextset()) if self._cursor.description and not self.server_side: self._rows = collections.deque( self.await_(self._cursor.fetchall()) ) - def setinputsizes(self, *inputsizes: Any) -> None: + def setinputsizes(self, *inputsizes): # NOTE: this is overrridden in aioodbc due to # see https://github.com/aio-libs/aioodbc/issues/451 # right now return self.await_(self._cursor.setinputsizes(*inputsizes)) - def __enter__(self) -> Self: - return self - - def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: - self.close() - - def __iter__(self) -> Iterator[Any]: + def __iter__(self): while self._rows: yield self._rows.popleft() - def fetchone(self) -> Optional[Any]: + def fetchone(self): if self._rows: return self._rows.popleft() else: return None - def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: + def fetchmany(self, size=None): if size is None: size = self.arraysize - rr = self._rows - return [rr.popleft() for _ in range(min(size, len(rr)))] - def fetchall(self) -> Sequence[Any]: + rr = iter(self._rows) + retval = list(itertools.islice(rr, 0, size)) + self._rows = collections.deque(rr) + return retval + + def fetchall(self): retval = list(self._rows) self._rows.clear() return retval @@ -270,78 +131,75 @@ class AsyncAdapt_dbapi_ss_cursor(AsyncAdapt_dbapi_cursor): __slots__ = () server_side = True - def close(self) -> None: + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + + cursor = self._connection.cursor() + + self._cursor = self.await_(cursor.__aenter__()) + + def close(self): if self._cursor is not None: self.await_(self._cursor.close()) - self._cursor = None # type: ignore + self._cursor = None - def fetchone(self) -> Optional[Any]: + def fetchone(self): return self.await_(self._cursor.fetchone()) - def fetchmany(self, size: Optional[int] = None) -> Any: + def fetchmany(self, size=None): return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self) -> Sequence[Any]: + def fetchall(self): return self.await_(self._cursor.fetchall()) - def __iter__(self) -> Iterator[Any]: - iterator = self._cursor.__aiter__() - while True: - try: - yield self.await_(iterator.__anext__()) - except StopAsyncIteration: - break - class AsyncAdapt_dbapi_connection(AdaptedConnection): _cursor_cls = AsyncAdapt_dbapi_cursor _ss_cursor_cls = AsyncAdapt_dbapi_ss_cursor await_ = staticmethod(await_only) - __slots__ = ("dbapi", "_execute_mutex") - _connection: AsyncIODBAPIConnection - - def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection): + def __init__(self, dbapi, connection): self.dbapi = dbapi self._connection = connection self._execute_mutex = asyncio.Lock() - def cursor(self, server_side: bool = False) -> AsyncAdapt_dbapi_cursor: + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def add_output_converter(self, *arg, **kw): + self._connection.add_output_converter(*arg, **kw) + + def character_set_name(self): + return self._connection.character_set_name() + + @property + def autocommit(self): + return self._connection.autocommit + + @autocommit.setter + def autocommit(self, value): + # https://github.com/aio-libs/aioodbc/issues/448 + # self._connection.autocommit = value + + self._connection._conn.autocommit = value + + def cursor(self, server_side=False): if server_side: return self._ss_cursor_cls(self) else: return self._cursor_cls(self) - def execute( - self, - operation: Any, - parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: - """lots of DBAPIs seem to provide this, so include it""" - cursor = self.cursor() - cursor.execute(operation, parameters) - return cursor + def rollback(self): + self.await_(self._connection.rollback()) - def _handle_exception(self, error: Exception) -> NoReturn: - exc_info = sys.exc_info() + def commit(self): + self.await_(self._connection.commit()) - raise error.with_traceback(exc_info[2]) - - def rollback(self) -> None: - try: - self.await_(self._connection.rollback()) - except Exception as error: - self._handle_exception(error) - - def commit(self) -> None: - try: - self.await_(self._connection.commit()) - except Exception as error: - self._handle_exception(error) - - def close(self) -> None: + def close(self): self.await_(self._connection.close()) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/pyodbc.py b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/pyodbc.py index dee2616..49712a5 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/connectors/pyodbc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/connectors/pyodbc.py @@ -1,5 +1,5 @@ # connectors/pyodbc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,6 +8,7 @@ from __future__ import annotations import re +from types import ModuleType import typing from typing import Any from typing import Dict @@ -28,7 +29,6 @@ from ..engine import URL from ..sql.type_api import TypeEngine if typing.TYPE_CHECKING: - from ..engine.interfaces import DBAPIModule from ..engine.interfaces import IsolationLevel @@ -48,13 +48,15 @@ class PyODBCConnector(Connector): # hold the desired driver name pyodbc_driver_name: Optional[str] = None + dbapi: ModuleType + def __init__(self, use_setinputsizes: bool = False, **kw: Any): super().__init__(**kw) if use_setinputsizes: self.bind_typing = interfaces.BindTyping.SETINPUTSIZES @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls) -> ModuleType: return __import__("pyodbc") def create_connect_args(self, url: URL) -> ConnectArgsType: @@ -148,7 +150,7 @@ class PyODBCConnector(Connector): ], cursor: Optional[interfaces.DBAPICursor], ) -> bool: - if isinstance(e, self.loaded_dbapi.ProgrammingError): + if isinstance(e, self.dbapi.ProgrammingError): return "The cursor's connection has been closed." in str( e ) or "Attempt to use a closed connection." in str(e) @@ -215,19 +217,19 @@ class PyODBCConnector(Connector): cursor.setinputsizes( [ - ( - (dbtype, None, None) - if not isinstance(dbtype, tuple) - else dbtype - ) + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype for key, dbtype, sqltype in list_of_tuples ] ) def get_isolation_level_values( - self, dbapi_conn: interfaces.DBAPIConnection + self, dbapi_connection: interfaces.DBAPIConnection ) -> List[IsolationLevel]: - return [*super().get_isolation_level_values(dbapi_conn), "AUTOCOMMIT"] + return super().get_isolation_level_values(dbapi_connection) + [ + "AUTOCOMMIT" + ] def set_isolation_level( self, @@ -243,8 +245,3 @@ class PyODBCConnector(Connector): else: dbapi_connection.autocommit = False super().set_isolation_level(dbapi_connection, level) - - def detect_autocommit_setting( - self, dbapi_conn: interfaces.DBAPIConnection - ) -> bool: - return bool(dbapi_conn.autocommit) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/__init__.py index cb8dc2c..e69de29 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/__init__.py @@ -1,6 +0,0 @@ -# cyextension/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so index 3bd3dee..9a70340 100755 Binary files a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.pyx b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.pyx index 86d2485..4d134cc 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.pyx +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/collections.pyx @@ -1,9 +1,3 @@ -# cyextension/collections.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php cimport cython from cpython.long cimport PyLong_FromLongLong from cpython.set cimport PySet_Add diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so index 325bf9d..acfee19 100755 Binary files a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pxd b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pxd index 76f2289..fe7ad6a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pxd +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pxd @@ -1,8 +1,2 @@ -# cyextension/immutabledict.pxd -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php cdef class immutabledict(dict): pass diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pyx b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pyx index b37eccc..100287b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pyx +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/immutabledict.pyx @@ -1,9 +1,3 @@ -# cyextension/immutabledict.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so index 9382f23..90b59ea 100755 Binary files a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.pyx b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.pyx index 3d71456..b0ad865 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.pyx +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/processors.pyx @@ -1,9 +1,3 @@ -# cyextension/processors.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php import datetime from datetime import datetime as datetime_cls from datetime import time as time_cls diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so index 3d51373..133fd55 100755 Binary files a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.pyx b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.pyx index b6e357a..0d7eeec 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.pyx +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/resultproxy.pyx @@ -1,9 +1,3 @@ -# cyextension/resultproxy.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php import operator cdef class BaseRow: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so index 3feefd9..838115c 100755 Binary files a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so and b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.cpython-312-x86_64-linux-gnu.so differ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.pyx b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.pyx index 68e4f9f..92e91a6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.pyx +++ b/venv/lib/python3.12/site-packages/sqlalchemy/cyextension/util.pyx @@ -1,33 +1,31 @@ -# cyextension/util.pyx -# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from collections.abc import Mapping from sqlalchemy import exc cdef tuple _Empty_Tuple = () -cdef inline bint _is_mapping_or_tuple(object value): +cdef inline bint _mapping_or_tuple(object value): return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) - -cdef inline bint _is_mapping(object value): - return isinstance(value, dict) or isinstance(value, Mapping) - +cdef inline bint _check_item(object params) except 0: + cdef object item + cdef bint ret = 1 + if params: + item = params[0] + if not _mapping_or_tuple(item): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret def _distill_params_20(object params): if params is None: return _Empty_Tuple elif isinstance(params, list) or isinstance(params, tuple): - if params and not _is_mapping(params[0]): - raise exc.ArgumentError( - "List argument must consist only of dictionaries" - ) + _check_item(params) return params - elif _is_mapping(params): + elif isinstance(params, dict) or isinstance(params, Mapping): return [params] else: raise exc.ArgumentError("mapping or list expected for parameters") @@ -37,12 +35,9 @@ def _distill_raw_params(object params): if params is None: return _Empty_Tuple elif isinstance(params, list): - if params and not _is_mapping_or_tuple(params[0]): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) + _check_item(params) return params - elif _is_mapping_or_tuple(params): + elif _mapping_or_tuple(params): return [params] else: raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/__init__.py index 30928a9..055d087 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/__init__.py @@ -1,5 +1,5 @@ # dialects/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,7 +7,6 @@ from __future__ import annotations -from typing import Any from typing import Callable from typing import Optional from typing import Type @@ -40,7 +39,7 @@ def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]: # hardcoded. if mysql / mariadb etc were third party dialects # they would just publish all the entrypoints, which would actually # look much nicer. - module: Any = __import__( + module = __import__( "sqlalchemy.dialects.mysql.mariadb" ).dialects.mysql.mariadb return module.loader(driver) # type: ignore diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/_typing.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/_typing.py index 4dd40d7..932742b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/_typing.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/_typing.py @@ -1,9 +1,3 @@ -# dialects/_typing.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations from typing import Any @@ -12,19 +6,14 @@ from typing import Mapping from typing import Optional from typing import Union -from ..sql import roles -from ..sql.base import ColumnCollection -from ..sql.schema import Column +from ..sql._typing import _DDLColumnArgument +from ..sql.elements import DQLDMLClauseElement from ..sql.schema import ColumnCollectionConstraint from ..sql.schema import Index _OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None] -_OnConflictIndexElementsT = Optional[ - Iterable[Union[Column[Any], str, roles.DDLConstraintColumnRole]] -] -_OnConflictIndexWhereT = Optional[roles.WhereHavingRole] -_OnConflictSetT = Optional[ - Union[Mapping[Any, Any], ColumnCollection[Any, Any]] -] -_OnConflictWhereT = Optional[roles.WhereHavingRole] +_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]] +_OnConflictIndexWhereT = Optional[DQLDMLClauseElement] +_OnConflictSetT = Optional[Mapping[Any, Any]] +_OnConflictWhereT = Union[DQLDMLClauseElement, str, None] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/__init__.py index 20140fd..6bbb934 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/__init__.py @@ -1,5 +1,5 @@ -# dialects/mssql/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/aioodbc.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/aioodbc.py index 522ad1d..23c2790 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/aioodbc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/aioodbc.py @@ -1,5 +1,5 @@ -# dialects/mssql/aioodbc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/aioodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -32,12 +32,13 @@ This dialect should normally be used only with the styles are otherwise equivalent to those documented in the pyodbc section:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine( "mssql+aioodbc://scott:tiger@mssql2017:1433/test?" "driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes" ) + + """ from __future__ import annotations diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/base.py index 368abaf..687de04 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/base.py @@ -1,5 +1,5 @@ -# dialects/mssql/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,6 +9,7 @@ """ .. dialect:: mssql :name: Microsoft SQL Server + :full_support: 2017 :normal_support: 2012+ :best_effort: 2005+ @@ -39,12 +40,9 @@ considered to be the identity column - unless it is associated with a from sqlalchemy import Table, MetaData, Column, Integer m = MetaData() - t = Table( - "t", - m, - Column("id", Integer, primary_key=True), - Column("x", Integer), - ) + t = Table('t', m, + Column('id', Integer, primary_key=True), + Column('x', Integer)) m.create_all(engine) The above example will generate DDL as: @@ -62,12 +60,9 @@ specify ``False`` for the :paramref:`_schema.Column.autoincrement` flag, on the first integer primary key column:: m = MetaData() - t = Table( - "t", - m, - Column("id", Integer, primary_key=True, autoincrement=False), - Column("x", Integer), - ) + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer)) m.create_all(engine) To add the ``IDENTITY`` keyword to a non-primary key column, specify @@ -77,12 +72,9 @@ To add the ``IDENTITY`` keyword to a non-primary key column, specify is set to ``False`` on any integer primary key column:: m = MetaData() - t = Table( - "t", - m, - Column("id", Integer, primary_key=True, autoincrement=False), - Column("x", Integer, autoincrement=True), - ) + t = Table('t', m, + Column('id', Integer, primary_key=True, autoincrement=False), + Column('x', Integer, autoincrement=True)) m.create_all(engine) .. versionchanged:: 1.4 Added :class:`_schema.Identity` construct @@ -145,12 +137,14 @@ parameters passed to the :class:`_schema.Identity` object:: from sqlalchemy import Table, Integer, Column, Identity test = Table( - "test", - metadata, + 'test', metadata, Column( - "id", Integer, primary_key=True, Identity(start=100, increment=10) + 'id', + Integer, + primary_key=True, + Identity(start=100, increment=10) ), - Column("name", String(20)), + Column('name', String(20)) ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -160,7 +154,7 @@ The CREATE TABLE for the above :class:`_schema.Table` object would be: CREATE TABLE test ( id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY, name VARCHAR(20) NULL, - ) + ) .. note:: @@ -193,7 +187,6 @@ type deployed to the SQL Server database can be specified as ``Numeric`` using Base = declarative_base() - class TestTable(Base): __tablename__ = "test" id = Column( @@ -219,9 +212,8 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows:: from sqlalchemy import TypeDecorator - class NumericAsInteger(TypeDecorator): - "normalize floating point return values into ints" + '''normalize floating point return values into ints''' impl = Numeric(10, 0, asdecimal=False) cache_ok = True @@ -231,7 +223,6 @@ integer values in Python 3), use :class:`_types.TypeDecorator` as follows:: value = int(value) return value - class TestTable(Base): __tablename__ = "test" id = Column( @@ -280,11 +271,11 @@ The process for fetching this value has several variants: fetched in order to receive the value. Given a table as:: t = Table( - "t", + 't', metadata, - Column("id", Integer, primary_key=True), - Column("x", Integer), - implicit_returning=False, + Column('id', Integer, primary_key=True), + Column('x', Integer), + implicit_returning=False ) an INSERT will look like: @@ -310,13 +301,12 @@ statement proceeding, and ``SET IDENTITY_INSERT OFF`` subsequent to the execution. Given this example:: m = MetaData() - t = Table( - "t", m, Column("id", Integer, primary_key=True), Column("x", Integer) - ) + t = Table('t', m, Column('id', Integer, primary_key=True), + Column('x', Integer)) m.create_all(engine) with engine.begin() as conn: - conn.execute(t.insert(), {"id": 1, "x": 1}, {"id": 2, "x": 2}) + conn.execute(t.insert(), {'id': 1, 'x':1}, {'id':2, 'x':2}) The above column will be created with IDENTITY, however the INSERT statement we emit is specifying explicit values. In the echo output we can see @@ -352,11 +342,7 @@ The :class:`.Sequence` object creates "real" sequences, i.e., >>> from sqlalchemy import Sequence >>> from sqlalchemy.schema import CreateSequence >>> from sqlalchemy.dialects import mssql - >>> print( - ... CreateSequence(Sequence("my_seq", start=1)).compile( - ... dialect=mssql.dialect() - ... ) - ... ) + >>> print(CreateSequence(Sequence("my_seq", start=1)).compile(dialect=mssql.dialect())) {printsql}CREATE SEQUENCE my_seq START WITH 1 For integer primary key generation, SQL Server's ``IDENTITY`` construct should @@ -390,12 +376,12 @@ more than one backend without using dialect-specific types. To build a SQL Server VARCHAR or NVARCHAR with MAX length, use None:: my_table = Table( - "my_table", - metadata, - Column("my_data", VARCHAR(None)), - Column("my_n_data", NVARCHAR(None)), + 'my_table', metadata, + Column('my_data', VARCHAR(None)), + Column('my_n_data', NVARCHAR(None)) ) + Collation Support ----------------- @@ -403,13 +389,10 @@ Character collations are supported by the base string types, specified by the string argument "collation":: from sqlalchemy import VARCHAR - - Column("login", VARCHAR(32, collation="Latin1_General_CI_AS")) + Column('login', VARCHAR(32, collation='Latin1_General_CI_AS')) When such a column is associated with a :class:`_schema.Table`, the -CREATE TABLE statement for this column will yield: - -.. sourcecode:: sql +CREATE TABLE statement for this column will yield:: login VARCHAR(32) COLLATE Latin1_General_CI_AS NULL @@ -429,9 +412,7 @@ versions when no OFFSET clause is present. A statement such as:: select(some_table).limit(5) -will render similarly to: - -.. sourcecode:: sql +will render similarly to:: SELECT TOP 5 col1, col2.. FROM table @@ -441,9 +422,7 @@ LIMIT and OFFSET, or just OFFSET alone, will be rendered using the select(some_table).order_by(some_table.c.col3).limit(5).offset(10) -will render similarly to: - -.. sourcecode:: sql +will render similarly to:: SELECT anon_1.col1, anon_1.col2 FROM (SELECT col1, col2, ROW_NUMBER() OVER (ORDER BY col3) AS @@ -496,13 +475,16 @@ each new connection. To set isolation level using :func:`_sa.create_engine`:: engine = create_engine( - "mssql+pyodbc://scott:tiger@ms_2008", isolation_level="REPEATABLE READ" + "mssql+pyodbc://scott:tiger@ms_2008", + isolation_level="REPEATABLE READ" ) To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options(isolation_level="READ COMMITTED") + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) Valid values for ``isolation_level`` include: @@ -552,6 +534,7 @@ will remain consistent with the state of the transaction:: mssql_engine = create_engine( "mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server", + # disable default reset-on-return scheme pool_reset_on_return=None, ) @@ -580,17 +563,13 @@ Nullability ----------- MSSQL has support for three levels of column nullability. The default nullability allows nulls and is explicit in the CREATE TABLE -construct: - -.. sourcecode:: sql +construct:: name VARCHAR(20) NULL If ``nullable=None`` is specified then no specification is made. In other words the database's configured default is used. This will -render: - -.. sourcecode:: sql +render:: name VARCHAR(20) @@ -646,9 +625,8 @@ behavior of this flag is as follows: * The flag can be set to either ``True`` or ``False`` when the dialect is created, typically via :func:`_sa.create_engine`:: - eng = create_engine( - "mssql+pymssql://user:pass@host/db", deprecate_large_types=True - ) + eng = create_engine("mssql+pymssql://user:pass@host/db", + deprecate_large_types=True) * Complete control over whether the "old" or "new" types are rendered is available in all SQLAlchemy versions by using the UPPERCASE type objects @@ -670,10 +648,9 @@ at once using the :paramref:`_schema.Table.schema` argument of :class:`_schema.Table`:: Table( - "some_table", - metadata, + "some_table", metadata, Column("q", String(50)), - schema="mydatabase.dbo", + schema="mydatabase.dbo" ) When performing operations such as table or component reflection, a schema @@ -685,10 +662,9 @@ components will be quoted separately for case sensitive names and other special characters. Given an argument as below:: Table( - "some_table", - metadata, + "some_table", metadata, Column("q", String(50)), - schema="MyDataBase.dbo", + schema="MyDataBase.dbo" ) The above schema would be rendered as ``[MyDataBase].dbo``, and also in @@ -701,22 +677,21 @@ Below, the "owner" will be considered as ``MyDataBase.dbo`` and the "database" will be None:: Table( - "some_table", - metadata, + "some_table", metadata, Column("q", String(50)), - schema="[MyDataBase.dbo]", + schema="[MyDataBase.dbo]" ) To individually specify both database and owner name with special characters or embedded dots, use two sets of brackets:: Table( - "some_table", - metadata, + "some_table", metadata, Column("q", String(50)), - schema="[MyDataBase.Period].[MyOwner.Dot]", + schema="[MyDataBase.Period].[MyOwner.Dot]" ) + .. versionchanged:: 1.2 the SQL Server dialect now treats brackets as identifier delimiters splitting the schema into separate database and owner tokens, to allow dots within either name itself. @@ -731,11 +706,10 @@ schema-qualified table would be auto-aliased when used in a SELECT statement; given a table:: account_table = Table( - "account", - metadata, - Column("id", Integer, primary_key=True), - Column("info", String(100)), - schema="customer_schema", + 'account', metadata, + Column('id', Integer, primary_key=True), + Column('info', String(100)), + schema="customer_schema" ) this legacy mode of rendering would assume that "customer_schema.account" @@ -778,55 +752,37 @@ which renders the index as ``CREATE CLUSTERED INDEX my_index ON table (x)``. To generate a clustered primary key use:: - Table( - "my_table", - metadata, - Column("x", ...), - Column("y", ...), - PrimaryKeyConstraint("x", "y", mssql_clustered=True), - ) + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=True)) -which will render the table, for example, as: +which will render the table, for example, as:: -.. sourcecode:: sql - - CREATE TABLE my_table ( - x INTEGER NOT NULL, - y INTEGER NOT NULL, - PRIMARY KEY CLUSTERED (x, y) - ) + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY CLUSTERED (x, y)) Similarly, we can generate a clustered unique constraint using:: - Table( - "my_table", - metadata, - Column("x", ...), - Column("y", ...), - PrimaryKeyConstraint("x"), - UniqueConstraint("y", mssql_clustered=True), - ) + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x"), + UniqueConstraint("y", mssql_clustered=True), + ) To explicitly request a non-clustered primary key (for example, when a separate clustered index is desired), use:: - Table( - "my_table", - metadata, - Column("x", ...), - Column("y", ...), - PrimaryKeyConstraint("x", "y", mssql_clustered=False), - ) + Table('my_table', metadata, + Column('x', ...), + Column('y', ...), + PrimaryKeyConstraint("x", "y", mssql_clustered=False)) -which will render the table, for example, as: +which will render the table, for example, as:: -.. sourcecode:: sql - - CREATE TABLE my_table ( - x INTEGER NOT NULL, - y INTEGER NOT NULL, - PRIMARY KEY NONCLUSTERED (x, y) - ) + CREATE TABLE my_table (x INTEGER NOT NULL, y INTEGER NOT NULL, + PRIMARY KEY NONCLUSTERED (x, y)) Columnstore Index Support ------------------------- @@ -864,7 +820,7 @@ INCLUDE The ``mssql_include`` option renders INCLUDE(colname) for the given string names:: - Index("my_index", table.c.x, mssql_include=["y"]) + Index("my_index", table.c.x, mssql_include=['y']) would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` @@ -919,19 +875,18 @@ To disable the usage of OUTPUT INSERTED on a per-table basis, specify ``implicit_returning=False`` for each :class:`_schema.Table` which has triggers:: - Table( - "mytable", - metadata, - Column("id", Integer, primary_key=True), + Table('mytable', metadata, + Column('id', Integer, primary_key=True), # ..., - implicit_returning=False, + implicit_returning=False ) Declarative form:: class MyClass(Base): # ... - __table_args__ = {"implicit_returning": False} + __table_args__ = {'implicit_returning':False} + .. _mssql_rowcount_versioning: @@ -965,9 +920,7 @@ isolation mode that locks entire tables, and causes even mildly concurrent applications to have long held locks and frequent deadlocks. Enabling snapshot isolation for the database as a whole is recommended for modern levels of concurrency support. This is accomplished via the -following ALTER DATABASE commands executed at the SQL prompt: - -.. sourcecode:: sql +following ALTER DATABASE commands executed at the SQL prompt:: ALTER DATABASE MyDatabase SET ALLOW_SNAPSHOT_ISOLATION ON @@ -1473,6 +1426,7 @@ class ROWVERSION(TIMESTAMP): class NTEXT(sqltypes.UnicodeText): + """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" @@ -1597,11 +1551,44 @@ class MSUUid(sqltypes.Uuid): def process(value): return f"""'{ - value.replace("-", "").replace("'", "''") - }'""" + value.replace("-", "").replace("'", "''") + }'""" return process + def _sentinel_value_resolver(self, dialect): + """Return a callable that will receive the uuid object or string + as it is normally passed to the DB in the parameter set, after + bind_processor() is called. Convert this value to match + what it would be as coming back from an INSERT..OUTPUT inserted. + + for the UUID type, there are four varieties of settings so here + we seek to convert to the string or UUID representation that comes + back from the driver. + + """ + character_based_uuid = ( + not dialect.supports_native_uuid or not self.native_uuid + ) + + if character_based_uuid: + if self.native_uuid: + # for pyodbc, uuid.uuid() objects are accepted for incoming + # data, as well as strings. but the driver will always return + # uppercase strings in result sets. + def process(value): + return str(value).upper() + + else: + + def process(value): + return str(value) + + return process + else: + # for pymssql, we get uuid.uuid() objects back. + return None + class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): __visit_name__ = "UNIQUEIDENTIFIER" @@ -1609,12 +1596,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): @overload def __init__( self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... - ): ... + ): + ... @overload - def __init__( - self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ... - ): ... + def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): + ... def __init__(self, as_uuid: bool = True): """Construct a :class:`_mssql.UNIQUEIDENTIFIER` type. @@ -1865,6 +1852,7 @@ class MSExecutionContext(default.DefaultExecutionContext): _enable_identity_insert = False _select_lastrowid = False _lastrowid = None + _rowcount = None dialect: MSDialect @@ -1984,6 +1972,13 @@ class MSExecutionContext(default.DefaultExecutionContext): def get_lastrowid(self): return self._lastrowid + @property + def rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + def handle_dbapi_exception(self, e): if self._enable_identity_insert: try: @@ -2035,10 +2030,6 @@ class MSSQLCompiler(compiler.SQLCompiler): self.tablealiases = {} super().__init__(*args, **kwargs) - def _format_frame_clause(self, range_, **kw): - kw["literal_execute"] = True - return super()._format_frame_clause(range_, **kw) - def _with_legacy_schema_aliasing(fn): def decorate(self, *arg, **kw): if self.dialect.legacy_schema_aliasing: @@ -2492,12 +2483,10 @@ class MSSQLCompiler(compiler.SQLCompiler): type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), - ( - "FLOAT" - if isinstance(binary.type, sqltypes.Float) - else "NUMERIC(%s, %s)" - % (binary.type.precision, binary.type.scale) - ), + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale), ) elif binary.type._type_affinity is sqltypes.Boolean: # the NULL handling is particularly weird with boolean, so @@ -2533,6 +2522,7 @@ class MSSQLCompiler(compiler.SQLCompiler): class MSSQLStrictCompiler(MSSQLCompiler): + """A subclass of MSSQLCompiler which disables the usage of bind parameters where not allowed natively by MS-SQL. @@ -3632,36 +3622,27 @@ where @reflection.cache @_db_plus_owner def get_columns(self, connection, tablename, dbname, owner, schema, **kw): - sys_columns = ischema.sys_columns - sys_types = ischema.sys_types - sys_default_constraints = ischema.sys_default_constraints - computed_cols = ischema.computed_columns - identity_cols = ischema.identity_columns - extended_properties = ischema.extended_properties - - # to access sys tables, need an object_id. - # object_id() can normally match to the unquoted name even if it - # has special characters. however it also accepts quoted names, - # which means for the special case that the name itself has - # "quotes" (e.g. brackets for SQL Server) we need to "quote" (e.g. - # bracket) that name anyway. Fixed as part of #12654 - is_temp_table = tablename.startswith("#") if is_temp_table: owner, tablename = self._get_internal_temp_table_name( connection, tablename ) - object_id_tokens = [self.identifier_preparer.quote(tablename)] + columns = ischema.mssql_temp_table_columns + else: + columns = ischema.columns + + computed_cols = ischema.computed_columns + identity_cols = ischema.identity_columns if owner: - object_id_tokens.insert(0, self.identifier_preparer.quote(owner)) - - if is_temp_table: - object_id_tokens.insert(0, "tempdb") - - object_id = func.object_id(".".join(object_id_tokens)) - - whereclause = sys_columns.c.object_id == object_id + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) + full_name = columns.c.table_schema + "." + columns.c.table_name + else: + whereclause = columns.c.table_name == tablename + full_name = columns.c.table_name if self._supports_nvarchar_max: computed_definition = computed_cols.c.definition @@ -3671,112 +3652,92 @@ where computed_cols.c.definition, NVARCHAR(4000) ) + object_id = func.object_id(full_name) + s = ( sql.select( - sys_columns.c.name, - sys_types.c.name, - sys_columns.c.is_nullable, - sys_columns.c.max_length, - sys_columns.c.precision, - sys_columns.c.scale, - sys_default_constraints.c.definition, - sys_columns.c.collation_name, + columns.c.column_name, + columns.c.data_type, + columns.c.is_nullable, + columns.c.character_maximum_length, + columns.c.numeric_precision, + columns.c.numeric_scale, + columns.c.column_default, + columns.c.collation_name, computed_definition, computed_cols.c.is_persisted, identity_cols.c.is_identity, identity_cols.c.seed_value, identity_cols.c.increment_value, - extended_properties.c.value.label("comment"), - ) - .select_from(sys_columns) - .join( - sys_types, - onclause=sys_columns.c.user_type_id - == sys_types.c.user_type_id, - ) - .outerjoin( - sys_default_constraints, - sql.and_( - sys_default_constraints.c.object_id - == sys_columns.c.default_object_id, - sys_default_constraints.c.parent_column_id - == sys_columns.c.column_id, - ), + ischema.extended_properties.c.value.label("comment"), ) + .select_from(columns) .outerjoin( computed_cols, onclause=sql.and_( - computed_cols.c.object_id == sys_columns.c.object_id, - computed_cols.c.column_id == sys_columns.c.column_id, + computed_cols.c.object_id == object_id, + computed_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), ), ) .outerjoin( identity_cols, onclause=sql.and_( - identity_cols.c.object_id == sys_columns.c.object_id, - identity_cols.c.column_id == sys_columns.c.column_id, + identity_cols.c.object_id == object_id, + identity_cols.c.name + == columns.c.column_name.collate("DATABASE_DEFAULT"), ), ) .outerjoin( - extended_properties, + ischema.extended_properties, onclause=sql.and_( - extended_properties.c["class"] == 1, - extended_properties.c.name == "MS_Description", - sys_columns.c.object_id == extended_properties.c.major_id, - sys_columns.c.column_id == extended_properties.c.minor_id, + ischema.extended_properties.c["class"] == 1, + ischema.extended_properties.c.major_id == object_id, + ischema.extended_properties.c.minor_id + == columns.c.ordinal_position, + ischema.extended_properties.c.name == "MS_Description", ), ) .where(whereclause) - .order_by(sys_columns.c.column_id) + .order_by(columns.c.ordinal_position) ) - if is_temp_table: - exec_opts = {"schema_translate_map": {"sys": "tempdb.sys"}} - else: - exec_opts = {"schema_translate_map": {}} - c = connection.execution_options(**exec_opts).execute(s) + c = connection.execution_options(future_result=True).execute(s) cols = [] for row in c.mappings(): - name = row[sys_columns.c.name] - type_ = row[sys_types.c.name] - nullable = row[sys_columns.c.is_nullable] == 1 - maxlen = row[sys_columns.c.max_length] - numericprec = row[sys_columns.c.precision] - numericscale = row[sys_columns.c.scale] - default = row[sys_default_constraints.c.definition] - collation = row[sys_columns.c.collation_name] + name = row[columns.c.column_name] + type_ = row[columns.c.data_type] + nullable = row[columns.c.is_nullable] == "YES" + charlen = row[columns.c.character_maximum_length] + numericprec = row[columns.c.numeric_precision] + numericscale = row[columns.c.numeric_scale] + default = row[columns.c.column_default] + collation = row[columns.c.collation_name] definition = row[computed_definition] is_persisted = row[computed_cols.c.is_persisted] is_identity = row[identity_cols.c.is_identity] identity_start = row[identity_cols.c.seed_value] identity_increment = row[identity_cols.c.increment_value] - comment = row[extended_properties.c.value] + comment = row[ischema.extended_properties.c.value] coltype = self.ischema_names.get(type_, None) kwargs = {} - if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, MSBinary, MSVarBinary, sqltypes.LargeBinary, ): - kwargs["length"] = maxlen if maxlen != -1 else None - elif coltype in ( - MSString, - MSChar, - MSText, - ): - kwargs["length"] = maxlen if maxlen != -1 else None - if collation: - kwargs["collation"] = collation - elif coltype in ( - MSNVarchar, - MSNChar, - MSNText, - ): - kwargs["length"] = maxlen // 2 if maxlen != -1 else None + if charlen == -1: + charlen = None + kwargs["length"] = charlen if collation: kwargs["collation"] = collation @@ -4020,8 +3981,10 @@ index_info AS ( ) # group rows by constraint ID, to handle multi-column FKs - fkeys = util.defaultdict( - lambda: { + fkeys = [] + + def fkey_rec(): + return { "name": None, "constrained_columns": [], "referred_schema": None, @@ -4029,7 +3992,8 @@ index_info AS ( "referred_columns": [], "options": {}, } - ) + + fkeys = util.defaultdict(fkey_rec) for r in connection.execute(s).all(): ( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/information_schema.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/information_schema.py index 5a68e3a..e770313 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/information_schema.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/information_schema.py @@ -1,5 +1,5 @@ -# dialects/mssql/information_schema.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/information_schema.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -88,41 +88,23 @@ columns = Table( schema="INFORMATION_SCHEMA", ) -sys_columns = Table( - "columns", +mssql_temp_table_columns = Table( + "COLUMNS", ischema, - Column("object_id", Integer), - Column("name", CoerceUnicode), - Column("column_id", Integer), - Column("default_object_id", Integer), - Column("user_type_id", Integer), - Column("is_nullable", Integer), - Column("ordinal_position", Integer), - Column("max_length", Integer), - Column("precision", Integer), - Column("scale", Integer), - Column("collation_name", String), - schema="sys", -) - -sys_types = Table( - "types", - ischema, - Column("name", CoerceUnicode, key="name"), - Column("system_type_id", Integer, key="system_type_id"), - Column("user_type_id", Integer, key="user_type_id"), - Column("schema_id", Integer, key="schema_id"), - Column("max_length", Integer, key="max_length"), - Column("precision", Integer, key="precision"), - Column("scale", Integer, key="scale"), - Column("collation_name", CoerceUnicode, key="collation_name"), - Column("is_nullable", Boolean, key="is_nullable"), - Column("is_user_defined", Boolean, key="is_user_defined"), - Column("is_assembly_type", Boolean, key="is_assembly_type"), - Column("default_object_id", Integer, key="default_object_id"), - Column("rule_object_id", Integer, key="rule_object_id"), - Column("is_table_type", Boolean, key="is_table_type"), - schema="sys", + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="tempdb.INFORMATION_SCHEMA", ) constraints = Table( @@ -135,17 +117,6 @@ constraints = Table( schema="INFORMATION_SCHEMA", ) -sys_default_constraints = Table( - "default_constraints", - ischema, - Column("object_id", Integer), - Column("name", CoerceUnicode), - Column("schema_id", Integer), - Column("parent_column_id", Integer), - Column("definition", CoerceUnicode), - schema="sys", -) - column_constraints = Table( "CONSTRAINT_COLUMN_USAGE", ischema, @@ -211,7 +182,6 @@ computed_columns = Table( ischema, Column("object_id", Integer), Column("name", CoerceUnicode), - Column("column_id", Integer), Column("is_computed", Boolean), Column("is_persisted", Boolean), Column("definition", CoerceUnicode), @@ -237,7 +207,6 @@ class NumericSqlVariant(TypeDecorator): int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the correct value as string. """ - impl = Unicode cache_ok = True @@ -250,7 +219,6 @@ identity_columns = Table( ischema, Column("object_id", Integer), Column("name", CoerceUnicode), - Column("column_id", Integer), Column("is_identity", Boolean), Column("seed_value", NumericSqlVariant), Column("increment_value", NumericSqlVariant), diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/json.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/json.py index a2d3ce8..815b5d2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/json.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/json.py @@ -1,9 +1,3 @@ -# dialects/mssql/json.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import types as sqltypes @@ -54,7 +48,9 @@ class JSON(sqltypes.JSON): dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor should be used:: - stmt = select(data_table.c.data["some key"].as_json()).where( + stmt = select( + data_table.c.data["some key"].as_json() + ).where( data_table.c.data["some key"].as_json() == {"sub": "structure"} ) @@ -65,7 +61,9 @@ class JSON(sqltypes.JSON): :meth:`_types.JSON.Comparator.as_integer`, :meth:`_types.JSON.Comparator.as_float`:: - stmt = select(data_table.c.data["some key"].as_string()).where( + stmt = select( + data_table.c.data["some key"].as_string() + ).where( data_table.c.data["some key"].as_string() == "some string" ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/provision.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/provision.py index 1016585..096ae03 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/provision.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/provision.py @@ -1,9 +1,3 @@ -# dialects/mssql/provision.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from sqlalchemy import inspect @@ -22,17 +16,10 @@ from ...testing.provision import generate_driver_url from ...testing.provision import get_temp_table_name from ...testing.provision import log from ...testing.provision import normalize_sequence -from ...testing.provision import post_configure_engine from ...testing.provision import run_reap_dbs from ...testing.provision import temp_table_keyword_args -@post_configure_engine.for_db("mssql") -def post_configure_engine(url, engine, follower_ident): - if engine.driver == "pyodbc": - engine.dialect.dbapi.pooling = False - - @generate_driver_url.for_db("mssql") def generate_driver_url(url, driver, query_str): backend = url.get_backend_name() @@ -42,9 +29,6 @@ def generate_driver_url(url, driver, query_str): if driver not in ("pyodbc", "aioodbc"): new_url = new_url.set(query="") - if driver == "aioodbc": - new_url = new_url.update_query_dict({"MARS_Connection": "Yes"}) - if query_str: new_url = new_url.update_query_string(query_str) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pymssql.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pymssql.py index 301a98e..3823db9 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pymssql.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pymssql.py @@ -1,5 +1,5 @@ -# dialects/mssql/pymssql.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/pymssql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -103,7 +103,6 @@ class MSDialect_pymssql(MSDialect): "message 20006", # Write to the server failed "message 20017", # Unexpected EOF from the server "message 20047", # DBPROCESS is dead or not enabled - "The server failed to resume the transaction", ): if msg in str(e): return True diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pyodbc.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pyodbc.py index cbf0adb..a8f12fd 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pyodbc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mssql/pyodbc.py @@ -1,5 +1,5 @@ -# dialects/mssql/pyodbc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mssql/pyodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -30,9 +30,7 @@ is configured on the client, a basic DSN-based connection looks like:: engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn") -Which above, will pass the following connection string to PyODBC: - -.. sourcecode:: text +Which above, will pass the following connection string to PyODBC:: DSN=some_dsn;UID=scott;PWD=tiger @@ -51,9 +49,7 @@ When using a hostname connection, the driver name must also be specified in the query parameters of the URL. As these names usually have spaces in them, the name must be URL encoded which means using plus signs for spaces:: - engine = create_engine( - "mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server" - ) + engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server") The ``driver`` keyword is significant to the pyodbc dialect and must be specified in lowercase. @@ -73,7 +69,6 @@ internally:: The equivalent URL can be constructed using :class:`_sa.engine.URL`:: from sqlalchemy.engine import URL - connection_url = URL.create( "mssql+pyodbc", username="scott", @@ -88,6 +83,7 @@ The equivalent URL can be constructed using :class:`_sa.engine.URL`:: }, ) + Pass through exact Pyodbc string ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -98,11 +94,8 @@ using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object can help make this easier:: from sqlalchemy.engine import URL - connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password" - connection_url = URL.create( - "mssql+pyodbc", query={"odbc_connect": connection_string} - ) + connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string}) engine = create_engine(connection_url) @@ -134,8 +127,7 @@ database using Azure credentials:: from sqlalchemy.engine.url import URL from azure import identity - # Connection option for access tokens, as defined in msodbcsql.h - SQL_COPT_SS_ACCESS_TOKEN = 1256 + SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server" @@ -144,19 +136,14 @@ database using Azure credentials:: azure_credentials = identity.DefaultAzureCredential() - @event.listens_for(engine, "do_connect") def provide_token(dialect, conn_rec, cargs, cparams): # remove the "Trusted_Connection" parameter that SQLAlchemy adds cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "") # create token credential - raw_token = azure_credentials.get_token(TOKEN_URL).token.encode( - "utf-16-le" - ) - token_struct = struct.pack( - f" 7 into strings. The routines here are needed for older pyodbc versions diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/__init__.py index 9174c54..b6af683 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/__init__.py @@ -1,5 +1,5 @@ -# dialects/mysql/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -53,8 +53,7 @@ from .base import YEAR from .dml import Insert from .dml import insert from .expression import match -from .mariadb import INET4 -from .mariadb import INET6 +from ...util import compat # default dialect base.dialect = dialect = mysqldb.dialect @@ -72,8 +71,6 @@ __all__ = ( "DOUBLE", "ENUM", "FLOAT", - "INET4", - "INET6", "INTEGER", "INTEGER", "JSON", diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/aiomysql.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/aiomysql.py index af1ac2f..2a0c6ba 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/aiomysql.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/aiomysql.py @@ -1,9 +1,10 @@ -# dialects/mysql/aiomysql.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" .. dialect:: mysql+aiomysql @@ -22,108 +23,207 @@ This dialect should normally be used only with the :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4") - engine = create_async_engine( - "mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4" - ) """ # noqa -from __future__ import annotations - -from types import ModuleType -from typing import Any -from typing import Dict -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union - from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...connectors.asyncio import AsyncAdapt_dbapi_connection -from ...connectors.asyncio import AsyncAdapt_dbapi_cursor -from ...connectors.asyncio import AsyncAdapt_dbapi_module -from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only -if TYPE_CHECKING: - from ...connectors.asyncio import AsyncIODBAPIConnection - from ...connectors.asyncio import AsyncIODBAPICursor - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import PoolProxiedConnection - from ...engine.url import URL +class AsyncAdapt_aiomysql_cursor: + # TODO: base on connectors/asyncio.py + # see #10415 + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ -class AsyncAdapt_aiomysql_cursor(AsyncAdapt_dbapi_cursor): - __slots__ = () + cursor = self._connection.cursor(adapt_connection.dbapi.Cursor) - def _make_new_cursor( - self, connection: AsyncIODBAPIConnection - ) -> AsyncIODBAPICursor: - return connection.cursor(self._adapt_connection.dbapi.Cursor) + # see https://github.com/aio-libs/aiomysql/issues/543 + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] + @property + def description(self): + return self._cursor.description -class AsyncAdapt_aiomysql_ss_cursor( - AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_aiomysql_cursor -): - __slots__ = () + @property + def rowcount(self): + return self._cursor.rowcount - def _make_new_cursor( - self, connection: AsyncIODBAPIConnection - ) -> AsyncIODBAPICursor: - return connection.cursor( - self._adapt_connection.dbapi.aiomysql.cursors.SSCursor + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) ) + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._execute_mutex: + result = await self._cursor.execute(operation, parameters) -class AsyncAdapt_aiomysql_connection(AsyncAdapt_dbapi_connection): + if not self.server_side: + # aiomysql has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._execute_mutex: + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor): + # TODO: base on connectors/asyncio.py + # see #10415 __slots__ = () + server_side = True - _cursor_cls = AsyncAdapt_aiomysql_cursor - _ss_cursor_cls = AsyncAdapt_aiomysql_ss_cursor + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ - def ping(self, reconnect: bool) -> None: - assert not reconnect - self.await_(self._connection.ping(reconnect)) + cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor) - def character_set_name(self) -> Optional[str]: - return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501 + self._cursor = self.await_(cursor.__aenter__()) - def autocommit(self, value: Any) -> None: + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_aiomysql_connection(AdaptedConnection): + # TODO: base on connectors/asyncio.py + # see #10415 + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + def ping(self, reconnect): + return self.await_(self._connection.ping(reconnect)) + + def character_set_name(self): + return self._connection.character_set_name() + + def autocommit(self, value): self.await_(self._connection.autocommit(value)) - def get_autocommit(self) -> bool: - return self._connection.get_autocommit() # type: ignore + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_aiomysql_ss_cursor(self) + else: + return AsyncAdapt_aiomysql_cursor(self) - def terminate(self) -> None: + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): # it's not awaitable. self._connection.close() - def close(self) -> None: - self.await_(self._connection.ensure_closed()) - class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection): + # TODO: base on connectors/asyncio.py + # see #10415 __slots__ = () await_ = staticmethod(await_fallback) -class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): - def __init__(self, aiomysql: ModuleType, pymysql: ModuleType): +class AsyncAdapt_aiomysql_dbapi: + def __init__(self, aiomysql, pymysql): self.aiomysql = aiomysql self.pymysql = pymysql self.paramstyle = "format" self._init_dbapi_attributes() self.Cursor, self.SSCursor = self._init_cursors_subclasses() - def _init_dbapi_attributes(self) -> None: + def _init_dbapi_attributes(self): for name in ( "Warning", "Error", @@ -149,7 +249,7 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): ): setattr(self, name, getattr(self.pymysql, name)) - def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiomysql_connection: + def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect) @@ -164,23 +264,17 @@ class AsyncAdapt_aiomysql_dbapi(AsyncAdapt_dbapi_module): await_only(creator_fn(*arg, **kw)), ) - def _init_cursors_subclasses( - self, - ) -> Tuple[AsyncIODBAPICursor, AsyncIODBAPICursor]: + def _init_cursors_subclasses(self): # suppress unconditional warning emitted by aiomysql - class Cursor(self.aiomysql.Cursor): # type: ignore[misc, name-defined] - async def _show_warnings( - self, conn: AsyncIODBAPIConnection - ) -> None: + class Cursor(self.aiomysql.Cursor): + async def _show_warnings(self, conn): pass - class SSCursor(self.aiomysql.SSCursor): # type: ignore[misc, name-defined] # noqa: E501 - async def _show_warnings( - self, conn: AsyncIODBAPIConnection - ) -> None: + class SSCursor(self.aiomysql.SSCursor): + async def _show_warnings(self, conn): pass - return Cursor, SSCursor # type: ignore[return-value] + return Cursor, SSCursor class MySQLDialect_aiomysql(MySQLDialect_pymysql): @@ -191,16 +285,15 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): _sscursor = AsyncAdapt_aiomysql_ss_cursor is_async = True - has_terminate = True @classmethod - def import_dbapi(cls) -> AsyncAdapt_aiomysql_dbapi: + def import_dbapi(cls): return AsyncAdapt_aiomysql_dbapi( __import__("aiomysql"), __import__("pymysql") ) @classmethod - def get_pool_class(cls, url: URL) -> type: + def get_pool_class(cls, url): async_fallback = url.query.get("async_fallback", False) if util.asbool(async_fallback): @@ -208,37 +301,25 @@ class MySQLDialect_aiomysql(MySQLDialect_pymysql): else: return pool.AsyncAdaptedQueuePool - def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: - dbapi_connection.terminate() - - def create_connect_args( - self, url: URL, _translate_args: Optional[Dict[str, Any]] = None - ) -> ConnectArgsType: + def create_connect_args(self, url): return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): if super().is_disconnect(e, connection, cursor): return True else: str_e = str(e).lower() return "not connected" in str_e - def _found_rows_client_flag(self) -> int: - from pymysql.constants import CLIENT # type: ignore + def _found_rows_client_flag(self): + from pymysql.constants import CLIENT - return CLIENT.FOUND_ROWS # type: ignore[no-any-return] + return CLIENT.FOUND_ROWS - def get_driver_connection( - self, connection: DBAPIConnection - ) -> AsyncIODBAPIConnection: - return connection._connection # type: ignore[no-any-return] + def get_driver_connection(self, connection): + return connection._connection dialect = MySQLDialect_aiomysql diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/asyncmy.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/asyncmy.py index 61157fa..92058d6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/asyncmy.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/asyncmy.py @@ -1,9 +1,10 @@ -# dialects/mysql/asyncmy.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" .. dialect:: mysql+asyncmy @@ -20,100 +21,210 @@ This dialect should normally be used only with the :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4") - engine = create_async_engine( - "mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4" - ) """ # noqa -from __future__ import annotations - -from types import ModuleType -from typing import Any -from typing import NoReturn -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union +from contextlib import asynccontextmanager from .pymysql import MySQLDialect_pymysql from ... import pool from ... import util -from ...connectors.asyncio import AsyncAdapt_dbapi_connection -from ...connectors.asyncio import AsyncAdapt_dbapi_cursor -from ...connectors.asyncio import AsyncAdapt_dbapi_module -from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor +from ...engine import AdaptedConnection +from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only -if TYPE_CHECKING: - from ...connectors.asyncio import AsyncIODBAPIConnection - from ...connectors.asyncio import AsyncIODBAPICursor - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import PoolProxiedConnection - from ...engine.url import URL +class AsyncAdapt_asyncmy_cursor: + # TODO: base on connectors/asyncio.py + # see #10415 + server_side = False + __slots__ = ( + "_adapt_connection", + "_connection", + "await_", + "_cursor", + "_rows", + ) -class AsyncAdapt_asyncmy_cursor(AsyncAdapt_dbapi_cursor): - __slots__ = () + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ + cursor = self._connection.cursor() -class AsyncAdapt_asyncmy_ss_cursor( - AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_asyncmy_cursor -): - __slots__ = () + self._cursor = self.await_(cursor.__aenter__()) + self._rows = [] - def _make_new_cursor( - self, connection: AsyncIODBAPIConnection - ) -> AsyncIODBAPICursor: - return connection.cursor( - self._adapt_connection.dbapi.asyncmy.cursors.SSCursor + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + @property + def arraysize(self): + return self._cursor.arraysize + + @arraysize.setter + def arraysize(self, value): + self._cursor.arraysize = value + + @property + def lastrowid(self): + return self._cursor.lastrowid + + def close(self): + # note we aren't actually closing the cursor here, + # we are just letting GC do it. to allow this to be async + # we would need the Result to change how it does "Safe close cursor". + # MySQL "cursors" don't actually have state to be "closed" besides + # exhausting rows, which we already have done for sync cursor. + # another option would be to emulate aiosqlite dialect and assign + # cursor only if we are doing server side cursor operation. + self._rows[:] = [] + + def execute(self, operation, parameters=None): + return self.await_(self._execute_async(operation, parameters)) + + def executemany(self, operation, seq_of_parameters): + return self.await_( + self._executemany_async(operation, seq_of_parameters) ) + async def _execute_async(self, operation, parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + if parameters is None: + result = await self._cursor.execute(operation) + else: + result = await self._cursor.execute(operation, parameters) -class AsyncAdapt_asyncmy_connection(AsyncAdapt_dbapi_connection): + if not self.server_side: + # asyncmy has a "fake" async result, so we have to pull it out + # of that here since our default result is not async. + # we could just as easily grab "_rows" here and be done with it + # but this is safer. + self._rows = list(await self._cursor.fetchall()) + return result + + async def _executemany_async(self, operation, seq_of_parameters): + async with self._adapt_connection._mutex_and_adapt_errors(): + return await self._cursor.executemany(operation, seq_of_parameters) + + def setinputsizes(self, *inputsizes): + pass + + def __iter__(self): + while self._rows: + yield self._rows.pop(0) + + def fetchone(self): + if self._rows: + return self._rows.pop(0) + else: + return None + + def fetchmany(self, size=None): + if size is None: + size = self.arraysize + + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval + + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] + return retval + + +class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor): + # TODO: base on connectors/asyncio.py + # see #10415 __slots__ = () + server_side = True - _cursor_cls = AsyncAdapt_asyncmy_cursor - _ss_cursor_cls = AsyncAdapt_asyncmy_ss_cursor + def __init__(self, adapt_connection): + self._adapt_connection = adapt_connection + self._connection = adapt_connection._connection + self.await_ = adapt_connection.await_ - def _handle_exception(self, error: Exception) -> NoReturn: - if isinstance(error, AttributeError): - raise self.dbapi.InternalError( - "network operation failed due to asyncmy attribute error" - ) + cursor = self._connection.cursor( + adapt_connection.dbapi.asyncmy.cursors.SSCursor + ) - raise error + self._cursor = self.await_(cursor.__aenter__()) - def ping(self, reconnect: bool) -> None: + def close(self): + if self._cursor is not None: + self.await_(self._cursor.close()) + self._cursor = None + + def fetchone(self): + return self.await_(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return self.await_(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return self.await_(self._cursor.fetchall()) + + +class AsyncAdapt_asyncmy_connection(AdaptedConnection): + # TODO: base on connectors/asyncio.py + # see #10415 + await_ = staticmethod(await_only) + __slots__ = ("dbapi", "_execute_mutex") + + def __init__(self, dbapi, connection): + self.dbapi = dbapi + self._connection = connection + self._execute_mutex = asyncio.Lock() + + @asynccontextmanager + async def _mutex_and_adapt_errors(self): + async with self._execute_mutex: + try: + yield + except AttributeError: + raise self.dbapi.InternalError( + "network operation failed due to asyncmy attribute error" + ) + + def ping(self, reconnect): assert not reconnect return self.await_(self._do_ping()) - async def _do_ping(self) -> None: - try: - async with self._execute_mutex: - await self._connection.ping(False) - except Exception as error: - self._handle_exception(error) + async def _do_ping(self): + async with self._mutex_and_adapt_errors(): + return await self._connection.ping(False) - def character_set_name(self) -> Optional[str]: - return self._connection.character_set_name() # type: ignore[no-any-return] # noqa: E501 + def character_set_name(self): + return self._connection.character_set_name() - def autocommit(self, value: Any) -> None: + def autocommit(self, value): self.await_(self._connection.autocommit(value)) - def get_autocommit(self) -> bool: - return self._connection.get_autocommit() # type: ignore + def cursor(self, server_side=False): + if server_side: + return AsyncAdapt_asyncmy_ss_cursor(self) + else: + return AsyncAdapt_asyncmy_cursor(self) - def terminate(self) -> None: + def rollback(self): + self.await_(self._connection.rollback()) + + def commit(self): + self.await_(self._connection.commit()) + + def close(self): # it's not awaitable. self._connection.close() - def close(self) -> None: - self.await_(self._connection.ensure_closed()) - class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): __slots__ = () @@ -121,13 +232,18 @@ class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection): await_ = staticmethod(await_fallback) -class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module): - def __init__(self, asyncmy: ModuleType): +def _Binary(x): + """Return x as a binary type.""" + return bytes(x) + + +class AsyncAdapt_asyncmy_dbapi: + def __init__(self, asyncmy): self.asyncmy = asyncmy self.paramstyle = "format" self._init_dbapi_attributes() - def _init_dbapi_attributes(self) -> None: + def _init_dbapi_attributes(self): for name in ( "Warning", "Error", @@ -148,9 +264,9 @@ class AsyncAdapt_asyncmy_dbapi(AsyncAdapt_dbapi_module): BINARY = util.symbol("BINARY") DATETIME = util.symbol("DATETIME") TIMESTAMP = util.symbol("TIMESTAMP") - Binary = staticmethod(bytes) + Binary = staticmethod(_Binary) - def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_asyncmy_connection: + def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect) @@ -174,14 +290,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): _sscursor = AsyncAdapt_asyncmy_ss_cursor is_async = True - has_terminate = True @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls): return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy")) @classmethod - def get_pool_class(cls, url: URL) -> type: + def get_pool_class(cls, url): async_fallback = url.query.get("async_fallback", False) if util.asbool(async_fallback): @@ -189,20 +304,12 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): else: return pool.AsyncAdaptedQueuePool - def do_terminate(self, dbapi_connection: DBAPIConnection) -> None: - dbapi_connection.terminate() - - def create_connect_args(self, url: URL) -> ConnectArgsType: # type: ignore[override] # noqa: E501 + def create_connect_args(self, url): return super().create_connect_args( url, _translate_args=dict(username="user", database="db") ) - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): if super().is_disconnect(e, connection, cursor): return True else: @@ -211,15 +318,13 @@ class MySQLDialect_asyncmy(MySQLDialect_pymysql): "not connected" in str_e or "network operation failed" in str_e ) - def _found_rows_client_flag(self) -> int: - from asyncmy.constants import CLIENT # type: ignore + def _found_rows_client_flag(self): + from asyncmy.constants import CLIENT - return CLIENT.FOUND_ROWS # type: ignore[no-any-return] + return CLIENT.FOUND_ROWS - def get_driver_connection( - self, connection: DBAPIConnection - ) -> AsyncIODBAPIConnection: - return connection._connection # type: ignore[no-any-return] + def get_driver_connection(self, connection): + return connection._connection dialect = MySQLDialect_asyncmy diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/base.py index f398fe8..92f9077 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/base.py @@ -1,15 +1,17 @@ -# dialects/mysql/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" .. dialect:: mysql :name: MySQL / MariaDB + :full_support: 5.6, 5.7, 8.0 / 10.8, 10.9 :normal_support: 5.6+ / 10+ :best_effort: 5.0.2+ / 5.0.2+ @@ -33,9 +35,7 @@ syntactical and behavioral differences that SQLAlchemy accommodates automaticall To connect to a MariaDB database, no changes to the database URL are required:: - engine = create_engine( - "mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4" - ) + engine = create_engine("mysql+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") Upon first connect, the SQLAlchemy dialect employs a server version detection scheme that determines if the @@ -53,9 +53,7 @@ useful for the case where an application makes use of MariaDB-specific features and is not compatible with a MySQL database. To use this mode of operation, replace the "mysql" token in the above URL with "mariadb":: - engine = create_engine( - "mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4" - ) + engine = create_engine("mariadb+pymysql://user:pass@some_mariadb/dbname?charset=utf8mb4") The above engine, upon first connect, will raise an error if the server version detection detects that the backing database is not MariaDB. @@ -101,7 +99,7 @@ the :paramref:`_sa.create_engine.pool_recycle` option which ensures that a connection will be discarded and replaced with a new one if it has been present in the pool for a fixed number of seconds:: - engine = create_engine("mysql+mysqldb://...", pool_recycle=3600) + engine = create_engine('mysql+mysqldb://...', pool_recycle=3600) For more comprehensive disconnect detection of pooled connections, including accommodation of server restarts and network issues, a pre-ping approach may @@ -125,14 +123,12 @@ To accommodate the rendering of these arguments, specify the form ``ENGINE`` of ``InnoDB``, ``CHARSET`` of ``utf8mb4``, and ``KEY_BLOCK_SIZE`` of ``1024``:: - Table( - "mytable", - metadata, - Column("data", String(32)), - mysql_engine="InnoDB", - mysql_charset="utf8mb4", - mysql_key_block_size="1024", - ) + Table('mytable', metadata, + Column('data', String(32)), + mysql_engine='InnoDB', + mysql_charset='utf8mb4', + mysql_key_block_size="1024" + ) When supporting :ref:`mysql_mariadb_only_mode` mode, similar keys against the "mariadb" prefix must be included as well. The values can of course @@ -141,17 +137,19 @@ be maintained:: # support both "mysql" and "mariadb-only" engine URLs - Table( - "mytable", - metadata, - Column("data", String(32)), - mysql_engine="InnoDB", - mariadb_engine="InnoDB", - mysql_charset="utf8mb4", - mariadb_charset="utf8", - mysql_key_block_size="1024", - mariadb_key_block_size="1024", - ) + Table('mytable', metadata, + Column('data', String(32)), + + mysql_engine='InnoDB', + mariadb_engine='InnoDB', + + mysql_charset='utf8mb4', + mariadb_charset='utf8', + + mysql_key_block_size="1024" + mariadb_key_block_size="1024" + + ) The MySQL / MariaDB dialects will normally transfer any keyword specified as ``mysql_keyword_name`` to be rendered as ``KEYWORD_NAME`` in the @@ -181,31 +179,6 @@ For fully atomic transactions as well as support for foreign key constraints, all participating ``CREATE TABLE`` statements must specify a transactional engine, which in the vast majority of cases is ``InnoDB``. -Partitioning can similarly be specified using similar options. -In the example below the create table will specify ``PARTITION_BY``, -``PARTITIONS``, ``SUBPARTITIONS`` and ``SUBPARTITION_BY``:: - - # can also use mariadb_* prefix - Table( - "testtable", - MetaData(), - Column("id", Integer(), primary_key=True, autoincrement=True), - Column("other_id", Integer(), primary_key=True, autoincrement=False), - mysql_partitions="2", - mysql_partition_by="KEY(other_id)", - mysql_subpartition_by="HASH(some_expr)", - mysql_subpartitions="2", - ) - -This will render: - -.. sourcecode:: sql - - CREATE TABLE testtable ( - id INTEGER NOT NULL AUTO_INCREMENT, - other_id INTEGER NOT NULL, - PRIMARY KEY (id, other_id) - )PARTITION BY KEY(other_id) PARTITIONS 2 SUBPARTITION BY HASH(some_expr) SUBPARTITIONS 2 Case Sensitivity and Table Reflection ------------------------------------- @@ -242,14 +215,16 @@ techniques are used. To set isolation level using :func:`_sa.create_engine`:: engine = create_engine( - "mysql+mysqldb://scott:tiger@localhost/test", - isolation_level="READ UNCOMMITTED", - ) + "mysql+mysqldb://scott:tiger@localhost/test", + isolation_level="READ UNCOMMITTED" + ) To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options(isolation_level="READ COMMITTED") + connection = connection.execution_options( + isolation_level="READ COMMITTED" + ) Valid values for ``isolation_level`` include: @@ -281,8 +256,8 @@ When creating tables, SQLAlchemy will automatically set ``AUTO_INCREMENT`` on the first :class:`.Integer` primary key column which is not marked as a foreign key:: - >>> t = Table( - ... "mytable", metadata, Column("mytable_id", Integer, primary_key=True) + >>> t = Table('mytable', metadata, + ... Column('mytable_id', Integer, primary_key=True) ... ) >>> t.create() CREATE TABLE mytable ( @@ -296,12 +271,10 @@ This flag can also be used to enable auto-increment on a secondary column in a multi-column key for some storage engines:: - Table( - "mytable", - metadata, - Column("gid", Integer, primary_key=True, autoincrement=False), - Column("id", Integer, primary_key=True), - ) + Table('mytable', metadata, + Column('gid', Integer, primary_key=True, autoincrement=False), + Column('id', Integer, primary_key=True) + ) .. _mysql_ss_cursors: @@ -319,9 +292,7 @@ Server side cursors are enabled on a per-statement basis by using the option:: with engine.connect() as conn: - result = conn.execution_options(stream_results=True).execute( - text("select * from table") - ) + result = conn.execution_options(stream_results=True).execute(text("select * from table")) Note that some kinds of SQL statements may not be supported with server side cursors; generally, only SQL statements that return rows should be @@ -349,8 +320,7 @@ a connection. This is typically delivered using the ``charset`` parameter in the URL, such as:: e = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4" - ) + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") This charset is the **client character set** for the connection. Some MySQL DBAPIs will default this to a value such as ``latin1``, and some @@ -370,8 +340,7 @@ charset is preferred, if supported by both the database as well as the client DBAPI, as in:: e = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4" - ) + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4") All modern DBAPIs should support the ``utf8mb4`` charset. @@ -393,9 +362,7 @@ Dealing with Binary Data Warnings and Unicode MySQL versions 5.6, 5.7 and later (not MariaDB at the time of this writing) now emit a warning when attempting to pass binary data to the database, while a character set encoding is also in place, when the binary data itself is not -valid for that encoding: - -.. sourcecode:: text +valid for that encoding:: default.py:509: Warning: (1300, "Invalid utf8mb4 character string: 'F9876A'") @@ -405,9 +372,7 @@ This warning is due to the fact that the MySQL client library is attempting to interpret the binary string as a unicode object even if a datatype such as :class:`.LargeBinary` is in use. To resolve this, the SQL statement requires a binary "character set introducer" be present before any non-NULL value -that renders like this: - -.. sourcecode:: sql +that renders like this:: INSERT INTO table (data) VALUES (_binary %s) @@ -417,13 +382,12 @@ string parameter ``binary_prefix=true`` to the URL to repair this warning:: # mysqlclient engine = create_engine( - "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true" - ) + "mysql+mysqldb://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") # PyMySQL engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true" - ) + "mysql+pymysql://scott:tiger@localhost/test?charset=utf8mb4&binary_prefix=true") + The ``binary_prefix`` flag may or may not be supported by other MySQL drivers. @@ -466,10 +430,7 @@ the ``first_connect`` and ``connect`` events:: from sqlalchemy import create_engine, event - eng = create_engine( - "mysql+mysqldb://scott:tiger@localhost/test", echo="debug" - ) - + eng = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo='debug') # `insert=True` will ensure this is the very first listener to run @event.listens_for(eng, "connect", insert=True) @@ -477,7 +438,6 @@ the ``first_connect`` and ``connect`` events:: cursor = dbapi_connection.cursor() cursor.execute("SET sql_mode = 'STRICT_ALL_TABLES'") - conn = eng.connect() In the example illustrated above, the "connect" event will invoke the "SET" @@ -494,8 +454,8 @@ MySQL / MariaDB SQL Extensions Many of the MySQL / MariaDB SQL extensions are handled through SQLAlchemy's generic function and operator support:: - table.select(table.c.password == func.md5("plaintext")) - table.select(table.c.username.op("regexp")("^[a-d]")) + table.select(table.c.password==func.md5('plaintext')) + table.select(table.c.username.op('regexp')('^[a-d]')) And of course any valid SQL statement can be executed as a string as well. @@ -508,18 +468,11 @@ available. * SELECT pragma, use :meth:`_expression.Select.prefix_with` and :meth:`_query.Query.prefix_with`:: - select(...).prefix_with(["HIGH_PRIORITY", "SQL_SMALL_RESULT"]) + select(...).prefix_with(['HIGH_PRIORITY', 'SQL_SMALL_RESULT']) * UPDATE with LIMIT:: - update(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) - -* DELETE - with LIMIT:: - - delete(...).with_dialect_options(mysql_limit=10, mariadb_limit=10) - - .. versionadded:: 2.0.37 Added delete with limit + update(..., mysql_limit=10, mariadb_limit=10) * optimizer hints, use :meth:`_expression.Select.prefix_with` and :meth:`_query.Query.prefix_with`:: @@ -531,16 +484,14 @@ available. select(...).with_hint(some_table, "USE INDEX xyz") -* MATCH - operator support:: +* MATCH operator support:: - from sqlalchemy.dialects.mysql import match + from sqlalchemy.dialects.mysql import match + select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) - select(...).where(match(col1, col2, against="some expr").in_boolean_mode()) + .. seealso:: - .. seealso:: - - :class:`_mysql.match` + :class:`_mysql.match` INSERT/DELETE...RETURNING ------------------------- @@ -557,15 +508,17 @@ To specify an explicit ``RETURNING`` clause, use the # INSERT..RETURNING result = connection.execute( - table.insert().values(name="foo").returning(table.c.col1, table.c.col2) + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) ) print(result.all()) # DELETE..RETURNING result = connection.execute( - table.delete() - .where(table.c.name == "foo") - .returning(table.c.col1, table.c.col2) + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) ) print(result.all()) @@ -592,11 +545,12 @@ the generative method :meth:`~.mysql.Insert.on_duplicate_key_update`: >>> from sqlalchemy.dialects.mysql import insert >>> insert_stmt = insert(my_table).values( - ... id="some_existing_id", data="inserted value" - ... ) + ... id='some_existing_id', + ... data='inserted value') >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( - ... data=insert_stmt.inserted.data, status="U" + ... data=insert_stmt.inserted.data, + ... status='U' ... ) >>> print(on_duplicate_key_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%s, %s) @@ -621,8 +575,8 @@ as values: .. sourcecode:: pycon+sql >>> insert_stmt = insert(my_table).values( - ... id="some_existing_id", data="inserted value" - ... ) + ... id='some_existing_id', + ... data='inserted value') >>> on_duplicate_key_stmt = insert_stmt.on_duplicate_key_update( ... data="some data", @@ -685,11 +639,13 @@ table: .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id="some_id", data="inserted value", author="jlh" - ... ) + ... id='some_id', + ... data='inserted value', + ... author='jlh') >>> do_update_stmt = stmt.on_duplicate_key_update( - ... data="updated value", author=stmt.inserted.author + ... data="updated value", + ... author=stmt.inserted.author ... ) >>> print(do_update_stmt) @@ -734,13 +690,13 @@ MySQL and MariaDB both provide an option to create index entries with a certain become part of the index. SQLAlchemy provides this feature via the ``mysql_length`` and/or ``mariadb_length`` parameters:: - Index("my_index", my_table.c.data, mysql_length=10, mariadb_length=10) + Index('my_index', my_table.c.data, mysql_length=10, mariadb_length=10) - Index("a_b_idx", my_table.c.a, my_table.c.b, mysql_length={"a": 4, "b": 9}) + Index('a_b_idx', my_table.c.a, my_table.c.b, mysql_length={'a': 4, + 'b': 9}) - Index( - "a_b_idx", my_table.c.a, my_table.c.b, mariadb_length={"a": 4, "b": 9} - ) + Index('a_b_idx', my_table.c.a, my_table.c.b, mariadb_length={'a': 4, + 'b': 9}) Prefix lengths are given in characters for nonbinary string types and in bytes for binary string types. The value passed to the keyword argument *must* be @@ -757,7 +713,7 @@ MySQL storage engines permit you to specify an index prefix when creating an index. SQLAlchemy provides this feature via the ``mysql_prefix`` parameter on :class:`.Index`:: - Index("my_index", my_table.c.data, mysql_prefix="FULLTEXT") + Index('my_index', my_table.c.data, mysql_prefix='FULLTEXT') The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX, so it *must* be a valid index prefix for your MySQL @@ -774,13 +730,11 @@ Some MySQL storage engines permit you to specify an index type when creating an index or primary key constraint. SQLAlchemy provides this feature via the ``mysql_using`` parameter on :class:`.Index`:: - Index( - "my_index", my_table.c.data, mysql_using="hash", mariadb_using="hash" - ) + Index('my_index', my_table.c.data, mysql_using='hash', mariadb_using='hash') As well as the ``mysql_using`` parameter on :class:`.PrimaryKeyConstraint`:: - PrimaryKeyConstraint("data", mysql_using="hash", mariadb_using="hash") + PrimaryKeyConstraint("data", mysql_using='hash', mariadb_using='hash') The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX or PRIMARY KEY clause, so it *must* be a valid index @@ -799,12 +753,9 @@ CREATE FULLTEXT INDEX in MySQL also supports a "WITH PARSER" option. This is available using the keyword argument ``mysql_with_parser``:: Index( - "my_index", - my_table.c.data, - mysql_prefix="FULLTEXT", - mysql_with_parser="ngram", - mariadb_prefix="FULLTEXT", - mariadb_with_parser="ngram", + 'my_index', my_table.c.data, + mysql_prefix='FULLTEXT', mysql_with_parser="ngram", + mariadb_prefix='FULLTEXT', mariadb_with_parser="ngram", ) .. versionadded:: 1.3 @@ -831,7 +782,6 @@ them ignored on a MySQL / MariaDB backend, use a custom compile rule:: from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import ForeignKeyConstraint - @compiles(ForeignKeyConstraint, "mysql", "mariadb") def process(element, compiler, **kw): element.deferrable = element.initially = None @@ -853,12 +803,10 @@ very common ``MyISAM`` MySQL storage engine, the information loaded by table reflection will not include foreign keys. For these tables, you may supply a :class:`~sqlalchemy.ForeignKeyConstraint` at reflection time:: - Table( - "mytable", - metadata, - ForeignKeyConstraint(["other_id"], ["othertable.other_id"]), - autoload_with=engine, - ) + Table('mytable', metadata, + ForeignKeyConstraint(['other_id'], ['othertable.other_id']), + autoload_with=engine + ) .. seealso:: @@ -930,15 +878,13 @@ parameter and pass a textual clause that also includes the ON UPDATE clause:: mytable = Table( "mytable", metadata, - Column("id", Integer, primary_key=True), - Column("data", String(50)), + Column('id', Integer, primary_key=True), + Column('data', String(50)), Column( - "last_updated", + 'last_updated', TIMESTAMP, - server_default=text( - "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" - ), - ), + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) ) The same instructions apply to use of the :class:`_types.DateTime` and @@ -949,37 +895,34 @@ The same instructions apply to use of the :class:`_types.DateTime` and mytable = Table( "mytable", metadata, - Column("id", Integer, primary_key=True), - Column("data", String(50)), + Column('id', Integer, primary_key=True), + Column('data', String(50)), Column( - "last_updated", + 'last_updated', DateTime, - server_default=text( - "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" - ), - ), + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP") + ) ) + Even though the :paramref:`_schema.Column.server_onupdate` feature does not generate this DDL, it still may be desirable to signal to the ORM that this updated value should be fetched. This syntax looks like the following:: from sqlalchemy.schema import FetchedValue - class MyClass(Base): - __tablename__ = "mytable" + __tablename__ = 'mytable' id = Column(Integer, primary_key=True) data = Column(String(50)) last_updated = Column( TIMESTAMP, - server_default=text( - "CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP" - ), - server_onupdate=FetchedValue(), + server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"), + server_onupdate=FetchedValue() ) + .. _mysql_timestamp_null: TIMESTAMP Columns and NULL @@ -989,9 +932,7 @@ MySQL historically enforces that a column which specifies the TIMESTAMP datatype implicitly includes a default value of CURRENT_TIMESTAMP, even though this is not stated, and additionally sets the column as NOT NULL, the opposite behavior vs. that of all -other datatypes: - -.. sourcecode:: text +other datatypes:: mysql> CREATE TABLE ts_test ( -> a INTEGER, @@ -1036,24 +977,19 @@ SQLAlchemy also emits NOT NULL for TIMESTAMP columns that do specify from sqlalchemy.dialects.mysql import TIMESTAMP m = MetaData() - t = Table( - "ts_test", - m, - Column("a", Integer), - Column("b", Integer, nullable=False), - Column("c", TIMESTAMP), - Column("d", TIMESTAMP, nullable=False), - ) + t = Table('ts_test', m, + Column('a', Integer), + Column('b', Integer, nullable=False), + Column('c', TIMESTAMP), + Column('d', TIMESTAMP, nullable=False) + ) from sqlalchemy import create_engine - e = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo=True) m.create_all(e) -output: - -.. sourcecode:: sql +output:: CREATE TABLE ts_test ( a INTEGER, @@ -1065,22 +1001,11 @@ output: """ # noqa from __future__ import annotations +from array import array as _array from collections import defaultdict from itertools import compress import re -from typing import Any -from typing import Callable from typing import cast -from typing import DefaultDict -from typing import Dict -from typing import List -from typing import NoReturn -from typing import Optional -from typing import overload -from typing import Sequence -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union from . import reflection as _reflection from .enumerated import ENUM @@ -1123,6 +1048,7 @@ from .types import VARCHAR from .types import YEAR from ... import exc from ... import literal_column +from ... import log from ... import schema as sa_schema from ... import sql from ... import util @@ -1146,46 +1072,10 @@ from ...types import BINARY from ...types import BLOB from ...types import BOOLEAN from ...types import DATE -from ...types import LargeBinary from ...types import UUID from ...types import VARBINARY from ...util import topological -if TYPE_CHECKING: - - from ...dialects.mysql import expression - from ...dialects.mysql.dml import OnDuplicateClause - from ...engine.base import Connection - from ...engine.cursor import CursorResult - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import IsolationLevel - from ...engine.interfaces import PoolProxiedConnection - from ...engine.interfaces import ReflectedCheckConstraint - from ...engine.interfaces import ReflectedColumn - from ...engine.interfaces import ReflectedForeignKeyConstraint - from ...engine.interfaces import ReflectedIndex - from ...engine.interfaces import ReflectedPrimaryKeyConstraint - from ...engine.interfaces import ReflectedTableComment - from ...engine.interfaces import ReflectedUniqueConstraint - from ...engine.row import Row - from ...engine.url import URL - from ...schema import Table - from ...sql import ddl - from ...sql import selectable - from ...sql.dml import _DMLTableElement - from ...sql.dml import Delete - from ...sql.dml import Update - from ...sql.dml import ValuesBase - from ...sql.functions import aggregate_strings - from ...sql.functions import random - from ...sql.functions import rollup - from ...sql.functions import sysdate - from ...sql.schema import Sequence as Sequence_SchemaItem - from ...sql.type_api import TypeEngine - from ...sql.visitors import ExternallyTraversible - SET_RE = re.compile( r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE @@ -1280,7 +1170,7 @@ ischema_names = { class MySQLExecutionContext(default.DefaultExecutionContext): - def post_exec(self) -> None: + def post_exec(self): if ( self.isdelete and cast(SQLCompiler, self.compiled).effective_returning @@ -1297,7 +1187,7 @@ class MySQLExecutionContext(default.DefaultExecutionContext): _cursor.FullyBufferedCursorFetchStrategy( self.cursor, [ - (entry.keyname, None) # type: ignore[misc] + (entry.keyname, None) for entry in cast( SQLCompiler, self.compiled )._result_columns @@ -1306,18 +1196,14 @@ class MySQLExecutionContext(default.DefaultExecutionContext): ) ) - def create_server_side_cursor(self) -> DBAPICursor: + def create_server_side_cursor(self): if self.dialect.supports_server_side_cursors: - return self._dbapi_connection.cursor( - self.dialect._sscursor # type: ignore[attr-defined] - ) + return self._dbapi_connection.cursor(self.dialect._sscursor) else: raise NotImplementedError() - def fire_sequence( - self, seq: Sequence_SchemaItem, type_: sqltypes.Integer - ) -> int: - return self._execute_scalar( # type: ignore[no-any-return] + def fire_sequence(self, seq, type_): + return self._execute_scalar( ( "select nextval(%s)" % self.identifier_preparer.format_sequence(seq) @@ -1327,51 +1213,46 @@ class MySQLExecutionContext(default.DefaultExecutionContext): class MySQLCompiler(compiler.SQLCompiler): - dialect: MySQLDialect render_table_with_column_in_update_from = True """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() extract_map.update({"milliseconds": "millisecond"}) - def default_from(self) -> str: + def default_from(self): """Called when a ``SELECT`` statement has no froms, and no ``FROM`` clause is to be appended. """ if self.stack: stmt = self.stack[-1]["selectable"] - if stmt._where_criteria: # type: ignore[attr-defined] + if stmt._where_criteria: return " FROM DUAL" return "" - def visit_random_func(self, fn: random, **kw: Any) -> str: + def visit_random_func(self, fn, **kw): return "rand%s" % self.function_argspec(fn) - def visit_rollup_func(self, fn: rollup[Any], **kw: Any) -> str: + def visit_rollup_func(self, fn, **kw): clause = ", ".join( elem._compiler_dispatch(self, **kw) for elem in fn.clauses ) return f"{clause} WITH ROLLUP" - def visit_aggregate_strings_func( - self, fn: aggregate_strings, **kw: Any - ) -> str: + def visit_aggregate_strings_func(self, fn, **kw): expr, delimeter = ( elem._compiler_dispatch(self, **kw) for elem in fn.clauses ) return f"group_concat({expr} SEPARATOR {delimeter})" - def visit_sequence(self, sequence: sa_schema.Sequence, **kw: Any) -> str: - return "nextval(%s)" % self.preparer.format_sequence(sequence) + def visit_sequence(self, seq, **kw): + return "nextval(%s)" % self.preparer.format_sequence(seq) - def visit_sysdate_func(self, fn: sysdate, **kw: Any) -> str: + def visit_sysdate_func(self, fn, **kw): return "SYSDATE()" - def _render_json_extract_from_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def _render_json_extract_from_binary(self, binary, operator, **kw): # note we are intentionally calling upon the process() calls in the # order in which they appear in the SQL String as this is used # by positional parameter rendering @@ -1398,10 +1279,9 @@ class MySQLCompiler(compiler.SQLCompiler): ) ) elif binary.type._type_affinity is sqltypes.Numeric: - binary_type = cast(sqltypes.Numeric[Any], binary.type) if ( - binary_type.scale is not None - and binary_type.precision is not None + binary.type.scale is not None + and binary.type.precision is not None ): # using DECIMAL here because MySQL does not recognize NUMERIC type_expression = ( @@ -1409,8 +1289,8 @@ class MySQLCompiler(compiler.SQLCompiler): % ( self.process(binary.left, **kw), self.process(binary.right, **kw), - binary_type.precision, - binary_type.scale, + binary.type.precision, + binary.type.scale, ) ) else: @@ -1444,22 +1324,15 @@ class MySQLCompiler(compiler.SQLCompiler): return case_expression + " " + type_expression + " END" - def visit_json_getitem_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_json_getitem_op_binary(self, binary, operator, **kw): return self._render_json_extract_from_binary(binary, operator, **kw) - def visit_json_path_getitem_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return self._render_json_extract_from_binary(binary, operator, **kw) - def visit_on_duplicate_key_update( - self, on_duplicate: OnDuplicateClause, **kw: Any - ) -> str: - statement: ValuesBase = self.current_executable + def visit_on_duplicate_key_update(self, on_duplicate, **kw): + statement = self.current_executable - cols: List[elements.KeyedColumnElement[Any]] if on_duplicate._parameter_ordering: parameter_ordering = [ coercions.expect(roles.DMLColumnRole, key) @@ -1472,56 +1345,49 @@ class MySQLCompiler(compiler.SQLCompiler): if key in statement.table.c ] + [c for c in statement.table.c if c.key not in ordered_keys] else: - cols = list(statement.table.c) + cols = statement.table.c clauses = [] - requires_mysql8_alias = statement.select is None and ( + requires_mysql8_alias = ( self.dialect._requires_alias_for_on_duplicate_key ) if requires_mysql8_alias: - if statement.table.name.lower() == "new": # type: ignore[union-attr] # noqa: E501 + if statement.table.name.lower() == "new": _on_dup_alias_name = "new_1" else: _on_dup_alias_name = "new" - on_duplicate_update = { - coercions.expect_as_key(roles.DMLColumnRole, key): value - for key, value in on_duplicate.update.items() - } - # traverses through all table columns to preserve table column order - for column in (col for col in cols if col.key in on_duplicate_update): - val = on_duplicate_update[column.key] + for column in (col for col in cols if col.key in on_duplicate.update): + val = on_duplicate.update[column.key] - # TODO: this coercion should be up front. we can't cache - # SQL constructs with non-bound literals buried in them if coercions._is_literal(val): val = elements.BindParameter(None, val, type_=column.type) value_text = self.process(val.self_group(), use_schema=False) else: - def replace( - element: ExternallyTraversible, **kw: Any - ) -> Optional[ExternallyTraversible]: + def replace(obj): if ( - isinstance(element, elements.BindParameter) - and element.type._isnull + isinstance(obj, elements.BindParameter) + and obj.type._isnull ): - return element._with_binary_element_type(column.type) + obj = obj._clone() + obj.type = column.type + return obj elif ( - isinstance(element, elements.ColumnClause) - and element.table is on_duplicate.inserted_alias + isinstance(obj, elements.ColumnClause) + and obj.table is on_duplicate.inserted_alias ): if requires_mysql8_alias: column_literal_clause = ( f"{_on_dup_alias_name}." - f"{self.preparer.quote(element.name)}" + f"{self.preparer.quote(obj.name)}" ) else: column_literal_clause = ( - f"VALUES({self.preparer.quote(element.name)})" + f"VALUES({self.preparer.quote(obj.name)})" ) return literal_column(column_literal_clause) else: @@ -1534,13 +1400,13 @@ class MySQLCompiler(compiler.SQLCompiler): name_text = self.preparer.quote(column.name) clauses.append("%s = %s" % (name_text, value_text)) - non_matching = set(on_duplicate_update) - {c.key for c in cols} + non_matching = set(on_duplicate.update) - {c.key for c in cols} if non_matching: util.warn( "Additional column names not matching " "any column keys in table '%s': %s" % ( - self.statement.table.name, # type: ignore[union-attr] + self.statement.table.name, (", ".join("'%s'" % c for c in non_matching)), ) ) @@ -1554,15 +1420,13 @@ class MySQLCompiler(compiler.SQLCompiler): return f"ON DUPLICATE KEY UPDATE {', '.join(clauses)}" def visit_concat_op_expression_clauselist( - self, clauselist: elements.ClauseList, operator: Any, **kw: Any - ) -> str: + self, clauselist, operator, **kw + ): return "concat(%s)" % ( ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) ) - def visit_concat_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), @@ -1585,12 +1449,10 @@ class MySQLCompiler(compiler.SQLCompiler): "WITH QUERY EXPANSION", ) - def visit_mysql_match(self, element: expression.match, **kw: Any) -> str: + def visit_mysql_match(self, element, **kw): return self.visit_match_op_binary(element, element.operator, **kw) - def visit_match_op_binary( - self, binary: expression.match, operator: Any, **kw: Any - ) -> str: + def visit_match_op_binary(self, binary, operator, **kw): """ Note that `mysql_boolean_mode` is enabled by default because of backward compatibility @@ -1611,11 +1473,12 @@ class MySQLCompiler(compiler.SQLCompiler): "with_query_expansion=%s" % query_expansion, ) - flags_str = ", ".join(flags) + flags = ", ".join(flags) - raise exc.CompileError("Invalid MySQL match flags: %s" % flags_str) + raise exc.CompileError("Invalid MySQL match flags: %s" % flags) - match_clause = self.process(binary.left, **kw) + match_clause = binary.left + match_clause = self.process(match_clause, **kw) against_clause = self.process(binary.right, **kw) if any(flag_combination): @@ -1624,25 +1487,21 @@ class MySQLCompiler(compiler.SQLCompiler): flag_combination, ) - against_clause = " ".join([against_clause, *flag_expressions]) + against_clause = [against_clause] + against_clause.extend(flag_expressions) + + against_clause = " ".join(against_clause) return "MATCH (%s) AGAINST (%s)" % (match_clause, against_clause) - def get_from_hint_text( - self, table: selectable.FromClause, text: Optional[str] - ) -> Optional[str]: + def get_from_hint_text(self, table, text): return text - def visit_typeclause( - self, - typeclause: elements.TypeClause, - type_: Optional[TypeEngine[Any]] = None, - **kw: Any, - ) -> Optional[str]: + def visit_typeclause(self, typeclause, type_=None, **kw): if type_ is None: type_ = typeclause.type.dialect_impl(self.dialect) if isinstance(type_, sqltypes.TypeDecorator): - return self.visit_typeclause(typeclause, type_.impl, **kw) # type: ignore[arg-type] # noqa: E501 + return self.visit_typeclause(typeclause, type_.impl, **kw) elif isinstance(type_, sqltypes.Integer): if getattr(type_, "unsigned", False): return "UNSIGNED INTEGER" @@ -1681,7 +1540,7 @@ class MySQLCompiler(compiler.SQLCompiler): else: return None - def visit_cast(self, cast: elements.Cast[Any], **kw: Any) -> str: + def visit_cast(self, cast, **kw): type_ = self.process(cast.typeclause) if type_ is None: util.warn( @@ -1695,9 +1554,7 @@ class MySQLCompiler(compiler.SQLCompiler): return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) - def render_literal_value( - self, value: Optional[str], type_: TypeEngine[Any] - ) -> str: + def render_literal_value(self, value, type_): value = super().render_literal_value(value, type_) if self.dialect._backslash_escapes: value = value.replace("\\", "\\\\") @@ -1705,18 +1562,16 @@ class MySQLCompiler(compiler.SQLCompiler): # override native_boolean=False behavior here, as # MySQL still supports native boolean - def visit_true(self, expr: elements.True_, **kw: Any) -> str: + def visit_true(self, element, **kw): return "true" - def visit_false(self, expr: elements.False_, **kw: Any) -> str: + def visit_false(self, element, **kw): return "false" - def get_select_precolumns( - self, select: selectable.Select[Any], **kw: Any - ) -> str: + def get_select_precolumns(self, select, **kw): """Add special MySQL keywords in place of DISTINCT. - .. deprecated:: 1.4 This usage is deprecated. + .. deprecated 1.4:: this usage is deprecated. :meth:`_expression.Select.prefix_with` should be used for special keywords at the start of a SELECT. @@ -1733,13 +1588,7 @@ class MySQLCompiler(compiler.SQLCompiler): return super().get_select_precolumns(select, **kw) - def visit_join( - self, - join: selectable.Join, - asfrom: bool = False, - from_linter: Optional[compiler.FromLinter] = None, - **kwargs: Any, - ) -> str: + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): if from_linter: from_linter.edges.add((join.left, join.right)) @@ -1760,21 +1609,18 @@ class MySQLCompiler(compiler.SQLCompiler): join.right, asfrom=True, from_linter=from_linter, **kwargs ), " ON ", - self.process(join.onclause, from_linter=from_linter, **kwargs), # type: ignore[arg-type] # noqa: E501 + self.process(join.onclause, from_linter=from_linter, **kwargs), ) ) - def for_update_clause( - self, select: selectable.GenerativeSelect, **kw: Any - ) -> str: - assert select._for_update_arg is not None + def for_update_clause(self, select, **kw): if select._for_update_arg.read: tmp = " LOCK IN SHARE MODE" else: tmp = " FOR UPDATE" if select._for_update_arg.of and self.dialect.supports_for_update_of: - tables: util.OrderedSet[elements.ClauseElement] = util.OrderedSet() + tables = util.OrderedSet() for c in select._for_update_arg.of: tables.update(sql_util.surface_selectables_only(c)) @@ -1791,9 +1637,7 @@ class MySQLCompiler(compiler.SQLCompiler): return tmp - def limit_clause( - self, select: selectable.GenerativeSelect, **kw: Any - ) -> str: + def limit_clause(self, select, **kw): # MySQL supports: # LIMIT # LIMIT , @@ -1829,31 +1673,17 @@ class MySQLCompiler(compiler.SQLCompiler): self.process(limit_clause, **kw), ) else: - assert limit_clause is not None # No offset provided, so just use the limit return " \n LIMIT %s" % (self.process(limit_clause, **kw),) - def update_limit_clause(self, update_stmt: Update) -> Optional[str]: + def update_limit_clause(self, update_stmt): limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) - if limit is not None: - return f"LIMIT {int(limit)}" + if limit: + return "LIMIT %s" % limit else: return None - def delete_limit_clause(self, delete_stmt: Delete) -> Optional[str]: - limit = delete_stmt.kwargs.get("%s_limit" % self.dialect.name, None) - if limit is not None: - return f"LIMIT {int(limit)}" - else: - return None - - def update_tables_clause( - self, - update_stmt: Update, - from_table: _DMLTableElement, - extra_froms: List[selectable.FromClause], - **kw: Any, - ) -> str: + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): kw["asfrom"] = True return ", ".join( t._compiler_dispatch(self, **kw) @@ -1861,22 +1691,11 @@ class MySQLCompiler(compiler.SQLCompiler): ) def update_from_clause( - self, - update_stmt: Update, - from_table: _DMLTableElement, - extra_froms: List[selectable.FromClause], - from_hints: Any, - **kw: Any, - ) -> None: + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): return None - def delete_table_clause( - self, - delete_stmt: Delete, - from_table: _DMLTableElement, - extra_froms: List[selectable.FromClause], - **kw: Any, - ) -> str: + def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1886,13 +1705,8 @@ class MySQLCompiler(compiler.SQLCompiler): ) def delete_extra_from_clause( - self, - delete_stmt: Delete, - from_table: _DMLTableElement, - extra_froms: List[selectable.FromClause], - from_hints: Any, - **kw: Any, - ) -> str: + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to MySQL.""" kw["asfrom"] = True return "USING " + ", ".join( @@ -1900,9 +1714,7 @@ class MySQLCompiler(compiler.SQLCompiler): for t in [from_table] + extra_froms ) - def visit_empty_set_expr( - self, element_types: List[TypeEngine[Any]], **kw: Any - ) -> str: + def visit_empty_set_expr(self, element_types, **kw): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " "as _empty_set WHERE 1!=1" @@ -1917,38 +1729,25 @@ class MySQLCompiler(compiler.SQLCompiler): } ) - def visit_is_distinct_from_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_is_distinct_from_binary(self, binary, operator, **kw): return "NOT (%s <=> %s)" % ( self.process(binary.left), self.process(binary.right), ) - def visit_is_not_distinct_from_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_is_not_distinct_from_binary(self, binary, operator, **kw): return "%s <=> %s" % ( self.process(binary.left), self.process(binary.right), ) - def _mariadb_regexp_flags( - self, flags: str, pattern: elements.ColumnElement[Any], **kw: Any - ) -> str: + def _mariadb_regexp_flags(self, flags, pattern, **kw): return "CONCAT('(?', %s, ')', %s)" % ( self.render_literal_value(flags, sqltypes.STRINGTYPE), self.process(pattern, **kw), ) - def _regexp_match( - self, - op_string: str, - binary: elements.BinaryExpression[Any], - operator: Any, - **kw: Any, - ) -> str: - assert binary.modifiers is not None + def _regexp_match(self, op_string, binary, operator, **kw): flags = binary.modifiers["flags"] if flags is None: return self._generate_generic_binary(binary, op_string, **kw) @@ -1969,20 +1768,13 @@ class MySQLCompiler(compiler.SQLCompiler): else: return text - def visit_regexp_match_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_regexp_match_op_binary(self, binary, operator, **kw): return self._regexp_match(" REGEXP ", binary, operator, **kw) - def visit_not_regexp_match_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return self._regexp_match(" NOT REGEXP ", binary, operator, **kw) - def visit_regexp_replace_op_binary( - self, binary: elements.BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: - assert binary.modifiers is not None + def visit_regexp_replace_op_binary(self, binary, operator, **kw): flags = binary.modifiers["flags"] if flags is None: return "REGEXP_REPLACE(%s, %s)" % ( @@ -2004,11 +1796,7 @@ class MySQLCompiler(compiler.SQLCompiler): class MySQLDDLCompiler(compiler.DDLCompiler): - dialect: MySQLDialect - - def get_column_specification( - self, column: sa_schema.Column[Any], **kw: Any - ) -> str: + def get_column_specification(self, column, **kw): """Builds column DDL.""" if ( self.dialect.is_mariadb is True @@ -2061,25 +1849,11 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec.append("AUTO_INCREMENT") else: default = self.get_column_default_string(column) - if default is not None: - if ( - self.dialect._support_default_function - and not re.match(r"^\s*[\'\"\(]", default) - and not re.search(r"ON +UPDATE", default, re.I) - and not re.match( - r"\bnow\(\d+\)|\bcurrent_timestamp\(\d+\)", - default, - re.I, - ) - and re.match(r".*\W.*", default) - ): - colspec.append(f"DEFAULT ({default})") - else: - colspec.append("DEFAULT " + default) + colspec.append("DEFAULT " + default) return " ".join(colspec) - def post_create_table(self, table: sa_schema.Table) -> str: + def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" table_opts = [] @@ -2163,27 +1937,25 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return " ".join(table_opts) - def visit_create_index(self, create: ddl.CreateIndex, **kw: Any) -> str: # type: ignore[override] # noqa: E501 + def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer - table = preparer.format_table(index.table) # type: ignore[arg-type] + table = preparer.format_table(index.table) columns = [ self.sql_compiler.process( - ( - elements.Grouping(expr) # type: ignore[arg-type] - if ( - isinstance(expr, elements.BinaryExpression) - or ( - isinstance(expr, elements.UnaryExpression) - and expr.modifier - not in (operators.desc_op, operators.asc_op) - ) - or isinstance(expr, functions.FunctionElement) + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) ) - else expr - ), + or isinstance(expr, functions.FunctionElement) + ) + else expr, include_table=False, literal_binds=True, ) @@ -2211,27 +1983,25 @@ class MySQLDDLCompiler(compiler.DDLCompiler): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns_str = ", ".join( - ( - "%s(%d)" % (expr, length[col.name]) # type: ignore[union-attr] # noqa: E501 - if col.name in length # type: ignore[union-attr] - else ( - "%s(%d)" % (expr, length[expr]) - if expr in length - else "%s" % expr - ) + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns_str = ", ".join( + columns = ", ".join( "%s(%d)" % (col, length) for col in columns ) else: - columns_str = ", ".join(columns) - text += "(%s)" % columns_str + columns = ", ".join(columns) + text += "(%s)" % columns parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: @@ -2243,16 +2013,14 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return text - def visit_primary_key_constraint( - self, constraint: sa_schema.PrimaryKeyConstraint, **kw: Any - ) -> str: + def visit_primary_key_constraint(self, constraint, **kw): text = super().visit_primary_key_constraint(constraint) using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text - def visit_drop_index(self, drop: ddl.DropIndex, **kw: Any) -> str: + def visit_drop_index(self, drop, **kw): index = drop.element text = "\nDROP INDEX " if drop.if_exists: @@ -2260,12 +2028,10 @@ class MySQLDDLCompiler(compiler.DDLCompiler): return text + "%s ON %s" % ( self._prepared_index_name(index, include_schema=False), - self.preparer.format_table(index.table), # type: ignore[arg-type] + self.preparer.format_table(index.table), ) - def visit_drop_constraint( - self, drop: ddl.DropConstraint, **kw: Any - ) -> str: + def visit_drop_constraint(self, drop, **kw): constraint = drop.element if isinstance(constraint, sa_schema.ForeignKeyConstraint): qual = "FOREIGN KEY " @@ -2291,9 +2057,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): const, ) - def define_constraint_match( - self, constraint: sa_schema.ForeignKeyConstraint - ) -> str: + def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " @@ -2301,9 +2065,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ) return "" - def visit_set_table_comment( - self, create: ddl.SetTableComment, **kw: Any - ) -> str: + def visit_set_table_comment(self, create, **kw): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( @@ -2311,16 +2073,12 @@ class MySQLDDLCompiler(compiler.DDLCompiler): ), ) - def visit_drop_table_comment( - self, drop: ddl.DropTableComment, **kw: Any - ) -> str: + def visit_drop_table_comment(self, create, **kw): return "ALTER TABLE %s COMMENT ''" % ( - self.preparer.format_table(drop.element) + self.preparer.format_table(create.element) ) - def visit_set_column_comment( - self, create: ddl.SetColumnComment, **kw: Any - ) -> str: + def visit_set_column_comment(self, create, **kw): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), @@ -2329,7 +2087,7 @@ class MySQLDDLCompiler(compiler.DDLCompiler): class MySQLTypeCompiler(compiler.GenericTypeCompiler): - def _extend_numeric(self, type_: _NumericType, spec: str) -> str: + def _extend_numeric(self, type_, spec): "Extend a numeric-type declaration with MySQL specific extensions." if not self._mysql_type(type_): @@ -2341,15 +2099,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): spec += " ZEROFILL" return spec - def _extend_string( - self, type_: _StringType, defaults: Dict[str, Any], spec: str - ) -> str: + def _extend_string(self, type_, defaults, spec): """Extend a string-type declaration with standard SQL CHARACTER SET / COLLATE annotations and MySQL specific extensions. """ - def attr(name: str) -> Any: + def attr(name): return getattr(type_, name, defaults.get(name)) if attr("charset"): @@ -2359,7 +2115,6 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): elif attr("unicode"): charset = "UNICODE" else: - charset = None if attr("collation"): @@ -2378,10 +2133,10 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): [c for c in (spec, charset, collation) if c is not None] ) - def _mysql_type(self, type_: Any) -> bool: + def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) - def visit_NUMERIC(self, type_: NUMERIC, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_NUMERIC(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: @@ -2396,7 +2151,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): % {"precision": type_.precision, "scale": type_.scale}, ) - def visit_DECIMAL(self, type_: DECIMAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: @@ -2411,7 +2166,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): % {"precision": type_.precision, "scale": type_.scale}, ) - def visit_DOUBLE(self, type_: DOUBLE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric( type_, @@ -2421,7 +2176,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "DOUBLE") - def visit_REAL(self, type_: REAL, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: return self._extend_numeric( type_, @@ -2431,7 +2186,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "REAL") - def visit_FLOAT(self, type_: FLOAT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_FLOAT(self, type_, **kw): if ( self._mysql_type(type_) and type_.scale is not None @@ -2447,7 +2202,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "FLOAT") - def visit_INTEGER(self, type_: INTEGER, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2457,7 +2212,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "INTEGER") - def visit_BIGINT(self, type_: BIGINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2467,7 +2222,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "BIGINT") - def visit_MEDIUMINT(self, type_: MEDIUMINT, **kw: Any) -> str: + def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2477,7 +2232,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "MEDIUMINT") - def visit_TINYINT(self, type_: TINYINT, **kw: Any) -> str: + def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, "TINYINT(%s)" % type_.display_width @@ -2485,7 +2240,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "TINYINT") - def visit_SMALLINT(self, type_: SMALLINT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( type_, @@ -2495,55 +2250,55 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_numeric(type_, "SMALLINT") - def visit_BIT(self, type_: BIT, **kw: Any) -> str: + def visit_BIT(self, type_, **kw): if type_.length is not None: return "BIT(%s)" % type_.length else: return "BIT" - def visit_DATETIME(self, type_: DATETIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_DATETIME(self, type_, **kw): if getattr(type_, "fsp", None): - return "DATETIME(%d)" % type_.fsp # type: ignore[str-format] + return "DATETIME(%d)" % type_.fsp else: return "DATETIME" - def visit_DATE(self, type_: DATE, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_DATE(self, type_, **kw): return "DATE" - def visit_TIME(self, type_: TIME, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_TIME(self, type_, **kw): if getattr(type_, "fsp", None): - return "TIME(%d)" % type_.fsp # type: ignore[str-format] + return "TIME(%d)" % type_.fsp else: return "TIME" - def visit_TIMESTAMP(self, type_: TIMESTAMP, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_TIMESTAMP(self, type_, **kw): if getattr(type_, "fsp", None): - return "TIMESTAMP(%d)" % type_.fsp # type: ignore[str-format] + return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" - def visit_YEAR(self, type_: YEAR, **kw: Any) -> str: + def visit_YEAR(self, type_, **kw): if type_.display_width is None: return "YEAR" else: return "YEAR(%s)" % type_.display_width - def visit_TEXT(self, type_: TEXT, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_TEXT(self, type_, **kw): if type_.length is not None: return self._extend_string(type_, {}, "TEXT(%d)" % type_.length) else: return self._extend_string(type_, {}, "TEXT") - def visit_TINYTEXT(self, type_: TINYTEXT, **kw: Any) -> str: + def visit_TINYTEXT(self, type_, **kw): return self._extend_string(type_, {}, "TINYTEXT") - def visit_MEDIUMTEXT(self, type_: MEDIUMTEXT, **kw: Any) -> str: + def visit_MEDIUMTEXT(self, type_, **kw): return self._extend_string(type_, {}, "MEDIUMTEXT") - def visit_LONGTEXT(self, type_: LONGTEXT, **kw: Any) -> str: + def visit_LONGTEXT(self, type_, **kw): return self._extend_string(type_, {}, "LONGTEXT") - def visit_VARCHAR(self, type_: VARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_VARCHAR(self, type_, **kw): if type_.length is not None: return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: @@ -2551,7 +2306,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "VARCHAR requires a length on dialect %s" % self.dialect.name ) - def visit_CHAR(self, type_: CHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_CHAR(self, type_, **kw): if type_.length is not None: return self._extend_string( type_, {}, "CHAR(%(length)s)" % {"length": type_.length} @@ -2559,7 +2314,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {}, "CHAR") - def visit_NVARCHAR(self, type_: NVARCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_NVARCHAR(self, type_, **kw): # We'll actually generate the equiv. "NATIONAL VARCHAR" instead # of "NVARCHAR". if type_.length is not None: @@ -2573,7 +2328,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): "NVARCHAR requires a length on dialect %s" % self.dialect.name ) - def visit_NCHAR(self, type_: NCHAR, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length is not None: @@ -2585,70 +2340,61 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): else: return self._extend_string(type_, {"national": True}, "CHAR") - def visit_UUID(self, type_: UUID[Any], **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_UUID(self, type_, **kw): return "UUID" - def visit_VARBINARY(self, type_: VARBINARY, **kw: Any) -> str: - return "VARBINARY(%d)" % type_.length # type: ignore[str-format] + def visit_VARBINARY(self, type_, **kw): + return "VARBINARY(%d)" % type_.length - def visit_JSON(self, type_: JSON, **kw: Any) -> str: + def visit_JSON(self, type_, **kw): return "JSON" - def visit_large_binary(self, type_: LargeBinary, **kw: Any) -> str: + def visit_large_binary(self, type_, **kw): return self.visit_BLOB(type_) - def visit_enum(self, type_: ENUM, **kw: Any) -> str: # type: ignore[override] # NOQA: E501 + def visit_enum(self, type_, **kw): if not type_.native_enum: return super().visit_enum(type_) else: return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_BLOB(self, type_: LargeBinary, **kw: Any) -> str: + def visit_BLOB(self, type_, **kw): if type_.length is not None: return "BLOB(%d)" % type_.length else: return "BLOB" - def visit_TINYBLOB(self, type_: TINYBLOB, **kw: Any) -> str: + def visit_TINYBLOB(self, type_, **kw): return "TINYBLOB" - def visit_MEDIUMBLOB(self, type_: MEDIUMBLOB, **kw: Any) -> str: + def visit_MEDIUMBLOB(self, type_, **kw): return "MEDIUMBLOB" - def visit_LONGBLOB(self, type_: LONGBLOB, **kw: Any) -> str: + def visit_LONGBLOB(self, type_, **kw): return "LONGBLOB" - def _visit_enumerated_values( - self, name: str, type_: _StringType, enumerated_values: Sequence[str] - ) -> str: + def _visit_enumerated_values(self, name, type_, enumerated_values): quoted_enums = [] for e in enumerated_values: - if self.dialect.identifier_preparer._double_percents: - e = e.replace("%", "%%") quoted_enums.append("'%s'" % e.replace("'", "''")) return self._extend_string( type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) - def visit_ENUM(self, type_: ENUM, **kw: Any) -> str: + def visit_ENUM(self, type_, **kw): return self._visit_enumerated_values("ENUM", type_, type_.enums) - def visit_SET(self, type_: SET, **kw: Any) -> str: + def visit_SET(self, type_, **kw): return self._visit_enumerated_values("SET", type_, type_.values) - def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str: + def visit_BOOLEAN(self, type_, **kw): return "BOOL" class MySQLIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS_MYSQL - def __init__( - self, - dialect: default.DefaultDialect, - server_ansiquotes: bool = False, - **kw: Any, - ): + def __init__(self, dialect, server_ansiquotes=False, **kw): if not server_ansiquotes: quote = "`" else: @@ -2656,7 +2402,7 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): super().__init__(dialect, initial_quote=quote, escape_quote=quote) - def _quote_free_identifiers(self, *ids: Optional[str]) -> Tuple[str, ...]: + def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" return tuple([self.quote_identifier(i) for i in ids if i is not None]) @@ -2666,6 +2412,7 @@ class MariaDBIdentifierPreparer(MySQLIdentifierPreparer): reserved_words = RESERVED_WORDS_MARIADB +@log.class_logger class MySQLDialect(default.DefaultDialect): """Details of the MySQL dialect. Not used directly in application code. @@ -2680,10 +2427,6 @@ class MySQLDialect(default.DefaultDialect): # allow for the "true" and "false" keywords, however supports_native_boolean = False - # support for BIT type; mysqlconnector coerces result values automatically, - # all other MySQL DBAPIs require a conversion routine - supports_native_bit = False - # identifiers are 64, however aliases can be 255... max_identifier_length = 255 max_index_name_length = 64 @@ -2732,9 +2475,9 @@ class MySQLDialect(default.DefaultDialect): ddl_compiler = MySQLDDLCompiler type_compiler_cls = MySQLTypeCompiler ischema_names = ischema_names - preparer: type[MySQLIdentifierPreparer] = MySQLIdentifierPreparer + preparer = MySQLIdentifierPreparer - is_mariadb: bool = False + is_mariadb = False _mariadb_normalized_version_info = None # default SQL compilation settings - @@ -2743,13 +2486,9 @@ class MySQLDialect(default.DefaultDialect): _backslash_escapes = True _server_ansiquotes = False - server_version_info: Tuple[int, ...] - identifier_preparer: MySQLIdentifierPreparer - construct_arguments = [ (sa_schema.Table, {"*": None}), (sql.Update, {"limit": None}), - (sql.Delete, {"limit": None}), (sa_schema.PrimaryKeyConstraint, {"using": None}), ( sa_schema.Index, @@ -2764,20 +2503,18 @@ class MySQLDialect(default.DefaultDialect): def __init__( self, - json_serializer: Optional[Callable[..., Any]] = None, - json_deserializer: Optional[Callable[..., Any]] = None, - is_mariadb: Optional[bool] = None, - **kwargs: Any, - ) -> None: + json_serializer=None, + json_deserializer=None, + is_mariadb=None, + **kwargs, + ): kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self._json_serializer = json_serializer self._json_deserializer = json_deserializer - self._set_mariadb(is_mariadb, ()) + self._set_mariadb(is_mariadb, None) - def get_isolation_level_values( - self, dbapi_conn: DBAPIConnection - ) -> Sequence[IsolationLevel]: + def get_isolation_level_values(self, dbapi_conn): return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -2785,17 +2522,13 @@ class MySQLDialect(default.DefaultDialect): "REPEATABLE READ", ) - def set_isolation_level( - self, dbapi_connection: DBAPIConnection, level: IsolationLevel - ) -> None: + def set_isolation_level(self, dbapi_connection, level): cursor = dbapi_connection.cursor() cursor.execute(f"SET SESSION TRANSACTION ISOLATION LEVEL {level}") cursor.execute("COMMIT") cursor.close() - def get_isolation_level( - self, dbapi_connection: DBAPIConnection - ) -> IsolationLevel: + def get_isolation_level(self, dbapi_connection): cursor = dbapi_connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): cursor.execute("SELECT @@transaction_isolation") @@ -2812,10 +2545,10 @@ class MySQLDialect(default.DefaultDialect): cursor.close() if isinstance(val, bytes): val = val.decode() - return val.upper().replace("-", " ") # type: ignore[no-any-return] + return val.upper().replace("-", " ") @classmethod - def _is_mariadb_from_url(cls, url: URL) -> bool: + def _is_mariadb_from_url(cls, url): dbapi = cls.import_dbapi() dialect = cls(dbapi=dbapi) @@ -2824,7 +2557,7 @@ class MySQLDialect(default.DefaultDialect): try: cursor = conn.cursor() cursor.execute("SELECT VERSION() LIKE '%MariaDB%'") - val = cursor.fetchone()[0] # type: ignore[index] + val = cursor.fetchone()[0] except: raise else: @@ -2832,25 +2565,22 @@ class MySQLDialect(default.DefaultDialect): finally: conn.close() - def _get_server_version_info( - self, connection: Connection - ) -> Tuple[int, ...]: + def _get_server_version_info(self, connection): # get database server version info explicitly over the wire # to avoid proxy servers like MaxScale getting in the # way with their own values, see #4205 dbapi_con = connection.connection cursor = dbapi_con.cursor() cursor.execute("SELECT VERSION()") - - val = cursor.fetchone()[0] # type: ignore[index] + val = cursor.fetchone()[0] cursor.close() if isinstance(val, bytes): val = val.decode() return self._parse_server_version(val) - def _parse_server_version(self, val: str) -> Tuple[int, ...]: - version: List[int] = [] + def _parse_server_version(self, val): + version = [] is_mariadb = False r = re.compile(r"[.\-+]") @@ -2871,7 +2601,7 @@ class MySQLDialect(default.DefaultDialect): server_version_info = tuple(version) self._set_mariadb( - bool(server_version_info and is_mariadb), server_version_info + server_version_info and is_mariadb, server_version_info ) if not is_mariadb: @@ -2887,9 +2617,7 @@ class MySQLDialect(default.DefaultDialect): self.server_version_info = server_version_info return server_version_info - def _set_mariadb( - self, is_mariadb: Optional[bool], server_version_info: Tuple[int, ...] - ) -> None: + def _set_mariadb(self, is_mariadb, server_version_info): if is_mariadb is None: return @@ -2899,12 +2627,10 @@ class MySQLDialect(default.DefaultDialect): % (".".join(map(str, server_version_info)),) ) if is_mariadb: - - if not issubclass(self.preparer, MariaDBIdentifierPreparer): - self.preparer = MariaDBIdentifierPreparer - # this would have been set by the default dialect already, - # so set it again - self.identifier_preparer = self.preparer(self) + self.preparer = MariaDBIdentifierPreparer + # this would have been set by the default dialect already, + # so set it again + self.identifier_preparer = self.preparer(self) # this will be updated on first connect in initialize() # if using older mariadb version @@ -2913,54 +2639,38 @@ class MySQLDialect(default.DefaultDialect): self.is_mariadb = is_mariadb - def do_begin_twophase(self, connection: Connection, xid: Any) -> None: + def do_begin_twophase(self, connection, xid): connection.execute(sql.text("XA BEGIN :xid"), dict(xid=xid)) - def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: + def do_prepare_twophase(self, connection, xid): connection.execute(sql.text("XA END :xid"), dict(xid=xid)) connection.execute(sql.text("XA PREPARE :xid"), dict(xid=xid)) def do_rollback_twophase( - self, - connection: Connection, - xid: Any, - is_prepared: bool = True, - recover: bool = False, - ) -> None: + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute(sql.text("XA END :xid"), dict(xid=xid)) connection.execute(sql.text("XA ROLLBACK :xid"), dict(xid=xid)) def do_commit_twophase( - self, - connection: Connection, - xid: Any, - is_prepared: bool = True, - recover: bool = False, - ) -> None: + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), dict(xid=xid)) - def do_recover_twophase(self, connection: Connection) -> List[Any]: + def do_recover_twophase(self, connection): resultset = connection.exec_driver_sql("XA RECOVER") - return [ - row["data"][0 : row["gtrid_length"]] - for row in resultset.mappings() - ] + return [row["data"][0 : row["gtrid_length"]] for row in resultset] - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): if isinstance( e, ( - self.dbapi.OperationalError, # type: ignore - self.dbapi.ProgrammingError, # type: ignore - self.dbapi.InterfaceError, # type: ignore + self.dbapi.OperationalError, + self.dbapi.ProgrammingError, + self.dbapi.InterfaceError, ), ) and self._extract_error_code(e) in ( 1927, @@ -2973,7 +2683,7 @@ class MySQLDialect(default.DefaultDialect): ): return True elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError) # type: ignore # noqa: E501 + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) ): # if underlying connection is closed, # this is the error you get @@ -2981,17 +2691,13 @@ class MySQLDialect(default.DefaultDialect): else: return False - def _compat_fetchall( - self, rp: CursorResult[Any], charset: Optional[str] = None - ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: + def _compat_fetchall(self, rp, charset=None): """Proxy result rows to smooth over MySQL-Python driver inconsistencies.""" return [_DecodingRow(row, charset) for row in rp.fetchall()] - def _compat_fetchone( - self, rp: CursorResult[Any], charset: Optional[str] = None - ) -> Union[Row[Any], None, _DecodingRow]: + def _compat_fetchone(self, rp, charset=None): """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" @@ -3001,9 +2707,7 @@ class MySQLDialect(default.DefaultDialect): else: return None - def _compat_first( - self, rp: CursorResult[Any], charset: Optional[str] = None - ) -> Optional[_DecodingRow]: + def _compat_first(self, rp, charset=None): """Proxy a result row to smooth over MySQL-Python driver inconsistencies.""" @@ -3013,22 +2717,14 @@ class MySQLDialect(default.DefaultDialect): else: return None - def _extract_error_code( - self, exception: DBAPIModule.Error - ) -> Optional[int]: + def _extract_error_code(self, exception): raise NotImplementedError() - def _get_default_schema_name(self, connection: Connection) -> str: - return connection.exec_driver_sql("SELECT DATABASE()").scalar() # type: ignore[return-value] # noqa: E501 + def _get_default_schema_name(self, connection): + return connection.exec_driver_sql("SELECT DATABASE()").scalar() @reflection.cache - def has_table( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> bool: + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) if schema is None: @@ -3069,18 +2765,12 @@ class MySQLDialect(default.DefaultDialect): # # there's more "doesn't exist" kinds of messages but they are # less clear if mysql 8 would suddenly start using one of those - if self._extract_error_code(e.orig) in (1146, 1049, 1051): # type: ignore # noqa: E501 + if self._extract_error_code(e.orig) in (1146, 1049, 1051): return False raise @reflection.cache - def has_sequence( - self, - connection: Connection, - sequence_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> bool: + def has_sequence(self, connection, sequence_name, schema=None, **kw): if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -3100,16 +2790,14 @@ class MySQLDialect(default.DefaultDialect): ) return cursor.first() is not None - def _sequences_not_supported(self) -> NoReturn: + def _sequences_not_supported(self): raise NotImplementedError( "Sequences are supported only by the " "MariaDB series 10.3 or greater" ) @reflection.cache - def get_sequence_names( - self, connection: Connection, schema: Optional[str] = None, **kw: Any - ) -> List[str]: + def get_sequence_names(self, connection, schema=None, **kw): if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -3129,12 +2817,10 @@ class MySQLDialect(default.DefaultDialect): ) ] - def initialize(self, connection: Connection) -> None: + def initialize(self, connection): # this is driver-based, does not need server version info # and is fairly critical for even basic SQL operations - self._connection_charset: Optional[str] = self._detect_charset( - connection - ) + self._connection_charset = self._detect_charset(connection) # call super().initialize() because we need to have # server_version_info set up. in 1.4 under python 2 only this does the @@ -3178,10 +2864,9 @@ class MySQLDialect(default.DefaultDialect): self._warn_for_known_db_issues() - def _warn_for_known_db_issues(self) -> None: + def _warn_for_known_db_issues(self): if self.is_mariadb: mdb_version = self._mariadb_normalized_version_info - assert mdb_version is not None if mdb_version > (10, 2) and mdb_version < (10, 2, 9): util.warn( "MariaDB %r before 10.2.9 has known issues regarding " @@ -3194,7 +2879,7 @@ class MySQLDialect(default.DefaultDialect): ) @property - def _support_float_cast(self) -> bool: + def _support_float_cast(self): if not self.server_version_info: return False elif self.is_mariadb: @@ -3205,49 +2890,32 @@ class MySQLDialect(default.DefaultDialect): return self.server_version_info >= (8, 0, 17) @property - def _support_default_function(self) -> bool: - if not self.server_version_info: - return False - elif self.is_mariadb: - # ref https://mariadb.com/kb/en/mariadb-1021-release-notes/ - return self.server_version_info >= (10, 2, 1) - else: - # ref https://dev.mysql.com/doc/refman/8.0/en/data-type-defaults.html # noqa - return self.server_version_info >= (8, 0, 13) - - @property - def _is_mariadb(self) -> bool: + def _is_mariadb(self): return self.is_mariadb @property - def _is_mysql(self) -> bool: + def _is_mysql(self): return not self.is_mariadb @property - def _is_mariadb_102(self) -> bool: - return ( - self.is_mariadb - and self._mariadb_normalized_version_info # type:ignore[operator] - > ( - 10, - 2, - ) + def _is_mariadb_102(self): + return self.is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, ) @reflection.cache - def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: + def get_schema_names(self, connection, **kw): rp = connection.exec_driver_sql("SHOW schemas") return [r[0] for r in rp] @reflection.cache - def get_table_names( - self, connection: Connection, schema: Optional[str] = None, **kw: Any - ) -> List[str]: + def get_table_names(self, connection, schema=None, **kw): """Return a Unicode SHOW TABLES from a given schema.""" if schema is not None: - current_schema: str = schema + current_schema = schema else: - current_schema = self.default_schema_name # type: ignore + current_schema = self.default_schema_name charset = self._connection_charset @@ -3263,12 +2931,9 @@ class MySQLDialect(default.DefaultDialect): ] @reflection.cache - def get_view_names( - self, connection: Connection, schema: Optional[str] = None, **kw: Any - ) -> List[str]: + def get_view_names(self, connection, schema=None, **kw): if schema is None: schema = self.default_schema_name - assert schema is not None charset = self._connection_charset rp = connection.exec_driver_sql( "SHOW FULL TABLES FROM %s" @@ -3281,13 +2946,7 @@ class MySQLDialect(default.DefaultDialect): ] @reflection.cache - def get_table_options( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> Dict[str, Any]: + def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -3297,13 +2956,7 @@ class MySQLDialect(default.DefaultDialect): return ReflectionDefaults.table_options() @reflection.cache - def get_columns( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> List[ReflectedColumn]: + def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -3313,13 +2966,7 @@ class MySQLDialect(default.DefaultDialect): return ReflectionDefaults.columns() @reflection.cache - def get_pk_constraint( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> ReflectedPrimaryKeyConstraint: + def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -3331,19 +2978,13 @@ class MySQLDialect(default.DefaultDialect): return ReflectionDefaults.pk_constraint() @reflection.cache - def get_foreign_keys( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> List[ReflectedForeignKeyConstraint]: + def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) default_schema = None - fkeys: List[ReflectedForeignKeyConstraint] = [] + fkeys = [] for spec in parsed_state.fk_constraints: ref_name = spec["table"][-1] @@ -3363,7 +3004,7 @@ class MySQLDialect(default.DefaultDialect): if spec.get(opt, False) not in ("NO ACTION", None): con_kw[opt] = spec[opt] - fkey_d: ReflectedForeignKeyConstraint = { + fkey_d = { "name": spec["name"], "constrained_columns": loc_names, "referred_schema": ref_schema, @@ -3378,11 +3019,7 @@ class MySQLDialect(default.DefaultDialect): return fkeys if fkeys else ReflectionDefaults.foreign_keys() - def _correct_for_mysql_bugs_88718_96365( - self, - fkeys: List[ReflectedForeignKeyConstraint], - connection: Connection, - ) -> None: + def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): # Foreign key is always in lower case (MySQL 8.0) # https://bugs.mysql.com/bug.php?id=88718 # issue #4344 for SQLAlchemy @@ -3398,60 +3035,38 @@ class MySQLDialect(default.DefaultDialect): if self._casing in (1, 2): - def lower(s: str) -> str: + def lower(s): return s.lower() else: # if on case sensitive, there can be two tables referenced # with the same name different casing, so we need to use # case-sensitive matching. - def lower(s: str) -> str: + def lower(s): return s - default_schema_name: str = connection.dialect.default_schema_name # type: ignore # noqa: E501 - - # NOTE: using (table_schema, table_name, lower(column_name)) in (...) - # is very slow since mysql does not seem able to properly use indexse. - # Unpack the where condition instead. - schema_by_table_by_column: DefaultDict[ - str, DefaultDict[str, List[str]] - ] = DefaultDict(lambda: DefaultDict(list)) - for rec in fkeys: - sch = lower(rec["referred_schema"] or default_schema_name) - tbl = lower(rec["referred_table"]) - for col_name in rec["referred_columns"]: - schema_by_table_by_column[sch][tbl].append(col_name) - - if schema_by_table_by_column: - - condition = sql.or_( - *( - sql.and_( - _info_columns.c.table_schema == schema, - sql.or_( - *( - sql.and_( - _info_columns.c.table_name == table, - sql.func.lower( - _info_columns.c.column_name - ).in_(columns), - ) - for table, columns in tables.items() - ) - ), - ) - for schema, tables in schema_by_table_by_column.items() - ) + default_schema_name = connection.dialect.default_schema_name + col_tuples = [ + ( + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, ) + for rec in fkeys + for col_name in rec["referred_columns"] + ] - select = sql.select( - _info_columns.c.table_schema, - _info_columns.c.table_name, - _info_columns.c.column_name, - ).where(condition) - - correct_for_wrong_fk_case: CursorResult[Tuple[str, str, str]] = ( - connection.execute(select) + if col_tuples: + correct_for_wrong_fk_case = connection.execute( + sql.text( + """ + select table_schema, table_name, column_name + from information_schema.columns + where (table_schema, table_name, lower(column_name)) in + :table_data; + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + dict(table_data=col_tuples), ) # in casing=0, table name and schema name come back in their @@ -3464,41 +3079,35 @@ class MySQLDialect(default.DefaultDialect): # SHOW CREATE TABLE converts them to *lower case*, therefore # not matching. So for this case, case-insensitive lookup # is necessary - d: DefaultDict[Tuple[str, str], Dict[str, str]] = defaultdict(dict) + d = defaultdict(dict) for schema, tname, cname in correct_for_wrong_fk_case: d[(lower(schema), lower(tname))]["SCHEMANAME"] = schema d[(lower(schema), lower(tname))]["TABLENAME"] = tname d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - rec_b = d[ + rec = d[ ( lower(fkey["referred_schema"] or default_schema_name), lower(fkey["referred_table"]), ) ] - fkey["referred_table"] = rec_b["TABLENAME"] + fkey["referred_table"] = rec["TABLENAME"] if fkey["referred_schema"] is not None: - fkey["referred_schema"] = rec_b["SCHEMANAME"] + fkey["referred_schema"] = rec["SCHEMANAME"] fkey["referred_columns"] = [ - rec_b[col.lower()] for col in fkey["referred_columns"] + rec[col.lower()] for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> List[ReflectedCheckConstraint]: + def get_check_constraints(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - cks: List[ReflectedCheckConstraint] = [ + cks = [ {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @@ -3506,13 +3115,7 @@ class MySQLDialect(default.DefaultDialect): return cks if cks else ReflectionDefaults.check_constraints() @reflection.cache - def get_table_comment( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> ReflectedTableComment: + def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) @@ -3523,18 +3126,12 @@ class MySQLDialect(default.DefaultDialect): return ReflectionDefaults.table_comment() @reflection.cache - def get_indexes( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> List[ReflectedIndex]: + def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - indexes: List[ReflectedIndex] = [] + indexes = [] for spec in parsed_state.keys: dialect_options = {} @@ -3546,30 +3143,32 @@ class MySQLDialect(default.DefaultDialect): unique = True elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["%s_prefix" % self.name] = flavor - elif flavor is not None: - util.warn( + elif flavor is None: + pass + else: + self.logger.info( "Converting unknown KEY type %s to a plain KEY", flavor ) + pass if spec["parser"]: dialect_options["%s_with_parser" % (self.name)] = spec[ "parser" ] - index_d: ReflectedIndex = { - "name": spec["name"], - "column_names": [s[0] for s in spec["columns"]], - "unique": unique, - } + index_d = {} + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] mysql_length = { s[0]: s[1] for s in spec["columns"] if s[1] is not None } if mysql_length: dialect_options["%s_length" % self.name] = mysql_length + index_d["unique"] = unique if flavor: - index_d["type"] = flavor # type: ignore[typeddict-unknown-key] + index_d["type"] = flavor if dialect_options: index_d["dialect_options"] = dialect_options @@ -3580,17 +3179,13 @@ class MySQLDialect(default.DefaultDialect): @reflection.cache def get_unique_constraints( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> List[ReflectedUniqueConstraint]: + self, connection, table_name, schema=None, **kw + ): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - ucs: List[ReflectedUniqueConstraint] = [ + ucs = [ { "name": key["name"], "column_names": [col[0] for col in key["columns"]], @@ -3606,13 +3201,7 @@ class MySQLDialect(default.DefaultDialect): return ReflectionDefaults.unique_constraints() @reflection.cache - def get_view_definition( - self, - connection: Connection, - view_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> str: + def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset full_name = ".".join( self.identifier_preparer._quote_free_identifiers(schema, view_name) @@ -3626,12 +3215,8 @@ class MySQLDialect(default.DefaultDialect): return sql def _parsed_state_or_create( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> _reflection.ReflectedState: + self, connection, table_name, schema=None, **kw + ): return self._setup_parser( connection, table_name, @@ -3640,7 +3225,7 @@ class MySQLDialect(default.DefaultDialect): ) @util.memoized_property - def _tabledef_parser(self) -> _reflection.MySQLTableDefinitionParser: + def _tabledef_parser(self): """return the MySQLTableDefinitionParser, generate if needed. The deferred creation ensures that the dialect has @@ -3651,13 +3236,7 @@ class MySQLDialect(default.DefaultDialect): return _reflection.MySQLTableDefinitionParser(self, preparer) @reflection.cache - def _setup_parser( - self, - connection: Connection, - table_name: str, - schema: Optional[str] = None, - **kw: Any, - ) -> _reflection.ReflectedState: + def _setup_parser(self, connection, table_name, schema=None, **kw): charset = self._connection_charset parser = self._tabledef_parser full_name = ".".join( @@ -3673,14 +3252,10 @@ class MySQLDialect(default.DefaultDialect): columns = self._describe_table( connection, None, charset, full_name=full_name ) - sql = parser._describe_to_create( - table_name, columns # type: ignore[arg-type] - ) + sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) - def _fetch_setting( - self, connection: Connection, setting_name: str - ) -> Optional[str]: + def _fetch_setting(self, connection, setting_name): charset = self._connection_charset if self.server_version_info and self.server_version_info < (5, 6): @@ -3695,12 +3270,12 @@ class MySQLDialect(default.DefaultDialect): if not row: return None else: - return cast(Optional[str], row[fetch_col]) + return row[fetch_col] - def _detect_charset(self, connection: Connection) -> str: + def _detect_charset(self, connection): raise NotImplementedError() - def _detect_casing(self, connection: Connection) -> int: + def _detect_casing(self, connection): """Sniff out identifier case sensitivity. Cached per-connection. This value can not change without a server @@ -3724,7 +3299,7 @@ class MySQLDialect(default.DefaultDialect): self._casing = cs return cs - def _detect_collations(self, connection: Connection) -> Dict[str, str]: + def _detect_collations(self, connection): """Pull the active COLLATIONS list from the server. Cached per-connection. @@ -3737,7 +3312,7 @@ class MySQLDialect(default.DefaultDialect): collations[row[0]] = row[1] return collations - def _detect_sql_mode(self, connection: Connection) -> None: + def _detect_sql_mode(self, connection): setting = self._fetch_setting(connection, "sql_mode") if setting is None: @@ -3749,7 +3324,7 @@ class MySQLDialect(default.DefaultDialect): else: self._sql_mode = setting or "" - def _detect_ansiquotes(self, connection: Connection) -> None: + def _detect_ansiquotes(self, connection): """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode @@ -3764,81 +3339,34 @@ class MySQLDialect(default.DefaultDialect): # as of MySQL 5.0.1 self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode - @overload def _show_create_table( - self, - connection: Connection, - table: Optional[Table], - charset: Optional[str], - full_name: str, - ) -> str: ... - - @overload - def _show_create_table( - self, - connection: Connection, - table: Table, - charset: Optional[str] = None, - full_name: None = None, - ) -> str: ... - - def _show_create_table( - self, - connection: Connection, - table: Optional[Table], - charset: Optional[str] = None, - full_name: Optional[str] = None, - ) -> str: + self, connection, table, charset=None, full_name=None + ): """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: - assert table is not None full_name = self.identifier_preparer.format_table(table) st = "SHOW CREATE TABLE %s" % full_name + rp = None try: rp = connection.execution_options( skip_user_error_events=True ).exec_driver_sql(st) except exc.DBAPIError as e: - if self._extract_error_code(e.orig) == 1146: # type: ignore[arg-type] # noqa: E501 + if self._extract_error_code(e.orig) == 1146: raise exc.NoSuchTableError(full_name) from e else: raise row = self._compat_first(rp, charset=charset) if not row: raise exc.NoSuchTableError(full_name) - return cast(str, row[1]).strip() + return row[1].strip() - @overload - def _describe_table( - self, - connection: Connection, - table: Optional[Table], - charset: Optional[str], - full_name: str, - ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: ... - - @overload - def _describe_table( - self, - connection: Connection, - table: Table, - charset: Optional[str] = None, - full_name: None = None, - ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: ... - - def _describe_table( - self, - connection: Connection, - table: Optional[Table], - charset: Optional[str] = None, - full_name: Optional[str] = None, - ) -> Union[Sequence[Row[Any]], Sequence[_DecodingRow]]: + def _describe_table(self, connection, table, charset=None, full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: - assert table is not None full_name = self.identifier_preparer.format_table(table) st = "DESCRIBE %s" % full_name @@ -3849,7 +3377,7 @@ class MySQLDialect(default.DefaultDialect): skip_user_error_events=True ).exec_driver_sql(st) except exc.DBAPIError as e: - code = self._extract_error_code(e.orig) # type: ignore[arg-type] # noqa: E501 + code = self._extract_error_code(e.orig) if code == 1146: raise exc.NoSuchTableError(full_name) from e @@ -3881,7 +3409,7 @@ class _DecodingRow: # sets.Set(['value']) (seriously) but thankfully that doesn't # seem to come up in DDL queries. - _encoding_compat: Dict[str, str] = { + _encoding_compat = { "koi8r": "koi8_r", "koi8u": "koi8_u", "utf16": "utf-16-be", # MySQL's uft16 is always bigendian @@ -3891,33 +3419,25 @@ class _DecodingRow: "eucjpms": "ujis", } - def __init__(self, rowproxy: Row[Any], charset: Optional[str]): + def __init__(self, rowproxy, charset): self.rowproxy = rowproxy - self.charset = ( - self._encoding_compat.get(charset, charset) - if charset is not None - else None - ) + self.charset = self._encoding_compat.get(charset, charset) - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index): item = self.rowproxy[index] + if isinstance(item, _array): + item = item.tostring() + if self.charset and isinstance(item, bytes): return item.decode(self.charset) else: return item - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, attr): item = getattr(self.rowproxy, attr) + if isinstance(item, _array): + item = item.tostring() if self.charset and isinstance(item, bytes): return item.decode(self.charset) else: return item - - -_info_columns = sql.table( - "columns", - sql.column("table_schema", VARCHAR(64)), - sql.column("table_name", VARCHAR(64)), - sql.column("column_name", VARCHAR(64)), - schema="information_schema", -) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/cymysql.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/cymysql.py index 1d48c4e..ed3c606 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/cymysql.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/cymysql.py @@ -1,9 +1,10 @@ -# dialects/mysql/cymysql.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/cymysql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" @@ -20,36 +21,18 @@ r""" dialects are mysqlclient and PyMySQL. """ # noqa -from __future__ import annotations - -from typing import Any -from typing import Iterable -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union +from .base import BIT from .base import MySQLDialect from .mysqldb import MySQLDialect_mysqldb -from .types import BIT from ... import util -if TYPE_CHECKING: - from ...engine.base import Connection - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import Dialect - from ...engine.interfaces import PoolProxiedConnection - from ...sql.type_api import _ResultProcessorType - class _cymysqlBIT(BIT): - def result_processor( - self, dialect: Dialect, coltype: object - ) -> Optional[_ResultProcessorType[Any]]: + def result_processor(self, dialect, coltype): """Convert MySQL's 64 bit, variable length binary string to a long.""" - def process(value: Optional[Iterable[int]]) -> Optional[int]: + def process(value): if value is not None: v = 0 for i in iter(value): @@ -72,22 +55,17 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls): return __import__("cymysql") - def _detect_charset(self, connection: Connection) -> str: - return connection.connection.charset # type: ignore[no-any-return] + def _detect_charset(self, connection): + return connection.connection.charset - def _extract_error_code(self, exception: DBAPIModule.Error) -> int: - return exception.errno # type: ignore[no-any-return] + def _extract_error_code(self, exception): + return exception.errno - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: - if isinstance(e, self.loaded_dbapi.OperationalError): + def is_disconnect(self, e, connection, cursor): + if isinstance(e, self.dbapi.OperationalError): return self._extract_error_code(e) in ( 2006, 2013, @@ -95,7 +73,7 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): 2045, 2055, ) - elif isinstance(e, self.loaded_dbapi.InterfaceError): + elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get return True diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/dml.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/dml.py index cceb081..dfa39f6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/dml.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/dml.py @@ -1,5 +1,5 @@ -# dialects/mysql/dml.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/dml.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,7 +7,6 @@ from __future__ import annotations from typing import Any -from typing import Dict from typing import List from typing import Mapping from typing import Optional @@ -142,11 +141,7 @@ class Insert(StandardInsert): in :ref:`tutorial_parameter_ordered_updates`:: insert().on_duplicate_key_update( - [ - ("name", "some name"), - ("value", "some value"), - ] - ) + [("name", "some name"), ("value", "some value")]) .. versionchanged:: 1.3 parameters can be specified as a dictionary or list of 2-tuples; the latter form provides for parameter @@ -186,7 +181,6 @@ class OnDuplicateClause(ClauseElement): _parameter_ordering: Optional[List[str]] = None - update: Dict[str, Any] stringify_dialect = "mysql" def __init__( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/enumerated.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/enumerated.py index ab30520..2e1d3c3 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/enumerated.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/enumerated.py @@ -1,51 +1,34 @@ -# dialects/mysql/enumerated.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/enumerated.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors -from __future__ import annotations -import enum import re -from typing import Any -from typing import Dict -from typing import Optional -from typing import Set -from typing import Type -from typing import TYPE_CHECKING -from typing import Union from .types import _StringType from ... import exc from ... import sql from ... import util from ...sql import sqltypes -from ...sql import type_api - -if TYPE_CHECKING: - from ...engine.interfaces import Dialect - from ...sql.elements import ColumnElement - from ...sql.type_api import _BindProcessorType - from ...sql.type_api import _ResultProcessorType - from ...sql.type_api import TypeEngine - from ...sql.type_api import TypeEngineMixin -class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType): +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType): """MySQL ENUM type.""" __visit_name__ = "ENUM" native_enum = True - def __init__(self, *enums: Union[str, Type[enum.Enum]], **kw: Any) -> None: + def __init__(self, *enums, **kw): """Construct an ENUM. E.g.:: - Column("myenum", ENUM("foo", "bar", "baz")) + Column('myenum', ENUM("foo", "bar", "baz")) :param enums: The range of valid values for this ENUM. Values in enums are not quoted, they will be escaped and surrounded by single @@ -79,27 +62,21 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType): """ kw.pop("strict", None) - self._enum_init(enums, kw) # type: ignore[arg-type] + self._enum_init(enums, kw) _StringType.__init__(self, length=self.length, **kw) @classmethod - def adapt_emulated_to_native( - cls, - impl: Union[TypeEngine[Any], TypeEngineMixin], - **kw: Any, - ) -> ENUM: + def adapt_emulated_to_native(cls, impl, **kw): """Produce a MySQL native :class:`.mysql.ENUM` from plain :class:`.Enum`. """ - if TYPE_CHECKING: - assert isinstance(impl, ENUM) kw.setdefault("validate_strings", impl.validate_strings) kw.setdefault("values_callable", impl.values_callable) kw.setdefault("omit_aliases", impl._omit_aliases) return cls(**kw) - def _object_value_for_elem(self, elem: str) -> Union[str, enum.Enum]: + def _object_value_for_elem(self, elem): # mysql sends back a blank string for any value that # was persisted that was not in the enums; that is, it does no # validation on the incoming data, it "truncates" it to be @@ -109,27 +86,24 @@ class ENUM(type_api.NativeForEmulated, sqltypes.Enum, _StringType): else: return super()._object_value_for_elem(elem) - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[ENUM, _StringType, sqltypes.Enum] ) -# TODO: SET is a string as far as configuration but does not act like -# a string at the python level. We either need to make a py-type agnostic -# version of String as a base to be used for this, make this some kind of -# TypeDecorator, or just vendor it out as its own type. class SET(_StringType): """MySQL SET type.""" __visit_name__ = "SET" - def __init__(self, *values: str, **kw: Any): + def __init__(self, *values, **kw): """Construct a SET. E.g.:: - Column("myset", SET("foo", "bar", "baz")) + Column('myset', SET("foo", "bar", "baz")) + The list of potential values is required in the case that this set will be used to generate DDL for a table, or if the @@ -177,19 +151,17 @@ class SET(_StringType): "setting retrieve_as_bitwise=True" ) if self.retrieve_as_bitwise: - self._inversed_bitmap: Dict[str, int] = { + self._bitmap = { value: 2**idx for idx, value in enumerate(self.values) } - self._bitmap: Dict[int, str] = { - 2**idx: value for idx, value in enumerate(self.values) - } + self._bitmap.update( + (2**idx, value) for idx, value in enumerate(self.values) + ) length = max([len(v) for v in values] + [0]) kw.setdefault("length", length) super().__init__(**kw) - def column_expression( - self, colexpr: ColumnElement[Any] - ) -> ColumnElement[Any]: + def column_expression(self, colexpr): if self.retrieve_as_bitwise: return sql.type_coerce( sql.type_coerce(colexpr, sqltypes.Integer) + 0, self @@ -197,12 +169,10 @@ class SET(_StringType): else: return colexpr - def result_processor( - self, dialect: Dialect, coltype: Any - ) -> Optional[_ResultProcessorType[Any]]: + def result_processor(self, dialect, coltype): if self.retrieve_as_bitwise: - def process(value: Union[str, int, None]) -> Optional[Set[str]]: + def process(value): if value is not None: value = int(value) @@ -213,14 +183,11 @@ class SET(_StringType): else: super_convert = super().result_processor(dialect, coltype) - def process(value: Union[str, Set[str], None]) -> Optional[Set[str]]: # type: ignore[misc] # noqa: E501 + def process(value): if isinstance(value, str): # MySQLdb returns a string, let's parse if super_convert: value = super_convert(value) - assert value is not None - if TYPE_CHECKING: - assert isinstance(value, str) return set(re.findall(r"[^,]+", value)) else: # mysql-connector-python does a naive @@ -231,48 +198,43 @@ class SET(_StringType): return process - def bind_processor( - self, dialect: Dialect - ) -> _BindProcessorType[Union[str, int]]: + def bind_processor(self, dialect): super_convert = super().bind_processor(dialect) if self.retrieve_as_bitwise: - def process( - value: Union[str, int, set[str], None], - ) -> Union[str, int, None]: + def process(value): if value is None: return None elif isinstance(value, (int, str)): if super_convert: - return super_convert(value) # type: ignore[arg-type, no-any-return] # noqa: E501 + return super_convert(value) else: return value else: int_value = 0 for v in value: - int_value |= self._inversed_bitmap[v] + int_value |= self._bitmap[v] return int_value else: - def process( - value: Union[str, int, set[str], None], - ) -> Union[str, int, None]: + def process(value): # accept strings and int (actually bitflag) values directly if value is not None and not isinstance(value, (int, str)): value = ",".join(value) + if super_convert: - return super_convert(value) # type: ignore + return super_convert(value) else: return value return process - def adapt(self, cls: type, **kw: Any) -> Any: + def adapt(self, impltype, **kw): kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise - return util.constructor_copy(self, cls, *self.values, **kw) + return util.constructor_copy(self, impltype, *self.values, **kw) - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[SET, _StringType], diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/expression.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/expression.py index 9d19d52..c5bd0be 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/expression.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/expression.py @@ -1,13 +1,10 @@ -# dialects/mysql/expression.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors -from __future__ import annotations - -from typing import Any from ... import exc from ... import util @@ -20,7 +17,7 @@ from ...sql.base import Generative from ...util.typing import Self -class match(Generative, elements.BinaryExpression[Any]): +class match(Generative, elements.BinaryExpression): """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause. E.g.:: @@ -40,9 +37,7 @@ class match(Generative, elements.BinaryExpression[Any]): .order_by(desc(match_expr)) ) - Would produce SQL resembling: - - .. sourcecode:: sql + Would produce SQL resembling:: SELECT id, firstname, lastname FROM user @@ -75,9 +70,8 @@ class match(Generative, elements.BinaryExpression[Any]): __visit_name__ = "mysql_match" inherit_cache = True - modifiers: util.immutabledict[str, Any] - def __init__(self, *cols: elements.ColumnElement[Any], **kw: Any): + def __init__(self, *cols, **kw): if not cols: raise exc.ArgumentError("columns are required") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/json.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/json.py index e654a61..66fcb71 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/json.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/json.py @@ -1,21 +1,13 @@ -# dialects/mysql/json.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/json.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from __future__ import annotations - -from typing import Any -from typing import TYPE_CHECKING +# mypy: ignore-errors from ... import types as sqltypes -if TYPE_CHECKING: - from ...engine.interfaces import Dialect - from ...sql.type_api import _BindProcessorType - from ...sql.type_api import _LiteralProcessorType - class JSON(sqltypes.JSON): """MySQL JSON type. @@ -42,13 +34,13 @@ class JSON(sqltypes.JSON): class _FormatTypeMixin: - def _format_value(self, value: Any) -> str: + def _format_value(self, value): raise NotImplementedError() - def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: - super_proc = self.string_bind_processor(dialect) # type: ignore[attr-defined] # noqa: E501 + def bind_processor(self, dialect): + super_proc = self.string_bind_processor(dialect) - def process(value: Any) -> Any: + def process(value): value = self._format_value(value) if super_proc: value = super_proc(value) @@ -56,31 +48,29 @@ class _FormatTypeMixin: return process - def literal_processor( - self, dialect: Dialect - ) -> _LiteralProcessorType[Any]: - super_proc = self.string_literal_processor(dialect) # type: ignore[attr-defined] # noqa: E501 + def literal_processor(self, dialect): + super_proc = self.string_literal_processor(dialect) - def process(value: Any) -> str: + def process(value): value = self._format_value(value) if super_proc: value = super_proc(value) - return value # type: ignore[no-any-return] + return value return process class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value: Any) -> str: + def _format_value(self, value): if isinstance(value, int): - formatted_value = "$[%s]" % value + value = "$[%s]" % value else: - formatted_value = '$."%s"' % value - return formatted_value + value = '$."%s"' % value + return value class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): - def _format_value(self, value: Any) -> str: + def _format_value(self, value): return "$%s" % ( "".join( [ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadb.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadb.py index 508820e..a6ee5df 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadb.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadb.py @@ -1,73 +1,32 @@ -# dialects/mysql/mariadb.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/mariadb.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - -from __future__ import annotations - -from typing import Any -from typing import Callable - +# mypy: ignore-errors from .base import MariaDBIdentifierPreparer from .base import MySQLDialect -from .base import MySQLIdentifierPreparer -from .base import MySQLTypeCompiler -from ...sql import sqltypes - - -class INET4(sqltypes.TypeEngine[str]): - """INET4 column type for MariaDB - - .. versionadded:: 2.0.37 - """ - - __visit_name__ = "INET4" - - -class INET6(sqltypes.TypeEngine[str]): - """INET6 column type for MariaDB - - .. versionadded:: 2.0.37 - """ - - __visit_name__ = "INET6" - - -class MariaDBTypeCompiler(MySQLTypeCompiler): - def visit_INET4(self, type_: INET4, **kwargs: Any) -> str: - return "INET4" - - def visit_INET6(self, type_: INET6, **kwargs: Any) -> str: - return "INET6" class MariaDBDialect(MySQLDialect): is_mariadb = True supports_statement_cache = True name = "mariadb" - preparer: type[MySQLIdentifierPreparer] = MariaDBIdentifierPreparer - type_compiler_cls = MariaDBTypeCompiler + preparer = MariaDBIdentifierPreparer -def loader(driver: str) -> Callable[[], type[MariaDBDialect]]: - dialect_mod = __import__( +def loader(driver): + driver_mod = __import__( "sqlalchemy.dialects.mysql.%s" % driver ).dialects.mysql + driver_cls = getattr(driver_mod, driver).dialect - driver_mod = getattr(dialect_mod, driver) - if hasattr(driver_mod, "mariadb_dialect"): - driver_cls = driver_mod.mariadb_dialect - return driver_cls # type: ignore[no-any-return] - else: - driver_cls = driver_mod.dialect - - return type( - "MariaDBDialect_%s" % driver, - ( - MariaDBDialect, - driver_cls, - ), - {"supports_statement_cache": True}, - ) + return type( + "MariaDBDialect_%s" % driver, + ( + MariaDBDialect, + driver_cls, + ), + {"supports_statement_cache": True}, + ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py index b2d3d63..9730c9b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mariadbconnector.py @@ -1,9 +1,11 @@ -# dialects/mysql/mariadbconnector.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/mariadbconnector.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + """ @@ -27,15 +29,7 @@ be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver. .. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python """ # noqa -from __future__ import annotations - import re -from typing import Any -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union from uuid import UUID as _python_UUID from .base import MySQLCompiler @@ -45,19 +39,6 @@ from ... import sql from ... import util from ...sql import sqltypes -if TYPE_CHECKING: - from ...engine.base import Connection - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import Dialect - from ...engine.interfaces import IsolationLevel - from ...engine.interfaces import PoolProxiedConnection - from ...engine.url import URL - from ...sql.compiler import SQLCompiler - from ...sql.type_api import _ResultProcessorType - mariadb_cpy_minimum_version = (1, 0, 1) @@ -66,12 +47,10 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]): # work around JIRA issue # https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed, # this type can be removed. - def result_processor( - self, dialect: Dialect, coltype: object - ) -> Optional[_ResultProcessorType[Any]]: + def result_processor(self, dialect, coltype): if self.as_uuid: - def process(value: Any) -> Any: + def process(value): if value is not None: if hasattr(value, "decode"): value = value.decode("ascii") @@ -81,7 +60,7 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]): return process else: - def process(value: Any) -> Any: + def process(value): if value is not None: if hasattr(value, "decode"): value = value.decode("ascii") @@ -92,27 +71,30 @@ class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]): class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext): - _lastrowid: Optional[int] = None + _lastrowid = None - def create_server_side_cursor(self) -> DBAPICursor: + def create_server_side_cursor(self): return self._dbapi_connection.cursor(buffered=False) - def create_default_cursor(self) -> DBAPICursor: + def create_default_cursor(self): return self._dbapi_connection.cursor(buffered=True) - def post_exec(self) -> None: + def post_exec(self): super().post_exec() self._rowcount = self.cursor.rowcount - if TYPE_CHECKING: - assert isinstance(self.compiled, SQLCompiler) if self.isinsert and self.compiled.postfetch_lastrowid: self._lastrowid = self.cursor.lastrowid - def get_lastrowid(self) -> int: - if TYPE_CHECKING: - assert self._lastrowid is not None + @property + def rowcount(self): + if self._rowcount is not None: + return self._rowcount + else: + return self.cursor.rowcount + + def get_lastrowid(self): return self._lastrowid @@ -151,7 +133,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) @util.memoized_property - def _dbapi_version(self) -> Tuple[int, ...]: + def _dbapi_version(self): if self.dbapi and hasattr(self.dbapi, "__version__"): return tuple( [ @@ -164,7 +146,7 @@ class MySQLDialect_mariadbconnector(MySQLDialect): else: return (99, 99, 99) - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **kwargs): super().__init__(**kwargs) self.paramstyle = "qmark" if self.dbapi is not None: @@ -176,26 +158,20 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls): return __import__("mariadb") - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): if super().is_disconnect(e, connection, cursor): return True - elif isinstance(e, self.loaded_dbapi.Error): + elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() return "not connected" in str_e or "isn't valid" in str_e else: return False - def create_connect_args(self, url: URL) -> ConnectArgsType: + def create_connect_args(self, url): opts = url.translate_connect_args() - opts.update(url.query) int_params = [ "connect_timeout", @@ -210,7 +186,6 @@ class MySQLDialect_mariadbconnector(MySQLDialect): "ssl_verify_cert", "ssl", "pool_reset_connection", - "compress", ] for key in int_params: @@ -230,21 +205,19 @@ class MySQLDialect_mariadbconnector(MySQLDialect): except (AttributeError, ImportError): self.supports_sane_rowcount = False opts["client_flag"] = client_flag - return [], opts + return [[], opts] - def _extract_error_code(self, exception: DBAPIModule.Error) -> int: + def _extract_error_code(self, exception): try: - rc: int = exception.errno + rc = exception.errno except: rc = -1 return rc - def _detect_charset(self, connection: Connection) -> str: + def _detect_charset(self, connection): return "utf8mb4" - def get_isolation_level_values( - self, dbapi_conn: DBAPIConnection - ) -> Sequence[IsolationLevel]: + def get_isolation_level_values(self, dbapi_connection): return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -253,26 +226,21 @@ class MySQLDialect_mariadbconnector(MySQLDialect): "AUTOCOMMIT", ) - def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool: - return bool(dbapi_conn.autocommit) - - def set_isolation_level( - self, dbapi_connection: DBAPIConnection, level: IsolationLevel - ) -> None: + def set_isolation_level(self, connection, level): if level == "AUTOCOMMIT": - dbapi_connection.autocommit = True + connection.autocommit = True else: - dbapi_connection.autocommit = False - super().set_isolation_level(dbapi_connection, level) + connection.autocommit = False + super().set_isolation_level(connection, level) - def do_begin_twophase(self, connection: Connection, xid: Any) -> None: + def do_begin_twophase(self, connection, xid): connection.execute( sql.text("XA BEGIN :xid").bindparams( sql.bindparam("xid", xid, literal_execute=True) ) ) - def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: + def do_prepare_twophase(self, connection, xid): connection.execute( sql.text("XA END :xid").bindparams( sql.bindparam("xid", xid, literal_execute=True) @@ -285,12 +253,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) def do_rollback_twophase( - self, - connection: Connection, - xid: Any, - is_prepared: bool = True, - recover: bool = False, - ) -> None: + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute( sql.text("XA END :xid").bindparams( @@ -304,12 +268,8 @@ class MySQLDialect_mariadbconnector(MySQLDialect): ) def do_commit_twophase( - self, - connection: Connection, - xid: Any, - is_prepared: bool = True, - recover: bool = False, - ) -> None: + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py index feaf520..fc90c65 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -1,9 +1,10 @@ -# dialects/mysql/mysqlconnector.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/mysqlconnector.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" @@ -13,85 +14,26 @@ r""" :connectstring: mysql+mysqlconnector://:@[:]/ :url: https://pypi.org/project/mysql-connector-python/ -Driver Status -------------- - -MySQL Connector/Python is supported as of SQLAlchemy 2.0.39 to the -degree which the driver is functional. There are still ongoing issues -with features such as server side cursors which remain disabled until -upstream issues are repaired. - -.. warning:: The MySQL Connector/Python driver published by Oracle is subject - to frequent, major regressions of essential functionality such as being able - to correctly persist simple binary strings which indicate it is not well - tested. The SQLAlchemy project is not able to maintain this dialect fully as - regressions in the driver prevent it from being included in continuous - integration. - -.. versionchanged:: 2.0.39 - - The MySQL Connector/Python dialect has been updated to support the - latest version of this DBAPI. Previously, MySQL Connector/Python - was not fully supported. However, support remains limited due to ongoing - regressions introduced in this driver. - -Connecting to MariaDB with MySQL Connector/Python --------------------------------------------------- - -MySQL Connector/Python may attempt to pass an incompatible collation to the -database when connecting to MariaDB. Experimentation has shown that using -``?charset=utf8mb4&collation=utfmb4_general_ci`` or similar MariaDB-compatible -charset/collation will allow connectivity. +.. note:: + The MySQL Connector/Python DBAPI has had many issues since its release, + some of which may remain unresolved, and the mysqlconnector dialect is + **not tested as part of SQLAlchemy's continuous integration**. + The recommended MySQL dialects are mysqlclient and PyMySQL. """ # noqa -from __future__ import annotations import re -from typing import Any -from typing import cast -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union -from .base import MariaDBIdentifierPreparer +from .base import BIT from .base import MySQLCompiler from .base import MySQLDialect -from .base import MySQLExecutionContext from .base import MySQLIdentifierPreparer -from .mariadb import MariaDBDialect -from .types import BIT from ... import util -if TYPE_CHECKING: - - from ...engine.base import Connection - from ...engine.cursor import CursorResult - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import IsolationLevel - from ...engine.interfaces import PoolProxiedConnection - from ...engine.row import Row - from ...engine.url import URL - from ...sql.elements import BinaryExpression - - -class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): - def create_server_side_cursor(self) -> DBAPICursor: - return self._dbapi_connection.cursor(buffered=False) - - def create_default_cursor(self) -> DBAPICursor: - return self._dbapi_connection.cursor(buffered=True) - class MySQLCompiler_mysqlconnector(MySQLCompiler): - def visit_mod_binary( - self, binary: BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_mod_binary(self, binary, operator, **kw): return ( self.process(binary.left, **kw) + " % " @@ -99,37 +41,22 @@ class MySQLCompiler_mysqlconnector(MySQLCompiler): ) -class IdentifierPreparerCommon_mysqlconnector: +class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer): @property - def _double_percents(self) -> bool: + def _double_percents(self): return False @_double_percents.setter - def _double_percents(self, value: Any) -> None: + def _double_percents(self, value): pass - def _escape_identifier(self, value: str) -> str: - value = value.replace( - self.escape_quote, # type:ignore[attr-defined] - self.escape_to_quote, # type:ignore[attr-defined] - ) + def _escape_identifier(self, value): + value = value.replace(self.escape_quote, self.escape_to_quote) return value -class MySQLIdentifierPreparer_mysqlconnector( - IdentifierPreparerCommon_mysqlconnector, MySQLIdentifierPreparer -): - pass - - -class MariaDBIdentifierPreparer_mysqlconnector( - IdentifierPreparerCommon_mysqlconnector, MariaDBIdentifierPreparer -): - pass - - class _myconnpyBIT(BIT): - def result_processor(self, dialect: Any, coltype: Any) -> None: + def result_processor(self, dialect, coltype): """MySQL-connector already converts mysql bits, so.""" return None @@ -144,31 +71,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect): supports_native_decimal = True - supports_native_bit = True - - # not until https://bugs.mysql.com/bug.php?id=117548 - supports_server_side_cursors = False - default_paramstyle = "format" statement_compiler = MySQLCompiler_mysqlconnector - execution_ctx_cls = MySQLExecutionContext_mysqlconnector - - preparer: type[MySQLIdentifierPreparer] = ( - MySQLIdentifierPreparer_mysqlconnector - ) + preparer = MySQLIdentifierPreparer_mysqlconnector colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) @classmethod - def import_dbapi(cls) -> DBAPIModule: - return cast("DBAPIModule", __import__("mysql.connector").connector) + def import_dbapi(cls): + from mysql import connector - def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + return connector + + def do_ping(self, dbapi_connection): dbapi_connection.ping(False) return True - def create_connect_args(self, url: URL) -> ConnectArgsType: + def create_connect_args(self, url): opts = url.translate_connect_args(username="user") opts.update(url.query) @@ -176,7 +96,6 @@ class MySQLDialect_mysqlconnector(MySQLDialect): util.coerce_kw_type(opts, "allow_local_infile", bool) util.coerce_kw_type(opts, "autocommit", bool) util.coerce_kw_type(opts, "buffered", bool) - util.coerce_kw_type(opts, "client_flag", int) util.coerce_kw_type(opts, "compress", bool) util.coerce_kw_type(opts, "connection_timeout", int) util.coerce_kw_type(opts, "connect_timeout", int) @@ -191,21 +110,15 @@ class MySQLDialect_mysqlconnector(MySQLDialect): util.coerce_kw_type(opts, "use_pure", bool) util.coerce_kw_type(opts, "use_unicode", bool) - # note that "buffered" is set to False by default in MySQL/connector - # python. If you set it to True, then there is no way to get a server - # side cursor because the logic is written to disallow that. - - # leaving this at True until - # https://bugs.mysql.com/bug.php?id=117548 can be fixed - opts["buffered"] = True + # unfortunately, MySQL/connector python refuses to release a + # cursor without reading fully, so non-buffered isn't an option + opts.setdefault("buffered", True) # FOUND_ROWS must be set in ClientFlag to enable # supports_sane_rowcount. if self.dbapi is not None: try: - from mysql.connector import constants # type: ignore - - ClientFlag = constants.ClientFlag + from mysql.connector.constants import ClientFlag client_flags = opts.get( "client_flags", ClientFlag.get_default() @@ -214,35 +127,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect): opts["client_flags"] = client_flags except Exception: pass - - return [], opts + return [[], opts] @util.memoized_property - def _mysqlconnector_version_info(self) -> Optional[Tuple[int, ...]]: + def _mysqlconnector_version_info(self): if self.dbapi and hasattr(self.dbapi, "__version__"): m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) - return None - def _detect_charset(self, connection: Connection) -> str: - return connection.connection.charset # type: ignore + def _detect_charset(self, connection): + return connection.connection.charset - def _extract_error_code(self, exception: BaseException) -> int: - return exception.errno # type: ignore + def _extract_error_code(self, exception): + return exception.errno - def is_disconnect( - self, - e: Exception, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): errnos = (2006, 2013, 2014, 2045, 2055, 2048) - exceptions = ( - self.loaded_dbapi.OperationalError, # - self.loaded_dbapi.InterfaceError, - self.loaded_dbapi.ProgrammingError, - ) + exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) if isinstance(e, exceptions): return ( e.errno in errnos @@ -252,51 +154,26 @@ class MySQLDialect_mysqlconnector(MySQLDialect): else: return False - def _compat_fetchall( - self, - rp: CursorResult[Tuple[Any, ...]], - charset: Optional[str] = None, - ) -> Sequence[Row[Tuple[Any, ...]]]: + def _compat_fetchall(self, rp, charset=None): return rp.fetchall() - def _compat_fetchone( - self, - rp: CursorResult[Tuple[Any, ...]], - charset: Optional[str] = None, - ) -> Optional[Row[Tuple[Any, ...]]]: + def _compat_fetchone(self, rp, charset=None): return rp.fetchone() - def get_isolation_level_values( - self, dbapi_conn: DBAPIConnection - ) -> Sequence[IsolationLevel]: - return ( - "SERIALIZABLE", - "READ UNCOMMITTED", - "READ COMMITTED", - "REPEATABLE READ", - "AUTOCOMMIT", - ) + _isolation_lookup = { + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + } - def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool: - return bool(dbapi_conn.autocommit) - - def set_isolation_level( - self, dbapi_connection: DBAPIConnection, level: IsolationLevel - ) -> None: + def _set_isolation_level(self, connection, level): if level == "AUTOCOMMIT": - dbapi_connection.autocommit = True + connection.autocommit = True else: - dbapi_connection.autocommit = False - super().set_isolation_level(dbapi_connection, level) - - -class MariaDBDialect_mysqlconnector( - MariaDBDialect, MySQLDialect_mysqlconnector -): - supports_statement_cache = True - _allows_uuid_binds = False - preparer = MariaDBIdentifierPreparer_mysqlconnector + connection.autocommit = False + super()._set_isolation_level(connection, level) dialect = MySQLDialect_mysqlconnector -mariadb_dialect = MariaDBDialect_mysqlconnector diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqldb.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqldb.py index a5b0ca2..d1cf835 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqldb.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/mysqldb.py @@ -1,9 +1,11 @@ -# dialects/mysql/mysqldb.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/mysqldb.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + """ @@ -46,9 +48,9 @@ key "ssl", which may be specified using the "ssl": { "ca": "/home/gord/client-ssl/ca.pem", "cert": "/home/gord/client-ssl/client-cert.pem", - "key": "/home/gord/client-ssl/client-key.pem", + "key": "/home/gord/client-ssl/client-key.pem" } - }, + } ) For convenience, the following keys may also be specified inline within the URL @@ -72,9 +74,7 @@ Using MySQLdb with Google Cloud SQL ----------------------------------- Google Cloud SQL now recommends use of the MySQLdb dialect. Connect -using a URL like the following: - -.. sourcecode:: text +using a URL like the following:: mysql+mysqldb://root@/?unix_socket=/cloudsql/: @@ -84,39 +84,25 @@ Server Side Cursors The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. """ -from __future__ import annotations import re -from typing import Any -from typing import Callable -from typing import cast -from typing import Dict -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING from .base import MySQLCompiler from .base import MySQLDialect from .base import MySQLExecutionContext from .base import MySQLIdentifierPreparer +from .base import TEXT +from ... import sql from ... import util -from ...util.typing import Literal - -if TYPE_CHECKING: - - from ...engine.base import Connection - from ...engine.interfaces import _DBAPIMultiExecuteParams - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import ExecutionContext - from ...engine.interfaces import IsolationLevel - from ...engine.url import URL class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - pass + @property + def rowcount(self): + if hasattr(self, "_rowcount"): + return self._rowcount + else: + return self.cursor.rowcount class MySQLCompiler_mysqldb(MySQLCompiler): @@ -136,9 +122,8 @@ class MySQLDialect_mysqldb(MySQLDialect): execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer - server_version_info: Tuple[int, ...] - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs): super().__init__(**kwargs) self._mysql_dbapi_version = ( self._parse_dbapi_version(self.dbapi.__version__) @@ -146,7 +131,7 @@ class MySQLDialect_mysqldb(MySQLDialect): else (0, 0, 0) ) - def _parse_dbapi_version(self, version: str) -> Tuple[int, ...]: + def _parse_dbapi_version(self, version): m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) @@ -154,7 +139,7 @@ class MySQLDialect_mysqldb(MySQLDialect): return (0, 0, 0) @util.langhelpers.memoized_property - def supports_server_side_cursors(self) -> bool: + def supports_server_side_cursors(self): try: cursors = __import__("MySQLdb.cursors").cursors self._sscursor = cursors.SSCursor @@ -163,13 +148,13 @@ class MySQLDialect_mysqldb(MySQLDialect): return False @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls): return __import__("MySQLdb") - def on_connect(self) -> Callable[[DBAPIConnection], None]: + def on_connect(self): super_ = super().on_connect() - def on_connect(conn: DBAPIConnection) -> None: + def on_connect(conn): if super_ is not None: super_(conn) @@ -182,24 +167,43 @@ class MySQLDialect_mysqldb(MySQLDialect): return on_connect - def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]: + def do_ping(self, dbapi_connection): dbapi_connection.ping() return True - def do_executemany( - self, - cursor: DBAPICursor, - statement: str, - parameters: _DBAPIMultiExecuteParams, - context: Optional[ExecutionContext] = None, - ) -> None: + def do_executemany(self, cursor, statement, parameters, context=None): rowcount = cursor.executemany(statement, parameters) if context is not None: - cast(MySQLExecutionContext, context)._rowcount = rowcount + context._rowcount = rowcount - def create_connect_args( - self, url: URL, _translate_args: Optional[Dict[str, Any]] = None - ) -> ConnectArgsType: + def _check_unicode_returns(self, connection): + # work around issue fixed in + # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 + # specific issue w/ the utf8mb4_bin collation and unicode returns + + collation = connection.exec_driver_sql( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ).scalar() + has_utf8mb4_bin = self.server_version_info > (5,) and collation + if has_utf8mb4_bin: + additional_tests = [ + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) + ] + else: + additional_tests = [] + return super()._check_unicode_returns(connection, additional_tests) + + def create_connect_args(self, url, _translate_args=None): if _translate_args is None: _translate_args = dict( database="db", username="user", password="passwd" @@ -213,7 +217,7 @@ class MySQLDialect_mysqldb(MySQLDialect): util.coerce_kw_type(opts, "read_timeout", int) util.coerce_kw_type(opts, "write_timeout", int) util.coerce_kw_type(opts, "client_flag", int) - util.coerce_kw_type(opts, "local_infile", bool) + util.coerce_kw_type(opts, "local_infile", int) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. @@ -248,9 +252,9 @@ class MySQLDialect_mysqldb(MySQLDialect): if client_flag_found_rows is not None: client_flag |= client_flag_found_rows opts["client_flag"] = client_flag - return [], opts + return [[], opts] - def _found_rows_client_flag(self) -> Optional[int]: + def _found_rows_client_flag(self): if self.dbapi is not None: try: CLIENT_FLAGS = __import__( @@ -259,23 +263,20 @@ class MySQLDialect_mysqldb(MySQLDialect): except (AttributeError, ImportError): return None else: - return CLIENT_FLAGS.FOUND_ROWS # type: ignore + return CLIENT_FLAGS.FOUND_ROWS else: return None - def _extract_error_code(self, exception: DBAPIModule.Error) -> int: - return exception.args[0] # type: ignore[no-any-return] + def _extract_error_code(self, exception): + return exception.args[0] - def _detect_charset(self, connection: Connection) -> str: + def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" try: # note: the SQL here would be # "SHOW VARIABLES LIKE 'character_set%%'" - - cset_name: Callable[[], str] = ( - connection.connection.character_set_name - ) + cset_name = connection.connection.character_set_name except AttributeError: util.warn( "No 'character_set_name' can be detected with " @@ -287,9 +288,7 @@ class MySQLDialect_mysqldb(MySQLDialect): else: return cset_name() - def get_isolation_level_values( - self, dbapi_conn: DBAPIConnection - ) -> Tuple[IsolationLevel, ...]: + def get_isolation_level_values(self, dbapi_connection): return ( "SERIALIZABLE", "READ UNCOMMITTED", @@ -298,12 +297,7 @@ class MySQLDialect_mysqldb(MySQLDialect): "AUTOCOMMIT", ) - def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool: - return dbapi_conn.get_autocommit() # type: ignore[no-any-return] - - def set_isolation_level( - self, dbapi_connection: DBAPIConnection, level: IsolationLevel - ) -> None: + def set_isolation_level(self, dbapi_connection, level): if level == "AUTOCOMMIT": dbapi_connection.autocommit(True) else: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/provision.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/provision.py index fe97672..b7faf77 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/provision.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/provision.py @@ -1,10 +1,5 @@ -# dialects/mysql/provision.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors + from ... import exc from ...testing.provision import configure_follower from ...testing.provision import create_db @@ -39,13 +34,6 @@ def generate_driver_url(url, driver, query_str): drivername="%s+%s" % (backend, driver) ).update_query_string(query_str) - if driver == "mariadbconnector": - new_url = new_url.difference_update_query(["charset"]) - elif driver == "mysqlconnector": - new_url = new_url.update_query_pairs( - [("collation", "utf8mb4_general_ci")] - ) - try: new_url.get_dialect() except exc.NoSuchModuleError: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pymysql.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pymysql.py index 48b7994..6567202 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pymysql.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pymysql.py @@ -1,9 +1,11 @@ -# dialects/mysql/pymysql.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/pymysql.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + r""" @@ -39,6 +41,7 @@ necessary to indicate ``ssl_check_hostname=false`` in PyMySQL:: "&ssl_check_hostname=false" ) + MySQL-Python Compatibility -------------------------- @@ -47,26 +50,9 @@ and targets 100% compatibility. Most behavioral notes for MySQL-python apply to the pymysql driver as well. """ # noqa -from __future__ import annotations - -from typing import Any -from typing import Dict -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from .mysqldb import MySQLDialect_mysqldb from ...util import langhelpers -from ...util.typing import Literal - -if TYPE_CHECKING: - - from ...engine.interfaces import ConnectArgsType - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.interfaces import PoolProxiedConnection - from ...engine.url import URL class MySQLDialect_pymysql(MySQLDialect_mysqldb): @@ -76,7 +62,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): description_encoding = None @langhelpers.memoized_property - def supports_server_side_cursors(self) -> bool: + def supports_server_side_cursors(self): try: cursors = __import__("pymysql.cursors").cursors self._sscursor = cursors.SSCursor @@ -85,11 +71,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return False @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls): return __import__("pymysql") @langhelpers.memoized_property - def _send_false_to_ping(self) -> bool: + def _send_false_to_ping(self): """determine if pymysql has deprecated, changed the default of, or removed the 'reconnect' argument of connection.ping(). @@ -100,9 +86,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): """ # noqa: E501 try: - Connection = __import__( - "pymysql.connections" - ).connections.Connection + Connection = __import__("pymysql.connections").Connection except (ImportError, AttributeError): return True else: @@ -116,7 +100,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): not insp.defaults or insp.defaults[0] is not False ) - def do_ping(self, dbapi_connection: DBAPIConnection) -> Literal[True]: + def do_ping(self, dbapi_connection): if self._send_false_to_ping: dbapi_connection.ping(False) else: @@ -124,24 +108,17 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return True - def create_connect_args( - self, url: URL, _translate_args: Optional[Dict[str, Any]] = None - ) -> ConnectArgsType: + def create_connect_args(self, url, _translate_args=None): if _translate_args is None: _translate_args = dict(username="user") return super().create_connect_args( url, _translate_args=_translate_args ) - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): if super().is_disconnect(e, connection, cursor): return True - elif isinstance(e, self.loaded_dbapi.Error): + elif isinstance(e, self.dbapi.Error): str_e = str(e).lower() return ( "already closed" in str_e or "connection was killed" in str_e @@ -149,7 +126,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): else: return False - def _extract_error_code(self, exception: BaseException) -> Any: + def _extract_error_code(self, exception): if isinstance(exception.args[0], Exception): exception = exception.args[0] return exception.args[0] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pyodbc.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pyodbc.py index 86f1b3c..e4b1177 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pyodbc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/pyodbc.py @@ -1,13 +1,15 @@ -# dialects/mysql/pyodbc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/pyodbc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" + .. dialect:: mysql+pyodbc :name: PyODBC :dbapi: pyodbc @@ -28,30 +30,21 @@ r""" Pass through exact pyodbc connection string:: import urllib - connection_string = ( - "DRIVER=MySQL ODBC 8.0 ANSI Driver;" - "SERVER=localhost;" - "PORT=3307;" - "DATABASE=mydb;" - "UID=root;" - "PWD=(whatever);" - "charset=utf8mb4;" + 'DRIVER=MySQL ODBC 8.0 ANSI Driver;' + 'SERVER=localhost;' + 'PORT=3307;' + 'DATABASE=mydb;' + 'UID=root;' + 'PWD=(whatever);' + 'charset=utf8mb4;' ) params = urllib.parse.quote_plus(connection_string) connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params """ # noqa -from __future__ import annotations -import datetime import re -from typing import Any -from typing import Callable -from typing import Optional -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union from .base import MySQLDialect from .base import MySQLExecutionContext @@ -61,31 +54,23 @@ from ... import util from ...connectors.pyodbc import PyODBCConnector from ...sql.sqltypes import Time -if TYPE_CHECKING: - from ...engine import Connection - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import Dialect - from ...sql.type_api import _ResultProcessorType - class _pyodbcTIME(TIME): - def result_processor( - self, dialect: Dialect, coltype: object - ) -> _ResultProcessorType[datetime.time]: - def process(value: Any) -> Union[datetime.time, None]: + def result_processor(self, dialect, coltype): + def process(value): # pyodbc returns a datetime.time object; no need to convert - return value # type: ignore[no-any-return] + return value return process class MySQLExecutionContext_pyodbc(MySQLExecutionContext): - def get_lastrowid(self) -> int: + def get_lastrowid(self): cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") - lastrowid = cursor.fetchone()[0] # type: ignore[index] + lastrowid = cursor.fetchone()[0] cursor.close() - return lastrowid # type: ignore[no-any-return] + return lastrowid class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): @@ -96,7 +81,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): pyodbc_driver_name = "MySQL" - def _detect_charset(self, connection: Connection) -> str: + def _detect_charset(self, connection): """Sniff out the character set in use for connection results.""" # Prefer 'character_set_results' for the current connection over the @@ -121,25 +106,21 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): ) return "latin1" - def _get_server_version_info( - self, connection: Connection - ) -> Tuple[int, ...]: + def _get_server_version_info(self, connection): return MySQLDialect._get_server_version_info(self, connection) - def _extract_error_code(self, exception: BaseException) -> Optional[int]: + def _extract_error_code(self, exception): m = re.compile(r"\((\d+)\)").search(str(exception.args)) - if m is None: - return None - c: Optional[str] = m.group(1) + c = m.group(1) if c: return int(c) else: return None - def on_connect(self) -> Callable[[DBAPIConnection], None]: + def on_connect(self): super_ = super().on_connect() - def on_connect(conn: DBAPIConnection) -> None: + def on_connect(conn): if super_ is not None: super_(conn) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reflection.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reflection.py index 71bd8c4..c4909fe 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reflection.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reflection.py @@ -1,65 +1,46 @@ -# dialects/mysql/reflection.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/reflection.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from __future__ import annotations +# mypy: ignore-errors + import re -from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import overload -from typing import Sequence -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union from .enumerated import ENUM from .enumerated import SET from .types import DATETIME from .types import TIME from .types import TIMESTAMP +from ... import log from ... import types as sqltypes from ... import util -from ...util.typing import Literal - -if TYPE_CHECKING: - from .base import MySQLDialect - from .base import MySQLIdentifierPreparer - from ...engine.interfaces import ReflectedColumn class ReflectedState: """Stores raw information about a SHOW CREATE TABLE statement.""" - charset: Optional[str] - - def __init__(self) -> None: - self.columns: List[ReflectedColumn] = [] - self.table_options: Dict[str, str] = {} - self.table_name: Optional[str] = None - self.keys: List[Dict[str, Any]] = [] - self.fk_constraints: List[Dict[str, Any]] = [] - self.ck_constraints: List[Dict[str, Any]] = [] + def __init__(self): + self.columns = [] + self.table_options = {} + self.table_name = None + self.keys = [] + self.fk_constraints = [] + self.ck_constraints = [] +@log.class_logger class MySQLTableDefinitionParser: """Parses the results of a SHOW CREATE TABLE statement.""" - def __init__( - self, dialect: MySQLDialect, preparer: MySQLIdentifierPreparer - ): + def __init__(self, dialect, preparer): self.dialect = dialect self.preparer = preparer self._prep_regexes() - def parse( - self, show_create: str, charset: Optional[str] - ) -> ReflectedState: + def parse(self, show_create, charset): state = ReflectedState() state.charset = charset for line in re.split(r"\r?\n", show_create): @@ -84,11 +65,11 @@ class MySQLTableDefinitionParser: if type_ is None: util.warn("Unknown schema content: %r" % line) elif type_ == "key": - state.keys.append(spec) # type: ignore[arg-type] + state.keys.append(spec) elif type_ == "fk_constraint": - state.fk_constraints.append(spec) # type: ignore[arg-type] + state.fk_constraints.append(spec) elif type_ == "ck_constraint": - state.ck_constraints.append(spec) # type: ignore[arg-type] + state.ck_constraints.append(spec) else: pass return state @@ -96,13 +77,7 @@ class MySQLTableDefinitionParser: def _check_view(self, sql: str) -> bool: return bool(self._re_is_view.match(sql)) - def _parse_constraints(self, line: str) -> Union[ - Tuple[None, str], - Tuple[Literal["partition"], str], - Tuple[ - Literal["ck_constraint", "fk_constraint", "key"], Dict[str, str] - ], - ]: + def _parse_constraints(self, line): """Parse a KEY or CONSTRAINT line. :param line: A line of SHOW CREATE TABLE output @@ -152,7 +127,7 @@ class MySQLTableDefinitionParser: # No match. return (None, line) - def _parse_table_name(self, line: str, state: ReflectedState) -> None: + def _parse_table_name(self, line, state): """Extract the table name. :param line: The first line of SHOW CREATE TABLE @@ -163,7 +138,7 @@ class MySQLTableDefinitionParser: if m: state.table_name = cleanup(m.group("name")) - def _parse_table_options(self, line: str, state: ReflectedState) -> None: + def _parse_table_options(self, line, state): """Build a dictionary of all reflected table-level options. :param line: The final line of SHOW CREATE TABLE output. @@ -189,9 +164,7 @@ class MySQLTableDefinitionParser: for opt, val in options.items(): state.table_options["%s_%s" % (self.dialect.name, opt)] = val - def _parse_partition_options( - self, line: str, state: ReflectedState - ) -> None: + def _parse_partition_options(self, line, state): options = {} new_line = line[:] @@ -247,7 +220,7 @@ class MySQLTableDefinitionParser: else: state.table_options["%s_%s" % (self.dialect.name, opt)] = val - def _parse_column(self, line: str, state: ReflectedState) -> None: + def _parse_column(self, line, state): """Extract column details. Falls back to a 'minimal support' variant if full parse fails. @@ -310,16 +283,13 @@ class MySQLTableDefinitionParser: type_instance = col_type(*type_args, **type_kw) - col_kw: Dict[str, Any] = {} + col_kw = {} # NOT NULL col_kw["nullable"] = True # this can be "NULL" in the case of TIMESTAMP if spec.get("notnull", False) == "NOT NULL": col_kw["nullable"] = False - # For generated columns, the nullability is marked in a different place - if spec.get("notnull_generated", False) == "NOT NULL": - col_kw["nullable"] = False # AUTO_INCREMENT if spec.get("autoincr", False): @@ -351,13 +321,9 @@ class MySQLTableDefinitionParser: name=name, type=type_instance, default=default, comment=comment ) col_d.update(col_kw) - state.columns.append(col_d) # type: ignore[arg-type] + state.columns.append(col_d) - def _describe_to_create( - self, - table_name: str, - columns: Sequence[Tuple[str, str, str, str, str, str]], - ) -> str: + def _describe_to_create(self, table_name, columns): """Re-format DESCRIBE output as a SHOW CREATE TABLE string. DESCRIBE is a much simpler reflection and is sufficient for @@ -410,9 +376,7 @@ class MySQLTableDefinitionParser: ] ) - def _parse_keyexprs( - self, identifiers: str - ) -> List[Tuple[str, Optional[int], str]]: + def _parse_keyexprs(self, identifiers): """Unpack '"col"(2),"col" ASC'-ish strings into components.""" return [ @@ -422,12 +386,11 @@ class MySQLTableDefinitionParser: ) ] - def _prep_regexes(self) -> None: + def _prep_regexes(self): """Pre-compile regular expressions.""" - self._pr_options: List[ - Tuple[re.Pattern[Any], Optional[Callable[[str], str]]] - ] = [] + self._re_columns = [] + self._pr_options = [] _final = self.preparer.final_quote @@ -485,13 +448,11 @@ class MySQLTableDefinitionParser: r"(?: +COLLATE +(?P[\w_]+))?" r"(?: +(?P(?:NOT )?NULL))?" r"(?: +DEFAULT +(?P" - r"(?:NULL|'(?:''|[^'])*'|\(.+?\)|[\-\w\.\(\)]+" + r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+" r"(?: +ON UPDATE [\-\w\.\(\)]+)?)" r"))?" r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P\(" - r".*\))? ?(?PVIRTUAL|STORED)?" - r"(?: +(?P(?:NOT )?NULL))?" - r")?" + r".*\))? ?(?PVIRTUAL|STORED)?)?" r"(?: +(?PAUTO_INCREMENT))?" r"(?: +COMMENT +'(?P(?:''|[^'])*)')?" r"(?: +COLUMN_FORMAT +(?P\w+))?" @@ -539,7 +500,7 @@ class MySQLTableDefinitionParser: # # unique constraints come back as KEYs kw = quotes.copy() - kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT" + kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION" self._re_fk_constraint = _re_compile( r" " r"CONSTRAINT +" @@ -616,21 +577,21 @@ class MySQLTableDefinitionParser: _optional_equals = r"(?:\s*(?:=\s*)|\s+)" - def _add_option_string(self, directive: str) -> None: + def _add_option_string(self, directive): regex = r"(?P%s)%s" r"'(?P(?:[^']|'')*?)'(?!')" % ( re.escape(directive), self._optional_equals, ) self._pr_options.append(_pr_compile(regex, cleanup_text)) - def _add_option_word(self, directive: str) -> None: + def _add_option_word(self, directive): regex = r"(?P%s)%s" r"(?P\w+)" % ( re.escape(directive), self._optional_equals, ) self._pr_options.append(_pr_compile(regex)) - def _add_partition_option_word(self, directive: str) -> None: + def _add_partition_option_word(self, directive): if directive == "PARTITION BY" or directive == "SUBPARTITION BY": regex = r"(?%s)%s" r"(?P\w+.*)" % ( re.escape(directive), @@ -645,7 +606,7 @@ class MySQLTableDefinitionParser: regex = r"(?%s)(?!\S)" % (re.escape(directive),) self._pr_options.append(_pr_compile(regex)) - def _add_option_regex(self, directive: str, regex: str) -> None: + def _add_option_regex(self, directive, regex): regex = r"(?P%s)%s" r"(?P%s)" % ( re.escape(directive), self._optional_equals, @@ -663,35 +624,21 @@ _options_of_type_string = ( ) -@overload -def _pr_compile( - regex: str, cleanup: Callable[[str], str] -) -> Tuple[re.Pattern[Any], Callable[[str], str]]: ... - - -@overload -def _pr_compile( - regex: str, cleanup: None = None -) -> Tuple[re.Pattern[Any], None]: ... - - -def _pr_compile( - regex: str, cleanup: Optional[Callable[[str], str]] = None -) -> Tuple[re.Pattern[Any], Optional[Callable[[str], str]]]: +def _pr_compile(regex, cleanup=None): """Prepare a 2-tuple of compiled regex and callable.""" return (_re_compile(regex), cleanup) -def _re_compile(regex: str) -> re.Pattern[Any]: +def _re_compile(regex): """Compile a string to regex, I and UNICODE.""" return re.compile(regex, re.I | re.UNICODE) -def _strip_values(values: Sequence[str]) -> List[str]: +def _strip_values(values): "Strip reflected values quotes" - strip_values: List[str] = [] + strip_values = [] for a in values: if a[0:1] == '"' or a[0:1] == "'": # strip enclosing quotes and unquote interior @@ -703,9 +650,7 @@ def _strip_values(values: Sequence[str]) -> List[str]: def cleanup_text(raw_text: str) -> str: if "\\" in raw_text: raw_text = re.sub( - _control_char_regexp, - lambda s: _control_char_map[s[0]], # type: ignore[index] - raw_text, + _control_char_regexp, lambda s: _control_char_map[s[0]], raw_text ) return raw_text.replace("''", "'") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reserved_words.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reserved_words.py index ff52639..9f3436e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reserved_words.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/reserved_words.py @@ -1,5 +1,5 @@ -# dialects/mysql/reserved_words.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/reserved_words.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,6 +11,7 @@ # https://mariadb.com/kb/en/reserved-words/ # includes: Reserved Words, Oracle Mode (separate set unioned) # excludes: Exceptions, Function Names +# mypy: ignore-errors RESERVED_WORDS_MARIADB = { "accessible", @@ -281,7 +282,6 @@ RESERVED_WORDS_MARIADB = { } ) -# https://dev.mysql.com/doc/refman/8.3/en/keywords.html # https://dev.mysql.com/doc/refman/8.0/en/keywords.html # https://dev.mysql.com/doc/refman/5.7/en/keywords.html # https://dev.mysql.com/doc/refman/5.6/en/keywords.html @@ -403,7 +403,6 @@ RESERVED_WORDS_MYSQL = { "int4", "int8", "integer", - "intersect", "interval", "into", "io_after_gtids", @@ -469,7 +468,6 @@ RESERVED_WORDS_MYSQL = { "outfile", "over", "parse_gcol_expr", - "parallel", "partition", "percent_rank", "persist", @@ -478,7 +476,6 @@ RESERVED_WORDS_MYSQL = { "primary", "procedure", "purge", - "qualify", "range", "rank", "read", diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/types.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/types.py index 117df4d..aa1de1b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/types.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/mysql/types.py @@ -1,30 +1,18 @@ -# dialects/mysql/types.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# mysql/types.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -from __future__ import annotations +# mypy: ignore-errors + import datetime -import decimal -from typing import Any -from typing import Iterable -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from ... import exc from ... import util from ...sql import sqltypes -if TYPE_CHECKING: - from .base import MySQLDialect - from ...engine.interfaces import Dialect - from ...sql.type_api import _BindProcessorType - from ...sql.type_api import _ResultProcessorType - from ...sql.type_api import TypeEngine - class _NumericType: """Base for MySQL numeric types. @@ -34,27 +22,19 @@ class _NumericType: """ - def __init__( - self, unsigned: bool = False, zerofill: bool = False, **kw: Any - ): + def __init__(self, unsigned=False, zerofill=False, **kw): self.unsigned = unsigned self.zerofill = zerofill super().__init__(**kw) - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[_NumericType, sqltypes.Numeric] ) -class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]): - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = True, - **kw: Any, - ): +class _FloatType(_NumericType, sqltypes.Float): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): if isinstance(self, (REAL, DOUBLE)) and ( (precision is None and scale is not None) or (precision is not None and scale is None) @@ -66,18 +46,18 @@ class _FloatType(_NumericType, sqltypes.Float[Union[decimal.Decimal, float]]): super().__init__(precision=precision, asdecimal=asdecimal, **kw) self.scale = scale - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] ) class _IntegerType(_NumericType, sqltypes.Integer): - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): self.display_width = display_width super().__init__(**kw) - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] ) @@ -88,13 +68,13 @@ class _StringType(sqltypes.String): def __init__( self, - charset: Optional[str] = None, - collation: Optional[str] = None, - ascii: bool = False, # noqa - binary: bool = False, - unicode: bool = False, - national: bool = False, - **kw: Any, + charset=None, + collation=None, + ascii=False, # noqa + binary=False, + unicode=False, + national=False, + **kw, ): self.charset = charset @@ -107,33 +87,25 @@ class _StringType(sqltypes.String): self.national = national super().__init__(**kw) - def __repr__(self) -> str: + def __repr__(self): return util.generic_repr( self, to_inspect=[_StringType, sqltypes.String] ) -class _MatchType( - sqltypes.Float[Union[decimal.Decimal, float]], sqltypes.MatchType -): - def __init__(self, **kw: Any): +class _MatchType(sqltypes.Float, sqltypes.MatchType): + def __init__(self, **kw): # TODO: float arguments? - sqltypes.Float.__init__(self) # type: ignore[arg-type] + sqltypes.Float.__init__(self) sqltypes.MatchType.__init__(self) -class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]): +class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" __visit_name__ = "NUMERIC" - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = True, - **kw: Any, - ): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a NUMERIC. :param precision: Total digits in this number. If scale and precision @@ -154,18 +126,12 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC[Union[decimal.Decimal, float]]): ) -class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]): +class DECIMAL(_NumericType, sqltypes.DECIMAL): """MySQL DECIMAL type.""" __visit_name__ = "DECIMAL" - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = True, - **kw: Any, - ): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DECIMAL. :param precision: Total digits in this number. If scale and precision @@ -186,18 +152,12 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL[Union[decimal.Decimal, float]]): ) -class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]): +class DOUBLE(_FloatType, sqltypes.DOUBLE): """MySQL DOUBLE type.""" __visit_name__ = "DOUBLE" - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = True, - **kw: Any, - ): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DOUBLE. .. note:: @@ -226,18 +186,12 @@ class DOUBLE(_FloatType, sqltypes.DOUBLE[Union[decimal.Decimal, float]]): ) -class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]): +class REAL(_FloatType, sqltypes.REAL): """MySQL REAL type.""" __visit_name__ = "REAL" - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = True, - **kw: Any, - ): + def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a REAL. .. note:: @@ -266,18 +220,12 @@ class REAL(_FloatType, sqltypes.REAL[Union[decimal.Decimal, float]]): ) -class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]): +class FLOAT(_FloatType, sqltypes.FLOAT): """MySQL FLOAT type.""" __visit_name__ = "FLOAT" - def __init__( - self, - precision: Optional[int] = None, - scale: Optional[int] = None, - asdecimal: bool = False, - **kw: Any, - ): + def __init__(self, precision=None, scale=None, asdecimal=False, **kw): """Construct a FLOAT. :param precision: Total digits in this number. If scale and precision @@ -297,9 +245,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT[Union[decimal.Decimal, float]]): precision=precision, scale=scale, asdecimal=asdecimal, **kw ) - def bind_processor( - self, dialect: Dialect - ) -> Optional[_BindProcessorType[Union[decimal.Decimal, float]]]: + def bind_processor(self, dialect): return None @@ -308,7 +254,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): __visit_name__ = "INTEGER" - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): """Construct an INTEGER. :param display_width: Optional, maximum display width for this number. @@ -329,7 +275,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): __visit_name__ = "BIGINT" - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): """Construct a BIGINTEGER. :param display_width: Optional, maximum display width for this number. @@ -350,7 +296,7 @@ class MEDIUMINT(_IntegerType): __visit_name__ = "MEDIUMINT" - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): """Construct a MEDIUMINTEGER :param display_width: Optional, maximum display width for this number. @@ -371,7 +317,7 @@ class TINYINT(_IntegerType): __visit_name__ = "TINYINT" - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): """Construct a TINYINT. :param display_width: Optional, maximum display width for this number. @@ -386,19 +332,13 @@ class TINYINT(_IntegerType): """ super().__init__(display_width=display_width, **kw) - def _compare_type_affinity(self, other: TypeEngine[Any]) -> bool: - return ( - self._type_affinity is other._type_affinity - or other._type_affinity is sqltypes.Boolean - ) - class SMALLINT(_IntegerType, sqltypes.SMALLINT): """MySQL SMALLINTEGER type.""" __visit_name__ = "SMALLINT" - def __init__(self, display_width: Optional[int] = None, **kw: Any): + def __init__(self, display_width=None, **kw): """Construct a SMALLINTEGER. :param display_width: Optional, maximum display width for this number. @@ -414,7 +354,7 @@ class SMALLINT(_IntegerType, sqltypes.SMALLINT): super().__init__(display_width=display_width, **kw) -class BIT(sqltypes.TypeEngine[Any]): +class BIT(sqltypes.TypeEngine): """MySQL BIT type. This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater @@ -425,7 +365,7 @@ class BIT(sqltypes.TypeEngine[Any]): __visit_name__ = "BIT" - def __init__(self, length: Optional[int] = None): + def __init__(self, length=None): """Construct a BIT. :param length: Optional, number of bits. @@ -433,19 +373,20 @@ class BIT(sqltypes.TypeEngine[Any]): """ self.length = length - def result_processor( - self, dialect: MySQLDialect, coltype: object # type: ignore[override] - ) -> Optional[_ResultProcessorType[Any]]: - """Convert a MySQL's 64 bit, variable length binary string to a - long.""" + def result_processor(self, dialect, coltype): + """Convert a MySQL's 64 bit, variable length binary string to a long. - if dialect.supports_native_bit: - return None + TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector + already do this, so this logic should be moved to those dialects. - def process(value: Optional[Iterable[int]]) -> Optional[int]: + """ + + def process(value): if value is not None: v = 0 for i in value: + if not isinstance(i, int): + i = ord(i) # convert byte to int on Python 2 v = v << 8 | i return v return value @@ -458,7 +399,7 @@ class TIME(sqltypes.TIME): __visit_name__ = "TIME" - def __init__(self, timezone: bool = False, fsp: Optional[int] = None): + def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIME type. :param timezone: not used by the MySQL dialect. @@ -477,12 +418,10 @@ class TIME(sqltypes.TIME): super().__init__(timezone=timezone) self.fsp = fsp - def result_processor( - self, dialect: Dialect, coltype: object - ) -> _ResultProcessorType[datetime.time]: + def result_processor(self, dialect, coltype): time = datetime.time - def process(value: Any) -> Optional[datetime.time]: + def process(value): # convert from a timedelta value if value is not None: microseconds = value.microseconds @@ -505,7 +444,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): __visit_name__ = "TIMESTAMP" - def __init__(self, timezone: bool = False, fsp: Optional[int] = None): + def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIMESTAMP type. :param timezone: not used by the MySQL dialect. @@ -530,7 +469,7 @@ class DATETIME(sqltypes.DATETIME): __visit_name__ = "DATETIME" - def __init__(self, timezone: bool = False, fsp: Optional[int] = None): + def __init__(self, timezone=False, fsp=None): """Construct a MySQL DATETIME type. :param timezone: not used by the MySQL dialect. @@ -550,26 +489,26 @@ class DATETIME(sqltypes.DATETIME): self.fsp = fsp -class YEAR(sqltypes.TypeEngine[Any]): +class YEAR(sqltypes.TypeEngine): """MySQL YEAR type, for single byte storage of years 1901-2155.""" __visit_name__ = "YEAR" - def __init__(self, display_width: Optional[int] = None): + def __init__(self, display_width=None): self.display_width = display_width class TEXT(_StringType, sqltypes.TEXT): - """MySQL TEXT type, for character storage encoded up to 2^16 bytes.""" + """MySQL TEXT type, for text up to 2^16 characters.""" __visit_name__ = "TEXT" - def __init__(self, length: Optional[int] = None, **kw: Any): + def __init__(self, length=None, **kw): """Construct a TEXT. :param length: Optional, if provided the server may optimize storage by substituting the smallest TEXT type sufficient to store - ``length`` bytes of characters. + ``length`` characters. :param charset: Optional, a column-level character set for this string value. Takes precedence to 'ascii' or 'unicode' short-hand. @@ -596,11 +535,11 @@ class TEXT(_StringType, sqltypes.TEXT): class TINYTEXT(_StringType): - """MySQL TINYTEXT type, for character storage encoded up to 2^8 bytes.""" + """MySQL TINYTEXT type, for text up to 2^8 characters.""" __visit_name__ = "TINYTEXT" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs): """Construct a TINYTEXT. :param charset: Optional, a column-level character set for this string @@ -628,12 +567,11 @@ class TINYTEXT(_StringType): class MEDIUMTEXT(_StringType): - """MySQL MEDIUMTEXT type, for character storage encoded up - to 2^24 bytes.""" + """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" __visit_name__ = "MEDIUMTEXT" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs): """Construct a MEDIUMTEXT. :param charset: Optional, a column-level character set for this string @@ -661,11 +599,11 @@ class MEDIUMTEXT(_StringType): class LONGTEXT(_StringType): - """MySQL LONGTEXT type, for character storage encoded up to 2^32 bytes.""" + """MySQL LONGTEXT type, for text up to 2^32 characters.""" __visit_name__ = "LONGTEXT" - def __init__(self, **kwargs: Any): + def __init__(self, **kwargs): """Construct a LONGTEXT. :param charset: Optional, a column-level character set for this string @@ -697,7 +635,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): __visit_name__ = "VARCHAR" - def __init__(self, length: Optional[int] = None, **kwargs: Any) -> None: + def __init__(self, length=None, **kwargs): """Construct a VARCHAR. :param charset: Optional, a column-level character set for this string @@ -729,7 +667,7 @@ class CHAR(_StringType, sqltypes.CHAR): __visit_name__ = "CHAR" - def __init__(self, length: Optional[int] = None, **kwargs: Any): + def __init__(self, length=None, **kwargs): """Construct a CHAR. :param length: Maximum data length, in characters. @@ -745,7 +683,7 @@ class CHAR(_StringType, sqltypes.CHAR): super().__init__(length=length, **kwargs) @classmethod - def _adapt_string_for_cast(cls, type_: sqltypes.String) -> sqltypes.CHAR: + def _adapt_string_for_cast(self, type_): # copy the given string type into a CHAR # for the purposes of rendering a CAST expression type_ = sqltypes.to_instance(type_) @@ -774,7 +712,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): __visit_name__ = "NVARCHAR" - def __init__(self, length: Optional[int] = None, **kwargs: Any): + def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. :param length: Maximum data length, in characters. @@ -800,7 +738,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): __visit_name__ = "NCHAR" - def __init__(self, length: Optional[int] = None, **kwargs: Any): + def __init__(self, length=None, **kwargs): """Construct an NCHAR. :param length: Maximum data length, in characters. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/__init__.py index 566edf1..46a5d0a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/__init__.py @@ -1,11 +1,11 @@ -# dialects/oracle/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# oracle/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors -from types import ModuleType + from . import base # noqa from . import cx_oracle # noqa @@ -32,18 +32,7 @@ from .base import ROWID from .base import TIMESTAMP from .base import VARCHAR from .base import VARCHAR2 -from .base import VECTOR -from .base import VectorIndexConfig -from .base import VectorIndexType -from .vector import SparseVector -from .vector import VectorDistanceType -from .vector import VectorStorageFormat -from .vector import VectorStorageType -# Alias oracledb also as oracledb_async -oracledb_async = type( - "oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async} -) base.dialect = dialect = cx_oracle.dialect @@ -71,11 +60,4 @@ __all__ = ( "NVARCHAR2", "ROWID", "REAL", - "VECTOR", - "VectorDistanceType", - "VectorIndexType", - "VectorIndexConfig", - "VectorStorageFormat", - "VectorStorageType", - "SparseVector", ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/base.py index 2d6d6eb..d993ef2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/base.py @@ -1,5 +1,5 @@ -# dialects/oracle/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# oracle/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,7 +9,8 @@ r""" .. dialect:: oracle - :name: Oracle Database + :name: Oracle + :full_support: 18c :normal_support: 11+ :best_effort: 9+ @@ -17,24 +18,21 @@ r""" Auto Increment Behavior ----------------------- -SQLAlchemy Table objects which include integer primary keys are usually assumed -to have "autoincrementing" behavior, meaning they can generate their own -primary key values upon INSERT. For use within Oracle Database, two options are -available, which are the use of IDENTITY columns (Oracle Database 12 and above -only) or the association of a SEQUENCE with the column. +SQLAlchemy Table objects which include integer primary keys are usually +assumed to have "autoincrementing" behavior, meaning they can generate their +own primary key values upon INSERT. For use within Oracle, two options are +available, which are the use of IDENTITY columns (Oracle 12 and above only) +or the association of a SEQUENCE with the column. -Specifying GENERATED AS IDENTITY (Oracle Database 12 and above) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Specifying GENERATED AS IDENTITY (Oracle 12 and above) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Starting from version 12, Oracle Database can make use of identity columns -using the :class:`_sql.Identity` to specify the autoincrementing behavior:: +Starting from version 12 Oracle can make use of identity columns using +the :class:`_sql.Identity` to specify the autoincrementing behavior:: - t = Table( - "mytable", - metadata, - Column("id", Integer, Identity(start=3), primary_key=True), - Column(...), - ..., + t = Table('mytable', metadata, + Column('id', Integer, Identity(start=3), primary_key=True), + Column(...), ... ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -49,38 +47,34 @@ The CREATE TABLE for the above :class:`_schema.Table` object would be: The :class:`_schema.Identity` object support many options to control the "autoincrementing" behavior of the column, like the starting value, the -incrementing value, etc. In addition to the standard options, Oracle Database -supports setting :paramref:`_schema.Identity.always` to ``None`` to use the -default generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports +incrementing value, etc. +In addition to the standard options, Oracle supports setting +:paramref:`_schema.Identity.always` to ``None`` to use the default +generated mode, rendering GENERATED AS IDENTITY in the DDL. It also supports setting :paramref:`_schema.Identity.on_null` to ``True`` to specify ON NULL in conjunction with a 'BY DEFAULT' identity column. -Using a SEQUENCE (all Oracle Database versions) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Using a SEQUENCE (all Oracle versions) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Older version of Oracle Database had no "autoincrement" feature: SQLAlchemy -relies upon sequences to produce these values. With the older Oracle Database -versions, *a sequence must always be explicitly specified to enable -autoincrement*. This is divergent with the majority of documentation examples -which assume the usage of an autoincrement-capable database. To specify -sequences, use the sqlalchemy.schema.Sequence object which is passed to a -Column construct:: +Older version of Oracle had no "autoincrement" +feature, SQLAlchemy relies upon sequences to produce these values. With the +older Oracle versions, *a sequence must always be explicitly specified to +enable autoincrement*. This is divergent with the majority of documentation +examples which assume the usage of an autoincrement-capable database. To +specify sequences, use the sqlalchemy.schema.Sequence object which is passed +to a Column construct:: - t = Table( - "mytable", - metadata, - Column("id", Integer, Sequence("id_seq", start=1), primary_key=True), - Column(...), - ..., + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + Column(...), ... ) This step is also required when using table reflection, i.e. autoload_with=engine:: - t = Table( - "mytable", - metadata, - Column("id", Integer, Sequence("id_seq", start=1), primary_key=True), - autoload_with=engine, + t = Table('mytable', metadata, + Column('id', Integer, Sequence('id_seq', start=1), primary_key=True), + autoload_with=engine ) .. versionchanged:: 1.4 Added :class:`_schema.Identity` construct @@ -92,18 +86,21 @@ This step is also required when using table reflection, i.e. autoload_with=engin Transaction Isolation Level / Autocommit ---------------------------------------- -Oracle Database supports "READ COMMITTED" and "SERIALIZABLE" modes of -isolation. The AUTOCOMMIT isolation level is also supported by the -python-oracledb and cx_Oracle dialects. +The Oracle database supports "READ COMMITTED" and "SERIALIZABLE" modes of +isolation. The AUTOCOMMIT isolation level is also supported by the cx_Oracle +dialect. To set using per-connection execution options:: connection = engine.connect() - connection = connection.execution_options(isolation_level="AUTOCOMMIT") + connection = connection.execution_options( + isolation_level="AUTOCOMMIT" + ) -For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle Database dialects sets -the level at the session level using ``ALTER SESSION``, which is reverted back -to its default setting when the connection is returned to the connection pool. +For ``READ COMMITTED`` and ``SERIALIZABLE``, the Oracle dialect sets the +level at the session level using ``ALTER SESSION``, which is reverted back +to its default setting when the connection is returned to the connection +pool. Valid values for ``isolation_level`` include: @@ -113,28 +110,28 @@ Valid values for ``isolation_level`` include: .. note:: The implementation for the :meth:`_engine.Connection.get_isolation_level` method as implemented by the - Oracle Database dialects necessarily force the start of a transaction using the - Oracle Database DBMS_TRANSACTION.LOCAL_TRANSACTION_ID function; otherwise no - level is normally readable. + Oracle dialect necessarily forces the start of a transaction using the + Oracle LOCAL_TRANSACTION_ID function; otherwise no level is normally + readable. Additionally, the :meth:`_engine.Connection.get_isolation_level` method will raise an exception if the ``v$transaction`` view is not available due to - permissions or other reasons, which is a common occurrence in Oracle Database + permissions or other reasons, which is a common occurrence in Oracle installations. - The python-oracledb and cx_Oracle dialects attempt to call the + The cx_Oracle dialect attempts to call the :meth:`_engine.Connection.get_isolation_level` method when the dialect makes its first connection to the database in order to acquire the "default"isolation level. This default level is necessary so that the level can be reset on a connection after it has been temporarily modified using - :meth:`_engine.Connection.execution_options` method. In the common event + :meth:`_engine.Connection.execution_options` method. In the common event that the :meth:`_engine.Connection.get_isolation_level` method raises an exception due to ``v$transaction`` not being readable as well as any other database-related failure, the level is assumed to be "READ COMMITTED". No warning is emitted for this initial first-connect condition as it is expected to be a common restriction on Oracle databases. -.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_Oracle dialect +.. versionadded:: 1.3.16 added support for AUTOCOMMIT to the cx_oracle dialect as well as the notion of a default isolation level .. versionadded:: 1.3.21 Added support for SERIALIZABLE as well as live @@ -152,182 +149,56 @@ Valid values for ``isolation_level`` include: Identifier Casing ----------------- -In Oracle Database, the data dictionary represents all case insensitive -identifier names using UPPERCASE text. This is in contradiction to the -expectations of SQLAlchemy, which assume a case insensitive name is represented -as lowercase text. - -As an example of case insensitive identifier names, consider the following table: - -.. sourcecode:: sql - - CREATE TABLE MyTable (Identifier INTEGER PRIMARY KEY) - -If you were to ask Oracle Database for information about this table, the -table name would be reported as ``MYTABLE`` and the column name would -be reported as ``IDENTIFIER``. Compare to most other databases such as -PostgreSQL and MySQL which would report these names as ``mytable`` and -``identifier``. The names are **not quoted, therefore are case insensitive**. -The special casing of ``MyTable`` and ``Identifier`` would only be maintained -if they were quoted in the table definition: - -.. sourcecode:: sql - - CREATE TABLE "MyTable" ("Identifier" INTEGER PRIMARY KEY) - -When constructing a SQLAlchemy :class:`.Table` object, **an all lowercase name -is considered to be case insensitive**. So the following table assumes -case insensitive names:: - - Table("mytable", metadata, Column("identifier", Integer, primary_key=True)) - -Whereas when mixed case or UPPERCASE names are used, case sensitivity is -assumed:: - - Table("MyTable", metadata, Column("Identifier", Integer, primary_key=True)) - -A similar situation occurs at the database driver level when emitting a -textual SQL SELECT statement and looking at column names in the DBAPI -``cursor.description`` attribute. A database like PostgreSQL will normalize -case insensitive names to be lowercase:: - - >>> pg_engine = create_engine("postgresql://scott:tiger@localhost/test") - >>> pg_connection = pg_engine.connect() - >>> result = pg_connection.exec_driver_sql("SELECT 1 AS SomeName") - >>> result.cursor.description - (Column(name='somename', type_code=23),) - -Whereas Oracle normalizes them to UPPERCASE:: - - >>> oracle_engine = create_engine("oracle+oracledb://scott:tiger@oracle18c/xe") - >>> oracle_connection = oracle_engine.connect() - >>> result = oracle_connection.exec_driver_sql( - ... "SELECT 1 AS SomeName FROM DUAL" - ... ) - >>> result.cursor.description - [('SOMENAME', , 127, None, 0, -127, True)] - -In order to achieve cross-database parity for the two cases of a. table -reflection and b. textual-only SQL statement round trips, SQLAlchemy performs a step -called **name normalization** when using the Oracle dialect. This process may -also apply to other third party dialects that have similar UPPERCASE handling -of case insensitive names. - -When using name normalization, SQLAlchemy attempts to detect if a name is -case insensitive by checking if all characters are UPPERCASE letters only; -if so, then it assumes this is a case insensitive name and is delivered as -a lowercase name. - -For table reflection, a tablename that is seen represented as all UPPERCASE -in Oracle Database's catalog tables will be assumed to have a case insensitive -name. This is what allows the ``Table`` definition to use lower case names -and be equally compatible from a reflection point of view on Oracle Database -and all other databases such as PostgreSQL and MySQL:: - - # matches a table created with CREATE TABLE mytable - Table("mytable", metadata, autoload_with=some_engine) - -Above, the all lowercase name ``"mytable"`` is case insensitive; it will match -a table reported by PostgreSQL as ``"mytable"`` and a table reported by -Oracle as ``"MYTABLE"``. If name normalization were not present, it would -not be possible for the above :class:`.Table` definition to be introspectable -in a cross-database way, since we are dealing with a case insensitive name -that is not reported by each database in the same way. - -Case sensitivity can be forced on in this case, such as if we wanted to represent -the quoted tablename ``"MYTABLE"`` with that exact casing, most simply by using -that casing directly, which will be seen as a case sensitive name:: - - # matches a table created with CREATE TABLE "MYTABLE" - Table("MYTABLE", metadata, autoload_with=some_engine) - -For the unusual case of a quoted all-lowercase name, the :class:`.quoted_name` -construct may be used:: - - from sqlalchemy import quoted_name - - # matches a table created with CREATE TABLE "mytable" - Table( - quoted_name("mytable", quote=True), metadata, autoload_with=some_engine - ) - -Name normalization also takes place when handling result sets from **purely -textual SQL strings**, that have no other :class:`.Table` or :class:`.Column` -metadata associated with them. This includes SQL strings executed using -:meth:`.Connection.exec_driver_sql` and SQL strings executed using the -:func:`.text` construct which do not include :class:`.Column` metadata. - -Returning to the Oracle Database SELECT statement, we see that even though -``cursor.description`` reports the column name as ``SOMENAME``, SQLAlchemy -name normalizes this to ``somename``:: - - >>> oracle_engine = create_engine("oracle+oracledb://scott:tiger@oracle18c/xe") - >>> oracle_connection = oracle_engine.connect() - >>> result = oracle_connection.exec_driver_sql( - ... "SELECT 1 AS SomeName FROM DUAL" - ... ) - >>> result.cursor.description - [('SOMENAME', , 127, None, 0, -127, True)] - >>> result.keys() - RMKeyView(['somename']) - -The single scenario where the above behavior produces inaccurate results -is when using an all-uppercase, quoted name. SQLAlchemy has no way to determine -that a particular name in ``cursor.description`` was quoted, and is therefore -case sensitive, or was not quoted, and should be name normalized:: - - >>> result = oracle_connection.exec_driver_sql( - ... 'SELECT 1 AS "SOMENAME" FROM DUAL' - ... ) - >>> result.cursor.description - [('SOMENAME', , 127, None, 0, -127, True)] - >>> result.keys() - RMKeyView(['somename']) - -For this case, a new feature will be available in SQLAlchemy 2.1 to disable -the name normalization behavior in specific cases. - +In Oracle, the data dictionary represents all case insensitive identifier +names using UPPERCASE text. SQLAlchemy on the other hand considers an +all-lower case identifier name to be case insensitive. The Oracle dialect +converts all case insensitive identifiers to and from those two formats during +schema level communication, such as reflection of tables and indexes. Using +an UPPERCASE name on the SQLAlchemy side indicates a case sensitive +identifier, and SQLAlchemy will quote the name - this will cause mismatches +against data dictionary data received from Oracle, so unless identifier names +have been truly created as case sensitive (i.e. using quoted names), all +lowercase names should be used on the SQLAlchemy side. .. _oracle_max_identifier_lengths: -Maximum Identifier Lengths --------------------------- +Max Identifier Lengths +---------------------- -SQLAlchemy is sensitive to the maximum identifier length supported by Oracle -Database. This affects generated SQL label names as well as the generation of -constraint names, particularly in the case where the constraint naming -convention feature described at :ref:`constraint_naming_conventions` is being -used. +Oracle has changed the default max identifier length as of Oracle Server +version 12.2. Prior to this version, the length was 30, and for 12.2 and +greater it is now 128. This change impacts SQLAlchemy in the area of +generated SQL label names as well as the generation of constraint names, +particularly in the case where the constraint naming convention feature +described at :ref:`constraint_naming_conventions` is being used. -Oracle Database 12.2 increased the default maximum identifier length from 30 to -128. As of SQLAlchemy 1.4, the default maximum identifier length for the Oracle -dialects is 128 characters. Upon first connection, the maximum length actually -supported by the database is obtained. In all cases, setting the +To assist with this change and others, Oracle includes the concept of a +"compatibility" version, which is a version number that is independent of the +actual server version in order to assist with migration of Oracle databases, +and may be configured within the Oracle server itself. This compatibility +version is retrieved using the query ``SELECT value FROM v$parameter WHERE +name = 'compatible';``. The SQLAlchemy Oracle dialect, when tasked with +determining the default max identifier length, will attempt to use this query +upon first connect in order to determine the effective compatibility version of +the server, which determines what the maximum allowed identifier length is for +the server. If the table is not available, the server version information is +used instead. + +As of SQLAlchemy 1.4, the default max identifier length for the Oracle dialect +is 128 characters. Upon first connect, the compatibility version is detected +and if it is less than Oracle version 12.2, the max identifier length is +changed to be 30 characters. In all cases, setting the :paramref:`_sa.create_engine.max_identifier_length` parameter will bypass this change and the value given will be used as is:: engine = create_engine( - "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1", - max_identifier_length=30, - ) - -If :paramref:`_sa.create_engine.max_identifier_length` is not set, the oracledb -dialect internally uses the ``max_identifier_length`` attribute available on -driver connections since python-oracledb version 2.5. When using an older -driver version, or using the cx_Oracle dialect, SQLAlchemy will instead attempt -to use the query ``SELECT value FROM v$parameter WHERE name = 'compatible'`` -upon first connect in order to determine the effective compatibility version of -the database. The "compatibility" version is a version number that is -independent of the actual database version. It is used to assist database -migration. It is configured by an Oracle Database initialization parameter. The -compatibility version then determines the maximum allowed identifier length for -the database. If the V$ view is not available, the database version information -is used instead. + "oracle+cx_oracle://scott:tiger@oracle122", + max_identifier_length=30) The maximum identifier length comes into play both when generating anonymized SQL labels in SELECT statements, but more crucially when generating constraint names from a naming convention. It is this area that has created the need for -SQLAlchemy to change this default conservatively. For example, the following +SQLAlchemy to change this default conservatively. For example, the following naming convention produces two very different constraint names based on the identifier length:: @@ -359,71 +230,68 @@ identifier length:: oracle_dialect = oracle.dialect(max_identifier_length=30) print(CreateIndex(ix).compile(dialect=oracle_dialect)) -With an identifier length of 30, the above CREATE INDEX looks like: - -.. sourcecode:: sql +With an identifier length of 30, the above CREATE INDEX looks like:: CREATE INDEX ix_some_column_name_1s_70cd ON t (some_column_name_1, some_column_name_2, some_column_name_3) -However with length of 128, it becomes:: - -.. sourcecode:: sql +However with length=128, it becomes:: CREATE INDEX ix_some_column_name_1some_column_name_2some_column_name_3 ON t (some_column_name_1, some_column_name_2, some_column_name_3) -Applications which have run versions of SQLAlchemy prior to 1.4 on Oracle -Database version 12.2 or greater are therefore subject to the scenario of a +Applications which have run versions of SQLAlchemy prior to 1.4 on an Oracle +server version 12.2 or greater are therefore subject to the scenario of a database migration that wishes to "DROP CONSTRAINT" on a name that was previously generated with the shorter length. This migration will fail when the identifier length is changed without the name of the index or constraint first being adjusted. Such applications are strongly advised to make use of -:paramref:`_sa.create_engine.max_identifier_length` in order to maintain -control of the generation of truncated names, and to fully review and test all -database migrations in a staging environment when changing this value to ensure -that the impact of this change has been mitigated. +:paramref:`_sa.create_engine.max_identifier_length` +in order to maintain control +of the generation of truncated names, and to fully review and test all database +migrations in a staging environment when changing this value to ensure that the +impact of this change has been mitigated. -.. versionchanged:: 1.4 the default max_identifier_length for Oracle Database - is 128 characters, which is adjusted down to 30 upon first connect if the - Oracle Database, or its compatibility setting, are lower than version 12.2. +.. versionchanged:: 1.4 the default max_identifier_length for Oracle is 128 + characters, which is adjusted down to 30 upon first connect if an older + version of Oracle server (compatibility version < 12.2) is detected. LIMIT/OFFSET/FETCH Support -------------------------- -Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make use -of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming Oracle Database 12c or -above, and assuming the SELECT statement is not embedded within a compound -statement like UNION. This syntax is also available directly by using the -:meth:`_sql.Select.fetch` method. +Methods like :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` make +use of ``FETCH FIRST N ROW / OFFSET N ROWS`` syntax assuming +Oracle 12c or above, and assuming the SELECT statement is not embedded within +a compound statement like UNION. This syntax is also available directly by using +the :meth:`_sql.Select.fetch` method. -.. versionchanged:: 2.0 the Oracle Database dialects now use ``FETCH FIRST N - ROW / OFFSET N ROWS`` for all :meth:`_sql.Select.limit` and - :meth:`_sql.Select.offset` usage including within the ORM and legacy - :class:`_orm.Query`. To force the legacy behavior using window functions, - specify the ``enable_offset_fetch=False`` dialect parameter to - :func:`_sa.create_engine`. +.. versionchanged:: 2.0 the Oracle dialect now uses + ``FETCH FIRST N ROW / OFFSET N ROWS`` for all + :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset` usage including + within the ORM and legacy :class:`_orm.Query`. To force the legacy + behavior using window functions, specify the ``enable_offset_fetch=False`` + dialect parameter to :func:`_sa.create_engine`. -The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle Database -version by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, -which will force the use of "legacy" mode that makes use of window functions. +The use of ``FETCH FIRST / OFFSET`` may be disabled on any Oracle version +by passing ``enable_offset_fetch=False`` to :func:`_sa.create_engine`, which +will force the use of "legacy" mode that makes use of window functions. This mode is also selected automatically when using a version of Oracle -Database prior to 12c. +prior to 12c. -When using legacy mode, or when a :class:`.Select` statement with limit/offset -is embedded in a compound statement, an emulated approach for LIMIT / OFFSET -based on window functions is used, which involves creation of a subquery using -``ROW_NUMBER`` that is prone to performance issues as well as SQL construction -issues for complex statements. However, this approach is supported by all -Oracle Database versions. See notes below. +When using legacy mode, or when a :class:`.Select` statement +with limit/offset is embedded in a compound statement, an emulated approach for +LIMIT / OFFSET based on window functions is used, which involves creation of a +subquery using ``ROW_NUMBER`` that is prone to performance issues as well as +SQL construction issues for complex statements. However, this approach is +supported by all Oracle versions. See notes below. Notes on LIMIT / OFFSET emulation (when fetch() method cannot be used) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ If using :meth:`_sql.Select.limit` and :meth:`_sql.Select.offset`, or with the ORM the :meth:`_orm.Query.limit` and :meth:`_orm.Query.offset` methods on an -Oracle Database version prior to 12c, the following notes apply: +Oracle version prior to 12c, the following notes apply: * SQLAlchemy currently makes use of ROWNUM to achieve LIMIT/OFFSET; the exact methodology is taken from @@ -434,11 +302,10 @@ Oracle Database version prior to 12c, the following notes apply: to :func:`_sa.create_engine`. .. versionchanged:: 1.4 - - The Oracle Database dialect renders limit/offset integer values using a - "post compile" scheme which renders the integer directly before passing - the statement to the cursor for execution. The ``use_binds_for_limits`` - flag no longer has an effect. + The Oracle dialect renders limit/offset integer values using a "post + compile" scheme which renders the integer directly before passing the + statement to the cursor for execution. The ``use_binds_for_limits`` flag + no longer has an effect. .. seealso:: @@ -449,36 +316,37 @@ Oracle Database version prior to 12c, the following notes apply: RETURNING Support ----------------- -Oracle Database supports RETURNING fully for INSERT, UPDATE and DELETE -statements that are invoked with a single collection of bound parameters (that -is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally +The Oracle database supports RETURNING fully for INSERT, UPDATE and DELETE +statements that are invoked with a single collection of bound parameters +(that is, a ``cursor.execute()`` style statement; SQLAlchemy does not generally support RETURNING with :term:`executemany` statements). Multiple rows may be returned as well. -.. versionchanged:: 2.0 the Oracle Database backend has full support for - RETURNING on parity with other backends. +.. versionchanged:: 2.0 the Oracle backend has full support for RETURNING + on parity with other backends. + ON UPDATE CASCADE ----------------- -Oracle Database doesn't have native ON UPDATE CASCADE functionality. A trigger -based solution is available at -https://web.archive.org/web/20090317041251/https://asktom.oracle.com/tkyte/update_cascade/index.html +Oracle doesn't have native ON UPDATE CASCADE functionality. A trigger based +solution is available at +https://asktom.oracle.com/tkyte/update_cascade/index.html . When using the SQLAlchemy ORM, the ORM has limited ability to manually issue cascading updates - specify ForeignKey objects using the "deferrable=True, initially='deferred'" keyword arguments, and specify "passive_updates=False" on each relationship(). -Oracle Database 8 Compatibility -------------------------------- +Oracle 8 Compatibility +---------------------- -.. warning:: The status of Oracle Database 8 compatibility is not known for - SQLAlchemy 2.0. +.. warning:: The status of Oracle 8 compatibility is not known for SQLAlchemy + 2.0. -When Oracle Database 8 is detected, the dialect internally configures itself to -the following behaviors: +When Oracle 8 is detected, the dialect internally configures itself to the +following behaviors: * the use_ansi flag is set to False. This has the effect of converting all JOIN phrases into the WHERE clause, and in the case of LEFT OUTER JOIN @@ -500,15 +368,14 @@ for tables indicated by synonyms, either in local or remote schemas or accessed over DBLINK, by passing the flag ``oracle_resolve_synonyms=True`` as a keyword argument to the :class:`_schema.Table` construct:: - some_table = Table( - "some_table", autoload_with=some_engine, oracle_resolve_synonyms=True - ) + some_table = Table('some_table', autoload_with=some_engine, + oracle_resolve_synonyms=True) -When this flag is set, the given name (such as ``some_table`` above) will be -searched not just in the ``ALL_TABLES`` view, but also within the +When this flag is set, the given name (such as ``some_table`` above) will +be searched not just in the ``ALL_TABLES`` view, but also within the ``ALL_SYNONYMS`` view to see if this name is actually a synonym to another -name. If the synonym is located and refers to a DBLINK, the Oracle Database -dialects know how to locate the table's information using DBLINK syntax(e.g. +name. If the synonym is located and refers to a DBLINK, the oracle dialect +knows how to locate the table's information using DBLINK syntax(e.g. ``@dblink``). ``oracle_resolve_synonyms`` is accepted wherever reflection arguments are @@ -522,8 +389,8 @@ If synonyms are not in use, this flag should be left disabled. Constraint Reflection --------------------- -The Oracle Database dialects can return information about foreign key, unique, -and CHECK constraints, as well as indexes on tables. +The Oracle dialect can return information about foreign key, unique, and +CHECK constraints, as well as indexes on tables. Raw information regarding these constraints can be acquired using :meth:`_reflection.Inspector.get_foreign_keys`, @@ -531,7 +398,7 @@ Raw information regarding these constraints can be acquired using :meth:`_reflection.Inspector.get_check_constraints`, and :meth:`_reflection.Inspector.get_indexes`. -.. versionchanged:: 1.2 The Oracle Database dialect can now reflect UNIQUE and +.. versionchanged:: 1.2 The Oracle dialect can now reflect UNIQUE and CHECK constraints. When using reflection at the :class:`_schema.Table` level, the @@ -541,29 +408,29 @@ will also include these constraints. Note the following caveats: * When using the :meth:`_reflection.Inspector.get_check_constraints` method, - Oracle Database builds a special "IS NOT NULL" constraint for columns that - specify "NOT NULL". This constraint is **not** returned by default; to - include the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: + Oracle + builds a special "IS NOT NULL" constraint for columns that specify + "NOT NULL". This constraint is **not** returned by default; to include + the "IS NOT NULL" constraints, pass the flag ``include_all=True``:: from sqlalchemy import create_engine, inspect - engine = create_engine( - "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" - ) + engine = create_engine("oracle+cx_oracle://s:t@dsn") inspector = inspect(engine) all_check_constraints = inspector.get_check_constraints( - "some_table", include_all=True - ) + "some_table", include_all=True) -* in most cases, when reflecting a :class:`_schema.Table`, a UNIQUE constraint - will **not** be available as a :class:`.UniqueConstraint` object, as Oracle - Database mirrors unique constraints with a UNIQUE index in most cases (the - exception seems to be when two or more unique constraints represent the same - columns); the :class:`_schema.Table` will instead represent these using - :class:`.Index` with the ``unique=True`` flag set. +* in most cases, when reflecting a :class:`_schema.Table`, + a UNIQUE constraint will + **not** be available as a :class:`.UniqueConstraint` object, as Oracle + mirrors unique constraints with a UNIQUE index in most cases (the exception + seems to be when two or more unique constraints represent the same columns); + the :class:`_schema.Table` will instead represent these using + :class:`.Index` + with the ``unique=True`` flag set. -* Oracle Database creates an implicit index for the primary key of a table; - this index is **excluded** from all index results. +* Oracle creates an implicit index for the primary key of a table; this index + is **excluded** from all index results. * the list of columns reflected for an index will not include column names that start with SYS_NC. @@ -583,112 +450,50 @@ the ``exclude_tablespaces`` parameter:: # exclude SYSAUX and SOME_TABLESPACE, but not SYSTEM e = create_engine( - "oracle+oracledb://scott:tiger@localhost:1521/?service_name=freepdb1", - exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"], - ) - -.. _oracle_float_support: - -FLOAT / DOUBLE Support and Behaviors ------------------------------------- - -The SQLAlchemy :class:`.Float` and :class:`.Double` datatypes are generic -datatypes that resolve to the "least surprising" datatype for a given backend. -For Oracle Database, this means they resolve to the ``FLOAT`` and ``DOUBLE`` -types:: - - >>> from sqlalchemy import cast, literal, Float - >>> from sqlalchemy.dialects import oracle - >>> float_datatype = Float() - >>> print(cast(literal(5.0), float_datatype).compile(dialect=oracle.dialect())) - CAST(:param_1 AS FLOAT) - -Oracle's ``FLOAT`` / ``DOUBLE`` datatypes are aliases for ``NUMBER``. Oracle -Database stores ``NUMBER`` values with full precision, not floating point -precision, which means that ``FLOAT`` / ``DOUBLE`` do not actually behave like -native FP values. Oracle Database instead offers special datatypes -``BINARY_FLOAT`` and ``BINARY_DOUBLE`` to deliver real 4- and 8- byte FP -values. - -SQLAlchemy supports these datatypes directly using :class:`.BINARY_FLOAT` and -:class:`.BINARY_DOUBLE`. To use the :class:`.Float` or :class:`.Double` -datatypes in a database agnostic way, while allowing Oracle backends to utilize -one of these types, use the :meth:`.TypeEngine.with_variant` method to set up a -variant:: - - >>> from sqlalchemy import cast, literal, Float - >>> from sqlalchemy.dialects import oracle - >>> float_datatype = Float().with_variant(oracle.BINARY_FLOAT(), "oracle") - >>> print(cast(literal(5.0), float_datatype).compile(dialect=oracle.dialect())) - CAST(:param_1 AS BINARY_FLOAT) - -E.g. to use this datatype in a :class:`.Table` definition:: - - my_table = Table( - "my_table", - metadata, - Column( - "fp_data", Float().with_variant(oracle.BINARY_FLOAT(), "oracle") - ), - ) + "oracle+cx_oracle://scott:tiger@xe", + exclude_tablespaces=["SYSAUX", "SOME_TABLESPACE"]) DateTime Compatibility ---------------------- -Oracle Database has no datatype known as ``DATETIME``, it instead has only -``DATE``, which can actually store a date and time value. For this reason, the -Oracle Database dialects provide a type :class:`_oracle.DATE` which is a -subclass of :class:`.DateTime`. This type has no special behavior, and is only -present as a "marker" for this type; additionally, when a database column is -reflected and the type is reported as ``DATE``, the time-supporting +Oracle has no datatype known as ``DATETIME``, it instead has only ``DATE``, +which can actually store a date and time value. For this reason, the Oracle +dialect provides a type :class:`_oracle.DATE` which is a subclass of +:class:`.DateTime`. This type has no special behavior, and is only +present as a "marker" for this type; additionally, when a database column +is reflected and the type is reported as ``DATE``, the time-supporting :class:`_oracle.DATE` type is used. .. _oracle_table_options: -Oracle Database Table Options ------------------------------ +Oracle Table Options +------------------------- -The CREATE TABLE phrase supports the following options with Oracle Database -dialects in conjunction with the :class:`_schema.Table` construct: +The CREATE TABLE phrase supports the following options with Oracle +in conjunction with the :class:`_schema.Table` construct: * ``ON COMMIT``:: Table( - "some_table", - metadata, - ..., - prefixes=["GLOBAL TEMPORARY"], - oracle_on_commit="PRESERVE ROWS", - ) + "some_table", metadata, ..., + prefixes=['GLOBAL TEMPORARY'], oracle_on_commit='PRESERVE ROWS') -* - ``COMPRESS``:: +* ``COMPRESS``:: - Table( - "mytable", metadata, Column("data", String(32)), oracle_compress=True - ) + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=True) - Table("mytable", metadata, Column("data", String(32)), oracle_compress=6) + Table('mytable', metadata, Column('data', String(32)), + oracle_compress=6) - The ``oracle_compress`` parameter accepts either an integer compression - level, or ``True`` to use the default compression level. - -* - ``TABLESPACE``:: - - Table("mytable", metadata, ..., oracle_tablespace="EXAMPLE_TABLESPACE") - - The ``oracle_tablespace`` parameter specifies the tablespace in which the - table is to be created. This is useful when you want to create a table in a - tablespace other than the default tablespace of the user. - - .. versionadded:: 2.0.37 + The ``oracle_compress`` parameter accepts either an integer compression + level, or ``True`` to use the default compression level. .. _oracle_index_options: -Oracle Database Specific Index Options --------------------------------------- +Oracle Specific Index Options +----------------------------- Bitmap Indexes ~~~~~~~~~~~~~~ @@ -696,7 +501,7 @@ Bitmap Indexes You can specify the ``oracle_bitmap`` parameter to create a bitmap index instead of a B-tree index:: - Index("my_index", my_table.c.data, oracle_bitmap=True) + Index('my_index', my_table.c.data, oracle_bitmap=True) Bitmap indexes cannot be unique and cannot be compressed. SQLAlchemy will not check for such limitations, only the database will. @@ -704,248 +509,24 @@ check for such limitations, only the database will. Index compression ~~~~~~~~~~~~~~~~~ -Oracle Database has a more efficient storage mode for indexes containing lots -of repeated values. Use the ``oracle_compress`` parameter to turn on key +Oracle has a more efficient storage mode for indexes containing lots of +repeated values. Use the ``oracle_compress`` parameter to turn on key compression:: - Index("my_index", my_table.c.data, oracle_compress=True) + Index('my_index', my_table.c.data, oracle_compress=True) - Index( - "my_index", - my_table.c.data1, - my_table.c.data2, - unique=True, - oracle_compress=1, - ) + Index('my_index', my_table.c.data1, my_table.c.data2, unique=True, + oracle_compress=1) The ``oracle_compress`` parameter accepts either an integer specifying the number of prefix columns to compress, or ``True`` to use the default (all columns for non-unique indexes, all but the last column for unique indexes). -.. _oracle_vector_datatype: - -VECTOR Datatype ---------------- - -Oracle Database 23ai introduced a new VECTOR datatype for artificial intelligence -and machine learning search operations. The VECTOR datatype is a homogeneous array -of 8-bit signed integers, 8-bit unsigned integers (binary), 32-bit floating-point -numbers, or 64-bit floating-point numbers. - -A vector's storage type can be either DENSE or SPARSE. A dense vector contains -meaningful values in most or all of its dimensions. In contrast, a sparse vector -has non-zero values in only a few dimensions, with the majority being zero. - -Sparse vectors are represented by the total number of vector dimensions, an array -of indices, and an array of values where each value’s location in the vector is -indicated by the corresponding indices array position. All other vector values are -treated as zero. - -The storage formats that can be used with sparse vectors are float32, float64, and -int8. Note that the binary storage format cannot be used with sparse vectors. - -Sparse vectors are supported when you are using Oracle Database 23.7 or later. - -.. seealso:: - - `Using VECTOR Data - `_ - in the documentation - for the :ref:`oracledb` driver. - -.. versionadded:: 2.0.41 - Added VECTOR datatype - -.. versionadded:: 2.0.43 - Added DENSE/SPARSE support - -CREATE TABLE support for VECTOR -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -With the :class:`.VECTOR` datatype, you can specify the number of dimensions, -the storage format, and the storage type for the data. Valid values for the -storage format are enum members of :class:`.VectorStorageFormat`. Valid values -for the storage type are enum members of :class:`.VectorStorageType`. If -storage type is not specified, a DENSE vector is created by default. - -To create a table that includes a :class:`.VECTOR` column:: - - from sqlalchemy.dialects.oracle import ( - VECTOR, - VectorStorageFormat, - VectorStorageType, - ) - - t = Table( - "t1", - metadata, - Column("id", Integer, primary_key=True), - Column( - "embedding", - VECTOR( - dim=3, - storage_format=VectorStorageFormat.FLOAT32, - storage_type=VectorStorageType.SPARSE, - ), - ), - Column(...), - ..., - ) - -Vectors can also be defined with an arbitrary number of dimensions and formats. -This allows you to specify vectors of different dimensions with the various -storage formats mentioned below. - -**Examples** - -* In this case, the storage format is flexible, allowing any vector type data to be - inserted, such as INT8 or BINARY etc:: - - vector_col: Mapped[array.array] = mapped_column(VECTOR(dim=3)) - -* The dimension is flexible in this case, meaning that any dimension vector can - be used:: - - vector_col: Mapped[array.array] = mapped_column( - VECTOR(storage_format=VectorStorageType.INT8) - ) - -* Both the dimensions and the storage format are flexible. It creates a DENSE vector:: - - vector_col: Mapped[array.array] = mapped_column(VECTOR) - -* To create a SPARSE vector with both dimensions and the storage format as flexible, - use the :attr:`.VectorStorageType.SPARSE` storage type:: - - vector_col: Mapped[array.array] = mapped_column( - VECTOR(storage_type=VectorStorageType.SPARSE) - ) - -Python Datatypes for VECTOR -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -VECTOR data can be inserted using Python list or Python ``array.array()`` objects. -Python arrays of type FLOAT (32-bit), DOUBLE (64-bit), INT (8-bit signed integers), -or BINARY (8-bit unsigned integers) are used as bind values when inserting -VECTOR columns:: - - from sqlalchemy import insert, select - - with engine.begin() as conn: - conn.execute( - insert(t1), - {"id": 1, "embedding": [1, 2, 3]}, - ) - -Data can be inserted into a sparse vector using the :class:`_oracle.SparseVector` -class, creating an object consisting of the number of dimensions, an array of indices, and a -corresponding array of values:: - - from sqlalchemy import insert, select - from sqlalchemy.dialects.oracle import SparseVector - - sparse_val = SparseVector(10, [1, 2], array.array("d", [23.45, 221.22])) - - with engine.begin() as conn: - conn.execute( - insert(t1), - {"id": 1, "embedding": sparse_val}, - ) - -VECTOR Indexes -~~~~~~~~~~~~~~ - -The VECTOR feature supports an Oracle-specific parameter ``oracle_vector`` -on the :class:`.Index` construct, which allows the construction of VECTOR -indexes. - -SPARSE vectors cannot be used in the creation of vector indexes. - -To utilize VECTOR indexing, set the ``oracle_vector`` parameter to True to use -the default values provided by Oracle. HNSW is the default indexing method:: - - from sqlalchemy import Index - - Index( - "vector_index", - t1.c.embedding, - oracle_vector=True, - ) - -The full range of parameters for vector indexes are available by using the -:class:`.VectorIndexConfig` dataclass in place of a boolean; this dataclass -allows full configuration of the index:: - - Index( - "hnsw_vector_index", - t1.c.embedding, - oracle_vector=VectorIndexConfig( - index_type=VectorIndexType.HNSW, - distance=VectorDistanceType.COSINE, - accuracy=90, - hnsw_neighbors=5, - hnsw_efconstruction=20, - parallel=10, - ), - ) - - Index( - "ivf_vector_index", - t1.c.embedding, - oracle_vector=VectorIndexConfig( - index_type=VectorIndexType.IVF, - distance=VectorDistanceType.DOT, - accuracy=90, - ivf_neighbor_partitions=5, - ), - ) - -For complete explanation of these parameters, see the Oracle documentation linked -below. - -.. seealso:: - - `CREATE VECTOR INDEX `_ - in the Oracle documentation - - - -Similarity Searching -~~~~~~~~~~~~~~~~~~~~ - -When using the :class:`_oracle.VECTOR` datatype with a :class:`.Column` or similar -ORM mapped construct, additional comparison functions are available, including: - -* ``l2_distance`` -* ``cosine_distance`` -* ``inner_product`` - -Example Usage:: - - result_vector = connection.scalars( - select(t1).order_by(t1.embedding.l2_distance([2, 3, 4])).limit(3) - ) - - for user in vector: - print(user.id, user.embedding) - -FETCH APPROXIMATE support -~~~~~~~~~~~~~~~~~~~~~~~~~ - -Approximate vector search can only be performed when all syntax and semantic -rules are satisfied, the corresponding vector index is available, and the -query optimizer determines to perform it. If any of these conditions are -unmet, then an approximate search is not performed. In this case the query -returns exact results. - -To enable approximate searching during similarity searches on VECTORS, the -``oracle_fetch_approximate`` parameter may be used with the :meth:`.Select.fetch` -clause to add ``FETCH APPROX`` to the SELECT statement:: - - select(users_table).fetch(5, oracle_fetch_approximate=True) - """ # noqa from __future__ import annotations from collections import defaultdict -from dataclasses import fields from functools import lru_cache from functools import wraps import re @@ -968,9 +549,6 @@ from .types import RAW from .types import ROWID # noqa from .types import TIMESTAMP from .types import VARCHAR2 # noqa -from .vector import VECTOR -from .vector import VectorIndexConfig -from .vector import VectorIndexType from ... import Computed from ... import exc from ... import schema as sa_schema @@ -989,7 +567,6 @@ from ...sql import func from ...sql import null from ...sql import or_ from ...sql import select -from ...sql import selectable as sa_selectable from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors @@ -1017,7 +594,7 @@ RESERVED_WORDS = set( ) NO_ARG_FNS = set( - "UID CURRENT_DATE SYSDATE USER CURRENT_TIME CURRENT_TIMESTAMP".split() + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() ) @@ -1051,7 +628,6 @@ ischema_names = { "BINARY_DOUBLE": BINARY_DOUBLE, "BINARY_FLOAT": BINARY_FLOAT, "ROWID": ROWID, - "VECTOR": VECTOR, } @@ -1132,16 +708,16 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): # https://www.oracletutorial.com/oracle-basics/oracle-float/ estimated_binary_precision = int(precision / 0.30103) raise exc.ArgumentError( - "Oracle Database FLOAT types use 'binary precision', " - "which does not convert cleanly from decimal " - "'precision'. Please specify " - "this type with a separate Oracle Database variant, such " - f"as {type_.__class__.__name__}(precision={precision})." + "Oracle FLOAT types use 'binary precision', which does " + "not convert cleanly from decimal 'precision'. Please " + "specify " + f"this type with a separate Oracle variant, such as " + f"{type_.__class__.__name__}(precision={precision})." f"with_variant(oracle.FLOAT" f"(binary_precision=" f"{estimated_binary_precision}), 'oracle'), so that the " - "Oracle Database specific 'binary_precision' may be " - "specified accurately." + "Oracle specific 'binary_precision' may be specified " + "accurately." ) else: precision = binary_precision @@ -1209,18 +785,6 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_ROWID(self, type_, **kw): return "ROWID" - def visit_VECTOR(self, type_, **kw): - dim = type_.dim if type_.dim is not None else "*" - storage_format = ( - type_.storage_format.value - if type_.storage_format is not None - else "*" - ) - storage_type = ( - type_.storage_type.value if type_.storage_type is not None else "*" - ) - return f"VECTOR({dim},{storage_format},{storage_type})" - class OracleCompiler(compiler.SQLCompiler): """Oracle compiler modifies the lexical structure of Select @@ -1275,7 +839,7 @@ class OracleCompiler(compiler.SQLCompiler): def visit_function(self, func, **kw): text = super().visit_function(func, **kw) - if kw.get("asfrom", False) and func.name.lower() != "table": + if kw.get("asfrom", False): text = "TABLE (%s)" % text return text @@ -1382,13 +946,13 @@ class OracleCompiler(compiler.SQLCompiler): and not self.dialect._supports_update_returning_computed_cols ): util.warn( - "Computed columns don't work with Oracle Database UPDATE " + "Computed columns don't work with Oracle UPDATE " "statements that use RETURNING; the value of the column " "*before* the UPDATE takes place is returned. It is " - "advised to not use RETURNING with an Oracle Database " - "computed column. Consider setting implicit_returning " - "to False on the Table object in order to avoid implicit " - "RETURNING clauses from being generated for this Table." + "advised to not use RETURNING with an Oracle computed " + "column. Consider setting implicit_returning to False on " + "the Table object in order to avoid implicit RETURNING " + "clauses from being generated for this Table." ) if column.type._has_column_expression: col_expr = column.type.column_expression(column) @@ -1412,7 +976,7 @@ class OracleCompiler(compiler.SQLCompiler): raise exc.InvalidRequestError( "Using explicit outparam() objects with " "UpdateBase.returning() in the same Core DML statement " - "is not supported in the Oracle Database dialects." + "is not supported in the Oracle dialect." ) self._oracle_returning = True @@ -1433,7 +997,7 @@ class OracleCompiler(compiler.SQLCompiler): return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) def _row_limit_clause(self, select, **kw): - """Oracle Database 12c supports OFFSET/FETCH operators + """ORacle 12c supports OFFSET/FETCH operators Use it instead subquery with row_number """ @@ -1459,29 +1023,6 @@ class OracleCompiler(compiler.SQLCompiler): else: return select._fetch_clause - def fetch_clause( - self, - select, - fetch_clause=None, - require_offset=False, - use_literal_execute_for_simple_int=False, - **kw, - ): - text = super().fetch_clause( - select, - fetch_clause=fetch_clause, - require_offset=require_offset, - use_literal_execute_for_simple_int=( - use_literal_execute_for_simple_int - ), - **kw, - ) - - if select.dialect_options["oracle"]["fetch_approximate"]: - text = re.sub("FETCH FIRST", "FETCH APPROX FIRST", text) - - return text - def translate_select_structure(self, select_stmt, **kwargs): select = select_stmt @@ -1703,75 +1244,8 @@ class OracleCompiler(compiler.SQLCompiler): def visit_aggregate_strings_func(self, fn, **kw): return "LISTAGG%s" % self.function_argspec(fn, **kw) - def _visit_bitwise(self, binary, fn_name, custom_right=None, **kw): - left = self.process(binary.left, **kw) - right = self.process( - custom_right if custom_right is not None else binary.right, **kw - ) - return f"{fn_name}({left}, {right})" - - def visit_bitwise_xor_op_binary(self, binary, operator, **kw): - return self._visit_bitwise(binary, "BITXOR", **kw) - - def visit_bitwise_or_op_binary(self, binary, operator, **kw): - return self._visit_bitwise(binary, "BITOR", **kw) - - def visit_bitwise_and_op_binary(self, binary, operator, **kw): - return self._visit_bitwise(binary, "BITAND", **kw) - - def visit_bitwise_rshift_op_binary(self, binary, operator, **kw): - raise exc.CompileError("Cannot compile bitwise_rshift in oracle") - - def visit_bitwise_lshift_op_binary(self, binary, operator, **kw): - raise exc.CompileError("Cannot compile bitwise_lshift in oracle") - - def visit_bitwise_not_op_unary_operator(self, element, operator, **kw): - raise exc.CompileError("Cannot compile bitwise_not in oracle") - class OracleDDLCompiler(compiler.DDLCompiler): - - def _build_vector_index_config( - self, vector_index_config: VectorIndexConfig - ) -> str: - parts = [] - sql_param_name = { - "hnsw_neighbors": "neighbors", - "hnsw_efconstruction": "efconstruction", - "ivf_neighbor_partitions": "neighbor partitions", - "ivf_sample_per_partition": "sample_per_partition", - "ivf_min_vectors_per_partition": "min_vectors_per_partition", - } - if vector_index_config.index_type == VectorIndexType.HNSW: - parts.append("ORGANIZATION INMEMORY NEIGHBOR GRAPH") - elif vector_index_config.index_type == VectorIndexType.IVF: - parts.append("ORGANIZATION NEIGHBOR PARTITIONS") - if vector_index_config.distance is not None: - parts.append(f"DISTANCE {vector_index_config.distance.value}") - - if vector_index_config.accuracy is not None: - parts.append( - f"WITH TARGET ACCURACY {vector_index_config.accuracy}" - ) - - parameters_str = [f"type {vector_index_config.index_type.name}"] - prefix = vector_index_config.index_type.name.lower() + "_" - - for field in fields(vector_index_config): - if field.name.startswith(prefix): - key = sql_param_name.get(field.name) - value = getattr(vector_index_config, field.name) - if value is not None: - parameters_str.append(f"{key} {value}") - - parameters_str = ", ".join(parameters_str) - parts.append(f"PARAMETERS ({parameters_str})") - - if vector_index_config.parallel is not None: - parts.append(f"PARALLEL {vector_index_config.parallel}") - - return " ".join(parts) - def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -1779,10 +1253,10 @@ class OracleDDLCompiler(compiler.DDLCompiler): # oracle has no ON UPDATE CASCADE - # its only available via triggers - # https://web.archive.org/web/20090317041251/https://asktom.oracle.com/tkyte/update_cascade/index.html + # https://asktom.oracle.com/tkyte/update_cascade/index.html if constraint.onupdate is not None: util.warn( - "Oracle Database does not contain native UPDATE CASCADE " + "Oracle does not contain native UPDATE CASCADE " "functionality - onupdates will not be rendered for foreign " "keys. Consider using deferrable=True, initially='deferred' " "or triggers." @@ -1804,9 +1278,6 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " - vector_options = index.dialect_options["oracle"]["vector"] - if vector_options: - text += "VECTOR " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), @@ -1824,11 +1295,6 @@ class OracleDDLCompiler(compiler.DDLCompiler): text += " COMPRESS %d" % ( index.dialect_options["oracle"]["compress"] ) - if vector_options: - if vector_options is True: - vector_options = VectorIndexConfig() - - text += " " + self._build_vector_index_config(vector_options) return text def post_create_table(self, table): @@ -1844,10 +1310,7 @@ class OracleDDLCompiler(compiler.DDLCompiler): table_opts.append("\n COMPRESS") else: table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) - if opts["tablespace"]: - table_opts.append( - "\n TABLESPACE %s" % self.preparer.quote(opts["tablespace"]) - ) + return "".join(table_opts) def get_identity_options(self, identity_options): @@ -1865,9 +1328,8 @@ class OracleDDLCompiler(compiler.DDLCompiler): ) if generated.persisted is True: raise exc.CompileError( - "Oracle Database computed columns do not support 'stored' " - "persistence; set the 'persisted' flag to None or False for " - "Oracle Database support." + "Oracle computed columns do not support 'stored' persistence; " + "set the 'persisted' flag to None or False for Oracle support." ) elif generated.persisted is False: text += " VIRTUAL" @@ -1972,30 +1434,16 @@ class OracleDialect(default.DefaultDialect): construct_arguments = [ ( sa_schema.Table, - { - "resolve_synonyms": False, - "on_commit": None, - "compress": False, - "tablespace": None, - }, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, ), - ( - sa_schema.Index, - { - "bitmap": False, - "compress": False, - "vector": False, - }, - ), - (sa_selectable.Select, {"fetch_approximate": False}), - (sa_selectable.CompoundSelect, {"fetch_approximate": False}), + (sa_schema.Index, {"bitmap": False, "compress": False}), ] @util.deprecated_params( use_binds_for_limits=( "1.4", - "The ``use_binds_for_limits`` Oracle Database dialect parameter " - "is deprecated. The dialect now renders LIMIT / OFFSET integers " + "The ``use_binds_for_limits`` Oracle dialect parameter is " + "deprecated. The dialect now renders LIMIT /OFFSET integers " "inline in all cases using a post-compilation hook, so that the " "value is still represented by a 'bound parameter' on the Core " "Expression side.", @@ -2016,9 +1464,9 @@ class OracleDialect(default.DefaultDialect): self.use_ansi = use_ansi self.optimize_limits = optimize_limits self.exclude_tablespaces = exclude_tablespaces - self.enable_offset_fetch = self._supports_offset_fetch = ( - enable_offset_fetch - ) + self.enable_offset_fetch = ( + self._supports_offset_fetch + ) = enable_offset_fetch def initialize(self, connection): super().initialize(connection) @@ -2588,17 +2036,8 @@ class OracleDialect(default.DefaultDialect): ): query = select( dictionary.all_tables.c.table_name, - ( - dictionary.all_tables.c.compression - if self._supports_table_compression - else sql.null().label("compression") - ), - ( - dictionary.all_tables.c.compress_for - if self._supports_table_compress_for - else sql.null().label("compress_for") - ), - dictionary.all_tables.c.tablespace_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, ).where(dictionary.all_tables.c.owner == owner) if has_filter_names: query = query.where( @@ -2690,12 +2129,11 @@ class OracleDialect(default.DefaultDialect): connection, query, dblink, returns_long=False, params=params ) - for table, compression, compress_for, tablespace in result: - data = default() + for table, compression, compress_for in result: if compression == "ENABLED": - data["oracle_compress"] = compress_for - if tablespace: - data["oracle_tablespace"] = tablespace + data = {"oracle_compress": compress_for} + else: + data = default() options[(schema, self.normalize_name(table))] = data if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: # add the views (no temporary views) @@ -3085,12 +2523,10 @@ class OracleDialect(default.DefaultDialect): return ( ( (schema, self.normalize_name(table)), - ( - {"text": comment} - if comment is not None - and not comment.startswith(ignore_mat_view) - else default() - ), + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default(), ) for table, comment in result ) @@ -3632,11 +3068,9 @@ class OracleDialect(default.DefaultDialect): table_uc[constraint_name] = uc = { "name": constraint_name, "column_names": [], - "duplicates_index": ( - constraint_name - if constraint_name_orig in index_names - else None - ), + "duplicates_index": constraint_name + if constraint_name_orig in index_names + else None, } else: uc = table_uc[constraint_name] @@ -3648,11 +3082,9 @@ class OracleDialect(default.DefaultDialect): return ( ( key, - ( - list(unique_cons[key].values()) - if key in unique_cons - else default() - ), + list(unique_cons[key].values()) + if key in unique_cons + else default(), ) for key in ( (schema, self.normalize_name(obj_name)) @@ -3775,11 +3207,9 @@ class OracleDialect(default.DefaultDialect): return ( ( key, - ( - check_constraints[key] - if key in check_constraints - else default() - ), + check_constraints[key] + if key in check_constraints + else default(), ) for key in ( (schema, self.normalize_name(obj_name)) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py index 69bb7f3..c595b56 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/cx_oracle.py @@ -1,5 +1,4 @@ -# dialects/oracle/cx_oracle.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,18 +6,13 @@ # mypy: ignore-errors -r""".. dialect:: oracle+cx_oracle +r""" +.. dialect:: oracle+cx_oracle :name: cx-Oracle :dbapi: cx_oracle :connectstring: oracle+cx_oracle://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] :url: https://oracle.github.io/python-cx_Oracle/ -Description ------------ - -cx_Oracle was the original driver for Oracle Database. It was superseded by -python-oracledb which should be used instead. - DSN vs. Hostname connections ----------------------------- @@ -28,41 +22,27 @@ dialect translates from a series of different URL forms. Hostname Connections with Easy Connect Syntax ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Given a hostname, port and service name of the target database, for example -from Oracle Database's Easy Connect syntax then connect in SQLAlchemy using the -``service_name`` query string parameter:: +Given a hostname, port and service name of the target Oracle Database, for +example from Oracle's `Easy Connect syntax +`_, +then connect in SQLAlchemy using the ``service_name`` query string parameter:: - engine = create_engine( - "oracle+cx_oracle://scott:tiger@hostname:port?service_name=myservice&encoding=UTF-8&nencoding=UTF-8" - ) + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:port/?service_name=myservice&encoding=UTF-8&nencoding=UTF-8") -Note that the default driver value for encoding and nencoding was changed to -“UTF-8” in cx_Oracle 8.0 so these parameters can be omitted when using that -version, or later. +The `full Easy Connect syntax +`_ +is not supported. Instead, use a ``tnsnames.ora`` file and connect using a +DSN. -To use a full Easy Connect string, pass it as the ``dsn`` key value in a -:paramref:`_sa.create_engine.connect_args` dictionary:: +Connections with tnsnames.ora or Oracle Cloud +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - import cx_Oracle - - e = create_engine( - "oracle+cx_oracle://@", - connect_args={ - "user": "scott", - "password": "tiger", - "dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60", - }, - ) - -Connections with tnsnames.ora or to Oracle Autonomous Database -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Alternatively, if no port, database name, or service name is provided, the -dialect will use an Oracle Database DSN "connection string". This takes the -"hostname" portion of the URL as the data source name. For example, if the -``tnsnames.ora`` file contains a TNS Alias of ``myalias`` as below: - -.. sourcecode:: text +Alternatively, if no port, database name, or ``service_name`` is provided, the +dialect will use an Oracle DSN "connection string". This takes the "hostname" +portion of the URL as the data source name. For example, if the +``tnsnames.ora`` file contains a `Net Service Name +`_ +of ``myalias`` as below:: myalias = (DESCRIPTION = @@ -77,22 +57,19 @@ The cx_Oracle dialect connects to this database service when ``myalias`` is the hostname portion of the URL, without specifying a port, database name or ``service_name``:: - engine = create_engine("oracle+cx_oracle://scott:tiger@myalias") + engine = create_engine("oracle+cx_oracle://scott:tiger@myalias/?encoding=UTF-8&nencoding=UTF-8") -Users of Oracle Autonomous Database should use this syntax. If the database is -configured for mutural TLS ("mTLS"), then you must also configure the cloud +Users of Oracle Cloud should use this syntax and also configure the cloud wallet as shown in cx_Oracle documentation `Connecting to Autononmous Databases -`_. +`_. SID Connections ^^^^^^^^^^^^^^^ -To use Oracle Database's obsolete System Identifier connection syntax, the SID -can be passed in a "database name" portion of the URL:: +To use Oracle's obsolete SID connection syntax, the SID can be passed in a +"database name" portion of the URL as below:: - engine = create_engine( - "oracle+cx_oracle://scott:tiger@hostname:port/dbname" - ) + engine = create_engine("oracle+cx_oracle://scott:tiger@hostname:1521/dbname?encoding=UTF-8&nencoding=UTF-8") Above, the DSN passed to cx_Oracle is created by ``cx_Oracle.makedsn()`` as follows:: @@ -101,23 +78,17 @@ follows:: >>> cx_Oracle.makedsn("hostname", 1521, sid="dbname") '(DESCRIPTION=(ADDRESS=(PROTOCOL=TCP)(HOST=hostname)(PORT=1521))(CONNECT_DATA=(SID=dbname)))' -Note that although the SQLAlchemy syntax ``hostname:port/dbname`` looks like -Oracle's Easy Connect syntax it is different. It uses a SID in place of the -service name required by Easy Connect. The Easy Connect syntax does not -support SIDs. - Passing cx_Oracle connect arguments ----------------------------------- -Additional connection arguments can usually be passed via the URL query string; -particular symbols like ``SYSDBA`` are intercepted and converted to the correct -symbol:: +Additional connection arguments can usually be passed via the URL +query string; particular symbols like ``cx_Oracle.SYSDBA`` are intercepted +and converted to the correct symbol:: e = create_engine( - "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true" - ) + "oracle+cx_oracle://user:pass@dsn?encoding=UTF-8&nencoding=UTF-8&mode=SYSDBA&events=true") -.. versionchanged:: 1.3 the cx_Oracle dialect now accepts all argument names +.. versionchanged:: 1.3 the cx_oracle dialect now accepts all argument names within the URL string itself, to be passed to the cx_Oracle DBAPI. As was the case earlier but not correctly documented, the :paramref:`_sa.create_engine.connect_args` parameter also accepts all @@ -128,20 +99,19 @@ string, use the :paramref:`_sa.create_engine.connect_args` dictionary. Any cx_Oracle parameter value and/or constant may be passed, such as:: import cx_Oracle - e = create_engine( "oracle+cx_oracle://user:pass@dsn", connect_args={ "encoding": "UTF-8", "nencoding": "UTF-8", "mode": cx_Oracle.SYSDBA, - "events": True, - }, + "events": True + } ) -Note that the default driver value for ``encoding`` and ``nencoding`` was -changed to "UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when -using that version, or later. +Note that the default value for ``encoding`` and ``nencoding`` was changed to +"UTF-8" in cx_Oracle 8.0 so these parameters can be omitted when using that +version, or later. Options consumed by the SQLAlchemy cx_Oracle dialect outside of the driver -------------------------------------------------------------------------- @@ -151,19 +121,14 @@ itself. These options are always passed directly to :func:`_sa.create_engine` , such as:: e = create_engine( - "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False - ) + "oracle+cx_oracle://user:pass@dsn", coerce_to_decimal=False) The parameters accepted by the cx_oracle dialect are as follows: -* ``arraysize`` - set the cx_oracle.arraysize value on cursors; defaults - to ``None``, indicating that the driver default should be used (typically - the value is 100). This setting controls how many rows are buffered when - fetching rows, and can have a significant effect on performance when - modified. - - .. versionchanged:: 2.0.26 - changed the default value from 50 to None, - to use the default value of the driver itself. +* ``arraysize`` - set the cx_oracle.arraysize value on cursors, defaulted + to 50. This setting is significant with cx_Oracle as the contents of LOB + objects are only readable within a "live" row (e.g. within a batch of + 50 rows). * ``auto_convert_lobs`` - defaults to True; See :ref:`cx_oracle_lob`. @@ -176,16 +141,10 @@ The parameters accepted by the cx_oracle dialect are as follows: Using cx_Oracle SessionPool --------------------------- -The cx_Oracle driver provides its own connection pool implementation that may -be used in place of SQLAlchemy's pooling functionality. The driver pool -supports Oracle Database features such dead connection detection, connection -draining for planned database downtime, support for Oracle Application -Continuity and Transparent Application Continuity, and gives support for -Database Resident Connection Pooling (DRCP). - -Using the driver pool can be achieved by using the -:paramref:`_sa.create_engine.creator` parameter to provide a function that -returns a new connection, along with setting +The cx_Oracle library provides its own connection pool implementation that may +be used in place of SQLAlchemy's pooling functionality. This can be achieved +by using the :paramref:`_sa.create_engine.creator` parameter to provide a +function that returns a new connection, along with setting :paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable SQLAlchemy's pooling:: @@ -194,41 +153,32 @@ SQLAlchemy's pooling:: from sqlalchemy.pool import NullPool pool = cx_Oracle.SessionPool( - user="scott", - password="tiger", - dsn="orclpdb", - min=1, - max=4, - increment=1, - threaded=True, - encoding="UTF-8", - nencoding="UTF-8", + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" ) - engine = create_engine( - "oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool - ) + engine = create_engine("oracle+cx_oracle://", creator=pool.acquire, poolclass=NullPool) The above engine may then be used normally where cx_Oracle's pool handles connection pooling:: with engine.connect() as conn: - print(conn.scalar("select 1 from dual")) + print(conn.scalar("select 1 FROM dual")) + As well as providing a scalable solution for multi-user applications, the cx_Oracle session pool supports some Oracle features such as DRCP and `Application Continuity `_. -Note that the pool creation parameters ``threaded``, ``encoding`` and -``nencoding`` were deprecated in later cx_Oracle releases. - Using Oracle Database Resident Connection Pooling (DRCP) -------------------------------------------------------- -When using Oracle Database's DRCP, the best practice is to pass a connection -class and "purity" when acquiring a connection from the SessionPool. Refer to -the `cx_Oracle DRCP documentation +When using Oracle's `DRCP +`_, +the best practice is to pass a connection class and "purity" when acquiring a +connection from the SessionPool. Refer to the `cx_Oracle DRCP documentation `_. This can be achieved by wrapping ``pool.acquire()``:: @@ -238,33 +188,21 @@ This can be achieved by wrapping ``pool.acquire()``:: from sqlalchemy.pool import NullPool pool = cx_Oracle.SessionPool( - user="scott", - password="tiger", - dsn="orclpdb", - min=2, - max=5, - increment=1, - threaded=True, - encoding="UTF-8", - nencoding="UTF-8", + user="scott", password="tiger", dsn="orclpdb", + min=2, max=5, increment=1, threaded=True, + encoding="UTF-8", nencoding="UTF-8" ) - def creator(): - return pool.acquire( - cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF - ) + return pool.acquire(cclass="MYCLASS", purity=cx_Oracle.ATTR_PURITY_SELF) - - engine = create_engine( - "oracle+cx_oracle://", creator=creator, poolclass=NullPool - ) + engine = create_engine("oracle+cx_oracle://", creator=creator, poolclass=NullPool) The above engine may then be used normally where cx_Oracle handles session pooling and Oracle Database additionally uses DRCP:: with engine.connect() as conn: - print(conn.scalar("select 1 from dual")) + print(conn.scalar("select 1 FROM dual")) .. _cx_oracle_unicode: @@ -272,28 +210,24 @@ Unicode ------- As is the case for all DBAPIs under Python 3, all strings are inherently -Unicode strings. In all cases however, the driver requires an explicit +Unicode strings. In all cases however, the driver requires an explicit encoding configuration. Ensuring the Correct Client Encoding ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The long accepted standard for establishing client encoding for nearly all -Oracle Database related software is via the `NLS_LANG -`_ environment -variable. Older versions of cx_Oracle use this environment variable as the -source of its encoding configuration. The format of this variable is -Territory_Country.CharacterSet; a typical value would be -``AMERICAN_AMERICA.AL32UTF8``. cx_Oracle version 8 and later use the character -set "UTF-8" by default, and ignore the character set component of NLS_LANG. +Oracle related software is via the `NLS_LANG `_ +environment variable. cx_Oracle like most other Oracle drivers will use +this environment variable as the source of its encoding configuration. The +format of this variable is idiosyncratic; a typical value would be +``AMERICAN_AMERICA.AL32UTF8``. -The cx_Oracle driver also supported a programmatic alternative which is to pass -the ``encoding`` and ``nencoding`` parameters directly to its ``.connect()`` -function. These can be present in the URL as follows:: +The cx_Oracle driver also supports a programmatic alternative which is to +pass the ``encoding`` and ``nencoding`` parameters directly to its +``.connect()`` function. These can be present in the URL as follows:: - engine = create_engine( - "oracle+cx_oracle://scott:tiger@tnsalias?encoding=UTF-8&nencoding=UTF-8" - ) + engine = create_engine("oracle+cx_oracle://scott:tiger@orclpdb/?encoding=UTF-8&nencoding=UTF-8") For the meaning of the ``encoding`` and ``nencoding`` parameters, please consult @@ -308,24 +242,25 @@ consult Unicode-specific Column datatypes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The Core expression language handles unicode data by use of the -:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond -to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using -these datatypes with Unicode data, it is expected that the database is -configured with a Unicode-aware character set, as well as that the ``NLS_LANG`` -environment variable is set appropriately (this applies to older versions of -cx_Oracle), so that the VARCHAR2 and CLOB datatypes can accommodate the data. +The Core expression language handles unicode data by use of the :class:`.Unicode` +and :class:`.UnicodeText` +datatypes. These types correspond to the VARCHAR2 and CLOB Oracle datatypes by +default. When using these datatypes with Unicode data, it is expected that +the Oracle database is configured with a Unicode-aware character set, as well +as that the ``NLS_LANG`` environment variable is set appropriately, so that +the VARCHAR2 and CLOB datatypes can accommodate the data. -In the case that Oracle Database is not configured with a Unicode character +In the case that the Oracle database is not configured with a Unicode character set, the two options are to use the :class:`_types.NCHAR` and :class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag -``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause -the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / +``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, +which will cause the +SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / :class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. -.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` - datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database - datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect +.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` + datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle datatypes + unless the ``use_nchar_for_unicode=True`` is passed to the dialect when :func:`_sa.create_engine` is called. @@ -334,7 +269,7 @@ the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / Encoding Errors ^^^^^^^^^^^^^^^ -For the unusual case that data in Oracle Database is present with a broken +For the unusual case that data in the Oracle database is present with a broken encoding, the dialect accepts a parameter ``encoding_errors`` which will be passed to Unicode decoding functions in order to affect how decoding errors are handled. The value is ultimately consumed by the Python `decode @@ -352,13 +287,13 @@ Fine grained control over cx_Oracle data binding performance with setinputsizes ------------------------------------------------------------------------------- The cx_Oracle DBAPI has a deep and fundamental reliance upon the usage of the -DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the +DBAPI ``setinputsizes()`` call. The purpose of this call is to establish the datatypes that are bound to a SQL statement for Python values being passed as parameters. While virtually no other DBAPI assigns any use to the ``setinputsizes()`` call, the cx_Oracle DBAPI relies upon it heavily in its -interactions with the Oracle Database client interface, and in some scenarios -it is not possible for SQLAlchemy to know exactly how data should be bound, as -some settings can cause profoundly different performance characteristics, while +interactions with the Oracle client interface, and in some scenarios it is not +possible for SQLAlchemy to know exactly how data should be bound, as some +settings can cause profoundly different performance characteristics, while altering the type coercion behavior at the same time. Users of the cx_Oracle dialect are **strongly encouraged** to read through @@ -387,16 +322,13 @@ objects which have a ``.key`` and a ``.type`` attribute:: engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") - @event.listens_for(engine, "do_setinputsizes") def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): for bindparam, dbapitype in inputsizes.items(): - log.info( - "Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s", - bindparam.key, - bindparam.type, - dbapitype, - ) + log.info( + "Bound parameter name: %s SQLAlchemy type: %r " + "DBAPI object: %s", + bindparam.key, bindparam.type, dbapitype) Example 2 - remove all bindings to CLOB ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -410,28 +342,12 @@ series. This setting can be modified as follows:: engine = create_engine("oracle+cx_oracle://scott:tiger@host/xe") - @event.listens_for(engine, "do_setinputsizes") def _remove_clob(inputsizes, cursor, statement, parameters, context): for bindparam, dbapitype in list(inputsizes.items()): if dbapitype is CLOB: del inputsizes[bindparam] -.. _cx_oracle_lob: - -LOB Datatypes --------------- - -LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and -BLOB. Modern versions of cx_Oracle is optimized for these datatypes to be -delivered as a single buffer. As such, SQLAlchemy makes use of these newer type -handlers by default. - -To disable the use of newer type handlers and deliver LOB objects as classic -buffered objects with a ``read()`` method, the parameter -``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`, -which takes place only engine-wide. - .. _cx_oracle_returning: RETURNING Support @@ -440,12 +356,29 @@ RETURNING Support The cx_Oracle dialect implements RETURNING using OUT parameters. The dialect supports RETURNING fully. -Two Phase Transactions Not Supported ------------------------------------- +.. _cx_oracle_lob: -Two phase transactions are **not supported** under cx_Oracle due to poor driver -support. The newer :ref:`oracledb` dialect however **does** support two phase -transactions. +LOB Datatypes +-------------- + +LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and +BLOB. Modern versions of cx_Oracle and oracledb are optimized for these +datatypes to be delivered as a single buffer. As such, SQLAlchemy makes use of +these newer type handlers by default. + +To disable the use of newer type handlers and deliver LOB objects as classic +buffered objects with a ``read()`` method, the parameter +``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`, +which takes place only engine-wide. + +Two Phase Transactions Not Supported +------------------------------------- + +Two phase transactions are **not supported** under cx_Oracle due to poor +driver support. As of cx_Oracle 6.0b1, the interface for +two phase transactions has been changed to be more of a direct pass-through +to the underlying OCI layer with less automation. The additional logic +to support this system is not implemented in SQLAlchemy. .. _cx_oracle_numeric: @@ -456,21 +389,20 @@ SQLAlchemy's numeric types can handle receiving and returning values as Python ``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in use, the :paramref:`.Numeric.asdecimal` flag determines if values should be -coerced to ``Decimal`` upon return, or returned as float objects. To make -matters more complicated under Oracle Database, the ``NUMBER`` type can also -represent integer values if the "scale" is zero, so the Oracle -Database-specific :class:`_oracle.NUMBER` type takes this into account as well. +coerced to ``Decimal`` upon return, or returned as float objects. To make +matters more complicated under Oracle, Oracle's ``NUMBER`` type can also +represent integer values if the "scale" is zero, so the Oracle-specific +:class:`_oracle.NUMBER` type takes this into account as well. The cx_Oracle dialect makes extensive use of connection- and cursor-level "outputtypehandler" callables in order to coerce numeric values as requested. These callables are specific to the specific flavor of :class:`.Numeric` in -use, as well as if no SQLAlchemy typing objects are present. There are -observed scenarios where Oracle Database may send incomplete or ambiguous -information about the numeric types being returned, such as a query where the -numeric types are buried under multiple levels of subquery. The type handlers -do their best to make the right decision in all cases, deferring to the -underlying cx_Oracle DBAPI for all those cases where the driver can make the -best decision. +use, as well as if no SQLAlchemy typing objects are present. There are +observed scenarios where Oracle may sends incomplete or ambiguous information +about the numeric types being returned, such as a query where the numeric types +are buried under multiple levels of subquery. The type handlers do their best +to make the right decision in all cases, deferring to the underlying cx_Oracle +DBAPI for all those cases where the driver can make the best decision. When no typing objects are present, as when executing plain SQL strings, a default "outputtypehandler" is present which will generally return numeric @@ -882,8 +814,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): out_parameters[name] = self.cursor.var( dbtype, - # this is fine also in oracledb_async since - # the driver will await the read coroutine outconverter=lambda value: value.read(), arraysize=len_params, ) @@ -902,9 +832,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): ) for param in self.parameters: - param[quoted_bind_names.get(name, name)] = ( - out_parameters[name] - ) + param[ + quoted_bind_names.get(name, name) + ] = out_parameters[name] def _generate_cursor_outputtype_handler(self): output_handlers = {} @@ -1100,7 +1030,7 @@ class OracleDialect_cx_oracle(OracleDialect): self, auto_convert_lobs=True, coerce_to_decimal=True, - arraysize=None, + arraysize=50, encoding_errors=None, threaded=None, **kwargs, @@ -1234,9 +1164,6 @@ class OracleDialect_cx_oracle(OracleDialect): with dbapi_connection.cursor() as cursor: cursor.execute(f"ALTER SESSION SET ISOLATION_LEVEL={level}") - def detect_autocommit_setting(self, dbapi_conn) -> bool: - return bool(dbapi_conn.autocommit) - def _detect_decimal_char(self, connection): # we have the option to change this setting upon connect, # or just look at what it is upon connect and convert. @@ -1356,13 +1283,8 @@ class OracleDialect_cx_oracle(OracleDialect): cx_Oracle.CLOB, cx_Oracle.NCLOB, ): - typ = ( - cx_Oracle.DB_TYPE_VARCHAR - if default_type is cx_Oracle.CLOB - else cx_Oracle.DB_TYPE_NVARCHAR - ) return cursor.var( - typ, + cx_Oracle.DB_TYPE_NVARCHAR, _CX_ORACLE_MAGIC_LOB_SIZE, cursor.arraysize, **dialect._cursor_var_unicode_kwargs, @@ -1493,6 +1415,13 @@ class OracleDialect_cx_oracle(OracleDialect): return False def create_xid(self): + """create a two-phase transaction ID. + + this id will be passed to do_begin_twophase(), do_rollback_twophase(), + do_commit_twophase(). its format is unspecified. + + """ + id_ = random.randint(0, 2**128) return (0x1234, "%032x" % id_, "%032x" % 9) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/dictionary.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/dictionary.py index f785a66..fdf47ef 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/dictionary.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/dictionary.py @@ -1,5 +1,4 @@ -# dialects/oracle/dictionary.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/oracledb.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/oracledb.py index c09d2ba..7defbc9 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/oracledb.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/oracledb.py @@ -1,639 +1,68 @@ -# dialects/oracle/oracledb.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors -r""".. dialect:: oracle+oracledb +r""" +.. dialect:: oracle+oracledb :name: python-oracledb :dbapi: oracledb :connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] :url: https://oracle.github.io/python-oracledb/ -Description ------------ +python-oracledb is released by Oracle to supersede the cx_Oracle driver. +It is fully compatible with cx_Oracle and features both a "thin" client +mode that requires no dependencies, as well as a "thick" mode that uses +the Oracle Client Interface in the same way as cx_Oracle. -Python-oracledb is the Oracle Database driver for Python. It features a default -"thin" client mode that requires no dependencies, and an optional "thick" mode -that uses Oracle Client libraries. It supports SQLAlchemy features including -two phase transactions and Asyncio. +.. seealso:: -Python-oracle is the renamed, updated cx_Oracle driver. Oracle is no longer -doing any releases in the cx_Oracle namespace. - -The SQLAlchemy ``oracledb`` dialect provides both a sync and an async -implementation under the same dialect name. The proper version is -selected depending on how the engine is created: - -* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will - automatically select the sync version:: - - from sqlalchemy import create_engine - - sync_engine = create_engine( - "oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1" - ) - -* calling :func:`_asyncio.create_async_engine` with ``oracle+oracledb://...`` - will automatically select the async version:: - - from sqlalchemy.ext.asyncio import create_async_engine - - asyncio_engine = create_async_engine( - "oracle+oracledb://scott:tiger@localhost?service_name=FREEPDB1" - ) - - The asyncio version of the dialect may also be specified explicitly using the - ``oracledb_async`` suffix:: - - from sqlalchemy.ext.asyncio import create_async_engine - - asyncio_engine = create_async_engine( - "oracle+oracledb_async://scott:tiger@localhost?service_name=FREEPDB1" - ) - -.. versionadded:: 2.0.25 added support for the async version of oracledb. + :ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver + as well. Thick mode support ------------------ -By default, the python-oracledb driver runs in a "thin" mode that does not -require Oracle Client libraries to be installed. The driver also supports a -"thick" mode that uses Oracle Client libraries to get functionality such as -Oracle Application Continuity. +By default the ``python-oracledb`` is started in thin mode, that does not +require oracle client libraries to be installed in the system. The +``python-oracledb`` driver also support a "thick" mode, that behaves +similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI) +is installed. -To enable thick mode, call `oracledb.init_oracle_client() -`_ -explicitly, or pass the parameter ``thick_mode=True`` to -:func:`_sa.create_engine`. To pass custom arguments to -``init_oracle_client()``, like the ``lib_dir`` path, a dict may be passed, for -example:: +To enable this mode, the user may call ``oracledb.init_oracle_client`` +manually, or by passing the parameter ``thick_mode=True`` to +:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``, +like the ``lib_dir`` path, a dict may be passed to this parameter, as in:: - engine = sa.create_engine( - "oracle+oracledb://...", - thick_mode={ - "lib_dir": "/path/to/oracle/client/lib", - "config_dir": "/path/to/network_config_file_directory", - "driver_name": "my-app : 1.0.0", - }, - ) - -Note that passing a ``lib_dir`` path should only be done on macOS or -Windows. On Linux it does not behave as you might expect. + engine = sa.create_engine("oracle+oracledb://...", thick_mode={ + "lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app" + }) .. seealso:: - python-oracledb documentation `Enabling python-oracledb Thick mode - `_ + https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client -Connecting to Oracle Database ------------------------------ -python-oracledb provides several methods of indicating the target database. -The dialect translates from a series of different URL forms. - -Given the hostname, port and service name of the target database, you can -connect in SQLAlchemy using the ``service_name`` query string parameter:: - - engine = create_engine( - "oracle+oracledb://scott:tiger@hostname:port?service_name=myservice" - ) - -Connecting with Easy Connect strings -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can pass any valid python-oracledb connection string as the ``dsn`` key -value in a :paramref:`_sa.create_engine.connect_args` dictionary. See -python-oracledb documentation `Oracle Net Services Connection Strings -`_. - -For example to use an `Easy Connect string -`_ -with a timeout to prevent connection establishment from hanging if the network -transport to the database cannot be establishd in 30 seconds, and also setting -a keep-alive time of 60 seconds to stop idle network connections from being -terminated by a firewall:: - - e = create_engine( - "oracle+oracledb://@", - connect_args={ - "user": "scott", - "password": "tiger", - "dsn": "hostname:port/myservice?transport_connect_timeout=30&expire_time=60", - }, - ) - -The Easy Connect syntax has been enhanced during the life of Oracle Database. -Review the documentation for your database version. The current documentation -is at `Understanding the Easy Connect Naming Method -`_. - -The general syntax is similar to: - -.. sourcecode:: text - - [[protocol:]//]host[:port][/[service_name]][?parameter_name=value{¶meter_name=value}] - -Note that although the SQLAlchemy URL syntax ``hostname:port/dbname`` looks -like Oracle's Easy Connect syntax, it is different. SQLAlchemy's URL requires a -system identifier (SID) for the ``dbname`` component:: - - engine = create_engine("oracle+oracledb://scott:tiger@hostname:port/sid") - -Easy Connect syntax does not support SIDs. It uses services names, which are -the preferred choice for connecting to Oracle Database. - -Passing python-oracledb connect arguments -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Other python-oracledb driver `connection options -`_ -can be passed in ``connect_args``. For example:: - - e = create_engine( - "oracle+oracledb://@", - connect_args={ - "user": "scott", - "password": "tiger", - "dsn": "hostname:port/myservice", - "events": True, - "mode": oracledb.AUTH_MODE_SYSDBA, - }, - ) - -Connecting with tnsnames.ora TNS aliases -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If no port, database name, or service name is provided, the dialect will use an -Oracle Database DSN "connection string". This takes the "hostname" portion of -the URL as the data source name. For example, if the ``tnsnames.ora`` file -contains a `TNS Alias -`_ -of ``myalias`` as below: - -.. sourcecode:: text - - myalias = - (DESCRIPTION = - (ADDRESS = (PROTOCOL = TCP)(HOST = mymachine.example.com)(PORT = 1521)) - (CONNECT_DATA = - (SERVER = DEDICATED) - (SERVICE_NAME = orclpdb1) - ) - ) - -The python-oracledb dialect connects to this database service when ``myalias`` is the -hostname portion of the URL, without specifying a port, database name or -``service_name``:: - - engine = create_engine("oracle+oracledb://scott:tiger@myalias") - -Connecting to Oracle Autonomous Database -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Users of Oracle Autonomous Database should use either use the TNS Alias URL -shown above, or pass the TNS Alias as the ``dsn`` key value in a -:paramref:`_sa.create_engine.connect_args` dictionary. - -If Oracle Autonomous Database is configured for mutual TLS ("mTLS") -connections, then additional configuration is required as shown in `Connecting -to Oracle Cloud Autonomous Databases -`_. In -summary, Thick mode users should configure file locations and set the wallet -path in ``sqlnet.ora`` appropriately:: - - e = create_engine( - "oracle+oracledb://@", - thick_mode={ - # directory containing tnsnames.ora and cwallet.so - "config_dir": "/opt/oracle/wallet_dir", - }, - connect_args={ - "user": "scott", - "password": "tiger", - "dsn": "mydb_high", - }, - ) - -Thin mode users of mTLS should pass the appropriate directories and PEM wallet -password when creating the engine, similar to:: - - e = create_engine( - "oracle+oracledb://@", - connect_args={ - "user": "scott", - "password": "tiger", - "dsn": "mydb_high", - "config_dir": "/opt/oracle/wallet_dir", # directory containing tnsnames.ora - "wallet_location": "/opt/oracle/wallet_dir", # directory containing ewallet.pem - "wallet_password": "top secret", # password for the PEM file - }, - ) - -Typically ``config_dir`` and ``wallet_location`` are the same directory, which -is where the Oracle Autonomous Database wallet zip file was extracted. Note -this directory should be protected. - -Connection Pooling ------------------- - -Applications with multiple concurrent users should use connection pooling. A -minimal sized connection pool is also beneficial for long-running, single-user -applications that do not frequently use a connection. - -The python-oracledb driver provides its own connection pool implementation that -may be used in place of SQLAlchemy's pooling functionality. The driver pool -gives support for high availability features such as dead connection detection, -connection draining for planned database downtime, support for Oracle -Application Continuity and Transparent Application Continuity, and gives -support for `Database Resident Connection Pooling (DRCP) -`_. - -To take advantage of python-oracledb's pool, use the -:paramref:`_sa.create_engine.creator` parameter to provide a function that -returns a new connection, along with setting -:paramref:`_sa.create_engine.pool_class` to ``NullPool`` to disable -SQLAlchemy's pooling:: - - import oracledb - from sqlalchemy import create_engine - from sqlalchemy import text - from sqlalchemy.pool import NullPool - - # Uncomment to use the optional python-oracledb Thick mode. - # Review the python-oracledb doc for the appropriate parameters - # oracledb.init_oracle_client() - - pool = oracledb.create_pool( - user="scott", - password="tiger", - dsn="localhost:1521/freepdb1", - min=1, - max=4, - increment=1, - ) - engine = create_engine( - "oracle+oracledb://", creator=pool.acquire, poolclass=NullPool - ) - -The above engine may then be used normally. Internally, python-oracledb handles -connection pooling:: - - with engine.connect() as conn: - print(conn.scalar(text("select 1 from dual"))) - -Refer to the python-oracledb documentation for `oracledb.create_pool() -`_ -for the arguments that can be used when creating a connection pool. - -.. _drcp: - -Using Oracle Database Resident Connection Pooling (DRCP) --------------------------------------------------------- - -When using Oracle Database's Database Resident Connection Pooling (DRCP), the -best practice is to specify a connection class and "purity". Refer to the -`python-oracledb documentation on DRCP -`_. -For example:: - - import oracledb - from sqlalchemy import create_engine - from sqlalchemy import text - from sqlalchemy.pool import NullPool - - # Uncomment to use the optional python-oracledb Thick mode. - # Review the python-oracledb doc for the appropriate parameters - # oracledb.init_oracle_client() - - pool = oracledb.create_pool( - user="scott", - password="tiger", - dsn="localhost:1521/freepdb1", - min=1, - max=4, - increment=1, - cclass="MYCLASS", - purity=oracledb.PURITY_SELF, - ) - engine = create_engine( - "oracle+oracledb://", creator=pool.acquire, poolclass=NullPool - ) - -The above engine may then be used normally where python-oracledb handles -application connection pooling and Oracle Database additionally uses DRCP:: - - with engine.connect() as conn: - print(conn.scalar(text("select 1 from dual"))) - -If you wish to use different connection classes or purities for different -connections, then wrap ``pool.acquire()``:: - - import oracledb - from sqlalchemy import create_engine - from sqlalchemy import text - from sqlalchemy.pool import NullPool - - # Uncomment to use python-oracledb Thick mode. - # Review the python-oracledb doc for the appropriate parameters - # oracledb.init_oracle_client() - - pool = oracledb.create_pool( - user="scott", - password="tiger", - dsn="localhost:1521/freepdb1", - min=1, - max=4, - increment=1, - cclass="MYCLASS", - purity=oracledb.PURITY_SELF, - ) - - - def creator(): - return pool.acquire(cclass="MYOTHERCLASS", purity=oracledb.PURITY_NEW) - - - engine = create_engine( - "oracle+oracledb://", creator=creator, poolclass=NullPool - ) - -Engine Options consumed by the SQLAlchemy oracledb dialect outside of the driver --------------------------------------------------------------------------------- - -There are also options that are consumed by the SQLAlchemy oracledb dialect -itself. These options are always passed directly to :func:`_sa.create_engine`, -such as:: - - e = create_engine("oracle+oracledb://user:pass@tnsalias", arraysize=500) - -The parameters accepted by the oracledb dialect are as follows: - -* ``arraysize`` - set the driver cursor.arraysize value. It defaults to - ``None``, indicating that the driver default value of 100 should be used. - This setting controls how many rows are buffered when fetching rows, and can - have a significant effect on performance if increased for queries that return - large numbers of rows. - - .. versionchanged:: 2.0.26 - changed the default value from 50 to None, - to use the default value of the driver itself. - -* ``auto_convert_lobs`` - defaults to True; See :ref:`oracledb_lob`. - -* ``coerce_to_decimal`` - see :ref:`oracledb_numeric` for detail. - -* ``encoding_errors`` - see :ref:`oracledb_unicode_encoding_errors` for detail. - -.. _oracledb_unicode: - -Unicode -------- - -As is the case for all DBAPIs under Python 3, all strings are inherently -Unicode strings. - -Ensuring the Correct Client Encoding -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -In python-oracledb, the encoding used for all character data is "UTF-8". - -Unicode-specific Column datatypes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The Core expression language handles unicode data by use of the -:class:`.Unicode` and :class:`.UnicodeText` datatypes. These types correspond -to the VARCHAR2 and CLOB Oracle Database datatypes by default. When using -these datatypes with Unicode data, it is expected that the database is -configured with a Unicode-aware character set so that the VARCHAR2 and CLOB -datatypes can accommodate the data. - -In the case that Oracle Database is not configured with a Unicode character -set, the two options are to use the :class:`_types.NCHAR` and -:class:`_oracle.NCLOB` datatypes explicitly, or to pass the flag -``use_nchar_for_unicode=True`` to :func:`_sa.create_engine`, which will cause -the SQLAlchemy dialect to use NCHAR/NCLOB for the :class:`.Unicode` / -:class:`.UnicodeText` datatypes instead of VARCHAR/CLOB. - -.. versionchanged:: 1.3 The :class:`.Unicode` and :class:`.UnicodeText` - datatypes now correspond to the ``VARCHAR2`` and ``CLOB`` Oracle Database - datatypes unless the ``use_nchar_for_unicode=True`` is passed to the dialect - when :func:`_sa.create_engine` is called. - - -.. _oracledb_unicode_encoding_errors: - -Encoding Errors -^^^^^^^^^^^^^^^ - -For the unusual case that data in Oracle Database is present with a broken -encoding, the dialect accepts a parameter ``encoding_errors`` which will be -passed to Unicode decoding functions in order to affect how decoding errors are -handled. The value is ultimately consumed by the Python `decode -`_ function, and -is passed both via python-oracledb's ``encodingErrors`` parameter consumed by -``Cursor.var()``, as well as SQLAlchemy's own decoding function, as the -python-oracledb dialect makes use of both under different circumstances. - -.. versionadded:: 1.3.11 - - -.. _oracledb_setinputsizes: - -Fine grained control over python-oracledb data binding with setinputsizes -------------------------------------------------------------------------- - -The python-oracle DBAPI has a deep and fundamental reliance upon the usage of -the DBAPI ``setinputsizes()`` call. The purpose of this call is to establish -the datatypes that are bound to a SQL statement for Python values being passed -as parameters. While virtually no other DBAPI assigns any use to the -``setinputsizes()`` call, the python-oracledb DBAPI relies upon it heavily in -its interactions with the Oracle Database, and in some scenarios it is not -possible for SQLAlchemy to know exactly how data should be bound, as some -settings can cause profoundly different performance characteristics, while -altering the type coercion behavior at the same time. - -Users of the oracledb dialect are **strongly encouraged** to read through -python-oracledb's list of built-in datatype symbols at `Database Types -`_ -Note that in some cases, significant performance degradation can occur when -using these types vs. not. - -On the SQLAlchemy side, the :meth:`.DialectEvents.do_setinputsizes` event can -be used both for runtime visibility (e.g. logging) of the setinputsizes step as -well as to fully control how ``setinputsizes()`` is used on a per-statement -basis. - -.. versionadded:: 1.2.9 Added :meth:`.DialectEvents.setinputsizes` - - -Example 1 - logging all setinputsizes calls -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The following example illustrates how to log the intermediary values from a -SQLAlchemy perspective before they are converted to the raw ``setinputsizes()`` -parameter dictionary. The keys of the dictionary are :class:`.BindParameter` -objects which have a ``.key`` and a ``.type`` attribute:: - - from sqlalchemy import create_engine, event - - engine = create_engine( - "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" - ) - - - @event.listens_for(engine, "do_setinputsizes") - def _log_setinputsizes(inputsizes, cursor, statement, parameters, context): - for bindparam, dbapitype in inputsizes.items(): - log.info( - "Bound parameter name: %s SQLAlchemy type: %r DBAPI object: %s", - bindparam.key, - bindparam.type, - dbapitype, - ) - -Example 2 - remove all bindings to CLOB -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For performance, fetching LOB datatypes from Oracle Database is set by default -for the ``Text`` type within SQLAlchemy. This setting can be modified as -follows:: - - - from sqlalchemy import create_engine, event - from oracledb import CLOB - - engine = create_engine( - "oracle+oracledb://scott:tiger@localhost:1521?service_name=freepdb1" - ) - - - @event.listens_for(engine, "do_setinputsizes") - def _remove_clob(inputsizes, cursor, statement, parameters, context): - for bindparam, dbapitype in list(inputsizes.items()): - if dbapitype is CLOB: - del inputsizes[bindparam] - -.. _oracledb_lob: - -LOB Datatypes --------------- - -LOB datatypes refer to the "large object" datatypes such as CLOB, NCLOB and -BLOB. Oracle Database can efficiently return these datatypes as a single -buffer. SQLAlchemy makes use of type handlers to do this by default. - -To disable the use of the type handlers and deliver LOB objects as classic -buffered objects with a ``read()`` method, the parameter -``auto_convert_lobs=False`` may be passed to :func:`_sa.create_engine`. - -.. _oracledb_returning: - -RETURNING Support ------------------ - -The oracledb dialect implements RETURNING using OUT parameters. The dialect -supports RETURNING fully. - -Two Phase Transaction Support ------------------------------ - -Two phase transactions are fully supported with python-oracledb. (Thin mode -requires python-oracledb 2.3). APIs for two phase transactions are provided at -the Core level via :meth:`_engine.Connection.begin_twophase` and -:paramref:`_orm.Session.twophase` for transparent ORM use. - -.. versionchanged:: 2.0.32 added support for two phase transactions - -.. _oracledb_numeric: - -Precision Numerics ------------------- - -SQLAlchemy's numeric types can handle receiving and returning values as Python -``Decimal`` objects or float objects. When a :class:`.Numeric` object, or a -subclass such as :class:`.Float`, :class:`_oracle.DOUBLE_PRECISION` etc. is in -use, the :paramref:`.Numeric.asdecimal` flag determines if values should be -coerced to ``Decimal`` upon return, or returned as float objects. To make -matters more complicated under Oracle Database, the ``NUMBER`` type can also -represent integer values if the "scale" is zero, so the Oracle -Database-specific :class:`_oracle.NUMBER` type takes this into account as well. - -The oracledb dialect makes extensive use of connection- and cursor-level -"outputtypehandler" callables in order to coerce numeric values as requested. -These callables are specific to the specific flavor of :class:`.Numeric` in -use, as well as if no SQLAlchemy typing objects are present. There are -observed scenarios where Oracle Database may send incomplete or ambiguous -information about the numeric types being returned, such as a query where the -numeric types are buried under multiple levels of subquery. The type handlers -do their best to make the right decision in all cases, deferring to the -underlying python-oracledb DBAPI for all those cases where the driver can make -the best decision. - -When no typing objects are present, as when executing plain SQL strings, a -default "outputtypehandler" is present which will generally return numeric -values which specify precision and scale as Python ``Decimal`` objects. To -disable this coercion to decimal for performance reasons, pass the flag -``coerce_to_decimal=False`` to :func:`_sa.create_engine`:: - - engine = create_engine( - "oracle+oracledb://scott:tiger@tnsalias", coerce_to_decimal=False - ) - -The ``coerce_to_decimal`` flag only impacts the results of plain string -SQL statements that are not otherwise associated with a :class:`.Numeric` -SQLAlchemy type (or a subclass of such). - -.. versionchanged:: 1.2 The numeric handling system for the oracle dialects has - been reworked to take advantage of newer driver features as well as better - integration of outputtypehandlers. - -.. versionadded:: 2.0.0 added support for the python-oracledb driver. +.. versionadded:: 2.0.0 added support for oracledb driver. """ # noqa -from __future__ import annotations - -import collections import re -from typing import Any -from typing import TYPE_CHECKING -from . import cx_oracle as _cx_oracle +from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle from ... import exc -from ... import pool -from ...connectors.asyncio import AsyncAdapt_dbapi_connection -from ...connectors.asyncio import AsyncAdapt_dbapi_cursor -from ...connectors.asyncio import AsyncAdapt_dbapi_ss_cursor -from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection -from ...engine import default -from ...util import asbool -from ...util import await_fallback -from ...util import await_only - -if TYPE_CHECKING: - from oracledb import AsyncConnection - from oracledb import AsyncCursor -class OracleExecutionContext_oracledb( - _cx_oracle.OracleExecutionContext_cx_oracle -): - pass - - -class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): +class OracleDialect_oracledb(_OracleDialect_cx_oracle): supports_statement_cache = True - execution_ctx_cls = OracleExecutionContext_oracledb - driver = "oracledb" - _min_version = (1,) def __init__( self, auto_convert_lobs=True, coerce_to_decimal=True, - arraysize=None, + arraysize=50, encoding_errors=None, thick_mode=None, **kwargs, @@ -662,10 +91,6 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): def is_thin_mode(cls, connection): return connection.connection.dbapi_connection.thin - @classmethod - def get_async_dialect_cls(cls, url): - return OracleDialectAsync_oracledb - def _load_version(self, dbapi_module): version = (0, 0, 0) if dbapi_module is not None: @@ -675,273 +100,10 @@ class OracleDialect_oracledb(_cx_oracle.OracleDialect_cx_oracle): int(x) for x in m.group(1, 2, 3) if x is not None ) self.oracledb_ver = version - if ( - self.oracledb_ver > (0, 0, 0) - and self.oracledb_ver < self._min_version - ): + if self.oracledb_ver < (1,) and self.oracledb_ver > (0, 0, 0): raise exc.InvalidRequestError( - f"oracledb version {self._min_version} and above are supported" + "oracledb version 1 and above are supported" ) - def do_begin_twophase(self, connection, xid): - conn_xis = connection.connection.xid(*xid) - connection.connection.tpc_begin(conn_xis) - connection.connection.info["oracledb_xid"] = conn_xis - - def do_prepare_twophase(self, connection, xid): - should_commit = connection.connection.tpc_prepare() - connection.info["oracledb_should_commit"] = should_commit - - def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False - ): - if recover: - conn_xid = connection.connection.xid(*xid) - else: - conn_xid = None - connection.connection.tpc_rollback(conn_xid) - - def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False - ): - conn_xid = None - if not is_prepared: - should_commit = connection.connection.tpc_prepare() - elif recover: - conn_xid = connection.connection.xid(*xid) - should_commit = True - else: - should_commit = connection.info["oracledb_should_commit"] - if should_commit: - connection.connection.tpc_commit(conn_xid) - - def do_recover_twophase(self, connection): - return [ - # oracledb seems to return bytes - ( - fi, - gti.decode() if isinstance(gti, bytes) else gti, - bq.decode() if isinstance(bq, bytes) else bq, - ) - for fi, gti, bq in connection.connection.tpc_recover() - ] - - def _check_max_identifier_length(self, connection): - if self.oracledb_ver >= (2, 5): - max_len = connection.connection.max_identifier_length - if max_len is not None: - return max_len - return super()._check_max_identifier_length(connection) - - -class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor): - _cursor: AsyncCursor - __slots__ = () - - @property - def outputtypehandler(self): - return self._cursor.outputtypehandler - - @outputtypehandler.setter - def outputtypehandler(self, value): - self._cursor.outputtypehandler = value - - def var(self, *args, **kwargs): - return self._cursor.var(*args, **kwargs) - - def close(self): - self._rows.clear() - self._cursor.close() - - def setinputsizes(self, *args: Any, **kwargs: Any) -> Any: - return self._cursor.setinputsizes(*args, **kwargs) - - def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor: - try: - return cursor.__enter__() - except Exception as error: - self._adapt_connection._handle_exception(error) - - async def _execute_async(self, operation, parameters): - # override to not use mutex, oracledb already has a mutex - - if parameters is None: - result = await self._cursor.execute(operation) - else: - result = await self._cursor.execute(operation, parameters) - - if self._cursor.description and not self.server_side: - self._rows = collections.deque(await self._cursor.fetchall()) - return result - - async def _executemany_async( - self, - operation, - seq_of_parameters, - ): - # override to not use mutex, oracledb already has a mutex - return await self._cursor.executemany(operation, seq_of_parameters) - - def __enter__(self): - return self - - def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: - self.close() - - -class AsyncAdapt_oracledb_ss_cursor( - AsyncAdapt_dbapi_ss_cursor, AsyncAdapt_oracledb_cursor -): - __slots__ = () - - def close(self) -> None: - if self._cursor is not None: - self._cursor.close() - self._cursor = None # type: ignore - - -class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection): - _connection: AsyncConnection - __slots__ = () - - thin = True - - _cursor_cls = AsyncAdapt_oracledb_cursor - _ss_cursor_cls = None - - @property - def autocommit(self): - return self._connection.autocommit - - @autocommit.setter - def autocommit(self, value): - self._connection.autocommit = value - - @property - def outputtypehandler(self): - return self._connection.outputtypehandler - - @outputtypehandler.setter - def outputtypehandler(self, value): - self._connection.outputtypehandler = value - - @property - def version(self): - return self._connection.version - - @property - def stmtcachesize(self): - return self._connection.stmtcachesize - - @stmtcachesize.setter - def stmtcachesize(self, value): - self._connection.stmtcachesize = value - - @property - def max_identifier_length(self): - return self._connection.max_identifier_length - - def cursor(self): - return AsyncAdapt_oracledb_cursor(self) - - def ss_cursor(self): - return AsyncAdapt_oracledb_ss_cursor(self) - - def xid(self, *args: Any, **kwargs: Any) -> Any: - return self._connection.xid(*args, **kwargs) - - def tpc_begin(self, *args: Any, **kwargs: Any) -> Any: - return self.await_(self._connection.tpc_begin(*args, **kwargs)) - - def tpc_commit(self, *args: Any, **kwargs: Any) -> Any: - return self.await_(self._connection.tpc_commit(*args, **kwargs)) - - def tpc_prepare(self, *args: Any, **kwargs: Any) -> Any: - return self.await_(self._connection.tpc_prepare(*args, **kwargs)) - - def tpc_recover(self, *args: Any, **kwargs: Any) -> Any: - return self.await_(self._connection.tpc_recover(*args, **kwargs)) - - def tpc_rollback(self, *args: Any, **kwargs: Any) -> Any: - return self.await_(self._connection.tpc_rollback(*args, **kwargs)) - - -class AsyncAdaptFallback_oracledb_connection( - AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection -): - __slots__ = () - - -class OracledbAdaptDBAPI: - def __init__(self, oracledb) -> None: - self.oracledb = oracledb - - for k, v in self.oracledb.__dict__.items(): - if k != "connect": - self.__dict__[k] = v - - def connect(self, *arg, **kw): - async_fallback = kw.pop("async_fallback", False) - creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async) - - if asbool(async_fallback): - return AsyncAdaptFallback_oracledb_connection( - self, await_fallback(creator_fn(*arg, **kw)) - ) - - else: - return AsyncAdapt_oracledb_connection( - self, await_only(creator_fn(*arg, **kw)) - ) - - -class OracleExecutionContextAsync_oracledb(OracleExecutionContext_oracledb): - # restore default create cursor - create_cursor = default.DefaultExecutionContext.create_cursor - - def create_default_cursor(self): - # copy of OracleExecutionContext_cx_oracle.create_cursor - c = self._dbapi_connection.cursor() - if self.dialect.arraysize: - c.arraysize = self.dialect.arraysize - - return c - - def create_server_side_cursor(self): - c = self._dbapi_connection.ss_cursor() - if self.dialect.arraysize: - c.arraysize = self.dialect.arraysize - - return c - - -class OracleDialectAsync_oracledb(OracleDialect_oracledb): - is_async = True - supports_server_side_cursors = True - supports_statement_cache = True - execution_ctx_cls = OracleExecutionContextAsync_oracledb - - _min_version = (2,) - - # thick_mode mode is not supported by asyncio, oracledb will raise - @classmethod - def import_dbapi(cls): - import oracledb - - return OracledbAdaptDBAPI(oracledb) - - @classmethod - def get_pool_class(cls, url): - async_fallback = url.query.get("async_fallback", False) - - if asbool(async_fallback): - return pool.FallbackAsyncAdaptedQueuePool - else: - return pool.AsyncAdaptedQueuePool - - def get_driver_connection(self, connection): - return connection._connection - dialect = OracleDialect_oracledb -dialect_async = OracleDialectAsync_oracledb diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/provision.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/provision.py index 3587de9..c8599e8 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/provision.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/provision.py @@ -1,9 +1,3 @@ -# dialects/oracle/provision.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import create_engine @@ -89,7 +83,7 @@ def _oracle_drop_db(cfg, eng, ident): # cx_Oracle seems to occasionally leak open connections when a large # suite it run, even if we confirm we have zero references to # connection objects. - # while there is a "kill session" command in Oracle Database, + # while there is a "kill session" command in Oracle, # it unfortunately does not release the connection sufficiently. _ora_drop_ignore(conn, ident) _ora_drop_ignore(conn, "%s_ts1" % ident) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/types.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/types.py index 06aeaac..4f82c43 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/types.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/types.py @@ -1,5 +1,4 @@ -# dialects/oracle/types.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -64,18 +63,17 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): class FLOAT(sqltypes.FLOAT): - """Oracle Database FLOAT. + """Oracle FLOAT. This is the same as :class:`_sqltypes.FLOAT` except that - an Oracle Database -specific :paramref:`_oracle.FLOAT.binary_precision` + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` parameter is accepted, and the :paramref:`_sqltypes.Float.precision` parameter is not accepted. - Oracle Database FLOAT types indicate precision in terms of "binary - precision", which defaults to 126. For a REAL type, the value is 63. This - parameter does not cleanly map to a specific number of decimal places but - is roughly equivalent to the desired number of decimal places divided by - 0.3103. + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. .. versionadded:: 2.0 @@ -92,11 +90,10 @@ class FLOAT(sqltypes.FLOAT): r""" Construct a FLOAT - :param binary_precision: Oracle Database binary precision value to be - rendered in DDL. This may be approximated to the number of decimal - characters using the formula "decimal precision = 0.30103 * binary - precision". The default value used by Oracle Database for FLOAT / - DOUBLE PRECISION is 126. + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` @@ -111,36 +108,10 @@ class FLOAT(sqltypes.FLOAT): class BINARY_DOUBLE(sqltypes.Double): - """Implement the Oracle ``BINARY_DOUBLE`` datatype. - - This datatype differs from the Oracle ``DOUBLE`` datatype in that it - delivers a true 8-byte FP value. The datatype may be combined with a - generic :class:`.Double` datatype using :meth:`.TypeEngine.with_variant`. - - .. seealso:: - - :ref:`oracle_float_support` - - - """ - __visit_name__ = "BINARY_DOUBLE" class BINARY_FLOAT(sqltypes.Float): - """Implement the Oracle ``BINARY_FLOAT`` datatype. - - This datatype differs from the Oracle ``FLOAT`` datatype in that it - delivers a true 4-byte FP value. The datatype may be combined with a - generic :class:`.Float` datatype using :meth:`.TypeEngine.with_variant`. - - .. seealso:: - - :ref:`oracle_float_support` - - - """ - __visit_name__ = "BINARY_FLOAT" @@ -191,10 +162,10 @@ class _OracleDateLiteralRender: class DATE(_OracleDateLiteralRender, sqltypes.DateTime): - """Provide the Oracle Database DATE type. + """Provide the oracle DATE type. This type has no special Python behavior, except that it subclasses - :class:`_types.DateTime`; this is to suit the fact that the Oracle Database + :class:`_types.DateTime`; this is to suit the fact that the Oracle ``DATE`` type supports a time value. """ @@ -274,8 +245,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): class TIMESTAMP(sqltypes.TIMESTAMP): - """Oracle Database implementation of ``TIMESTAMP``, which supports - additional Oracle Database-specific modes + """Oracle implementation of ``TIMESTAMP``, which supports additional + Oracle-specific modes .. versionadded:: 2.0 @@ -285,11 +256,10 @@ class TIMESTAMP(sqltypes.TIMESTAMP): """Construct a new :class:`_oracle.TIMESTAMP`. :param timezone: boolean. Indicates that the TIMESTAMP type should - use Oracle Database's ``TIMESTAMP WITH TIME ZONE`` datatype. + use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype. :param local_timezone: boolean. Indicates that the TIMESTAMP type - should use Oracle Database's ``TIMESTAMP WITH LOCAL TIME ZONE`` - datatype. + should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype. """ @@ -302,7 +272,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): class ROWID(sqltypes.TypeEngine): - """Oracle Database ROWID type. + """Oracle ROWID type. When used in a cast() or similar, generates ROWID. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/vector.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/vector.py deleted file mode 100644 index 88d47ea..0000000 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/oracle/vector.py +++ /dev/null @@ -1,364 +0,0 @@ -# dialects/oracle/vector.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors - - -from __future__ import annotations - -import array -from dataclasses import dataclass -from enum import Enum -from typing import Optional -from typing import Union - -import sqlalchemy.types as types -from sqlalchemy.types import Float - - -class VectorIndexType(Enum): - """Enum representing different types of VECTOR index structures. - - See :ref:`oracle_vector_datatype` for background. - - .. versionadded:: 2.0.41 - - """ - - HNSW = "HNSW" - """ - The HNSW (Hierarchical Navigable Small World) index type. - """ - IVF = "IVF" - """ - The IVF (Inverted File Index) index type - """ - - -class VectorDistanceType(Enum): - """Enum representing different types of vector distance metrics. - - See :ref:`oracle_vector_datatype` for background. - - .. versionadded:: 2.0.41 - - """ - - EUCLIDEAN = "EUCLIDEAN" - """Euclidean distance (L2 norm). - - Measures the straight-line distance between two vectors in space. - """ - DOT = "DOT" - """Dot product similarity. - - Measures the algebraic similarity between two vectors. - """ - COSINE = "COSINE" - """Cosine similarity. - - Measures the cosine of the angle between two vectors. - """ - MANHATTAN = "MANHATTAN" - """Manhattan distance (L1 norm). - - Calculates the sum of absolute differences across dimensions. - """ - - -class VectorStorageFormat(Enum): - """Enum representing the data format used to store vector components. - - See :ref:`oracle_vector_datatype` for background. - - .. versionadded:: 2.0.41 - - """ - - INT8 = "INT8" - """ - 8-bit integer format. - """ - BINARY = "BINARY" - """ - Binary format. - """ - FLOAT32 = "FLOAT32" - """ - 32-bit floating-point format. - """ - FLOAT64 = "FLOAT64" - """ - 64-bit floating-point format. - """ - - -class VectorStorageType(Enum): - """Enum representing the vector type, - - See :ref:`oracle_vector_datatype` for background. - - .. versionadded:: 2.0.43 - - """ - - SPARSE = "SPARSE" - """ - A Sparse vector is a vector which has zero value for - most of its dimensions. - """ - DENSE = "DENSE" - """ - A Dense vector is a vector where most, if not all, elements - hold meaningful values. - """ - - -@dataclass -class VectorIndexConfig: - """Define the configuration for Oracle VECTOR Index. - - See :ref:`oracle_vector_datatype` for background. - - .. versionadded:: 2.0.41 - - :param index_type: Enum value from :class:`.VectorIndexType` - Specifies the indexing method. For HNSW, this must be - :attr:`.VectorIndexType.HNSW`. - - :param distance: Enum value from :class:`.VectorDistanceType` - specifies the metric for calculating distance between VECTORS. - - :param accuracy: interger. Should be in the range 0 to 100 - Specifies the accuracy of the nearest neighbor search during - query execution. - - :param parallel: integer. Specifies degree of parallelism. - - :param hnsw_neighbors: interger. Should be in the range 0 to - 2048. Specifies the number of nearest neighbors considered - during the search. The attribute :attr:`.VectorIndexConfig.hnsw_neighbors` - is HNSW index specific. - - :param hnsw_efconstruction: integer. Should be in the range 0 - to 65535. Controls the trade-off between indexing speed and - recall quality during index construction. The attribute - :attr:`.VectorIndexConfig.hnsw_efconstruction` is HNSW index - specific. - - :param ivf_neighbor_partitions: integer. Should be in the range - 0 to 10,000,000. Specifies the number of partitions used to - divide the dataset. The attribute - :attr:`.VectorIndexConfig.ivf_neighbor_partitions` is IVF index - specific. - - :param ivf_sample_per_partition: integer. Should be between 1 - and ``num_vectors / neighbor partitions``. Specifies the - number of samples used per partition. The attribute - :attr:`.VectorIndexConfig.ivf_sample_per_partition` is IVF index - specific. - - :param ivf_min_vectors_per_partition: integer. From 0 (no trimming) - to the total number of vectors (results in 1 partition). Specifies - the minimum number of vectors per partition. The attribute - :attr:`.VectorIndexConfig.ivf_min_vectors_per_partition` - is IVF index specific. - - """ - - index_type: VectorIndexType = VectorIndexType.HNSW - distance: Optional[VectorDistanceType] = None - accuracy: Optional[int] = None - hnsw_neighbors: Optional[int] = None - hnsw_efconstruction: Optional[int] = None - ivf_neighbor_partitions: Optional[int] = None - ivf_sample_per_partition: Optional[int] = None - ivf_min_vectors_per_partition: Optional[int] = None - parallel: Optional[int] = None - - def __post_init__(self): - self.index_type = VectorIndexType(self.index_type) - for field in [ - "hnsw_neighbors", - "hnsw_efconstruction", - "ivf_neighbor_partitions", - "ivf_sample_per_partition", - "ivf_min_vectors_per_partition", - "parallel", - "accuracy", - ]: - value = getattr(self, field) - if value is not None and not isinstance(value, int): - raise TypeError( - f"{field} must be an integer if" - f"provided, got {type(value).__name__}" - ) - - -class SparseVector: - """ - Lightweight SQLAlchemy-side version of SparseVector. - This mimics oracledb.SparseVector. - - .. versionadded:: 2.0.43 - - """ - - def __init__( - self, - num_dimensions: int, - indices: Union[list, array.array], - values: Union[list, array.array], - ): - if not isinstance(indices, array.array) or indices.typecode != "I": - indices = array.array("I", indices) - if not isinstance(values, array.array): - values = array.array("d", values) - if len(indices) != len(values): - raise TypeError("indices and values must be of the same length!") - - self.num_dimensions = num_dimensions - self.indices = indices - self.values = values - - def __str__(self): - return ( - f"SparseVector(num_dimensions={self.num_dimensions}, " - f"size={len(self.indices)}, typecode={self.values.typecode})" - ) - - -class VECTOR(types.TypeEngine): - """Oracle VECTOR datatype. - - For complete background on using this type, see - :ref:`oracle_vector_datatype`. - - .. versionadded:: 2.0.41 - - """ - - cache_ok = True - __visit_name__ = "VECTOR" - - _typecode_map = { - VectorStorageFormat.INT8: "b", # Signed int - VectorStorageFormat.BINARY: "B", # Unsigned int - VectorStorageFormat.FLOAT32: "f", # Float - VectorStorageFormat.FLOAT64: "d", # Double - } - - def __init__(self, dim=None, storage_format=None, storage_type=None): - """Construct a VECTOR. - - :param dim: integer. The dimension of the VECTOR datatype. This - should be an integer value. - - :param storage_format: VectorStorageFormat. The VECTOR storage - type format. This should be Enum values form - :class:`.VectorStorageFormat` INT8, BINARY, FLOAT32, or FLOAT64. - - :param storage_type: VectorStorageType. The Vector storage type. This - should be Enum values from :class:`.VectorStorageType` SPARSE or - DENSE. - - """ - - if dim is not None and not isinstance(dim, int): - raise TypeError("dim must be an interger") - if storage_format is not None and not isinstance( - storage_format, VectorStorageFormat - ): - raise TypeError( - "storage_format must be an enum of type VectorStorageFormat" - ) - if storage_type is not None and not isinstance( - storage_type, VectorStorageType - ): - raise TypeError( - "storage_type must be an enum of type VectorStorageType" - ) - - self.dim = dim - self.storage_format = storage_format - self.storage_type = storage_type - - def _cached_bind_processor(self, dialect): - """ - Converts a Python-side SparseVector instance into an - oracledb.SparseVectormor a compatible array format before - binding it to the database. - """ - - def process(value): - if value is None or isinstance(value, array.array): - return value - - # Convert list to a array.array - elif isinstance(value, list): - typecode = self._array_typecode(self.storage_format) - value = array.array(typecode, value) - return value - - # Convert SqlAlchemy SparseVector to oracledb SparseVector object - elif isinstance(value, SparseVector): - return dialect.dbapi.SparseVector( - value.num_dimensions, - value.indices, - value.values, - ) - - else: - raise TypeError( - """ - Invalid input for VECTOR: expected a list, an array.array, - or a SparseVector object. - """ - ) - - return process - - def _cached_result_processor(self, dialect, coltype): - """ - Converts database-returned values into Python-native representations. - If the value is an oracledb.SparseVector, it is converted into the - SQLAlchemy-side SparseVector class. - If the value is a array.array, it is converted to a plain Python list. - - """ - - def process(value): - if value is None: - return None - - elif isinstance(value, array.array): - return list(value) - - # Convert Oracledb SparseVector to SqlAlchemy SparseVector object - elif isinstance(value, dialect.dbapi.SparseVector): - return SparseVector( - num_dimensions=value.num_dimensions, - indices=value.indices, - values=value.values, - ) - - return process - - def _array_typecode(self, typecode): - """ - Map storage format to array typecode. - """ - return self._typecode_map.get(typecode, "d") - - class comparator_factory(types.TypeEngine.Comparator): - def l2_distance(self, other): - return self.op("<->", return_type=Float)(other) - - def inner_product(self, other): - return self.op("<#>", return_type=Float)(other) - - def cosine_distance(self, other): - return self.op("<=>", return_type=Float)(other) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/__init__.py index 88935e2..c3ed7c1 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/__init__.py @@ -1,5 +1,5 @@ -# dialects/postgresql/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,7 +8,6 @@ from types import ModuleType -from . import array as arraylib # noqa # keep above base and other dialects from . import asyncpg # noqa from . import base from . import pg8000 # noqa @@ -57,14 +56,12 @@ from .named_types import ENUM from .named_types import NamedType from .ranges import AbstractMultiRange from .ranges import AbstractRange -from .ranges import AbstractSingleRange from .ranges import DATEMULTIRANGE from .ranges import DATERANGE from .ranges import INT4MULTIRANGE from .ranges import INT4RANGE from .ranges import INT8MULTIRANGE from .ranges import INT8RANGE -from .ranges import MultiRange from .ranges import NUMMULTIRANGE from .ranges import NUMRANGE from .ranges import Range @@ -89,7 +86,6 @@ from .types import TIMESTAMP from .types import TSQUERY from .types import TSVECTOR - # Alias psycopg also as psycopg_async psycopg_async = type( "psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async} diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py index 0ff301e..dfb25a5 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -1,5 +1,4 @@ -# dialects/postgresql/_psycopg_common.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -170,10 +169,8 @@ class _PGDialect_common_psycopg(PGDialect): def _do_autocommit(self, connection, value): connection.autocommit = value - def detect_autocommit_setting(self, dbapi_connection): - return bool(dbapi_connection.autocommit) - def do_ping(self, dbapi_connection): + cursor = None before_autocommit = dbapi_connection.autocommit if not before_autocommit: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/array.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/array.py index 96f6dc2..3496ed6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/array.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/array.py @@ -1,21 +1,18 @@ -# dialects/postgresql/array.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/array.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors from __future__ import annotations import re -from typing import Any as typing_Any -from typing import Iterable +from typing import Any from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING from typing import TypeVar -from typing import Union from .operators import CONTAINED_BY from .operators import CONTAINS @@ -24,55 +21,32 @@ from ... import types as sqltypes from ... import util from ...sql import expression from ...sql import operators -from ...sql.visitors import InternalTraversal - -if TYPE_CHECKING: - from ...engine.interfaces import Dialect - from ...sql._typing import _ColumnExpressionArgument - from ...sql._typing import _TypeEngineArgument - from ...sql.elements import ColumnElement - from ...sql.elements import Grouping - from ...sql.expression import BindParameter - from ...sql.operators import OperatorType - from ...sql.selectable import _SelectIterable - from ...sql.type_api import _BindProcessorType - from ...sql.type_api import _LiteralProcessorType - from ...sql.type_api import _ResultProcessorType - from ...sql.type_api import TypeEngine - from ...sql.visitors import _TraverseInternalsType - from ...util.typing import Self +from ...sql._typing import _TypeEngineArgument -_T = TypeVar("_T", bound=typing_Any) +_T = TypeVar("_T", bound=Any) -def Any( - other: typing_Any, - arrexpr: _ColumnExpressionArgument[_T], - operator: OperatorType = operators.eq, -) -> ColumnElement[bool]: +def Any(other, arrexpr, operator=operators.eq): """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method. See that method for details. """ - return arrexpr.any(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 + return arrexpr.any(other, operator) -def All( - other: typing_Any, - arrexpr: _ColumnExpressionArgument[_T], - operator: OperatorType = operators.eq, -) -> ColumnElement[bool]: +def All(other, arrexpr, operator=operators.eq): """A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method. See that method for details. """ - return arrexpr.all(other, operator) # type: ignore[no-any-return, union-attr] # noqa: E501 + return arrexpr.all(other, operator) class array(expression.ExpressionClauseList[_T]): + """A PostgreSQL ARRAY literal. This is used to produce ARRAY literals in SQL expressions, e.g.:: @@ -81,43 +55,20 @@ class array(expression.ExpressionClauseList[_T]): from sqlalchemy.dialects import postgresql from sqlalchemy import select, func - stmt = select(array([1, 2]) + array([3, 4, 5])) + stmt = select(array([1,2]) + array([3,4,5])) print(stmt.compile(dialect=postgresql.dialect())) - Produces the SQL: - - .. sourcecode:: sql + Produces the SQL:: SELECT ARRAY[%(param_1)s, %(param_2)s] || ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1 An instance of :class:`.array` will always have the datatype - :class:`_types.ARRAY`. The "inner" type of the array is inferred from the - values present, unless the :paramref:`_postgresql.array.type_` keyword - argument is passed:: + :class:`_types.ARRAY`. The "inner" type of the array is inferred from + the values present, unless the ``type_`` keyword argument is passed:: - array(["foo", "bar"], type_=CHAR) - - When constructing an empty array, the :paramref:`_postgresql.array.type_` - argument is particularly important as PostgreSQL server typically requires - a cast to be rendered for the inner type in order to render an empty array. - SQLAlchemy's compilation for the empty array will produce this cast so - that:: - - stmt = array([], type_=Integer) - print(stmt.compile(dialect=postgresql.dialect())) - - Produces: - - .. sourcecode:: sql - - ARRAY[]::INTEGER[] - - As required by PostgreSQL for empty arrays. - - .. versionadded:: 2.0.40 added support to render empty PostgreSQL array - literals with a required cast. + array(['foo', 'bar'], type_=CHAR) Multidimensional arrays are produced by nesting :class:`.array` constructs. The dimensionality of the final :class:`_types.ARRAY` @@ -126,21 +77,16 @@ class array(expression.ExpressionClauseList[_T]): type:: stmt = select( - array( - [array([1, 2]), array([3, 4]), array([column("q"), column("x")])] - ) + array([ + array([1, 2]), array([3, 4]), array([column('q'), column('x')]) + ]) ) print(stmt.compile(dialect=postgresql.dialect())) - Produces: + Produces:: - .. sourcecode:: sql - - SELECT ARRAY[ - ARRAY[%(param_1)s, %(param_2)s], - ARRAY[%(param_3)s, %(param_4)s], - ARRAY[q, x] - ] AS anon_1 + SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s], + ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1 .. versionadded:: 1.3.6 added support for multidimensional array literals @@ -148,63 +94,42 @@ class array(expression.ExpressionClauseList[_T]): :class:`_postgresql.ARRAY` - """ # noqa: E501 + """ __visit_name__ = "array" stringify_dialect = "postgresql" + inherit_cache = True - _traverse_internals: _TraverseInternalsType = [ - ("clauses", InternalTraversal.dp_clauseelement_tuple), - ("type", InternalTraversal.dp_type), - ] - - def __init__( - self, - clauses: Iterable[_T], - *, - type_: Optional[_TypeEngineArgument[_T]] = None, - **kw: typing_Any, - ): - r"""Construct an ARRAY literal. - - :param clauses: iterable, such as a list, containing elements to be - rendered in the array - :param type\_: optional type. If omitted, the type is inferred - from the contents of the array. - - """ + def __init__(self, clauses, **kw): + type_arg = kw.pop("type_", None) super().__init__(operators.comma_op, *clauses, **kw) + self._type_tuple = [arg.type for arg in self.clauses] + main_type = ( - type_ - if type_ is not None - else self.clauses[0].type if self.clauses else sqltypes.NULLTYPE + type_arg + if type_arg is not None + else self._type_tuple[0] + if self._type_tuple + else sqltypes.NULLTYPE ) if isinstance(main_type, ARRAY): self.type = ARRAY( main_type.item_type, - dimensions=( - main_type.dimensions + 1 - if main_type.dimensions is not None - else 2 - ), - ) # type: ignore[assignment] + dimensions=main_type.dimensions + 1 + if main_type.dimensions is not None + else 2, + ) else: - self.type = ARRAY(main_type) # type: ignore[assignment] + self.type = ARRAY(main_type) @property - def _select_iterable(self) -> _SelectIterable: + def _select_iterable(self): return (self,) - def _bind_param( - self, - operator: OperatorType, - obj: typing_Any, - type_: Optional[TypeEngine[_T]] = None, - _assume_scalar: bool = False, - ) -> BindParameter[_T]: + def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): if _assume_scalar or operator is operators.getitem: return expression.BindParameter( None, @@ -223,18 +148,16 @@ class array(expression.ExpressionClauseList[_T]): ) for o in obj ] - ) # type: ignore[return-value] + ) - def self_group( - self, against: Optional[OperatorType] = None - ) -> Union[Self, Grouping[_T]]: + def self_group(self, against=None): if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: return self -class ARRAY(sqltypes.ARRAY[_T]): +class ARRAY(sqltypes.ARRAY): """PostgreSQL ARRAY type. The :class:`_postgresql.ARRAY` type is constructed in the same way @@ -244,11 +167,9 @@ class ARRAY(sqltypes.ARRAY[_T]): from sqlalchemy.dialects import postgresql - mytable = Table( - "mytable", - metadata, - Column("data", postgresql.ARRAY(Integer, dimensions=2)), - ) + mytable = Table("mytable", metadata, + Column("data", postgresql.ARRAY(Integer, dimensions=2)) + ) The :class:`_postgresql.ARRAY` type provides all operations defined on the core :class:`_types.ARRAY` type, including support for "dimensions", @@ -263,9 +184,8 @@ class ARRAY(sqltypes.ARRAY[_T]): mytable.c.data.contains([1, 2]) - Indexed access is one-based by default, to match that of PostgreSQL; - for zero-based indexed access, set - :paramref:`_postgresql.ARRAY.zero_indexes`. + The :class:`_postgresql.ARRAY` type may not be supported on all + PostgreSQL DBAPIs; it is currently known to work on psycopg2 only. Additionally, the :class:`_postgresql.ARRAY` type does not work directly in @@ -284,7 +204,6 @@ class ARRAY(sqltypes.ARRAY[_T]): from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.ext.mutable import MutableList - class SomeOrmClass(Base): # ... @@ -306,9 +225,45 @@ class ARRAY(sqltypes.ARRAY[_T]): """ + class Comparator(sqltypes.ARRAY.Comparator): + + """Define comparison operations for :class:`_types.ARRAY`. + + Note that these operations are in addition to those provided + by the base :class:`.types.ARRAY.Comparator` class, including + :meth:`.types.ARRAY.Comparator.any` and + :meth:`.types.ARRAY.Comparator.all`. + + """ + + def contains(self, other, **kwargs): + """Boolean expression. Test if elements are a superset of the + elements of the argument array expression. + + kwargs may be ignored by this operator but are required for API + conformance. + """ + return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) + + def contained_by(self, other): + """Boolean expression. Test if elements are a proper subset of the + elements of the argument array expression. + """ + return self.operate( + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) + + def overlap(self, other): + """Boolean expression. Test if array has elements in common with + an argument array expression. + """ + return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) + + comparator_factory = Comparator + def __init__( self, - item_type: _TypeEngineArgument[_T], + item_type: _TypeEngineArgument[Any], as_tuple: bool = False, dimensions: Optional[int] = None, zero_indexes: bool = False, @@ -317,7 +272,7 @@ class ARRAY(sqltypes.ARRAY[_T]): E.g.:: - Column("myarray", ARRAY(Integer)) + Column('myarray', ARRAY(Integer)) Arguments are: @@ -357,63 +312,35 @@ class ARRAY(sqltypes.ARRAY[_T]): self.dimensions = dimensions self.zero_indexes = zero_indexes - class Comparator(sqltypes.ARRAY.Comparator[_T]): - """Define comparison operations for :class:`_types.ARRAY`. + @property + def hashable(self): + return self.as_tuple - Note that these operations are in addition to those provided - by the base :class:`.types.ARRAY.Comparator` class, including - :meth:`.types.ARRAY.Comparator.any` and - :meth:`.types.ARRAY.Comparator.all`. + @property + def python_type(self): + return list - """ - - def contains( - self, other: typing_Any, **kwargs: typing_Any - ) -> ColumnElement[bool]: - """Boolean expression. Test if elements are a superset of the - elements of the argument array expression. - - kwargs may be ignored by this operator but are required for API - conformance. - """ - return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - - def contained_by(self, other: typing_Any) -> ColumnElement[bool]: - """Boolean expression. Test if elements are a proper subset of the - elements of the argument array expression. - """ - return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean - ) - - def overlap(self, other: typing_Any) -> ColumnElement[bool]: - """Boolean expression. Test if array has elements in common with - an argument array expression. - """ - return self.operate(OVERLAP, other, result_type=sqltypes.Boolean) - - comparator_factory = Comparator + def compare_values(self, x, y): + return x == y @util.memoized_property - def _against_native_enum(self) -> bool: + def _against_native_enum(self): return ( isinstance(self.item_type, sqltypes.Enum) and self.item_type.native_enum ) - def literal_processor( - self, dialect: Dialect - ) -> Optional[_LiteralProcessorType[_T]]: + def literal_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).literal_processor( dialect ) if item_proc is None: return None - def to_str(elements: Iterable[typing_Any]) -> str: + def to_str(elements): return f"ARRAY[{', '.join(elements)}]" - def process(value: Sequence[typing_Any]) -> str: + def process(value): inner = self._apply_item_processor( value, item_proc, self.dimensions, to_str ) @@ -421,16 +348,12 @@ class ARRAY(sqltypes.ARRAY[_T]): return process - def bind_processor( - self, dialect: Dialect - ) -> Optional[_BindProcessorType[Sequence[typing_Any]]]: + def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect ) - def process( - value: Optional[Sequence[typing_Any]], - ) -> Optional[list[typing_Any]]: + def process(value): if value is None: return value else: @@ -440,16 +363,12 @@ class ARRAY(sqltypes.ARRAY[_T]): return process - def result_processor( - self, dialect: Dialect, coltype: object - ) -> _ResultProcessorType[Sequence[typing_Any]]: + def result_processor(self, dialect, coltype): item_proc = self.item_type.dialect_impl(dialect).result_processor( dialect, coltype ) - def process( - value: Sequence[typing_Any], - ) -> Optional[Sequence[typing_Any]]: + def process(value): if value is None: return value else: @@ -464,13 +383,11 @@ class ARRAY(sqltypes.ARRAY[_T]): super_rp = process pattern = re.compile(r"^{(.*)}$") - def handle_raw_string(value: str) -> list[str]: - inner = pattern.match(value).group(1) # type: ignore[union-attr] # noqa: E501 + def handle_raw_string(value): + inner = pattern.match(value).group(1) return _split_enum_values(inner) - def process( - value: Sequence[typing_Any], - ) -> Optional[Sequence[typing_Any]]: + def process(value): if value is None: return value # isinstance(value, str) is required to handle @@ -485,7 +402,7 @@ class ARRAY(sqltypes.ARRAY[_T]): return process -def _split_enum_values(array_string: str) -> list[str]: +def _split_enum_values(array_string): if '"' not in array_string: # no escape char is present so it can just split on the comma return array_string.split(",") if array_string else [] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py index 5b3073a..ca35bf9 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py @@ -1,5 +1,5 @@ -# dialects/postgresql/asyncpg.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under @@ -23,10 +23,18 @@ This dialect should normally be used only with the :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") + +The dialect can also be run as a "synchronous" dialect within the +:func:`_sa.create_engine` function, which will pass "await" calls into +an ad-hoc event loop. This mode of operation is of **limited use** +and is for special testing scenarios only. The mode can be enabled by +adding the SQLAlchemy-specific flag ``async_fallback`` to the URL +in conjunction with :func:`_sa.create_engine`:: + + # for testing purposes only; do not use in production! + engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") - engine = create_async_engine( - "postgresql+asyncpg://user:pass@hostname/dbname" - ) .. versionadded:: 1.4 @@ -81,15 +89,11 @@ asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect argument):: - engine = create_async_engine( - "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500" - ) + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") To disable the prepared statement cache, use a value of zero:: - engine = create_async_engine( - "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0" - ) + engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") .. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. @@ -119,8 +123,8 @@ To disable the prepared statement cache, use a value of zero:: .. _asyncpg_prepared_statement_name: -Prepared Statement Name with PGBouncer --------------------------------------- +Prepared Statement Name +----------------------- By default, asyncpg enumerates prepared statements in numeric order, which can lead to errors if a name has already been taken for another prepared @@ -135,10 +139,10 @@ a prepared statement is prepared:: from uuid import uuid4 engine = create_async_engine( - "postgresql+asyncpg://user:pass@somepgbouncer/dbname", + "postgresql+asyncpg://user:pass@hostname/dbname", poolclass=NullPool, connect_args={ - "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__", + 'prepared_statement_name_func': lambda: f'__asyncpg_{uuid4()}__', }, ) @@ -148,7 +152,7 @@ a prepared statement is prepared:: https://github.com/sqlalchemy/sqlalchemy/issues/6467 -.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in +.. warning:: To prevent a buildup of useless prepared statements in your application, it's important to use the :class:`.NullPool` pool class, and to configure PgBouncer to use `DISCARD `_ when returning connections. The DISCARD command is used to release resources held by the db connection, @@ -178,11 +182,13 @@ client using this setting passed to :func:`_asyncio.create_async_engine`:: from __future__ import annotations -from collections import deque +import collections import decimal import json as _py_json import re import time +from typing import cast +from typing import TYPE_CHECKING from . import json from . import ranges @@ -212,6 +218,9 @@ from ...util.concurrency import asyncio from ...util.concurrency import await_fallback from ...util.concurrency import await_only +if TYPE_CHECKING: + from typing import Iterable + class AsyncpgARRAY(PGARRAY): render_bind_cast = True @@ -265,20 +274,20 @@ class AsyncpgInteger(sqltypes.Integer): render_bind_cast = True -class AsyncpgSmallInteger(sqltypes.SmallInteger): - render_bind_cast = True - - class AsyncpgBigInteger(sqltypes.BigInteger): render_bind_cast = True class AsyncpgJSON(json.JSON): + render_bind_cast = True + def result_processor(self, dialect, coltype): return None class AsyncpgJSONB(json.JSONB): + render_bind_cast = True + def result_processor(self, dialect, coltype): return None @@ -363,7 +372,7 @@ class AsyncpgCHAR(sqltypes.CHAR): render_bind_cast = True -class _AsyncpgRange(ranges.AbstractSingleRangeImpl): +class _AsyncpgRange(ranges.AbstractRangeImpl): def bind_processor(self, dialect): asyncpg_Range = dialect.dbapi.asyncpg.Range @@ -417,7 +426,10 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): ) return value - return [to_range(element) for element in value] + return [ + to_range(element) + for element in cast("Iterable[ranges.Range]", value) + ] return to_range @@ -436,7 +448,7 @@ class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl): return rvalue if value is not None: - value = ranges.MultiRange(to_range(elem) for elem in value) + value = [to_range(elem) for elem in value] return value @@ -494,7 +506,7 @@ class AsyncAdapt_asyncpg_cursor: def __init__(self, adapt_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection - self._rows = deque() + self._rows = [] self._cursor = None self.description = None self.arraysize = 1 @@ -502,7 +514,7 @@ class AsyncAdapt_asyncpg_cursor: self._invalidate_schema_cache_asof = 0 def close(self): - self._rows.clear() + self._rows[:] = [] def _handle_exception(self, error): self._adapt_connection._handle_exception(error) @@ -542,12 +554,11 @@ class AsyncAdapt_asyncpg_cursor: self._cursor = await prepared_stmt.cursor(*parameters) self.rowcount = -1 else: - self._rows = deque(await prepared_stmt.fetch(*parameters)) + self._rows = await prepared_stmt.fetch(*parameters) status = prepared_stmt.get_statusmsg() reg = re.match( - r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", - status or "", + r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status ) if reg: self.rowcount = int(reg.group(1)) @@ -591,11 +602,11 @@ class AsyncAdapt_asyncpg_cursor: def __iter__(self): while self._rows: - yield self._rows.popleft() + yield self._rows.pop(0) def fetchone(self): if self._rows: - return self._rows.popleft() + return self._rows.pop(0) else: return None @@ -603,12 +614,13 @@ class AsyncAdapt_asyncpg_cursor: if size is None: size = self.arraysize - rr = self._rows - return [rr.popleft() for _ in range(min(size, len(rr)))] + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval def fetchall(self): - retval = list(self._rows) - self._rows.clear() + retval = self._rows[:] + self._rows[:] = [] return retval @@ -618,21 +630,23 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): def __init__(self, adapt_connection): super().__init__(adapt_connection) - self._rowbuffer = deque() + self._rowbuffer = None def close(self): self._cursor = None - self._rowbuffer.clear() + self._rowbuffer = None def _buffer_rows(self): - assert self._cursor is not None new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) - self._rowbuffer.extend(new_rows) + self._rowbuffer = collections.deque(new_rows) def __aiter__(self): return self async def __anext__(self): + if not self._rowbuffer: + self._buffer_rows() + while True: while self._rowbuffer: yield self._rowbuffer.popleft() @@ -655,19 +669,21 @@ class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): if not self._rowbuffer: self._buffer_rows() - assert self._cursor is not None - rb = self._rowbuffer - lb = len(rb) + buf = list(self._rowbuffer) + lb = len(buf) if size > lb: - rb.extend( + buf.extend( self._adapt_connection.await_(self._cursor.fetch(size - lb)) ) - return [rb.popleft() for _ in range(min(size, len(rb)))] + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result def fetchall(self): - ret = list(self._rowbuffer) - ret.extend(self._adapt_connection.await_(self._all())) + ret = list(self._rowbuffer) + list( + self._adapt_connection.await_(self._all()) + ) self._rowbuffer.clear() return ret @@ -717,7 +733,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): ): self.dbapi = dbapi self._connection = connection - self.isolation_level = self._isolation_setting = None + self.isolation_level = self._isolation_setting = "read_committed" self.readonly = False self.deferrable = False self._transaction = None @@ -786,9 +802,9 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): translated_error = exception_mapping[super_]( "%s: %s" % (type(error), error) ) - translated_error.pgcode = translated_error.sqlstate = ( - getattr(error, "sqlstate", None) - ) + translated_error.pgcode = ( + translated_error.sqlstate + ) = getattr(error, "sqlstate", None) raise translated_error from error else: raise error @@ -852,45 +868,25 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): else: return AsyncAdapt_asyncpg_cursor(self) - async def _rollback_and_discard(self): - try: - await self._transaction.rollback() - finally: - # if asyncpg .rollback() was actually called, then whether or - # not it raised or succeeded, the transation is done, discard it - self._transaction = None - self._started = False - - async def _commit_and_discard(self): - try: - await self._transaction.commit() - finally: - # if asyncpg .commit() was actually called, then whether or - # not it raised or succeeded, the transation is done, discard it - self._transaction = None - self._started = False - def rollback(self): if self._started: try: - self.await_(self._rollback_and_discard()) + self.await_(self._transaction.rollback()) + except Exception as error: + self._handle_exception(error) + finally: self._transaction = None self._started = False - except Exception as error: - # don't dereference asyncpg transaction if we didn't - # actually try to call rollback() on it - self._handle_exception(error) def commit(self): if self._started: try: - self.await_(self._commit_and_discard()) + self.await_(self._transaction.commit()) + except Exception as error: + self._handle_exception(error) + finally: self._transaction = None self._started = False - except Exception as error: - # don't dereference asyncpg transaction if we didn't - # actually try to call commit() on it - self._handle_exception(error) def close(self): self.rollback() @@ -898,31 +894,7 @@ class AsyncAdapt_asyncpg_connection(AdaptedConnection): self.await_(self._connection.close()) def terminate(self): - if util.concurrency.in_greenlet(): - # in a greenlet; this is the connection was invalidated - # case. - try: - # try to gracefully close; see #10717 - # timeout added in asyncpg 0.14.0 December 2017 - self.await_(asyncio.shield(self._connection.close(timeout=2))) - except ( - asyncio.TimeoutError, - asyncio.CancelledError, - OSError, - self.dbapi.asyncpg.PostgresError, - ) as e: - # in the case where we are recycling an old connection - # that may have already been disconnected, close() will - # fail with the above timeout. in this case, terminate - # the connection without any further waiting. - # see issue #8419 - self._connection.terminate() - if isinstance(e, asyncio.CancelledError): - # re-raise CancelledError if we were cancelled - raise - else: - # not in a greenlet; this is the gc cleanup case - self._connection.terminate() + self._connection.terminate() self._started = False @staticmethod @@ -1059,7 +1031,6 @@ class PGDialect_asyncpg(PGDialect): INTERVAL: AsyncPgInterval, sqltypes.Boolean: AsyncpgBoolean, sqltypes.Integer: AsyncpgInteger, - sqltypes.SmallInteger: AsyncpgSmallInteger, sqltypes.BigInteger: AsyncpgBigInteger, sqltypes.Numeric: AsyncpgNumeric, sqltypes.Float: AsyncpgFloat, @@ -1074,7 +1045,7 @@ class PGDialect_asyncpg(PGDialect): OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, sqltypes.CHAR: AsyncpgCHAR, - ranges.AbstractSingleRange: _AsyncpgRange, + ranges.AbstractRange: _AsyncpgRange, ranges.AbstractMultiRange: _AsyncpgMultiRange, }, ) @@ -1117,9 +1088,6 @@ class PGDialect_asyncpg(PGDialect): def set_isolation_level(self, dbapi_connection, level): dbapi_connection.set_isolation_level(self._isolation_lookup[level]) - def detect_autocommit_setting(self, dbapi_conn) -> bool: - return bool(dbapi_conn.autocommit) - def set_readonly(self, connection, value): connection.readonly = value diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/base.py index 25570c2..b9fd8c8 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/base.py @@ -1,5 +1,5 @@ -# dialects/postgresql/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -9,6 +9,7 @@ r""" .. dialect:: postgresql :name: PostgreSQL + :full_support: 12, 13, 14, 15 :normal_support: 9.6+ :best_effort: 9+ @@ -31,7 +32,7 @@ use the :func:`~sqlalchemy.schema.Sequence` construct:: metadata, Column( "id", Integer, Sequence("some_id_seq", start=1), primary_key=True - ), + ) ) When SQLAlchemy issues a single INSERT statement, to fulfill the contract of @@ -63,9 +64,9 @@ of SERIAL. The :class:`_schema.Identity` construct in a "data", metadata, Column( - "id", Integer, Identity(start=42, cycle=True), primary_key=True + 'id', Integer, Identity(start=42, cycle=True), primary_key=True ), - Column("data", String), + Column('data', String) ) The CREATE TABLE for the above :class:`_schema.Table` object would be: @@ -92,21 +93,23 @@ The CREATE TABLE for the above :class:`_schema.Table` object would be: from sqlalchemy.ext.compiler import compiles - @compiles(CreateColumn, "postgresql") + @compiles(CreateColumn, 'postgresql') def use_identity(element, compiler, **kw): text = compiler.visit_create_column(element, **kw) - text = text.replace("SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY") + text = text.replace( + "SERIAL", "INT GENERATED BY DEFAULT AS IDENTITY" + ) return text Using the above, a table such as:: t = Table( - "t", m, Column("id", Integer, primary_key=True), Column("data", String) + 't', m, + Column('id', Integer, primary_key=True), + Column('data', String) ) - Will generate on the backing database as: - - .. sourcecode:: sql + Will generate on the backing database as:: CREATE TABLE t ( id INT GENERATED BY DEFAULT AS IDENTITY, @@ -127,9 +130,7 @@ Server side cursors are enabled on a per-statement basis by using the option:: with engine.connect() as conn: - result = conn.execution_options(stream_results=True).execute( - text("select * from table") - ) + result = conn.execution_options(stream_results=True).execute(text("select * from table")) Note that some kinds of SQL statements may not be supported with server side cursors; generally, only SQL statements that return rows should be @@ -168,15 +169,17 @@ To set isolation level using :func:`_sa.create_engine`:: engine = create_engine( "postgresql+pg8000://scott:tiger@localhost/test", - isolation_level="REPEATABLE READ", + isolation_level = "REPEATABLE READ" ) To set using per-connection execution options:: with engine.connect() as conn: - conn = conn.execution_options(isolation_level="REPEATABLE READ") + conn = conn.execution_options( + isolation_level="REPEATABLE READ" + ) with conn.begin(): - ... # work with transaction + # ... work with transaction There are also more options for isolation level configurations, such as "sub-engine" objects linked to a main :class:`_engine.Engine` which each apply @@ -219,10 +222,10 @@ passing the ``"SERIALIZABLE"`` isolation level at the same time as setting conn = conn.execution_options( isolation_level="SERIALIZABLE", postgresql_readonly=True, - postgresql_deferrable=True, + postgresql_deferrable=True ) with conn.begin(): - ... # work with transaction + # ... work with transaction Note that some DBAPIs such as asyncpg only support "readonly" with SERIALIZABLE isolation. @@ -266,7 +269,8 @@ will remain consistent with the state of the transaction:: from sqlalchemy import event postgresql_engine = create_engine( - "postgresql+psycopg2://scott:tiger@hostname/dbname", + "postgresql+pyscopg2://scott:tiger@hostname/dbname", + # disable default reset-on-return scheme pool_reset_on_return=None, ) @@ -313,7 +317,6 @@ at :ref:`schema_set_default_connections`:: engine = create_engine("postgresql+psycopg2://scott:tiger@host/dbname") - @event.listens_for(engine, "connect", insert=True) def set_search_path(dbapi_connection, connection_record): existing_autocommit = dbapi_connection.autocommit @@ -332,6 +335,9 @@ be reverted when the DBAPI connection has a rollback. :ref:`schema_set_default_connections` - in the :ref:`metadata_toplevel` documentation + + + .. _postgresql_schema_reflection: Remote-Schema Table Introspection and PostgreSQL search_path @@ -340,9 +346,7 @@ Remote-Schema Table Introspection and PostgreSQL search_path .. admonition:: Section Best Practices Summarized keep the ``search_path`` variable set to its default of ``public``, without - any other schema names. Ensure the username used to connect **does not** - match remote schemas, or ensure the ``"$user"`` token is **removed** from - ``search_path``. For other schema names, name these explicitly + any other schema names. For other schema names, name these explicitly within :class:`_schema.Table` definitions. Alternatively, the ``postgresql_ignore_search_path`` option will cause all reflected :class:`_schema.Table` objects to have a :attr:`_schema.Table.schema` @@ -351,78 +355,19 @@ Remote-Schema Table Introspection and PostgreSQL search_path The PostgreSQL dialect can reflect tables from any schema, as outlined in :ref:`metadata_reflection_schemas`. -In all cases, the first thing SQLAlchemy does when reflecting tables is -to **determine the default schema for the current database connection**. -It does this using the PostgreSQL ``current_schema()`` -function, illustated below using a PostgreSQL client session (i.e. using -the ``psql`` tool): - -.. sourcecode:: sql - - test=> select current_schema(); - current_schema - ---------------- - public - (1 row) - -Above we see that on a plain install of PostgreSQL, the default schema name -is the name ``public``. - -However, if your database username **matches the name of a schema**, PostgreSQL's -default is to then **use that name as the default schema**. Below, we log in -using the username ``scott``. When we create a schema named ``scott``, **it -implicitly changes the default schema**: - -.. sourcecode:: sql - - test=> select current_schema(); - current_schema - ---------------- - public - (1 row) - - test=> create schema scott; - CREATE SCHEMA - test=> select current_schema(); - current_schema - ---------------- - scott - (1 row) - -The behavior of ``current_schema()`` is derived from the -`PostgreSQL search path -`_ -variable ``search_path``, which in modern PostgreSQL versions defaults to this: - -.. sourcecode:: sql - - test=> show search_path; - search_path - ----------------- - "$user", public - (1 row) - -Where above, the ``"$user"`` variable will inject the current username as the -default schema, if one exists. Otherwise, ``public`` is used. - -When a :class:`_schema.Table` object is reflected, if it is present in the -schema indicated by the ``current_schema()`` function, **the schema name assigned -to the ".schema" attribute of the Table is the Python "None" value**. Otherwise, the -".schema" attribute will be assigned the string name of that schema. - With regards to tables which these :class:`_schema.Table` objects refer to via foreign key constraint, a decision must be made as to how the ``.schema`` is represented in those remote tables, in the case where that -remote schema name is also a member of the current ``search_path``. +remote schema name is also a member of the current +`PostgreSQL search path +`_. By default, the PostgreSQL dialect mimics the behavior encouraged by PostgreSQL's own ``pg_get_constraintdef()`` builtin procedure. This function returns a sample definition for a particular foreign key constraint, omitting the referenced schema name from that definition when the name is also in the PostgreSQL schema search path. The interaction below -illustrates this behavior: - -.. sourcecode:: sql +illustrates this behavior:: test=> CREATE TABLE test_schema.referred(id INTEGER PRIMARY KEY); CREATE TABLE @@ -449,17 +394,13 @@ PG ``search_path`` and then asked ``pg_get_constraintdef()`` for the the function. On the other hand, if we set the search path back to the typical default -of ``public``: - -.. sourcecode:: sql +of ``public``:: test=> SET search_path TO public; SET The same query against ``pg_get_constraintdef()`` now returns the fully -schema-qualified name for us: - -.. sourcecode:: sql +schema-qualified name for us:: test=> SELECT pg_catalog.pg_get_constraintdef(r.oid, true) FROM test-> pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n @@ -481,14 +422,16 @@ reflection process as follows:: >>> with engine.connect() as conn: ... conn.execute(text("SET search_path TO test_schema, public")) ... metadata_obj = MetaData() - ... referring = Table("referring", metadata_obj, autoload_with=conn) + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn) + ... The above process would deliver to the :attr:`_schema.MetaData.tables` collection ``referred`` table named **without** the schema:: - >>> metadata_obj.tables["referred"].schema is None + >>> metadata_obj.tables['referred'].schema is None True To alter the behavior of reflection such that the referred schema is @@ -500,17 +443,15 @@ dialect-specific argument to both :class:`_schema.Table` as well as >>> with engine.connect() as conn: ... conn.execute(text("SET search_path TO test_schema, public")) ... metadata_obj = MetaData() - ... referring = Table( - ... "referring", - ... metadata_obj, - ... autoload_with=conn, - ... postgresql_ignore_search_path=True, - ... ) + ... referring = Table('referring', metadata_obj, + ... autoload_with=conn, + ... postgresql_ignore_search_path=True) + ... We will now have ``test_schema.referred`` stored as schema-qualified:: - >>> metadata_obj.tables["test_schema.referred"].schema + >>> metadata_obj.tables['test_schema.referred'].schema 'test_schema' .. sidebar:: Best Practices for PostgreSQL Schema reflection @@ -525,6 +466,13 @@ We will now have ``test_schema.referred`` stored as schema-qualified:: described here are only for those users who can't, or prefer not to, stay within these guidelines. +Note that **in all cases**, the "default" schema is always reflected as +``None``. The "default" schema on PostgreSQL is that which is returned by the +PostgreSQL ``current_schema()`` function. On a typical PostgreSQL +installation, this is the name ``public``. So a table that refers to another +which is in the ``public`` (i.e. default) schema will always have the +``.schema`` attribute set to ``None``. + .. seealso:: :ref:`reflection_schema_qualified_interaction` - discussion of the issue @@ -544,26 +492,18 @@ primary key identifiers. To specify an explicit ``RETURNING`` clause, use the :meth:`._UpdateBase.returning` method on a per-statement basis:: # INSERT..RETURNING - result = ( - table.insert().returning(table.c.col1, table.c.col2).values(name="foo") - ) + result = table.insert().returning(table.c.col1, table.c.col2).\ + values(name='foo') print(result.fetchall()) # UPDATE..RETURNING - result = ( - table.update() - .returning(table.c.col1, table.c.col2) - .where(table.c.name == "foo") - .values(name="bar") - ) + result = table.update().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo').values(name='bar') print(result.fetchall()) # DELETE..RETURNING - result = ( - table.delete() - .returning(table.c.col1, table.c.col2) - .where(table.c.name == "foo") - ) + result = table.delete().returning(table.c.col1, table.c.col2).\ + where(table.c.name=='foo') print(result.fetchall()) .. _postgresql_insert_on_conflict: @@ -593,16 +533,19 @@ and :meth:`~.postgresql.Insert.on_conflict_do_nothing`: >>> from sqlalchemy.dialects.postgresql import insert >>> insert_stmt = insert(my_table).values( - ... id="some_existing_id", data="inserted value" + ... id='some_existing_id', + ... data='inserted value') + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] ... ) - >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"]) >>> print(do_nothing_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) ON CONFLICT (id) DO NOTHING {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint="pk_my_table", set_=dict(data="updated value") + ... constraint='pk_my_table', + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -628,7 +571,8 @@ named constraint or by column inference: .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=["id"], set_=dict(data="updated value") + ... index_elements=['id'], + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -636,7 +580,8 @@ named constraint or by column inference: {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=[my_table.c.id], set_=dict(data="updated value") + ... index_elements=[my_table.c.id], + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -648,11 +593,11 @@ named constraint or by column inference: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(user_email="a@b.com", data="inserted data") + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') >>> stmt = stmt.on_conflict_do_update( ... index_elements=[my_table.c.user_email], - ... index_where=my_table.c.user_email.like("%@gmail.com"), - ... set_=dict(data=stmt.excluded.data), + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) ... ) >>> print(stmt) {printsql}INSERT INTO my_table (data, user_email) @@ -666,7 +611,8 @@ named constraint or by column inference: .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint="my_table_idx_1", set_=dict(data="updated value") + ... constraint='my_table_idx_1', + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -674,7 +620,8 @@ named constraint or by column inference: {stop} >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint="my_table_pk", set_=dict(data="updated value") + ... constraint='my_table_pk', + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -696,7 +643,8 @@ named constraint or by column inference: .. sourcecode:: pycon+sql >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... constraint=my_table.primary_key, set_=dict(data="updated value") + ... constraint=my_table.primary_key, + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -714,9 +662,10 @@ for UPDATE: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], set_=dict(data="updated value") + ... index_elements=['id'], + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -745,11 +694,13 @@ table: .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id="some_id", data="inserted value", author="jlh" + ... id='some_id', + ... data='inserted value', + ... author='jlh' ... ) >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], - ... set_=dict(data="updated value", author=stmt.excluded.author), + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data, author) @@ -766,12 +717,14 @@ parameter, which will limit those rows which receive an UPDATE: .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id="some_id", data="inserted value", author="jlh" + ... id='some_id', + ... data='inserted value', + ... author='jlh' ... ) >>> on_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], - ... set_=dict(data="updated value", author=stmt.excluded.author), - ... where=(my_table.c.status == 2), + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) ... ) >>> print(on_update_stmt) {printsql}INSERT INTO my_table (id, data, author) @@ -789,8 +742,8 @@ this is illustrated using the .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") - >>> stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) ON CONFLICT (id) DO NOTHING @@ -801,7 +754,7 @@ constraint violation which occurs: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') >>> stmt = stmt.on_conflict_do_nothing() >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (%(id)s, %(data)s) @@ -832,9 +785,7 @@ On the PostgreSQL dialect, an expression like the following:: select(sometable.c.text.match("search string")) -would emit to the database: - -.. sourcecode:: sql +would emit to the database:: SELECT text @@ plainto_tsquery('search string') FROM table @@ -850,11 +801,11 @@ with other backends. from sqlalchemy import func - select(sometable.c.text.bool_op("@@")(func.to_tsquery("search string"))) + select( + sometable.c.text.bool_op("@@")(func.to_tsquery("search string")) + ) - Which would emit: - - .. sourcecode:: sql + Which would emit:: SELECT text @@ to_tsquery('search string') FROM table @@ -868,7 +819,9 @@ any boolean operator. For example, the query:: - select(func.to_tsquery("cat").bool_op("@>")(func.to_tsquery("cat & rat"))) + select( + func.to_tsquery('cat').bool_op("@>")(func.to_tsquery('cat & rat')) + ) would generate: @@ -881,12 +834,9 @@ The :class:`_postgresql.TSVECTOR` type can provide for explicit CAST:: from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy import select, cast - select(cast("some text", TSVECTOR)) -produces a statement equivalent to: - -.. sourcecode:: sql +produces a statement equivalent to:: SELECT CAST('some text' AS TSVECTOR) AS anon_1 @@ -914,12 +864,10 @@ When using :meth:`.Operators.match`, this additional parameter may be specified using the ``postgresql_regconfig`` parameter, such as:: select(mytable.c.id).where( - mytable.c.title.match("somestring", postgresql_regconfig="english") + mytable.c.title.match('somestring', postgresql_regconfig='english') ) -Which would emit: - -.. sourcecode:: sql +Which would emit:: SELECT mytable.id FROM mytable WHERE mytable.title @@ plainto_tsquery('english', 'somestring') @@ -933,9 +881,7 @@ When using other PostgreSQL search functions with :data:`.func`, the ) ) -produces a statement equivalent to: - -.. sourcecode:: sql +produces a statement equivalent to:: SELECT mytable.id FROM mytable WHERE to_tsvector('english', mytable.title) @@ @@ -959,16 +905,16 @@ table in an inheritance hierarchy. This can be used to produce the syntaxes. It uses SQLAlchemy's hints mechanism:: # SELECT ... FROM ONLY ... - result = table.select().with_hint(table, "ONLY", "postgresql") + result = table.select().with_hint(table, 'ONLY', 'postgresql') print(result.fetchall()) # UPDATE ONLY ... - table.update(values=dict(foo="bar")).with_hint( - "ONLY", dialect_name="postgresql" - ) + table.update(values=dict(foo='bar')).with_hint('ONLY', + dialect_name='postgresql') # DELETE FROM ONLY ... - table.delete().with_hint("ONLY", dialect_name="postgresql") + table.delete().with_hint('ONLY', dialect_name='postgresql') + .. _postgresql_indexes: @@ -978,24 +924,18 @@ PostgreSQL-Specific Index Options Several extensions to the :class:`.Index` construct are available, specific to the PostgreSQL dialect. -.. _postgresql_covering_indexes: - Covering Indexes ^^^^^^^^^^^^^^^^ The ``postgresql_include`` option renders INCLUDE(colname) for the given string names:: - Index("my_index", table.c.x, postgresql_include=["y"]) + Index("my_index", table.c.x, postgresql_include=['y']) would render the index as ``CREATE INDEX my_index ON table (x) INCLUDE (y)`` Note that this feature requires PostgreSQL 11 or later. -.. seealso:: - - :ref:`postgresql_constraint_options` - .. versionadded:: 1.4 .. _postgresql_partial_indexes: @@ -1007,7 +947,7 @@ Partial indexes add criterion to the index definition so that the index is applied to a subset of rows. These can be specified on :class:`.Index` using the ``postgresql_where`` keyword argument:: - Index("my_index", my_table.c.id, postgresql_where=my_table.c.value > 10) + Index('my_index', my_table.c.id, postgresql_where=my_table.c.value > 10) .. _postgresql_operator_classes: @@ -1021,11 +961,11 @@ The :class:`.Index` construct allows these to be specified via the ``postgresql_ops`` keyword argument:: Index( - "my_index", - my_table.c.id, - my_table.c.data, - postgresql_ops={"data": "text_pattern_ops", "id": "int4_ops"}, - ) + 'my_index', my_table.c.id, my_table.c.data, + postgresql_ops={ + 'data': 'text_pattern_ops', + 'id': 'int4_ops' + }) Note that the keys in the ``postgresql_ops`` dictionaries are the "key" name of the :class:`_schema.Column`, i.e. the name used to access it from @@ -1037,11 +977,12 @@ as a function call, then to apply to the column it must be given a label that is identified in the dictionary by name, e.g.:: Index( - "my_index", - my_table.c.id, - func.lower(my_table.c.data).label("data_lower"), - postgresql_ops={"data_lower": "text_pattern_ops", "id": "int4_ops"}, - ) + 'my_index', my_table.c.id, + func.lower(my_table.c.data).label('data_lower'), + postgresql_ops={ + 'data_lower': 'text_pattern_ops', + 'id': 'int4_ops' + }) Operator classes are also supported by the :class:`_postgresql.ExcludeConstraint` construct using the @@ -1060,7 +1001,7 @@ as the ability for users to create their own (see https://www.postgresql.org/docs/current/static/indexes-types.html). These can be specified on :class:`.Index` using the ``postgresql_using`` keyword argument:: - Index("my_index", my_table.c.data, postgresql_using="gin") + Index('my_index', my_table.c.data, postgresql_using='gin') The value passed to the keyword argument will be simply passed through to the underlying CREATE INDEX command, so it *must* be a valid index type for your @@ -1076,13 +1017,13 @@ parameters available depend on the index method used by the index. Storage parameters can be specified on :class:`.Index` using the ``postgresql_with`` keyword argument:: - Index("my_index", my_table.c.data, postgresql_with={"fillfactor": 50}) + Index('my_index', my_table.c.data, postgresql_with={"fillfactor": 50}) PostgreSQL allows to define the tablespace in which to create the index. The tablespace can be specified on :class:`.Index` using the ``postgresql_tablespace`` keyword argument:: - Index("my_index", my_table.c.data, postgresql_tablespace="my_tablespace") + Index('my_index', my_table.c.data, postgresql_tablespace='my_tablespace') Note that the same option is available on :class:`_schema.Table` as well. @@ -1094,21 +1035,17 @@ Indexes with CONCURRENTLY The PostgreSQL index option CONCURRENTLY is supported by passing the flag ``postgresql_concurrently`` to the :class:`.Index` construct:: - tbl = Table("testtbl", m, Column("data", Integer)) + tbl = Table('testtbl', m, Column('data', Integer)) - idx1 = Index("test_idx1", tbl.c.data, postgresql_concurrently=True) + idx1 = Index('test_idx1', tbl.c.data, postgresql_concurrently=True) The above index construct will render DDL for CREATE INDEX, assuming -PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as: - -.. sourcecode:: sql +PostgreSQL 8.2 or higher is detected or for a connection-less dialect, as:: CREATE INDEX CONCURRENTLY test_idx1 ON testtbl (data) For DROP INDEX, assuming PostgreSQL 9.2 or higher is detected or for -a connection-less dialect, it will emit: - -.. sourcecode:: sql +a connection-less dialect, it will emit:: DROP INDEX CONCURRENTLY test_idx1 @@ -1118,11 +1055,14 @@ even for a single statement, a transaction is present, so to use this construct, the DBAPI's "autocommit" mode must be used:: metadata = MetaData() - table = Table("foo", metadata, Column("id", String)) - index = Index("foo_idx", table.c.id, postgresql_concurrently=True) + table = Table( + "foo", metadata, + Column("id", String)) + index = Index( + "foo_idx", table.c.id, postgresql_concurrently=True) with engine.connect() as conn: - with conn.execution_options(isolation_level="AUTOCOMMIT"): + with conn.execution_options(isolation_level='AUTOCOMMIT'): table.create(conn) .. seealso:: @@ -1172,41 +1112,15 @@ PostgreSQL Table Options Several options for CREATE TABLE are supported directly by the PostgreSQL dialect in conjunction with the :class:`_schema.Table` construct: -* ``INHERITS``:: +* ``TABLESPACE``:: - Table("some_table", metadata, ..., postgresql_inherits="some_supertable") - - Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) - -* ``ON COMMIT``:: - - Table("some_table", metadata, ..., postgresql_on_commit="PRESERVE ROWS") - -* - ``PARTITION BY``:: - - Table( - "some_table", - metadata, - ..., - postgresql_partition_by="LIST (part_column)", - ) - - .. versionadded:: 1.2.6 - -* - ``TABLESPACE``:: - - Table("some_table", metadata, ..., postgresql_tablespace="some_tablespace") + Table("some_table", metadata, ..., postgresql_tablespace='some_tablespace') The above option is also available on the :class:`.Index` construct. -* - ``USING``:: +* ``ON COMMIT``:: - Table("some_table", metadata, ..., postgresql_using="heap") - - .. versionadded:: 2.0.26 + Table("some_table", metadata, ..., postgresql_on_commit='PRESERVE ROWS') * ``WITH OIDS``:: @@ -1216,6 +1130,19 @@ dialect in conjunction with the :class:`_schema.Table` construct: Table("some_table", metadata, ..., postgresql_with_oids=False) +* ``INHERITS``:: + + Table("some_table", metadata, ..., postgresql_inherits="some_supertable") + + Table("some_table", metadata, ..., postgresql_inherits=("t1", "t2", ...)) + +* ``PARTITION BY``:: + + Table("some_table", metadata, ..., + postgresql_partition_by='LIST (part_column)') + + .. versionadded:: 1.2.6 + .. seealso:: `PostgreSQL CREATE TABLE options @@ -1247,7 +1174,7 @@ with selected constraint constructs: "user", ["user_id"], ["id"], - postgresql_not_valid=True, + postgresql_not_valid=True ) The keyword is ultimately accepted directly by the @@ -1258,9 +1185,7 @@ with selected constraint constructs: CheckConstraint("some_field IS NOT NULL", postgresql_not_valid=True) - ForeignKeyConstraint( - ["some_id"], ["some_table.some_id"], postgresql_not_valid=True - ) + ForeignKeyConstraint(["some_id"], ["some_table.some_id"], postgresql_not_valid=True) .. versionadded:: 1.4.32 @@ -1270,65 +1195,6 @@ with selected constraint constructs: `_ - in the PostgreSQL documentation. -* ``INCLUDE``: This option adds one or more columns as a "payload" to the - unique index created automatically by PostgreSQL for the constraint. - For example, the following table definition:: - - Table( - "mytable", - metadata, - Column("id", Integer, nullable=False), - Column("value", Integer, nullable=False), - UniqueConstraint("id", postgresql_include=["value"]), - ) - - would produce the DDL statement - - .. sourcecode:: sql - - CREATE TABLE mytable ( - id INTEGER NOT NULL, - value INTEGER NOT NULL, - UNIQUE (id) INCLUDE (value) - ) - - Note that this feature requires PostgreSQL 11 or later. - - .. versionadded:: 2.0.41 - - .. seealso:: - - :ref:`postgresql_covering_indexes` - - .. seealso:: - - `PostgreSQL CREATE TABLE options - `_ - - in the PostgreSQL documentation. - -* Column list with foreign key ``ON DELETE SET`` actions: This applies to - :class:`.ForeignKey` and :class:`.ForeignKeyConstraint`, the :paramref:`.ForeignKey.ondelete` - parameter will accept on the PostgreSQL backend only a string list of column - names inside parenthesis, following the ``SET NULL`` or ``SET DEFAULT`` - phrases, which will limit the set of columns that are subject to the - action:: - - fktable = Table( - "fktable", - metadata, - Column("tid", Integer), - Column("id", Integer), - Column("fk_id_del_set_null", Integer), - ForeignKeyConstraint( - columns=["tid", "fk_id_del_set_null"], - refcolumns=[pktable.c.tid, pktable.c.id], - ondelete="SET NULL (fk_id_del_set_null)", - ), - ) - - .. versionadded:: 2.0.40 - - .. _postgresql_table_valued_overview: Table values, Table and Column valued functions, Row and Tuple objects @@ -1362,9 +1228,7 @@ Examples from PostgreSQL's reference documentation follow below: .. sourcecode:: pycon+sql >>> from sqlalchemy import select, func - >>> stmt = select( - ... func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value") - ... ) + >>> stmt = select(func.json_each('{"a":"foo", "b":"bar"}').table_valued("key", "value")) >>> print(stmt) {printsql}SELECT anon_1.key, anon_1.value FROM json_each(:json_each_1) AS anon_1 @@ -1376,7 +1240,8 @@ Examples from PostgreSQL's reference documentation follow below: >>> from sqlalchemy import select, func, literal_column >>> stmt = select( ... func.json_populate_record( - ... literal_column("null::myrowtype"), '{"a":1,"b":2}' + ... literal_column("null::myrowtype"), + ... '{"a":1,"b":2}' ... ).table_valued("a", "b", name="x") ... ) >>> print(stmt) @@ -1394,13 +1259,9 @@ Examples from PostgreSQL's reference documentation follow below: >>> from sqlalchemy import select, func, column, Integer, Text >>> stmt = select( - ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}') - ... .table_valued( - ... column("a", Integer), - ... column("b", Text), - ... column("d", Text), - ... ) - ... .render_derived(name="x", with_types=True) + ... func.json_to_record('{"a":1,"b":[1,2,3],"c":"bar"}').table_valued( + ... column("a", Integer), column("b", Text), column("d", Text), + ... ).render_derived(name="x", with_types=True) ... ) >>> print(stmt) {printsql}SELECT x.a, x.b, x.d @@ -1417,9 +1278,9 @@ Examples from PostgreSQL's reference documentation follow below: >>> from sqlalchemy import select, func >>> stmt = select( - ... func.generate_series(4, 1, -1) - ... .table_valued("value", with_ordinality="ordinality") - ... .render_derived() + ... func.generate_series(4, 1, -1). + ... table_valued("value", with_ordinality="ordinality"). + ... render_derived() ... ) >>> print(stmt) {printsql}SELECT anon_1.value, anon_1.ordinality @@ -1448,9 +1309,7 @@ scalar value. PostgreSQL functions such as ``json_array_elements()``, .. sourcecode:: pycon+sql >>> from sqlalchemy import select, func - >>> stmt = select( - ... func.json_array_elements('["one", "two"]').column_valued("x") - ... ) + >>> stmt = select(func.json_array_elements('["one", "two"]').column_valued("x")) >>> print(stmt) {printsql}SELECT x FROM json_array_elements(:json_array_elements_1) AS x @@ -1474,7 +1333,7 @@ scalar value. PostgreSQL functions such as ``json_array_elements()``, >>> from sqlalchemy import table, column, ARRAY, Integer >>> from sqlalchemy import select, func - >>> t = table("t", column("value", ARRAY(Integer))) + >>> t = table("t", column('value', ARRAY(Integer))) >>> stmt = select(func.unnest(t.c.value).column_valued("unnested_value")) >>> print(stmt) {printsql}SELECT unnested_value @@ -1496,10 +1355,10 @@ Built-in support for rendering a ``ROW`` may be approximated using >>> from sqlalchemy import table, column, func, tuple_ >>> t = table("t", column("id"), column("fk")) - >>> stmt = ( - ... t.select() - ... .where(tuple_(t.c.id, t.c.fk) > (1, 2)) - ... .where(func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7)) + >>> stmt = t.select().where( + ... tuple_(t.c.id, t.c.fk) > (1,2) + ... ).where( + ... func.ROW(t.c.id, t.c.fk) < func.ROW(3, 7) ... ) >>> print(stmt) {printsql}SELECT t.id, t.fk @@ -1528,7 +1387,7 @@ itself: .. sourcecode:: pycon+sql >>> from sqlalchemy import table, column, func, select - >>> a = table("a", column("id"), column("x"), column("y")) + >>> a = table( "a", column("id"), column("x"), column("y")) >>> stmt = select(func.row_to_json(a.table_valued())) >>> print(stmt) {printsql}SELECT row_to_json(a) AS row_to_json_1 @@ -1547,20 +1406,19 @@ from functools import lru_cache import re from typing import Any from typing import cast -from typing import Dict from typing import List from typing import Optional from typing import Tuple from typing import TYPE_CHECKING from typing import Union -from . import arraylib as _array +from . import array as _array +from . import hstore as _hstore from . import json as _json from . import pg_catalog from . import ranges as _ranges from .ext import _regconfig_fn from .ext import aggregate_order_by -from .hstore import HSTORE from .named_types import CreateDomainType as CreateDomainType # noqa: F401 from .named_types import CreateEnumType as CreateEnumType # noqa: F401 from .named_types import DOMAIN as DOMAIN # noqa: F401 @@ -1738,7 +1596,6 @@ RESERVED_WORDS = { "verbose", } - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -1751,7 +1608,7 @@ colspecs = { ischema_names = { "_array": _array.ARRAY, - "hstore": HSTORE, + "hstore": _hstore.HSTORE, "json": _json.JSON, "jsonb": _json.JSONB, "int4range": _ranges.INT4RANGE, @@ -1849,14 +1706,12 @@ class PGCompiler(compiler.SQLCompiler): # see #9511 dbapi_type = sqltypes.STRINGTYPE return f"""{sqltext}::{ - self.dialect.type_compiler_instance.process( - dbapi_type, identifier_preparer=self.preparer - ) - }""" + self.dialect.type_compiler_instance.process( + dbapi_type, identifier_preparer=self.preparer + ) + }""" def visit_array(self, element, **kw): - if not element.clauses and not element.type.item_type._isnull: - return "ARRAY[]::%s" % element.type.compile(self.dialect) return "ARRAY[%s]" % self.visit_clauselist(element, **kw) def visit_slice(self, element, **kw): @@ -1880,23 +1735,9 @@ class PGCompiler(compiler.SQLCompiler): kw["eager_grouping"] = True - if ( - not _cast_applied - and isinstance(binary.left.type, _json.JSONB) - and self.dialect._supports_jsonb_subscripting - ): - # for pg14+JSONB use subscript notation: col['key'] instead - # of col -> 'key' - return "%s[%s]" % ( - self.process(binary.left, **kw), - self.process(binary.right, **kw), - ) - else: - # Fall back to arrow notation for older versions or when cast - # is applied - return self._generate_generic_binary( - binary, " -> " if not _cast_applied else " ->> ", **kw - ) + return self._generate_generic_binary( + binary, " -> " if not _cast_applied else " ->> ", **kw + ) def visit_json_path_getitem_op_binary( self, binary, operator, _cast_applied=False, **kw @@ -2084,10 +1925,9 @@ class PGCompiler(compiler.SQLCompiler): for c in select._for_update_arg.of: tables.update(sql_util.surface_selectables_only(c)) - of_kw = dict(kw) - of_kw.update(ashint=True, use_schema=False) tmp += " OF " + ", ".join( - self.process(table, **of_kw) for table in tables + self.process(table, ashint=True, use_schema=False, **kw) + for table in tables ) if select._for_update_arg.nowait: @@ -2169,8 +2009,6 @@ class PGCompiler(compiler.SQLCompiler): else: continue - # TODO: this coercion should be up front. we can't cache - # SQL constructs with non-bound literals buried in them if coercions._is_literal(value): value = elements.BindParameter(None, value, type_=c.type) @@ -2248,11 +2086,9 @@ class PGCompiler(compiler.SQLCompiler): text += "\n FETCH FIRST (%s)%s ROWS %s" % ( self.process(select._fetch_clause, **kw), " PERCENT" if select._fetch_clause_options["percent"] else "", - ( - "WITH TIES" - if select._fetch_clause_options["with_ties"] - else "ONLY" - ), + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY", ) return text @@ -2316,18 +2152,6 @@ class PGDDLCompiler(compiler.DDLCompiler): not_valid = constraint.dialect_options["postgresql"]["not_valid"] return " NOT VALID" if not_valid else "" - def _define_include(self, obj): - includeclause = obj.dialect_options["postgresql"]["include"] - if not includeclause: - return "" - inclusions = [ - obj.table.c[col] if isinstance(col, str) else col - for col in includeclause - ] - return " INCLUDE (%s)" % ", ".join( - [self.preparer.quote(c.name) for c in inclusions] - ) - def visit_check_constraint(self, constraint, **kw): if constraint._type_bound: typ = list(constraint.columns)[0].type @@ -2351,29 +2175,6 @@ class PGDDLCompiler(compiler.DDLCompiler): text += self._define_constraint_validity(constraint) return text - def visit_primary_key_constraint(self, constraint, **kw): - text = super().visit_primary_key_constraint(constraint) - text += self._define_include(constraint) - return text - - def visit_unique_constraint(self, constraint, **kw): - text = super().visit_unique_constraint(constraint) - text += self._define_include(constraint) - return text - - @util.memoized_property - def _fk_ondelete_pattern(self): - return re.compile( - r"^(?:RESTRICT|CASCADE|SET (?:NULL|DEFAULT)(?:\s*\(.+\))?" - r"|NO ACTION)$", - re.I, - ) - - def define_constraint_ondelete_cascade(self, constraint): - return " ON DELETE %s" % self.preparer.validate_sql_phrase( - constraint.ondelete, self._fk_ondelete_pattern - ) - def visit_create_enum_type(self, create, **kw): type_ = create.element @@ -2457,11 +2258,9 @@ class PGDDLCompiler(compiler.DDLCompiler): ", ".join( [ self.sql_compiler.process( - ( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr - ), + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, include_table=False, literal_binds=True, ) @@ -2475,7 +2274,15 @@ class PGDDLCompiler(compiler.DDLCompiler): ) ) - text += self._define_include(index) + includeclause = index.dialect_options["postgresql"]["include"] + if includeclause: + inclusions = [ + index.table.c[col] if isinstance(col, str) else col + for col in includeclause + ] + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) nulls_not_distinct = index.dialect_options["postgresql"][ "nulls_not_distinct" @@ -2588,9 +2395,6 @@ class PGDDLCompiler(compiler.DDLCompiler): if pg_opts["partition_by"]: table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) - if pg_opts["using"]: - table_opts.append("\n USING %s" % pg_opts["using"]) - if pg_opts["with_oids"] is True: table_opts.append("\n WITH OIDS") elif pg_opts["with_oids"] is False: @@ -2778,21 +2582,17 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( - ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "" - ), + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( - ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "" - ), + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "", (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -2913,8 +2713,6 @@ class ReflectedDomain(ReflectedNamedType): """The constraints defined in the domain, if any. The constraint are in order of evaluation by postgresql. """ - collation: Optional[str] - """The collation for the domain.""" class ReflectedEnum(ReflectedNamedType): @@ -3208,7 +3006,6 @@ class PGDialect(default.DefaultDialect): "with_oids": None, "on_commit": None, "inherits": None, - "using": None, }, ), ( @@ -3223,16 +3020,9 @@ class PGDialect(default.DefaultDialect): "not_valid": False, }, ), - ( - schema.PrimaryKeyConstraint, - {"include": None}, - ), ( schema.UniqueConstraint, - { - "include": None, - "nulls_not_distinct": None, - }, + {"nulls_not_distinct": None}, ), ] @@ -3241,7 +3031,6 @@ class PGDialect(default.DefaultDialect): _backslash_escapes = True _supports_create_index_concurrently = True _supports_drop_index_concurrently = True - _supports_jsonb_subscripting = True def __init__( self, @@ -3270,8 +3059,6 @@ class PGDialect(default.DefaultDialect): ) self.supports_identity_columns = self.server_version_info >= (10,) - self._supports_jsonb_subscripting = self.server_version_info >= (14,) - def get_isolation_level_values(self, dbapi_conn): # note the generic dialect doesn't have AUTOCOMMIT, however # all postgresql dialects should include AUTOCOMMIT. @@ -3310,7 +3097,9 @@ class PGDialect(default.DefaultDialect): def get_deferrable(self, connection): raise NotImplementedError() - def _split_multihost_from_url(self, url: URL) -> Union[ + def _split_multihost_from_url( + self, url: URL + ) -> Union[ Tuple[None, None], Tuple[Tuple[Optional[str], ...], Tuple[Optional[int], ...]], ]: @@ -3722,7 +3511,6 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_sequence.c.seqcache, "cycle", pg_catalog.pg_sequence.c.seqcycle, - type_=sqltypes.JSON(), ) ) .select_from(pg_catalog.pg_sequence) @@ -3843,11 +3631,9 @@ class PGDialect(default.DefaultDialect): # dictionary with (name, ) if default search path or (schema, name) # as keys enums = dict( - ( - ((rec["name"],), rec) - if rec["visible"] - else ((rec["schema"], rec["name"]), rec) - ) + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) for rec in self._load_enums( connection, schema="*", info_cache=kw.get("info_cache") ) @@ -3857,188 +3643,155 @@ class PGDialect(default.DefaultDialect): return columns.items() - _format_type_args_pattern = re.compile(r"\((.*)\)") - _format_type_args_delim = re.compile(r"\s*,\s*") - _format_array_spec_pattern = re.compile(r"((?:\[\])*)$") - - def _reflect_type( - self, - format_type: Optional[str], - domains: Dict[str, ReflectedDomain], - enums: Dict[str, ReflectedEnum], - type_description: str, - ) -> sqltypes.TypeEngine[Any]: - """ - Attempts to reconstruct a column type defined in ischema_names based - on the information available in the format_type. - - If the `format_type` cannot be associated with a known `ischema_names`, - it is treated as a reference to a known PostgreSQL named `ENUM` or - `DOMAIN` type. - """ - type_description = type_description or "unknown type" - if format_type is None: - util.warn( - "PostgreSQL format_type() returned NULL for %s" - % type_description - ) - return sqltypes.NULLTYPE - - attype_args_match = self._format_type_args_pattern.search(format_type) - if attype_args_match and attype_args_match.group(1): - attype_args = self._format_type_args_delim.split( - attype_args_match.group(1) - ) - else: - attype_args = () - - match_array_dim = self._format_array_spec_pattern.search(format_type) - # Each "[]" in array specs corresponds to an array dimension - array_dim = len(match_array_dim.group(1) or "") // 2 - - # Remove all parameters and array specs from format_type to obtain an - # ischema_name candidate - attype = self._format_type_args_pattern.sub("", format_type) - attype = self._format_array_spec_pattern.sub("", attype) - - schema_type = self.ischema_names.get(attype.lower(), None) - args, kwargs = (), {} - - if attype == "numeric": - if len(attype_args) == 2: - precision, scale = map(int, attype_args) - args = (precision, scale) - - elif attype == "double precision": - args = (53,) - - elif attype == "integer": - args = () - - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if len(attype_args) == 1: - kwargs["precision"] = int(attype_args[0]) - - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if len(attype_args) == 1: - kwargs["precision"] = int(attype_args[0]) - - elif attype == "bit varying": - kwargs["varying"] = True - if len(attype_args) == 1: - charlen = int(attype_args[0]) - args = (charlen,) - - # a domain or enum can start with interval, so be mindful of that. - elif attype == "interval" or attype.startswith("interval "): - schema_type = INTERVAL - - field_match = re.match(r"interval (.+)", attype) - if field_match: - kwargs["fields"] = field_match.group(1) - - if len(attype_args) == 1: - kwargs["precision"] = int(attype_args[0]) - - else: - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - - if enum_or_domain_key in enums: - schema_type = ENUM - enum = enums[enum_or_domain_key] - - kwargs["name"] = enum["name"] - - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - elif enum_or_domain_key in domains: - schema_type = DOMAIN - domain = domains[enum_or_domain_key] - - data_type = self._reflect_type( - domain["type"], - domains, - enums, - type_description="DOMAIN '%s'" % domain["name"], - ) - args = (domain["name"], data_type) - - kwargs["collation"] = domain["collation"] - kwargs["default"] = domain["default"] - kwargs["not_null"] = not domain["nullable"] - kwargs["create_type"] = False - - if domain["constraints"]: - # We only support a single constraint - check_constraint = domain["constraints"][0] - - kwargs["constraint_name"] = check_constraint["name"] - kwargs["check"] = check_constraint["check"] - - if not domain["visible"]: - kwargs["schema"] = domain["schema"] - - else: - try: - charlen = int(attype_args[0]) - args = (charlen, *attype_args[1:]) - except (ValueError, IndexError): - args = attype_args - - if not schema_type: - util.warn( - "Did not recognize type '%s' of %s" - % (attype, type_description) - ) - return sqltypes.NULLTYPE - - data_type = schema_type(*args, **kwargs) - if array_dim >= 1: - # postgres does not preserve dimensionality or size of array types. - data_type = _array.ARRAY(data_type) - - return data_type - def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") + + def _handle_array_type(attype): + return ( + # strip '[]' from integer[], etc. + array_type_pattern.sub("", attype), + attype.endswith("[]"), + ) + columns = defaultdict(list) for row_dict in rows: # ensure that each table has an entry, even if it has no columns if row_dict["name"] is None: - columns[(schema, row_dict["table_name"])] = ( - ReflectionDefaults.columns() - ) + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() continue table_cols = columns[(schema, row_dict["table_name"])] - coltype = self._reflect_type( - row_dict["format_type"], - domains, - enums, - type_description="column '%s'" % row_dict["name"], - ) - + format_type = row_dict["format_type"] default = row_dict["default"] name = row_dict["name"] generated = row_dict["generated"] + identity = row_dict["identity_options"] + + if format_type is None: + no_format_type = True + attype = format_type = "no format_type()" + is_array = False + else: + no_format_type = False + + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) + + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) + + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + nullable = not row_dict["not_null"] - if isinstance(coltype, DOMAIN): - if not default: - # domain can override the default value but - # cant set it to None - if coltype.default is not None: - default = coltype.default + charlen = charlen_pattern.search(format_type) + if charlen: + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) + else: + args = () + kwargs = {} - nullable = nullable and not coltype.not_null + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: + args = (int(charlen),) - identity = row_dict["identity_options"] + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["type"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) + elif no_format_type: + util.warn( + "PostgreSQL format_type() returned NULL for column '%s'" + % (name,) + ) + coltype = sqltypes.NULLTYPE + else: + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE # If a zero byte or blank string depending on driver (is also # absent for older PG versions), then not a generated column. @@ -4117,35 +3870,21 @@ class PGDialect(default.DefaultDialect): result = connection.execute(oid_q, params) return result.all() - @util.memoized_property - def _constraint_query(self): - if self.server_version_info >= (11, 0): - indnkeyatts = pg_catalog.pg_index.c.indnkeyatts - else: - indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts") - - if self.server_version_info >= (15,): - indnullsnotdistinct = pg_catalog.pg_index.c.indnullsnotdistinct - else: - indnullsnotdistinct = sql.false().label("indnullsnotdistinct") - + @lru_cache() + def _constraint_query(self, is_unique): con_sq = ( select( pg_catalog.pg_constraint.c.conrelid, pg_catalog.pg_constraint.c.conname, - sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + pg_catalog.pg_constraint.c.conindid, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), sql.func.generate_subscripts( - pg_catalog.pg_index.c.indkey, 1 + pg_catalog.pg_constraint.c.conkey, 1 ).label("ord"), - indnkeyatts, - indnullsnotdistinct, pg_catalog.pg_description.c.description, ) - .join( - pg_catalog.pg_index, - pg_catalog.pg_constraint.c.conindid - == pg_catalog.pg_index.c.indexrelid, - ) .outerjoin( pg_catalog.pg_description, pg_catalog.pg_description.c.objoid @@ -4154,9 +3893,6 @@ class PGDialect(default.DefaultDialect): .where( pg_catalog.pg_constraint.c.contype == bindparam("contype"), pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), - # NOTE: filtering also on pg_index.indrelid for oids does - # not seem to have a performance effect, but it may be an - # option if perf problems are reported ) .subquery("con") ) @@ -4165,10 +3901,9 @@ class PGDialect(default.DefaultDialect): select( con_sq.c.conrelid, con_sq.c.conname, + con_sq.c.conindid, con_sq.c.description, con_sq.c.ord, - con_sq.c.indnkeyatts, - con_sq.c.indnullsnotdistinct, pg_catalog.pg_attribute.c.attname, ) .select_from(pg_catalog.pg_attribute) @@ -4191,7 +3926,7 @@ class PGDialect(default.DefaultDialect): .subquery("attr") ) - return ( + constraint_query = ( select( attr_sq.c.conrelid, sql.func.array_agg( @@ -4203,15 +3938,31 @@ class PGDialect(default.DefaultDialect): ).label("cols"), attr_sq.c.conname, sql.func.min(attr_sq.c.description).label("description"), - sql.func.min(attr_sq.c.indnkeyatts).label("indnkeyatts"), - sql.func.bool_and(attr_sq.c.indnullsnotdistinct).label( - "indnullsnotdistinct" - ), ) .group_by(attr_sq.c.conrelid, attr_sq.c.conname) .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) + if is_unique: + if self.server_version_info >= (15,): + constraint_query = constraint_query.join( + pg_catalog.pg_index, + attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid, + ).add_columns( + sql.func.bool_and( + pg_catalog.pg_index.c.indnullsnotdistinct + ).label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.false().label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.null().label("extra") + ) + return constraint_query + def _reflect_constraint( self, connection, contype, schema, filter_names, scope, kind, **kw ): @@ -4227,42 +3978,26 @@ class PGDialect(default.DefaultDialect): batches[0:3000] = [] result = connection.execute( - self._constraint_query, + self._constraint_query(is_unique), {"oids": [r[0] for r in batch], "contype": contype}, - ).mappings() + ) result_by_oid = defaultdict(list) - for row_dict in result: - result_by_oid[row_dict["conrelid"]].append(row_dict) + for oid, cols, constraint_name, comment, extra in result: + result_by_oid[oid].append( + (cols, constraint_name, comment, extra) + ) for oid, tablename in batch: for_oid = result_by_oid.get(oid, ()) if for_oid: - for row in for_oid: - # See note in get_multi_indexes - all_cols = row["cols"] - indnkeyatts = row["indnkeyatts"] - if len(all_cols) > indnkeyatts: - inc_cols = all_cols[indnkeyatts:] - cst_cols = all_cols[:indnkeyatts] - else: - inc_cols = [] - cst_cols = all_cols - - opts = {} - if self.server_version_info >= (11,): - opts["postgresql_include"] = inc_cols + for cols, constraint, comment, extra in for_oid: if is_unique: - opts["postgresql_nulls_not_distinct"] = row[ - "indnullsnotdistinct" - ] - yield ( - tablename, - cst_cols, - row["conname"], - row["description"], - opts, - ) + yield tablename, cols, constraint, comment, { + "nullsnotdistinct": extra + } + else: + yield tablename, cols, constraint, comment, None else: yield tablename, None, None, None, None @@ -4288,27 +4023,18 @@ class PGDialect(default.DefaultDialect): # only a single pk can be present for each table. Return an entry # even if a table has no primary key default = ReflectionDefaults.pk_constraint - - def pk_constraint(pk_name, cols, comment, opts): - info = { - "constrained_columns": cols, - "name": pk_name, - "comment": comment, - } - if opts: - info["dialect_options"] = opts - return info - return ( ( (schema, table_name), - ( - pk_constraint(pk_name, cols, comment, opts) - if pk_name is not None - else default() - ), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + "comment": comment, + } + if pk_name is not None + else default(), ) - for table_name, cols, pk_name, comment, opts in result + for table_name, cols, pk_name, comment, _ in result ) @reflection.cache @@ -4402,8 +4128,7 @@ class PGDialect(default.DefaultDialect): r"[\s]?(ON UPDATE " r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" r"[\s]?(ON DELETE " - r"(CASCADE|RESTRICT|NO ACTION|" - r"SET (?:NULL|DEFAULT)(?:\s\(.+\))?)+)?" + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) @@ -4519,10 +4244,7 @@ class PGDialect(default.DefaultDialect): @util.memoized_property def _index_query(self): - # NOTE: pg_index is used as from two times to improve performance, - # since extraing all the index information from `idx_sq` to avoid - # the second pg_index use leads to a worse performing query in - # particular when querying for a single table (as of pg 17) + pg_class_index = pg_catalog.pg_class.alias("cls_idx") # NOTE: repeating oids clause improve query performance # subquery to get the columns @@ -4531,9 +4253,6 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_index.c.indexrelid, pg_catalog.pg_index.c.indrelid, sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), - sql.func.unnest(pg_catalog.pg_index.c.indclass).label( - "att_opclass" - ), sql.func.generate_subscripts( pg_catalog.pg_index.c.indkey, 1 ).label("ord"), @@ -4565,8 +4284,6 @@ class PGDialect(default.DefaultDialect): else_=pg_catalog.pg_attribute.c.attname.cast(TEXT), ).label("element"), (idx_sq.c.attnum == 0).label("is_expr"), - pg_catalog.pg_opclass.c.opcname, - pg_catalog.pg_opclass.c.opcdefault, ) .select_from(idx_sq) .outerjoin( @@ -4577,10 +4294,6 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, ), ) - .outerjoin( - pg_catalog.pg_opclass, - pg_catalog.pg_opclass.c.oid == idx_sq.c.att_opclass, - ) .where(idx_sq.c.indrelid.in_(bindparam("oids"))) .subquery("idx_attr") ) @@ -4595,12 +4308,6 @@ class PGDialect(default.DefaultDialect): sql.func.array_agg( aggregate_order_by(attr_sq.c.is_expr, attr_sq.c.ord) ).label("elements_is_expr"), - sql.func.array_agg( - aggregate_order_by(attr_sq.c.opcname, attr_sq.c.ord) - ).label("elements_opclass"), - sql.func.array_agg( - aggregate_order_by(attr_sq.c.opcdefault, attr_sq.c.ord) - ).label("elements_opdefault"), ) .group_by(attr_sq.c.indexrelid) .subquery("idx_cols") @@ -4609,7 +4316,7 @@ class PGDialect(default.DefaultDialect): if self.server_version_info >= (11, 0): indnkeyatts = pg_catalog.pg_index.c.indnkeyatts else: - indnkeyatts = pg_catalog.pg_index.c.indnatts.label("indnkeyatts") + indnkeyatts = sql.null().label("indnkeyatts") if self.server_version_info >= (15,): nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct @@ -4619,13 +4326,13 @@ class PGDialect(default.DefaultDialect): return ( select( pg_catalog.pg_index.c.indrelid, - pg_catalog.pg_class.c.relname, + pg_class_index.c.relname.label("relname_index"), pg_catalog.pg_index.c.indisunique, pg_catalog.pg_constraint.c.conrelid.is_not(None).label( "has_constraint" ), pg_catalog.pg_index.c.indoption, - pg_catalog.pg_class.c.reloptions, + pg_class_index.c.reloptions, pg_catalog.pg_am.c.amname, # NOTE: pg_get_expr is very fast so this case has almost no # performance impact @@ -4643,8 +4350,6 @@ class PGDialect(default.DefaultDialect): nulls_not_distinct, cols_sq.c.elements, cols_sq.c.elements_is_expr, - cols_sq.c.elements_opclass, - cols_sq.c.elements_opdefault, ) .select_from(pg_catalog.pg_index) .where( @@ -4652,12 +4357,12 @@ class PGDialect(default.DefaultDialect): ~pg_catalog.pg_index.c.indisprimary, ) .join( - pg_catalog.pg_class, - pg_catalog.pg_index.c.indexrelid == pg_catalog.pg_class.c.oid, + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, ) .join( pg_catalog.pg_am, - pg_catalog.pg_class.c.relam == pg_catalog.pg_am.c.oid, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, ) .outerjoin( cols_sq, @@ -4674,9 +4379,7 @@ class PGDialect(default.DefaultDialect): == sql.any_(_array.array(("p", "u", "x"))), ), ) - .order_by( - pg_catalog.pg_index.c.indrelid, pg_catalog.pg_class.c.relname - ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) ) def get_multi_indexes( @@ -4711,19 +4414,17 @@ class PGDialect(default.DefaultDialect): continue for row in result_by_oid[oid]: - index_name = row["relname"] + index_name = row["relname_index"] table_indexes = indexes[(schema, table_name)] all_elements = row["elements"] all_elements_is_expr = row["elements_is_expr"] - all_elements_opclass = row["elements_opclass"] - all_elements_opdefault = row["elements_opdefault"] indnkeyatts = row["indnkeyatts"] # "The number of key columns in the index, not counting any # included columns, which are merely stored and do not # participate in the index semantics" - if len(all_elements) > indnkeyatts: + if indnkeyatts and len(all_elements) > indnkeyatts: # this is a "covering index" which has INCLUDE columns # as well as regular index columns inc_cols = all_elements[indnkeyatts:] @@ -4738,18 +4439,10 @@ class PGDialect(default.DefaultDialect): not is_expr for is_expr in all_elements_is_expr[indnkeyatts:] ) - idx_elements_opclass = all_elements_opclass[ - :indnkeyatts - ] - idx_elements_opdefault = all_elements_opdefault[ - :indnkeyatts - ] else: idx_elements = all_elements idx_elements_is_expr = all_elements_is_expr inc_cols = [] - idx_elements_opclass = all_elements_opclass - idx_elements_opdefault = all_elements_opdefault index = {"name": index_name, "unique": row["indisunique"]} if any(idx_elements_is_expr): @@ -4763,19 +4456,6 @@ class PGDialect(default.DefaultDialect): else: index["column_names"] = idx_elements - dialect_options = {} - - if not all(idx_elements_opdefault): - dialect_options["postgresql_ops"] = { - name: opclass - for name, opclass, is_default in zip( - idx_elements, - idx_elements_opclass, - idx_elements_opdefault, - ) - if not is_default - } - sorting = {} for col_index, col_flags in enumerate(row["indoption"]): col_sorting = () @@ -4795,12 +4475,10 @@ class PGDialect(default.DefaultDialect): if row["has_constraint"]: index["duplicates_constraint"] = index_name + dialect_options = {} if row["reloptions"]: dialect_options["postgresql_with"] = dict( - [ - option.split("=", 1) - for option in row["reloptions"] - ] + [option.split("=") for option in row["reloptions"]] ) # it *might* be nice to include that this is 'btree' in the # reflection info. But we don't want an Index object @@ -4873,7 +4551,12 @@ class PGDialect(default.DefaultDialect): "comment": comment, } if options: - uc_dict["dialect_options"] = options + if options["nullsnotdistinct"]: + uc_dict["dialect_options"] = { + "postgresql_nulls_not_distinct": options[ + "nullsnotdistinct" + ] + } uniques[(schema, table_name)].append(uc_dict) return uniques.items() @@ -4905,8 +4588,6 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_class.c.oid == pg_catalog.pg_description.c.objoid, pg_catalog.pg_description.c.objsubid == 0, - pg_catalog.pg_description.c.classoid - == sql.func.cast("pg_catalog.pg_class", REGCLASS), ), ) .where(self._pg_class_relkind_condition(relkinds)) @@ -5015,13 +4696,9 @@ class PGDialect(default.DefaultDialect): # "CHECK (((a > 1) AND (a < 5))) NOT VALID" # "CHECK (some_boolean_function(a))" # "CHECK (((a\n < 1)\n OR\n (a\n >= 5))\n)" - # "CHECK (a NOT NULL) NO INHERIT" - # "CHECK (a NOT NULL) NO INHERIT NOT VALID" m = re.match( - r"^CHECK *\((.+)\)( NO INHERIT)?( NOT VALID)?$", - src, - flags=re.DOTALL, + r"^CHECK *\((.+)\)( NOT VALID)?$", src, flags=re.DOTALL ) if not m: util.warn("Could not parse CHECK constraint text: %r" % src) @@ -5035,14 +4712,8 @@ class PGDialect(default.DefaultDialect): "sqltext": sqltext, "comment": comment, } - if m: - do = {} - if " NOT VALID" in m.groups(): - do["not_valid"] = True - if " NO INHERIT" in m.groups(): - do["no_inherit"] = True - if do: - entry["dialect_options"] = do + if m and m.group(2): + entry["dialect_options"] = {"not_valid": True} check_constraints[(schema, table_name)].append(entry) return check_constraints.items() @@ -5157,18 +4828,12 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_namespace.c.nspname.label("schema"), con_sq.c.condefs, con_sq.c.connames, - pg_catalog.pg_collation.c.collname, ) .join( pg_catalog.pg_namespace, pg_catalog.pg_namespace.c.oid == pg_catalog.pg_type.c.typnamespace, ) - .outerjoin( - pg_catalog.pg_collation, - pg_catalog.pg_type.c.typcollation - == pg_catalog.pg_collation.c.oid, - ) .outerjoin( con_sq, pg_catalog.pg_type.c.oid == con_sq.c.contypid, @@ -5182,13 +4847,14 @@ class PGDialect(default.DefaultDialect): @reflection.cache def _load_domains(self, connection, schema=None, **kw): + # Load data types for domains: result = connection.execute(self._domain_query(schema)) - domains: List[ReflectedDomain] = [] + domains = [] for domain in result.mappings(): # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) - constraints: List[ReflectedDomainConstraint] = [] + constraints = [] if domain["connames"]: # When a domain has multiple CHECK constraints, they will # be tested in alphabetical order by name. @@ -5197,13 +4863,12 @@ class PGDialect(default.DefaultDialect): key=lambda t: t[0], ) for name, def_ in sorted_constraints: - # constraint is in the form "CHECK (expression)" - # or "NOT NULL". Ignore the "NOT NULL" and + # constraint is in the form "CHECK (expression)". # remove "CHECK (" and the tailing ")". - if def_.casefold().startswith("check"): - check = def_[7:-1] - constraints.append({"name": name, "check": check}) - domain_rec: ReflectedDomain = { + check = def_[7:-1] + constraints.append({"name": name, "check": check}) + + domain_rec = { "name": domain["name"], "schema": domain["schema"], "visible": domain["visible"], @@ -5211,7 +4876,6 @@ class PGDialect(default.DefaultDialect): "nullable": domain["nullable"], "default": domain["default"], "constraints": constraints, - "collation": domain["collname"], } domains.append(domain_rec) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/dml.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/dml.py index 1187b6b..dee7af3 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/dml.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/dml.py @@ -1,5 +1,5 @@ -# dialects/postgresql/dml.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/dml.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,10 +7,7 @@ from __future__ import annotations from typing import Any -from typing import List from typing import Optional -from typing import Tuple -from typing import Union from . import ext from .._typing import _OnConflictConstraintT @@ -29,9 +26,7 @@ from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement -from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement -from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -158,10 +153,11 @@ class Insert(StandardInsert): :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. An expression object representing a ``WHERE`` - clause that restricts the rows affected by ``DO UPDATE SET``. Rows not - meeting the ``WHERE`` condition will not be updated (effectively a - ``DO NOTHING`` for those rows). + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). .. seealso:: @@ -216,10 +212,8 @@ class OnConflictClause(ClauseElement): stringify_dialect = "postgresql" constraint_target: Optional[str] - inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] - inferred_target_whereclause: Optional[ - Union[ColumnElement[Any], TextClause] - ] + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT def __init__( self, @@ -260,28 +254,12 @@ class OnConflictClause(ClauseElement): if index_elements is not None: self.constraint_target = None - self.inferred_target_elements = [ - coercions.expect(roles.DDLConstraintColumnRole, column) - for column in index_elements - ] - - self.inferred_target_whereclause = ( - coercions.expect( - ( - roles.StatementOptionRole - if isinstance(constraint, ext.ExcludeConstraint) - else roles.WhereHavingRole - ), - index_where, - ) - if index_where is not None - else None - ) - + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where elif constraint is None: - self.constraint_target = self.inferred_target_elements = ( - self.inferred_target_whereclause - ) = None + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None class OnConflictDoNothing(OnConflictClause): @@ -291,9 +269,6 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" - update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] - update_whereclause: Optional[ColumnElement[Any]] - def __init__( self, constraint: _OnConflictConstraintT = None, @@ -332,8 +307,4 @@ class OnConflictDoUpdate(OnConflictClause): (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = ( - coercions.expect(roles.WhereHavingRole, where) - if where is not None - else None - ) + self.update_whereclause = where diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ext.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ext.py index 54bacd9..ad12677 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ext.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ext.py @@ -1,5 +1,5 @@ -# dialects/postgresql/ext.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/ext.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,10 +8,6 @@ from __future__ import annotations from typing import Any -from typing import Iterable -from typing import List -from typing import Optional -from typing import overload from typing import TYPE_CHECKING from typing import TypeVar @@ -27,44 +23,34 @@ from ...sql.schema import ColumnCollectionConstraint from ...sql.sqltypes import TEXT from ...sql.visitors import InternalTraversal -if TYPE_CHECKING: - from ...sql._typing import _ColumnExpressionArgument - from ...sql.elements import ClauseElement - from ...sql.elements import ColumnElement - from ...sql.operators import OperatorType - from ...sql.selectable import FromClause - from ...sql.visitors import _CloneCallableType - from ...sql.visitors import _TraverseInternalsType - _T = TypeVar("_T", bound=Any) +if TYPE_CHECKING: + from ...sql.visitors import _TraverseInternalsType -class aggregate_order_by(expression.ColumnElement[_T]): + +class aggregate_order_by(expression.ColumnElement): """Represent a PostgreSQL aggregate order by expression. E.g.:: from sqlalchemy.dialects.postgresql import aggregate_order_by - expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc())) stmt = select(expr) - would represent the expression: - - .. sourcecode:: sql + would represent the expression:: SELECT array_agg(a ORDER BY b DESC) FROM table; Similarly:: expr = func.string_agg( - table.c.a, aggregate_order_by(literal_column("','"), table.c.a) + table.c.a, + aggregate_order_by(literal_column("','"), table.c.a) ) stmt = select(expr) - Would represent: - - .. sourcecode:: sql + Would represent:: SELECT string_agg(a, ',' ORDER BY a) FROM table; @@ -85,32 +71,11 @@ class aggregate_order_by(expression.ColumnElement[_T]): ("order_by", InternalTraversal.dp_clauseelement), ] - @overload - def __init__( - self, - target: ColumnElement[_T], - *order_by: _ColumnExpressionArgument[Any], - ): ... - - @overload - def __init__( - self, - target: _ColumnExpressionArgument[_T], - *order_by: _ColumnExpressionArgument[Any], - ): ... - - def __init__( - self, - target: _ColumnExpressionArgument[_T], - *order_by: _ColumnExpressionArgument[Any], - ): - self.target: ClauseElement = coercions.expect( - roles.ExpressionElementRole, target - ) + def __init__(self, target, *order_by): + self.target = coercions.expect(roles.ExpressionElementRole, target) self.type = self.target.type _lob = len(order_by) - self.order_by: ClauseElement if _lob == 0: raise TypeError("at least one ORDER BY element is required") elif _lob == 1: @@ -122,22 +87,18 @@ class aggregate_order_by(expression.ColumnElement[_T]): *order_by, _literal_as_text_role=roles.ExpressionElementRole ) - def self_group( - self, against: Optional[OperatorType] = None - ) -> ClauseElement: + def self_group(self, against=None): return self - def get_children(self, **kwargs: Any) -> Iterable[ClauseElement]: + def get_children(self, **kwargs): return self.target, self.order_by - def _copy_internals( - self, clone: _CloneCallableType = elements._clone, **kw: Any - ) -> None: + def _copy_internals(self, clone=elements._clone, **kw): self.target = clone(self.target, **kw) self.order_by = clone(self.order_by, **kw) @property - def _from_objects(self) -> List[FromClause]: + def _from_objects(self): return self.target._from_objects + self.order_by._from_objects @@ -170,10 +131,10 @@ class ExcludeConstraint(ColumnCollectionConstraint): E.g.:: const = ExcludeConstraint( - (Column("period"), "&&"), - (Column("group"), "="), - where=(Column("group") != "some group"), - ops={"group": "my_operator_class"}, + (Column('period'), '&&'), + (Column('group'), '='), + where=(Column('group') != 'some group'), + ops={'group': 'my_operator_class'} ) The constraint is normally embedded into the :class:`_schema.Table` @@ -181,20 +142,19 @@ class ExcludeConstraint(ColumnCollectionConstraint): directly, or added later using :meth:`.append_constraint`:: some_table = Table( - "some_table", - metadata, - Column("id", Integer, primary_key=True), - Column("period", TSRANGE()), - Column("group", String), + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('period', TSRANGE()), + Column('group', String) ) some_table.append_constraint( ExcludeConstraint( - (some_table.c.period, "&&"), - (some_table.c.group, "="), - where=some_table.c.group != "some group", - name="some_table_excl_const", - ops={"group": "my_operator_class"}, + (some_table.c.period, '&&'), + (some_table.c.group, '='), + where=some_table.c.group != 'some group', + name='some_table_excl_const', + ops={'group': 'my_operator_class'} ) ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/hstore.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/hstore.py index 0a915b1..83c4932 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/hstore.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/hstore.py @@ -1,5 +1,5 @@ -# dialects/postgresql/hstore.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/hstore.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,29 +28,28 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): The :class:`.HSTORE` type stores dictionaries containing strings, e.g.:: - data_table = Table( - "data_table", - metadata, - Column("id", Integer, primary_key=True), - Column("data", HSTORE), + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', HSTORE) ) with engine.connect() as conn: conn.execute( - data_table.insert(), data={"key1": "value1", "key2": "value2"} + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} ) :class:`.HSTORE` provides for a wide range of operations, including: * Index operations:: - data_table.c.data["some key"] == "some value" + data_table.c.data['some key'] == 'some value' * Containment operations:: - data_table.c.data.has_key("some key") + data_table.c.data.has_key('some key') - data_table.c.data.has_all(["one", "two", "three"]) + data_table.c.data.has_all(['one', 'two', 'three']) * Concatenation:: @@ -73,19 +72,17 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): from sqlalchemy.ext.mutable import MutableDict - class MyClass(Base): - __tablename__ = "data_table" + __tablename__ = 'data_table' id = Column(Integer, primary_key=True) data = Column(MutableDict.as_mutable(HSTORE)) - my_object = session.query(MyClass).one() # in-place mutation, requires Mutable extension # in order for the ORM to detect - my_object.data["some_key"] = "some value" + my_object.data['some_key'] = 'some value' session.commit() @@ -99,7 +96,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): :class:`.hstore` - render the PostgreSQL ``hstore()`` function. - """ # noqa: E501 + """ __visit_name__ = "HSTORE" hashable = False @@ -195,9 +192,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): comparator_factory = Comparator def bind_processor(self, dialect): - # note that dialect-specific types like that of psycopg and - # psycopg2 will override this method to allow driver-level conversion - # instead, see _PsycopgHStore def process(value): if isinstance(value, dict): return _serialize_hstore(value) @@ -207,9 +201,6 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return process def result_processor(self, dialect, coltype): - # note that dialect-specific types like that of psycopg and - # psycopg2 will override this method to allow driver-level conversion - # instead, see _PsycopgHStore def process(value): if value is not None: return _parse_hstore(value) @@ -230,12 +221,12 @@ class hstore(sqlfunc.GenericFunction): from sqlalchemy.dialects.postgresql import array, hstore - select(hstore("key1", "value1")) + select(hstore('key1', 'value1')) select( hstore( - array(["key1", "key2", "key3"]), - array(["value1", "value2", "value3"]), + array(['key1', 'key2', 'key3']), + array(['value1', 'value2', 'value3']) ) ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/json.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/json.py index 06f8db5..ee56a74 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/json.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/json.py @@ -1,18 +1,11 @@ -# dialects/postgresql/json.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/json.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors -from __future__ import annotations - -from typing import Any -from typing import Callable -from typing import List -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from .array import ARRAY from .array import array as _pg_array @@ -28,23 +21,13 @@ from .operators import PATH_EXISTS from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast -from ...sql._typing import _T - -if TYPE_CHECKING: - from ...engine.interfaces import Dialect - from ...sql.elements import ColumnElement - from ...sql.type_api import _BindProcessorType - from ...sql.type_api import _LiteralProcessorType - from ...sql.type_api import TypeEngine __all__ = ("JSON", "JSONB") class JSONPathType(sqltypes.JSON.JSONPathType): - def _processor( - self, dialect: Dialect, super_proc: Optional[Callable[[Any], Any]] - ) -> Callable[[Any], Any]: - def process(value: Any) -> Any: + def _processor(self, dialect, super_proc): + def process(value): if isinstance(value, str): # If it's already a string assume that it's in json path # format. This allows using cast with json paths literals @@ -61,13 +44,11 @@ class JSONPathType(sqltypes.JSON.JSONPathType): return process - def bind_processor(self, dialect: Dialect) -> _BindProcessorType[Any]: - return self._processor(dialect, self.string_bind_processor(dialect)) # type: ignore[return-value] # noqa: E501 + def bind_processor(self, dialect): + return self._processor(dialect, self.string_bind_processor(dialect)) - def literal_processor( - self, dialect: Dialect - ) -> _LiteralProcessorType[Any]: - return self._processor(dialect, self.string_literal_processor(dialect)) # type: ignore[return-value] # noqa: E501 + def literal_processor(self, dialect): + return self._processor(dialect, self.string_literal_processor(dialect)) class JSONPATH(JSONPathType): @@ -109,14 +90,14 @@ class JSON(sqltypes.JSON): * Index operations (the ``->`` operator):: - data_table.c.data["some key"] + data_table.c.data['some key'] data_table.c.data[5] - * Index operations returning text - (the ``->>`` operator):: - data_table.c.data["some key"].astext == "some value" + * Index operations returning text (the ``->>`` operator):: + + data_table.c.data['some key'].astext == 'some value' Note that equivalent functionality is available via the :attr:`.JSON.Comparator.as_string` accessor. @@ -124,20 +105,18 @@ class JSON(sqltypes.JSON): * Index operations with CAST (equivalent to ``CAST(col ->> ['some key'] AS )``):: - data_table.c.data["some key"].astext.cast(Integer) == 5 + data_table.c.data['some key'].astext.cast(Integer) == 5 Note that equivalent functionality is available via the :attr:`.JSON.Comparator.as_integer` and similar accessors. * Path index operations (the ``#>`` operator):: - data_table.c.data[("key_1", "key_2", 5, ..., "key_n")] + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')] * Path index operations returning text (the ``#>>`` operator):: - data_table.c.data[ - ("key_1", "key_2", 5, ..., "key_n") - ].astext == "some value" + data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value' Index operations return an expression object whose type defaults to :class:`_types.JSON` by default, @@ -149,11 +128,10 @@ class JSON(sqltypes.JSON): using psycopg2, the DBAPI only allows serializers at the per-cursor or per-connection level. E.g.:: - engine = create_engine( - "postgresql+psycopg2://scott:tiger@localhost/test", - json_serializer=my_serialize_fn, - json_deserializer=my_deserialize_fn, - ) + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + json_serializer=my_serialize_fn, + json_deserializer=my_deserialize_fn + ) When using the psycopg2 dialect, the json_deserializer is registered against the database using ``psycopg2.extras.register_default_json``. @@ -166,14 +144,9 @@ class JSON(sqltypes.JSON): """ # noqa - render_bind_cast = True - astext_type: TypeEngine[str] = sqltypes.Text() + astext_type = sqltypes.Text() - def __init__( - self, - none_as_null: bool = False, - astext_type: Optional[TypeEngine[str]] = None, - ): + def __init__(self, none_as_null=False, astext_type=None): """Construct a :class:`_types.JSON` type. :param none_as_null: if True, persist the value ``None`` as a @@ -182,8 +155,7 @@ class JSON(sqltypes.JSON): be used to persist a NULL value:: from sqlalchemy import null - - conn.execute(table.insert(), {"data": null()}) + conn.execute(table.insert(), data=null()) .. seealso:: @@ -198,19 +170,17 @@ class JSON(sqltypes.JSON): if astext_type is not None: self.astext_type = astext_type - class Comparator(sqltypes.JSON.Comparator[_T]): + class Comparator(sqltypes.JSON.Comparator): """Define comparison operations for :class:`_types.JSON`.""" - type: JSON - @property - def astext(self) -> ColumnElement[str]: + def astext(self): """On an indexed expression, use the "astext" (e.g. "->>") conversion when rendered in SQL. E.g.:: - select(data_table.c.data["some key"].astext) + select(data_table.c.data['some key'].astext) .. seealso:: @@ -218,13 +188,13 @@ class JSON(sqltypes.JSON): """ if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): - return self.expr.left.operate( # type: ignore[no-any-return] + return self.expr.left.operate( JSONPATH_ASTEXT, self.expr.right, result_type=self.type.astext_type, ) else: - return self.expr.left.operate( # type: ignore[no-any-return] + return self.expr.left.operate( ASTEXT, self.expr.right, result_type=self.type.astext_type ) @@ -237,16 +207,15 @@ class JSONB(JSON): The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data, e.g.:: - data_table = Table( - "data_table", - metadata, - Column("id", Integer, primary_key=True), - Column("data", JSONB), + data_table = Table('data_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', JSONB) ) with engine.connect() as conn: conn.execute( - data_table.insert(), data={"key1": "value1", "key2": "value2"} + data_table.insert(), + data = {"key1": "value1", "key2": "value2"} ) The :class:`_postgresql.JSONB` type includes all operations provided by @@ -283,53 +252,43 @@ class JSONB(JSON): __visit_name__ = "JSONB" - class Comparator(JSON.Comparator[_T]): + class Comparator(JSON.Comparator): """Define comparison operations for :class:`_types.JSON`.""" - type: JSONB - - def has_key(self, other: Any) -> ColumnElement[bool]: - """Boolean expression. Test for presence of a key (equivalent of - the ``?`` operator). Note that the key may be a SQLA expression. + def has_key(self, other): + """Boolean expression. Test for presence of a key. Note that the + key may be a SQLA expression. """ return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean) - def has_all(self, other: Any) -> ColumnElement[bool]: - """Boolean expression. Test for presence of all keys in jsonb - (equivalent of the ``?&`` operator) - """ + def has_all(self, other): + """Boolean expression. Test for presence of all keys in jsonb""" return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean) - def has_any(self, other: Any) -> ColumnElement[bool]: - """Boolean expression. Test for presence of any key in jsonb - (equivalent of the ``?|`` operator) - """ + def has_any(self, other): + """Boolean expression. Test for presence of any key in jsonb""" return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean) - def contains(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: + def contains(self, other, **kwargs): """Boolean expression. Test if keys (or array) are a superset - of/contained the keys of the argument jsonb expression - (equivalent of the ``@>`` operator). + of/contained the keys of the argument jsonb expression. kwargs may be ignored by this operator but are required for API conformance. """ return self.operate(CONTAINS, other, result_type=sqltypes.Boolean) - def contained_by(self, other: Any) -> ColumnElement[bool]: + def contained_by(self, other): """Boolean expression. Test if keys are a proper subset of the - keys of the argument jsonb expression - (equivalent of the ``<@`` operator). + keys of the argument jsonb expression. """ return self.operate( CONTAINED_BY, other, result_type=sqltypes.Boolean ) - def delete_path( - self, array: Union[List[str], _pg_array[str]] - ) -> ColumnElement[JSONB]: + def delete_path(self, array): """JSONB expression. Deletes field or array element specified in - the argument array (equivalent of the ``#-`` operator). + the argument array. The input may be a list of strings that will be coerced to an ``ARRAY`` or an instance of :meth:`_postgres.array`. @@ -341,9 +300,9 @@ class JSONB(JSON): right_side = cast(array, ARRAY(sqltypes.TEXT)) return self.operate(DELETE_PATH, right_side, result_type=JSONB) - def path_exists(self, other: Any) -> ColumnElement[bool]: + def path_exists(self, other): """Boolean expression. Test for presence of item given by the - argument JSONPath expression (equivalent of the ``@?`` operator). + argument JSONPath expression. .. versionadded:: 2.0 """ @@ -351,10 +310,9 @@ class JSONB(JSON): PATH_EXISTS, other, result_type=sqltypes.Boolean ) - def path_match(self, other: Any) -> ColumnElement[bool]: + def path_match(self, other): """Boolean expression. Test if JSONPath predicate given by the - argument JSONPath expression matches - (equivalent of the ``@@`` operator). + argument JSONPath expression matches. Only the first item of the result is taken into account. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/named_types.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/named_types.py index 5807041..19994d4 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/named_types.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/named_types.py @@ -1,5 +1,5 @@ -# dialects/postgresql/named_types.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/named_types.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,9 +7,7 @@ # mypy: ignore-errors from __future__ import annotations -from types import ModuleType from typing import Any -from typing import Dict from typing import Optional from typing import Type from typing import TYPE_CHECKING @@ -27,11 +25,10 @@ from ...sql.ddl import InvokeCreateDDLBase from ...sql.ddl import InvokeDropDDLBase if TYPE_CHECKING: - from ...sql._typing import _CreateDropBind from ...sql._typing import _TypeEngineArgument -class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): +class NamedType(sqltypes.TypeEngine): """Base for named types.""" __abstract__ = True @@ -39,9 +36,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): DDLDropper: Type[NamedTypeDropper] create_type: bool - def create( - self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any - ) -> None: + def create(self, bind, checkfirst=True, **kw): """Emit ``CREATE`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -55,9 +50,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) - def drop( - self, bind: _CreateDropBind, checkfirst: bool = True, **kw: Any - ) -> None: + def drop(self, bind, checkfirst=True, **kw): """Emit ``DROP`` DDL for this type. :param bind: a connectable :class:`_engine.Engine`, @@ -70,9 +63,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): """ bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) - def _check_for_name_in_memos( - self, checkfirst: bool, kw: Dict[str, Any] - ) -> bool: + def _check_for_name_in_memos(self, checkfirst, kw): """Look in the 'ddl runner' for 'memos', then note our name in that collection. @@ -96,13 +87,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): else: return False - def _on_table_create( - self, - target: Any, - bind: _CreateDropBind, - checkfirst: bool = False, - **kw: Any, - ) -> None: + def _on_table_create(self, target, bind, checkfirst=False, **kw): if ( checkfirst or ( @@ -112,13 +97,7 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): ) and not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_table_drop( - self, - target: Any, - bind: _CreateDropBind, - checkfirst: bool = False, - **kw: Any, - ) -> None: + def _on_table_drop(self, target, bind, checkfirst=False, **kw): if ( not self.metadata and not kw.get("_is_metadata_operation", False) @@ -126,23 +105,11 @@ class NamedType(schema.SchemaVisitable, sqltypes.TypeEngine): ): self.drop(bind=bind, checkfirst=checkfirst) - def _on_metadata_create( - self, - target: Any, - bind: _CreateDropBind, - checkfirst: bool = False, - **kw: Any, - ) -> None: + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): if not self._check_for_name_in_memos(checkfirst, kw): self.create(bind=bind, checkfirst=checkfirst) - def _on_metadata_drop( - self, - target: Any, - bind: _CreateDropBind, - checkfirst: bool = False, - **kw: Any, - ) -> None: + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) @@ -196,6 +163,7 @@ class EnumDropper(NamedTypeDropper): class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): + """PostgreSQL ENUM type. This is a subclass of :class:`_types.Enum` which includes @@ -218,10 +186,8 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): :meth:`_schema.Table.drop` methods are called:: - table = Table( - "sometable", - metadata, - Column("some_enum", ENUM("a", "b", "c", name="myenum")), + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) ) table.create(engine) # will emit CREATE ENUM and CREATE TABLE @@ -232,17 +198,21 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): :class:`_postgresql.ENUM` independently, and associate it with the :class:`_schema.MetaData` object itself:: - my_enum = ENUM("a", "b", "c", name="myenum", metadata=metadata) + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) - t1 = Table("sometable_one", metadata, Column("some_enum", myenum)) + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) - t2 = Table("sometable_two", metadata, Column("some_enum", myenum)) + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) When this pattern is used, care must still be taken at the level of individual table creates. Emitting CREATE TABLE without also specifying ``checkfirst=True`` will still cause issues:: - t1.create(engine) # will fail: no such type 'myenum' + t1.create(engine) # will fail: no such type 'myenum' If we specify ``checkfirst=True``, the individual table-level create operation will check for the ``ENUM`` and create if not exists:: @@ -347,7 +317,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): return cls(**kw) - def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: + def create(self, bind=None, checkfirst=True): """Emit ``CREATE TYPE`` for this :class:`_postgresql.ENUM`. @@ -368,7 +338,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().create(bind, checkfirst=checkfirst) - def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: + def drop(self, bind=None, checkfirst=True): """Emit ``DROP TYPE`` for this :class:`_postgresql.ENUM`. @@ -388,7 +358,7 @@ class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): super().drop(bind, checkfirst=checkfirst) - def get_dbapi_type(self, dbapi: ModuleType) -> None: + def get_dbapi_type(self, dbapi): """dont return dbapi.STRING for ENUM in PostgreSQL, since that's a different type""" @@ -418,12 +388,14 @@ class DOMAIN(NamedType, sqltypes.SchemaType): A domain is essentially a data type with optional constraints that restrict the allowed set of values. E.g.:: - PositiveInt = DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True) + PositiveInt = DOMAIN( + "pos_int", Integer, check="VALUE > 0", not_null=True + ) UsPostalCode = DOMAIN( "us_postal_code", Text, - check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'", + check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'" ) See the `PostgreSQL documentation`__ for additional details @@ -432,7 +404,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType): .. versionadded:: 2.0 - """ # noqa: E501 + """ DDLGenerator = DomainGenerator DDLDropper = DomainDropper @@ -445,10 +417,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType): data_type: _TypeEngineArgument[Any], *, collation: Optional[str] = None, - default: Union[elements.TextClause, str, None] = None, + default: Optional[Union[str, elements.TextClause]] = None, constraint_name: Optional[str] = None, not_null: Optional[bool] = None, - check: Union[elements.TextClause, str, None] = None, + check: Optional[str] = None, create_type: bool = True, **kw: Any, ): @@ -492,7 +464,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType): self.default = default self.collation = collation self.constraint_name = constraint_name - self.not_null = bool(not_null) + self.not_null = not_null if check is not None: check = coercions.expect(roles.DDLExpressionRole, check) self.check = check diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/operators.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/operators.py index ebcafcb..f393451 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/operators.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/operators.py @@ -1,5 +1,5 @@ -# dialects/postgresql/operators.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/operators.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg8000.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg8000.py index 47016b4..71ee4eb 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg8000.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg8000.py @@ -1,5 +1,5 @@ -# dialects/postgresql/pg8000.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors # # This module is part of SQLAlchemy and is released under @@ -27,21 +27,19 @@ PostgreSQL ``client_encoding`` parameter; by default this is the value in the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``. Typically, this can be changed to ``utf-8``, as a more useful default:: - # client_encoding = sql_ascii # actually, defaults to database encoding + #client_encoding = sql_ascii # actually, defaults to database + # encoding client_encoding = utf8 The ``client_encoding`` can be overridden for a session by executing the SQL: -.. sourcecode:: sql - - SET CLIENT_ENCODING TO 'utf8'; +SET CLIENT_ENCODING TO 'utf8'; SQLAlchemy will execute this SQL on all new connections based on the value passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter:: engine = create_engine( - "postgresql+pg8000://user:pass@host/dbname", client_encoding="utf8" - ) + "postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8') .. _pg8000_ssl: @@ -52,7 +50,6 @@ pg8000 accepts a Python ``SSLContext`` object which may be specified using the :paramref:`_sa.create_engine.connect_args` dictionary:: import ssl - ssl_context = ssl.create_default_context() engine = sa.create_engine( "postgresql+pg8000://scott:tiger@192.168.0.199/test", @@ -64,7 +61,6 @@ or does not match the host name (as seen from the client), it may also be necessary to disable hostname checking:: import ssl - ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -257,7 +253,7 @@ class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): pass -class _Pg8000Range(ranges.AbstractSingleRangeImpl): +class _Pg8000Range(ranges.AbstractRangeImpl): def bind_processor(self, dialect): pg8000_Range = dialect.dbapi.Range @@ -308,13 +304,15 @@ class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl): def to_multirange(value): if value is None: return None - else: - return ranges.MultiRange( + + mr = [] + for v in value: + mr.append( ranges.Range( v.lower, v.upper, bounds=v.bounds, empty=v.is_empty ) - for v in value ) + return mr return to_multirange @@ -540,9 +538,6 @@ class PGDialect_pg8000(PGDialect): cursor.execute("COMMIT") cursor.close() - def detect_autocommit_setting(self, dbapi_conn) -> bool: - return bool(dbapi_conn.autocommit) - def set_readonly(self, connection, value): cursor = connection.cursor() try: @@ -589,8 +584,8 @@ class PGDialect_pg8000(PGDialect): cursor = dbapi_connection.cursor() cursor.execute( f"""SET CLIENT_ENCODING TO '{ - client_encoding.replace("'", "''") - }'""" + client_encoding.replace("'", "''") + }'""" ) cursor.execute("COMMIT") cursor.close() diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py index 9625ccf..fa4b30f 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -1,16 +1,10 @@ -# dialects/postgresql/pg_catalog.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/pg_catalog.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php - -from __future__ import annotations - -from typing import Any -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING +# mypy: ignore-errors from .array import ARRAY from .types import OID @@ -29,37 +23,31 @@ from ...types import String from ...types import Text from ...types import TypeDecorator -if TYPE_CHECKING: - from ...engine.interfaces import Dialect - from ...sql.type_api import _ResultProcessorType - # types -class NAME(TypeDecorator[str]): +class NAME(TypeDecorator): impl = String(64, collation="C") cache_ok = True -class PG_NODE_TREE(TypeDecorator[str]): +class PG_NODE_TREE(TypeDecorator): impl = Text(collation="C") cache_ok = True -class INT2VECTOR(TypeDecorator[Sequence[int]]): +class INT2VECTOR(TypeDecorator): impl = ARRAY(SmallInteger) cache_ok = True -class OIDVECTOR(TypeDecorator[Sequence[int]]): +class OIDVECTOR(TypeDecorator): impl = ARRAY(OID) cache_ok = True class _SpaceVector: - def result_processor( - self, dialect: Dialect, coltype: object - ) -> _ResultProcessorType[list[int]]: - def process(value: Any) -> Optional[list[int]]: + def result_processor(self, dialect, coltype): + def process(value): if value is None: return value return [int(p) for p in value.split(" ")] @@ -89,7 +77,7 @@ RELKINDS_MAT_VIEW = ("m",) RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW # tables -pg_catalog_meta = MetaData(schema="pg_catalog") +pg_catalog_meta = MetaData() pg_namespace = Table( "pg_namespace", @@ -97,6 +85,7 @@ pg_namespace = Table( Column("oid", OID), Column("nspname", NAME), Column("nspowner", OID), + schema="pg_catalog", ) pg_class = Table( @@ -131,6 +120,7 @@ pg_class = Table( Column("relispartition", Boolean, info={"server_version": (10,)}), Column("relrewrite", OID, info={"server_version": (11,)}), Column("reloptions", ARRAY(Text)), + schema="pg_catalog", ) pg_type = Table( @@ -165,6 +155,7 @@ pg_type = Table( Column("typndims", Integer), Column("typcollation", OID, info={"server_version": (9, 1)}), Column("typdefault", Text), + schema="pg_catalog", ) pg_index = Table( @@ -191,6 +182,7 @@ pg_index = Table( Column("indoption", INT2VECTOR), Column("indexprs", PG_NODE_TREE), Column("indpred", PG_NODE_TREE), + schema="pg_catalog", ) pg_attribute = Table( @@ -217,6 +209,7 @@ pg_attribute = Table( Column("attislocal", Boolean), Column("attinhcount", Integer), Column("attcollation", OID, info={"server_version": (9, 1)}), + schema="pg_catalog", ) pg_constraint = Table( @@ -242,6 +235,7 @@ pg_constraint = Table( Column("connoinherit", Boolean, info={"server_version": (9, 2)}), Column("conkey", ARRAY(SmallInteger)), Column("confkey", ARRAY(SmallInteger)), + schema="pg_catalog", ) pg_sequence = Table( @@ -255,6 +249,7 @@ pg_sequence = Table( Column("seqmin", BigInteger), Column("seqcache", BigInteger), Column("seqcycle", Boolean), + schema="pg_catalog", info={"server_version": (10,)}, ) @@ -265,6 +260,7 @@ pg_attrdef = Table( Column("adrelid", OID), Column("adnum", SmallInteger), Column("adbin", PG_NODE_TREE), + schema="pg_catalog", ) pg_description = Table( @@ -274,6 +270,7 @@ pg_description = Table( Column("classoid", OID), Column("objsubid", Integer), Column("description", Text(collation="C")), + schema="pg_catalog", ) pg_enum = Table( @@ -283,6 +280,7 @@ pg_enum = Table( Column("enumtypid", OID), Column("enumsortorder", Float(), info={"server_version": (9, 1)}), Column("enumlabel", NAME), + schema="pg_catalog", ) pg_am = Table( @@ -292,35 +290,5 @@ pg_am = Table( Column("amname", NAME), Column("amhandler", REGPROC, info={"server_version": (9, 6)}), Column("amtype", CHAR, info={"server_version": (9, 6)}), -) - -pg_collation = Table( - "pg_collation", - pg_catalog_meta, - Column("oid", OID, info={"server_version": (9, 3)}), - Column("collname", NAME), - Column("collnamespace", OID), - Column("collowner", OID), - Column("collprovider", CHAR, info={"server_version": (10,)}), - Column("collisdeterministic", Boolean, info={"server_version": (12,)}), - Column("collencoding", Integer), - Column("collcollate", Text), - Column("collctype", Text), - Column("colliculocale", Text), - Column("collicurules", Text, info={"server_version": (16,)}), - Column("collversion", Text, info={"server_version": (10,)}), -) - -pg_opclass = Table( - "pg_opclass", - pg_catalog_meta, - Column("oid", OID, info={"server_version": (9, 3)}), - Column("opcmethod", NAME), - Column("opcname", NAME), - Column("opsnamespace", OID), - Column("opsowner", OID), - Column("opcfamily", OID), - Column("opcintype", OID), - Column("opcdefault", Boolean), - Column("opckeytype", OID), + schema="pg_catalog", ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/provision.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/provision.py index c76f5f5..87f1c9a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/provision.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/provision.py @@ -1,9 +1,3 @@ -# dialects/postgresql/provision.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import time @@ -97,7 +91,7 @@ def drop_all_schema_objects_pre_tables(cfg, eng): for xid in conn.exec_driver_sql( "select gid from pg_prepared_xacts" ).scalars(): - conn.exec_driver_sql("ROLLBACK PREPARED '%s'" % xid) + conn.execute("ROLLBACK PREPARED '%s'" % xid) @drop_all_schema_objects_post_tables.for_db("postgresql") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg.py index 0554048..dcd69ce 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg.py @@ -1,5 +1,5 @@ -# dialects/postgresql/psycopg.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/psycopg2.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,29 +29,20 @@ selected depending on how the engine is created: automatically select the sync version, e.g.:: from sqlalchemy import create_engine - - sync_engine = create_engine( - "postgresql+psycopg://scott:tiger@localhost/test" - ) + sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test") * calling :func:`_asyncio.create_async_engine` with ``postgresql+psycopg://...`` will automatically select the async version, e.g.:: from sqlalchemy.ext.asyncio import create_async_engine - - asyncio_engine = create_async_engine( - "postgresql+psycopg://scott:tiger@localhost/test" - ) + asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test") The asyncio version of the dialect may also be specified explicitly using the ``psycopg_async`` suffix, as:: from sqlalchemy.ext.asyncio import create_async_engine - - asyncio_engine = create_async_engine( - "postgresql+psycopg_async://scott:tiger@localhost/test" - ) + asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test") .. seealso:: @@ -59,42 +50,9 @@ The asyncio version of the dialect may also be specified explicitly using the dialect shares most of its behavior with the ``psycopg2`` dialect. Further documentation is available there. -Using a different Cursor class ------------------------------- - -One of the differences between ``psycopg`` and the older ``psycopg2`` -is how bound parameters are handled: ``psycopg2`` would bind them -client side, while ``psycopg`` by default will bind them server side. - -It's possible to configure ``psycopg`` to do client side binding by -specifying the ``cursor_factory`` to be ``ClientCursor`` when creating -the engine:: - - from psycopg import ClientCursor - - client_side_engine = create_engine( - "postgresql+psycopg://...", - connect_args={"cursor_factory": ClientCursor}, - ) - -Similarly when using an async engine the ``AsyncClientCursor`` can be -specified:: - - from psycopg import AsyncClientCursor - - client_side_engine = create_async_engine( - "postgresql+psycopg://...", - connect_args={"cursor_factory": AsyncClientCursor}, - ) - -.. seealso:: - - `Client-side-binding cursors `_ - """ # noqa from __future__ import annotations -from collections import deque import logging import re from typing import cast @@ -121,8 +79,6 @@ from ...util.concurrency import await_only if TYPE_CHECKING: from typing import Iterable - from psycopg import AsyncConnection - logger = logging.getLogger("sqlalchemy.dialects.postgresql") @@ -135,6 +91,8 @@ class _PGREGCONFIG(REGCONFIG): class _PGJSON(JSON): + render_bind_cast = True + def bind_processor(self, dialect): return self._make_bind_processor(None, dialect._psycopg_Json) @@ -143,6 +101,8 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): + render_bind_cast = True + def bind_processor(self, dialect): return self._make_bind_processor(None, dialect._psycopg_Jsonb) @@ -202,7 +162,7 @@ class _PGBoolean(sqltypes.Boolean): render_bind_cast = True -class _PsycopgRange(ranges.AbstractSingleRangeImpl): +class _PsycopgRange(ranges.AbstractRangeImpl): def bind_processor(self, dialect): psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range @@ -258,10 +218,8 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): def result_processor(self, dialect, coltype): def to_range(value): - if value is None: - return None - else: - return ranges.MultiRange( + if value is not None: + value = [ ranges.Range( elem._lower, elem._upper, @@ -269,7 +227,9 @@ class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl): empty=not elem._bounds, ) for elem in value - ) + ] + + return value return to_range @@ -326,7 +286,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): sqltypes.Integer: _PGInteger, sqltypes.SmallInteger: _PGSmallInteger, sqltypes.BigInteger: _PGBigInteger, - ranges.AbstractSingleRange: _PsycopgRange, + ranges.AbstractRange: _PsycopgRange, ranges.AbstractMultiRange: _PsycopgMultiRange, }, ) @@ -406,12 +366,10 @@ class PGDialect_psycopg(_PGDialect_common_psycopg): # register the adapter for connections made subsequent to # this one - assert self._psycopg_adapters_map register_hstore(info, self._psycopg_adapters_map) # register the adapter for this connection - assert connection.connection - register_hstore(info, connection.connection.driver_connection) + register_hstore(info, connection.connection) @classmethod def import_dbapi(cls): @@ -572,7 +530,7 @@ class AsyncAdapt_psycopg_cursor: def __init__(self, cursor, await_) -> None: self._cursor = cursor self.await_ = await_ - self._rows = deque() + self._rows = [] def __getattr__(self, name): return getattr(self._cursor, name) @@ -599,19 +557,24 @@ class AsyncAdapt_psycopg_cursor: # eq/ne if res and res.status == self._psycopg_ExecStatus.TUPLES_OK: rows = self.await_(self._cursor.fetchall()) - self._rows = deque(rows) + if not isinstance(rows, list): + self._rows = list(rows) + else: + self._rows = rows return result def executemany(self, query, params_seq): return self.await_(self._cursor.executemany(query, params_seq)) def __iter__(self): + # TODO: try to avoid pop(0) on a list while self._rows: - yield self._rows.popleft() + yield self._rows.pop(0) def fetchone(self): if self._rows: - return self._rows.popleft() + # TODO: try to avoid pop(0) on a list + return self._rows.pop(0) else: return None @@ -619,12 +582,13 @@ class AsyncAdapt_psycopg_cursor: if size is None: size = self._cursor.arraysize - rr = self._rows - return [rr.popleft() for _ in range(min(size, len(rr)))] + retval = self._rows[0:size] + self._rows = self._rows[size:] + return retval def fetchall(self): - retval = list(self._rows) - self._rows.clear() + retval = self._rows + self._rows = [] return retval @@ -655,7 +619,6 @@ class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor): class AsyncAdapt_psycopg_connection(AdaptedConnection): - _connection: AsyncConnection __slots__ = () await_ = staticmethod(await_only) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py index eeb7604..2719f3d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2.py @@ -1,5 +1,5 @@ -# dialects/postgresql/psycopg2.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# postgresql/psycopg2.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -88,6 +88,7 @@ connection URI:: "postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require" ) + Unix Domain Connections ------------------------ @@ -102,17 +103,13 @@ in ``/tmp``, or whatever socket directory was specified when PostgreSQL was built. This value can be overridden by passing a pathname to psycopg2, using ``host`` as an additional keyword argument:: - create_engine( - "postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql" - ) + create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql") .. warning:: The format accepted here allows for a hostname in the main URL in addition to the "host" query string argument. **When using this URL format, the initial host is silently ignored**. That is, this URL:: - engine = create_engine( - "postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2" - ) + engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2") Above, the hostname ``myhost1`` is **silently ignored and discarded.** The host which is connected is the ``myhost2`` host. @@ -193,7 +190,7 @@ any or all elements of the connection string. For this form, the URL can be passed without any elements other than the initial scheme:: - engine = create_engine("postgresql+psycopg2://") + engine = create_engine('postgresql+psycopg2://') In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()`` function which in turn represents an empty DSN passed to libpq. @@ -245,7 +242,7 @@ Psycopg2 Fast Execution Helpers Modern versions of psycopg2 include a feature known as `Fast Execution Helpers \ -`_, which +`_, which have been shown in benchmarking to improve psycopg2's executemany() performance, primarily with INSERT statements, by at least an order of magnitude. @@ -267,8 +264,8 @@ used feature. The use of this extension may be enabled using the engine = create_engine( "postgresql+psycopg2://scott:tiger@host/dbname", - executemany_mode="values_plus_batch", - ) + executemany_mode='values_plus_batch') + Possible options for ``executemany_mode`` include: @@ -314,10 +311,8 @@ is below:: engine = create_engine( "postgresql+psycopg2://scott:tiger@host/dbname", - executemany_mode="values_plus_batch", - insertmanyvalues_page_size=5000, - executemany_batch_page_size=500, - ) + executemany_mode='values_plus_batch', + insertmanyvalues_page_size=5000, executemany_batch_page_size=500) .. seealso:: @@ -343,9 +338,7 @@ in the following ways: passed in the database URL; this parameter is consumed by the underlying ``libpq`` PostgreSQL client library:: - engine = create_engine( - "postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8" - ) + engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8") Alternatively, the above ``client_encoding`` value may be passed using :paramref:`_sa.create_engine.connect_args` for programmatic establishment with @@ -353,7 +346,7 @@ in the following ways: engine = create_engine( "postgresql+psycopg2://user:pass@host/dbname", - connect_args={"client_encoding": "utf8"}, + connect_args={'client_encoding': 'utf8'} ) * For all PostgreSQL versions, psycopg2 supports a client-side encoding @@ -362,7 +355,8 @@ in the following ways: ``client_encoding`` parameter passed to :func:`_sa.create_engine`:: engine = create_engine( - "postgresql+psycopg2://user:pass@host/dbname", client_encoding="utf8" + "postgresql+psycopg2://user:pass@host/dbname", + client_encoding="utf8" ) .. tip:: The above ``client_encoding`` parameter admittedly is very similar @@ -381,9 +375,11 @@ in the following ways: # postgresql.conf file # client_encoding = sql_ascii # actually, defaults to database - # encoding + # encoding client_encoding = utf8 + + Transactions ------------ @@ -430,15 +426,15 @@ is set to the ``logging.INFO`` level, notice messages will be logged:: import logging - logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) Above, it is assumed that logging is configured externally. If this is not the case, configuration such as ``logging.basicConfig()`` must be utilized:: import logging - logging.basicConfig() # log messages to stdout - logging.getLogger("sqlalchemy.dialects.postgresql").setLevel(logging.INFO) + logging.basicConfig() # log messages to stdout + logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO) .. seealso:: @@ -475,10 +471,8 @@ textual HSTORE expression. If this behavior is not desired, disable the use of the hstore extension by setting ``use_native_hstore`` to ``False`` as follows:: - engine = create_engine( - "postgresql+psycopg2://scott:tiger@localhost/test", - use_native_hstore=False, - ) + engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test", + use_native_hstore=False) The ``HSTORE`` type is **still supported** when the ``psycopg2.extensions.register_hstore()`` extension is not used. It merely @@ -519,7 +513,7 @@ class _PGJSONB(JSONB): return None -class _Psycopg2Range(ranges.AbstractSingleRangeImpl): +class _Psycopg2Range(ranges.AbstractRangeImpl): _psycopg2_range_cls = "none" def bind_processor(self, dialect): @@ -850,43 +844,33 @@ class PGDialect_psycopg2(_PGDialect_common_psycopg): # checks based on strings. in the case that .closed # didn't cut it, fall back onto these. str_e = str(e).partition("\n")[0] - for msg in self._is_disconnect_messages: + for msg in [ + # these error messages from libpq: interfaces/libpq/fe-misc.c + # and interfaces/libpq/fe-secure.c. + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", + # psycopg2 client errors, psycopg2/connection.h, + # psycopg2/cursor.h + "connection already closed", + "cursor already closed", + # not sure where this path is originally from, it may + # be obsolete. It really says "losed", not "closed". + "losed the connection unexpectedly", + # these can occur in newer SSL + "connection has been closed unexpectedly", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL SYSCALL error: Operation timed out", + "SSL SYSCALL error: Bad address", + ]: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: return True return False - @util.memoized_property - def _is_disconnect_messages(self): - return ( - # these error messages from libpq: interfaces/libpq/fe-misc.c - # and interfaces/libpq/fe-secure.c. - "terminating connection", - "closed the connection", - "connection not open", - "could not receive data from server", - "could not send data to server", - # psycopg2 client errors, psycopg2/connection.h, - # psycopg2/cursor.h - "connection already closed", - "cursor already closed", - # not sure where this path is originally from, it may - # be obsolete. It really says "losed", not "closed". - "losed the connection unexpectedly", - # these can occur in newer SSL - "connection has been closed unexpectedly", - "SSL error: decryption failed or bad record mac", - "SSL SYSCALL error: Bad file descriptor", - "SSL SYSCALL error: EOF detected", - "SSL SYSCALL error: Operation timed out", - "SSL SYSCALL error: Bad address", - # This can occur in OpenSSL 1 when an unexpected EOF occurs. - # https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html#BUGS - # It may also occur in newer OpenSSL for a non-recoverable I/O - # error as a result of a system call that does not set 'errno' - # in libc. - "SSL SYSCALL error: Success", - ) - dialect = PGDialect_psycopg2 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py index 55e1760..211432c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -1,5 +1,5 @@ -# dialects/postgresql/psycopg2cffi.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# testing/engines.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ranges.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ranges.py index 0ce4ea2..f1c2989 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ranges.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/ranges.py @@ -1,5 +1,4 @@ -# dialects/postgresql/ranges.py -# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -15,10 +14,8 @@ from decimal import Decimal from typing import Any from typing import cast from typing import Generic -from typing import List from typing import Optional from typing import overload -from typing import Sequence from typing import Tuple from typing import Type from typing import TYPE_CHECKING @@ -154,8 +151,8 @@ class Range(Generic[_T]): return not self.empty and self.upper is None @property - def __sa_type_engine__(self) -> AbstractSingleRange[_T]: - return AbstractSingleRange() + def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: + return AbstractRange() def _contains_value(self, value: _T) -> bool: """Return True if this range contains the given value.""" @@ -271,9 +268,9 @@ class Range(Generic[_T]): value2 += step value2_inc = False - if value1 < value2: + if value1 < value2: # type: ignore return -1 - elif value1 > value2: + elif value1 > value2: # type: ignore return 1 elif only_values: return 0 @@ -360,8 +357,6 @@ class Range(Generic[_T]): else: return self._contains_value(value) - __contains__ = contains - def overlaps(self, other: Range[_T]) -> bool: "Determine whether this range overlaps with `other`." @@ -712,46 +707,27 @@ class Range(Generic[_T]): return f"{b0}{l},{r}{b1}" -class MultiRange(List[Range[_T]]): - """Represents a multirange sequence. - - This list subclass is an utility to allow automatic type inference of - the proper multi-range SQL type depending on the single range values. - This is useful when operating on literal multi-ranges:: - - import sqlalchemy as sa - from sqlalchemy.dialects.postgresql import MultiRange, Range - - value = literal(MultiRange([Range(2, 4)])) - - select(tbl).where(tbl.c.value.op("@")(MultiRange([Range(-3, 7)]))) - - .. versionadded:: 2.0.26 +class AbstractRange(sqltypes.TypeEngine[Range[_T]]): + """ + Base for PostgreSQL RANGE types. .. seealso:: - - :ref:`postgresql_multirange_list_use`. - """ + `PostgreSQL range functions `_ - @property - def __sa_type_engine__(self) -> AbstractMultiRange[_T]: - return AbstractMultiRange() - - -class AbstractRange(sqltypes.TypeEngine[_T]): - """Base class for single and multi Range SQL types.""" + """ # noqa: E501 render_bind_cast = True __abstract__ = True @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... @overload - def adapt( - self, cls: Type[TypeEngineMixin], **kw: Any - ) -> TypeEngine[Any]: ... + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... def adapt( self, @@ -765,10 +741,7 @@ class AbstractRange(sqltypes.TypeEngine[_T]): and also render as ``INT4RANGE`` in SQL and DDL. """ - if ( - issubclass(cls, (AbstractSingleRangeImpl, AbstractMultiRangeImpl)) - and cls is not self.__class__ - ): + if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__: # two ways to do this are: 1. create a new type on the fly # or 2. have AbstractRangeImpl(visit_name) constructor and a # visit_abstract_range_impl() method in the PG compiler. @@ -787,6 +760,21 @@ class AbstractRange(sqltypes.TypeEngine[_T]): else: return super().adapt(cls) + def _resolve_for_literal(self, value: Any) -> Any: + spec = value.lower if value.lower is not None else value.upper + + if isinstance(spec, int): + return INT8RANGE() + elif isinstance(spec, (Decimal, float)): + return NUMRANGE() + elif isinstance(spec, datetime): + return TSRANGE() if not spec.tzinfo else TSTZRANGE() + elif isinstance(spec, date): + return DATERANGE() + else: + # empty Range, SQL datatype can't be determined here + return sqltypes.NULLTYPE + class comparator_factory(TypeEngine.Comparator[Range[Any]]): """Define comparison operations for range types.""" @@ -868,164 +856,91 @@ class AbstractRange(sqltypes.TypeEngine[_T]): return self.expr.operate(operators.mul, other) -class AbstractSingleRange(AbstractRange[Range[_T]]): - """Base for PostgreSQL RANGE types. - - These are types that return a single :class:`_postgresql.Range` object. - - .. seealso:: - - `PostgreSQL range functions `_ - - """ # noqa: E501 - - __abstract__ = True - - def _resolve_for_literal(self, value: Range[Any]) -> Any: - spec = value.lower if value.lower is not None else value.upper - - if isinstance(spec, int): - # pg is unreasonably picky here: the query - # "select 1::INTEGER <@ '[1, 4)'::INT8RANGE" raises - # "operator does not exist: integer <@ int8range" as of pg 16 - if _is_int32(value): - return INT4RANGE() - else: - return INT8RANGE() - elif isinstance(spec, (Decimal, float)): - return NUMRANGE() - elif isinstance(spec, datetime): - return TSRANGE() if not spec.tzinfo else TSTZRANGE() - elif isinstance(spec, date): - return DATERANGE() - else: - # empty Range, SQL datatype can't be determined here - return sqltypes.NULLTYPE - - -class AbstractSingleRangeImpl(AbstractSingleRange[_T]): - """Marker for AbstractSingleRange that will apply a subclass-specific +class AbstractRangeImpl(AbstractRange[Range[_T]]): + """Marker for AbstractRange that will apply a subclass-specific adaptation""" -class AbstractMultiRange(AbstractRange[Sequence[Range[_T]]]): - """Base for PostgreSQL MULTIRANGE types. - - these are types that return a sequence of :class:`_postgresql.Range` - objects. - - """ +class AbstractMultiRange(AbstractRange[Range[_T]]): + """base for PostgreSQL MULTIRANGE types""" __abstract__ = True - def _resolve_for_literal(self, value: Sequence[Range[Any]]) -> Any: - if not value: - # empty MultiRange, SQL datatype can't be determined here - return sqltypes.NULLTYPE - first = value[0] - spec = first.lower if first.lower is not None else first.upper - if isinstance(spec, int): - # pg is unreasonably picky here: the query - # "select 1::INTEGER <@ '{[1, 4),[6,19)}'::INT8MULTIRANGE" raises - # "operator does not exist: integer <@ int8multirange" as of pg 16 - if all(_is_int32(r) for r in value): - return INT4MULTIRANGE() - else: - return INT8MULTIRANGE() - elif isinstance(spec, (Decimal, float)): - return NUMMULTIRANGE() - elif isinstance(spec, datetime): - return TSMULTIRANGE() if not spec.tzinfo else TSTZMULTIRANGE() - elif isinstance(spec, date): - return DATEMULTIRANGE() - else: - # empty Range, SQL datatype can't be determined here - return sqltypes.NULLTYPE - - -class AbstractMultiRangeImpl(AbstractMultiRange[_T]): - """Marker for AbstractMultiRange that will apply a subclass-specific +class AbstractMultiRangeImpl( + AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] +): + """Marker for AbstractRange that will apply a subclass-specific adaptation""" -class INT4RANGE(AbstractSingleRange[int]): +class INT4RANGE(AbstractRange[Range[int]]): """Represent the PostgreSQL INT4RANGE type.""" __visit_name__ = "INT4RANGE" -class INT8RANGE(AbstractSingleRange[int]): +class INT8RANGE(AbstractRange[Range[int]]): """Represent the PostgreSQL INT8RANGE type.""" __visit_name__ = "INT8RANGE" -class NUMRANGE(AbstractSingleRange[Decimal]): +class NUMRANGE(AbstractRange[Range[Decimal]]): """Represent the PostgreSQL NUMRANGE type.""" __visit_name__ = "NUMRANGE" -class DATERANGE(AbstractSingleRange[date]): +class DATERANGE(AbstractRange[Range[date]]): """Represent the PostgreSQL DATERANGE type.""" __visit_name__ = "DATERANGE" -class TSRANGE(AbstractSingleRange[datetime]): +class TSRANGE(AbstractRange[Range[datetime]]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSRANGE" -class TSTZRANGE(AbstractSingleRange[datetime]): +class TSTZRANGE(AbstractRange[Range[datetime]]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZRANGE" -class INT4MULTIRANGE(AbstractMultiRange[int]): +class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): """Represent the PostgreSQL INT4MULTIRANGE type.""" __visit_name__ = "INT4MULTIRANGE" -class INT8MULTIRANGE(AbstractMultiRange[int]): +class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): """Represent the PostgreSQL INT8MULTIRANGE type.""" __visit_name__ = "INT8MULTIRANGE" -class NUMMULTIRANGE(AbstractMultiRange[Decimal]): +class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): """Represent the PostgreSQL NUMMULTIRANGE type.""" __visit_name__ = "NUMMULTIRANGE" -class DATEMULTIRANGE(AbstractMultiRange[date]): +class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): """Represent the PostgreSQL DATEMULTIRANGE type.""" __visit_name__ = "DATEMULTIRANGE" -class TSMULTIRANGE(AbstractMultiRange[datetime]): +class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSMULTIRANGE" -class TSTZMULTIRANGE(AbstractMultiRange[datetime]): +class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZMULTIRANGE" - - -_max_int_32 = 2**31 - 1 -_min_int_32 = -(2**31) - - -def _is_int32(r: Range[int]) -> bool: - return (r.lower is None or _min_int_32 <= r.lower <= _max_int_32) and ( - r.upper is None or _min_int_32 <= r.upper <= _max_int_32 - ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/types.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/types.py index 1aed2bf..2cac5d8 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/types.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/postgresql/types.py @@ -1,5 +1,4 @@ -# dialects/postgresql/types.py -# Copyright (C) 2013-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2013-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -38,52 +37,43 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): @overload def __init__( self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... - ) -> None: ... + ) -> None: + ... @overload - def __init__( - self: PGUuid[str], as_uuid: Literal[False] = ... - ) -> None: ... + def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None: + ... - def __init__(self, as_uuid: bool = True) -> None: ... + def __init__(self, as_uuid: bool = True) -> None: + ... class BYTEA(sqltypes.LargeBinary): __visit_name__ = "BYTEA" -class _NetworkAddressTypeMixin: - - def coerce_compared_value( - self, op: Optional[OperatorType], value: Any - ) -> TypeEngine[Any]: - if TYPE_CHECKING: - assert isinstance(self, TypeEngine) - return self - - -class INET(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): +class INET(sqltypes.TypeEngine[str]): __visit_name__ = "INET" PGInet = INET -class CIDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): +class CIDR(sqltypes.TypeEngine[str]): __visit_name__ = "CIDR" PGCidr = CIDR -class MACADDR(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): +class MACADDR(sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR" PGMacAddr = MACADDR -class MACADDR8(_NetworkAddressTypeMixin, sqltypes.TypeEngine[str]): +class MACADDR8(sqltypes.TypeEngine[str]): __visit_name__ = "MACADDR8" @@ -104,11 +94,12 @@ class MONEY(sqltypes.TypeEngine[str]): from sqlalchemy import Dialect from sqlalchemy import TypeDecorator - class NumericMoney(TypeDecorator): impl = MONEY - def process_result_value(self, value: Any, dialect: Dialect) -> None: + def process_result_value( + self, value: Any, dialect: Dialect + ) -> None: if value is not None: # adjust this for the currency and numeric m = re.match(r"\$([\d.]+)", value) @@ -123,7 +114,6 @@ class MONEY(sqltypes.TypeEngine[str]): from sqlalchemy import cast from sqlalchemy import TypeDecorator - class NumericMoney(TypeDecorator): impl = MONEY @@ -132,18 +122,20 @@ class MONEY(sqltypes.TypeEngine[str]): .. versionadded:: 1.2 - """ # noqa: E501 + """ __visit_name__ = "MONEY" class OID(sqltypes.TypeEngine[int]): + """Provide the PostgreSQL OID type.""" __visit_name__ = "OID" class REGCONFIG(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL REGCONFIG type. .. versionadded:: 2.0.0rc1 @@ -154,6 +146,7 @@ class REGCONFIG(sqltypes.TypeEngine[str]): class TSQUERY(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL TSQUERY type. .. versionadded:: 2.0.0rc1 @@ -164,6 +157,7 @@ class TSQUERY(sqltypes.TypeEngine[str]): class REGCLASS(sqltypes.TypeEngine[str]): + """Provide the PostgreSQL REGCLASS type. .. versionadded:: 1.2.7 @@ -174,6 +168,7 @@ class REGCLASS(sqltypes.TypeEngine[str]): class TIMESTAMP(sqltypes.TIMESTAMP): + """Provide the PostgreSQL TIMESTAMP type.""" __visit_name__ = "TIMESTAMP" @@ -194,6 +189,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): class TIME(sqltypes.TIME): + """PostgreSQL TIME type.""" __visit_name__ = "TIME" @@ -214,6 +210,7 @@ class TIME(sqltypes.TIME): class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): + """PostgreSQL INTERVAL type.""" __visit_name__ = "INTERVAL" @@ -283,6 +280,7 @@ PGBit = BIT class TSVECTOR(sqltypes.TypeEngine[str]): + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL text search type TSVECTOR. @@ -299,6 +297,7 @@ class TSVECTOR(sqltypes.TypeEngine[str]): class CITEXT(sqltypes.TEXT): + """Provide the PostgreSQL CITEXT type. .. versionadded:: 2.0.7 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/__init__.py index 7b381fa..56bca47 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/__init__.py @@ -1,5 +1,5 @@ -# dialects/sqlite/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py index 3f39d4d..d9438d1 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/aiosqlite.py @@ -1,9 +1,10 @@ -# dialects/sqlite/aiosqlite.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/aiosqlite.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r""" @@ -30,7 +31,6 @@ This dialect should normally be used only with the :func:`_asyncio.create_async_engine` engine creation function:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("sqlite+aiosqlite:///filename") The URL passes through all arguments to the ``pysqlite`` driver, so all @@ -49,71 +49,45 @@ in Python and use them directly in SQLite queries as described here: :ref:`pysql Serializable isolation / Savepoints / Transactional DDL (asyncio version) ------------------------------------------------------------------------- -A newly revised version of this important section is now available -at the top level of the SQLAlchemy SQLite documentation, in the section -:ref:`sqlite_transactions`. +Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature. +The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async:: -.. _aiosqlite_pooling: + from sqlalchemy import create_engine, event + from sqlalchemy.ext.asyncio import create_async_engine -Pooling Behavior ----------------- + engine = create_async_engine("sqlite+aiosqlite:///myfile.db") -The SQLAlchemy ``aiosqlite`` DBAPI establishes the connection pool differently -based on the kind of SQLite database that's requested: + @event.listens_for(engine.sync_engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable aiosqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None -* When a ``:memory:`` SQLite database is specified, the dialect by default - will use :class:`.StaticPool`. This pool maintains a single - connection, so that all access to the engine - use the same ``:memory:`` database. -* When a file-based database is specified, the dialect will use - :class:`.AsyncAdaptedQueuePool` as the source of connections. + @event.listens_for(engine.sync_engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") - .. versionchanged:: 2.0.38 - - SQLite file database engines now use :class:`.AsyncAdaptedQueuePool` by default. - Previously, :class:`.NullPool` were used. The :class:`.NullPool` class - may be used by specifying it via the - :paramref:`_sa.create_engine.poolclass` parameter. +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. """ # noqa -from __future__ import annotations import asyncio -from collections import deque from functools import partial -from types import ModuleType -from typing import Any -from typing import cast -from typing import Deque -from typing import Iterator -from typing import NoReturn -from typing import Optional -from typing import Sequence -from typing import TYPE_CHECKING -from typing import Union from .base import SQLiteExecutionContext from .pysqlite import SQLiteDialect_pysqlite from ... import pool from ... import util -from ...connectors.asyncio import AsyncAdapt_dbapi_module from ...engine import AdaptedConnection from ...util.concurrency import await_fallback from ...util.concurrency import await_only -if TYPE_CHECKING: - from ...connectors.asyncio import AsyncIODBAPIConnection - from ...connectors.asyncio import AsyncIODBAPICursor - from ...engine.interfaces import _DBAPICursorDescription - from ...engine.interfaces import _DBAPIMultiExecuteParams - from ...engine.interfaces import _DBAPISingleExecuteParams - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.url import URL - from ...pool.base import PoolProxiedConnection - class AsyncAdapt_aiosqlite_cursor: # TODO: base on connectors/asyncio.py @@ -132,26 +106,21 @@ class AsyncAdapt_aiosqlite_cursor: server_side = False - def __init__(self, adapt_connection: AsyncAdapt_aiosqlite_connection): + def __init__(self, adapt_connection): self._adapt_connection = adapt_connection self._connection = adapt_connection._connection self.await_ = adapt_connection.await_ self.arraysize = 1 self.rowcount = -1 - self.description: Optional[_DBAPICursorDescription] = None - self._rows: Deque[Any] = deque() + self.description = None + self._rows = [] - def close(self) -> None: - self._rows.clear() - - def execute( - self, - operation: Any, - parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: + def close(self): + self._rows[:] = [] + def execute(self, operation, parameters=None): try: - _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501 + _cursor = self.await_(self._connection.cursor()) if parameters is None: self.await_(_cursor.execute(operation)) @@ -163,7 +132,7 @@ class AsyncAdapt_aiosqlite_cursor: self.lastrowid = self.rowcount = -1 if not self.server_side: - self._rows = deque(self.await_(_cursor.fetchall())) + self._rows = self.await_(_cursor.fetchall()) else: self.description = None self.lastrowid = _cursor.lastrowid @@ -172,17 +141,13 @@ class AsyncAdapt_aiosqlite_cursor: if not self.server_side: self.await_(_cursor.close()) else: - self._cursor = _cursor # type: ignore[misc] + self._cursor = _cursor except Exception as error: self._adapt_connection._handle_exception(error) - def executemany( - self, - operation: Any, - seq_of_parameters: _DBAPIMultiExecuteParams, - ) -> Any: + def executemany(self, operation, seq_of_parameters): try: - _cursor: AsyncIODBAPICursor = self.await_(self._connection.cursor()) # type: ignore[arg-type] # noqa: E501 + _cursor = self.await_(self._connection.cursor()) self.await_(_cursor.executemany(operation, seq_of_parameters)) self.description = None self.lastrowid = _cursor.lastrowid @@ -191,29 +156,30 @@ class AsyncAdapt_aiosqlite_cursor: except Exception as error: self._adapt_connection._handle_exception(error) - def setinputsizes(self, *inputsizes: Any) -> None: + def setinputsizes(self, *inputsizes): pass - def __iter__(self) -> Iterator[Any]: + def __iter__(self): while self._rows: - yield self._rows.popleft() + yield self._rows.pop(0) - def fetchone(self) -> Optional[Any]: + def fetchone(self): if self._rows: - return self._rows.popleft() + return self._rows.pop(0) else: return None - def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: + def fetchmany(self, size=None): if size is None: size = self.arraysize - rr = self._rows - return [rr.popleft() for _ in range(min(size, len(rr)))] + retval = self._rows[0:size] + self._rows[:] = self._rows[size:] + return retval - def fetchall(self) -> Sequence[Any]: - retval = list(self._rows) - self._rows.clear() + def fetchall(self): + retval = self._rows[:] + self._rows[:] = [] return retval @@ -224,27 +190,24 @@ class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor): server_side = True - def __init__(self, *arg: Any, **kw: Any) -> None: + def __init__(self, *arg, **kw): super().__init__(*arg, **kw) - self._cursor: Optional[AsyncIODBAPICursor] = None + self._cursor = None - def close(self) -> None: + def close(self): if self._cursor is not None: self.await_(self._cursor.close()) self._cursor = None - def fetchone(self) -> Optional[Any]: - assert self._cursor is not None + def fetchone(self): return self.await_(self._cursor.fetchone()) - def fetchmany(self, size: Optional[int] = None) -> Sequence[Any]: - assert self._cursor is not None + def fetchmany(self, size=None): if size is None: size = self.arraysize return self.await_(self._cursor.fetchmany(size=size)) - def fetchall(self) -> Sequence[Any]: - assert self._cursor is not None + def fetchall(self): return self.await_(self._cursor.fetchall()) @@ -252,24 +215,22 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): await_ = staticmethod(await_only) __slots__ = ("dbapi",) - def __init__(self, dbapi: Any, connection: AsyncIODBAPIConnection) -> None: + def __init__(self, dbapi, connection): self.dbapi = dbapi self._connection = connection @property - def isolation_level(self) -> Optional[str]: - return cast(str, self._connection.isolation_level) + def isolation_level(self): + return self._connection.isolation_level @isolation_level.setter - def isolation_level(self, value: Optional[str]) -> None: + def isolation_level(self, value): # aiosqlite's isolation_level setter works outside the Thread # that it's supposed to, necessitating setting check_same_thread=False. # for improved stability, we instead invent our own awaitable version # using aiosqlite's async queue directly. - def set_iso( - connection: AsyncAdapt_aiosqlite_connection, value: Optional[str] - ) -> None: + def set_iso(connection, value): connection.isolation_level = value function = partial(set_iso, self._connection._conn, value) @@ -278,38 +239,38 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): self._connection._tx.put_nowait((future, function)) try: - self.await_(future) + return self.await_(future) except Exception as error: self._handle_exception(error) - def create_function(self, *args: Any, **kw: Any) -> None: + def create_function(self, *args, **kw): try: self.await_(self._connection.create_function(*args, **kw)) except Exception as error: self._handle_exception(error) - def cursor(self, server_side: bool = False) -> AsyncAdapt_aiosqlite_cursor: + def cursor(self, server_side=False): if server_side: return AsyncAdapt_aiosqlite_ss_cursor(self) else: return AsyncAdapt_aiosqlite_cursor(self) - def execute(self, *args: Any, **kw: Any) -> Any: + def execute(self, *args, **kw): return self.await_(self._connection.execute(*args, **kw)) - def rollback(self) -> None: + def rollback(self): try: self.await_(self._connection.rollback()) except Exception as error: self._handle_exception(error) - def commit(self) -> None: + def commit(self): try: self.await_(self._connection.commit()) except Exception as error: self._handle_exception(error) - def close(self) -> None: + def close(self): try: self.await_(self._connection.close()) except ValueError: @@ -325,7 +286,7 @@ class AsyncAdapt_aiosqlite_connection(AdaptedConnection): except Exception as error: self._handle_exception(error) - def _handle_exception(self, error: Exception) -> NoReturn: + def _handle_exception(self, error): if ( isinstance(error, ValueError) and error.args[0] == "no active connection" @@ -343,14 +304,14 @@ class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection): await_ = staticmethod(await_fallback) -class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): - def __init__(self, aiosqlite: ModuleType, sqlite: ModuleType): +class AsyncAdapt_aiosqlite_dbapi: + def __init__(self, aiosqlite, sqlite): self.aiosqlite = aiosqlite self.sqlite = sqlite self.paramstyle = "qmark" self._init_dbapi_attributes() - def _init_dbapi_attributes(self) -> None: + def _init_dbapi_attributes(self): for name in ( "DatabaseError", "Error", @@ -369,7 +330,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): for name in ("Binary",): setattr(self, name, getattr(self.sqlite, name)) - def connect(self, *arg: Any, **kw: Any) -> AsyncAdapt_aiosqlite_connection: + def connect(self, *arg, **kw): async_fallback = kw.pop("async_fallback", False) creator_fn = kw.pop("async_creator_fn", None) @@ -393,7 +354,7 @@ class AsyncAdapt_aiosqlite_dbapi(AsyncAdapt_dbapi_module): class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext): - def create_server_side_cursor(self) -> DBAPICursor: + def create_server_side_cursor(self): return self._dbapi_connection.cursor(server_side=True) @@ -408,25 +369,19 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): execution_ctx_cls = SQLiteExecutionContext_aiosqlite @classmethod - def import_dbapi(cls) -> AsyncAdapt_aiosqlite_dbapi: + def import_dbapi(cls): return AsyncAdapt_aiosqlite_dbapi( __import__("aiosqlite"), __import__("sqlite3") ) @classmethod - def get_pool_class(cls, url: URL) -> type[pool.Pool]: + def get_pool_class(cls, url): if cls._is_url_file_db(url): - return pool.AsyncAdaptedQueuePool + return pool.NullPool else: return pool.StaticPool - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: - self.dbapi = cast("DBAPIModule", self.dbapi) + def is_disconnect(self, e, connection, cursor): if isinstance( e, self.dbapi.OperationalError ) and "no active connection" in str(e): @@ -434,10 +389,8 @@ class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite): return super().is_disconnect(e, connection, cursor) - def get_driver_connection( - self, connection: DBAPIConnection - ) -> AsyncIODBAPIConnection: - return connection._connection # type: ignore[no-any-return] + def get_driver_connection(self, connection): + return connection._connection dialect = SQLiteDialect_aiosqlite diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/base.py index 5dbac00..d4eb3bc 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/base.py @@ -1,5 +1,5 @@ -# dialects/sqlite/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,9 +7,10 @@ # mypy: ignore-errors -r''' +r""" .. dialect:: sqlite :name: SQLite + :full_support: 3.36.0 :normal_support: 3.12+ :best_effort: 3.7.16+ @@ -69,12 +70,9 @@ To specifically render the AUTOINCREMENT keyword on the primary key column when rendering DDL, add the flag ``sqlite_autoincrement=True`` to the Table construct:: - Table( - "sometable", - metadata, - Column("id", Integer, primary_key=True), - sqlite_autoincrement=True, - ) + Table('sometable', metadata, + Column('id', Integer, primary_key=True), + sqlite_autoincrement=True) Allowing autoincrement behavior SQLAlchemy types other than Integer/INTEGER ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -94,13 +92,8 @@ One approach to achieve this is to use :class:`.Integer` on SQLite only using :meth:`.TypeEngine.with_variant`:: table = Table( - "my_table", - metadata, - Column( - "id", - BigInteger().with_variant(Integer, "sqlite"), - primary_key=True, - ), + "my_table", metadata, + Column("id", BigInteger().with_variant(Integer, "sqlite"), primary_key=True) ) Another is to use a subclass of :class:`.BigInteger` that overrides its DDL @@ -109,23 +102,21 @@ name to be ``INTEGER`` when compiled against SQLite:: from sqlalchemy import BigInteger from sqlalchemy.ext.compiler import compiles - class SLBigInteger(BigInteger): pass - - @compiles(SLBigInteger, "sqlite") + @compiles(SLBigInteger, 'sqlite') def bi_c(element, compiler, **kw): return "INTEGER" - @compiles(SLBigInteger) def bi_c(element, compiler, **kw): return compiler.visit_BIGINT(element, **kw) table = Table( - "my_table", metadata, Column("id", SLBigInteger(), primary_key=True) + "my_table", metadata, + Column("id", SLBigInteger(), primary_key=True) ) .. seealso:: @@ -136,199 +127,99 @@ name to be ``INTEGER`` when compiled against SQLite:: `Datatypes In SQLite Version 3 `_ -.. _sqlite_transactions: +.. _sqlite_concurrency: -Transactions with SQLite and the sqlite3 driver ------------------------------------------------ +Database Locking Behavior / Concurrency +--------------------------------------- -As a file-based database, SQLite's approach to transactions differs from -traditional databases in many ways. Additionally, the ``sqlite3`` driver -standard with Python (as well as the async version ``aiosqlite`` which builds -on top of it) has several quirks, workarounds, and API features in the -area of transaction control, all of which generally need to be addressed when -constructing a SQLAlchemy application that uses SQLite. +SQLite is not designed for a high level of write concurrency. The database +itself, being a file, is locked completely during write operations within +transactions, meaning exactly one "connection" (in reality a file handle) +has exclusive access to the database during this period - all other +"connections" will be blocked during this time. -Legacy Transaction Mode with the sqlite3 driver -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The Python DBAPI specification also calls for a connection model that is +always in a transaction; there is no ``connection.begin()`` method, +only ``connection.commit()`` and ``connection.rollback()``, upon which a +new transaction is to be begun immediately. This may seem to imply +that the SQLite driver would in theory allow only a single filehandle on a +particular database file at any time; however, there are several +factors both within SQLite itself as well as within the pysqlite driver +which loosen this restriction significantly. -The most important aspect of transaction handling with the sqlite3 driver is -that it defaults (which will continue through Python 3.15 before being -removed in Python 3.16) to legacy transactional behavior which does -not strictly follow :pep:`249`. The way in which the driver diverges from the -PEP is that it does not "begin" a transaction automatically as dictated by -:pep:`249` except in the case of DML statements, e.g. INSERT, UPDATE, and -DELETE. Normally, :pep:`249` dictates that a BEGIN must be emitted upon -the first SQL statement of any kind, so that all subsequent operations will -be established within a transaction until ``connection.commit()`` has been -called. The ``sqlite3`` driver, in an effort to be easier to use in -highly concurrent environments, skips this step for DQL (e.g. SELECT) statements, -and also skips it for DDL (e.g. CREATE TABLE etc.) statements for more legacy -reasons. Statements such as SAVEPOINT are also skipped. +However, no matter what locking modes are used, SQLite will still always +lock the database file once a transaction is started and DML (e.g. INSERT, +UPDATE, DELETE) has at least been emitted, and this will block +other transactions at least at the point that they also attempt to emit DML. +By default, the length of time on this block is very short before it times out +with an error. -In modern versions of the ``sqlite3`` driver as of Python 3.12, this legacy -mode of operation is referred to as -`"legacy transaction control" `_, and is in -effect by default due to the ``Connection.autocommit`` parameter being set to -the constant ``sqlite3.LEGACY_TRANSACTION_CONTROL``. Prior to Python 3.12, -the ``Connection.autocommit`` attribute did not exist. +This behavior becomes more critical when used in conjunction with the +SQLAlchemy ORM. SQLAlchemy's :class:`.Session` object by default runs +within a transaction, and with its autoflush model, may emit DML preceding +any SELECT statement. This may lead to a SQLite database that locks +more quickly than is expected. The locking mode of SQLite and the pysqlite +driver can be manipulated to some degree, however it should be noted that +achieving a high degree of write-concurrency with SQLite is a losing battle. -The implications of legacy transaction mode include: +For more information on SQLite's lack of write concurrency by design, please +see +`Situations Where Another RDBMS May Work Better - High Concurrency +`_ near the bottom of the page. -* **Incorrect support for transactional DDL** - statements like CREATE TABLE, ALTER TABLE, - CREATE INDEX etc. will not automatically BEGIN a transaction if one were not - started already, leading to the changes by each statement being - "autocommitted" immediately unless BEGIN were otherwise emitted first. Very - old (pre Python 3.6) versions of SQLite would also force a COMMIT for these - operations even if a transaction were present, however this is no longer the - case. -* **SERIALIZABLE behavior not fully functional** - SQLite's transaction isolation - behavior is normally consistent with SERIALIZABLE isolation, as it is a file- - based system that locks the database file entirely for write operations, - preventing COMMIT until all reader transactions (and associated file locks) - have completed. However, sqlite3's legacy transaction mode fails to emit BEGIN for SELECT - statements, which causes these SELECT statements to no longer be "repeatable", - failing one of the consistency guarantees of SERIALIZABLE. -* **Incorrect behavior for SAVEPOINT** - as the SAVEPOINT statement does not - imply a BEGIN, a new SAVEPOINT emitted before a BEGIN will function on its - own but fails to participate in the enclosing transaction, meaning a ROLLBACK - of the transaction will not rollback elements that were part of a released - savepoint. - -Legacy transaction mode first existed in order to faciliate working around -SQLite's file locks. Because SQLite relies upon whole-file locks, it is easy to -get "database is locked" errors, particularly when newer features like "write -ahead logging" are disabled. This is a key reason why ``sqlite3``'s legacy -transaction mode is still the default mode of operation; disabling it will -produce behavior that is more susceptible to locked database errors. However -note that **legacy transaction mode will no longer be the default** in a future -Python version (3.16 as of this writing). - -.. _sqlite_enabling_transactions: - -Enabling Non-Legacy SQLite Transactional Modes with the sqlite3 or aiosqlite driver -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Current SQLAlchemy support allows either for setting the -``.Connection.autocommit`` attribute, most directly by using a -:func:`._sa.create_engine` parameter, or if on an older version of Python where -the attribute is not available, using event hooks to control the behavior of -BEGIN. - -* **Enabling modern sqlite3 transaction control via the autocommit connect parameter** (Python 3.12 and above) - - To use SQLite in the mode described at `Transaction control via the autocommit attribute `_, - the most straightforward approach is to set the attribute to its recommended value - of ``False`` at the connect level using :paramref:`_sa.create_engine.connect_args``:: - - from sqlalchemy import create_engine - - engine = create_engine( - "sqlite:///myfile.db", connect_args={"autocommit": False} - ) - - This parameter is also passed through when using the aiosqlite driver:: - - from sqlalchemy.ext.asyncio import create_async_engine - - engine = create_async_engine( - "sqlite+aiosqlite:///myfile.db", connect_args={"autocommit": False} - ) - - The parameter can also be set at the attribute level using the :meth:`.PoolEvents.connect` - event hook, however this will only work for sqlite3, as aiosqlite does not yet expose this - attribute on its ``Connection`` object:: - - from sqlalchemy import create_engine, event - - engine = create_engine("sqlite:///myfile.db") - - - @event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): - # enable autocommit=False mode - dbapi_connection.autocommit = False - -* **Using SQLAlchemy to emit BEGIN in lieu of SQLite's transaction control** (all Python versions, sqlite3 and aiosqlite) - - For older versions of ``sqlite3`` or for cross-compatiblity with older and - newer versions, SQLAlchemy can also take over the job of transaction control. - This is achieved by using the :meth:`.ConnectionEvents.begin` hook - to emit the "BEGIN" command directly, while also disabling SQLite's control - of this command using the :meth:`.PoolEvents.connect` event hook to set the - ``Connection.isolation_level`` attribute to ``None``:: - - - from sqlalchemy import create_engine, event - - engine = create_engine("sqlite:///myfile.db") - - - @event.listens_for(engine, "connect") - def do_connect(dbapi_connection, connection_record): - # disable sqlite3's emitting of the BEGIN statement entirely. - dbapi_connection.isolation_level = None - - - @event.listens_for(engine, "begin") - def do_begin(conn): - # emit our own BEGIN. sqlite3 still emits COMMIT/ROLLBACK correctly - conn.exec_driver_sql("BEGIN") - - When using the asyncio variant ``aiosqlite``, refer to ``engine.sync_engine`` - as in the example below:: - - from sqlalchemy import create_engine, event - from sqlalchemy.ext.asyncio import create_async_engine - - engine = create_async_engine("sqlite+aiosqlite:///myfile.db") - - - @event.listens_for(engine.sync_engine, "connect") - def do_connect(dbapi_connection, connection_record): - # disable aiosqlite's emitting of the BEGIN statement entirely. - dbapi_connection.isolation_level = None - - - @event.listens_for(engine.sync_engine, "begin") - def do_begin(conn): - # emit our own BEGIN. aiosqlite still emits COMMIT/ROLLBACK correctly - conn.exec_driver_sql("BEGIN") +The following subsections introduce areas that are impacted by SQLite's +file-based architecture and additionally will usually require workarounds to +work when using the pysqlite driver. .. _sqlite_isolation_level: -Using SQLAlchemy's Driver Level AUTOCOMMIT Feature with SQLite -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Transaction Isolation Level / Autocommit +---------------------------------------- -SQLAlchemy has a comprehensive database isolation feature with optional -autocommit support that is introduced in the section :ref:`dbapi_autocommit`. +SQLite supports "transaction isolation" in a non-standard way, along two +axes. One is that of the +`PRAGMA read_uncommitted `_ +instruction. This setting can essentially switch SQLite between its +default mode of ``SERIALIZABLE`` isolation, and a "dirty read" isolation +mode normally referred to as ``READ UNCOMMITTED``. -For the ``sqlite3`` and ``aiosqlite`` drivers, SQLAlchemy only includes -built-in support for "AUTOCOMMIT". Note that this mode is currently incompatible -with the non-legacy isolation mode hooks documented in the previous -section at :ref:`sqlite_enabling_transactions`. +SQLAlchemy ties into this PRAGMA statement using the +:paramref:`_sa.create_engine.isolation_level` parameter of +:func:`_sa.create_engine`. +Valid values for this parameter when used with SQLite are ``"SERIALIZABLE"`` +and ``"READ UNCOMMITTED"`` corresponding to a value of 0 and 1, respectively. +SQLite defaults to ``SERIALIZABLE``, however its behavior is impacted by +the pysqlite driver's default behavior. -To use the ``sqlite3`` driver with SQLAlchemy driver-level autocommit, -create an engine setting the :paramref:`_sa.create_engine.isolation_level` -parameter to "AUTOCOMMIT":: +When using the pysqlite driver, the ``"AUTOCOMMIT"`` isolation level is also +available, which will alter the pysqlite connection using the ``.isolation_level`` +attribute on the DBAPI connection and set it to None for the duration +of the setting. - eng = create_engine("sqlite:///myfile.db", isolation_level="AUTOCOMMIT") +.. versionadded:: 1.3.16 added support for SQLite AUTOCOMMIT isolation level + when using the pysqlite / sqlite3 SQLite driver. -When using the above mode, any event hooks that set the sqlite3 ``Connection.autocommit`` -parameter away from its default of ``sqlite3.LEGACY_TRANSACTION_CONTROL`` -as well as hooks that emit ``BEGIN`` should be disabled. -Additional Reading for SQLite / sqlite3 transaction control -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The other axis along which SQLite's transactional locking is impacted is +via the nature of the ``BEGIN`` statement used. The three varieties +are "deferred", "immediate", and "exclusive", as described at +`BEGIN TRANSACTION `_. A straight +``BEGIN`` statement uses the "deferred" mode, where the database file is +not locked until the first read or write operation, and read access remains +open to other transactions until the first write operation. But again, +it is critical to note that the pysqlite driver interferes with this behavior +by *not even emitting BEGIN* until the first write operation. -Links with important information on SQLite, the sqlite3 driver, -as well as long historical conversations on how things got to their current state: +.. warning:: -* `Isolation in SQLite `_ - on the SQLite website -* `Transaction control `_ - describes the sqlite3 autocommit attribute as well - as the legacy isolation_level attribute. -* `sqlite3 SELECT does not BEGIN a transaction, but should according to spec `_ - imported Python standard library issue on github -* `sqlite3 module breaks transactions and potentially corrupts data `_ - imported Python standard library issue on github + SQLite's transactional scope is impacted by unresolved + issues in the pysqlite driver, which defers BEGIN statements to a greater + degree than is often feasible. See the section :ref:`pysqlite_serializable` + or :ref:`aiosqlite_serializable` for techniques to work around this behavior. +.. seealso:: + + :ref:`dbapi_autocommit` INSERT/UPDATE/DELETE...RETURNING --------------------------------- @@ -345,29 +236,63 @@ To specify an explicit ``RETURNING`` clause, use the # INSERT..RETURNING result = connection.execute( - table.insert().values(name="foo").returning(table.c.col1, table.c.col2) + table.insert(). + values(name='foo'). + returning(table.c.col1, table.c.col2) ) print(result.all()) # UPDATE..RETURNING result = connection.execute( - table.update() - .where(table.c.name == "foo") - .values(name="bar") - .returning(table.c.col1, table.c.col2) + table.update(). + where(table.c.name=='foo'). + values(name='bar'). + returning(table.c.col1, table.c.col2) ) print(result.all()) # DELETE..RETURNING result = connection.execute( - table.delete() - .where(table.c.name == "foo") - .returning(table.c.col1, table.c.col2) + table.delete(). + where(table.c.name=='foo'). + returning(table.c.col1, table.c.col2) ) print(result.all()) .. versionadded:: 2.0 Added support for SQLite RETURNING +SAVEPOINT Support +---------------------------- + +SQLite supports SAVEPOINTs, which only function once a transaction is +begun. SQLAlchemy's SAVEPOINT support is available using the +:meth:`_engine.Connection.begin_nested` method at the Core level, and +:meth:`.Session.begin_nested` at the ORM level. However, SAVEPOINTs +won't work at all with pysqlite unless workarounds are taken. + +.. warning:: + + SQLite's SAVEPOINT feature is impacted by unresolved + issues in the pysqlite and aiosqlite drivers, which defer BEGIN statements + to a greater degree than is often feasible. See the sections + :ref:`pysqlite_serializable` and :ref:`aiosqlite_serializable` + for techniques to work around this behavior. + +Transactional DDL +---------------------------- + +The SQLite database supports transactional :term:`DDL` as well. +In this case, the pysqlite driver is not only failing to start transactions, +it also is ending any existing transaction when DDL is detected, so again, +workarounds are required. + +.. warning:: + + SQLite's transactional DDL is impacted by unresolved issues + in the pysqlite driver, which fails to emit BEGIN and additionally + forces a COMMIT to cancel any transaction when DDL is encountered. + See the section :ref:`pysqlite_serializable` + for techniques to work around this behavior. .. _sqlite_foreign_keys: @@ -393,21 +318,12 @@ new connections through the usage of events:: from sqlalchemy.engine import Engine from sqlalchemy import event - @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): - # the sqlite3 driver will not set PRAGMA foreign_keys - # if autocommit=False; set to True temporarily - ac = dbapi_connection.autocommit - dbapi_connection.autocommit = True - cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() - # restore previous autocommit setting - dbapi_connection.autocommit = ac - .. warning:: When SQLite foreign keys are enabled, it is **not possible** @@ -464,16 +380,13 @@ ABORT, FAIL, IGNORE, and REPLACE. For example, to add a UNIQUE constraint that specifies the IGNORE algorithm:: some_table = Table( - "some_table", - metadata, - Column("id", Integer, primary_key=True), - Column("data", Integer), - UniqueConstraint("id", "data", sqlite_on_conflict="IGNORE"), + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer), + UniqueConstraint('id', 'data', sqlite_on_conflict='IGNORE') ) -The above renders CREATE TABLE DDL as: - -.. sourcecode:: sql +The above renders CREATE TABLE DDL as:: CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -490,17 +403,13 @@ be added to the :class:`_schema.Column` as well, which will be added to the UNIQUE constraint in the DDL:: some_table = Table( - "some_table", - metadata, - Column("id", Integer, primary_key=True), - Column( - "data", Integer, unique=True, sqlite_on_conflict_unique="IGNORE" - ), + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, unique=True, + sqlite_on_conflict_unique='IGNORE') ) -rendering: - -.. sourcecode:: sql +rendering:: CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -513,17 +422,13 @@ To apply the FAIL algorithm for a NOT NULL constraint, ``sqlite_on_conflict_not_null`` is used:: some_table = Table( - "some_table", - metadata, - Column("id", Integer, primary_key=True), - Column( - "data", Integer, nullable=False, sqlite_on_conflict_not_null="FAIL" - ), + 'some_table', metadata, + Column('id', Integer, primary_key=True), + Column('data', Integer, nullable=False, + sqlite_on_conflict_not_null='FAIL') ) -this renders the column inline ON CONFLICT phrase: - -.. sourcecode:: sql +this renders the column inline ON CONFLICT phrase:: CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -535,20 +440,13 @@ this renders the column inline ON CONFLICT phrase: Similarly, for an inline primary key, use ``sqlite_on_conflict_primary_key``:: some_table = Table( - "some_table", - metadata, - Column( - "id", - Integer, - primary_key=True, - sqlite_on_conflict_primary_key="FAIL", - ), + 'some_table', metadata, + Column('id', Integer, primary_key=True, + sqlite_on_conflict_primary_key='FAIL') ) SQLAlchemy renders the PRIMARY KEY constraint separately, so the conflict -resolution algorithm is applied to the constraint itself: - -.. sourcecode:: sql +resolution algorithm is applied to the constraint itself:: CREATE TABLE some_table ( id INTEGER NOT NULL, @@ -558,7 +456,7 @@ resolution algorithm is applied to the constraint itself: .. _sqlite_on_conflict_insert: INSERT...ON CONFLICT (Upsert) ------------------------------ +----------------------------------- .. seealso:: This section describes the :term:`DML` version of "ON CONFLICT" for SQLite, which occurs within an INSERT statement. For "ON CONFLICT" as @@ -586,18 +484,21 @@ and :meth:`_sqlite.Insert.on_conflict_do_nothing`: >>> from sqlalchemy.dialects.sqlite import insert >>> insert_stmt = insert(my_table).values( - ... id="some_existing_id", data="inserted value" - ... ) + ... id='some_existing_id', + ... data='inserted value') >>> do_update_stmt = insert_stmt.on_conflict_do_update( - ... index_elements=["id"], set_=dict(data="updated value") + ... index_elements=['id'], + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO UPDATE SET data = ?{stop} - >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing(index_elements=["id"]) + >>> do_nothing_stmt = insert_stmt.on_conflict_do_nothing( + ... index_elements=['id'] + ... ) >>> print(do_nothing_stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) @@ -628,13 +529,13 @@ Both methods supply the "target" of the conflict using column inference: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(user_email="a@b.com", data="inserted data") + >>> stmt = insert(my_table).values(user_email='a@b.com', data='inserted data') >>> do_update_stmt = stmt.on_conflict_do_update( ... index_elements=[my_table.c.user_email], - ... index_where=my_table.c.user_email.like("%@gmail.com"), - ... set_=dict(data=stmt.excluded.data), - ... ) + ... index_where=my_table.c.user_email.like('%@gmail.com'), + ... set_=dict(data=stmt.excluded.data) + ... ) >>> print(do_update_stmt) {printsql}INSERT INTO my_table (data, user_email) VALUES (?, ?) @@ -654,10 +555,11 @@ for UPDATE: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], set_=dict(data="updated value") + ... index_elements=['id'], + ... set_=dict(data='updated value') ... ) >>> print(do_update_stmt) @@ -685,12 +587,14 @@ would have been inserted had the constraint not failed: .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id="some_id", data="inserted value", author="jlh" + ... id='some_id', + ... data='inserted value', + ... author='jlh' ... ) >>> do_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], - ... set_=dict(data="updated value", author=stmt.excluded.author), + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author) ... ) >>> print(do_update_stmt) @@ -707,13 +611,15 @@ parameter, which will limit those rows which receive an UPDATE: .. sourcecode:: pycon+sql >>> stmt = insert(my_table).values( - ... id="some_id", data="inserted value", author="jlh" + ... id='some_id', + ... data='inserted value', + ... author='jlh' ... ) >>> on_update_stmt = stmt.on_conflict_do_update( - ... index_elements=["id"], - ... set_=dict(data="updated value", author=stmt.excluded.author), - ... where=(my_table.c.status == 2), + ... index_elements=['id'], + ... set_=dict(data='updated value', author=stmt.excluded.author), + ... where=(my_table.c.status == 2) ... ) >>> print(on_update_stmt) {printsql}INSERT INTO my_table (id, data, author) VALUES (?, ?, ?) @@ -730,8 +636,8 @@ using the :meth:`_sqlite.Insert.on_conflict_do_nothing` method: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") - >>> stmt = stmt.on_conflict_do_nothing(index_elements=["id"]) + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') + >>> stmt = stmt.on_conflict_do_nothing(index_elements=['id']) >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT (id) DO NOTHING @@ -742,7 +648,7 @@ occurs: .. sourcecode:: pycon+sql - >>> stmt = insert(my_table).values(id="some_id", data="inserted value") + >>> stmt = insert(my_table).values(id='some_id', data='inserted value') >>> stmt = stmt.on_conflict_do_nothing() >>> print(stmt) {printsql}INSERT INTO my_table (id, data) VALUES (?, ?) ON CONFLICT DO NOTHING @@ -802,16 +708,11 @@ Partial Indexes A partial index, e.g. one which uses a WHERE clause, can be specified with the DDL system using the argument ``sqlite_where``:: - tbl = Table("testtbl", m, Column("data", Integer)) - idx = Index( - "test_idx1", - tbl.c.data, - sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10), - ) + tbl = Table('testtbl', m, Column('data', Integer)) + idx = Index('test_idx1', tbl.c.data, + sqlite_where=and_(tbl.c.data > 5, tbl.c.data < 10)) -The index will be rendered at create time as: - -.. sourcecode:: sql +The index will be rendered at create time as:: CREATE INDEX test_idx1 ON testtbl (data) WHERE data > 5 AND data < 10 @@ -831,11 +732,7 @@ The bug, entirely outside of SQLAlchemy, can be illustrated thusly:: import sqlite3 - assert sqlite3.sqlite_version_info < ( - 3, - 10, - 0, - ), "bug is fixed in this version" + assert sqlite3.sqlite_version_info < (3, 10, 0), "bug is fixed in this version" conn = sqlite3.connect(":memory:") cursor = conn.cursor() @@ -845,22 +742,17 @@ The bug, entirely outside of SQLAlchemy, can be illustrated thusly:: cursor.execute("insert into x (a, b) values (2, 2)") cursor.execute("select x.a, x.b from x") - assert [c[0] for c in cursor.description] == ["a", "b"] + assert [c[0] for c in cursor.description] == ['a', 'b'] - cursor.execute( - """ + cursor.execute(''' select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - """ - ) - assert [c[0] for c in cursor.description] == ["a", "b"], [ - c[0] for c in cursor.description - ] + ''') + assert [c[0] for c in cursor.description] == ['a', 'b'], \ + [c[0] for c in cursor.description] -The second assertion fails: - -.. sourcecode:: text +The second assertion fails:: Traceback (most recent call last): File "test.py", line 19, in @@ -888,13 +780,11 @@ to filter these out:: result = conn.exec_driver_sql("select x.a, x.b from x") assert result.keys() == ["a", "b"] - result = conn.exec_driver_sql( - """ + result = conn.exec_driver_sql(''' select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - """ - ) + ''') assert result.keys() == ["a", "b"] Note that above, even though SQLAlchemy filters out the dots, *both @@ -918,20 +808,16 @@ contain dots, and the functionality of :meth:`_engine.CursorResult.keys` and the ``sqlite_raw_colnames`` execution option may be provided, either on a per-:class:`_engine.Connection` basis:: - result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql( - """ + result = conn.execution_options(sqlite_raw_colnames=True).exec_driver_sql(''' select x.a, x.b from x where a=1 union select x.a, x.b from x where a=2 - """ - ) + ''') assert result.keys() == ["x.a", "x.b"] or on a per-:class:`_engine.Engine` basis:: - engine = create_engine( - "sqlite://", execution_options={"sqlite_raw_colnames": True} - ) + engine = create_engine("sqlite://", execution_options={"sqlite_raw_colnames": True}) When using the per-:class:`_engine.Engine` execution option, note that **Core and ORM queries that use UNION may not function properly**. @@ -946,18 +832,12 @@ dialect in conjunction with the :class:`_schema.Table` construct: Table("some_table", metadata, ..., sqlite_with_rowid=False) -* - ``STRICT``:: - - Table("some_table", metadata, ..., sqlite_strict=True) - - .. versionadded:: 2.0.37 - .. seealso:: `SQLite CREATE TABLE options `_ + .. _sqlite_include_internal: Reflecting internal schema tables @@ -986,7 +866,7 @@ passed to methods such as :meth:`_schema.MetaData.reflect` or `SQLite Internal Schema Objects `_ - in the SQLite documentation. -''' # noqa +""" # noqa from __future__ import annotations import datetime @@ -1008,6 +888,7 @@ from ...engine import processors from ...engine import reflection from ...engine.reflection import ReflectionDefaults from ...sql import coercions +from ...sql import ColumnElement from ...sql import compiler from ...sql import elements from ...sql import roles @@ -1099,9 +980,7 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): "%(year)04d-%(month)02d-%(day)02d %(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" - e.g.: - - .. sourcecode:: text + e.g.:: 2021-03-15 12:05:57.105542 @@ -1117,17 +996,11 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): import re from sqlalchemy.dialects.sqlite import DATETIME - dt = DATETIME( - storage_format=( - "%(year)04d/%(month)02d/%(day)02d %(hour)02d:%(minute)02d:%(second)02d" - ), - regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)", + dt = DATETIME(storage_format="%(year)04d/%(month)02d/%(day)02d " + "%(hour)02d:%(minute)02d:%(second)02d", + regexp=r"(\d+)/(\d+)/(\d+) (\d+)-(\d+)-(\d+)" ) - :param truncate_microseconds: when ``True`` microseconds will be truncated - from the datetime. Can't be specified together with ``storage_format`` - or ``regexp``. - :param storage_format: format string which will be applied to the dict with keys year, month, day, hour, minute, second, and microsecond. @@ -1215,9 +1088,7 @@ class DATE(_DateTimeMixin, sqltypes.Date): "%(year)04d-%(month)02d-%(day)02d" - e.g.: - - .. sourcecode:: text + e.g.:: 2011-03-15 @@ -1235,9 +1106,9 @@ class DATE(_DateTimeMixin, sqltypes.Date): from sqlalchemy.dialects.sqlite import DATE d = DATE( - storage_format="%(month)02d/%(day)02d/%(year)04d", - regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)"), - ) + storage_format="%(month)02d/%(day)02d/%(year)04d", + regexp=re.compile("(?P\d+)/(?P\d+)/(?P\d+)") + ) :param storage_format: format string which will be applied to the dict with keys year, month, and day. @@ -1291,9 +1162,7 @@ class TIME(_DateTimeMixin, sqltypes.Time): "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" - e.g.: - - .. sourcecode:: text + e.g.:: 12:05:57.10558 @@ -1309,15 +1178,11 @@ class TIME(_DateTimeMixin, sqltypes.Time): import re from sqlalchemy.dialects.sqlite import TIME - t = TIME( - storage_format="%(hour)02d-%(minute)02d-%(second)02d-%(microsecond)06d", - regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?"), + t = TIME(storage_format="%(hour)02d-%(minute)02d-" + "%(second)02d-%(microsecond)06d", + regexp=re.compile("(\d+)-(\d+)-(\d+)-(?:-(\d+))?") ) - :param truncate_microseconds: when ``True`` microseconds will be truncated - from the time. Can't be specified together with ``storage_format`` - or ``regexp``. - :param storage_format: format string which will be applied to the dict with keys hour, minute, second, and microsecond. @@ -1443,7 +1308,7 @@ class SQLiteCompiler(compiler.SQLCompiler): return "CURRENT_TIMESTAMP" def visit_localtimestamp_func(self, func, **kw): - return "DATETIME(CURRENT_TIMESTAMP, 'localtime')" + return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' def visit_true(self, expr, **kw): return "1" @@ -1564,7 +1429,9 @@ class SQLiteCompiler(compiler.SQLCompiler): return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) def _on_conflict_target(self, clause, **kw): - if clause.inferred_target_elements is not None: + if clause.constraint_target is not None: + target_text = "(%s)" % clause.constraint_target + elif clause.inferred_target_elements is not None: target_text = "(%s)" % ", ".join( ( self.preparer.quote(c) @@ -1578,7 +1445,7 @@ class SQLiteCompiler(compiler.SQLCompiler): clause.inferred_target_whereclause, include_table=False, use_schema=False, - literal_execute=True, + literal_binds=True, ) else: @@ -1661,13 +1528,6 @@ class SQLiteCompiler(compiler.SQLCompiler): return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) - def visit_bitwise_xor_op_binary(self, binary, operator, **kw): - # sqlite has no xor. Use "a XOR b" = "(a | b) - (a & b)". - kw["eager_grouping"] = True - or_ = self._generate_generic_binary(binary, " | ", **kw) - and_ = self._generate_generic_binary(binary, " & ", **kw) - return f"({or_} - {and_})" - class SQLiteDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): @@ -1677,13 +1537,9 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: - - if not re.match(r"""^\s*[\'\"\(]""", default) and re.match( - r".*\W.*", default - ): - colspec += f" DEFAULT ({default})" - else: - colspec += f" DEFAULT {default}" + if isinstance(column.server_default.arg, ColumnElement): + default = "(" + default + ")" + colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" @@ -1845,18 +1701,9 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def post_create_table(self, table): - table_options = [] - - if not table.dialect_options["sqlite"]["with_rowid"]: - table_options.append("WITHOUT ROWID") - - if table.dialect_options["sqlite"]["strict"]: - table_options.append("STRICT") - - if table_options: - return "\n " + ",\n ".join(table_options) - else: - return "" + if table.dialect_options["sqlite"]["with_rowid"] is False: + return "\n WITHOUT ROWID" + return "" class SQLiteTypeCompiler(compiler.GenericTypeCompiler): @@ -2091,7 +1938,6 @@ class SQLiteDialect(default.DefaultDialect): { "autoincrement": False, "with_rowid": True, - "strict": False, }, ), (sa_schema.Index, {"where": None}), @@ -2184,9 +2030,9 @@ class SQLiteDialect(default.DefaultDialect): ) if self.dbapi.sqlite_version_info < (3, 35) or util.pypy: - self.update_returning = self.delete_returning = ( - self.insert_returning - ) = False + self.update_returning = ( + self.delete_returning + ) = self.insert_returning = False if self.dbapi.sqlite_version_info < (3, 32, 0): # https://www.sqlite.org/limits.html @@ -2385,14 +2231,6 @@ class SQLiteDialect(default.DefaultDialect): tablesql = self._get_table_sql( connection, table_name, schema, **kw ) - # remove create table - match = re.match( - r"create table .*?\((.*)\)$", - tablesql.strip(), - re.DOTALL | re.IGNORECASE, - ) - assert match, f"create table not found in {tablesql}" - tablesql = match.group(1).strip() columns.append( self._get_column_info( @@ -2447,10 +2285,7 @@ class SQLiteDialect(default.DefaultDialect): if generated: sqltext = "" if tablesql: - pattern = ( - r"[^,]*\s+GENERATED\s+ALWAYS\s+AS" - r"\s+\((.*)\)\s*(?:virtual|stored)?" - ) + pattern = r"[^,]*\s+AS\s+\(([^,]*)\)\s*(?:virtual|stored)?" match = re.search( re.escape(name) + pattern, tablesql, re.IGNORECASE ) @@ -2735,8 +2570,8 @@ class SQLiteDialect(default.DefaultDialect): return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( - r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?)[\t ]' - r"+[a-z0-9_ ]+?[\t ]+UNIQUE" + r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' + r"+[a-z0-9_ ]+? +UNIQUE" ) for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): @@ -2771,21 +2606,15 @@ class SQLiteDialect(default.DefaultDialect): connection, table_name, schema=schema, **kw ) - # NOTE NOTE NOTE - # DO NOT CHANGE THIS REGULAR EXPRESSION. There is no known way - # to parse CHECK constraints that contain newlines themselves using - # regular expressions, and the approach here relies upon each - # individual - # CHECK constraint being on a single line by itself. This - # necessarily makes assumptions as to how the CREATE TABLE - # was emitted. A more comprehensive DDL parsing solution would be - # needed to improve upon the current situation. See #11840 for - # background - CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?CHECK *\( *(.+) *\),? *" + CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" cks = [] + # NOTE: we aren't using re.S here because we actually are + # taking advantage of each CHECK constraint being all on one + # line in the table definition in order to delineate. This + # necessarily makes assumptions as to how the CREATE TABLE + # was emitted. for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): - name = match.group(1) if name: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/dml.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/dml.py index 84cdb8b..ec428f5 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/dml.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/dml.py @@ -1,5 +1,5 @@ -# dialects/sqlite/dml.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/dml.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,10 +7,6 @@ from __future__ import annotations from typing import Any -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union from .._typing import _OnConflictIndexElementsT from .._typing import _OnConflictIndexWhereT @@ -19,7 +15,6 @@ from .._typing import _OnConflictWhereT from ... import util from ...sql import coercions from ...sql import roles -from ...sql import schema from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative @@ -27,9 +22,7 @@ from ...sql.base import ColumnCollection from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement -from ...sql.elements import ColumnElement from ...sql.elements import KeyedColumnElement -from ...sql.elements import TextClause from ...sql.expression import alias from ...util.typing import Self @@ -148,10 +141,11 @@ class Insert(StandardInsert): :paramref:`.Insert.on_conflict_do_update.set_` dictionary. :param where: - Optional argument. An expression object representing a ``WHERE`` - clause that restricts the rows affected by ``DO UPDATE SET``. Rows not - meeting the ``WHERE`` condition will not be updated (effectively a - ``DO NOTHING`` for those rows). + Optional argument. If present, can be a literal SQL + string or an acceptable expression for a ``WHERE`` clause + that restricts the rows affected by ``DO UPDATE SET``. Rows + not meeting the ``WHERE`` condition will not be updated + (effectively a ``DO NOTHING`` for those rows). """ @@ -190,10 +184,9 @@ class Insert(StandardInsert): class OnConflictClause(ClauseElement): stringify_dialect = "sqlite" - inferred_target_elements: Optional[List[Union[str, schema.Column[Any]]]] - inferred_target_whereclause: Optional[ - Union[ColumnElement[Any], TextClause] - ] + constraint_target: None + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT def __init__( self, @@ -201,22 +194,13 @@ class OnConflictClause(ClauseElement): index_where: _OnConflictIndexWhereT = None, ): if index_elements is not None: - self.inferred_target_elements = [ - coercions.expect(roles.DDLConstraintColumnRole, column) - for column in index_elements - ] - self.inferred_target_whereclause = ( - coercions.expect( - roles.WhereHavingRole, - index_where, - ) - if index_where is not None - else None - ) + self.constraint_target = None + self.inferred_target_elements = index_elements + self.inferred_target_whereclause = index_where else: - self.inferred_target_elements = ( - self.inferred_target_whereclause - ) = None + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None class OnConflictDoNothing(OnConflictClause): @@ -226,9 +210,6 @@ class OnConflictDoNothing(OnConflictClause): class OnConflictDoUpdate(OnConflictClause): __visit_name__ = "on_conflict_do_update" - update_values_to_set: List[Tuple[Union[schema.Column[Any], str], Any]] - update_whereclause: Optional[ColumnElement[Any]] - def __init__( self, index_elements: _OnConflictIndexElementsT = None, @@ -256,8 +237,4 @@ class OnConflictDoUpdate(OnConflictClause): (coercions.expect(roles.DMLColumnRole, key), value) for key, value in set_.items() ] - self.update_whereclause = ( - coercions.expect(roles.WhereHavingRole, where) - if where is not None - else None - ) + self.update_whereclause = where diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/json.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/json.py index 02f4ea4..69df317 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/json.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/json.py @@ -1,9 +1,3 @@ -# dialects/sqlite/json.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors from ... import types as sqltypes diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/provision.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/provision.py index e1df005..2ed8253 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/provision.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/provision.py @@ -1,9 +1,3 @@ -# dialects/sqlite/provision.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import os @@ -52,6 +46,8 @@ def _format_url(url, driver, ident): assert "test_schema" not in filename tokens = re.split(r"[_\.]", filename) + new_filename = f"{driver}" + for token in tokens: if token in _drivernames: if driver is None: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py index 7a3dc1b..28b900e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -1,5 +1,5 @@ -# dialects/sqlite/pysqlcipher.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/pysqlcipher.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -39,7 +39,7 @@ Current dialect selection logic is: e = create_engine( "sqlite+pysqlcipher://:password@/dbname.db", - module=sqlcipher_compatible_driver, + module=sqlcipher_compatible_driver ) These drivers make use of the SQLCipher engine. This system essentially @@ -55,12 +55,12 @@ The format of the connect string is in every way the same as that of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the "password" field is now accepted, which should contain a passphrase:: - e = create_engine("sqlite+pysqlcipher://:testing@/foo.db") + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db') For an absolute file path, two leading slashes should be used for the database name:: - e = create_engine("sqlite+pysqlcipher://:testing@//path/to/foo.db") + e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db') A selection of additional encryption-related pragmas supported by SQLCipher as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed @@ -68,9 +68,7 @@ in the query string, and will result in that PRAGMA being called for each new connection. Currently, ``cipher``, ``kdf_iter`` ``cipher_page_size`` and ``cipher_use_hmac`` are supported:: - e = create_engine( - "sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000" - ) + e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000') .. warning:: Previous versions of sqlalchemy did not take into consideration the encryption-related pragmas passed in the url string, that were silently diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py index 1f9a55c..3cd6e5f 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/dialects/sqlite/pysqlite.py @@ -1,5 +1,5 @@ -# dialects/sqlite/pysqlite.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlite/pysqlite.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,9 +28,7 @@ Connect Strings --------------- The file specification for the SQLite database is taken as the "database" -portion of the URL. Note that the format of a SQLAlchemy url is: - -.. sourcecode:: text +portion of the URL. Note that the format of a SQLAlchemy url is:: driver://user:pass@host/database @@ -39,28 +37,25 @@ the **right** of the third slash. So connecting to a relative filepath looks like:: # relative path - e = create_engine("sqlite:///path/to/database.db") + e = create_engine('sqlite:///path/to/database.db') An absolute path, which is denoted by starting with a slash, means you need **four** slashes:: # absolute path - e = create_engine("sqlite:////path/to/database.db") + e = create_engine('sqlite:////path/to/database.db') To use a Windows path, regular drive specifications and backslashes can be used. Double backslashes are probably needed:: # absolute path on Windows - e = create_engine("sqlite:///C:\\path\\to\\database.db") + e = create_engine('sqlite:///C:\\path\\to\\database.db') -To use sqlite ``:memory:`` database specify it as the filename using -``sqlite:///:memory:``. It's also the default if no filepath is -present, specifying only ``sqlite://`` and nothing else:: +The sqlite ``:memory:`` identifier is the default if no filepath is +present. Specify ``sqlite://`` and nothing else:: - # in-memory database (note three slashes) - e = create_engine("sqlite:///:memory:") - # also in-memory database - e2 = create_engine("sqlite://") + # in-memory database + e = create_engine('sqlite://') .. _pysqlite_uri_connections: @@ -100,9 +95,7 @@ Above, the pysqlite / sqlite3 DBAPI would be passed arguments as:: sqlite3.connect( "file:path/to/database?mode=ro&nolock=1", - check_same_thread=True, - timeout=10, - uri=True, + check_same_thread=True, timeout=10, uri=True ) Regarding future parameters added to either the Python or native drivers. new @@ -148,11 +141,8 @@ as follows:: def regexp(a, b): return re.search(a, b) is not None - sqlite_connection.create_function( - "regexp", - 2, - regexp, + "regexp", 2, regexp, ) There is currently no support for regular expression flags as a separate @@ -193,12 +183,10 @@ Keeping in mind that pysqlite's parsing option is not recommended, nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES can be forced if one configures "native_datetime=True" on create_engine():: - engine = create_engine( - "sqlite://", - connect_args={ - "detect_types": sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES - }, - native_datetime=True, + engine = create_engine('sqlite://', + connect_args={'detect_types': + sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES}, + native_datetime=True ) With this flag enabled, the DATE and TIMESTAMP types (but note - not the @@ -253,7 +241,6 @@ Pooling may be disabled for a file based database by specifying the parameter:: from sqlalchemy import NullPool - engine = create_engine("sqlite:///myfile.db", poolclass=NullPool) It's been observed that the :class:`.NullPool` implementation incurs an @@ -273,12 +260,9 @@ globally, and the ``check_same_thread`` flag can be passed to Pysqlite as ``False``:: from sqlalchemy.pool import StaticPool - - engine = create_engine( - "sqlite://", - connect_args={"check_same_thread": False}, - poolclass=StaticPool, - ) + engine = create_engine('sqlite://', + connect_args={'check_same_thread':False}, + poolclass=StaticPool) Note that using a ``:memory:`` database in multiple threads requires a recent version of SQLite. @@ -297,14 +281,14 @@ needed within multiple threads for this case:: # maintain the same connection per thread from sqlalchemy.pool import SingletonThreadPool - - engine = create_engine("sqlite:///mydb.db", poolclass=SingletonThreadPool) + engine = create_engine('sqlite:///mydb.db', + poolclass=SingletonThreadPool) # maintain the same connection across all threads from sqlalchemy.pool import StaticPool - - engine = create_engine("sqlite:///mydb.db", poolclass=StaticPool) + engine = create_engine('sqlite:///mydb.db', + poolclass=StaticPool) Note that :class:`.SingletonThreadPool` should be configured for the number of threads that are to be used; beyond that number, connections will be @@ -333,14 +317,13 @@ same column, use a custom type that will check each row individually:: from sqlalchemy import String from sqlalchemy import TypeDecorator - class MixedBinary(TypeDecorator): impl = String cache_ok = True def process_result_value(self, value, dialect): if isinstance(value, str): - value = bytes(value, "utf-8") + value = bytes(value, 'utf-8') elif value is not None: value = bytes(value) @@ -354,10 +337,74 @@ Then use the above ``MixedBinary`` datatype in the place where Serializable isolation / Savepoints / Transactional DDL ------------------------------------------------------- -A newly revised version of this important section is now available -at the top level of the SQLAlchemy SQLite documentation, in the section -:ref:`sqlite_transactions`. +In the section :ref:`sqlite_concurrency`, we refer to the pysqlite +driver's assortment of issues that prevent several features of SQLite +from working correctly. The pysqlite DBAPI driver has several +long-standing bugs which impact the correctness of its transactional +behavior. In its default mode of operation, SQLite features such as +SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are +non-functional, and in order to use these features, workarounds must +be taken. +The issue is essentially that the driver attempts to second-guess the user's +intent, failing to start transactions and sometimes ending them prematurely, in +an effort to minimize the SQLite databases's file locking behavior, even +though SQLite itself uses "shared" locks for read-only activities. + +SQLAlchemy chooses to not alter this behavior by default, as it is the +long-expected behavior of the pysqlite driver; if and when the pysqlite +driver attempts to repair these issues, that will be more of a driver towards +defaults for SQLAlchemy. + +The good news is that with a few events, we can implement transactional +support fully, by disabling pysqlite's feature entirely and emitting BEGIN +ourselves. This is achieved using two event listeners:: + + from sqlalchemy import create_engine, event + + engine = create_engine("sqlite:///myfile.db") + + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.exec_driver_sql("BEGIN") + +.. warning:: When using the above recipe, it is advised to not use the + :paramref:`.Connection.execution_options.isolation_level` setting on + :class:`_engine.Connection` and :func:`_sa.create_engine` + with the SQLite driver, + as this function necessarily will also alter the ".isolation_level" setting. + + +Above, we intercept a new pysqlite connection and disable any transactional +integration. Then, at the point at which SQLAlchemy knows that transaction +scope is to begin, we emit ``"BEGIN"`` ourselves. + +When we take control of ``"BEGIN"``, we can also control directly SQLite's +locking modes, introduced at +`BEGIN TRANSACTION `_, +by adding the desired locking mode to our ``"BEGIN"``:: + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN EXCLUSIVE") + +.. seealso:: + + `BEGIN TRANSACTION `_ - + on the SQLite site + + `sqlite3 SELECT does not BEGIN a transaction `_ - + on the Python bug tracker + + `sqlite3 module breaks transactions and potentially corrupts data `_ - + on the Python bug tracker .. _pysqlite_udfs: @@ -392,16 +439,12 @@ connection when it is created. That is accomplished with an event listener:: with engine.connect() as conn: print(conn.scalar(text("SELECT UDF()"))) + """ # noqa -from __future__ import annotations import math import os import re -from typing import cast -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union from .base import DATE from .base import DATETIME @@ -411,13 +454,6 @@ from ... import pool from ... import types as sqltypes from ... import util -if TYPE_CHECKING: - from ...engine.interfaces import DBAPIConnection - from ...engine.interfaces import DBAPICursor - from ...engine.interfaces import DBAPIModule - from ...engine.url import URL - from ...pool.base import PoolProxiedConnection - class _SQLite_pysqliteTimeStamp(DATETIME): def bind_processor(self, dialect): @@ -471,7 +507,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return sqlite @classmethod - def _is_url_file_db(cls, url: URL): + def _is_url_file_db(cls, url): if (url.database and url.database != ":memory:") and ( url.query.get("mode", None) != "memory" ): @@ -502,9 +538,6 @@ class SQLiteDialect_pysqlite(SQLiteDialect): dbapi_connection.isolation_level = "" return super().set_isolation_level(dbapi_connection, level) - def detect_autocommit_setting(self, dbapi_connection): - return dbapi_connection.isolation_level is None - def on_connect(self): def regexp(a, b): if b is None: @@ -604,13 +637,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): return ([filename], pysqlite_opts) - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], - cursor: Optional[DBAPICursor], - ) -> bool: - self.dbapi = cast("DBAPIModule", self.dbapi) + def is_disconnect(self, e, connection, cursor): return isinstance( e, self.dbapi.ProgrammingError ) and "Cannot operate on a closed database." in str(e) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/__init__.py index f4205d8..843f970 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/__init__.py @@ -1,5 +1,5 @@ # engine/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_processors.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_processors.py index 8536d53..1cc5e8d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_processors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_processors.py @@ -1,5 +1,5 @@ -# engine/_py_processors.py -# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors +# sqlalchemy/processors.py +# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors # # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_row.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_row.py index 38c60fc..3358abd 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_row.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_row.py @@ -1,9 +1,3 @@ -# engine/_py_row.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import operator diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_util.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_util.py index c717660..538c075 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_util.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/_py_util.py @@ -1,9 +1,3 @@ -# engine/_py_util.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import typing @@ -32,9 +26,9 @@ def _distill_params_20( # Assume list is more likely than tuple elif isinstance(params, list) or isinstance(params, tuple): # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance(params[0], Mapping): + if params and not isinstance(params[0], (tuple, Mapping)): raise exc.ArgumentError( - "List argument must consist only of dictionaries" + "List argument must consist only of tuples or dictionaries" ) return params diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/base.py index 82729ee..0000e28 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/base.py @@ -1,10 +1,12 @@ # engine/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.""" +"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`. + +""" from __future__ import annotations import contextlib @@ -68,11 +70,12 @@ if typing.TYPE_CHECKING: from ..sql._typing import _InfoType from ..sql.compiler import Compiled from ..sql.ddl import ExecutableDDLElement - from ..sql.ddl import InvokeDDLBase + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator from ..sql.functions import FunctionElement from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr - from ..sql.schema import SchemaVisitable + from ..sql.schema import SchemaItem from ..sql.selectable import TypedReturnsRows @@ -106,7 +109,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ - dialect: Dialect dispatch: dispatcher[ConnectionEventsTarget] _sqla_logger_namespace = "sqlalchemy.engine.Connection" @@ -171,9 +173,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._has_events or self.engine._has_events: self.dispatch.engine_connect(self) - # this can be assigned differently via - # characteristics.LoggingTokenCharacteristic - _message_formatter: Any = None + @util.memoized_property + def _message_formatter(self) -> Any: + if "logging_token" in self._execution_options: + token = self._execution_options["logging_token"] + return lambda msg: "[%s] %s" % (token, msg) + else: + return None def _log_info(self, message: str, *arg: Any, **kw: Any) -> None: fmt = self._message_formatter @@ -199,9 +205,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @property def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]: - schema_translate_map: Optional[SchemaTranslateMapType] = ( - self._execution_options.get("schema_translate_map", None) - ) + schema_translate_map: Optional[ + SchemaTranslateMapType + ] = self._execution_options.get("schema_translate_map", None) return schema_translate_map @@ -212,9 +218,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ name = obj.schema - schema_translate_map: Optional[SchemaTranslateMapType] = ( - self._execution_options.get("schema_translate_map", None) - ) + schema_translate_map: Optional[ + SchemaTranslateMapType + ] = self._execution_options.get("schema_translate_map", None) if ( schema_translate_map @@ -244,12 +250,13 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., - preserve_rowcount: bool = False, **opt: Any, - ) -> Connection: ... + ) -> Connection: + ... @overload - def execution_options(self, **opt: Any) -> Connection: ... + def execution_options(self, **opt: Any) -> Connection: + ... def execution_options(self, **opt: Any) -> Connection: r"""Set non-SQL options for the connection which take effect @@ -375,11 +382,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :param stream_results: Available on: :class:`_engine.Connection`, :class:`_sql.Executable`. - Indicate to the dialect that results should be "streamed" and not - pre-buffered, if possible. For backends such as PostgreSQL, MySQL - and MariaDB, this indicates the use of a "server side cursor" as - opposed to a client side cursor. Other backends such as that of - Oracle Database may already use server side cursors by default. + Indicate to the dialect that results should be + "streamed" and not pre-buffered, if possible. For backends + such as PostgreSQL, MySQL and MariaDB, this indicates the use of + a "server side cursor" as opposed to a client side cursor. + Other backends such as that of Oracle may already use server + side cursors by default. The usage of :paramref:`_engine.Connection.execution_options.stream_results` is @@ -484,18 +492,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): :ref:`schema_translating` - :param preserve_rowcount: Boolean; when True, the ``cursor.rowcount`` - attribute will be unconditionally memoized within the result and - made available via the :attr:`.CursorResult.rowcount` attribute. - Normally, this attribute is only preserved for UPDATE and DELETE - statements. Using this option, the DBAPIs rowcount value can - be accessed for other kinds of statements such as INSERT and SELECT, - to the degree that the DBAPI supports these statements. See - :attr:`.CursorResult.rowcount` for notes regarding the behavior - of this attribute. - - .. versionadded:: 2.0.28 - .. seealso:: :meth:`_engine.Engine.execution_options` @@ -797,6 +793,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): with conn.begin() as trans: conn.execute(table.insert(), {"username": "sandy"}) + The returned object is an instance of :class:`_engine.RootTransaction`. This object represents the "scope" of the transaction, which completes when either the :meth:`_engine.Transaction.rollback` @@ -902,7 +899,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): trans.rollback() # rollback to savepoint # outer transaction continues - connection.execute(...) + connection.execute( ... ) If :meth:`_engine.Connection.begin_nested` is called without first calling :meth:`_engine.Connection.begin` or @@ -912,11 +909,11 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): with engine.connect() as connection: # begin() wasn't called - with connection.begin_nested(): # will auto-"begin()" first - connection.execute(...) + with connection.begin_nested(): will auto-"begin()" first + connection.execute( ... ) # savepoint is released - connection.execute(...) + connection.execute( ... ) # explicitly commit outer transaction connection.commit() @@ -1112,16 +1109,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._still_open_and_dbapi_connection_is_valid: if self._echo: if self._is_autocommit_isolation(): - if self.dialect.skip_autocommit_rollback: - self._log_info( - "ROLLBACK will be skipped by " - "skip_autocommit_rollback" - ) - else: - self._log_info( - "ROLLBACK using DBAPI connection.rollback(); " - "set skip_autocommit_rollback to prevent fully" - ) + self._log_info( + "ROLLBACK using DBAPI connection.rollback(), " + "DBAPI should ignore due to autocommit mode" + ) else: self._log_info("ROLLBACK") try: @@ -1137,7 +1128,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._is_autocommit_isolation(): self._log_info( "COMMIT using DBAPI connection.commit(), " - "has no effect due to autocommit mode" + "DBAPI should ignore due to autocommit mode" ) else: self._log_info("COMMIT") @@ -1271,7 +1262,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload def scalar( @@ -1280,7 +1272,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: ... + ) -> Any: + ... def scalar( self, @@ -1318,7 +1311,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload def scalars( @@ -1327,7 +1321,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... def scalars( self, @@ -1361,7 +1356,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[_T]: ... + ) -> CursorResult[_T]: + ... @overload def execute( @@ -1370,7 +1366,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... def execute( self, @@ -1501,7 +1498,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ) -> CursorResult[Any]: """Execute a schema.DDL object.""" - exec_opts = ddl._execution_options.merge_with( + execution_options = ddl._execution_options.merge_with( self._execution_options, execution_options ) @@ -1515,11 +1512,12 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): event_multiparams, event_params, ) = self._invoke_before_exec_event( - ddl, distilled_parameters, exec_opts + ddl, distilled_parameters, execution_options ) else: event_multiparams = event_params = None + exec_opts = self._execution_options.merge_with(execution_options) schema_translate_map = exec_opts.get("schema_translate_map", None) dialect = self.dialect @@ -1532,7 +1530,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): dialect.execution_ctx_cls._init_ddl, compiled, None, - exec_opts, + execution_options, compiled, ) if self._has_events or self.engine._has_events: @@ -1541,7 +1539,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): ddl, event_multiparams, event_params, - exec_opts, + execution_options, ret, ) return ret @@ -1739,20 +1737,21 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): conn.exec_driver_sql( "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", - [{"id": 1, "value": "v1"}, {"id": 2, "value": "v2"}], + [{"id":1, "value":"v1"}, {"id":2, "value":"v2"}] ) Single dictionary:: conn.exec_driver_sql( "INSERT INTO table (id, value) VALUES (%(id)s, %(value)s)", - dict(id=1, value="v1"), + dict(id=1, value="v1") ) Single tuple:: conn.exec_driver_sql( - "INSERT INTO table (id, value) VALUES (?, ?)", (1, "v1") + "INSERT INTO table (id, value) VALUES (?, ?)", + (1, 'v1') ) .. note:: The :meth:`_engine.Connection.exec_driver_sql` method does @@ -1841,7 +1840,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): context.pre_exec() if context.execute_style is ExecuteStyle.INSERTMANYVALUES: - return self._exec_insertmany_context(dialect, context) + return self._exec_insertmany_context( + dialect, + context, + ) else: return self._exec_single_context( dialect, context, statement, parameters @@ -2016,22 +2018,16 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): engine_events = self._has_events or self.engine._has_events if self.dialect._has_events: - do_execute_dispatch: Iterable[Any] = ( - self.dialect.dispatch.do_execute - ) + do_execute_dispatch: Iterable[ + Any + ] = self.dialect.dispatch.do_execute else: do_execute_dispatch = () if self._echo: stats = context._get_cache_stats() + " (insertmanyvalues)" - preserve_rowcount = context.execution_options.get( - "preserve_rowcount", False - ) - rowcount = 0 - for imv_batch in dialect._deliver_insertmanyvalues_batches( - self, cursor, str_statement, effective_parameters, @@ -2052,7 +2048,6 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): imv_batch.replaced_parameters, None, context, - is_sub_exec=True, ) sub_stmt = imv_batch.replaced_statement @@ -2072,16 +2067,15 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): if self._echo: self._log_info(sql_util._long_statement(sub_stmt)) - imv_stats = f""" {imv_batch.batchnum}/{ - imv_batch.total_batches - } ({ - 'ordered' - if imv_batch.rows_sorted else 'unordered' - }{ - '; batch not supported' - if imv_batch.is_downgraded - else '' - })""" + imv_stats = f""" { + imv_batch.batchnum}/{imv_batch.total_batches} ({ + 'ordered' + if imv_batch.rows_sorted else 'unordered' + }{ + '; batch not supported' + if imv_batch.is_downgraded + else '' + })""" if imv_batch.batchnum == 1: stats += imv_stats @@ -2142,15 +2136,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): context.executemany, ) - if preserve_rowcount: - rowcount += imv_batch.current_batch_size - try: context.post_exec() - if preserve_rowcount: - context._rowcount = rowcount # type: ignore[attr-defined] - result = context._setup_result_proxy() except BaseException as e: @@ -2392,9 +2380,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): None, cast(Exception, e), dialect.loaded_dbapi.Error, - hide_parameters=( - engine.hide_parameters if engine is not None else False - ), + hide_parameters=engine.hide_parameters + if engine is not None + else False, connection_invalidated=is_disconnect, dialect=dialect, ) @@ -2431,7 +2419,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): break if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = ctx.is_disconnect + sqlalchemy_exception.connection_invalidated = ( + is_disconnect + ) = ctx.is_disconnect if newraise: raise newraise.with_traceback(exc_info[2]) from e @@ -2444,8 +2434,8 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _run_ddl_visitor( self, - visitorcallable: Type[InvokeDDLBase], - element: SchemaVisitable, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, **kwargs: Any, ) -> None: """run a DDL visitor. @@ -2454,9 +2444,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): options given to the visitor so that "checkfirst" is skipped. """ - visitorcallable( - dialect=self.dialect, connection=self, **kwargs - ).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) class ExceptionContextImpl(ExceptionContext): @@ -2514,7 +2502,6 @@ class Transaction(TransactionalContext): :class:`_engine.Connection`:: from sqlalchemy import create_engine - engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") connection = engine.connect() trans = connection.begin() @@ -3003,7 +2990,7 @@ class Engine( This applies **only** to the built-in cache that is established via the :paramref:`_engine.create_engine.query_cache_size` parameter. It will not impact any dictionary caches that were passed via the - :paramref:`.Connection.execution_options.compiled_cache` parameter. + :paramref:`.Connection.execution_options.query_cache` parameter. .. versionadded:: 1.4 @@ -3042,10 +3029,12 @@ class Engine( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> OptionEngine: ... + ) -> OptionEngine: + ... @overload - def execution_options(self, **opt: Any) -> OptionEngine: ... + def execution_options(self, **opt: Any) -> OptionEngine: + ... def execution_options(self, **opt: Any) -> OptionEngine: """Return a new :class:`_engine.Engine` that will provide @@ -3092,10 +3081,10 @@ class Engine( shards = {"default": "base", "shard_1": "db1", "shard_2": "db2"} - @event.listens_for(Engine, "before_cursor_execute") - def _switch_shard(conn, cursor, stmt, params, context, executemany): - shard_id = conn.get_execution_options().get("shard_id", "default") + def _switch_shard(conn, cursor, stmt, + params, context, executemany): + shard_id = conn.get_execution_options().get('shard_id', "default") current_shard = conn.info.get("current_shard", None) if current_shard != shard_id: @@ -3221,7 +3210,9 @@ class Engine( E.g.:: with engine.begin() as conn: - conn.execute(text("insert into table (x, y, z) values (1, 2, 3)")) + conn.execute( + text("insert into table (x, y, z) values (1, 2, 3)") + ) conn.execute(text("my_special_procedure(5)")) Upon successful operation, the :class:`.Transaction` @@ -3237,15 +3228,15 @@ class Engine( :meth:`_engine.Connection.begin` - start a :class:`.Transaction` for a particular :class:`_engine.Connection`. - """ # noqa: E501 + """ with self.connect() as conn: with conn.begin(): yield conn def _run_ddl_visitor( self, - visitorcallable: Type[InvokeDDLBase], - element: SchemaVisitable, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, **kwargs: Any, ) -> None: with self.begin() as conn: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/characteristics.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/characteristics.py index 322c28b..c0feb00 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/characteristics.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/characteristics.py @@ -1,9 +1,3 @@ -# engine/characteristics.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php from __future__ import annotations import abc @@ -12,7 +6,6 @@ from typing import Any from typing import ClassVar if typing.TYPE_CHECKING: - from .base import Connection from .interfaces import DBAPIConnection from .interfaces import Dialect @@ -45,30 +38,13 @@ class ConnectionCharacteristic(abc.ABC): def reset_characteristic( self, dialect: Dialect, dbapi_conn: DBAPIConnection ) -> None: - """Reset the characteristic on the DBAPI connection to its default - value.""" + """Reset the characteristic on the connection to its default value.""" @abc.abstractmethod def set_characteristic( self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any ) -> None: - """set characteristic on the DBAPI connection to a given value.""" - - def set_connection_characteristic( - self, - dialect: Dialect, - conn: Connection, - dbapi_conn: DBAPIConnection, - value: Any, - ) -> None: - """set characteristic on the :class:`_engine.Connection` to a given - value. - - .. versionadded:: 2.0.30 - added to support elements that are local - to the :class:`_engine.Connection` itself. - - """ - self.set_characteristic(dialect, dbapi_conn, value) + """set characteristic on the connection to a given value.""" @abc.abstractmethod def get_characteristic( @@ -79,22 +55,8 @@ class ConnectionCharacteristic(abc.ABC): """ - def get_connection_characteristic( - self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection - ) -> Any: - """Given a :class:`_engine.Connection`, get the current value of the - characteristic. - - .. versionadded:: 2.0.30 - added to support elements that are local - to the :class:`_engine.Connection` itself. - - """ - return self.get_characteristic(dialect, dbapi_conn) - class IsolationLevelCharacteristic(ConnectionCharacteristic): - """Manage the isolation level on a DBAPI connection""" - transactional: ClassVar[bool] = True def reset_characteristic( @@ -111,45 +73,3 @@ class IsolationLevelCharacteristic(ConnectionCharacteristic): self, dialect: Dialect, dbapi_conn: DBAPIConnection ) -> Any: return dialect.get_isolation_level(dbapi_conn) - - -class LoggingTokenCharacteristic(ConnectionCharacteristic): - """Manage the 'logging_token' option of a :class:`_engine.Connection`. - - .. versionadded:: 2.0.30 - - """ - - transactional: ClassVar[bool] = False - - def reset_characteristic( - self, dialect: Dialect, dbapi_conn: DBAPIConnection - ) -> None: - pass - - def set_characteristic( - self, dialect: Dialect, dbapi_conn: DBAPIConnection, value: Any - ) -> None: - raise NotImplementedError() - - def set_connection_characteristic( - self, - dialect: Dialect, - conn: Connection, - dbapi_conn: DBAPIConnection, - value: Any, - ) -> None: - if value: - conn._message_formatter = lambda msg: "[%s] %s" % (value, msg) - else: - del conn._message_formatter - - def get_characteristic( - self, dialect: Dialect, dbapi_conn: DBAPIConnection - ) -> Any: - raise NotImplementedError() - - def get_connection_characteristic( - self, dialect: Dialect, conn: Connection, dbapi_conn: DBAPIConnection - ) -> Any: - return conn._execution_options.get("logging_token", None) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/create.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/create.py index bf1ede6..684550e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/create.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/create.py @@ -1,5 +1,5 @@ # engine/create.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -82,11 +82,13 @@ def create_engine( query_cache_size: int = ..., use_insertmanyvalues: bool = ..., **kwargs: Any, -) -> Engine: ... +) -> Engine: + ... @overload -def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ... +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: + ... @util.deprecated_params( @@ -133,11 +135,8 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: and its underlying :class:`.Dialect` and :class:`_pool.Pool` constructs:: - engine = create_engine( - "mysql+mysqldb://scott:tiger@hostname/dbname", - pool_recycle=3600, - echo=True, - ) + engine = create_engine("mysql+mysqldb://scott:tiger@hostname/dbname", + pool_recycle=3600, echo=True) The string form of the URL is ``dialect[+driver]://user:password@host/dbname[?key=value..]``, where @@ -468,9 +467,6 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: :ref:`pool_reset_on_return` - :ref:`dbapi_autocommit_skip_rollback` - a more modern approach - to using connections with no transactional instructions - :param pool_timeout=30: number of seconds to wait before giving up on getting a connection from the pool. This is only used with :class:`~sqlalchemy.pool.QueuePool`. This can be a float but is @@ -527,18 +523,6 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: .. versionadded:: 1.4 - :param skip_autocommit_rollback: When True, the dialect will - unconditionally skip all calls to the DBAPI ``connection.rollback()`` - method if the DBAPI connection is confirmed to be in "autocommit" mode. - The availability of this feature is dialect specific; if not available, - a ``NotImplementedError`` is raised by the dialect when rollback occurs. - - .. seealso:: - - :ref:`dbapi_autocommit_skip_rollback` - - .. versionadded:: 2.0.43 - :param use_insertmanyvalues: True by default, use the "insertmanyvalues" execution style for INSERT..RETURNING statements by default. @@ -632,14 +616,6 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: # assemble connection arguments (cargs_tup, cparams) = dialect.create_connect_args(u) cparams.update(pop_kwarg("connect_args", {})) - - if "async_fallback" in cparams and util.asbool(cparams["async_fallback"]): - util.warn_deprecated( - "The async_fallback dialect argument is deprecated and will be " - "removed in SQLAlchemy 2.1.", - "2.0", - ) - cargs = list(cargs_tup) # allow mutability # look for existing pool or create @@ -681,17 +657,6 @@ def create_engine(url: Union[str, _url.URL], **kwargs: Any) -> Engine: else: pool._dialect = dialect - if ( - hasattr(pool, "_is_asyncio") - and pool._is_asyncio is not dialect.is_async - ): - raise exc.ArgumentError( - f"Pool class {pool.__class__.__name__} cannot be " - f"used with {'non-' if not dialect.is_async else ''}" - "asyncio engine", - code="pcls", - ) - # create engine. if not pop_kwarg("future", True): raise exc.ArgumentError( @@ -851,11 +816,13 @@ def create_pool_from_url( timeout: float = ..., use_lifo: bool = ..., **kwargs: Any, -) -> Pool: ... +) -> Pool: + ... @overload -def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ... +def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: + ... def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/cursor.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/cursor.py index 8e2348e..45af49a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/cursor.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/cursor.py @@ -1,5 +1,5 @@ # engine/cursor.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -20,7 +20,6 @@ from typing import Any from typing import cast from typing import ClassVar from typing import Dict -from typing import Iterable from typing import Iterator from typing import List from typing import Mapping @@ -121,7 +120,7 @@ _CursorKeyMapRecType = Tuple[ List[Any], # MD_OBJECTS str, # MD_LOOKUP_KEY str, # MD_RENDERED_NAME - Optional["_ResultProcessorType[Any]"], # MD_PROCESSOR + Optional["_ResultProcessorType"], # MD_PROCESSOR Optional[str], # MD_UNTRANSLATED ] @@ -135,7 +134,7 @@ _NonAmbigCursorKeyMapRecType = Tuple[ List[Any], str, str, - Optional["_ResultProcessorType[Any]"], + Optional["_ResultProcessorType"], str, ] @@ -152,7 +151,7 @@ class CursorResultMetaData(ResultMetaData): "_translated_indexes", "_safe_for_cache", "_unpickled", - "_key_to_index", + "_key_to_index" # don't need _unique_filters support here for now. Can be added # if a need arises. ) @@ -226,11 +225,9 @@ class CursorResultMetaData(ResultMetaData): { key: ( # int index should be None for ambiguous key - ( - value[0] + offset - if value[0] is not None and key not in keymap - else None - ), + value[0] + offset + if value[0] is not None and key not in keymap + else None, value[1] + offset, *value[2:], ) @@ -365,11 +362,13 @@ class CursorResultMetaData(ResultMetaData): ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = cols_are_ordered = ( # type: ignore + result_columns = ( # type: ignore + cols_are_ordered + ) = ( num_ctx_cols - ) = ad_hoc_textual = loose_column_name_matching = ( - textual_ordered - ) = False + ) = ( + ad_hoc_textual + ) = loose_column_name_matching = textual_ordered = False # merge cursor.description with the column info # present in the compiled structure, if any @@ -689,7 +688,6 @@ class CursorResultMetaData(ResultMetaData): % (num_ctx_cols, len(cursor_description)) ) seen = set() - for ( idx, colname, @@ -1163,7 +1161,7 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): result = conn.execution_options( stream_results=True, max_row_buffer=50 - ).execute(text("select * from table")) + ).execute(text("select * from table")) .. versionadded:: 1.4 ``max_row_buffer`` may now exceed 1000 rows. @@ -1248,9 +1246,8 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): if size is None: return self.fetchall(result, dbapi_cursor) - rb = self._rowbuffer - lb = len(rb) - close = False + buf = list(self._rowbuffer) + lb = len(buf) if size > lb: try: new = dbapi_cursor.fetchmany(size - lb) @@ -1258,15 +1255,13 @@ class BufferedRowCursorFetchStrategy(CursorFetchStrategy): self.handle_exception(result, dbapi_cursor, e) else: if not new: - # defer closing since it may clear the row buffer - close = True + result._soft_close() else: - rb.extend(new) + buf.extend(new) - res = [rb.popleft() for _ in range(min(size, len(rb)))] - if close: - result._soft_close() - return res + result = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) + return result def fetchall(self, result, dbapi_cursor): try: @@ -1290,16 +1285,12 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): __slots__ = ("_rowbuffer", "alternate_cursor_description") def __init__( - self, - dbapi_cursor: Optional[DBAPICursor], - alternate_description: Optional[_DBAPICursorDescription] = None, - initial_buffer: Optional[Iterable[Any]] = None, + self, dbapi_cursor, alternate_description=None, initial_buffer=None ): self.alternate_cursor_description = alternate_description if initial_buffer is not None: self._rowbuffer = collections.deque(initial_buffer) else: - assert dbapi_cursor is not None self._rowbuffer = collections.deque(dbapi_cursor.fetchall()) def yield_per(self, result, dbapi_cursor, num): @@ -1324,8 +1315,9 @@ class FullyBufferedCursorFetchStrategy(CursorFetchStrategy): if size is None: return self.fetchall(result, dbapi_cursor) - rb = self._rowbuffer - rows = [rb.popleft() for _ in range(min(size, len(rb)))] + buf = list(self._rowbuffer) + rows = buf[0:size] + self._rowbuffer = collections.deque(buf[size:]) if not rows: result._soft_close() return rows @@ -1358,15 +1350,15 @@ class _NoResultMetaData(ResultMetaData): self._we_dont_return_rows() @property - def _keymap(self): # type: ignore[override] + def _keymap(self): self._we_dont_return_rows() @property - def _key_to_index(self): # type: ignore[override] + def _key_to_index(self): self._we_dont_return_rows() @property - def _processors(self): # type: ignore[override] + def _processors(self): self._we_dont_return_rows() @property @@ -1446,7 +1438,6 @@ class CursorResult(Result[_T]): metadata = self._init_metadata(context, cursor_description) - _make_row: Any _make_row = functools.partial( Row, metadata, @@ -1619,11 +1610,11 @@ class CursorResult(Result[_T]): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled expression construct." + "Statement is not a compiled " "expression construct." ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() expression construct." + "Statement is not an insert() " "expression construct." ) elif self.context._is_explicit_returning: raise exc.InvalidRequestError( @@ -1690,11 +1681,11 @@ class CursorResult(Result[_T]): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled expression construct." + "Statement is not a compiled " "expression construct." ) elif not self.context.isupdate: raise exc.InvalidRequestError( - "Statement is not an update() expression construct." + "Statement is not an update() " "expression construct." ) elif self.context.executemany: return self.context.compiled_parameters @@ -1712,11 +1703,11 @@ class CursorResult(Result[_T]): """ if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled expression construct." + "Statement is not a compiled " "expression construct." ) elif not self.context.isinsert: raise exc.InvalidRequestError( - "Statement is not an insert() expression construct." + "Statement is not an insert() " "expression construct." ) elif self.context.executemany: return self.context.compiled_parameters @@ -1761,9 +1752,11 @@ class CursorResult(Result[_T]): r1 = connection.execute( users.insert().returning( - users.c.user_name, users.c.user_id, sort_by_parameter_order=True + users.c.user_name, + users.c.user_id, + sort_by_parameter_order=True ), - user_values, + user_values ) r2 = connection.execute( @@ -1771,16 +1764,19 @@ class CursorResult(Result[_T]): addresses.c.address_id, addresses.c.address, addresses.c.user_id, - sort_by_parameter_order=True, + sort_by_parameter_order=True ), - address_values, + address_values ) rows = r1.splice_horizontally(r2).all() - assert rows == [ - ("john", 1, 1, "foo@bar.com", 1), - ("jack", 2, 2, "bar@bat.com", 2), - ] + assert ( + rows == + [ + ("john", 1, 1, "foo@bar.com", 1), + ("jack", 2, 2, "bar@bat.com", 2), + ] + ) .. versionadded:: 2.0 @@ -1789,7 +1785,7 @@ class CursorResult(Result[_T]): :meth:`.CursorResult.splice_vertically` - """ # noqa: E501 + """ clone = self._generate() total_rows = [ @@ -1924,7 +1920,7 @@ class CursorResult(Result[_T]): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled expression construct." + "Statement is not a compiled " "expression construct." ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( @@ -1947,7 +1943,7 @@ class CursorResult(Result[_T]): if not self.context.compiled: raise exc.InvalidRequestError( - "Statement is not a compiled expression construct." + "Statement is not a compiled " "expression construct." ) elif not self.context.isinsert and not self.context.isupdate: raise exc.InvalidRequestError( @@ -1978,28 +1974,8 @@ class CursorResult(Result[_T]): def rowcount(self) -> int: """Return the 'rowcount' for this result. - The primary purpose of 'rowcount' is to report the number of rows - matched by the WHERE criterion of an UPDATE or DELETE statement - executed once (i.e. for a single parameter set), which may then be - compared to the number of rows expected to be updated or deleted as a - means of asserting data integrity. - - This attribute is transferred from the ``cursor.rowcount`` attribute - of the DBAPI before the cursor is closed, to support DBAPIs that - don't make this value available after cursor close. Some DBAPIs may - offer meaningful values for other kinds of statements, such as INSERT - and SELECT statements as well. In order to retrieve ``cursor.rowcount`` - for these statements, set the - :paramref:`.Connection.execution_options.preserve_rowcount` - execution option to True, which will cause the ``cursor.rowcount`` - value to be unconditionally memoized before any results are returned - or the cursor is closed, regardless of statement type. - - For cases where the DBAPI does not support rowcount for a particular - kind of statement and/or execution, the returned value will be ``-1``, - which is delivered directly from the DBAPI and is part of :pep:`249`. - All DBAPIs should support rowcount for single-parameter-set - UPDATE and DELETE statements, however. + The 'rowcount' reports the number of rows *matched* + by the WHERE criterion of an UPDATE or DELETE statement. .. note:: @@ -2008,47 +1984,38 @@ class CursorResult(Result[_T]): * This attribute returns the number of rows *matched*, which is not necessarily the same as the number of rows - that were actually *modified*. For example, an UPDATE statement + that were actually *modified* - an UPDATE statement, for example, may have no net change on a given row if the SET values given are the same as those present in the row already. Such a row would be matched but not modified. On backends that feature both styles, such as MySQL, - rowcount is configured to return the match + rowcount is configured by default to return the match count in all cases. - * :attr:`_engine.CursorResult.rowcount` in the default case is - *only* useful in conjunction with an UPDATE or DELETE statement, - and only with a single set of parameters. For other kinds of - statements, SQLAlchemy will not attempt to pre-memoize the value - unless the - :paramref:`.Connection.execution_options.preserve_rowcount` - execution option is used. Note that contrary to :pep:`249`, many - DBAPIs do not support rowcount values for statements that are not - UPDATE or DELETE, particularly when rows are being returned which - are not fully pre-buffered. DBAPIs that dont support rowcount - for a particular kind of statement should return the value ``-1`` - for such statements. + * :attr:`_engine.CursorResult.rowcount` + is *only* useful in conjunction + with an UPDATE or DELETE statement. Contrary to what the Python + DBAPI says, it does *not* reliably return the + number of rows available from the results of a SELECT statement + as DBAPIs cannot support this functionality when rows are + unbuffered. - * :attr:`_engine.CursorResult.rowcount` may not be meaningful - when executing a single statement with multiple parameter sets - (i.e. an :term:`executemany`). Most DBAPIs do not sum "rowcount" - values across multiple parameter sets and will return ``-1`` - when accessed. + * :attr:`_engine.CursorResult.rowcount` + may not be fully implemented by + all dialects. In particular, most DBAPIs do not support an + aggregate rowcount result from an executemany call. + The :meth:`_engine.CursorResult.supports_sane_rowcount` and + :meth:`_engine.CursorResult.supports_sane_multi_rowcount` methods + will report from the dialect if each usage is known to be + supported. - * SQLAlchemy's :ref:`engine_insertmanyvalues` feature does support - a correct population of :attr:`_engine.CursorResult.rowcount` - when the :paramref:`.Connection.execution_options.preserve_rowcount` - execution option is set to True. - - * Statements that use RETURNING may not support rowcount, returning - a ``-1`` value instead. + * Statements that use RETURNING may not return a correct + rowcount. .. seealso:: :ref:`tutorial_update_delete_rowcount` - in the :ref:`unified_tutorial` - :paramref:`.Connection.execution_options.preserve_rowcount` - """ # noqa: E501 try: return self.context.rowcount @@ -2142,7 +2109,8 @@ class CursorResult(Result[_T]): def merge(self, *others: Result[Any]) -> MergedResult[Any]: merged_result = super().merge(*others) - if self.context._has_rowcount: + setup_rowcounts = self.context._has_rowcount + if setup_rowcounts: merged_result.rowcount = sum( cast("CursorResult[Any]", result).rowcount for result in (self,) + others diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/default.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/default.py index a241e90..553d8f0 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/default.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/default.py @@ -1,5 +1,5 @@ # engine/default.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,7 +58,6 @@ from ..sql import compiler from ..sql import dml from ..sql import expression from ..sql import type_api -from ..sql import util as sql_util from ..sql._typing import is_tuple_type from ..sql.base import _NoArg from ..sql.compiler import DDLCompiler @@ -77,13 +76,10 @@ if typing.TYPE_CHECKING: from .interfaces import _CoreSingleExecuteParams from .interfaces import _DBAPICursorDescription from .interfaces import _DBAPIMultiExecuteParams - from .interfaces import _DBAPISingleExecuteParams from .interfaces import _ExecuteOptions from .interfaces import _MutableCoreSingleExecuteParams from .interfaces import _ParamStyle - from .interfaces import ConnectArgsType from .interfaces import DBAPIConnection - from .interfaces import DBAPIModule from .interfaces import IsolationLevel from .row import Row from .url import URL @@ -99,10 +95,8 @@ if typing.TYPE_CHECKING: from ..sql.elements import BindParameter from ..sql.schema import Column from ..sql.type_api import _BindProcessorType - from ..sql.type_api import _ResultProcessorType from ..sql.type_api import TypeEngine - # When we're handed literal SQL, ensure it's a SELECT query SERVER_SIDE_CURSOR_RE = re.compile(r"\s*SELECT", re.I | re.UNICODE) @@ -173,10 +167,7 @@ class DefaultDialect(Dialect): tuple_in_values = False connection_characteristics = util.immutabledict( - { - "isolation_level": characteristics.IsolationLevelCharacteristic(), - "logging_token": characteristics.LoggingTokenCharacteristic(), - } + {"isolation_level": characteristics.IsolationLevelCharacteristic()} ) engine_config_types: Mapping[str, Any] = util.immutabledict( @@ -258,7 +249,7 @@ class DefaultDialect(Dialect): default_schema_name: Optional[str] = None # indicates symbol names are - # UPPERCASED if they are case insensitive + # UPPERCASEd if they are case insensitive # within the database. # if this is True, the methods normalize_name() # and denormalize_name() must be provided. @@ -307,7 +298,6 @@ class DefaultDialect(Dialect): # Linting.NO_LINTING constant compiler_linting: Linting = int(compiler.NO_LINTING), # type: ignore server_side_cursors: bool = False, - skip_autocommit_rollback: bool = False, **kwargs: Any, ): if server_side_cursors: @@ -332,8 +322,6 @@ class DefaultDialect(Dialect): self.dbapi = dbapi - self.skip_autocommit_rollback = skip_autocommit_rollback - if paramstyle is not None: self.paramstyle = paramstyle elif self.dbapi is not None: @@ -399,8 +387,7 @@ class DefaultDialect(Dialect): available if the dialect in use has opted into using the "use_insertmanyvalues" feature. If they haven't opted into that, then this attribute is False, unless the dialect in question overrides this - and provides some other implementation (such as the Oracle Database - dialects). + and provides some other implementation (such as the Oracle dialect). """ return self.insert_returning and self.use_insertmanyvalues @@ -423,7 +410,7 @@ class DefaultDialect(Dialect): If the dialect in use hasn't opted into that, then this attribute is False, unless the dialect in question overrides this and provides some - other implementation (such as the Oracle Database dialects). + other implementation (such as the Oracle dialect). """ return self.insert_returning and self.use_insertmanyvalues @@ -432,7 +419,7 @@ class DefaultDialect(Dialect): delete_executemany_returning = False @util.memoized_property - def loaded_dbapi(self) -> DBAPIModule: + def loaded_dbapi(self) -> ModuleType: if self.dbapi is None: raise exc.InvalidRequestError( f"Dialect {self} does not have a Python DBAPI established " @@ -444,7 +431,7 @@ class DefaultDialect(Dialect): def _bind_typing_render_casts(self): return self.bind_typing is interfaces.BindTyping.RENDER_CASTS - def _ensure_has_table_connection(self, arg: Connection) -> None: + def _ensure_has_table_connection(self, arg): if not isinstance(arg, Connection): raise exc.ArgumentError( "The argument passed to Dialect.has_table() should be a " @@ -481,7 +468,7 @@ class DefaultDialect(Dialect): return weakref.WeakKeyDictionary() @property - def dialect_description(self): # type: ignore[override] + def dialect_description(self): return self.name + "+" + self.driver @property @@ -522,7 +509,7 @@ class DefaultDialect(Dialect): else: return None - def initialize(self, connection: Connection) -> None: + def initialize(self, connection): try: self.server_version_info = self._get_server_version_info( connection @@ -558,7 +545,7 @@ class DefaultDialect(Dialect): % (self.label_length, self.max_identifier_length) ) - def on_connect(self) -> Optional[Callable[[Any], None]]: + def on_connect(self): # inherits the docstring from interfaces.Dialect.on_connect return None @@ -617,18 +604,18 @@ class DefaultDialect(Dialect): ) -> bool: return schema_name in self.get_schema_names(connection, **kw) - def validate_identifier(self, ident: str) -> None: + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( "Identifier '%s' exceeds maximum length of %d characters" % (ident, self.max_identifier_length) ) - def connect(self, *cargs: Any, **cparams: Any) -> DBAPIConnection: + def connect(self, *cargs, **cparams): # inherits the docstring from interfaces.Dialect.connect - return self.loaded_dbapi.connect(*cargs, **cparams) # type: ignore[no-any-return] # NOQA: E501 + return self.loaded_dbapi.connect(*cargs, **cparams) - def create_connect_args(self, url: URL) -> ConnectArgsType: + def create_connect_args(self, url): # inherits the docstring from interfaces.Dialect.create_connect_args opts = url.translate_connect_args() opts.update(url.query) @@ -672,7 +659,7 @@ class DefaultDialect(Dialect): if connection.in_transaction(): trans_objs = [ (name, obj) - for name, obj, _ in characteristic_values + for name, obj, value in characteristic_values if obj.transactional ] if trans_objs: @@ -685,10 +672,8 @@ class DefaultDialect(Dialect): ) dbapi_connection = connection.connection.dbapi_connection - for _, characteristic, value in characteristic_values: - characteristic.set_connection_characteristic( - self, connection, dbapi_connection, value - ) + for name, characteristic, value in characteristic_values: + characteristic.set_characteristic(self, dbapi_connection, value) connection.connection._connection_record.finalize_callback.append( functools.partial(self._reset_characteristics, characteristics) ) @@ -704,10 +689,6 @@ class DefaultDialect(Dialect): pass def do_rollback(self, dbapi_connection): - if self.skip_autocommit_rollback and self.detect_autocommit_setting( - dbapi_connection - ): - return dbapi_connection.rollback() def do_commit(self, dbapi_connection): @@ -747,6 +728,8 @@ class DefaultDialect(Dialect): raise def do_ping(self, dbapi_connection: DBAPIConnection) -> bool: + cursor = None + cursor = dbapi_connection.cursor() try: cursor.execute(self._dialect_specific_select_one) @@ -773,25 +756,11 @@ class DefaultDialect(Dialect): connection.execute(expression.ReleaseSavepointClause(name)) def _deliver_insertmanyvalues_batches( - self, - connection, - cursor, - statement, - parameters, - generic_setinputsizes, - context, + self, cursor, statement, parameters, generic_setinputsizes, context ): context = cast(DefaultExecutionContext, context) compiled = cast(SQLCompiler, context.compiled) - _composite_sentinel_proc: Sequence[ - Optional[_ResultProcessorType[Any]] - ] = () - _scalar_sentinel_proc: Optional[_ResultProcessorType[Any]] = None - _sentinel_proc_initialized: bool = False - - compiled_parameters = context.compiled_parameters - imv = compiled._insertmanyvalues assert imv is not None @@ -800,12 +769,7 @@ class DefaultDialect(Dialect): "insertmanyvalues_page_size", self.insertmanyvalues_page_size ) - if compiled.schema_translate_map: - schema_translate_map = context.execution_options.get( - "schema_translate_map", {} - ) - else: - schema_translate_map = None + sentinel_value_resolvers = None if is_returning: result: Optional[List[Any]] = [] @@ -813,6 +777,10 @@ class DefaultDialect(Dialect): sort_by_parameter_order = imv.sort_by_parameter_order + if imv.num_sentinel_columns: + sentinel_value_resolvers = ( + compiled._imv_sentinel_value_resolvers + ) else: sort_by_parameter_order = False result = None @@ -820,27 +788,14 @@ class DefaultDialect(Dialect): for imv_batch in compiled._deliver_insertmanyvalues_batches( statement, parameters, - compiled_parameters, generic_setinputsizes, batch_size, sort_by_parameter_order, - schema_translate_map, ): yield imv_batch if is_returning: - - try: - rows = context.fetchall_for_returning(cursor) - except BaseException as be: - connection._handle_dbapi_exception( - be, - sql_util._long_statement(imv_batch.replaced_statement), - imv_batch.replaced_parameters, - None, - context, - is_sub_exec=True, - ) + rows = context.fetchall_for_returning(cursor) # I would have thought "is_returning: Final[bool]" # would have assured this but pylance thinks not @@ -860,46 +815,11 @@ class DefaultDialect(Dialect): # otherwise, create dictionaries to match up batches # with parameters assert imv.sentinel_param_keys - assert imv.sentinel_columns - _nsc = imv.num_sentinel_columns - - if not _sentinel_proc_initialized: - if composite_sentinel: - _composite_sentinel_proc = [ - col.type._cached_result_processor( - self, cursor_desc[1] - ) - for col, cursor_desc in zip( - imv.sentinel_columns, - cursor.description[-_nsc:], - ) - ] - else: - _scalar_sentinel_proc = ( - imv.sentinel_columns[0] - ).type._cached_result_processor( - self, cursor.description[-1][1] - ) - _sentinel_proc_initialized = True - - rows_by_sentinel: Union[ - Dict[Tuple[Any, ...], Any], - Dict[Any, Any], - ] if composite_sentinel: + _nsc = imv.num_sentinel_columns rows_by_sentinel = { - tuple( - (proc(val) if proc else val) - for val, proc in zip( - row[-_nsc:], _composite_sentinel_proc - ) - ): row - for row in rows - } - elif _scalar_sentinel_proc: - rows_by_sentinel = { - _scalar_sentinel_proc(row[-1]): row for row in rows + tuple(row[-_nsc:]): row for row in rows } else: rows_by_sentinel = {row[-1]: row for row in rows} @@ -918,10 +838,61 @@ class DefaultDialect(Dialect): ) try: - ordered_rows = [ - rows_by_sentinel[sentinel_keys] - for sentinel_keys in imv_batch.sentinel_values - ] + if composite_sentinel: + if sentinel_value_resolvers: + # composite sentinel (PK) with value resolvers + ordered_rows = [ + rows_by_sentinel[ + tuple( + _resolver(parameters[_spk]) # type: ignore # noqa: E501 + if _resolver + else parameters[_spk] # type: ignore # noqa: E501 + for _resolver, _spk in zip( + sentinel_value_resolvers, + imv.sentinel_param_keys, + ) + ) + ] + for parameters in imv_batch.batch + ] + else: + # composite sentinel (PK) with no value + # resolvers + ordered_rows = [ + rows_by_sentinel[ + tuple( + parameters[_spk] # type: ignore + for _spk in imv.sentinel_param_keys + ) + ] + for parameters in imv_batch.batch + ] + else: + _sentinel_param_key = imv.sentinel_param_keys[0] + if ( + sentinel_value_resolvers + and sentinel_value_resolvers[0] + ): + # single-column sentinel with value resolver + _sentinel_value_resolver = ( + sentinel_value_resolvers[0] + ) + ordered_rows = [ + rows_by_sentinel[ + _sentinel_value_resolver( + parameters[_sentinel_param_key] # type: ignore # noqa: E501 + ) + ] + for parameters in imv_batch.batch + ] + else: + # single-column sentinel with no value resolver + ordered_rows = [ + rows_by_sentinel[ + parameters[_sentinel_param_key] # type: ignore # noqa: E501 + ] + for parameters in imv_batch.batch + ] except KeyError as ke: # see test_insert_exec.py:: # IMVSentinelTest::test_sentinel_cant_match_keys @@ -953,14 +924,7 @@ class DefaultDialect(Dialect): def do_execute_no_params(self, cursor, statement, context=None): cursor.execute(statement) - def is_disconnect( - self, - e: DBAPIModule.Error, - connection: Union[ - pool.PoolProxiedConnection, interfaces.DBAPIConnection, None - ], - cursor: Optional[interfaces.DBAPICursor], - ) -> bool: + def is_disconnect(self, e, connection, cursor): return False @util.memoized_instancemethod @@ -1060,7 +1024,7 @@ class DefaultDialect(Dialect): name = name_upper return name - def get_driver_connection(self, connection: DBAPIConnection) -> Any: + def get_driver_connection(self, connection): return connection def _overrides_default(self, method): @@ -1232,7 +1196,7 @@ class DefaultExecutionContext(ExecutionContext): _soft_closed = False - _rowcount: Optional[int] = None + _has_rowcount = False # a hook for SQLite's translation of # result column names @@ -1489,11 +1453,9 @@ class DefaultExecutionContext(ExecutionContext): assert positiontup is not None for compiled_params in self.compiled_parameters: l_param: List[Any] = [ - ( - flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] - ) + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] for key in positiontup ] core_positional_parameters.append( @@ -1514,20 +1476,18 @@ class DefaultExecutionContext(ExecutionContext): for compiled_params in self.compiled_parameters: if escaped_names: d_param = { - escaped_names.get(key, key): ( - flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] + escaped_names.get(key, key): flattened_processors[key]( + compiled_params[key] ) + if key in flattened_processors + else compiled_params[key] for key in compiled_params } else: d_param = { - key: ( - flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] - ) + key: flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] for key in compiled_params } @@ -1617,13 +1577,7 @@ class DefaultExecutionContext(ExecutionContext): elif ch is CACHE_MISS: return "generated in %.5fs" % (now - gen_time,) elif ch is CACHING_DISABLED: - if "_cache_disable_reason" in self.execution_options: - return "caching disabled (%s) %.5fs " % ( - self.execution_options["_cache_disable_reason"], - now - gen_time, - ) - else: - return "caching disabled %.5fs" % (now - gen_time,) + return "caching disabled %.5fs" % (now - gen_time,) elif ch is NO_DIALECT_SUPPORT: return "dialect %s+%s does not support caching %.5fs" % ( self.dialect.name, @@ -1634,7 +1588,7 @@ class DefaultExecutionContext(ExecutionContext): return "unknown" @property - def executemany(self): # type: ignore[override] + def executemany(self): return self.execute_style in ( ExecuteStyle.EXECUTEMANY, ExecuteStyle.INSERTMANYVALUES, @@ -1676,12 +1630,7 @@ class DefaultExecutionContext(ExecutionContext): def no_parameters(self): return self.execution_options.get("no_parameters", False) - def _execute_scalar( - self, - stmt: str, - type_: Optional[TypeEngine[Any]], - parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: + def _execute_scalar(self, stmt, type_, parameters=None): """Execute a string statement on the current cursor, returning a scalar result. @@ -1755,7 +1704,7 @@ class DefaultExecutionContext(ExecutionContext): return use_server_side - def create_cursor(self) -> DBAPICursor: + def create_cursor(self): if ( # inlining initial preference checks for SS cursors self.dialect.supports_server_side_cursors @@ -1776,10 +1725,10 @@ class DefaultExecutionContext(ExecutionContext): def fetchall_for_returning(self, cursor): return cursor.fetchall() - def create_default_cursor(self) -> DBAPICursor: + def create_default_cursor(self): return self._dbapi_connection.cursor() - def create_server_side_cursor(self) -> DBAPICursor: + def create_server_side_cursor(self): raise NotImplementedError() def pre_exec(self): @@ -1827,14 +1776,7 @@ class DefaultExecutionContext(ExecutionContext): @util.non_memoized_property def rowcount(self) -> int: - if self._rowcount is not None: - return self._rowcount - else: - return self.cursor.rowcount - - @property - def _has_rowcount(self): - return self._rowcount is not None + return self.cursor.rowcount def supports_sane_rowcount(self): return self.dialect.supports_sane_rowcount @@ -1845,13 +1787,9 @@ class DefaultExecutionContext(ExecutionContext): def _setup_result_proxy(self): exec_opt = self.execution_options - if self._rowcount is None and exec_opt.get("preserve_rowcount", False): - self._rowcount = self.cursor.rowcount - - yp: Optional[Union[int, bool]] if self.is_crud or self.is_text: result = self._setup_dml_or_text_result() - yp = False + yp = sr = False else: yp = exec_opt.get("yield_per", None) sr = self._is_server_side or exec_opt.get("stream_results", False) @@ -2005,7 +1943,8 @@ class DefaultExecutionContext(ExecutionContext): if rows: self.returned_default_rows = rows - self._rowcount = len(rows) + result.rowcount = len(rows) + self._has_rowcount = True if self._is_supplemental_returning: result._rewind(rows) @@ -2019,12 +1958,12 @@ class DefaultExecutionContext(ExecutionContext): elif not result._metadata.returns_rows: # no results, get rowcount # (which requires open cursor on some drivers) - if self._rowcount is None: - self._rowcount = self.cursor.rowcount + result.rowcount + self._has_rowcount = True result._soft_close() elif self.isupdate or self.isdelete: - if self._rowcount is None: - self._rowcount = self.cursor.rowcount + result.rowcount + self._has_rowcount = True return result @util.memoized_property @@ -2073,11 +2012,10 @@ class DefaultExecutionContext(ExecutionContext): style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. - This method only called by those dialects which set the - :attr:`.Dialect.bind_typing` attribute to - :attr:`.BindTyping.SETINPUTSIZES`. Python-oracledb and cx_Oracle are - the only DBAPIs that requires setinputsizes(); pyodbc offers it as an - option. + This method only called by those dialects which set + the :attr:`.Dialect.bind_typing` attribute to + :attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI + that requires setinputsizes(), pyodbc offers it as an option. Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used for pg8000 and asyncpg, which has been changed to inline rendering @@ -2205,21 +2143,17 @@ class DefaultExecutionContext(ExecutionContext): if compiled.positional: parameters = self.dialect.execute_sequence_format( [ - ( - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] - ) + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] for key in compiled.positiontup or () ] ) else: parameters = { - key: ( - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] - ) + key: processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] for key in compiled_params } return self._execute_scalar( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/events.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/events.py index b759382..aac756d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/events.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/events.py @@ -1,5 +1,5 @@ -# engine/events.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/engine/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -54,24 +54,19 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): from sqlalchemy import event, create_engine - - def before_cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): + def before_cursor_execute(conn, cursor, statement, parameters, context, + executemany): log.info("Received statement: %s", statement) - - engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/test') event.listen(engine, "before_cursor_execute", before_cursor_execute) or with a specific :class:`_engine.Connection`:: with engine.begin() as conn: - - @event.listens_for(conn, "before_cursor_execute") - def before_cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): + @event.listens_for(conn, 'before_cursor_execute') + def before_cursor_execute(conn, cursor, statement, parameters, + context, executemany): log.info("Received statement: %s", statement) When the methods are called with a `statement` parameter, such as in @@ -89,11 +84,9 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): from sqlalchemy.engine import Engine from sqlalchemy import event - @event.listens_for(Engine, "before_cursor_execute", retval=True) - def comment_sql_calls( - conn, cursor, statement, parameters, context, executemany - ): + def comment_sql_calls(conn, cursor, statement, parameters, + context, executemany): statement = statement + " -- some comment" return statement, parameters @@ -323,9 +316,8 @@ class ConnectionEvents(event.Events[ConnectionEventsTarget]): returned as a two-tuple in this case:: @event.listens_for(Engine, "before_cursor_execute", retval=True) - def before_cursor_execute( - conn, cursor, statement, parameters, context, executemany - ): + def before_cursor_execute(conn, cursor, statement, + parameters, context, executemany): # do something with statement, parameters return statement, parameters @@ -774,9 +766,9 @@ class DialectEvents(event.Events[Dialect]): @event.listens_for(Engine, "handle_error") def handle_exception(context): - if isinstance( - context.original_exception, psycopg2.OperationalError - ) and "failed" in str(context.original_exception): + if isinstance(context.original_exception, + psycopg2.OperationalError) and \ + "failed" in str(context.original_exception): raise MySpecialException("failed operation") .. warning:: Because the @@ -799,13 +791,10 @@ class DialectEvents(event.Events[Dialect]): @event.listens_for(Engine, "handle_error", retval=True) def handle_exception(context): - if ( - context.chained_exception is not None - and "special" in context.chained_exception.message - ): - return MySpecialException( - "failed", cause=context.chained_exception - ) + if context.chained_exception is not None and \ + "special" in context.chained_exception.message: + return MySpecialException("failed", + cause=context.chained_exception) Handlers that return ``None`` may be used within the chain; when a handler returns ``None``, the previous exception instance, @@ -847,8 +836,7 @@ class DialectEvents(event.Events[Dialect]): e = create_engine("postgresql+psycopg2://user@host/dbname") - - @event.listens_for(e, "do_connect") + @event.listens_for(e, 'do_connect') def receive_do_connect(dialect, conn_rec, cargs, cparams): cparams["password"] = "some_password" @@ -857,8 +845,7 @@ class DialectEvents(event.Events[Dialect]): e = create_engine("postgresql+psycopg2://user@host/dbname") - - @event.listens_for(e, "do_connect") + @event.listens_for(e, 'do_connect') def receive_do_connect(dialect, conn_rec, cargs, cparams): return psycopg2.connect(*cargs, **cparams) @@ -941,8 +928,7 @@ class DialectEvents(event.Events[Dialect]): The setinputsizes hook overall is only used for dialects which include the flag ``use_setinputsizes=True``. Dialects which use this - include python-oracledb, cx_Oracle, pg8000, asyncpg, and pyodbc - dialects. + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. .. note:: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/interfaces.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/interfaces.py index 37093e8..ea1f27d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/interfaces.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/interfaces.py @@ -1,5 +1,5 @@ # engine/interfaces.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,6 +10,7 @@ from __future__ import annotations from enum import Enum +from types import ModuleType from typing import Any from typing import Awaitable from typing import Callable @@ -33,7 +34,7 @@ from typing import Union from .. import util from ..event import EventTarget from ..pool import Pool -from ..pool import PoolProxiedConnection as PoolProxiedConnection +from ..pool import PoolProxiedConnection from ..sql.compiler import Compiled as Compiled from ..sql.compiler import Compiled # noqa from ..sql.compiler import TypeCompiler as TypeCompiler @@ -50,7 +51,6 @@ if TYPE_CHECKING: from .base import Engine from .cursor import CursorResult from .url import URL - from ..connectors.asyncio import AsyncIODBAPIConnection from ..event import _ListenerFnType from ..event import dispatcher from ..exc import StatementError @@ -70,7 +70,6 @@ if TYPE_CHECKING: from ..sql.sqltypes import Integer from ..sql.type_api import _TypeMemoDict from ..sql.type_api import TypeEngine - from ..util.langhelpers import generic_fn_descriptor ConnectArgsType = Tuple[Sequence[str], MutableMapping[str, Any]] @@ -107,22 +106,6 @@ class ExecuteStyle(Enum): """ -class DBAPIModule(Protocol): - class Error(Exception): - def __getattr__(self, key: str) -> Any: ... - - class OperationalError(Error): - pass - - class InterfaceError(Error): - pass - - class IntegrityError(Error): - pass - - def __getattr__(self, key: str) -> Any: ... - - class DBAPIConnection(Protocol): """protocol representing a :pep:`249` database connection. @@ -135,17 +118,19 @@ class DBAPIConnection(Protocol): """ # noqa: E501 - def close(self) -> None: ... + def close(self) -> None: + ... - def commit(self) -> None: ... + def commit(self) -> None: + ... - def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... + def cursor(self) -> DBAPICursor: + ... - def rollback(self) -> None: ... + def rollback(self) -> None: + ... - def __getattr__(self, key: str) -> Any: ... - - def __setattr__(self, key: str, value: Any) -> None: ... + autocommit: bool class DBAPIType(Protocol): @@ -189,43 +174,53 @@ class DBAPICursor(Protocol): ... @property - def rowcount(self) -> int: ... + def rowcount(self) -> int: + ... arraysize: int lastrowid: int - def close(self) -> None: ... + def close(self) -> None: + ... def execute( self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: ... + ) -> Any: + ... def executemany( self, operation: Any, - parameters: _DBAPIMultiExecuteParams, - ) -> Any: ... + parameters: Sequence[_DBAPIMultiExecuteParams], + ) -> Any: + ... - def fetchone(self) -> Optional[Any]: ... + def fetchone(self) -> Optional[Any]: + ... - def fetchmany(self, size: int = ...) -> Sequence[Any]: ... + def fetchmany(self, size: int = ...) -> Sequence[Any]: + ... - def fetchall(self) -> Sequence[Any]: ... + def fetchall(self) -> Sequence[Any]: + ... - def setinputsizes(self, sizes: Sequence[Any]) -> None: ... + def setinputsizes(self, sizes: Sequence[Any]) -> None: + ... - def setoutputsize(self, size: Any, column: Any) -> None: ... + def setoutputsize(self, size: Any, column: Any) -> None: + ... - def callproc( - self, procname: str, parameters: Sequence[Any] = ... - ) -> Any: ... + def callproc(self, procname: str, parameters: Sequence[Any] = ...) -> Any: + ... - def nextset(self) -> Optional[bool]: ... + def nextset(self) -> Optional[bool]: + ... - def __getattr__(self, key: str) -> Any: ... + def __getattr__(self, key: str) -> Any: + ... _CoreSingleExecuteParams = Mapping[str, Any] @@ -289,7 +284,6 @@ class _CoreKnownExecutionOptions(TypedDict, total=False): yield_per: int insertmanyvalues_page_size: int schema_translate_map: Optional[SchemaTranslateMapType] - preserve_rowcount: bool _ExecuteOptions = immutabledict[str, Any] @@ -599,8 +593,8 @@ class BindTyping(Enum): """Use the pep-249 setinputsizes method. This is only implemented for DBAPIs that support this method and for which - the SQLAlchemy dialect has the appropriate infrastructure for that dialect - set up. Current dialects include python-oracledb, cx_Oracle as well as + the SQLAlchemy dialect has the appropriate infrastructure for that + dialect set up. Current dialects include cx_Oracle as well as optional support for SQL Server using pyodbc. When using setinputsizes, dialects also have a means of only using the @@ -677,7 +671,7 @@ class Dialect(EventTarget): dialect_description: str - dbapi: Optional[DBAPIModule] + dbapi: Optional[ModuleType] """A reference to the DBAPI module object itself. SQLAlchemy dialects import DBAPI modules using the classmethod @@ -701,7 +695,7 @@ class Dialect(EventTarget): """ @util.non_memoized_property - def loaded_dbapi(self) -> DBAPIModule: + def loaded_dbapi(self) -> ModuleType: """same as .dbapi, but is never None; will raise an error if no DBAPI was set up. @@ -779,14 +773,6 @@ class Dialect(EventTarget): default_isolation_level: Optional[IsolationLevel] """the isolation that is implicitly present on new connections""" - skip_autocommit_rollback: bool - """Whether or not the :paramref:`.create_engine.skip_autocommit_rollback` - parameter was set. - - .. versionadded:: 2.0.43 - - """ - # create_engine() -> isolation_level currently goes here _on_connect_isolation_level: Optional[IsolationLevel] @@ -806,14 +792,8 @@ class Dialect(EventTarget): max_identifier_length: int """The maximum length of identifier names.""" - max_index_name_length: Optional[int] - """The maximum length of index names if different from - ``max_identifier_length``.""" - max_constraint_name_length: Optional[int] - """The maximum length of constraint names if different from - ``max_identifier_length``.""" - supports_server_side_cursors: Union[generic_fn_descriptor[bool], bool] + supports_server_side_cursors: bool """indicates if the dialect supports server side cursors""" server_side_cursors: bool @@ -904,12 +884,12 @@ class Dialect(EventTarget): the statement multiple times for a series of batches when large numbers of rows are given. - The parameter is False for the default dialect, and is set to True for - SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, SQL Server. - It remains at False for Oracle Database, which provides native "executemany - with RETURNING" support and also does not support - ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL dialects - that don't support RETURNING will not report + The parameter is False for the default dialect, and is set to + True for SQLAlchemy internal dialects SQLite, MySQL/MariaDB, PostgreSQL, + SQL Server. It remains at False for Oracle, which provides native + "executemany with RETURNING" support and also does not support + ``supports_multivalues_insert``. For MySQL/MariaDB, those MySQL + dialects that don't support RETURNING will not report ``insert_executemany_returning`` as True. .. versionadded:: 2.0 @@ -1093,7 +1073,11 @@ class Dialect(EventTarget): To implement, establish as a series of tuples, as in:: construct_arguments = [ - (schema.Index, {"using": False, "where": None, "ops": None}), + (schema.Index, { + "using": False, + "where": None, + "ops": None + }) ] If the above construct is established on the PostgreSQL dialect, @@ -1122,8 +1106,7 @@ class Dialect(EventTarget): established on a :class:`.Table` object which will be passed as "reflection options" when using :paramref:`.Table.autoload_with`. - Current example is "oracle_resolve_synonyms" in the Oracle Database - dialects. + Current example is "oracle_resolve_synonyms" in the Oracle dialect. """ @@ -1147,7 +1130,7 @@ class Dialect(EventTarget): supports_constraint_comments: bool """Indicates if the dialect supports comment DDL on constraints. - .. versionadded:: 2.0 + .. versionadded: 2.0 """ _has_events = False @@ -1266,7 +1249,7 @@ class Dialect(EventTarget): raise NotImplementedError() @classmethod - def import_dbapi(cls) -> DBAPIModule: + def import_dbapi(cls) -> ModuleType: """Import the DBAPI module that is used by this dialect. The Python module object returned here will be assigned as an @@ -1283,7 +1266,8 @@ class Dialect(EventTarget): """ raise NotImplementedError() - def type_descriptor(self, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: + @classmethod + def type_descriptor(cls, typeobj: TypeEngine[_T]) -> TypeEngine[_T]: """Transform a generic type to a dialect-specific type. Dialect classes will usually use the @@ -1315,9 +1299,12 @@ class Dialect(EventTarget): """ + pass + if TYPE_CHECKING: - def _overrides_default(self, method_name: str) -> bool: ... + def _overrides_default(self, method_name: str) -> bool: + ... def get_columns( self, @@ -1343,7 +1330,6 @@ class Dialect(EventTarget): def get_multi_columns( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1392,7 +1378,6 @@ class Dialect(EventTarget): def get_multi_pk_constraint( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1439,7 +1424,6 @@ class Dialect(EventTarget): def get_multi_foreign_keys( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1599,7 +1583,6 @@ class Dialect(EventTarget): def get_multi_indexes( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1646,7 +1629,6 @@ class Dialect(EventTarget): def get_multi_unique_constraints( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1694,7 +1676,6 @@ class Dialect(EventTarget): def get_multi_check_constraints( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1737,7 +1718,6 @@ class Dialect(EventTarget): def get_multi_table_options( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -1789,7 +1769,6 @@ class Dialect(EventTarget): def get_multi_table_comment( self, connection: Connection, - *, schema: Optional[str] = None, filter_names: Optional[Collection[str]] = None, **kw: Any, @@ -2182,7 +2161,6 @@ class Dialect(EventTarget): def _deliver_insertmanyvalues_batches( self, - connection: Connection, cursor: DBAPICursor, statement: str, parameters: _DBAPIMultiExecuteParams, @@ -2236,7 +2214,7 @@ class Dialect(EventTarget): def is_disconnect( self, - e: DBAPIModule.Error, + e: Exception, connection: Optional[Union[PoolProxiedConnection, DBAPIConnection]], cursor: Optional[DBAPICursor], ) -> bool: @@ -2340,7 +2318,7 @@ class Dialect(EventTarget): """ return self.on_connect() - def on_connect(self) -> Optional[Callable[[Any], None]]: + def on_connect(self) -> Optional[Callable[[Any], Any]]: """return a callable which sets up a newly created DBAPI connection. The callable should accept a single argument "conn" which is the @@ -2489,30 +2467,6 @@ class Dialect(EventTarget): raise NotImplementedError() - def detect_autocommit_setting(self, dbapi_conn: DBAPIConnection) -> bool: - """Detect the current autocommit setting for a DBAPI connection. - - :param dbapi_connection: a DBAPI connection object - :return: True if autocommit is enabled, False if disabled - :rtype: bool - - This method inspects the given DBAPI connection to determine - whether autocommit mode is currently enabled. The specific - mechanism for detecting autocommit varies by database dialect - and DBAPI driver, however it should be done **without** network - round trips. - - .. note:: - - Not all dialects support autocommit detection. Dialects - that do not support this feature will raise - :exc:`NotImplementedError`. - - """ - raise NotImplementedError( - "This dialect cannot detect autocommit on a DBAPI connection" - ) - def get_default_isolation_level( self, dbapi_conn: DBAPIConnection ) -> IsolationLevel: @@ -2537,7 +2491,7 @@ class Dialect(EventTarget): def get_isolation_level_values( self, dbapi_conn: DBAPIConnection - ) -> Sequence[IsolationLevel]: + ) -> List[IsolationLevel]: """return a sequence of string isolation level names that are accepted by this dialect. @@ -2550,7 +2504,7 @@ class Dialect(EventTarget): ``REPEATABLE READ``. isolation level names will have underscores converted to spaces before being passed along to the dialect. * The names for the four standard isolation names to the extent that - they are supported by the backend should be ``READ UNCOMMITTED``, + they are supported by the backend should be ``READ UNCOMMITTED`` ``READ COMMITTED``, ``REPEATABLE READ``, ``SERIALIZABLE`` * if the dialect supports an autocommit option it should be provided using the isolation level name ``AUTOCOMMIT``. @@ -2711,9 +2665,6 @@ class Dialect(EventTarget): """return a Pool class to use for a given URL""" raise NotImplementedError() - def validate_identifier(self, ident: str) -> None: - """Validates an identifier name, raising an exception if invalid""" - class CreateEnginePlugin: """A set of hooks intended to augment the construction of an @@ -2739,14 +2690,11 @@ class CreateEnginePlugin: from sqlalchemy.engine import CreateEnginePlugin from sqlalchemy import event - class LogCursorEventsPlugin(CreateEnginePlugin): def __init__(self, url, kwargs): # consume the parameter "log_cursor_logging_name" from the # URL query - logging_name = url.query.get( - "log_cursor_logging_name", "log_cursor" - ) + logging_name = url.query.get("log_cursor_logging_name", "log_cursor") self.log = logging.getLogger(logging_name) @@ -2758,6 +2706,7 @@ class CreateEnginePlugin: "attach an event listener after the new Engine is constructed" event.listen(engine, "before_cursor_execute", self._log_event) + def _log_event( self, conn, @@ -2765,19 +2714,19 @@ class CreateEnginePlugin: statement, parameters, context, - executemany, - ): + executemany): self.log.info("Plugin logged cursor event: %s", statement) + + Plugins are registered using entry points in a similar way as that of dialects:: - entry_points = { - "sqlalchemy.plugins": [ - "log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin" + entry_points={ + 'sqlalchemy.plugins': [ + 'log_cursor_plugin = myapp.plugins:LogCursorEventsPlugin' ] - } A plugin that uses the above names would be invoked from a database URL as in:: @@ -2794,16 +2743,15 @@ class CreateEnginePlugin: in the URL:: engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?" - "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three" - ) + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=plugin_one&plugin=plugin_twp&plugin=plugin_three") The plugin names may also be passed directly to :func:`_sa.create_engine` using the :paramref:`_sa.create_engine.plugins` argument:: engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test", plugins=["myplugin"] - ) + "mysql+pymysql://scott:tiger@localhost/test", + plugins=["myplugin"]) .. versionadded:: 1.2.3 plugin names can also be specified to :func:`_sa.create_engine` as a list @@ -2825,9 +2773,9 @@ class CreateEnginePlugin: class MyPlugin(CreateEnginePlugin): def __init__(self, url, kwargs): - self.my_argument_one = url.query["my_argument_one"] - self.my_argument_two = url.query["my_argument_two"] - self.my_argument_three = kwargs.pop("my_argument_three", None) + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] + self.my_argument_three = kwargs.pop('my_argument_three', None) def update_url(self, url): return url.difference_update_query( @@ -2840,9 +2788,9 @@ class CreateEnginePlugin: from sqlalchemy import create_engine engine = create_engine( - "mysql+pymysql://scott:tiger@localhost/test?" - "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", - my_argument_three="bat", + "mysql+pymysql://scott:tiger@localhost/test?" + "plugin=myplugin&my_argument_one=foo&my_argument_two=bar", + my_argument_three='bat' ) .. versionchanged:: 1.4 @@ -2861,15 +2809,15 @@ class CreateEnginePlugin: def __init__(self, url, kwargs): if hasattr(CreateEnginePlugin, "update_url"): # detect the 1.4 API - self.my_argument_one = url.query["my_argument_one"] - self.my_argument_two = url.query["my_argument_two"] + self.my_argument_one = url.query['my_argument_one'] + self.my_argument_two = url.query['my_argument_two'] else: # detect the 1.3 and earlier API - mutate the # URL directly - self.my_argument_one = url.query.pop("my_argument_one") - self.my_argument_two = url.query.pop("my_argument_two") + self.my_argument_one = url.query.pop('my_argument_one') + self.my_argument_two = url.query.pop('my_argument_two') - self.my_argument_three = kwargs.pop("my_argument_three", None) + self.my_argument_three = kwargs.pop('my_argument_three', None) def update_url(self, url): # this method is only called in the 1.4 version @@ -3044,9 +2992,6 @@ class ExecutionContext: inline SQL expression value was fired off. Applies to inserts and updates.""" - execution_options: _ExecuteOptions - """Execution options associated with the current statement execution""" - @classmethod def _init_ddl( cls, @@ -3421,7 +3366,7 @@ class AdaptedConnection: __slots__ = ("_connection",) - _connection: AsyncIODBAPIConnection + _connection: Any @property def driver_connection(self) -> Any: @@ -3440,14 +3385,11 @@ class AdaptedConnection: engine = create_async_engine(...) - @event.listens_for(engine.sync_engine, "connect") - def register_custom_types( - dbapi_connection, # ... - ): + def register_custom_types(dbapi_connection, ...): dbapi_connection.run_async( lambda connection: connection.set_type_codec( - "MyCustomType", encoder, decoder, ... + 'MyCustomType', encoder, decoder, ... ) ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/mock.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/mock.py index a96af36..618ea1d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/mock.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/mock.py @@ -1,5 +1,5 @@ # engine/mock.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -27,9 +27,10 @@ if typing.TYPE_CHECKING: from .interfaces import Dialect from .url import URL from ..sql.base import Executable - from ..sql.ddl import InvokeDDLBase + from ..sql.ddl import SchemaDropper + from ..sql.ddl import SchemaGenerator from ..sql.schema import HasSchemaAttr - from ..sql.visitors import Visitable + from ..sql.schema import SchemaItem class MockConnection: @@ -52,14 +53,12 @@ class MockConnection: def _run_ddl_visitor( self, - visitorcallable: Type[InvokeDDLBase], - element: Visitable, + visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]], + element: SchemaItem, **kwargs: Any, ) -> None: kwargs["checkfirst"] = False - visitorcallable( - dialect=self.dialect, connection=self, **kwargs - ).traverse_single(element) + visitorcallable(self.dialect, self, **kwargs).traverse_single(element) def execute( self, @@ -91,12 +90,10 @@ def create_mock_engine( from sqlalchemy import create_mock_engine - def dump(sql, *multiparams, **params): print(sql.compile(dialect=engine.dialect)) - - engine = create_mock_engine("postgresql+psycopg2://", dump) + engine = create_mock_engine('postgresql+psycopg2://', dump) metadata.create_all(engine, checkfirst=False) :param url: A string URL which typically needs to contain only the diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/processors.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/processors.py index b3f9330..c01d3b7 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/processors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/processors.py @@ -1,5 +1,5 @@ -# engine/processors.py -# Copyright (C) 2010-2025 the SQLAlchemy authors and contributors +# sqlalchemy/processors.py +# Copyright (C) 2010-2023 the SQLAlchemy authors and contributors # # Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com # diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/reflection.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/reflection.py index 23009c6..6d2a8a2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/reflection.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/reflection.py @@ -1,5 +1,5 @@ # engine/reflection.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -55,7 +55,6 @@ from .. import util from ..sql import operators from ..sql import schema as sa_schema from ..sql.cache_key import _ad_hoc_cache_key_from_args -from ..sql.elements import quoted_name from ..sql.elements import TextClause from ..sql.type_api import TypeEngine from ..sql.visitors import InternalTraversal @@ -90,16 +89,8 @@ def cache( exclude = {"info_cache", "unreflectable"} key = ( fn.__name__, - tuple( - (str(a), a.quote) if isinstance(a, quoted_name) else a - for a in args - if isinstance(a, str) - ), - tuple( - (k, (str(v), v.quote) if isinstance(v, quoted_name) else v) - for k, v in kw.items() - if k not in exclude - ), + tuple(a for a in args if isinstance(a, str)), + tuple((k, v) for k, v in kw.items() if k not in exclude), ) ret: _R = info_cache.get(key) if ret is None: @@ -193,8 +184,7 @@ class Inspector(inspection.Inspectable["Inspector"]): or a :class:`_engine.Connection`:: from sqlalchemy import inspect, create_engine - - engine = create_engine("...") + engine = create_engine('...') insp = inspect(engine) Where above, the :class:`~sqlalchemy.engine.interfaces.Dialect` associated @@ -631,7 +621,7 @@ class Inspector(inspection.Inspectable["Inspector"]): r"""Return a list of temporary table names for the current bind. This method is unsupported by most dialects; currently - only Oracle Database, PostgreSQL and SQLite implements it. + only Oracle, PostgreSQL and SQLite implements it. :param \**kw: Additional keyword argument to pass to the dialect specific implementation. See the documentation of the dialect @@ -667,7 +657,7 @@ class Inspector(inspection.Inspectable["Inspector"]): given name was created. This currently includes some options that apply to MySQL and Oracle - Database tables. + tables. :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -1493,9 +1483,9 @@ class Inspector(inspection.Inspectable["Inspector"]): from sqlalchemy import create_engine, MetaData, Table from sqlalchemy import inspect - engine = create_engine("...") + engine = create_engine('...') meta = MetaData() - user_table = Table("user", meta) + user_table = Table('user', meta) insp = inspect(engine) insp.reflect_table(user_table, None) @@ -1714,12 +1704,9 @@ class Inspector(inspection.Inspectable["Inspector"]): if pk in cols_by_orig_name and pk not in exclude_columns ] - # update pk constraint name, comment and dialect_kwargs + # update pk constraint name and comment table.primary_key.name = pk_cons.get("name") table.primary_key.comment = pk_cons.get("comment", None) - dialect_options = pk_cons.get("dialect_options") - if dialect_options: - table.primary_key.dialect_kwargs.update(dialect_options) # tell the PKConstraint to re-initialize # its column collection @@ -1856,7 +1843,7 @@ class Inspector(inspection.Inspectable["Inspector"]): if not expressions: util.warn( f"Skipping {flavor} {name!r} because key " - f"{index + 1} reflected as None but no " + f"{index+1} reflected as None but no " "'expressions' were returned" ) break diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/result.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/result.py index c54a965..132ae88 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/result.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/result.py @@ -1,5 +1,5 @@ # engine/result.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -52,11 +52,11 @@ else: from sqlalchemy.cyextension.resultproxy import tuplegetter as tuplegetter if typing.TYPE_CHECKING: - from ..sql.elements import SQLCoreOperations + from ..sql.schema import Column from ..sql.type_api import _ResultProcessorType -_KeyType = Union[str, "SQLCoreOperations[Any]"] -_KeyIndexType = Union[_KeyType, int] +_KeyType = Union[str, "Column[Any]"] +_KeyIndexType = Union[str, "Column[Any]", int] # is overridden in cursor using _CursorKeyMapRecType _KeyMapRecType = Any @@ -64,7 +64,7 @@ _KeyMapRecType = Any _KeyMapType = Mapping[_KeyType, _KeyMapRecType] -_RowData = Union[Row[Any], RowMapping, Any] +_RowData = Union[Row, RowMapping, Any] """A generic form of "row" that accommodates for the different kinds of "rows" that different result objects return, including row, row mapping, and scalar values""" @@ -82,7 +82,7 @@ across all the result types """ -_InterimSupportsScalarsRowType = Union[Row[Any], Any] +_InterimSupportsScalarsRowType = Union[Row, Any] _ProcessorsType = Sequence[Optional["_ResultProcessorType[Any]"]] _TupleGetterType = Callable[[Sequence[Any]], Sequence[Any]] @@ -116,7 +116,8 @@ class ResultMetaData: @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... - ) -> NoReturn: ... + ) -> NoReturn: + ... @overload def _key_fallback( @@ -124,12 +125,14 @@ class ResultMetaData: key: Any, err: Optional[Exception], raiseerr: Literal[False] = ..., - ) -> None: ... + ) -> None: + ... @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = ... - ) -> Optional[NoReturn]: ... + ) -> Optional[NoReturn]: + ... def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = True @@ -326,6 +329,9 @@ class SimpleResultMetaData(ResultMetaData): _tuplefilter=_tuplefilter, ) + def _contains(self, value: Any, row: Row[Any]) -> bool: + return value in row._data + def _index_for_key(self, key: Any, raiseerr: bool = True) -> int: if int in key.__class__.__mro__: key = self._keys[key] @@ -722,21 +728,14 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): return manyrows - @overload - def _only_one_row( - self: ResultInternal[Row[Any]], - raise_for_second_row: bool, - raise_for_none: bool, - scalar: Literal[True], - ) -> Any: ... - @overload def _only_one_row( self, raise_for_second_row: bool, raise_for_none: Literal[True], scalar: bool, - ) -> _R: ... + ) -> _R: + ... @overload def _only_one_row( @@ -744,7 +743,8 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): raise_for_second_row: bool, raise_for_none: bool, scalar: bool, - ) -> Optional[_R]: ... + ) -> Optional[_R]: + ... def _only_one_row( self, @@ -817,6 +817,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): "was required" ) else: + next_row = _NO_ROW # if we checked for second row then that would have # closed us :) self._soft_close(hard=True) @@ -1106,15 +1107,17 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): statement = select(table.c.x, table.c.y, table.c.z) result = connection.execute(statement) - for z, y in result.columns("z", "y"): - ... + for z, y in result.columns('z', 'y'): + # ... + Example of using the column objects from the statement itself:: for z, y in result.columns( - statement.selected_columns.c.z, statement.selected_columns.c.y + statement.selected_columns.c.z, + statement.selected_columns.c.y ): - ... + # ... .. versionadded:: 1.4 @@ -1129,15 +1132,18 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return self._column_slices(col_expressions) @overload - def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: ... + def scalars(self: Result[Tuple[_T]]) -> ScalarResult[_T]: + ... @overload def scalars( self: Result[Tuple[_T]], index: Literal[0] - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload - def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: ... + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: + ... def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: """Return a :class:`_engine.ScalarResult` filtering object which @@ -1346,7 +1352,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): def fetchmany(self, size: Optional[int] = None) -> Sequence[Row[_TP]]: """Fetch many rows. - When all rows are exhausted, returns an empty sequence. + When all rows are exhausted, returns an empty list. This method is provided for backwards compatibility with SQLAlchemy 1.x.x. @@ -1354,7 +1360,7 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): To fetch rows in groups, use the :meth:`_engine.Result.partitions` method. - :return: a sequence of :class:`_engine.Row` objects. + :return: a list of :class:`_engine.Row` objects. .. seealso:: @@ -1365,14 +1371,14 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): return self._manyrow_getter(self, size) def all(self) -> Sequence[Row[_TP]]: - """Return all rows in a sequence. + """Return all rows in a list. Closes the result set after invocation. Subsequent invocations - will return an empty sequence. + will return an empty list. .. versionadded:: 1.4 - :return: a sequence of :class:`_engine.Row` objects. + :return: a list of :class:`_engine.Row` objects. .. seealso:: @@ -1448,20 +1454,22 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): ) @overload - def scalar_one(self: Result[Tuple[_T]]) -> _T: ... + def scalar_one(self: Result[Tuple[_T]]) -> _T: + ... @overload - def scalar_one(self) -> Any: ... + def scalar_one(self) -> Any: + ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` and - then :meth:`_engine.ScalarResult.one`. + then :meth:`_engine.Result.one`. .. seealso:: - :meth:`_engine.ScalarResult.one` + :meth:`_engine.Result.one` :meth:`_engine.Result.scalars` @@ -1471,20 +1479,22 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): ) @overload - def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: ... + def scalar_one_or_none(self: Result[Tuple[_T]]) -> Optional[_T]: + ... @overload - def scalar_one_or_none(self) -> Optional[Any]: ... + def scalar_one_or_none(self) -> Optional[Any]: + ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. This is equivalent to calling :meth:`_engine.Result.scalars` and - then :meth:`_engine.ScalarResult.one_or_none`. + then :meth:`_engine.Result.one_or_none`. .. seealso:: - :meth:`_engine.ScalarResult.one_or_none` + :meth:`_engine.Result.one_or_none` :meth:`_engine.Result.scalars` @@ -1496,8 +1506,8 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): def one(self) -> Row[_TP]: """Return exactly one row or raise an exception. - Raises :class:`_exc.NoResultFound` if the result returns no - rows, or :class:`_exc.MultipleResultsFound` if multiple rows + Raises :class:`.NoResultFound` if the result returns no + rows, or :class:`.MultipleResultsFound` if multiple rows would be returned. .. note:: This method returns one **row**, e.g. tuple, by default. @@ -1527,10 +1537,12 @@ class Result(_WithKeys, ResultInternal[Row[_TP]]): ) @overload - def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: ... + def scalar(self: Result[Tuple[_T]]) -> Optional[_T]: + ... @overload - def scalar(self) -> Any: ... + def scalar(self) -> Any: + ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -1764,7 +1776,7 @@ class ScalarResult(FilterResult[_R]): return self._manyrow_getter(self, size) def all(self) -> Sequence[_R]: - """Return all scalar values in a sequence. + """Return all scalar values in a list. Equivalent to :meth:`_engine.Result.all` except that scalar values, rather than :class:`_engine.Row` objects, @@ -1868,7 +1880,7 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... def all(self) -> Sequence[_R]: # noqa: A001 - """Return all scalar values in a sequence. + """Return all scalar values in a list. Equivalent to :meth:`_engine.Result.all` except that tuple values, rather than :class:`_engine.Row` objects, @@ -1877,9 +1889,11 @@ class TupleResult(FilterResult[_R], util.TypingOnly): """ ... - def __iter__(self) -> Iterator[_R]: ... + def __iter__(self) -> Iterator[_R]: + ... - def __next__(self) -> _R: ... + def __next__(self) -> _R: + ... def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -1913,20 +1927,22 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: ... + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: + ... @overload - def scalar_one(self) -> Any: ... + def scalar_one(self) -> Any: + ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.ScalarResult.one`. + and then :meth:`_engine.Result.one`. .. seealso:: - :meth:`_engine.ScalarResult.one` + :meth:`_engine.Result.one` :meth:`_engine.Result.scalars` @@ -1934,22 +1950,22 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar_one_or_none( - self: TupleResult[Tuple[_T]], - ) -> Optional[_T]: ... + def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... @overload - def scalar_one_or_none(self) -> Optional[Any]: ... + def scalar_one_or_none(self) -> Optional[Any]: + ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.ScalarResult.one_or_none`. + and then :meth:`_engine.Result.one_or_none`. .. seealso:: - :meth:`_engine.ScalarResult.one_or_none` + :meth:`_engine.Result.one_or_none` :meth:`_engine.Result.scalars` @@ -1957,10 +1973,12 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: ... + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: + ... @overload - def scalar(self) -> Any: ... + def scalar(self) -> Any: + ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result @@ -2013,7 +2031,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): return self def columns(self, *col_expressions: _KeyIndexType) -> Self: - """Establish the columns that should be returned in each row.""" + r"""Establish the columns that should be returned in each row.""" return self._column_slices(col_expressions) def partitions( @@ -2068,7 +2086,7 @@ class MappingResult(_WithKeys, FilterResult[RowMapping]): return self._manyrow_getter(self, size) def all(self) -> Sequence[RowMapping]: - """Return all scalar values in a sequence. + """Return all scalar values in a list. Equivalent to :meth:`_engine.Result.all` except that :class:`_engine.RowMapping` values, rather than :class:`_engine.Row` diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/row.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/row.py index da7ae9a..9017537 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/row.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/row.py @@ -1,5 +1,5 @@ # engine/row.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -213,12 +213,15 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): if TYPE_CHECKING: @overload - def __getitem__(self, index: int) -> Any: ... + def __getitem__(self, index: int) -> Any: + ... @overload - def __getitem__(self, index: slice) -> Sequence[Any]: ... + def __getitem__(self, index: slice) -> Sequence[Any]: + ... - def __getitem__(self, index: Union[int, slice]) -> Any: ... + def __getitem__(self, index: Union[int, slice]) -> Any: + ... def __lt__(self, other: Any) -> bool: return self._op(other, operator.lt) @@ -293,8 +296,8 @@ class ROMappingView(ABC): def __init__( self, mapping: Mapping["_KeyType", Any], items: Sequence[Any] ): - self._mapping = mapping # type: ignore[misc] - self._items = items # type: ignore[misc] + self._mapping = mapping + self._items = items def __len__(self) -> int: return len(self._items) @@ -318,11 +321,11 @@ class ROMappingView(ABC): class ROMappingKeysValuesView( ROMappingView, typing.KeysView["_KeyType"], typing.ValuesView[Any] ): - __slots__ = ("_items",) # mapping slot is provided by KeysView + __slots__ = ("_items",) class ROMappingItemsView(ROMappingView, typing.ItemsView["_KeyType", Any]): - __slots__ = ("_items",) # mapping slot is provided by ItemsView + __slots__ = ("_items",) class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): @@ -340,11 +343,12 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): as iteration of keys, values, and items:: for row in result: - if "a" in row._mapping: - print("Column 'a': %s" % row._mapping["a"]) + if 'a' in row._mapping: + print("Column 'a': %s" % row._mapping['a']) print("Column b: %s" % row._mapping[table.c.b]) + .. versionadded:: 1.4 The :class:`.RowMapping` object replaces the mapping-like access previously provided by a database result row, which now seeks to behave mostly like a named tuple. @@ -355,7 +359,8 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): if TYPE_CHECKING: - def __getitem__(self, key: _KeyType) -> Any: ... + def __getitem__(self, key: _KeyType) -> Any: + ... else: __getitem__ = BaseRow._get_by_key_impl_mapping diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/strategies.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/strategies.py index b4b8077..f884f20 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/strategies.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/strategies.py @@ -1,11 +1,14 @@ # engine/strategies.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Deprecated mock engine strategy used by Alembic.""" +"""Deprecated mock engine strategy used by Alembic. + + +""" from __future__ import annotations diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/url.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/url.py index 20079a6..5cf5ec7 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/url.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/url.py @@ -1,5 +1,5 @@ # engine/url.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -32,7 +32,6 @@ from typing import Tuple from typing import Type from typing import Union from urllib.parse import parse_qsl -from urllib.parse import quote from urllib.parse import quote_plus from urllib.parse import unquote @@ -122,9 +121,7 @@ class URL(NamedTuple): for keys and either strings or tuples of strings for values, e.g.:: >>> from sqlalchemy.engine import make_url - >>> url = make_url( - ... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" - ... ) + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") >>> url.query immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) @@ -173,11 +170,6 @@ class URL(NamedTuple): :param password: database password. Is typically a string, but may also be an object that can be stringified with ``str()``. - .. note:: The password string should **not** be URL encoded when - passed as an argument to :meth:`_engine.URL.create`; the string - should contain the password characters exactly as they would be - typed. - .. note:: A password-producing object will be stringified only **once** per :class:`_engine.Engine` object. For dynamic password generation per connect, see :ref:`engines_dynamic_tokens`. @@ -255,12 +247,14 @@ class URL(NamedTuple): @overload def _assert_value( val: str, - ) -> str: ... + ) -> str: + ... @overload def _assert_value( val: Sequence[str], - ) -> Union[str, Tuple[str, ...]]: ... + ) -> Union[str, Tuple[str, ...]]: + ... def _assert_value( val: Union[str, Sequence[str]], @@ -373,9 +367,7 @@ class URL(NamedTuple): >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_string( - ... "alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" - ... ) + >>> url = url.update_query_string("alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -411,13 +403,7 @@ class URL(NamedTuple): >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_pairs( - ... [ - ... ("alt_host", "host1"), - ... ("alt_host", "host2"), - ... ("ssl_cipher", "/path/to/crt"), - ... ] - ... ) + >>> url = url.update_query_pairs([("alt_host", "host1"), ("alt_host", "host2"), ("ssl_cipher", "/path/to/crt")]) >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -499,9 +485,7 @@ class URL(NamedTuple): >>> from sqlalchemy.engine import make_url >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname") - >>> url = url.update_query_dict( - ... {"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"} - ... ) + >>> url = url.update_query_dict({"alt_host": ["host1", "host2"], "ssl_cipher": "/path/to/crt"}) >>> str(url) 'postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt' @@ -539,14 +523,14 @@ class URL(NamedTuple): E.g.:: - url = url.difference_update_query(["foo", "bar"]) + url = url.difference_update_query(['foo', 'bar']) Equivalent to using :meth:`_engine.URL.set` as follows:: url = url.set( query={ key: url.query[key] - for key in set(url.query).difference(["foo", "bar"]) + for key in set(url.query).difference(['foo', 'bar']) } ) @@ -595,9 +579,7 @@ class URL(NamedTuple): >>> from sqlalchemy.engine import make_url - >>> url = make_url( - ... "postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt" - ... ) + >>> url = make_url("postgresql+psycopg2://user:pass@host/dbname?alt_host=host1&alt_host=host2&ssl_cipher=%2Fpath%2Fto%2Fcrt") >>> url.query immutabledict({'alt_host': ('host1', 'host2'), 'ssl_cipher': '/path/to/crt'}) >>> url.normalized_query @@ -639,17 +621,17 @@ class URL(NamedTuple): """ s = self.drivername + "://" if self.username is not None: - s += quote(self.username, safe=" +") + s += _sqla_url_quote(self.username) if self.password is not None: s += ":" + ( "***" if hide_password - else quote(str(self.password), safe=" +") + else _sqla_url_quote(str(self.password)) ) s += "@" if self.host is not None: if ":" in self.host: - s += f"[{self.host}]" + s += "[%s]" % self.host else: s += self.host if self.port is not None: @@ -660,7 +642,7 @@ class URL(NamedTuple): keys = list(self.query) keys.sort() s += "?" + "&".join( - f"{quote_plus(k)}={quote_plus(element)}" + "%s=%s" % (quote_plus(k), quote_plus(element)) for k in keys for element in util.to_list(self.query[k]) ) @@ -903,10 +885,10 @@ def _parse_url(name: str) -> URL: components["query"] = query if components["username"] is not None: - components["username"] = unquote(components["username"]) + components["username"] = _sqla_url_unquote(components["username"]) if components["password"] is not None: - components["password"] = unquote(components["password"]) + components["password"] = _sqla_url_unquote(components["password"]) ipv4host = components.pop("ipv4host") ipv6host = components.pop("ipv6host") @@ -920,5 +902,12 @@ def _parse_url(name: str) -> URL: else: raise exc.ArgumentError( - "Could not parse SQLAlchemy URL from given URL string" + "Could not parse SQLAlchemy URL from string '%s'" % name ) + + +def _sqla_url_quote(text: str) -> str: + return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text) + + +_sqla_url_unquote = unquote diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/engine/util.py b/venv/lib/python3.12/site-packages/sqlalchemy/engine/util.py index e499efa..9b147a7 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/engine/util.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/engine/util.py @@ -1,5 +1,5 @@ # engine/util.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -17,7 +17,6 @@ from .. import exc from .. import util from ..util._has_cy import HAS_CYEXTENSION from ..util.typing import Protocol -from ..util.typing import Self if typing.TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_util import _distill_params_20 as _distill_params_20 @@ -114,7 +113,7 @@ class TransactionalContext: "before emitting further commands." ) - def __enter__(self) -> Self: + def __enter__(self) -> TransactionalContext: subject = self._get_subject() # none for outer transaction, may be non-None for nested diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/__init__.py index 309b7bd..20a20d1 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/__init__.py @@ -1,5 +1,5 @@ # event/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/api.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/api.py index 01dd4bd..bb1dbea 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/api.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/api.py @@ -1,11 +1,13 @@ # event/api.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Public API functions for the event system.""" +"""Public API functions for the event system. + +""" from __future__ import annotations from typing import Any @@ -49,14 +51,15 @@ def listen( from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - def unique_constraint_name(const, table): - const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name) - - + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) event.listen( - UniqueConstraint, "after_parent_attach", unique_constraint_name - ) + UniqueConstraint, + "after_parent_attach", + unique_constraint_name) :param bool insert: The default behavior for event handlers is to append the decorated user defined function to an internal list of registered @@ -129,17 +132,19 @@ def listens_for( The :func:`.listens_for` decorator is part of the primary interface for the SQLAlchemy event system, documented at :ref:`event_toplevel`. - This function generally shares the same kwargs as :func:`.listen`. + This function generally shares the same kwargs as :func:`.listens`. e.g.:: from sqlalchemy import event from sqlalchemy.schema import UniqueConstraint - @event.listens_for(UniqueConstraint, "after_parent_attach") def unique_constraint_name(const, table): - const.name = "uq_%s_%s" % (table.name, list(const.columns)[0].name) + const.name = "uq_%s_%s" % ( + table.name, + list(const.columns)[0].name + ) A given function can also be invoked for only the first invocation of the event using the ``once`` argument:: @@ -148,6 +153,7 @@ def listens_for( def on_config(): do_config() + .. warning:: The ``once`` argument does not imply automatic de-registration of the listener function after it has been invoked a first time; a listener entry will remain associated with the target object. @@ -183,7 +189,6 @@ def remove(target: Any, identifier: str, fn: Callable[..., Any]) -> None: def my_listener_function(*arg): pass - # ... it's removed like this event.remove(SomeMappedClass, "before_insert", my_listener_function) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/attr.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/attr.py index ec5d582..0aa3419 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/attr.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/attr.py @@ -1,5 +1,5 @@ # event/attr.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -391,23 +391,20 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]): class _MutexProtocol(Protocol): - def __enter__(self) -> bool: ... + def __enter__(self) -> bool: + ... def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> Optional[bool]: ... + ) -> Optional[bool]: + ... class _CompoundListener(_InstanceLevelDispatch[_ET]): - __slots__ = ( - "_exec_once_mutex", - "_exec_once", - "_exec_w_sync_once", - "_is_asyncio", - ) + __slots__ = "_exec_once_mutex", "_exec_once", "_exec_w_sync_once" _exec_once_mutex: _MutexProtocol parent_listeners: Collection[_ListenerFnType] @@ -415,18 +412,11 @@ class _CompoundListener(_InstanceLevelDispatch[_ET]): _exec_once: bool _exec_w_sync_once: bool - def __init__(self, *arg: Any, **kw: Any): - super().__init__(*arg, **kw) - self._is_asyncio = False - def _set_asyncio(self) -> None: - self._is_asyncio = True + self._exec_once_mutex = AsyncAdaptedLock() def _memoized_attr__exec_once_mutex(self) -> _MutexProtocol: - if self._is_asyncio: - return AsyncAdaptedLock() - else: - return threading.Lock() + return threading.Lock() def _exec_once_impl( self, retry_on_exception: bool, *args: Any, **kw: Any @@ -535,7 +525,6 @@ class _ListenerCollection(_CompoundListener[_ET]): propagate: Set[_ListenerFnType] def __init__(self, parent: _ClsLevelDispatch[_ET], target_cls: Type[_ET]): - super().__init__() if target_cls not in parent._clslevel: parent.update_subclass(target_cls) self._exec_once = False @@ -575,9 +564,6 @@ class _ListenerCollection(_CompoundListener[_ET]): existing_listeners.extend(other_listeners) - if other._is_asyncio: - self._set_asyncio() - to_associate = other.propagate.union(other_listeners) registry._stored_in_collection_multi(self, other, to_associate) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/base.py index 66dc129..f92b2ed 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/base.py @@ -1,5 +1,5 @@ # event/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -42,9 +42,9 @@ from .registry import _EventKey from .. import util from ..util.typing import Literal -_registrars: MutableMapping[str, List[Type[_HasEventsDispatch[Any]]]] = ( - util.defaultdict(list) -) +_registrars: MutableMapping[ + str, List[Type[_HasEventsDispatch[Any]]] +] = util.defaultdict(list) def _is_event_name(name: str) -> bool: @@ -191,8 +191,13 @@ class _Dispatch(_DispatchCommon[_ET]): :class:`._Dispatch` objects. """ - assert "_joined_dispatch_cls" in self.__class__.__dict__ - + if "_joined_dispatch_cls" not in self.__class__.__dict__: + cls = type( + "Joined%s" % self.__class__.__name__, + (_JoinedDispatcher,), + {"__slots__": self._event_names}, + ) + self.__class__._joined_dispatch_cls = cls return self._joined_dispatch_cls(self, other) def __reduce__(self) -> Union[str, Tuple[Any, ...]]: @@ -235,7 +240,8 @@ class _HasEventsDispatch(Generic[_ET]): if typing.TYPE_CHECKING: - def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: ... + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: + ... def __init_subclass__(cls) -> None: """Intercept new Event subclasses and create associated _Dispatch @@ -323,51 +329,6 @@ class _HasEventsDispatch(Generic[_ET]): else: dispatch_target_cls.dispatch = dispatcher(cls) - klass = type( - "Joined%s" % dispatch_cls.__name__, - (_JoinedDispatcher,), - {"__slots__": event_names}, - ) - dispatch_cls._joined_dispatch_cls = klass - - # establish pickle capability by adding it to this module - globals()[klass.__name__] = klass - - -class _JoinedDispatcher(_DispatchCommon[_ET]): - """Represent a connection between two _Dispatch objects.""" - - __slots__ = "local", "parent", "_instance_cls" - - local: _DispatchCommon[_ET] - parent: _DispatchCommon[_ET] - _instance_cls: Optional[Type[_ET]] - - def __init__( - self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] - ): - self.local = local - self.parent = parent - self._instance_cls = self.local._instance_cls - - def __reduce__(self) -> Any: - return (self.__class__, (self.local, self.parent)) - - def __getattr__(self, name: str) -> _JoinedListener[_ET]: - # Assign _JoinedListeners as attributes on demand - # to reduce startup time for new dispatch objects. - ls = getattr(self.local, name) - jl = _JoinedListener(self.parent, ls.name, ls) - setattr(self, ls.name, jl) - return jl - - def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: - return self.parent._listen(event_key, **kw) - - @property - def _events(self) -> Type[_HasEventsDispatch[_ET]]: - return self.parent._events - class Events(_HasEventsDispatch[_ET]): """Define event listening functions for a particular target type.""" @@ -380,11 +341,9 @@ class Events(_HasEventsDispatch[_ET]): return all(isinstance(target.dispatch, t) for t in types) def dispatch_parent_is(t: Type[Any]) -> bool: - parent = cast("_JoinedDispatcher[_ET]", target.dispatch).parent - while isinstance(parent, _JoinedDispatcher): - parent = cast("_JoinedDispatcher[_ET]", parent).parent - - return isinstance(parent, t) + return isinstance( + cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t + ) # Mapper, ClassManager, Session override this to # also accept classes, scoped_sessions, sessionmakers, etc. @@ -424,6 +383,38 @@ class Events(_HasEventsDispatch[_ET]): cls.dispatch._clear() +class _JoinedDispatcher(_DispatchCommon[_ET]): + """Represent a connection between two _Dispatch objects.""" + + __slots__ = "local", "parent", "_instance_cls" + + local: _DispatchCommon[_ET] + parent: _DispatchCommon[_ET] + _instance_cls: Optional[Type[_ET]] + + def __init__( + self, local: _DispatchCommon[_ET], parent: _DispatchCommon[_ET] + ): + self.local = local + self.parent = parent + self._instance_cls = self.local._instance_cls + + def __getattr__(self, name: str) -> _JoinedListener[_ET]: + # Assign _JoinedListeners as attributes on demand + # to reduce startup time for new dispatch objects. + ls = getattr(self.local, name) + jl = _JoinedListener(self.parent, ls.name, ls) + setattr(self, ls.name, jl) + return jl + + def _listen(self, event_key: _EventKey[_ET], **kw: Any) -> None: + return self.parent._listen(event_key, **kw) + + @property + def _events(self) -> Type[_HasEventsDispatch[_ET]]: + return self.parent._events + + class dispatcher(Generic[_ET]): """Descriptor used by target classes to deliver the _Dispatch class at the class level @@ -439,10 +430,12 @@ class dispatcher(Generic[_ET]): @overload def __get__( self, obj: Literal[None], cls: Type[Any] - ) -> Type[_Dispatch[_ET]]: ... + ) -> Type[_Dispatch[_ET]]: + ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: + ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/legacy.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/legacy.py index e60fd9a..f3a7d04 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/legacy.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/legacy.py @@ -1,5 +1,5 @@ # event/legacy.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -147,9 +147,9 @@ def _standard_listen_example( ) text %= { - "current_since": ( - " (arguments as of %s)" % current_since if current_since else "" - ), + "current_since": " (arguments as of %s)" % current_since + if current_since + else "", "event_name": fn.__name__, "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(dispatch_collection.arg_names), @@ -177,9 +177,9 @@ def _legacy_listen_examples( % { "since": since, "event_name": fn.__name__, - "has_kw_arguments": ( - " **kw" if dispatch_collection.has_kw else "" - ), + "has_kw_arguments": " **kw" + if dispatch_collection.has_kw + else "", "named_event_arguments": ", ".join(args), "sample_target": sample_target, } diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/event/registry.py b/venv/lib/python3.12/site-packages/sqlalchemy/event/registry.py index d7e4b32..fb2fed8 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/event/registry.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/event/registry.py @@ -1,5 +1,5 @@ # event/registry.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -66,9 +66,9 @@ _RefCollectionToListenerType = Dict[ "weakref.ref[_ListenerFnType]", ] -_key_to_collection: Dict[_EventKeyTupleType, _RefCollectionToListenerType] = ( - collections.defaultdict(dict) -) +_key_to_collection: Dict[ + _EventKeyTupleType, _RefCollectionToListenerType +] = collections.defaultdict(dict) """ Given an original listen() argument, can locate all listener collections and the listener fn contained @@ -154,11 +154,7 @@ def _removed_from_collection( if owner_ref in _collection_to_key: listener_to_key = _collection_to_key[owner_ref] - # see #12216 - this guards against a removal that already occurred - # here. however, I cannot come up with a test that shows any negative - # side effects occurring from this removal happening, even though an - # event key may still be referenced from a clsleveldispatch here - listener_to_key.pop(listen_ref, None) + listener_to_key.pop(listen_ref) def _stored_in_collection_multi( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/events.py b/venv/lib/python3.12/site-packages/sqlalchemy/events.py index ce83243..2f7b23d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/events.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/events.py @@ -1,5 +1,5 @@ -# events.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/exc.py b/venv/lib/python3.12/site-packages/sqlalchemy/exc.py index 71e5dd8..a5a66de 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/exc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/exc.py @@ -1,5 +1,5 @@ -# exc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/exc.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -432,16 +432,14 @@ class DontWrapMixin: from sqlalchemy.exc import DontWrapMixin - class MyCustomException(Exception, DontWrapMixin): pass - class MySpecialType(TypeDecorator): impl = String def process_bind_param(self, value, dialect): - if value == "invalid": + if value == 'invalid': raise MyCustomException("invalid!") """ @@ -573,7 +571,8 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> StatementError: ... + ) -> StatementError: + ... @overload @classmethod @@ -587,7 +586,8 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> DontWrapMixin: ... + ) -> DontWrapMixin: + ... @overload @classmethod @@ -601,7 +601,8 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> BaseException: ... + ) -> BaseException: + ... @classmethod def instance( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/__init__.py index 2751bcf..e3af738 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/__init__.py @@ -1,5 +1,5 @@ # ext/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/associationproxy.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/associationproxy.py index d72cfc3..31df134 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/associationproxy.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/associationproxy.py @@ -1,5 +1,5 @@ # ext/associationproxy.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -98,8 +98,6 @@ def association_proxy( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, ) -> AssociationProxy[Any]: r"""Return a Python property implementing a view of a target attribute which references an attribute on members of the @@ -200,19 +198,6 @@ def association_proxy( .. versionadded:: 2.0.0b4 - :param hash: Specific to - :ref:`orm_declarative_native_dataclasses`, controls if this field - is included when generating the ``__hash__()`` method for the mapped - class. - - .. versionadded:: 2.0.36 - - :param dataclass_metadata: Specific to - :ref:`orm_declarative_native_dataclasses`, supplies metadata - to be attached to the generated dataclass field. - - .. versionadded:: 2.0.42 - :param info: optional, will be assigned to :attr:`.AssociationProxy.info` if present. @@ -252,14 +237,7 @@ def association_proxy( cascade_scalar_deletes=cascade_scalar_deletes, create_on_none_assignment=create_on_none_assignment, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), ) @@ -276,39 +254,45 @@ class AssociationProxyExtensionType(InspectionAttrExtensionType): class _GetterProtocol(Protocol[_T_co]): - def __call__(self, instance: Any) -> _T_co: ... + def __call__(self, instance: Any) -> _T_co: + ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _SetterProtocol(Protocol): ... +class _SetterProtocol(Protocol): + ... class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, value: _T_con) -> None: ... + def __call__(self, instance: Any, value: _T_con) -> None: + ... class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ... + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: + ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _CreatorProtocol(Protocol): ... +class _CreatorProtocol(Protocol): + ... class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, value: _T_con) -> Any: ... + def __call__(self, value: _T_con) -> Any: + ... class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ... + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: + ... class _LazyCollectionProtocol(Protocol[_T]): def __call__( self, - ) -> Union[ - MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T] - ]: ... + ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: + ... class _GetSetFactoryProtocol(Protocol): @@ -316,7 +300,8 @@ class _GetSetFactoryProtocol(Protocol): self, collection_class: Optional[Type[Any]], assoc_instance: AssociationProxyInstance[Any], - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + ... class _ProxyFactoryProtocol(Protocol): @@ -326,13 +311,15 @@ class _ProxyFactoryProtocol(Protocol): creator: _CreatorProtocol, value_attr: str, parent: AssociationProxyInstance[Any], - ) -> Any: ... + ) -> Any: + ... class _ProxyBulkSetProtocol(Protocol): def __call__( self, proxy: _AssociationCollection[Any], collection: Iterable[Any] - ) -> None: ... + ) -> None: + ... class _AssociationProxyProtocol(Protocol[_T]): @@ -350,15 +337,18 @@ class _AssociationProxyProtocol(Protocol[_T]): proxy_bulk_set: Optional[_ProxyBulkSetProtocol] @util.ro_memoized_property - def info(self) -> _InfoType: ... + def info(self) -> _InfoType: + ... def for_class( self, class_: Type[Any], obj: Optional[object] = None - ) -> AssociationProxyInstance[_T]: ... + ) -> AssociationProxyInstance[_T]: + ... def _default_getset( self, collection_class: Any - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: + ... class AssociationProxy( @@ -429,17 +419,18 @@ class AssociationProxy( self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS @overload - def __get__( - self, instance: Literal[None], owner: Literal[None] - ) -> Self: ... + def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self: + ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> AssociationProxyInstance[_T]: ... + ) -> AssociationProxyInstance[_T]: + ... @overload - def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> _T: + ... def __get__( self, instance: object, owner: Any @@ -472,7 +463,7 @@ class AssociationProxy( class User(Base): # ... - keywords = association_proxy("kws", "keyword") + keywords = association_proxy('kws', 'keyword') If we access this :class:`.AssociationProxy` from :attr:`_orm.Mapper.all_orm_descriptors`, and we want to view the @@ -792,9 +783,9 @@ class AssociationProxyInstance(SQLORMOperations[_T]): :attr:`.AssociationProxyInstance.remote_attr` attributes separately:: stmt = ( - select(Parent) - .join(Parent.proxied.local_attr) - .join(Parent.proxied.remote_attr) + select(Parent). + join(Parent.proxied.local_attr). + join(Parent.proxied.remote_attr) ) A future release may seek to provide a more succinct join pattern @@ -870,10 +861,12 @@ class AssociationProxyInstance(SQLORMOperations[_T]): return self.parent.info @overload - def get(self: _Self, obj: Literal[None]) -> _Self: ... + def get(self: _Self, obj: Literal[None]) -> _Self: + ... @overload - def get(self, obj: Any) -> _T: ... + def get(self, obj: Any) -> _T: + ... def get( self, obj: Any @@ -1096,7 +1089,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]): and (not self._target_is_object or self._value_is_scalar) ): raise exc.InvalidRequestError( - "'any()' not implemented for scalar attributes. Use has()." + "'any()' not implemented for scalar " "attributes. Use has()." ) return self._criterion_exists( criterion=criterion, is_has=False, **kwargs @@ -1120,7 +1113,7 @@ class AssociationProxyInstance(SQLORMOperations[_T]): or (self._target_is_object and not self._value_is_scalar) ): raise exc.InvalidRequestError( - "'has()' not implemented for collections. Use any()." + "'has()' not implemented for collections. " "Use any()." ) return self._criterion_exists( criterion=criterion, is_has=True, **kwargs @@ -1439,10 +1432,12 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): self.setter(object_, value) @overload - def __getitem__(self, index: int) -> _T: ... + def __getitem__(self, index: int) -> _T: + ... @overload - def __getitem__(self, index: slice) -> MutableSequence[_T]: ... + def __getitem__(self, index: slice) -> MutableSequence[_T]: + ... def __getitem__( self, index: Union[int, slice] @@ -1453,10 +1448,12 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): return [self._get(member) for member in self.col[index]] @overload - def __setitem__(self, index: int, value: _T) -> None: ... + def __setitem__(self, index: int, value: _T) -> None: + ... @overload - def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: + ... def __setitem__( self, index: Union[int, slice], value: Union[_T, Iterable[_T]] @@ -1495,10 +1492,12 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): self._set(self.col[i], item) @overload - def __delitem__(self, index: int) -> None: ... + def __delitem__(self, index: int) -> None: + ... @overload - def __delitem__(self, index: slice) -> None: ... + def __delitem__(self, index: slice) -> None: + ... def __delitem__(self, index: Union[slice, int]) -> None: del self.col[index] @@ -1625,9 +1624,8 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): if typing.TYPE_CHECKING: # TODO: no idea how to do this without separate "stub" - def index( - self, value: Any, start: int = ..., stop: int = ... - ) -> int: ... + def index(self, value: Any, start: int = ..., stop: int = ...) -> int: + ... else: @@ -1703,10 +1701,12 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): return repr(dict(self)) @overload - def get(self, __key: _KT) -> Optional[_VT]: ... + def get(self, __key: _KT) -> Optional[_VT]: + ... @overload - def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... + def get(self, __key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: + ... def get( self, key: _KT, default: Optional[Union[_VT, _T]] = None @@ -1738,12 +1738,12 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): return ValuesView(self) @overload - def pop(self, __key: _KT) -> _VT: ... + def pop(self, __key: _KT) -> _VT: + ... @overload - def pop( - self, __key: _KT, default: Union[_VT, _T] = ... - ) -> Union[_VT, _T]: ... + def pop(self, __key: _KT, default: Union[_VT, _T] = ...) -> Union[_VT, _T]: + ... def pop(self, __key: _KT, *arg: Any, **kw: Any) -> Union[_VT, _T]: member = self.col.pop(__key, *arg, **kw) @@ -1756,15 +1756,16 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): @overload def update( self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT - ) -> None: ... + ) -> None: + ... @overload - def update( - self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT - ) -> None: ... + def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: + ... @overload - def update(self, **kwargs: _VT) -> None: ... + def update(self, **kwargs: _VT) -> None: + ... def update(self, *a: Any, **kw: Any) -> None: up: Dict[_KT, _VT] = {} diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/__init__.py index 7d8a04b..8564db6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/__init__.py @@ -1,5 +1,5 @@ # ext/asyncio/__init__.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/base.py index 72a617f..251f521 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/base.py @@ -1,5 +1,5 @@ # ext/asyncio/base.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -44,10 +44,12 @@ class ReversibleProxy(Generic[_PT]): __slots__ = ("__weakref__",) @overload - def _assign_proxied(self, target: _PT) -> _PT: ... + def _assign_proxied(self, target: _PT) -> _PT: + ... @overload - def _assign_proxied(self, target: None) -> None: ... + def _assign_proxied(self, target: None) -> None: + ... def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: @@ -71,26 +73,28 @@ class ReversibleProxy(Generic[_PT]): cls._proxy_objects.pop(ref, None) @classmethod - def _regenerate_proxy_for_target( - cls, target: _PT, **additional_kw: Any - ) -> Self: + def _regenerate_proxy_for_target(cls, target: _PT) -> Self: raise NotImplementedError() @overload @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: Literal[True] = ..., **additional_kw: Any - ) -> Self: ... + cls, + target: _PT, + regenerate: Literal[True] = ..., + ) -> Self: + ... @overload @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True, **additional_kw: Any - ) -> Optional[Self]: ... + cls, target: _PT, regenerate: bool = True + ) -> Optional[Self]: + ... @classmethod def _retrieve_proxy_for_target( - cls, target: _PT, regenerate: bool = True, **additional_kw: Any + cls, target: _PT, regenerate: bool = True ) -> Optional[Self]: try: proxy_ref = cls._proxy_objects[weakref.ref(target)] @@ -102,7 +106,7 @@ class ReversibleProxy(Generic[_PT]): return proxy # type: ignore if regenerate: - return cls._regenerate_proxy_for_target(target, **additional_kw) + return cls._regenerate_proxy_for_target(target) else: return None @@ -178,7 +182,7 @@ class GeneratorStartableContext(StartableContext[_T_co]): # tell if we get the same exception back value = typ() try: - await self.gen.athrow(value) + await util.athrow(self.gen, typ, value, traceback) except StopAsyncIteration as exc: # Suppress StopIteration *unless* it's the same exception that # was passed to throw(). This prevents a StopIteration @@ -215,7 +219,7 @@ class GeneratorStartableContext(StartableContext[_T_co]): def asyncstartablecontext( - func: Callable[..., AsyncIterator[_T_co]], + func: Callable[..., AsyncIterator[_T_co]] ) -> Callable[..., GeneratorStartableContext[_T_co]]: """@asyncstartablecontext decorator. @@ -224,9 +228,7 @@ def asyncstartablecontext( ``@contextlib.asynccontextmanager`` supports, and the usage pattern is different as well. - Typical usage: - - .. sourcecode:: text + Typical usage:: @asyncstartablecontext async def some_async_generator(): diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/engine.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/engine.py index d4ecbda..bf968cc 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/engine.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/engine.py @@ -1,5 +1,5 @@ # ext/asyncio/engine.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -41,8 +41,6 @@ from ...engine.base import NestedTransaction from ...engine.base import Transaction from ...exc import ArgumentError from ...util.concurrency import greenlet_spawn -from ...util.typing import Concatenate -from ...util.typing import ParamSpec if TYPE_CHECKING: from ...engine.cursor import CursorResult @@ -63,7 +61,6 @@ if TYPE_CHECKING: from ...sql.base import Executable from ...sql.selectable import TypedReturnsRows -_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -198,7 +195,6 @@ class AsyncConnection( method of :class:`_asyncio.AsyncEngine`:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") async with engine.connect() as conn: @@ -255,7 +251,7 @@ class AsyncConnection( @classmethod def _regenerate_proxy_for_target( - cls, target: Connection, **additional_kw: Any # noqa: U100 + cls, target: Connection ) -> AsyncConnection: return AsyncConnection( AsyncEngine._retrieve_proxy_for_target(target.engine), target @@ -418,12 +414,13 @@ class AsyncConnection( yield_per: int = ..., insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., - preserve_rowcount: bool = False, **opt: Any, - ) -> AsyncConnection: ... + ) -> AsyncConnection: + ... @overload - async def execution_options(self, **opt: Any) -> AsyncConnection: ... + async def execution_options(self, **opt: Any) -> AsyncConnection: + ... async def execution_options(self, **opt: Any) -> AsyncConnection: r"""Set non-SQL options for the connection which take effect @@ -521,7 +518,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[_T]]: ... + ) -> GeneratorStartableContext[AsyncResult[_T]]: + ... @overload def stream( @@ -530,7 +528,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[Any]]: ... + ) -> GeneratorStartableContext[AsyncResult[Any]]: + ... @asyncstartablecontext async def stream( @@ -545,7 +544,7 @@ class AsyncConnection( E.g.:: - result = await conn.stream(stmt) + result = await conn.stream(stmt): async for row in result: print(f"{row}") @@ -574,11 +573,6 @@ class AsyncConnection( :meth:`.AsyncConnection.stream_scalars` """ - if not self.dialect.supports_server_side_cursors: - raise exc.InvalidRequestError( - "Cant use `stream` or `stream_scalars` with the current " - "dialect since it does not support server side cursors." - ) result = await greenlet_spawn( self._proxied.execute, @@ -606,7 +600,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[_T]: ... + ) -> CursorResult[_T]: + ... @overload async def execute( @@ -615,7 +610,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... async def execute( self, @@ -671,7 +667,8 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload async def scalar( @@ -680,7 +677,8 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: ... + ) -> Any: + ... async def scalar( self, @@ -711,7 +709,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload async def scalars( @@ -720,7 +719,8 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... async def scalars( self, @@ -752,7 +752,8 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: + ... @overload def stream_scalars( @@ -761,7 +762,8 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: + ... @asyncstartablecontext async def stream_scalars( @@ -817,12 +819,9 @@ class AsyncConnection( yield result.scalars() async def run_sync( - self, - fn: Callable[Concatenate[Connection, _P], _T], - *arg: _P.args, - **kw: _P.kwargs, + self, fn: Callable[..., _T], *arg: Any, **kw: Any ) -> _T: - '''Invoke the given synchronous (i.e. not async) callable, + """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_engine.Connection` as the first argument. @@ -832,26 +831,26 @@ class AsyncConnection( E.g.:: def do_something_with_core(conn: Connection, arg1: int, arg2: str) -> str: - """A synchronous function that does not require awaiting + '''A synchronous function that does not require awaiting :param conn: a Core SQLAlchemy Connection, used synchronously :return: an optional return value is supported - """ - conn.execute(some_table.insert().values(int_col=arg1, str_col=arg2)) + ''' + conn.execute( + some_table.insert().values(int_col=arg1, str_col=arg2) + ) return "success" async def do_something_async(async_engine: AsyncEngine) -> None: - """an async function that uses awaiting""" + '''an async function that uses awaiting''' async with async_engine.begin() as async_conn: # run do_something_with_core() with a sync-style # Connection, proxied into an awaitable - return_code = await async_conn.run_sync( - do_something_with_core, 5, "strval" - ) + return_code = await async_conn.run_sync(do_something_with_core, 5, "strval") print(return_code) This method maintains the asyncio event loop all the way through @@ -882,11 +881,9 @@ class AsyncConnection( :ref:`session_run_sync` - ''' # noqa: E501 + """ # noqa: E501 - return await greenlet_spawn( - fn, self._proxied, *arg, _require_await=False, **kw - ) + return await greenlet_spawn(fn, self._proxied, *arg, **kw) def __await__(self) -> Generator[Any, None, AsyncConnection]: return self.start().__await__() @@ -931,7 +928,7 @@ class AsyncConnection( return self._proxied.invalidated @property - def dialect(self) -> Dialect: + def dialect(self) -> Any: r"""Proxy for the :attr:`_engine.Connection.dialect` attribute on behalf of the :class:`_asyncio.AsyncConnection` class. @@ -940,7 +937,7 @@ class AsyncConnection( return self._proxied.dialect @dialect.setter - def dialect(self, attr: Dialect) -> None: + def dialect(self, attr: Any) -> None: self._proxied.dialect = attr @property @@ -1001,7 +998,6 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): :func:`_asyncio.create_async_engine` function:: from sqlalchemy.ext.asyncio import create_async_engine - engine = create_async_engine("postgresql+asyncpg://user:pass@host/dbname") .. versionadded:: 1.4 @@ -1041,9 +1037,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): return self.sync_engine @classmethod - def _regenerate_proxy_for_target( - cls, target: Engine, **additional_kw: Any # noqa: U100 - ) -> AsyncEngine: + def _regenerate_proxy_for_target(cls, target: Engine) -> AsyncEngine: return AsyncEngine(target) @contextlib.asynccontextmanager @@ -1060,6 +1054,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): ) await conn.execute(text("my_special_procedure(5)")) + """ conn = self.connect() @@ -1105,10 +1100,12 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> AsyncEngine: ... + ) -> AsyncEngine: + ... @overload - def execution_options(self, **opt: Any) -> AsyncEngine: ... + def execution_options(self, **opt: Any) -> AsyncEngine: + ... def execution_options(self, **opt: Any) -> AsyncEngine: """Return a new :class:`_asyncio.AsyncEngine` that will provide @@ -1163,7 +1160,7 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): This applies **only** to the built-in cache that is established via the :paramref:`_engine.create_engine.query_cache_size` parameter. It will not impact any dictionary caches that were passed via the - :paramref:`.Connection.execution_options.compiled_cache` parameter. + :paramref:`.Connection.execution_options.query_cache` parameter. .. versionadded:: 1.4 @@ -1346,7 +1343,7 @@ class AsyncTransaction( @classmethod def _regenerate_proxy_for_target( - cls, target: Transaction, **additional_kw: Any # noqa: U100 + cls, target: Transaction ) -> AsyncTransaction: sync_connection = target.connection sync_transaction = target @@ -1421,17 +1418,19 @@ class AsyncTransaction( @overload -def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ... +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: + ... @overload def _get_sync_engine_or_connection( async_engine: AsyncConnection, -) -> Connection: ... +) -> Connection: + ... def _get_sync_engine_or_connection( - async_engine: Union[AsyncEngine, AsyncConnection], + async_engine: Union[AsyncEngine, AsyncConnection] ) -> Union[Engine, Connection]: if isinstance(async_engine, AsyncConnection): return async_engine._proxied diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/exc.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/exc.py index 558187c..3f93767 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/exc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/exc.py @@ -1,5 +1,5 @@ # ext/asyncio/exc.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/result.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/result.py index f1df53b..a13e106 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/result.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/result.py @@ -1,5 +1,5 @@ # ext/asyncio/result.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -93,7 +93,6 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): self._metadata = real_result._metadata self._unique_filter_state = real_result._unique_filter_state - self._source_supports_scalars = real_result._source_supports_scalars self._post_creational_filter = None # BaseCursorResult pre-generates the "_row_getter". Use that @@ -325,20 +324,22 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): return await greenlet_spawn(self._only_one_row, True, False, False) @overload - async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: ... + async def scalar_one(self: AsyncResult[Tuple[_T]]) -> _T: + ... @overload - async def scalar_one(self) -> Any: ... + async def scalar_one(self) -> Any: + ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and - then :meth:`_asyncio.AsyncScalarResult.one`. + then :meth:`_asyncio.AsyncResult.one`. .. seealso:: - :meth:`_asyncio.AsyncScalarResult.one` + :meth:`_asyncio.AsyncResult.one` :meth:`_asyncio.AsyncResult.scalars` @@ -348,20 +349,22 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): @overload async def scalar_one_or_none( self: AsyncResult[Tuple[_T]], - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: ... + async def scalar_one_or_none(self) -> Optional[Any]: + ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. This is equivalent to calling :meth:`_asyncio.AsyncResult.scalars` and - then :meth:`_asyncio.AsyncScalarResult.one_or_none`. + then :meth:`_asyncio.AsyncResult.one_or_none`. .. seealso:: - :meth:`_asyncio.AsyncScalarResult.one_or_none` + :meth:`_asyncio.AsyncResult.one_or_none` :meth:`_asyncio.AsyncResult.scalars` @@ -400,10 +403,12 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): return await greenlet_spawn(self._only_one_row, True, True, False) @overload - async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: ... + async def scalar(self: AsyncResult[Tuple[_T]]) -> Optional[_T]: + ... @overload - async def scalar(self) -> Any: ... + async def scalar(self) -> Any: + ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -447,13 +452,16 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[_TP]]): @overload def scalars( self: AsyncResult[Tuple[_T]], index: Literal[0] - ) -> AsyncScalarResult[_T]: ... + ) -> AsyncScalarResult[_T]: + ... @overload - def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: ... + def scalars(self: AsyncResult[Tuple[_T]]) -> AsyncScalarResult[_T]: + ... @overload - def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ... + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: + ... def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which @@ -825,9 +833,11 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): """ ... - def __aiter__(self) -> AsyncIterator[_R]: ... + async def __aiter__(self) -> AsyncIterator[_R]: + ... - async def __anext__(self) -> _R: ... + async def __anext__(self) -> _R: + ... async def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -861,20 +871,22 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): ... @overload - async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ... + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: + ... @overload - async def scalar_one(self) -> Any: ... + async def scalar_one(self) -> Any: + ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.AsyncScalarResult.one`. + and then :meth:`_engine.Result.one`. .. seealso:: - :meth:`_engine.AsyncScalarResult.one` + :meth:`_engine.Result.one` :meth:`_engine.Result.scalars` @@ -884,20 +896,22 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): @overload async def scalar_one_or_none( self: AsyncTupleResult[Tuple[_T]], - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: ... + async def scalar_one_or_none(self) -> Optional[Any]: + ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. This is equivalent to calling :meth:`_engine.Result.scalars` - and then :meth:`_engine.AsyncScalarResult.one_or_none`. + and then :meth:`_engine.Result.one_or_none`. .. seealso:: - :meth:`_engine.AsyncScalarResult.one_or_none` + :meth:`_engine.Result.one_or_none` :meth:`_engine.Result.scalars` @@ -905,12 +919,12 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): ... @overload - async def scalar( - self: AsyncTupleResult[Tuple[_T]], - ) -> Optional[_T]: ... + async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: + ... @overload - async def scalar(self) -> Any: ... + async def scalar(self) -> Any: + ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/scoping.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/scoping.py index d2a9a51..4c68f53 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/scoping.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/scoping.py @@ -1,5 +1,5 @@ # ext/asyncio/scoping.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -364,7 +364,7 @@ class async_scoped_session(Generic[_AS]): object is entered:: async with async_session.begin(): - ... # ORM transaction is begun + # .. ORM transaction is begun Note that database IO will not normally occur when the session-level transaction is begun, as database transactions begin on an @@ -536,7 +536,8 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: ... + ) -> Result[_T]: + ... @overload async def execute( @@ -548,7 +549,8 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... @overload async def execute( @@ -560,7 +562,8 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> Result[Any]: + ... async def execute( self, @@ -808,28 +811,28 @@ class async_scoped_session(Generic[_AS]): # construct async engines w/ async drivers engines = { - "leader": create_async_engine("sqlite+aiosqlite:///leader.db"), - "other": create_async_engine("sqlite+aiosqlite:///other.db"), - "follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"), - "follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"), + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), } - class RoutingSession(Session): def get_bind(self, mapper=None, clause=None, **kw): # within get_bind(), return sync engines if mapper and issubclass(mapper.class_, MyOtherClass): - return engines["other"].sync_engine + return engines['other'].sync_engine elif self._flushing or isinstance(clause, (Update, Delete)): - return engines["leader"].sync_engine + return engines['leader'].sync_engine else: return engines[ - random.choice(["follower1", "follower2"]) + random.choice(['follower1','follower2']) ].sync_engine - # apply to AsyncSession using sync_session_class - AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession) + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, implicitly non-blocking context in the same manner as ORM event hooks @@ -864,7 +867,7 @@ class async_scoped_session(Generic[_AS]): This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously flushed or committed value, if any. + value to its previously committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1012,7 +1015,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload async def scalar( @@ -1023,7 +1027,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: ... + ) -> Any: + ... async def scalar( self, @@ -1065,7 +1070,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload async def scalars( @@ -1076,7 +1082,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... async def scalars( self, @@ -1175,7 +1182,8 @@ class async_scoped_session(Generic[_AS]): Proxied for the :class:`_asyncio.AsyncSession` class on behalf of the :class:`_asyncio.scoping.async_scoped_session` class. - Raises :class:`_exc.NoResultFound` if the query selects no rows. + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. ..versionadded: 2.0.22 @@ -1205,7 +1213,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[_T]: ... + ) -> AsyncResult[_T]: + ... @overload async def stream( @@ -1216,7 +1225,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Any]: ... + ) -> AsyncResult[Any]: + ... async def stream( self, @@ -1255,7 +1265,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: ... + ) -> AsyncScalarResult[_T]: + ... @overload async def stream_scalars( @@ -1266,7 +1277,8 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: ... + ) -> AsyncScalarResult[Any]: + ... async def stream_scalars( self, diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/session.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/session.py index 68cbb59..30232e5 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/session.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/asyncio/session.py @@ -1,5 +1,5 @@ # ext/asyncio/session.py -# Copyright (C) 2020-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2020-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -38,9 +38,6 @@ from ...orm import Session from ...orm import SessionTransaction from ...orm import state as _instance_state from ...util.concurrency import greenlet_spawn -from ...util.typing import Concatenate -from ...util.typing import ParamSpec - if TYPE_CHECKING: from .engine import AsyncConnection @@ -74,7 +71,6 @@ if TYPE_CHECKING: _AsyncSessionBind = Union["AsyncEngine", "AsyncConnection"] -_P = ParamSpec("_P") _T = TypeVar("_T", bound=Any) @@ -336,12 +332,9 @@ class AsyncSession(ReversibleProxy[Session]): ) async def run_sync( - self, - fn: Callable[Concatenate[Session, _P], _T], - *arg: _P.args, - **kw: _P.kwargs, + self, fn: Callable[..., _T], *arg: Any, **kw: Any ) -> _T: - '''Invoke the given synchronous (i.e. not async) callable, + """Invoke the given synchronous (i.e. not async) callable, passing a synchronous-style :class:`_orm.Session` as the first argument. @@ -351,27 +344,25 @@ class AsyncSession(ReversibleProxy[Session]): E.g.:: def some_business_method(session: Session, param: str) -> str: - """A synchronous function that does not require awaiting + '''A synchronous function that does not require awaiting :param session: a SQLAlchemy Session, used synchronously :return: an optional return value is supported - """ + ''' session.add(MyObject(param=param)) session.flush() return "success" async def do_something_async(async_engine: AsyncEngine) -> None: - """an async function that uses awaiting""" + '''an async function that uses awaiting''' with AsyncSession(async_engine) as async_session: # run some_business_method() with a sync-style # Session, proxied into an awaitable - return_code = await async_session.run_sync( - some_business_method, param="param1" - ) + return_code = await async_session.run_sync(some_business_method, param="param1") print(return_code) This method maintains the asyncio event loop all the way through @@ -393,11 +384,9 @@ class AsyncSession(ReversibleProxy[Session]): :meth:`.AsyncConnection.run_sync` :ref:`session_run_sync` - ''' # noqa: E501 + """ # noqa: E501 - return await greenlet_spawn( - fn, self.sync_session, *arg, _require_await=False, **kw - ) + return await greenlet_spawn(fn, self.sync_session, *arg, **kw) @overload async def execute( @@ -409,7 +398,8 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: ... + ) -> Result[_T]: + ... @overload async def execute( @@ -421,7 +411,8 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... @overload async def execute( @@ -433,7 +424,8 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> Result[Any]: + ... async def execute( self, @@ -479,7 +471,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload async def scalar( @@ -490,7 +483,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: ... + ) -> Any: + ... async def scalar( self, @@ -534,7 +528,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload async def scalars( @@ -545,7 +540,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... async def scalars( self, @@ -628,7 +624,8 @@ class AsyncSession(ReversibleProxy[Session]): """Return an instance based on the given primary key identifier, or raise an exception if not found. - Raises :class:`_exc.NoResultFound` if the query selects no rows. + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. ..versionadded: 2.0.22 @@ -658,7 +655,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[_T]: ... + ) -> AsyncResult[_T]: + ... @overload async def stream( @@ -669,7 +667,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Any]: ... + ) -> AsyncResult[Any]: + ... async def stream( self, @@ -711,7 +710,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: ... + ) -> AsyncScalarResult[_T]: + ... @overload async def stream_scalars( @@ -722,7 +722,8 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: ... + ) -> AsyncScalarResult[Any]: + ... async def stream_scalars( self, @@ -811,9 +812,7 @@ class AsyncSession(ReversibleProxy[Session]): """ trans = self.sync_session.get_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target( - trans, async_session=self - ) + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) else: return None @@ -829,9 +828,7 @@ class AsyncSession(ReversibleProxy[Session]): trans = self.sync_session.get_nested_transaction() if trans is not None: - return AsyncSessionTransaction._retrieve_proxy_for_target( - trans, async_session=self - ) + return AsyncSessionTransaction._retrieve_proxy_for_target(trans) else: return None @@ -882,28 +879,28 @@ class AsyncSession(ReversibleProxy[Session]): # construct async engines w/ async drivers engines = { - "leader": create_async_engine("sqlite+aiosqlite:///leader.db"), - "other": create_async_engine("sqlite+aiosqlite:///other.db"), - "follower1": create_async_engine("sqlite+aiosqlite:///follower1.db"), - "follower2": create_async_engine("sqlite+aiosqlite:///follower2.db"), + 'leader':create_async_engine("sqlite+aiosqlite:///leader.db"), + 'other':create_async_engine("sqlite+aiosqlite:///other.db"), + 'follower1':create_async_engine("sqlite+aiosqlite:///follower1.db"), + 'follower2':create_async_engine("sqlite+aiosqlite:///follower2.db"), } - class RoutingSession(Session): def get_bind(self, mapper=None, clause=None, **kw): # within get_bind(), return sync engines if mapper and issubclass(mapper.class_, MyOtherClass): - return engines["other"].sync_engine + return engines['other'].sync_engine elif self._flushing or isinstance(clause, (Update, Delete)): - return engines["leader"].sync_engine + return engines['leader'].sync_engine else: return engines[ - random.choice(["follower1", "follower2"]) + random.choice(['follower1','follower2']) ].sync_engine - # apply to AsyncSession using sync_session_class - AsyncSessionMaker = async_sessionmaker(sync_session_class=RoutingSession) + AsyncSessionMaker = async_sessionmaker( + sync_session_class=RoutingSession + ) The :meth:`_orm.Session.get_bind` method is called in a non-asyncio, implicitly non-blocking context in the same manner as ORM event hooks @@ -959,7 +956,7 @@ class AsyncSession(ReversibleProxy[Session]): object is entered:: async with async_session.begin(): - ... # ORM transaction is begun + # .. ORM transaction is begun Note that database IO will not normally occur when the session-level transaction is begun, as database transactions begin on an @@ -1312,7 +1309,7 @@ class AsyncSession(ReversibleProxy[Session]): This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously flushed or committed value, if any. + value to its previously committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1636,22 +1633,16 @@ class async_sessionmaker(Generic[_AS]): from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import async_sessionmaker - - async def run_some_sql( - async_session: async_sessionmaker[AsyncSession], - ) -> None: + async def run_some_sql(async_session: async_sessionmaker[AsyncSession]) -> None: async with async_session() as session: session.add(SomeObject(data="object")) session.add(SomeOtherObject(name="other object")) await session.commit() - async def main() -> None: # an AsyncEngine, which the AsyncSession will use for connection # resources - engine = create_async_engine( - "postgresql+asyncpg://scott:tiger@localhost/" - ) + engine = create_async_engine('postgresql+asyncpg://scott:tiger@localhost/') # create a reusable factory for new AsyncSession instances async_session = async_sessionmaker(engine) @@ -1695,7 +1686,8 @@ class async_sessionmaker(Generic[_AS]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): ... + ): + ... @overload def __init__( @@ -1706,7 +1698,8 @@ class async_sessionmaker(Generic[_AS]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): ... + ): + ... def __init__( self, @@ -1750,6 +1743,7 @@ class async_sessionmaker(Generic[_AS]): # commits transaction, closes session + """ session = self() @@ -1782,7 +1776,7 @@ class async_sessionmaker(Generic[_AS]): AsyncSession = async_sessionmaker(some_engine) - AsyncSession.configure(bind=create_async_engine("sqlite+aiosqlite://")) + AsyncSession.configure(bind=create_async_engine('sqlite+aiosqlite://')) """ # noqa E501 self.kw.update(new_kw) @@ -1868,27 +1862,12 @@ class AsyncSessionTransaction( await greenlet_spawn(self._sync_transaction().commit) - @classmethod - def _regenerate_proxy_for_target( # type: ignore[override] - cls, - target: SessionTransaction, - async_session: AsyncSession, - **additional_kw: Any, # noqa: U100 - ) -> AsyncSessionTransaction: - sync_transaction = target - nested = target.nested - obj = cls.__new__(cls) - obj.session = async_session - obj.sync_transaction = obj._assign_proxied(sync_transaction) - obj.nested = nested - return obj - async def start( self, is_ctxmanager: bool = False ) -> AsyncSessionTransaction: self.sync_transaction = self._assign_proxied( await greenlet_spawn( - self.session.sync_session.begin_nested + self.session.sync_session.begin_nested # type: ignore if self.nested else self.session.sync_session.begin ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/automap.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/automap.py index 817f91d..18568c7 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/automap.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/automap.py @@ -1,5 +1,5 @@ # ext/automap.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,7 +11,7 @@ schema, typically though not necessarily one which is reflected. It is hoped that the :class:`.AutomapBase` system provides a quick and modernized solution to the problem that the very famous -`SQLSoup `_ +`SQLSoup `_ also tries to solve, that of generating a quick and rudimentary object model from an existing database on the fly. By addressing the issue strictly at the mapper configuration level, and integrating fully with existing @@ -64,7 +64,7 @@ asking it to reflect the schema and produce mappings:: # collection-based relationships are by default named # "_collection" u1 = session.query(User).first() - print(u1.address_collection) + print (u1.address_collection) Above, calling :meth:`.AutomapBase.prepare` while passing along the :paramref:`.AutomapBase.prepare.reflect` parameter indicates that the @@ -101,7 +101,6 @@ explicit table declaration:: from sqlalchemy import create_engine, MetaData, Table, Column, ForeignKey from sqlalchemy.ext.automap import automap_base - engine = create_engine("sqlite:///mydatabase.db") # produce our own MetaData object @@ -109,15 +108,13 @@ explicit table declaration:: # we can reflect it ourselves from a database, using options # such as 'only' to limit what tables we look at... - metadata.reflect(engine, only=["user", "address"]) + metadata.reflect(engine, only=['user', 'address']) # ... or just define our own Table objects with it (or combine both) - Table( - "user_order", - metadata, - Column("id", Integer, primary_key=True), - Column("user_id", ForeignKey("user.id")), - ) + Table('user_order', metadata, + Column('id', Integer, primary_key=True), + Column('user_id', ForeignKey('user.id')) + ) # we can then produce a set of mappings from this MetaData. Base = automap_base(metadata=metadata) @@ -126,9 +123,8 @@ explicit table declaration:: Base.prepare() # mapped classes are ready - User = Base.classes.user - Address = Base.classes.address - Order = Base.classes.user_order + User, Address, Order = Base.classes.user, Base.classes.address,\ + Base.classes.user_order .. _automap_by_module: @@ -181,23 +177,18 @@ the schema name ``default`` is used if no schema is present:: Base.metadata.create_all(e) - def module_name_for_table(cls, tablename, table): if table.schema is not None: return f"mymodule.{table.schema}" else: return f"mymodule.default" - Base = automap_base() Base.prepare(e, modulename_for_table=module_name_for_table) - Base.prepare( - e, schema="test_schema", modulename_for_table=module_name_for_table - ) - Base.prepare( - e, schema="test_schema_2", modulename_for_table=module_name_for_table - ) + Base.prepare(e, schema="test_schema", modulename_for_table=module_name_for_table) + Base.prepare(e, schema="test_schema_2", modulename_for_table=module_name_for_table) + The same named-classes are organized into a hierarchical collection available at :attr:`.AutomapBase.by_module`. This collection is traversed using the @@ -260,13 +251,12 @@ established based on the table name we use. If our schema contains tables # automap base Base = automap_base() - # pre-declare User for the 'user' table class User(Base): - __tablename__ = "user" + __tablename__ = 'user' # override schema elements like Columns - user_name = Column("name", String) + user_name = Column('name', String) # override relationships too, if desired. # we must use the same name that automap would use for the @@ -274,7 +264,6 @@ established based on the table name we use. If our schema contains tables # generate for "address" address_collection = relationship("address", collection_class=set) - # reflect engine = create_engine("sqlite:///mydatabase.db") Base.prepare(autoload_with=engine) @@ -285,11 +274,11 @@ established based on the table name we use. If our schema contains tables Address = Base.classes.address u1 = session.query(User).first() - print(u1.address_collection) + print (u1.address_collection) # the backref is still there: a1 = session.query(Address).first() - print(a1.user) + print (a1.user) Above, one of the more intricate details is that we illustrated overriding one of the :func:`_orm.relationship` objects that automap would have created. @@ -316,49 +305,35 @@ scheme for class names and a "pluralizer" for collection names using the import re import inflect - def camelize_classname(base, tablename, table): - "Produce a 'camelized' class name, e.g." + "Produce a 'camelized' class name, e.g. " "'words_and_underscores' -> 'WordsAndUnderscores'" - return str( - tablename[0].upper() - + re.sub( - r"_([a-z])", - lambda m: m.group(1).upper(), - tablename[1:], - ) - ) - + return str(tablename[0].upper() + \ + re.sub(r'_([a-z])', lambda m: m.group(1).upper(), tablename[1:])) _pluralizer = inflect.engine() - - def pluralize_collection(base, local_cls, referred_cls, constraint): - "Produce an 'uncamelized', 'pluralized' class name, e.g." + "Produce an 'uncamelized', 'pluralized' class name, e.g. " "'SomeTerm' -> 'some_terms'" referred_name = referred_cls.__name__ - uncamelized = re.sub( - r"[A-Z]", - lambda m: "_%s" % m.group(0).lower(), - referred_name, - )[1:] + uncamelized = re.sub(r'[A-Z]', + lambda m: "_%s" % m.group(0).lower(), + referred_name)[1:] pluralized = _pluralizer.plural(uncamelized) return pluralized - from sqlalchemy.ext.automap import automap_base Base = automap_base() engine = create_engine("sqlite:///mydatabase.db") - Base.prepare( - autoload_with=engine, - classname_for_table=camelize_classname, - name_for_collection_relationship=pluralize_collection, - ) + Base.prepare(autoload_with=engine, + classname_for_table=camelize_classname, + name_for_collection_relationship=pluralize_collection + ) From the above mapping, we would now have classes ``User`` and ``Address``, where the collection from ``User`` to ``Address`` is called @@ -447,21 +422,16 @@ Below is an illustration of how to send options along to all one-to-many relationships:: from sqlalchemy.ext.automap import generate_relationship - from sqlalchemy.orm import interfaces - - def _gen_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw - ): + def _gen_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw): if direction is interfaces.ONETOMANY: - kw["cascade"] = "all, delete-orphan" - kw["passive_deletes"] = True + kw['cascade'] = 'all, delete-orphan' + kw['passive_deletes'] = True # make use of the built-in function to actually return # the result. - return generate_relationship( - base, direction, return_fn, attrname, local_cls, referred_cls, **kw - ) - + return generate_relationship(base, direction, return_fn, + attrname, local_cls, referred_cls, **kw) from sqlalchemy.ext.automap import automap_base from sqlalchemy import create_engine @@ -470,7 +440,8 @@ options along to all one-to-many relationships:: Base = automap_base() engine = create_engine("sqlite:///mydatabase.db") - Base.prepare(autoload_with=engine, generate_relationship=_gen_relationship) + Base.prepare(autoload_with=engine, + generate_relationship=_gen_relationship) Many-to-Many relationships -------------------------- @@ -511,20 +482,18 @@ two classes that are in an inheritance relationship. That is, with two classes given as follows:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id = Column(Integer, primary_key=True) type = Column(String(50)) __mapper_args__ = { - "polymorphic_identity": "employee", - "polymorphic_on": type, + 'polymorphic_identity':'employee', 'polymorphic_on': type } - class Engineer(Employee): - __tablename__ = "engineer" - id = Column(Integer, ForeignKey("employee.id"), primary_key=True) + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) __mapper_args__ = { - "polymorphic_identity": "engineer", + 'polymorphic_identity':'engineer', } The foreign key from ``Engineer`` to ``Employee`` is used not for a @@ -539,28 +508,25 @@ we want as well as the ``inherit_condition``, as these are not things SQLAlchemy can guess:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id = Column(Integer, primary_key=True) type = Column(String(50)) __mapper_args__ = { - "polymorphic_identity": "employee", - "polymorphic_on": type, + 'polymorphic_identity':'employee', 'polymorphic_on':type } - class Engineer(Employee): - __tablename__ = "engineer" - id = Column(Integer, ForeignKey("employee.id"), primary_key=True) - favorite_employee_id = Column(Integer, ForeignKey("employee.id")) + __tablename__ = 'engineer' + id = Column(Integer, ForeignKey('employee.id'), primary_key=True) + favorite_employee_id = Column(Integer, ForeignKey('employee.id')) - favorite_employee = relationship( - Employee, foreign_keys=favorite_employee_id - ) + favorite_employee = relationship(Employee, + foreign_keys=favorite_employee_id) __mapper_args__ = { - "polymorphic_identity": "engineer", - "inherit_condition": id == Employee.id, + 'polymorphic_identity':'engineer', + 'inherit_condition': id == Employee.id } Handling Simple Naming Conflicts @@ -593,24 +559,20 @@ and will emit an error on mapping. We can resolve this conflict by using an underscore as follows:: - def name_for_scalar_relationship( - base, local_cls, referred_cls, constraint - ): + def name_for_scalar_relationship(base, local_cls, referred_cls, constraint): name = referred_cls.__name__.lower() local_table = local_cls.__table__ if name in local_table.columns: newname = name + "_" warnings.warn( - "Already detected name %s present. using %s" % (name, newname) - ) + "Already detected name %s present. using %s" % + (name, newname)) return newname return name - Base.prepare( - autoload_with=engine, - name_for_scalar_relationship=name_for_scalar_relationship, - ) + Base.prepare(autoload_with=engine, + name_for_scalar_relationship=name_for_scalar_relationship) Alternatively, we can change the name on the column side. The columns that are mapped can be modified using the technique described at @@ -619,14 +581,13 @@ to a new name:: Base = automap_base() - class TableB(Base): - __tablename__ = "table_b" - _table_a = Column("table_a", ForeignKey("table_a.id")) - + __tablename__ = 'table_b' + _table_a = Column('table_a', ForeignKey('table_a.id')) Base.prepare(autoload_with=engine) + Using Automap with Explicit Declarations ======================================== @@ -642,29 +603,26 @@ defines table metadata:: Base = automap_base() - class User(Base): - __tablename__ = "user" + __tablename__ = 'user' id = Column(Integer, primary_key=True) name = Column(String) - class Address(Base): - __tablename__ = "address" + __tablename__ = 'address' id = Column(Integer, primary_key=True) email = Column(String) - user_id = Column(ForeignKey("user.id")) - + user_id = Column(ForeignKey('user.id')) # produce relationships Base.prepare() # mapping is complete, with "address_collection" and # "user" relationships - a1 = Address(email="u1") - a2 = Address(email="u2") + a1 = Address(email='u1') + a2 = Address(email='u2') u1 = User(address_collection=[a1, a2]) assert a1.user is u1 @@ -693,8 +651,7 @@ be applied as:: @event.listens_for(Base.metadata, "column_reflect") def column_reflect(inspector, table, column_info): # set column.key = "attr_" - column_info["key"] = "attr_%s" % column_info["name"].lower() - + column_info['key'] = "attr_%s" % column_info['name'].lower() # run reflection Base.prepare(autoload_with=engine) @@ -758,9 +715,8 @@ _VT = TypeVar("_VT", bound=Any) class PythonNameForTableType(Protocol): - def __call__( - self, base: Type[Any], tablename: str, table: Table - ) -> str: ... + def __call__(self, base: Type[Any], tablename: str, table: Table) -> str: + ... def classname_for_table( @@ -807,7 +763,8 @@ class NameForScalarRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: ... + ) -> str: + ... def name_for_scalar_relationship( @@ -847,7 +804,8 @@ class NameForCollectionRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: ... + ) -> str: + ... def name_for_collection_relationship( @@ -892,7 +850,8 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Relationship[Any]: ... + ) -> Relationship[Any]: + ... @overload def __call__( @@ -904,7 +863,8 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> ORMBackrefArgument: ... + ) -> ORMBackrefArgument: + ... def __call__( self, @@ -917,7 +877,8 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Union[ORMBackrefArgument, Relationship[Any]]: ... + ) -> Union[ORMBackrefArgument, Relationship[Any]]: + ... @overload @@ -929,7 +890,8 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> Relationship[Any]: ... +) -> Relationship[Any]: + ... @overload @@ -941,7 +903,8 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> ORMBackrefArgument: ... +) -> ORMBackrefArgument: + ... def generate_relationship( @@ -1045,12 +1008,6 @@ class AutomapBase: User, Address = Base.classes.User, Base.classes.Address - For class names that overlap with a method name of - :class:`.util.Properties`, such as ``items()``, the getitem form - is also supported:: - - Item = Base.classes["items"] - """ by_module: ClassVar[ByModuleProperties] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/baked.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/baked.py index cd3e087..64c9ce6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/baked.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/baked.py @@ -1,5 +1,5 @@ -# ext/baked.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/ext/baked.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -258,19 +258,23 @@ class BakedQuery: is passed to the lambda:: sub_bq = self.bakery(lambda s: s.query(User.name)) - sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) main_bq = self.bakery(lambda s: s.query(Address)) - main_bq += lambda q: q.filter(sub_bq.to_query(q).exists()) + main_bq += lambda q: q.filter( + sub_bq.to_query(q).exists()) In the case where the subquery is used in the first callable against a :class:`.Session`, the :class:`.Session` is also accepted:: sub_bq = self.bakery(lambda s: s.query(User.name)) - sub_bq += lambda q: q.filter(User.id == Address.user_id).correlate(Address) + sub_bq += lambda q: q.filter( + User.id == Address.user_id).correlate(Address) main_bq = self.bakery( - lambda s: s.query(Address.id, sub_bq.to_query(q).scalar_subquery()) + lambda s: s.query( + Address.id, sub_bq.to_query(q).scalar_subquery()) ) :param query_or_session: a :class:`_query.Query` object or a class @@ -281,7 +285,7 @@ class BakedQuery: .. versionadded:: 1.3 - """ # noqa: E501 + """ if isinstance(query_or_session, Session): session = query_or_session diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/compiler.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/compiler.py index cc64477..39a5541 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/compiler.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/compiler.py @@ -1,9 +1,10 @@ # ext/compiler.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors r"""Provides an API for creation of custom ClauseElements and compilers. @@ -17,11 +18,9 @@ more callables defining its compilation:: from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ColumnClause - class MyColumn(ColumnClause): inherit_cache = True - @compiles(MyColumn) def compile_mycolumn(element, compiler, **kw): return "[%s]" % element.name @@ -33,12 +32,10 @@ when the object is compiled to a string:: from sqlalchemy import select - s = select(MyColumn("x"), MyColumn("y")) + s = select(MyColumn('x'), MyColumn('y')) print(str(s)) -Produces: - -.. sourcecode:: sql +Produces:: SELECT [x], [y] @@ -50,7 +47,6 @@ invoked for the dialect in use:: from sqlalchemy.schema import DDLElement - class AlterColumn(DDLElement): inherit_cache = False @@ -58,18 +54,14 @@ invoked for the dialect in use:: self.column = column self.cmd = cmd - @compiles(AlterColumn) def visit_alter_column(element, compiler, **kw): return "ALTER COLUMN %s ..." % element.column.name - - @compiles(AlterColumn, "postgresql") + @compiles(AlterColumn, 'postgresql') def visit_alter_column(element, compiler, **kw): - return "ALTER TABLE %s ALTER COLUMN %s ..." % ( - element.table.name, - element.column.name, - ) + return "ALTER TABLE %s ALTER COLUMN %s ..." % (element.table.name, + element.column.name) The second ``visit_alter_table`` will be invoked when any ``postgresql`` dialect is used. @@ -89,7 +81,6 @@ method which can be used for compilation of embedded attributes:: from sqlalchemy.sql.expression import Executable, ClauseElement - class InsertFromSelect(Executable, ClauseElement): inherit_cache = False @@ -97,27 +88,20 @@ method which can be used for compilation of embedded attributes:: self.table = table self.select = select - @compiles(InsertFromSelect) def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True, **kw), - compiler.process(element.select, **kw), + compiler.process(element.select, **kw) ) - - insert = InsertFromSelect(t1, select(t1).where(t1.c.x > 5)) + insert = InsertFromSelect(t1, select(t1).where(t1.c.x>5)) print(insert) -Produces (formatted for readability): +Produces:: -.. sourcecode:: sql - - INSERT INTO mytable ( - SELECT mytable.x, mytable.y, mytable.z - FROM mytable - WHERE mytable.x > :x_1 - ) + "INSERT INTO mytable (SELECT mytable.x, mytable.y, mytable.z + FROM mytable WHERE mytable.x > :x_1)" .. note:: @@ -137,10 +121,11 @@ below where we generate a CHECK constraint that embeds a SQL expression:: @compiles(MyConstraint) def compile_my_constraint(constraint, ddlcompiler, **kw): - kw["literal_binds"] = True + kw['literal_binds'] = True return "CONSTRAINT %s CHECK (%s)" % ( constraint.name, - ddlcompiler.sql_compiler.process(constraint.expression, **kw), + ddlcompiler.sql_compiler.process( + constraint.expression, **kw) ) Above, we add an additional flag to the process step as called by @@ -168,7 +153,6 @@ an endless loop. Such as, to add "prefix" to all insert statements:: from sqlalchemy.sql.expression import Insert - @compiles(Insert) def prefix_inserts(insert, compiler, **kw): return compiler.visit_insert(insert.prefix_with("some prefix"), **kw) @@ -184,16 +168,17 @@ Changing Compilation of Types ``compiler`` works for types, too, such as below where we implement the MS-SQL specific 'max' keyword for ``String``/``VARCHAR``:: - @compiles(String, "mssql") - @compiles(VARCHAR, "mssql") + @compiles(String, 'mssql') + @compiles(VARCHAR, 'mssql') def compile_varchar(element, compiler, **kw): - if element.length == "max": + if element.length == 'max': return "VARCHAR('max')" else: return compiler.visit_VARCHAR(element, **kw) - - foo = Table("foo", metadata, Column("data", VARCHAR("max"))) + foo = Table('foo', metadata, + Column('data', VARCHAR('max')) + ) Subclassing Guidelines ====================== @@ -231,23 +216,18 @@ A synopsis is as follows: from sqlalchemy.sql.expression import FunctionElement - class coalesce(FunctionElement): - name = "coalesce" + name = 'coalesce' inherit_cache = True - @compiles(coalesce) def compile(element, compiler, **kw): return "coalesce(%s)" % compiler.process(element.clauses, **kw) - - @compiles(coalesce, "oracle") + @compiles(coalesce, 'oracle') def compile(element, compiler, **kw): if len(element.clauses) > 2: - raise TypeError( - "coalesce only supports two arguments on " "Oracle Database" - ) + raise TypeError("coalesce only supports two arguments on Oracle") return "nvl(%s)" % compiler.process(element.clauses, **kw) * :class:`.ExecutableDDLElement` - The root of all DDL expressions, @@ -301,7 +281,6 @@ for example to the "synopsis" example indicated previously:: class MyColumn(ColumnClause): inherit_cache = True - @compiles(MyColumn) def compile_mycolumn(element, compiler, **kw): return "[%s]" % element.name @@ -340,12 +319,11 @@ caching:: self.table = table self.select = select - @compiles(InsertFromSelect) def visit_insert_from_select(element, compiler, **kw): return "INSERT INTO %s (%s)" % ( compiler.process(element.table, asfrom=True, **kw), - compiler.process(element.select, **kw), + compiler.process(element.select, **kw) ) While it is also possible that the above ``InsertFromSelect`` could be made to @@ -381,32 +359,28 @@ For PostgreSQL and Microsoft SQL Server:: from sqlalchemy.ext.compiler import compiles from sqlalchemy.types import DateTime - class utcnow(expression.FunctionElement): type = DateTime() inherit_cache = True - - @compiles(utcnow, "postgresql") + @compiles(utcnow, 'postgresql') def pg_utcnow(element, compiler, **kw): return "TIMEZONE('utc', CURRENT_TIMESTAMP)" - - @compiles(utcnow, "mssql") + @compiles(utcnow, 'mssql') def ms_utcnow(element, compiler, **kw): return "GETUTCDATE()" Example usage:: - from sqlalchemy import Table, Column, Integer, String, DateTime, MetaData - + from sqlalchemy import ( + Table, Column, Integer, String, DateTime, MetaData + ) metadata = MetaData() - event = Table( - "event", - metadata, + event = Table("event", metadata, Column("id", Integer, primary_key=True), Column("description", String(50), nullable=False), - Column("timestamp", DateTime, server_default=utcnow()), + Column("timestamp", DateTime, server_default=utcnow()) ) "GREATEST" function @@ -421,30 +395,30 @@ accommodates two arguments:: from sqlalchemy.ext.compiler import compiles from sqlalchemy.types import Numeric - class greatest(expression.FunctionElement): type = Numeric() - name = "greatest" + name = 'greatest' inherit_cache = True - @compiles(greatest) def default_greatest(element, compiler, **kw): return compiler.visit_function(element) - - @compiles(greatest, "sqlite") - @compiles(greatest, "mssql") - @compiles(greatest, "oracle") + @compiles(greatest, 'sqlite') + @compiles(greatest, 'mssql') + @compiles(greatest, 'oracle') def case_greatest(element, compiler, **kw): arg1, arg2 = list(element.clauses) return compiler.process(case((arg1 > arg2, arg1), else_=arg2), **kw) Example usage:: - Session.query(Account).filter( - greatest(Account.checking_balance, Account.savings_balance) > 10000 - ) + Session.query(Account).\ + filter( + greatest( + Account.checking_balance, + Account.savings_balance) > 10000 + ) "false" expression ------------------ @@ -455,19 +429,16 @@ don't have a "false" constant:: from sqlalchemy.sql import expression from sqlalchemy.ext.compiler import compiles - class sql_false(expression.ColumnElement): inherit_cache = True - @compiles(sql_false) def default_false(element, compiler, **kw): return "false" - - @compiles(sql_false, "mssql") - @compiles(sql_false, "mysql") - @compiles(sql_false, "oracle") + @compiles(sql_false, 'mssql') + @compiles(sql_false, 'mysql') + @compiles(sql_false, 'oracle') def int_false(element, compiler, **kw): return "0" @@ -477,33 +448,19 @@ Example usage:: exp = union_all( select(users.c.name, sql_false().label("enrolled")), - select(customers.c.name, customers.c.enrolled), + select(customers.c.name, customers.c.enrolled) ) """ -from __future__ import annotations - -from typing import Any -from typing import Callable -from typing import Dict -from typing import Type -from typing import TYPE_CHECKING -from typing import TypeVar - from .. import exc from ..sql import sqltypes -if TYPE_CHECKING: - from ..sql.compiler import SQLCompiler -_F = TypeVar("_F", bound=Callable[..., Any]) - - -def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]: +def compiles(class_, *specs): """Register a function as a compiler for a given :class:`_expression.ClauseElement` type.""" - def decorate(fn: _F) -> _F: + def decorate(fn): # get an existing @compiles handler existing = class_.__dict__.get("_compiler_dispatcher", None) @@ -516,9 +473,7 @@ def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]: if existing_dispatch: - def _wrap_existing_dispatch( - element: Any, compiler: SQLCompiler, **kw: Any - ) -> Any: + def _wrap_existing_dispatch(element, compiler, **kw): try: return existing_dispatch(element, compiler, **kw) except exc.UnsupportedCompilationError as uce: @@ -550,7 +505,7 @@ def compiles(class_: Type[Any], *specs: str) -> Callable[[_F], _F]: return decorate -def deregister(class_: Type[Any]) -> None: +def deregister(class_): """Remove all custom compilers associated with a given :class:`_expression.ClauseElement` type. @@ -562,10 +517,10 @@ def deregister(class_: Type[Any]) -> None: class _dispatcher: - def __init__(self) -> None: - self.specs: Dict[str, Callable[..., Any]] = {} + def __init__(self): + self.specs = {} - def __call__(self, element: Any, compiler: SQLCompiler, **kw: Any) -> Any: + def __call__(self, element, compiler, **kw): # TODO: yes, this could also switch off of DBAPI in use. fn = self.specs.get(compiler.dialect.name, None) if not fn: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/__init__.py index 0383f9d..2f6b2f2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/__init__.py @@ -1,5 +1,5 @@ # ext/declarative/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/extensions.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/extensions.py index 3dc6bf6..acc9d08 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/extensions.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/declarative/extensions.py @@ -1,5 +1,5 @@ # ext/declarative/extensions.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -50,26 +50,23 @@ class ConcreteBase: from sqlalchemy.ext.declarative import ConcreteBase - class Employee(ConcreteBase, Base): - __tablename__ = "employee" + __tablename__ = 'employee' employee_id = Column(Integer, primary_key=True) name = Column(String(50)) __mapper_args__ = { - "polymorphic_identity": "employee", - "concrete": True, - } - + 'polymorphic_identity':'employee', + 'concrete':True} class Manager(Employee): - __tablename__ = "manager" + __tablename__ = 'manager' employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - "polymorphic_identity": "manager", - "concrete": True, - } + 'polymorphic_identity':'manager', + 'concrete':True} + The name of the discriminator column used by :func:`.polymorphic_union` defaults to the name ``type``. To suit the use case of a mapping where an @@ -78,7 +75,7 @@ class ConcreteBase: ``_concrete_discriminator_name`` attribute:: class Employee(ConcreteBase, Base): - _concrete_discriminator_name = "_concrete_discriminator" + _concrete_discriminator_name = '_concrete_discriminator' .. versionadded:: 1.3.19 Added the ``_concrete_discriminator_name`` attribute to :class:`_declarative.ConcreteBase` so that the @@ -171,27 +168,23 @@ class AbstractConcreteBase(ConcreteBase): from sqlalchemy.orm import DeclarativeBase from sqlalchemy.ext.declarative import AbstractConcreteBase - class Base(DeclarativeBase): pass - class Employee(AbstractConcreteBase, Base): pass - class Manager(Employee): - __tablename__ = "manager" + __tablename__ = 'manager' employee_id = Column(Integer, primary_key=True) name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - "polymorphic_identity": "manager", - "concrete": True, + 'polymorphic_identity':'manager', + 'concrete':True } - Base.registry.configure() The abstract base class is handled by declarative in a special way; @@ -207,12 +200,10 @@ class AbstractConcreteBase(ConcreteBase): from sqlalchemy.ext.declarative import AbstractConcreteBase - class Company(Base): - __tablename__ = "company" + __tablename__ = 'company' id = Column(Integer, primary_key=True) - class Employee(AbstractConcreteBase, Base): strict_attrs = True @@ -220,31 +211,31 @@ class AbstractConcreteBase(ConcreteBase): @declared_attr def company_id(cls): - return Column(ForeignKey("company.id")) + return Column(ForeignKey('company.id')) @declared_attr def company(cls): return relationship("Company") - class Manager(Employee): - __tablename__ = "manager" + __tablename__ = 'manager' name = Column(String(50)) manager_data = Column(String(40)) __mapper_args__ = { - "polymorphic_identity": "manager", - "concrete": True, + 'polymorphic_identity':'manager', + 'concrete':True } - Base.registry.configure() When we make use of our mappings however, both ``Manager`` and ``Employee`` will have an independently usable ``.company`` attribute:: - session.execute(select(Employee).filter(Employee.company.has(id=5))) + session.execute( + select(Employee).filter(Employee.company.has(id=5)) + ) :param strict_attrs: when specified on the base class, "strict" attribute mode is enabled which attempts to limit ORM mapped attributes on the @@ -375,12 +366,10 @@ class DeferredReflection: from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import DeferredReflection - Base = declarative_base() - class MyClass(DeferredReflection, Base): - __tablename__ = "mytable" + __tablename__ = 'mytable' Above, ``MyClass`` is not yet mapped. After a series of classes have been defined in the above fashion, all tables @@ -402,22 +391,17 @@ class DeferredReflection: class ReflectedOne(DeferredReflection, Base): __abstract__ = True - class ReflectedTwo(DeferredReflection, Base): __abstract__ = True - class MyClass(ReflectedOne): - __tablename__ = "mytable" - + __tablename__ = 'mytable' class MyOtherClass(ReflectedOne): - __tablename__ = "myothertable" - + __tablename__ = 'myothertable' class YetAnotherClass(ReflectedTwo): - __tablename__ = "yetanothertable" - + __tablename__ = 'yetanothertable' # ... etc. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/horizontal_shard.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/horizontal_shard.py index 3ea3304..963bd00 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/horizontal_shard.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/horizontal_shard.py @@ -1,5 +1,5 @@ # ext/horizontal_shard.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -83,7 +83,8 @@ class ShardChooser(Protocol): mapper: Optional[Mapper[_T]], instance: Any, clause: Optional[ClauseElement], - ) -> Any: ... + ) -> Any: + ... class IdentityChooser(Protocol): @@ -96,7 +97,8 @@ class IdentityChooser(Protocol): execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, **kw: Any, - ) -> Any: ... + ) -> Any: + ... class ShardedQuery(Query[_T]): @@ -125,9 +127,12 @@ class ShardedQuery(Query[_T]): The shard_id can be passed for a 2.0 style execution to the bind_arguments dictionary of :meth:`.Session.execute`:: - results = session.execute(stmt, bind_arguments={"shard_id": "my_shard"}) + results = session.execute( + stmt, + bind_arguments={"shard_id": "my_shard"} + ) - """ # noqa: E501 + """ return self.execution_options(_sa_shard_id=shard_id) @@ -318,7 +323,7 @@ class ShardedSession(Session): state.identity_token = shard_id return shard_id - def connection_callable( + def connection_callable( # type: ignore [override] self, mapper: Optional[Mapper[_T]] = None, instance: Optional[Any] = None, @@ -379,9 +384,9 @@ class set_shard_id(ORMOption): the :meth:`_sql.Executable.options` method of any executable statement:: stmt = ( - select(MyObject) - .where(MyObject.name == "some name") - .options(set_shard_id("shard1")) + select(MyObject). + where(MyObject.name == 'some name'). + options(set_shard_id("shard1")) ) Above, the statement when invoked will limit to the "shard1" shard diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/hybrid.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/hybrid.py index c1c46e7..615f166 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/hybrid.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/hybrid.py @@ -1,5 +1,5 @@ # ext/hybrid.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -34,9 +34,8 @@ may receive the class directly, depending on context:: class Base(DeclarativeBase): pass - class Interval(Base): - __tablename__ = "interval" + __tablename__ = 'interval' id: Mapped[int] = mapped_column(primary_key=True) start: Mapped[int] @@ -58,6 +57,7 @@ may receive the class directly, depending on context:: def intersects(self, other: Interval) -> bool: return self.contains(other.start) | self.contains(other.end) + Above, the ``length`` property returns the difference between the ``end`` and ``start`` attributes. With an instance of ``Interval``, this subtraction occurs in Python, using normal Python descriptor @@ -150,7 +150,6 @@ the absolute value function:: from sqlalchemy import func from sqlalchemy import type_coerce - class Interval(Base): # ... @@ -215,7 +214,6 @@ example below that illustrates the use of :meth:`.hybrid_property.setter` and # correct use, however is not accepted by pep-484 tooling - class Interval(Base): # ... @@ -258,7 +256,6 @@ a single decorator under one name:: # correct use which is also accepted by pep-484 tooling - class Interval(Base): # ... @@ -333,7 +330,6 @@ expression is used as the column that's the target of the SET. If our ``Interval.start``, this could be substituted directly:: from sqlalchemy import update - stmt = update(Interval).values({Interval.start_point: 10}) However, when using a composite hybrid like ``Interval.length``, this @@ -344,7 +340,6 @@ A handler that works similarly to our setter would be:: from typing import List, Tuple, Any - class Interval(Base): # ... @@ -357,10 +352,10 @@ A handler that works similarly to our setter would be:: self.end = self.start + value @length.inplace.update_expression - def _length_update_expression( - cls, value: Any - ) -> List[Tuple[Any, Any]]: - return [(cls.end, cls.start + value)] + def _length_update_expression(cls, value: Any) -> List[Tuple[Any, Any]]: + return [ + (cls.end, cls.start + value) + ] Above, if we use ``Interval.length`` in an UPDATE expression, we get a hybrid SET expression: @@ -417,16 +412,15 @@ mapping which relates a ``User`` to a ``SavingsAccount``:: class SavingsAccount(Base): - __tablename__ = "account" + __tablename__ = 'account' id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) owner: Mapped[User] = relationship(back_populates="accounts") - class User(Base): - __tablename__ = "user" + __tablename__ = 'user' id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) @@ -454,10 +448,7 @@ mapping which relates a ``User`` to a ``SavingsAccount``:: @balance.inplace.expression @classmethod def _balance_expression(cls) -> SQLColumnExpression[Optional[Decimal]]: - return cast( - "SQLColumnExpression[Optional[Decimal]]", - SavingsAccount.balance, - ) + return cast("SQLColumnExpression[Optional[Decimal]]", SavingsAccount.balance) The above hybrid property ``balance`` works with the first ``SavingsAccount`` entry in the list of accounts for this user. The @@ -480,11 +471,8 @@ be used in an appropriate context such that an appropriate join to .. sourcecode:: pycon+sql >>> from sqlalchemy import select - >>> print( - ... select(User, User.balance) - ... .join(User.accounts) - ... .filter(User.balance > 5000) - ... ) + >>> print(select(User, User.balance). + ... join(User.accounts).filter(User.balance > 5000)) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance FROM "user" JOIN account ON "user".id = account.user_id @@ -499,11 +487,8 @@ would use an outer join: >>> from sqlalchemy import select >>> from sqlalchemy import or_ - >>> print( - ... select(User, User.balance) - ... .outerjoin(User.accounts) - ... .filter(or_(User.balance < 5000, User.balance == None)) - ... ) + >>> print (select(User, User.balance).outerjoin(User.accounts). + ... filter(or_(User.balance < 5000, User.balance == None))) {printsql}SELECT "user".id AS user_id, "user".name AS user_name, account.balance AS account_balance FROM "user" LEFT OUTER JOIN account ON "user".id = account.user_id @@ -543,16 +528,15 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for class SavingsAccount(Base): - __tablename__ = "account" + __tablename__ = 'account' id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(ForeignKey("user.id")) + user_id: Mapped[int] = mapped_column(ForeignKey('user.id')) balance: Mapped[Decimal] = mapped_column(Numeric(15, 5)) owner: Mapped[User] = relationship(back_populates="accounts") - class User(Base): - __tablename__ = "user" + __tablename__ = 'user' id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(100)) @@ -562,9 +546,7 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for @hybrid_property def balance(self) -> Decimal: - return sum( - (acc.balance for acc in self.accounts), start=Decimal("0") - ) + return sum((acc.balance for acc in self.accounts), start=Decimal("0")) @balance.inplace.expression @classmethod @@ -575,6 +557,7 @@ we can adjust our ``SavingsAccount`` example to aggregate the balances for .label("total_balance") ) + The above recipe will give us the ``balance`` column which renders a correlated SELECT: @@ -621,7 +604,6 @@ named ``word_insensitive``:: from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column - class Base(DeclarativeBase): pass @@ -630,9 +612,8 @@ named ``word_insensitive``:: def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 return func.lower(self.__clause_element__()) == func.lower(other) - class SearchWord(Base): - __tablename__ = "searchword" + __tablename__ = 'searchword' id: Mapped[int] = mapped_column(primary_key=True) word: Mapped[str] @@ -694,7 +675,6 @@ how the standard Python ``@property`` object works:: def _name_setter(self, value: str) -> None: self.first_name = value - class FirstNameLastName(FirstNameOnly): # ... @@ -704,11 +684,11 @@ how the standard Python ``@property`` object works:: # of FirstNameOnly.name that is local to FirstNameLastName @FirstNameOnly.name.getter def name(self) -> str: - return self.first_name + " " + self.last_name + return self.first_name + ' ' + self.last_name @name.inplace.setter def _name_setter(self, value: str) -> None: - self.first_name, self.last_name = value.split(" ", 1) + self.first_name, self.last_name = value.split(' ', 1) Above, the ``FirstNameLastName`` class refers to the hybrid from ``FirstNameOnly.name`` to repurpose its getter and setter for the subclass. @@ -729,7 +709,8 @@ reference the instrumented attribute back to the hybrid object:: @FirstNameOnly.name.overrides.expression @classmethod def name(cls): - return func.concat(cls.first_name, " ", cls.last_name) + return func.concat(cls.first_name, ' ', cls.last_name) + Hybrid Value Objects -------------------- @@ -770,7 +751,7 @@ Replacing the previous ``CaseInsensitiveComparator`` class with a new def __str__(self): return self.word - key = "word" + key = 'word' "Label to apply to Query tuple results" Above, the ``CaseInsensitiveWord`` object represents ``self.word``, which may @@ -781,7 +762,7 @@ SQL side or Python side. Our ``SearchWord`` class can now deliver the ``CaseInsensitiveWord`` object unconditionally from a single hybrid call:: class SearchWord(Base): - __tablename__ = "searchword" + __tablename__ = 'searchword' id: Mapped[int] = mapped_column(primary_key=True) word: Mapped[str] @@ -923,11 +904,13 @@ class HybridExtensionType(InspectionAttrExtensionType): class _HybridGetterType(Protocol[_T_co]): - def __call__(s, self: Any) -> _T_co: ... + def __call__(s, self: Any) -> _T_co: + ... class _HybridSetterType(Protocol[_T_con]): - def __call__(s, self: Any, value: _T_con) -> None: ... + def __call__(s, self: Any, value: _T_con) -> None: + ... class _HybridUpdaterType(Protocol[_T_con]): @@ -935,21 +918,25 @@ class _HybridUpdaterType(Protocol[_T_con]): s, cls: Any, value: Union[_T_con, _ColumnExpressionArgument[_T_con]], - ) -> List[Tuple[_DMLColumnArgument, Any]]: ... + ) -> List[Tuple[_DMLColumnArgument, Any]]: + ... class _HybridDeleterType(Protocol[_T_co]): - def __call__(s, self: Any) -> None: ... + def __call__(s, self: Any) -> None: + ... class _HybridExprCallableType(Protocol[_T_co]): def __call__( s, cls: Any - ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... + ) -> Union[_HasClauseElement, SQLColumnExpression[_T_co]]: + ... class _HybridComparatorCallableType(Protocol[_T]): - def __call__(self, cls: Any) -> Comparator[_T]: ... + def __call__(self, cls: Any) -> Comparator[_T]: + ... class _HybridClassLevelAccessor(QueryableAttribute[_T]): @@ -960,24 +947,23 @@ class _HybridClassLevelAccessor(QueryableAttribute[_T]): if TYPE_CHECKING: - def getter( - self, fget: _HybridGetterType[_T] - ) -> hybrid_property[_T]: ... + def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: + ... - def setter( - self, fset: _HybridSetterType[_T] - ) -> hybrid_property[_T]: ... + def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: + ... - def deleter( - self, fdel: _HybridDeleterType[_T] - ) -> hybrid_property[_T]: ... + def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: + ... @property - def overrides(self) -> hybrid_property[_T]: ... + def overrides(self) -> hybrid_property[_T]: + ... def update_expression( self, meth: _HybridUpdaterType[_T] - ) -> hybrid_property[_T]: ... + ) -> hybrid_property[_T]: + ... class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): @@ -1002,7 +988,6 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): from sqlalchemy.ext.hybrid import hybrid_method - class SomeClass: @hybrid_method def value(self, x, y): @@ -1040,12 +1025,14 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> Callable[_P, SQLCoreOperations[_R]]: ... + ) -> Callable[_P, SQLCoreOperations[_R]]: + ... @overload def __get__( self, instance: object, owner: Type[object] - ) -> Callable[_P, _R]: ... + ) -> Callable[_P, _R]: + ... def __get__( self, instance: Optional[object], owner: Type[object] @@ -1100,7 +1087,6 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): from sqlalchemy.ext.hybrid import hybrid_property - class SomeClass: @hybrid_property def value(self): @@ -1117,18 +1103,21 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): self.expr = _unwrap_classmethod(expr) self.custom_comparator = _unwrap_classmethod(custom_comparator) self.update_expr = _unwrap_classmethod(update_expr) - util.update_wrapper(self, fget) # type: ignore[arg-type] + util.update_wrapper(self, fget) @overload - def __get__(self, instance: Any, owner: Literal[None]) -> Self: ... + def __get__(self, instance: Any, owner: Literal[None]) -> Self: + ... @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> _HybridClassLevelAccessor[_T]: ... + ) -> _HybridClassLevelAccessor[_T]: + ... @overload - def __get__(self, instance: object, owner: Type[object]) -> _T: ... + def __get__(self, instance: object, owner: Type[object]) -> _T: + ... def __get__( self, instance: Optional[object], owner: Optional[Type[object]] @@ -1179,7 +1168,6 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): def foobar(self): return self._foobar - class SubClass(SuperClass): # ... @@ -1389,7 +1377,10 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): @fullname.update_expression def fullname(cls, value): fname, lname = value.split(" ", 1) - return [(cls.first_name, fname), (cls.last_name, lname)] + return [ + (cls.first_name, fname), + (cls.last_name, lname) + ] .. versionadded:: 1.2 @@ -1456,7 +1447,7 @@ class Comparator(interfaces.PropComparator[_T]): classes for usage with hybrids.""" def __init__( - self, expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]] + self, expression: Union[_HasClauseElement, SQLColumnExpression[_T]] ): self.expression = expression @@ -1491,7 +1482,7 @@ class ExprComparator(Comparator[_T]): def __init__( self, cls: Type[Any], - expression: Union[_HasClauseElement[_T], SQLColumnExpression[_T]], + expression: Union[_HasClauseElement, SQLColumnExpression[_T]], hybrid: hybrid_property[_T], ): self.cls = cls diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/indexable.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/indexable.py index 883d974..dbaad3c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/indexable.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/indexable.py @@ -1,5 +1,5 @@ -# ext/indexable.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# ext/index.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -36,19 +36,19 @@ as a dedicated attribute which behaves like a standalone column:: Base = declarative_base() - class Person(Base): - __tablename__ = "person" + __tablename__ = 'person' id = Column(Integer, primary_key=True) data = Column(JSON) - name = index_property("data", "name") + name = index_property('data', 'name') + Above, the ``name`` attribute now behaves like a mapped column. We can compose a new ``Person`` and set the value of ``name``:: - >>> person = Person(name="Alchemist") + >>> person = Person(name='Alchemist') The value is now accessible:: @@ -59,11 +59,11 @@ Behind the scenes, the JSON field was initialized to a new blank dictionary and the field was set:: >>> person.data - {'name': 'Alchemist'} + {"name": "Alchemist'} The field is mutable in place:: - >>> person.name = "Renamed" + >>> person.name = 'Renamed' >>> person.name 'Renamed' >>> person.data @@ -87,17 +87,18 @@ A missing key will produce ``AttributeError``:: >>> person = Person() >>> person.name + ... AttributeError: 'name' Unless you set a default value:: >>> class Person(Base): - ... __tablename__ = "person" - ... - ... id = Column(Integer, primary_key=True) - ... data = Column(JSON) - ... - ... name = index_property("data", "name", default=None) # See default + >>> __tablename__ = 'person' + >>> + >>> id = Column(Integer, primary_key=True) + >>> data = Column(JSON) + >>> + >>> name = index_property('data', 'name', default=None) # See default >>> person = Person() >>> print(person.name) @@ -110,11 +111,11 @@ an indexed SQL criteria:: >>> from sqlalchemy.orm import Session >>> session = Session() - >>> query = session.query(Person).filter(Person.name == "Alchemist") + >>> query = session.query(Person).filter(Person.name == 'Alchemist') The above query is equivalent to:: - >>> query = session.query(Person).filter(Person.data["name"] == "Alchemist") + >>> query = session.query(Person).filter(Person.data['name'] == 'Alchemist') Multiple :class:`.index_property` objects can be chained to produce multiple levels of indexing:: @@ -125,25 +126,22 @@ multiple levels of indexing:: Base = declarative_base() - class Person(Base): - __tablename__ = "person" + __tablename__ = 'person' id = Column(Integer, primary_key=True) data = Column(JSON) - birthday = index_property("data", "birthday") - year = index_property("birthday", "year") - month = index_property("birthday", "month") - day = index_property("birthday", "day") + birthday = index_property('data', 'birthday') + year = index_property('birthday', 'year') + month = index_property('birthday', 'month') + day = index_property('birthday', 'day') Above, a query such as:: - q = session.query(Person).filter(Person.year == "1980") + q = session.query(Person).filter(Person.year == '1980') -On a PostgreSQL backend, the above query will render as: - -.. sourcecode:: sql +On a PostgreSQL backend, the above query will render as:: SELECT person.id, person.data FROM person @@ -200,14 +198,13 @@ version of :class:`_postgresql.JSON`:: Base = declarative_base() - class Person(Base): - __tablename__ = "person" + __tablename__ = 'person' id = Column(Integer, primary_key=True) data = Column(JSON) - age = pg_json_property("data", "age", Integer) + age = pg_json_property('data', 'age', Integer) The ``age`` attribute at the instance level works as before; however when rendering SQL, PostgreSQL's ``->>`` operator will be used @@ -215,9 +212,7 @@ for indexed access, instead of the usual index operator of ``->``:: >>> query = session.query(Person).filter(Person.age < 20) -The above query will render: - -.. sourcecode:: sql +The above query will render:: SELECT person.id, person.data FROM person diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/instrumentation.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/instrumentation.py index 8bb0198..688c762 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/instrumentation.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/instrumentation.py @@ -1,5 +1,5 @@ # ext/instrumentation.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -214,9 +214,9 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): )(instance) -orm_instrumentation._instrumentation_factory = _instrumentation_factory = ( - ExtendedInstrumentationRegistry() -) +orm_instrumentation._instrumentation_factory = ( + _instrumentation_factory +) = ExtendedInstrumentationRegistry() orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -436,15 +436,17 @@ def _install_lookups(lookups): instance_dict = lookups["instance_dict"] manager_of_class = lookups["manager_of_class"] opt_manager_of_class = lookups["opt_manager_of_class"] - orm_base.instance_state = attributes.instance_state = ( - orm_instrumentation.instance_state - ) = instance_state - orm_base.instance_dict = attributes.instance_dict = ( - orm_instrumentation.instance_dict - ) = instance_dict - orm_base.manager_of_class = attributes.manager_of_class = ( - orm_instrumentation.manager_of_class - ) = manager_of_class - orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = ( + orm_base.instance_state = ( + attributes.instance_state + ) = orm_instrumentation.instance_state = instance_state + orm_base.instance_dict = ( + attributes.instance_dict + ) = orm_instrumentation.instance_dict = instance_dict + orm_base.manager_of_class = ( + attributes.manager_of_class + ) = orm_instrumentation.manager_of_class = manager_of_class + orm_base.opt_manager_of_class = ( + orm_util.opt_manager_of_class + ) = ( attributes.opt_manager_of_class ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mutable.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mutable.py index 3d568fc..0f82518 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mutable.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mutable.py @@ -1,5 +1,5 @@ # ext/mutable.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -21,7 +21,6 @@ JSON strings before being persisted:: from sqlalchemy.types import TypeDecorator, VARCHAR import json - class JSONEncodedDict(TypeDecorator): "Represents an immutable structure as a json-encoded string." @@ -49,7 +48,6 @@ the :class:`.Mutable` mixin to a plain Python dictionary:: from sqlalchemy.ext.mutable import Mutable - class MutableDict(Mutable, dict): @classmethod def coerce(cls, key, value): @@ -103,11 +101,9 @@ attribute. Such as, with classical table metadata:: from sqlalchemy import Table, Column, Integer - my_data = Table( - "my_data", - metadata, - Column("id", Integer, primary_key=True), - Column("data", MutableDict.as_mutable(JSONEncodedDict)), + my_data = Table('my_data', metadata, + Column('id', Integer, primary_key=True), + Column('data', MutableDict.as_mutable(JSONEncodedDict)) ) Above, :meth:`~.Mutable.as_mutable` returns an instance of ``JSONEncodedDict`` @@ -119,17 +115,13 @@ mapping against the ``my_data`` table:: from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column - class Base(DeclarativeBase): pass - class MyDataClass(Base): - __tablename__ = "my_data" + __tablename__ = 'my_data' id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[dict[str, str]] = mapped_column( - MutableDict.as_mutable(JSONEncodedDict) - ) + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) The ``MyDataClass.data`` member will now be notified of in place changes to its value. @@ -140,11 +132,11 @@ will flag the attribute as "dirty" on the parent object:: >>> from sqlalchemy.orm import Session >>> sess = Session(some_engine) - >>> m1 = MyDataClass(data={"value1": "foo"}) + >>> m1 = MyDataClass(data={'value1':'foo'}) >>> sess.add(m1) >>> sess.commit() - >>> m1.data["value1"] = "bar" + >>> m1.data['value1'] = 'bar' >>> assert m1 in sess.dirty True @@ -161,16 +153,15 @@ the need to declare it individually:: MutableDict.associate_with(JSONEncodedDict) - class Base(DeclarativeBase): pass - class MyDataClass(Base): - __tablename__ = "my_data" + __tablename__ = 'my_data' id: Mapped[int] = mapped_column(primary_key=True) data: Mapped[dict[str, str]] = mapped_column(JSONEncodedDict) + Supporting Pickling -------------------- @@ -189,7 +180,7 @@ stream:: class MyMutableType(Mutable): def __getstate__(self): d = self.__dict__.copy() - d.pop("_parents", None) + d.pop('_parents', None) return d With our dictionary example, we need to return the contents of the dict itself @@ -222,18 +213,13 @@ from within the mutable extension:: from sqlalchemy.orm import mapped_column from sqlalchemy import event - class Base(DeclarativeBase): pass - class MyDataClass(Base): - __tablename__ = "my_data" + __tablename__ = 'my_data' id: Mapped[int] = mapped_column(primary_key=True) - data: Mapped[dict[str, str]] = mapped_column( - MutableDict.as_mutable(JSONEncodedDict) - ) - + data: Mapped[dict[str, str]] = mapped_column(MutableDict.as_mutable(JSONEncodedDict)) @event.listens_for(MyDataClass.data, "modified") def modified_json(instance, initiator): @@ -261,7 +247,6 @@ class introduced in :ref:`mapper_composite` to include import dataclasses from sqlalchemy.ext.mutable import MutableComposite - @dataclasses.dataclass class Point(MutableComposite): x: int @@ -276,6 +261,7 @@ class introduced in :ref:`mapper_composite` to include # alert all parents to the change self.changed() + The :class:`.MutableComposite` class makes use of class mapping events to automatically establish listeners for any usage of :func:`_orm.composite` that specifies our ``Point`` type. Below, when ``Point`` is mapped to the ``Vertex`` @@ -285,7 +271,6 @@ objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: from sqlalchemy.orm import DeclarativeBase, Mapped from sqlalchemy.orm import composite, mapped_column - class Base(DeclarativeBase): pass @@ -295,12 +280,8 @@ objects to each of the ``Vertex.start`` and ``Vertex.end`` attributes:: id: Mapped[int] = mapped_column(primary_key=True) - start: Mapped[Point] = composite( - mapped_column("x1"), mapped_column("y1") - ) - end: Mapped[Point] = composite( - mapped_column("x2"), mapped_column("y2") - ) + start: Mapped[Point] = composite(mapped_column("x1"), mapped_column("y1")) + end: Mapped[Point] = composite(mapped_column("x2"), mapped_column("y2")) def __repr__(self): return f"Vertex(start={self.start}, end={self.end})" @@ -397,7 +378,6 @@ from weakref import WeakKeyDictionary from .. import event from .. import inspect from .. import types -from .. import util from ..orm import Mapper from ..orm._typing import _ExternalEntityType from ..orm._typing import _O @@ -410,7 +390,6 @@ from ..orm.context import QueryContext from ..orm.decl_api import DeclarativeAttributeIntercept from ..orm.state import InstanceState from ..orm.unitofwork import UOWTransaction -from ..sql._typing import _TypeEngineArgument from ..sql.base import SchemaEventTarget from ..sql.schema import Column from ..sql.type_api import TypeEngine @@ -524,7 +503,6 @@ class MutableBase: if val is not None: if coerce: val = cls.coerce(key, val) - assert val is not None state.dict[key] = val val._parents[state] = key @@ -659,7 +637,7 @@ class Mutable(MutableBase): event.listen(Mapper, "mapper_configured", listen_for_type) @classmethod - def as_mutable(cls, sqltype: _TypeEngineArgument[_T]) -> TypeEngine[_T]: + def as_mutable(cls, sqltype: TypeEngine[_T]) -> TypeEngine[_T]: """Associate a SQL type with this mutable Python type. This establishes listeners that will detect ORM mappings against @@ -668,11 +646,9 @@ class Mutable(MutableBase): The type is returned, unconditionally as an instance, so that :meth:`.as_mutable` can be used inline:: - Table( - "mytable", - metadata, - Column("id", Integer, primary_key=True), - Column("data", MyMutableType.as_mutable(PickleType)), + Table('mytable', metadata, + Column('id', Integer, primary_key=True), + Column('data', MyMutableType.as_mutable(PickleType)) ) Note that the returned type is always an instance, even if a class @@ -823,12 +799,15 @@ class MutableDict(Mutable, Dict[_KT, _VT]): @overload def setdefault( self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload - def setdefault(self, key: _KT, value: _VT) -> _VT: ... + def setdefault(self, key: _KT, value: _VT) -> _VT: + ... - def setdefault(self, key: _KT, value: object = None) -> object: ... + def setdefault(self, key: _KT, value: object = None) -> object: + ... else: @@ -849,14 +828,17 @@ class MutableDict(Mutable, Dict[_KT, _VT]): if TYPE_CHECKING: @overload - def pop(self, __key: _KT) -> _VT: ... + def pop(self, __key: _KT) -> _VT: + ... @overload - def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: ... + def pop(self, __key: _KT, __default: _VT | _T) -> _VT | _T: + ... def pop( self, __key: _KT, __default: _VT | _T | None = None - ) -> _VT | _T: ... + ) -> _VT | _T: + ... else: @@ -927,10 +909,10 @@ class MutableList(Mutable, List[_T]): self[:] = state def is_scalar(self, value: _T | Iterable[_T]) -> TypeGuard[_T]: - return not util.is_non_string_iterable(value) + return not isinstance(value, Iterable) def is_iterable(self, value: _T | Iterable[_T]) -> TypeGuard[Iterable[_T]]: - return util.is_non_string_iterable(value) + return isinstance(value, Iterable) def __setitem__( self, index: SupportsIndex | slice, value: _T | Iterable[_T] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/__init__.py index b5827cb..e69de29 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/__init__.py @@ -1,6 +0,0 @@ -# ext/mypy/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors -# -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/apply.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/apply.py index 02908cc..1bfaf1d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/apply.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/apply.py @@ -1,5 +1,5 @@ # ext/mypy/apply.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -161,9 +161,9 @@ def re_apply_declarative_assignments( # update the SQLAlchemyAttribute with the better # information - mapped_attr_lookup[stmt.lvalues[0].name].type = ( - python_type_for_type - ) + mapped_attr_lookup[ + stmt.lvalues[0].name + ].type = python_type_for_type update_cls_metadata = True @@ -199,15 +199,11 @@ def apply_type_to_mapped_statement( To one that describes the final Python behavior to Mypy:: - ... format: off - class User(Base): # ... attrname : Mapped[Optional[int]] = - ... format: on - """ left_node = lvalue.node assert isinstance(left_node, Var) @@ -227,11 +223,9 @@ def apply_type_to_mapped_statement( lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, - ( - [AnyType(TypeOfAny.special_form)] - if python_type_for_type is None - else [python_type_for_type] - ), + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type], ) # so to have it skip the right side totally, we can do this: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/decl_class.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/decl_class.py index 2ce7ad5..9c7b44b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/decl_class.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/decl_class.py @@ -1,5 +1,5 @@ # ext/mypy/decl_class.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,9 +58,9 @@ def scan_declarative_assignments_and_apply_types( elif cls.fullname.startswith("builtins"): return None - mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = ( - util.get_mapped_attributes(info, api) - ) + mapped_attributes: Optional[ + List[util.SQLAlchemyAttribute] + ] = util.get_mapped_attributes(info, api) # used by assign.add_additional_orm_attributes among others util.establish_as_sqlalchemy(info) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/infer.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/infer.py index 26a83cc..e8345d0 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/infer.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/infer.py @@ -1,5 +1,5 @@ # ext/mypy/infer.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -385,9 +385,9 @@ def _infer_type_from_decl_column( class MyClass: # ... - a: Mapped[int] + a : Mapped[int] - b: Mapped[str] + b : Mapped[str] c: Mapped[int] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/names.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/names.py index 1eaef77..ae55ca4 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/names.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/names.py @@ -1,5 +1,5 @@ # ext/mypy/names.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -58,14 +58,6 @@ NAMED_TYPE_BUILTINS_STR = "builtins.str" NAMED_TYPE_BUILTINS_LIST = "builtins.list" NAMED_TYPE_SQLA_MAPPED = "sqlalchemy.orm.base.Mapped" -_RelFullNames = { - "sqlalchemy.orm.relationships.Relationship", - "sqlalchemy.orm.relationships.RelationshipProperty", - "sqlalchemy.orm.relationships._RelationshipDeclared", - "sqlalchemy.orm.Relationship", - "sqlalchemy.orm.RelationshipProperty", -} - _lookup: Dict[str, Tuple[int, Set[str]]] = { "Column": ( COLUMN, @@ -74,9 +66,24 @@ _lookup: Dict[str, Tuple[int, Set[str]]] = { "sqlalchemy.sql.Column", }, ), - "Relationship": (RELATIONSHIP, _RelFullNames), - "RelationshipProperty": (RELATIONSHIP, _RelFullNames), - "_RelationshipDeclared": (RELATIONSHIP, _RelFullNames), + "Relationship": ( + RELATIONSHIP, + { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", + }, + ), + "RelationshipProperty": ( + RELATIONSHIP, + { + "sqlalchemy.orm.relationships.Relationship", + "sqlalchemy.orm.relationships.RelationshipProperty", + "sqlalchemy.orm.Relationship", + "sqlalchemy.orm.RelationshipProperty", + }, + ), "registry": ( REGISTRY, { @@ -297,7 +304,7 @@ def type_id_for_callee(callee: Expression) -> Optional[int]: def type_id_for_named_node( - node: Union[NameExpr, MemberExpr, SymbolNode], + node: Union[NameExpr, MemberExpr, SymbolNode] ) -> Optional[int]: type_id, fullnames = _lookup.get(node.name, (None, None)) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/plugin.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/plugin.py index 1ec2c02..862d7d2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/plugin.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/plugin.py @@ -1,5 +1,5 @@ # ext/mypy/plugin.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/util.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/util.py index 16761b9..238c82a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/util.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/mypy/util.py @@ -1,5 +1,5 @@ # ext/mypy/util.py -# Copyright (C) 2021-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2021-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -80,7 +80,7 @@ class SQLAlchemyAttribute: "name": self.name, "line": self.line, "column": self.column, - "type": serialize_type(self.type), + "type": self.type.serialize(), } def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: @@ -212,7 +212,8 @@ def add_global( @overload def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, expr_types: None = ... -) -> Optional[Union[CallExpr, NameExpr]]: ... +) -> Optional[Union[CallExpr, NameExpr]]: + ... @overload @@ -221,7 +222,8 @@ def get_callexpr_kwarg( name: str, *, expr_types: Tuple[TypingType[_TArgType], ...], -) -> Optional[_TArgType]: ... +) -> Optional[_TArgType]: + ... def get_callexpr_kwarg( @@ -313,11 +315,9 @@ def unbound_to_instance( return Instance( bound_type, [ - ( - unbound_to_instance(api, arg) - if isinstance(arg, UnboundType) - else arg - ) + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg for arg in typ.args ], ) @@ -336,22 +336,3 @@ def info_for_cls( return sym.node return cls.info - - -def serialize_type(typ: Type) -> Union[str, JsonDict]: - try: - return typ.serialize() - except Exception: - pass - if hasattr(typ, "args"): - typ.args = tuple( - ( - a.resolve_string_annotation() - if hasattr(a, "resolve_string_annotation") - else a - ) - for a in typ.args - ) - elif hasattr(typ, "resolve_string_annotation"): - typ = typ.resolve_string_annotation() - return typ.serialize() diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/orderinglist.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/orderinglist.py index a39d216..a6c42ff 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/orderinglist.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/orderinglist.py @@ -1,9 +1,10 @@ # ext/orderinglist.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors """A custom list that manages index/position information for contained elements. @@ -25,20 +26,18 @@ displayed in order based on the value of the ``position`` column in the Base = declarative_base() - class Slide(Base): - __tablename__ = "slide" + __tablename__ = 'slide' id = Column(Integer, primary_key=True) name = Column(String) bullets = relationship("Bullet", order_by="Bullet.position") - class Bullet(Base): - __tablename__ = "bullet" + __tablename__ = 'bullet' id = Column(Integer, primary_key=True) - slide_id = Column(Integer, ForeignKey("slide.id")) + slide_id = Column(Integer, ForeignKey('slide.id')) position = Column(Integer) text = Column(String) @@ -58,24 +57,19 @@ constructed using the :func:`.ordering_list` factory:: Base = declarative_base() - class Slide(Base): - __tablename__ = "slide" + __tablename__ = 'slide' id = Column(Integer, primary_key=True) name = Column(String) - bullets = relationship( - "Bullet", - order_by="Bullet.position", - collection_class=ordering_list("position"), - ) - + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) class Bullet(Base): - __tablename__ = "bullet" + __tablename__ = 'bullet' id = Column(Integer, primary_key=True) - slide_id = Column(Integer, ForeignKey("slide.id")) + slide_id = Column(Integer, ForeignKey('slide.id')) position = Column(Integer) text = Column(String) @@ -128,24 +122,17 @@ start numbering at 1 or some other integer, provide ``count_from=1``. """ from __future__ import annotations -from typing import Any from typing import Callable -from typing import Dict -from typing import Iterable from typing import List from typing import Optional -from typing import overload from typing import Sequence -from typing import Type from typing import TypeVar -from typing import Union from ..orm.collections import collection from ..orm.collections import collection_adapter -from ..util.typing import SupportsIndex _T = TypeVar("_T") -OrderingFunc = Callable[[int, Sequence[_T]], object] +OrderingFunc = Callable[[int, Sequence[_T]], int] __all__ = ["ordering_list"] @@ -154,9 +141,9 @@ __all__ = ["ordering_list"] def ordering_list( attr: str, count_from: Optional[int] = None, - ordering_func: Optional[OrderingFunc[_T]] = None, + ordering_func: Optional[OrderingFunc] = None, reorder_on_append: bool = False, -) -> Callable[[], OrderingList[_T]]: +) -> Callable[[], OrderingList]: """Prepares an :class:`OrderingList` factory for use in mapper definitions. Returns an object suitable for use as an argument to a Mapper @@ -164,18 +151,14 @@ def ordering_list( from sqlalchemy.ext.orderinglist import ordering_list - class Slide(Base): - __tablename__ = "slide" + __tablename__ = 'slide' id = Column(Integer, primary_key=True) name = Column(String) - bullets = relationship( - "Bullet", - order_by="Bullet.position", - collection_class=ordering_list("position"), - ) + bullets = relationship("Bullet", order_by="Bullet.position", + collection_class=ordering_list('position')) :param attr: Name of the mapped attribute to use for storage and retrieval of @@ -202,22 +185,22 @@ def ordering_list( # Ordering utility functions -def count_from_0(index: int, collection: object) -> int: +def count_from_0(index, collection): """Numbering function: consecutive integers starting at 0.""" return index -def count_from_1(index: int, collection: object) -> int: +def count_from_1(index, collection): """Numbering function: consecutive integers starting at 1.""" return index + 1 -def count_from_n_factory(start: int) -> OrderingFunc[Any]: +def count_from_n_factory(start): """Numbering function: consecutive integers starting at arbitrary start.""" - def f(index: int, collection: object) -> int: + def f(index, collection): return index + start try: @@ -227,7 +210,7 @@ def count_from_n_factory(start: int) -> OrderingFunc[Any]: return f -def _unsugar_count_from(**kw: Any) -> Dict[str, Any]: +def _unsugar_count_from(**kw): """Builds counting functions from keyword arguments. Keyword argument filter, prepares a simple ``ordering_func`` from a @@ -255,13 +238,13 @@ class OrderingList(List[_T]): """ ordering_attr: str - ordering_func: OrderingFunc[_T] + ordering_func: OrderingFunc reorder_on_append: bool def __init__( self, - ordering_attr: str, - ordering_func: Optional[OrderingFunc[_T]] = None, + ordering_attr: Optional[str] = None, + ordering_func: Optional[OrderingFunc] = None, reorder_on_append: bool = False, ): """A custom list that manages position information for its children. @@ -321,10 +304,10 @@ class OrderingList(List[_T]): # More complex serialization schemes (multi column, e.g.) are possible by # subclassing and reimplementing these two methods. - def _get_order_value(self, entity: _T) -> Any: + def _get_order_value(self, entity): return getattr(entity, self.ordering_attr) - def _set_order_value(self, entity: _T, value: Any) -> None: + def _set_order_value(self, entity, value): setattr(entity, self.ordering_attr, value) def reorder(self) -> None: @@ -340,9 +323,7 @@ class OrderingList(List[_T]): # As of 0.5, _reorder is no longer semi-private _reorder = reorder - def _order_entity( - self, index: int, entity: _T, reorder: bool = True - ) -> None: + def _order_entity(self, index, entity, reorder=True): have = self._get_order_value(entity) # Don't disturb existing ordering if reorder is False @@ -353,44 +334,34 @@ class OrderingList(List[_T]): if have != should_be: self._set_order_value(entity, should_be) - def append(self, entity: _T) -> None: + def append(self, entity): super().append(entity) self._order_entity(len(self) - 1, entity, self.reorder_on_append) - def _raw_append(self, entity: _T) -> None: + def _raw_append(self, entity): """Append without any ordering behavior.""" super().append(entity) _raw_append = collection.adds(1)(_raw_append) - def insert(self, index: SupportsIndex, entity: _T) -> None: + def insert(self, index, entity): super().insert(index, entity) self._reorder() - def remove(self, entity: _T) -> None: + def remove(self, entity): super().remove(entity) adapter = collection_adapter(self) if adapter and adapter._referenced_by_owner: self._reorder() - def pop(self, index: SupportsIndex = -1) -> _T: + def pop(self, index=-1): entity = super().pop(index) self._reorder() return entity - @overload - def __setitem__(self, index: SupportsIndex, entity: _T) -> None: ... - - @overload - def __setitem__(self, index: slice, entity: Iterable[_T]) -> None: ... - - def __setitem__( - self, - index: Union[SupportsIndex, slice], - entity: Union[_T, Iterable[_T]], - ) -> None: + def __setitem__(self, index, entity): if isinstance(index, slice): step = index.step or 1 start = index.start or 0 @@ -399,18 +370,26 @@ class OrderingList(List[_T]): stop = index.stop or len(self) if stop < 0: stop += len(self) - entities = list(entity) # type: ignore[arg-type] - for i in range(start, stop, step): - self.__setitem__(i, entities[i]) - else: - self._order_entity(int(index), entity, True) # type: ignore[arg-type] # noqa: E501 - super().__setitem__(index, entity) # type: ignore[assignment] - def __delitem__(self, index: Union[SupportsIndex, slice]) -> None: + for i in range(start, stop, step): + self.__setitem__(i, entity[i]) + else: + self._order_entity(index, entity, True) + super().__setitem__(index, entity) + + def __delitem__(self, index): super().__delitem__(index) self._reorder() - def __reduce__(self) -> Any: + def __setslice__(self, start, end, values): + super().__setslice__(start, end, values) + self._reorder() + + def __delslice__(self, start, end): + super().__delslice__(start, end) + self._reorder() + + def __reduce__(self): return _reconstitute, (self.__class__, self.__dict__, list(self)) for func_name, func in list(locals().items()): @@ -424,9 +403,7 @@ class OrderingList(List[_T]): del func_name, func -def _reconstitute( - cls: Type[OrderingList[_T]], dict_: Dict[str, Any], items: List[_T] -) -> OrderingList[_T]: +def _reconstitute(cls, dict_, items): """Reconstitute an :class:`.OrderingList`. This is the adjoint to :meth:`.OrderingList.__reduce__`. It is used for diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/ext/serializer.py b/venv/lib/python3.12/site-packages/sqlalchemy/ext/serializer.py index b7032b6..706bff2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/ext/serializer.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/ext/serializer.py @@ -1,5 +1,5 @@ # ext/serializer.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,17 +28,13 @@ when it is deserialized. Usage is nearly the same as that of the standard Python pickle module:: from sqlalchemy.ext.serializer import loads, dumps - metadata = MetaData(bind=some_engine) Session = scoped_session(sessionmaker()) # ... define mappers - query = ( - Session.query(MyClass) - .filter(MyClass.somedata == "foo") - .order_by(MyClass.sortkey) - ) + query = Session.query(MyClass). + filter(MyClass.somedata=='foo').order_by(MyClass.sortkey) # pickle the query serialized = dumps(query) @@ -46,7 +42,7 @@ Usage is nearly the same as that of the standard Python pickle module:: # unpickle. Pass in metadata + scoped_session query2 = loads(serialized, metadata, Session) - print(query2.all()) + print query2.all() Similar restrictions as when using raw pickle apply; mapped classes must be themselves be pickleable, meaning they are importable from a module-level @@ -86,9 +82,10 @@ from ..util import b64encode __all__ = ["Serializer", "Deserializer", "dumps", "loads"] -class Serializer(pickle.Pickler): +def Serializer(*args, **kw): + pickler = pickle.Pickler(*args, **kw) - def persistent_id(self, obj): + def persistent_id(obj): # print "serializing:", repr(obj) if isinstance(obj, Mapper) and not obj.non_primary: id_ = "mapper:" + b64encode(pickle.dumps(obj.class_)) @@ -116,6 +113,9 @@ class Serializer(pickle.Pickler): return None return id_ + pickler.persistent_id = persistent_id + return pickler + our_ids = re.compile( r"(mapperprop|mapper|mapper_selectable|table|column|" @@ -123,23 +123,20 @@ our_ids = re.compile( ) -class Deserializer(pickle.Unpickler): +def Deserializer(file, metadata=None, scoped_session=None, engine=None): + unpickler = pickle.Unpickler(file) - def __init__(self, file, metadata=None, scoped_session=None, engine=None): - super().__init__(file) - self.metadata = metadata - self.scoped_session = scoped_session - self.engine = engine - - def get_engine(self): - if self.engine: - return self.engine - elif self.scoped_session and self.scoped_session().bind: - return self.scoped_session().bind + def get_engine(): + if engine: + return engine + elif scoped_session and scoped_session().bind: + return scoped_session().bind + elif metadata and metadata.bind: + return metadata.bind else: return None - def persistent_load(self, id_): + def persistent_load(id_): m = our_ids.match(str(id_)) if not m: return None @@ -160,17 +157,20 @@ class Deserializer(pickle.Unpickler): cls = pickle.loads(b64decode(mapper)) return class_mapper(cls).attrs[keyname] elif type_ == "table": - return self.metadata.tables[args] + return metadata.tables[args] elif type_ == "column": table, colname = args.split(":") - return self.metadata.tables[table].c[colname] + return metadata.tables[table].c[colname] elif type_ == "session": - return self.scoped_session() + return scoped_session() elif type_ == "engine": - return self.get_engine() + return get_engine() else: raise Exception("Unknown token: %s" % type_) + unpickler.persistent_load = persistent_load + return unpickler + def dumps(obj, protocol=pickle.HIGHEST_PROTOCOL): buf = BytesIO() diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/future/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/future/__init__.py index ef9afb1..bfc31d4 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/future/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/future/__init__.py @@ -1,5 +1,5 @@ -# future/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sql/future/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/future/engine.py b/venv/lib/python3.12/site-packages/sqlalchemy/future/engine.py index 0449c3d..1984f34 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/future/engine.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/future/engine.py @@ -1,5 +1,5 @@ -# future/engine.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sql/future/engine.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/inspection.py b/venv/lib/python3.12/site-packages/sqlalchemy/inspection.py index 2e5b220..7d8479b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/inspection.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/inspection.py @@ -1,5 +1,5 @@ -# inspection.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/inspect.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -74,7 +74,8 @@ class _InspectableTypeProtocol(Protocol[_TCov]): """ - def _sa_inspect_type(self) -> _TCov: ... + def _sa_inspect_type(self) -> _TCov: + ... class _InspectableProtocol(Protocol[_TCov]): @@ -83,31 +84,35 @@ class _InspectableProtocol(Protocol[_TCov]): """ - def _sa_inspect_instance(self) -> _TCov: ... + def _sa_inspect_instance(self) -> _TCov: + ... @overload def inspect( subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True -) -> _IN: ... +) -> _IN: + ... @overload -def inspect( - subject: _InspectableProtocol[_IN], raiseerr: bool = True -) -> _IN: ... +def inspect(subject: _InspectableProtocol[_IN], raiseerr: bool = True) -> _IN: + ... @overload -def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: ... +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: + ... @overload -def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: + ... @overload -def inspect(subject: Any, raiseerr: bool = True) -> Any: ... +def inspect(subject: Any, raiseerr: bool = True) -> Any: + ... def inspect(subject: Any, raiseerr: bool = True) -> Any: @@ -157,7 +162,9 @@ def _inspects( def decorate(fn_or_cls: _F) -> _F: for type_ in types: if type_ in _registrars: - raise AssertionError("Type %s is already registered" % type_) + raise AssertionError( + "Type %s is already " "registered" % type_ + ) _registrars[type_] = fn_or_cls return fn_or_cls @@ -169,6 +176,6 @@ _TT = TypeVar("_TT", bound="Type[Any]") def _self_inspects(cls: _TT) -> _TT: if cls in _registrars: - raise AssertionError("Type %s is already registered" % cls) + raise AssertionError("Type %s is already " "registered" % cls) _registrars[cls] = True return cls diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/log.py b/venv/lib/python3.12/site-packages/sqlalchemy/log.py index 849a0bf..8de6d18 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/log.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/log.py @@ -1,5 +1,5 @@ -# log.py -# Copyright (C) 2006-2025 the SQLAlchemy authors and contributors +# sqlalchemy/log.py +# Copyright (C) 2006-2023 the SQLAlchemy authors and contributors # # Includes alterations by Vinay Sajip vinay_sajip@yahoo.co.uk # @@ -269,12 +269,14 @@ class echo_property: @overload def __get__( self, instance: Literal[None], owner: Type[Identified] - ) -> echo_property: ... + ) -> echo_property: + ... @overload def __get__( self, instance: Identified, owner: Type[Identified] - ) -> _EchoFlagType: ... + ) -> _EchoFlagType: + ... def __get__( self, instance: Optional[Identified], owner: Type[Identified] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/__init__.py index 7771de4..f6888ae 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/__init__.py @@ -1,5 +1,5 @@ # orm/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/_orm_constructors.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/_orm_constructors.py index 9c07bf1..df36c38 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/_orm_constructors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/_orm_constructors.py @@ -1,5 +1,5 @@ # orm/_orm_constructors.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,6 @@ from typing import Any from typing import Callable from typing import Collection from typing import Iterable -from typing import Mapping from typing import NoReturn from typing import Optional from typing import overload @@ -29,8 +28,8 @@ from .properties import MappedColumn from .properties import MappedSQLExpression from .query import AliasOption from .relationships import _RelationshipArgumentType -from .relationships import _RelationshipDeclared from .relationships import _RelationshipSecondaryArgument +from .relationships import Relationship from .relationships import RelationshipProperty from .session import Session from .util import _ORMJoin @@ -71,7 +70,7 @@ if TYPE_CHECKING: from ..sql._typing import _TypeEngineArgument from ..sql.elements import ColumnElement from ..sql.schema import _ServerDefaultArgument - from ..sql.schema import _ServerOnUpdateArgument + from ..sql.schema import FetchedValue from ..sql.selectable import Alias from ..sql.selectable import Subquery @@ -109,7 +108,6 @@ def mapped_column( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 nullable: Optional[ Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] ] = SchemaConst.NULL_UNSPECIFIED, @@ -129,13 +127,12 @@ def mapped_column( onupdate: Optional[Any] = None, insert_default: Optional[Any] = _NoArg.NO_ARG, server_default: Optional[_ServerDefaultArgument] = None, - server_onupdate: Optional[_ServerOnUpdateArgument] = None, + server_onupdate: Optional[FetchedValue] = None, active_history: bool = False, quote: Optional[bool] = None, system: bool = False, comment: Optional[str] = None, sort_order: Union[_NoArg, int] = _NoArg.NO_ARG, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, **kw: Any, ) -> MappedColumn[Any]: r"""declare a new ORM-mapped :class:`_schema.Column` construct @@ -189,9 +186,9 @@ def mapped_column( :class:`_schema.Column`. :param nullable: Optional bool, whether the column should be "NULL" or "NOT NULL". If omitted, the nullability is derived from the type - annotation based on whether or not ``typing.Optional`` (or its equivalent) - is present. ``nullable`` defaults to ``True`` otherwise for non-primary - key columns, and ``False`` for primary key columns. + annotation based on whether or not ``typing.Optional`` is present. + ``nullable`` defaults to ``True`` otherwise for non-primary key columns, + and ``False`` for primary key columns. :param primary_key: optional bool, indicates the :class:`_schema.Column` would be part of the table's primary key or not. :param deferred: Optional bool - this keyword argument is consumed by the @@ -258,28 +255,12 @@ def mapped_column( be used instead**. This is necessary to disambiguate the callable from being interpreted as a dataclass level default. - .. seealso:: - - :ref:`defaults_default_factory_insert_default` - - :paramref:`_orm.mapped_column.insert_default` - - :paramref:`_orm.mapped_column.default_factory` - :param insert_default: Passed directly to the :paramref:`_schema.Column.default` parameter; will supersede the value of :paramref:`_orm.mapped_column.default` when present, however :paramref:`_orm.mapped_column.default` will always apply to the constructor default for a dataclasses mapping. - .. seealso:: - - :ref:`defaults_default_factory_insert_default` - - :paramref:`_orm.mapped_column.default` - - :paramref:`_orm.mapped_column.default_factory` - :param sort_order: An integer that indicates how this mapped column should be sorted compared to the others when the ORM is creating a :class:`_schema.Table`. Among mapped columns that have the same @@ -314,15 +295,6 @@ def mapped_column( specifies a default-value generation function that will take place as part of the ``__init__()`` method as generated by the dataclass process. - - .. seealso:: - - :ref:`defaults_default_factory_insert_default` - - :paramref:`_orm.mapped_column.default` - - :paramref:`_orm.mapped_column.insert_default` - :param compare: Specific to :ref:`orm_declarative_native_dataclasses`, indicates if this field should be included in comparison operations when generating the @@ -334,19 +306,6 @@ def mapped_column( :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. - :param hash: Specific to - :ref:`orm_declarative_native_dataclasses`, controls if this field - is included when generating the ``__hash__()`` method for the mapped - class. - - .. versionadded:: 2.0.36 - - :param dataclass_metadata: Specific to - :ref:`orm_declarative_native_dataclasses`, supplies metadata - to be attached to the generated dataclass field. - - .. versionadded:: 2.0.42 - :param \**kw: All remaining keyword arguments are passed through to the constructor for the :class:`_schema.Column`. @@ -361,14 +320,7 @@ def mapped_column( autoincrement=autoincrement, insert_default=insert_default, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), doc=doc, key=key, @@ -433,9 +385,9 @@ def orm_insert_sentinel( return mapped_column( name=name, - default=( - default if default is not None else _InsertSentinelColumnDefault() - ), + default=default + if default is not None + else _InsertSentinelColumnDefault(), _omit_from_statements=omit_from_statements, insert_sentinel=True, use_existing_column=True, @@ -463,18 +415,16 @@ def column_property( deferred: bool = False, raiseload: bool = False, comparator_factory: Optional[Type[PropComparator[_T]]] = None, - init: Union[_NoArg, bool] = _NoArg.NO_ARG, + init: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 repr: Union[_NoArg, bool] = _NoArg.NO_ARG, # noqa: A002 default: Optional[Any] = _NoArg.NO_ARG, default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, doc: Optional[str] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, ) -> MappedSQLExpression[_T]: r"""Provide a column-level property for use with a mapping. @@ -559,49 +509,13 @@ def column_property( :ref:`orm_queryguide_deferred_raiseload` - :param init: Specific to :ref:`orm_declarative_native_dataclasses`, - specifies if the mapped attribute should be part of the ``__init__()`` - method as generated by the dataclass process. - :param repr: Specific to :ref:`orm_declarative_native_dataclasses`, - specifies if the mapped attribute should be part of the ``__repr__()`` - method as generated by the dataclass process. - :param default_factory: Specific to - :ref:`orm_declarative_native_dataclasses`, - specifies a default-value generation function that will take place - as part of the ``__init__()`` - method as generated by the dataclass process. + :param init: - .. seealso:: + :param default: - :ref:`defaults_default_factory_insert_default` + :param default_factory: - :paramref:`_orm.mapped_column.default` - - :paramref:`_orm.mapped_column.insert_default` - - :param compare: Specific to - :ref:`orm_declarative_native_dataclasses`, indicates if this field - should be included in comparison operations when generating the - ``__eq__()`` and ``__ne__()`` methods for the mapped class. - - .. versionadded:: 2.0.0b4 - - :param kw_only: Specific to - :ref:`orm_declarative_native_dataclasses`, indicates if this field - should be marked as keyword-only when generating the ``__init__()``. - - :param hash: Specific to - :ref:`orm_declarative_native_dataclasses`, controls if this field - is included when generating the ``__hash__()`` method for the mapped - class. - - .. versionadded:: 2.0.36 - - :param dataclass_metadata: Specific to - :ref:`orm_declarative_native_dataclasses`, supplies metadata - to be attached to the generated dataclass field. - - .. versionadded:: 2.0.42 + :param kw_only: """ return MappedSQLExpression( @@ -614,8 +528,6 @@ def column_property( default_factory, compare, kw_only, - hash, - dataclass_metadata, ), group=group, deferred=deferred, @@ -644,12 +556,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, **__kw: Any, -) -> Composite[Any]: ... +) -> Composite[Any]: + ... @overload @@ -667,11 +578,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: ... +) -> Composite[_CC]: + ... @overload @@ -689,11 +600,11 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: ... +) -> Composite[_CC]: + ... def composite( @@ -712,10 +623,8 @@ def composite( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, **__kw: Any, ) -> Composite[Any]: r"""Return a composite column-based property for use with a Mapper. @@ -788,19 +697,6 @@ def composite( :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. - :param hash: Specific to - :ref:`orm_declarative_native_dataclasses`, controls if this field - is included when generating the ``__hash__()`` method for the mapped - class. - - .. versionadded:: 2.0.36 - - :param dataclass_metadata: Specific to - :ref:`orm_declarative_native_dataclasses`, supplies metadata - to be attached to the generated dataclass field. - - .. versionadded:: 2.0.42 - """ if __kw: raise _no_kw() @@ -809,14 +705,7 @@ def composite( _class_or_attr, *attrs, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), group=group, deferred=deferred, @@ -830,10 +719,7 @@ def composite( def with_loader_criteria( entity_or_base: _EntityType[Any], - where_criteria: Union[ - _ColumnExpressionArgument[bool], - Callable[[Any], _ColumnExpressionArgument[bool]], - ], + where_criteria: _ColumnExpressionArgument[bool], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, @@ -862,7 +748,7 @@ def with_loader_criteria( stmt = select(User).options( selectinload(User.addresses), - with_loader_criteria(Address, Address.email_address != "foo"), + with_loader_criteria(Address, Address.email_address != 'foo')) ) Above, the "selectinload" for ``User.addresses`` will apply the @@ -872,10 +758,8 @@ def with_loader_criteria( ON clause of the join, in this example using :term:`1.x style` queries:: - q = ( - session.query(User) - .outerjoin(User.addresses) - .options(with_loader_criteria(Address, Address.email_address != "foo")) + q = session.query(User).outerjoin(User.addresses).options( + with_loader_criteria(Address, Address.email_address != 'foo')) ) The primary purpose of :func:`_orm.with_loader_criteria` is to use @@ -888,7 +772,6 @@ def with_loader_criteria( session = Session(bind=engine) - @event.listens_for("do_orm_execute", session) def _add_filtering_criteria(execute_state): @@ -900,8 +783,8 @@ def with_loader_criteria( execute_state.statement = execute_state.statement.options( with_loader_criteria( SecurityRole, - lambda cls: cls.role.in_(["some_role"]), - include_aliases=True, + lambda cls: cls.role.in_(['some_role']), + include_aliases=True ) ) @@ -938,19 +821,16 @@ def with_loader_criteria( ``A -> A.bs -> B``, the given :func:`_orm.with_loader_criteria` option will affect the way in which the JOIN is rendered:: - stmt = ( - select(A) - .join(A.bs) - .options(contains_eager(A.bs), with_loader_criteria(B, B.flag == 1)) + stmt = select(A).join(A.bs).options( + contains_eager(A.bs), + with_loader_criteria(B, B.flag == 1) ) Above, the given :func:`_orm.with_loader_criteria` option will affect the ON clause of the JOIN that is specified by ``.join(A.bs)``, so is applied as expected. The :func:`_orm.contains_eager` option has the effect that columns from - ``B`` are added to the columns clause: - - .. sourcecode:: sql + ``B`` are added to the columns clause:: SELECT b.id, b.a_id, b.data, b.flag, @@ -1016,7 +896,7 @@ def with_loader_criteria( .. versionadded:: 1.4.0b2 - """ # noqa: E501 + """ return LoaderCriteriaOption( entity_or_base, where_criteria, @@ -1050,7 +930,6 @@ def relationship( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 lazy: _LazyLoadArgumentType = "select", passive_deletes: Union[Literal["all"], bool] = False, passive_updates: bool = True, @@ -1070,9 +949,8 @@ def relationship( info: Optional[_InfoType] = None, omit_join: Literal[None, False] = None, sync_backref: Optional[bool] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, **kw: Any, -) -> _RelationshipDeclared[Any]: +) -> Relationship[Any]: """Provide a relationship between two mapped classes. This corresponds to a parent-child or associative table relationship. @@ -1810,10 +1688,19 @@ def relationship( the full set of related objects, to prevent modifications of the collection from resulting in persistence operations. + When using the :paramref:`_orm.relationship.viewonly` flag in + conjunction with backrefs, the originating relationship for a + particular state change will not produce state changes within the + viewonly relationship. This is the behavior implied by + :paramref:`_orm.relationship.sync_backref` being set to False. + + .. versionchanged:: 1.3.17 - the + :paramref:`_orm.relationship.sync_backref` flag is set to False + when using viewonly in conjunction with backrefs. + .. seealso:: - :ref:`relationship_viewonly_notes` - more details on best practices - when using :paramref:`_orm.relationship.viewonly`. + :paramref:`_orm.relationship.sync_backref` :param sync_backref: A boolean that enables the events used to synchronize the in-Python @@ -1875,22 +1762,10 @@ def relationship( :ref:`orm_declarative_native_dataclasses`, indicates if this field should be marked as keyword-only when generating the ``__init__()``. - :param hash: Specific to - :ref:`orm_declarative_native_dataclasses`, controls if this field - is included when generating the ``__hash__()`` method for the mapped - class. - - .. versionadded:: 2.0.36 - - :param dataclass_metadata: Specific to - :ref:`orm_declarative_native_dataclasses`, supplies metadata - to be attached to the generated dataclass field. - - .. versionadded:: 2.0.42 """ - return _RelationshipDeclared( + return Relationship( argument, secondary=secondary, uselist=uselist, @@ -1905,14 +1780,7 @@ def relationship( cascade=cascade, viewonly=viewonly, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), lazy=lazy, passive_deletes=passive_deletes, @@ -1947,10 +1815,8 @@ def synonym( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 info: Optional[_InfoType] = None, doc: Optional[str] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, ) -> Synonym[Any]: """Denote an attribute name as a synonym to a mapped property, in that the attribute will mirror the value and expression behavior @@ -1959,13 +1825,14 @@ def synonym( e.g.:: class MyClass(Base): - __tablename__ = "my_table" + __tablename__ = 'my_table' id = Column(Integer, primary_key=True) job_status = Column(String(50)) status = synonym("job_status") + :param name: the name of the existing mapped property. This can refer to the string name ORM-mapped attribute configured on the class, including column-bound attributes @@ -1993,13 +1860,11 @@ def synonym( :paramref:`.synonym.descriptor` parameter:: my_table = Table( - "my_table", - metadata, - Column("id", Integer, primary_key=True), - Column("job_status", String(50)), + "my_table", metadata, + Column('id', Integer, primary_key=True), + Column('job_status', String(50)) ) - class MyClass: @property def _job_status_descriptor(self): @@ -2007,15 +1872,11 @@ def synonym( mapper( - MyClass, - my_table, - properties={ + MyClass, my_table, properties={ "job_status": synonym( - "_job_status", - map_column=True, - descriptor=MyClass._job_status_descriptor, - ) - }, + "_job_status", map_column=True, + descriptor=MyClass._job_status_descriptor) + } ) Above, the attribute named ``_job_status`` is automatically @@ -2064,14 +1925,7 @@ def synonym( descriptor=descriptor, comparator_factory=comparator_factory, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), doc=doc, info=info, @@ -2172,7 +2026,8 @@ def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: E.g.:: - "items": relationship(SomeItem, backref=backref("parent", lazy="subquery")) + 'items':relationship( + SomeItem, backref=backref('parent', lazy='subquery')) The :paramref:`_orm.relationship.backref` parameter is generally considered to be legacy; for modern applications, using @@ -2184,7 +2039,7 @@ def backref(name: str, **kwargs: Any) -> ORMBackrefArgument: :ref:`relationships_backref` - background on backrefs - """ # noqa: E501 + """ return (name, kwargs) @@ -2201,12 +2056,10 @@ def deferred( default_factory: Union[_NoArg, Callable[[], _T]] = _NoArg.NO_ARG, compare: Union[_NoArg, bool] = _NoArg.NO_ARG, kw_only: Union[_NoArg, bool] = _NoArg.NO_ARG, - hash: Union[_NoArg, bool, None] = _NoArg.NO_ARG, # noqa: A002 active_history: bool = False, expire_on_flush: bool = True, info: Optional[_InfoType] = None, doc: Optional[str] = None, - dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] = _NoArg.NO_ARG, ) -> MappedSQLExpression[_T]: r"""Indicate a column-based mapped attribute that by default will not load unless accessed. @@ -2237,14 +2090,7 @@ def deferred( column, *additional_columns, attribute_options=_AttributeOptions( - init, - repr, - default, - default_factory, - compare, - kw_only, - hash, - dataclass_metadata, + init, repr, default, default_factory, compare, kw_only ), group=group, deferred=True, @@ -2287,8 +2133,6 @@ def query_expression( _NoArg.NO_ARG, compare, _NoArg.NO_ARG, - _NoArg.NO_ARG, - _NoArg.NO_ARG, ), expire_on_flush=expire_on_flush, info=info, @@ -2342,7 +2186,8 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedType[_O]: ... +) -> AliasedType[_O]: + ... @overload @@ -2352,7 +2197,8 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedClass[_O]: ... +) -> AliasedClass[_O]: + ... @overload @@ -2362,7 +2208,8 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> FromClause: ... +) -> FromClause: + ... def aliased( @@ -2435,16 +2282,6 @@ def aliased( supported by all modern databases with regards to right-nested joins and generally produces more efficient queries. - When :paramref:`_orm.aliased.flat` is combined with - :paramref:`_orm.aliased.name`, the resulting joins will alias individual - tables using a naming scheme similar to ``_``. This - naming scheme is for visibility / debugging purposes only and the - specific scheme is subject to change without notice. - - .. versionadded:: 2.0.32 added support for combining - :paramref:`_orm.aliased.name` with :paramref:`_orm.aliased.flat`. - Previously, this would raise ``NotImplementedError``. - :param adapt_on_names: if True, more liberal "matching" will be used when mapping the mapped columns of the ORM entity to those of the given selectable - a name-based match will be performed if the @@ -2454,21 +2291,17 @@ def aliased( aggregate functions:: class UnitPrice(Base): - __tablename__ = "unit_price" + __tablename__ = 'unit_price' ... unit_id = Column(Integer) price = Column(Numeric) + aggregated_unit_price = Session.query( + func.sum(UnitPrice.price).label('price') + ).group_by(UnitPrice.unit_id).subquery() - aggregated_unit_price = ( - Session.query(func.sum(UnitPrice.price).label("price")) - .group_by(UnitPrice.unit_id) - .subquery() - ) - - aggregated_unit_price = aliased( - UnitPrice, alias=aggregated_unit_price, adapt_on_names=True - ) + aggregated_unit_price = aliased(UnitPrice, + alias=aggregated_unit_price, adapt_on_names=True) Above, functions on ``aggregated_unit_price`` which refer to ``.price`` will return the @@ -2496,7 +2329,6 @@ def with_polymorphic( aliased: bool = False, innerjoin: bool = False, adapt_on_names: bool = False, - name: Optional[str] = None, _use_mapper_path: bool = False, ) -> AliasedClass[_O]: """Produce an :class:`.AliasedClass` construct which specifies @@ -2568,10 +2400,6 @@ def with_polymorphic( .. versionadded:: 1.4.33 - :param name: Name given to the generated :class:`.AliasedClass`. - - .. versionadded:: 2.0.31 - """ return AliasedInsp._with_polymorphic_factory( base, @@ -2582,7 +2410,6 @@ def with_polymorphic( adapt_on_names=adapt_on_names, aliased=aliased, innerjoin=innerjoin, - name=name, _use_mapper_path=_use_mapper_path, ) @@ -2614,21 +2441,16 @@ def join( :meth:`_sql.Select.select_from` method, as in:: from sqlalchemy.orm import join - - stmt = ( - select(User) - .select_from(join(User, Address, User.addresses)) - .filter(Address.email_address == "foo@bar.com") - ) + stmt = select(User).\ + select_from(join(User, Address, User.addresses)).\ + filter(Address.email_address=='foo@bar.com') In modern SQLAlchemy the above join can be written more succinctly as:: - stmt = ( - select(User) - .join(User.addresses) - .filter(Address.email_address == "foo@bar.com") - ) + stmt = select(User).\ + join(User.addresses).\ + filter(Address.email_address=='foo@bar.com') .. warning:: using :func:`_orm.join` directly may not work properly with modern ORM options such as :func:`_orm.with_loader_criteria`. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/_typing.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/_typing.py index ccb8413..3085351 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/_typing.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/_typing.py @@ -1,5 +1,5 @@ # orm/_typing.py -# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2022 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -78,7 +78,7 @@ _IdentityKeyType = Tuple[Type[_T], Tuple[Any, ...], Optional[Any]] _ORMColumnExprArgument = Union[ ColumnElement[_T], - _HasClauseElement[_T], + _HasClauseElement, roles.ExpressionElementRole[_T], ] @@ -108,13 +108,13 @@ class _ORMAdapterProto(Protocol): """ - def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: ... + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: + ... class _LoaderCallable(Protocol): - def __call__( - self, state: InstanceState[Any], passive: PassiveFlag - ) -> Any: ... + def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: + ... def is_orm_option( @@ -138,33 +138,39 @@ def is_composite_class(obj: Any) -> bool: if TYPE_CHECKING: - def insp_is_mapper_property( - obj: Any, - ) -> TypeGuard[MapperProperty[Any]]: ... + def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: + ... - def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: ... + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: + ... - def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: ... + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: + ... def insp_is_attribute( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: ... + ) -> TypeGuard[QueryableAttribute[Any]]: + ... def attr_is_internal_proxy( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: ... + ) -> TypeGuard[QueryableAttribute[Any]]: + ... def prop_is_relationship( prop: MapperProperty[Any], - ) -> TypeGuard[RelationshipProperty[Any]]: ... + ) -> TypeGuard[RelationshipProperty[Any]]: + ... def is_collection_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: ... + ) -> TypeGuard[CollectionAttributeImpl]: + ... def is_has_collection_adapter( impl: AttributeImpl, - ) -> TypeGuard[HasCollectionAdapter]: ... + ) -> TypeGuard[HasCollectionAdapter]: + ... else: insp_is_mapper_property = operator.attrgetter("is_property") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/attributes.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/attributes.py index 9c67936..1098359 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/attributes.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/attributes.py @@ -1,5 +1,5 @@ # orm/attributes.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -401,7 +401,7 @@ class QueryableAttribute( parententity=adapt_to_entity, ) - def of_type(self, entity: _EntityType[_T]) -> QueryableAttribute[_T]: + def of_type(self, entity: _EntityType[Any]) -> QueryableAttribute[_T]: return QueryableAttribute( self.class_, self.key, @@ -462,9 +462,6 @@ class QueryableAttribute( ) -> bool: return self.impl.hasparent(state, optimistic=optimistic) is not False - def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: - return (self,) - def __getattr__(self, key: str) -> Any: try: return util.MemoizedSlots.__getattr__(self, key) @@ -506,7 +503,7 @@ def _queryable_attribute_unreduce( return getattr(entity, key) -class InstrumentedAttribute(QueryableAttribute[_T_co]): +class InstrumentedAttribute(QueryableAttribute[_T]): """Class bound instrumented attribute which adds basic :term:`descriptor` methods. @@ -545,16 +542,16 @@ class InstrumentedAttribute(QueryableAttribute[_T_co]): self.impl.delete(instance_state(instance), instance_dict(instance)) @overload - def __get__( - self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: ... + def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: + ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: ... + def __get__(self, instance: object, owner: Any) -> _T: + ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], _T_co]: + ) -> Union[InstrumentedAttribute[_T], _T]: if instance is None: return self @@ -598,7 +595,7 @@ def create_proxied_attribute( # TODO: can move this to descriptor_props if the need for this # function is removed from ext/hybrid.py - class Proxy(QueryableAttribute[_T_co]): + class Proxy(QueryableAttribute[Any]): """Presents the :class:`.QueryableAttribute` interface as a proxy on top of a Python descriptor / :class:`.PropComparator` combination. @@ -613,13 +610,13 @@ def create_proxied_attribute( def __init__( self, - class_: _ExternalEntityType[Any], - key: str, - descriptor: Any, - comparator: interfaces.PropComparator[_T_co], - adapt_to_entity: Optional[AliasedInsp[Any]] = None, - doc: Optional[str] = None, - original_property: Optional[QueryableAttribute[_T_co]] = None, + class_, + key, + descriptor, + comparator, + adapt_to_entity=None, + doc=None, + original_property=None, ): self.class_ = class_ self.key = key @@ -630,11 +627,11 @@ def create_proxied_attribute( self._doc = self.__doc__ = doc @property - def _parententity(self): # type: ignore[override] + def _parententity(self): return inspection.inspect(self.class_, raiseerr=False) @property - def parent(self): # type: ignore[override] + def parent(self): return inspection.inspect(self.class_, raiseerr=False) _is_internal_proxy = True @@ -644,13 +641,6 @@ def create_proxied_attribute( ("_parententity", visitors.ExtendedInternalTraversal.dp_multi), ] - def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: - prop = self.original_property - if prop is None: - return () - else: - return prop._column_strategy_attrs() - @property def _impl_uses_objects(self): return ( @@ -1548,7 +1538,8 @@ class HasCollectionAdapter: dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -1557,7 +1548,8 @@ class HasCollectionAdapter: dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -1568,7 +1560,8 @@ class HasCollectionAdapter: passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: ... + ]: + ... def get_collection( self, @@ -1599,7 +1592,8 @@ if TYPE_CHECKING: def _is_collection_attribute_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: ... + ) -> TypeGuard[CollectionAttributeImpl]: + ... else: _is_collection_attribute_impl = operator.attrgetter("collection") @@ -2055,7 +2049,8 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -2064,7 +2059,8 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -2075,7 +2071,8 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): passive: PassiveFlag = PASSIVE_OFF, ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: ... + ]: + ... def get_collection( self, @@ -2673,7 +2670,7 @@ def init_collection(obj: object, key: str) -> CollectionAdapter: This function is used to provide direct access to collection internals for a previously unloaded attribute. e.g.:: - collection_adapter = init_collection(someobject, "elements") + collection_adapter = init_collection(someobject, 'elements') for elem in values: collection_adapter.append_without_event(elem) @@ -2717,7 +2714,7 @@ def init_state_collection( return adapter -def set_committed_value(instance: object, key: str, value: Any) -> None: +def set_committed_value(instance, key, value): """Set the value of an attribute with no history events. Cancels any previous history present. The value should be diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/base.py index b9f8d32..362346c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/base.py @@ -1,11 +1,13 @@ # orm/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Constants and rudimental functions used throughout the ORM.""" +"""Constants and rudimental functions used throughout the ORM. + +""" from __future__ import annotations @@ -19,7 +21,6 @@ from typing import Generic from typing import no_type_check from typing import Optional from typing import overload -from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -143,7 +144,7 @@ class PassiveFlag(FastIntFlag): """ NO_AUTOFLUSH = 64 - """Loader callables should disable autoflush.""" + """Loader callables should disable autoflush.""", NO_RAISE = 128 """Loader callables should not raise any assertions""" @@ -281,8 +282,6 @@ _never_set = frozenset([NEVER_SET]) _none_set = frozenset([None, NEVER_SET, PASSIVE_NO_RESULT]) -_none_only_set = frozenset([None]) - _SET_DEFERRED_EXPIRED = util.symbol("SET_DEFERRED_EXPIRED") _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE") @@ -309,23 +308,29 @@ def _assertions( if TYPE_CHECKING: - def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: ... + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: + ... @overload - def opt_manager_of_class(cls: AliasedClass[Any]) -> None: ... + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: + ... @overload def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: ... + ) -> Optional[ClassManager[_O]]: + ... def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: ... + ) -> Optional[ClassManager[_O]]: + ... - def instance_state(instance: _O) -> InstanceState[_O]: ... + def instance_state(instance: _O) -> InstanceState[_O]: + ... - def instance_dict(instance: object) -> Dict[str, Any]: ... + def instance_dict(instance: object) -> Dict[str, Any]: + ... else: # these can be replaced by sqlalchemy.ext.instrumentation @@ -433,7 +438,7 @@ def _inspect_mapped_object(instance: _T) -> Optional[InstanceState[_T]]: def _class_to_mapper( - class_or_mapper: Union[Mapper[_T], Type[_T]], + class_or_mapper: Union[Mapper[_T], Type[_T]] ) -> Mapper[_T]: # can't get mypy to see an overload for this insp = inspection.inspect(class_or_mapper, False) @@ -445,7 +450,7 @@ def _class_to_mapper( def _mapper_or_none( - entity: Union[Type[_T], _InternalEntityType[_T]], + entity: Union[Type[_T], _InternalEntityType[_T]] ) -> Optional[Mapper[_T]]: """Return the :class:`_orm.Mapper` for the given class or None if the class is not mapped. @@ -507,7 +512,8 @@ def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: if TYPE_CHECKING: - def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: ... + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: + ... else: _state_mapper = util.dottedgetter("manager.mapper") @@ -580,7 +586,7 @@ class InspectionAttr: """ - __slots__: Tuple[str, ...] = () + __slots__ = () is_selectable = False """Return True if this object is an instance of @@ -678,25 +684,27 @@ class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): if typing.TYPE_CHECKING: - def of_type( - self, class_: _EntityType[Any] - ) -> PropComparator[_T_co]: ... + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: + ... def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> PropComparator[bool]: ... + ) -> PropComparator[bool]: + ... def any( # noqa: A001 self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: ... + ) -> ColumnElement[bool]: + ... def has( self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: ... + ) -> ColumnElement[bool]: + ... class ORMDescriptor(Generic[_T_co], TypingOnly): @@ -710,19 +718,23 @@ class ORMDescriptor(Generic[_T_co], TypingOnly): @overload def __get__( self, instance: Any, owner: Literal[None] - ) -> ORMDescriptor[_T_co]: ... + ) -> ORMDescriptor[_T_co]: + ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLCoreOperations[_T_co]: ... + ) -> SQLCoreOperations[_T_co]: + ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: ... + def __get__(self, instance: object, owner: Any) -> _T_co: + ... def __get__( self, instance: object, owner: Any - ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: + ... class _MappedAnnotationBase(Generic[_T_co], TypingOnly): @@ -808,23 +820,29 @@ class Mapped( @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: ... + ) -> InstrumentedAttribute[_T_co]: + ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: ... + def __get__(self, instance: object, owner: Any) -> _T_co: + ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: + ... @classmethod - def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: + ... def __set__( self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] - ) -> None: ... + ) -> None: + ... - def __delete__(self, instance: Any) -> None: ... + def __delete__(self, instance: Any) -> None: + ... class _MappedAttribute(Generic[_T_co], TypingOnly): @@ -901,20 +919,24 @@ class DynamicMapped(_MappedAnnotationBase[_T_co]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: ... + ) -> InstrumentedAttribute[_T_co]: + ... @overload def __get__( self, instance: object, owner: Any - ) -> AppenderQuery[_T_co]: ... + ) -> AppenderQuery[_T_co]: + ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: + ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: ... + ) -> None: + ... class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): @@ -953,19 +975,21 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: ... + ) -> InstrumentedAttribute[_T_co]: + ... @overload def __get__( self, instance: object, owner: Any - ) -> WriteOnlyCollection[_T_co]: ... + ) -> WriteOnlyCollection[_T_co]: + ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[ - InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co] - ]: ... + ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]: + ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: ... + ) -> None: + ... diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/bulk_persistence.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/bulk_persistence.py index 86ea53b..31caedc 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/bulk_persistence.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/bulk_persistence.py @@ -1,5 +1,5 @@ # orm/bulk_persistence.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -76,13 +76,13 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Literal[None] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> None: ... +) -> None: + ... @overload @@ -90,20 +90,19 @@ def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, return_defaults: bool, render_nulls: bool, use_orm_insert_stmt: Optional[dml.Insert] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> cursor.CursorResult[Any]: ... +) -> cursor.CursorResult[Any]: + ... def _bulk_insert( mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, return_defaults: bool, render_nulls: bool, @@ -119,35 +118,13 @@ def _bulk_insert( ) if isstates: - if TYPE_CHECKING: - mappings = cast(Iterable[InstanceState[_O]], mappings) - if return_defaults: - # list of states allows us to attach .key for return_defaults case states = [(state, state.dict) for state in mappings] mappings = [dict_ for (state, dict_) in states] else: mappings = [state.dict for state in mappings] else: - if TYPE_CHECKING: - mappings = cast(Iterable[Dict[str, Any]], mappings) - - if return_defaults: - # use dictionaries given, so that newly populated defaults - # can be delivered back to the caller (see #11661). This is **not** - # compatible with other use cases such as a session-executed - # insert() construct, as this will confuse the case of - # insert-per-subclass for joined inheritance cases (see - # test_bulk_statements.py::BulkDMLReturningJoinedInhTest). - # - # So in this conditional, we have **only** called - # session.bulk_insert_mappings() which does not have this - # requirement - mappings = list(mappings) - else: - # for all other cases we need to establish a local dictionary - # so that the incoming dictionaries aren't mutated - mappings = [dict(m) for m in mappings] + mappings = [dict(m) for m in mappings] _expand_composites(mapper, mappings) connection = session_transaction.connection(base_mapper) @@ -243,7 +220,6 @@ def _bulk_insert( state.key = ( identity_cls, tuple([dict_[key] for key in identity_props]), - None, ) if use_orm_insert_stmt is not None: @@ -256,12 +232,12 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., enable_check_rowcount: bool = True, -) -> None: ... +) -> None: + ... @overload @@ -269,19 +245,18 @@ def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., enable_check_rowcount: bool = True, -) -> _result.Result[Any]: ... +) -> _result.Result[Any]: + ... def _bulk_update( mapper: Mapper[Any], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], session_transaction: SessionTransaction, - *, isstates: bool, update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = None, @@ -402,16 +377,14 @@ class ORMDMLState(AbstractORMCompileState): if desc is NO_VALUE: yield ( coercions.expect(roles.DMLColumnRole, k), - ( - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ) - if needs_to_be_cacheable - else v - ), + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v, ) else: yield from core_get_crud_kv_pairs( @@ -432,36 +405,21 @@ class ORMDMLState(AbstractORMCompileState): else: yield ( k, - ( - v - if not needs_to_be_cacheable - else coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ) + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, ), ) - @classmethod - def _get_dml_plugin_subject(cls, statement): - plugin_subject = statement.table._propagate_attrs.get("plugin_subject") - - if ( - not plugin_subject - or not plugin_subject.mapper - or plugin_subject - is not statement._propagate_attrs["plugin_subject"] - ): - return None - return plugin_subject - @classmethod def _get_multi_crud_kv_pairs(cls, statement, kv_iterator): - plugin_subject = cls._get_dml_plugin_subject(statement) + plugin_subject = statement._propagate_attrs["plugin_subject"] - if not plugin_subject: + if not plugin_subject or not plugin_subject.mapper: return UpdateDMLState._get_multi_crud_kv_pairs( statement, kv_iterator ) @@ -481,12 +439,13 @@ class ORMDMLState(AbstractORMCompileState): needs_to_be_cacheable ), "no test coverage for needs_to_be_cacheable=False" - plugin_subject = cls._get_dml_plugin_subject(statement) + plugin_subject = statement._propagate_attrs["plugin_subject"] - if not plugin_subject: + if not plugin_subject or not plugin_subject.mapper: return UpdateDMLState._get_crud_kv_pairs( statement, kv_iterator, needs_to_be_cacheable ) + return list( cls._get_orm_crud_kv_pairs( plugin_subject.mapper, @@ -569,9 +528,9 @@ class ORMDMLState(AbstractORMCompileState): fs = fs.execution_options(**orm_level_statement._execution_options) fs = fs.options(*orm_level_statement._with_options) self.select_statement = fs - self.from_statement_ctx = fsc = ( - ORMFromStatementCompileState.create_for_statement(fs, compiler) - ) + self.from_statement_ctx = ( + fsc + ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) fsc.setup_dml_returning_compile_state(dml_mapper) dml_level_statement = dml_level_statement._generate() @@ -631,7 +590,6 @@ class ORMDMLState(AbstractORMCompileState): querycontext = QueryContext( compile_state.from_statement_ctx, compile_state.select_statement, - statement, params, session, load_options, @@ -656,7 +614,6 @@ class BulkUDCompileState(ORMDMLState): _eval_condition = None _matched_rows = None _identity_token = None - _populate_existing: bool = False @classmethod def can_use_returning( @@ -689,7 +646,6 @@ class BulkUDCompileState(ORMDMLState): { "synchronize_session", "autoflush", - "populate_existing", "identity_token", "is_delete_using", "is_update_from", @@ -874,39 +830,53 @@ class BulkUDCompileState(ORMDMLState): return return_crit @classmethod - def _interpret_returning_rows(cls, result, mapper, rows): - """return rows that indicate PK cols in mapper.primary_key position - for RETURNING rows. + def _interpret_returning_rows(cls, mapper, rows): + """translate from local inherited table columns to base mapper + primary key columns. - Prior to 2.0.36, this method seemed to be written for some kind of - inheritance scenario but the scenario was unused for actual joined - inheritance, and the function instead seemed to perform some kind of - partial translation that would remove non-PK cols if the PK cols - happened to be first in the row, but not otherwise. The joined - inheritance walk feature here seems to have never been used as it was - always skipped by the "local_table" check. + Joined inheritance mappers always establish the primary key in terms of + the base table. When we UPDATE a sub-table, we can only get + RETURNING for the sub-table's columns. - As of 2.0.36 the function strips away non-PK cols and provides the - PK cols for the table in mapper PK order. + Here, we create a lookup from the local sub table's primary key + columns to the base table PK columns so that we can get identity + key values from RETURNING that's against the joined inheritance + sub-table. + + the complexity here is to support more than one level deep of + inheritance, where we have to link columns to each other across + the inheritance hierarchy. """ - try: - if mapper.local_table is not mapper.base_mapper.local_table: - # TODO: dive more into how a local table PK is used for fetch - # sync, not clear if this is correct as it depends on the - # downstream routine to fetch rows using - # local_table.primary_key order - pk_keys = result._tuple_getter(mapper.local_table.primary_key) - else: - pk_keys = result._tuple_getter(mapper.primary_key) - except KeyError: - # can't use these rows, they don't have PK cols in them - # this is an unusual case where the user would have used - # .return_defaults() - return [] + if mapper.local_table is not mapper.base_mapper.local_table: + return rows - return [pk_keys(row) for row in rows] + # this starts as a mapping of + # local_pk_col: local_pk_col. + # we will then iteratively rewrite the "value" of the dict with + # each successive superclass column + local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key} + + for mp in mapper.iterate_to_root(): + if mp.inherits is None: + break + elif mp.local_table is mp.inherits.local_table: + continue + + t_to_e = dict(mp._table_to_equated[mp.inherits.local_table]) + col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]} + for pk, super_ in local_pk_to_base_pk.items(): + local_pk_to_base_pk[pk] = col_to_col[super_] + + lookup = { + local_pk_to_base_pk[lpk]: idx + for idx, lpk in enumerate(mapper.local_table.primary_key) + } + primary_key_convert = [ + lookup[bpk] for bpk in mapper.base_mapper.primary_key + ] + return [tuple(row[idx] for idx in primary_key_convert) for row in rows] @classmethod def _get_matched_objects_on_criteria(cls, update_options, states): @@ -1469,9 +1439,6 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): new_stmt = statement._clone() - if new_stmt.table._annotations["parententity"] is mapper: - new_stmt.table = mapper.local_table - # note if the statement has _multi_values, these # are passed through to the new statement, which will then raise # InvalidRequestError because UPDATE doesn't support multi_values @@ -1590,20 +1557,10 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): bind_arguments: _BindArguments, conn: Connection, ) -> _result.Result: - update_options = execution_options.get( "_sa_orm_update_options", cls.default_update_options ) - if update_options._populate_existing: - load_options = execution_options.get( - "_sa_orm_load_options", QueryContext.default_load_options - ) - load_options += {"_populate_existing": True} - execution_options = execution_options.union( - {"_sa_orm_load_options": load_options} - ) - if update_options._dml_strategy not in ( "orm", "auto", @@ -1759,10 +1716,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): session, update_options, statement, - result.context.compiled_parameters[0], [(obj, state, dict_) for obj, state, dict_, _ in matched_objects], - result.prefetch_cols(), - result.postfetch_cols(), ) @classmethod @@ -1774,8 +1728,9 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): returned_defaults_rows = result.returned_defaults_rows if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( - result, target_mapper, returned_defaults_rows + target_mapper, returned_defaults_rows ) + matched_rows = [ tuple(row) + (update_options._identity_token,) for row in pk_rows @@ -1806,7 +1761,6 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): session, update_options, statement, - result.context.compiled_parameters[0], [ ( obj, @@ -1815,26 +1769,16 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): ) for obj in objs ], - result.prefetch_cols(), - result.postfetch_cols(), ) @classmethod def _apply_update_set_values_to_objects( - cls, - session, - update_options, - statement, - effective_params, - matched_objects, - prefetch_cols, - postfetch_cols, + cls, session, update_options, statement, matched_objects ): """apply values to objects derived from an update statement, e.g. UPDATE..SET """ - mapper = update_options._subject_mapper target_cls = mapper.class_ evaluator_compiler = evaluator._EvaluatorCompiler(target_cls) @@ -1857,35 +1801,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): attrib = {k for k, v in resolved_keys_as_propnames} states = set() - - to_prefetch = { - c - for c in prefetch_cols - if c.key in effective_params - and c in mapper._columntoproperty - and c.key not in evaluated_keys - } - to_expire = { - mapper._columntoproperty[c].key - for c in postfetch_cols - if c in mapper._columntoproperty - }.difference(evaluated_keys) - - prefetch_transfer = [ - (mapper._columntoproperty[c].key, c.key) for c in to_prefetch - ] - for obj, state, dict_ in matched_objects: - - dict_.update( - { - col_to_prop: effective_params[c_key] - for col_to_prop, c_key in prefetch_transfer - } - ) - - state._expire_attributes(state.dict, to_expire) - to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: @@ -1942,9 +1858,6 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): new_stmt = statement._clone() - if new_stmt.table._annotations["parententity"] is mapper: - new_stmt.table = mapper.local_table - new_crit = cls._adjust_for_extra_criteria( self.global_attributes, mapper ) @@ -2105,7 +2018,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): if returned_defaults_rows: pk_rows = cls._interpret_returning_rows( - result, target_mapper, returned_defaults_rows + target_mapper, returned_defaults_rows ) matched_rows = [ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py index fd4828e..10f1db0 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/clsregistry.py @@ -1,5 +1,5 @@ -# orm/clsregistry.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# ext/declarative/clsregistry.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -72,7 +72,7 @@ def add_class( # class already exists. existing = decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): - decl_class_registry[classname] = _MultipleClassMarker( + existing = decl_class_registry[classname] = _MultipleClassMarker( [cls, cast("Type[Any]", existing)] ) else: @@ -83,9 +83,9 @@ def add_class( _ModuleMarker, decl_class_registry["_sa_module_registry"] ) except KeyError: - decl_class_registry["_sa_module_registry"] = root_module = ( - _ModuleMarker("_sa_module_registry", None) - ) + decl_class_registry[ + "_sa_module_registry" + ] = root_module = _ModuleMarker("_sa_module_registry", None) tokens = cls.__module__.split(".") @@ -239,10 +239,10 @@ class _MultipleClassMarker(ClsRegistryToken): def add_item(self, item: Type[Any]) -> None: # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, - # [ticket:3208] and [ticket:10782] + # [ticket:3208] modules = { cls.__module__ - for cls in [ref() for ref in list(self.contents)] + for cls in [ref() for ref in self.contents] if cls is not None } if item.__module__ in modules: @@ -287,9 +287,8 @@ class _ModuleMarker(ClsRegistryToken): def _remove_item(self, name: str) -> None: self.contents.pop(name, None) - if not self.contents: - if self.parent is not None: - self.parent._remove_item(self.name) + if not self.contents and self.parent is not None: + self.parent._remove_item(self.name) _registries.discard(self) def resolve_attr(self, key: str) -> Union[_ModNS, Type[Any]]: @@ -317,7 +316,7 @@ class _ModuleMarker(ClsRegistryToken): else: raise else: - self.contents[name] = _MultipleClassMarker( + existing = self.contents[name] = _MultipleClassMarker( [cls], on_remove=lambda: self._remove_item(name) ) @@ -543,7 +542,9 @@ class _class_resolver: _fallback_dict: Mapping[str, Any] = None # type: ignore -def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ +def _resolver( + cls: Type[Any], prop: RelationshipProperty[Any] +) -> Tuple[ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], Callable[[str, bool], _class_resolver], ]: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/collections.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/collections.py index b698c2e..3a4964c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/collections.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/collections.py @@ -1,5 +1,5 @@ # orm/collections.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -21,8 +21,6 @@ provided. One is a bundle of generic decorators that map function arguments and return values to events:: from sqlalchemy.orm.collections import collection - - class MyClass: # ... @@ -34,6 +32,7 @@ and return values to events:: def pop(self): return self.data.pop() + The second approach is a bundle of targeted decorators that wrap appropriate append and remove notifiers around the mutation methods present in the standard Python ``list``, ``set`` and ``dict`` interfaces. These could be @@ -74,11 +73,10 @@ generally not needed. Odds are, the extension method will delegate to a method that's already instrumented. For example:: class QueueIsh(list): - def push(self, item): - self.append(item) - - def shift(self): - return self.pop(0) + def push(self, item): + self.append(item) + def shift(self): + return self.pop(0) There's no need to decorate these methods. ``append`` and ``pop`` are already instrumented as part of the ``list`` interface. Decorating them would fire @@ -150,12 +148,10 @@ __all__ = [ "keyfunc_mapping", "column_keyed_dict", "attribute_keyed_dict", - "KeyFuncDict", - # old names in < 2.0 - "mapped_collection", - "column_mapped_collection", - "attribute_mapped_collection", + "column_keyed_dict", + "attribute_keyed_dict", "MappedCollection", + "KeyFuncDict", ] __instrumentation_mutex = threading.Lock() @@ -171,7 +167,8 @@ _FN = TypeVar("_FN", bound="Callable[..., Any]") class _CollectionConverterProtocol(Protocol): - def __call__(self, collection: _COL) -> _COL: ... + def __call__(self, collection: _COL) -> _COL: + ... class _AdaptedCollectionProtocol(Protocol): @@ -197,10 +194,9 @@ class collection: The recipe decorators all require parens, even those that take no arguments:: - @collection.adds("entity") + @collection.adds('entity') def insert(self, position, entity): ... - @collection.removes_return() def popitem(self): ... @@ -220,13 +216,11 @@ class collection: @collection.appender def add(self, append): ... - # or, equivalently @collection.appender @collection.adds(1) def add(self, append): ... - # for mapping type, an 'append' may kick out a previous value # that occupies that slot. consider d['a'] = 'foo'- any previous # value in d['a'] is discarded. @@ -266,11 +260,10 @@ class collection: @collection.remover def zap(self, entity): ... - # or, equivalently @collection.remover @collection.removes_return() - def zap(self): ... + def zap(self, ): ... If the value to remove is not present in the collection, you may raise an exception or return None to ignore the error. @@ -359,7 +352,7 @@ class collection: return fn @staticmethod - def adds(arg: int) -> Callable[[_FN], _FN]: + def adds(arg): """Mark the method as adding an entity to the collection. Adds "add to collection" handling to the method. The decorator @@ -370,8 +363,7 @@ class collection: @collection.adds(1) def push(self, item): ... - - @collection.adds("entity") + @collection.adds('entity') def do_stuff(self, thing, entity=None): ... """ @@ -556,9 +548,9 @@ class CollectionAdapter: self.empty ), "This collection adapter is not in the 'empty' state" self.empty = False - self.owner_state.dict[self._key] = ( - self.owner_state._empty_collections.pop(self._key) - ) + self.owner_state.dict[ + self._key + ] = self.owner_state._empty_collections.pop(self._key) def _refuse_empty(self) -> NoReturn: raise sa_exc.InvalidRequestError( @@ -1562,14 +1554,14 @@ class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation: util.immutabledict[Any, _CollectionFactoryType] = ( - util.immutabledict( - { - list: InstrumentedList, - set: InstrumentedSet, - dict: InstrumentedDict, - } - ) +__canned_instrumentation: util.immutabledict[ + Any, _CollectionFactoryType +] = util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } ) __interfaces: util.immutabledict[ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/context.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/context.py index 30b0594..79b43f5 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/context.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/context.py @@ -1,5 +1,5 @@ # orm/context.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -104,7 +104,6 @@ class QueryContext: "top_level_context", "compile_state", "query", - "user_passed_query", "params", "load_options", "bind_arguments", @@ -148,12 +147,7 @@ class QueryContext: def __init__( self, compile_state: CompileState, - statement: Union[Select[Any], FromStatement[Any], UpdateBase], - user_passed_query: Union[ - Select[Any], - FromStatement[Any], - UpdateBase, - ], + statement: Union[Select[Any], FromStatement[Any]], params: _CoreSingleExecuteParams, session: Session, load_options: Union[ @@ -168,13 +162,6 @@ class QueryContext: self.bind_arguments = bind_arguments or _EMPTY_DICT self.compile_state = compile_state self.query = statement - - # the query that the end user passed to Session.execute() or similar. - # this is usually the same as .query, except in the bulk_persistence - # routines where a separate FromStatement is manufactured in the - # compile stage; this allows differentiation in that case. - self.user_passed_query = user_passed_query - self.session = session self.loaders_require_buffering = False self.loaders_require_uniquing = False @@ -182,7 +169,7 @@ class QueryContext: self.top_level_context = load_options._sa_top_level_orm_context cached_options = compile_state.select_statement._with_options - uncached_options = user_passed_query._with_options + uncached_options = statement._with_options # see issue #7447 , #8399 for some background # propagated loader options will be present on loaded InstanceState @@ -231,7 +218,7 @@ class AbstractORMCompileState(CompileState): if compiler is None: # this is the legacy / testing only ORM _compile_state() use case. # there is no need to apply criteria options for this. - self.global_attributes = {} + self.global_attributes = ga = {} assert toplevel return else: @@ -265,10 +252,10 @@ class AbstractORMCompileState(CompileState): @classmethod def create_for_statement( cls, - statement: Executable, - compiler: SQLCompiler, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], **kw: Any, - ) -> CompileState: + ) -> AbstractORMCompileState: """Create a context for a statement given a :class:`.Compiler`. This method is always invoked in the context of SQLCompiler.process(). @@ -414,8 +401,8 @@ class ORMCompileState(AbstractORMCompileState): attributes: Dict[Any, Any] global_attributes: Dict[Any, Any] - statement: Union[Select[Any], FromStatement[Any], UpdateBase] - select_statement: Union[Select[Any], FromStatement[Any], UpdateBase] + statement: Union[Select[Any], FromStatement[Any]] + select_statement: Union[Select[Any], FromStatement[Any]] _entities: List[_QueryEntity] _polymorphic_adapters: Dict[_InternalEntityType, ORMAdapter] compile_options: Union[ @@ -437,30 +424,16 @@ class ORMCompileState(AbstractORMCompileState): def __init__(self, *arg, **kw): raise NotImplementedError() - @classmethod - def create_for_statement( - cls, - statement: Executable, - compiler: SQLCompiler, - **kw: Any, - ) -> ORMCompileState: - return cls._create_orm_context( - cast("Union[Select, FromStatement]", statement), - toplevel=not compiler.stack, - compiler=compiler, - **kw, - ) + if TYPE_CHECKING: - @classmethod - def _create_orm_context( - cls, - statement: Union[Select, FromStatement], - *, - toplevel: bool, - compiler: Optional[SQLCompiler], - **kw: Any, - ) -> ORMCompileState: - raise NotImplementedError() + @classmethod + def create_for_statement( + cls, + statement: Union[Select, FromStatement], + compiler: Optional[SQLCompiler], + **kw: Any, + ) -> ORMCompileState: + ... def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns @@ -544,14 +517,15 @@ class ORMCompileState(AbstractORMCompileState): and len(statement._compile_options._current_path) > 10 and execution_options.get("compiled_cache", True) is not None ): - execution_options: util.immutabledict[str, Any] = ( - execution_options.union( - { - "compiled_cache": None, - "_cache_disable_reason": "excess depth for " - "ORM loader options", - } - ) + util.warn( + "Loader depth for query is excessively deep; caching will " + "be disabled for additional loaders. Consider using the " + "recursion_depth feature for deeply nested recursive eager " + "loaders. Use the compiled_cache=None execution option to " + "skip this warning." + ) + execution_options = execution_options.union( + {"compiled_cache": None} ) bind_arguments["clause"] = statement @@ -606,7 +580,6 @@ class ORMCompileState(AbstractORMCompileState): querycontext = QueryContext( compile_state, statement, - statement, params, session, load_options, @@ -670,8 +643,8 @@ class ORMCompileState(AbstractORMCompileState): ) -class _DMLReturningColFilter: - """a base for an adapter used for the DML RETURNING cases +class DMLReturningColFilter: + """an adapter used for the DML RETURNING case. Has a subset of the interface used by :class:`.ORMAdapter` and is used for :class:`._QueryEntity` @@ -705,21 +678,6 @@ class _DMLReturningColFilter: else: return None - def adapt_check_present(self, col): - raise NotImplementedError() - - -class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter): - """an adapter used for the DML RETURNING case specifically - for ORM bulk insert (or any hypothetical DML that is splitting out a class - hierarchy among multiple DML statements....ORM bulk insert is the only - example right now) - - its main job is to limit the columns in a RETURNING to only a specific - mapped table in a hierarchy. - - """ - def adapt_check_present(self, col): mapper = self.mapper prop = mapper._columntoproperty.get(col, None) @@ -728,30 +686,6 @@ class _DMLBulkInsertReturningColFilter(_DMLReturningColFilter): return mapper.local_table.c.corresponding_column(col) -class _DMLUpdateDeleteReturningColFilter(_DMLReturningColFilter): - """an adapter used for the DML RETURNING case specifically - for ORM enabled UPDATE/DELETE - - its main job is to limit the columns in a RETURNING to include - only direct persisted columns from the immediate selectable, not - expressions like column_property(), or to also allow columns from other - mappers for the UPDATE..FROM use case. - - """ - - def adapt_check_present(self, col): - mapper = self.mapper - prop = mapper._columntoproperty.get(col, None) - if prop is not None: - # if the col is from the immediate mapper, only return a persisted - # column, not any kind of column_property expression - return mapper.persist_selectable.c.corresponding_column(col) - - # if the col is from some other mapper, just return it, assume the - # user knows what they are doing - return col - - @sql.base.CompileState.plugin_for("orm", "orm_from_statement") class ORMFromStatementCompileState(ORMCompileState): _from_obj_alias = None @@ -770,16 +704,12 @@ class ORMFromStatementCompileState(ORMCompileState): eager_joins = _EMPTY_DICT @classmethod - def _create_orm_context( + def create_for_statement( cls, - statement: Union[Select, FromStatement], - *, - toplevel: bool, + statement_container: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMFromStatementCompileState: - statement_container = statement - assert isinstance(statement_container, FromStatement) if compiler is not None and compiler.stack: @@ -821,11 +751,9 @@ class ORMFromStatementCompileState(ORMCompileState): self.statement = statement self._label_convention = self._column_naming_convention( - ( - statement._label_style - if not statement._is_textual and not statement.is_dml - else LABEL_STYLE_NONE - ), + statement._label_style + if not statement._is_textual and not statement.is_dml + else LABEL_STYLE_NONE, self.use_legacy_query_style, ) @@ -871,9 +799,9 @@ class ORMFromStatementCompileState(ORMCompileState): for entity in self._entities: entity.setup_compile_state(self) - compiler._ordered_columns = compiler._textual_ordered_columns = ( - False - ) + compiler._ordered_columns = ( + compiler._textual_ordered_columns + ) = False # enable looser result column matching. this is shown to be # needed by test_query.py::TextTest @@ -910,24 +838,14 @@ class ORMFromStatementCompileState(ORMCompileState): return None def setup_dml_returning_compile_state(self, dml_mapper): - """used by BulkORMInsert, Update, Delete to set up a handler + """used by BulkORMInsert (and Update / Delete?) to set up a handler for RETURNING to return ORM objects and expressions """ target_mapper = self.statement._propagate_attrs.get( "plugin_subject", None ) - - if self.statement.is_insert: - adapter = _DMLBulkInsertReturningColFilter( - target_mapper, dml_mapper - ) - elif self.statement.is_update or self.statement.is_delete: - adapter = _DMLUpdateDeleteReturningColFilter( - target_mapper, dml_mapper - ) - else: - adapter = None + adapter = DMLReturningColFilter(target_mapper, dml_mapper) if self.compile_options._is_star and (len(self._entities) != 1): raise sa_exc.CompileError( @@ -970,8 +888,6 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): ("_compile_options", InternalTraversal.dp_has_cache_key) ] - is_from_statement = True - def __init__( self, entities: Iterable[_ColumnsClauseArgument[Any]], @@ -989,10 +905,6 @@ class FromStatement(GroupedElement, Generative, TypedReturnsRows[_TP]): ] self.element = element self.is_dml = element.is_dml - self.is_select = element.is_select - self.is_delete = element.is_delete - self.is_insert = element.is_insert - self.is_update = element.is_update self._label_style = ( element._label_style if is_select_base(element) else None ) @@ -1086,17 +998,21 @@ class ORMSelectCompileState(ORMCompileState, SelectState): _having_criteria = () @classmethod - def _create_orm_context( + def create_for_statement( cls, statement: Union[Select, FromStatement], - *, - toplevel: bool, compiler: Optional[SQLCompiler], **kw: Any, ) -> ORMSelectCompileState: + """compiler hook, we arrive here from compiler.visit_select() only.""" self = cls.__new__(cls) + if compiler is not None: + toplevel = not compiler.stack + else: + toplevel = True + select_statement = statement # if we are a select() that was never a legacy Query, we won't @@ -1452,15 +1368,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def get_columns_clause_froms(cls, statement): return cls._normalize_froms( itertools.chain.from_iterable( - ( - element._from_objects - if "parententity" not in element._annotations - else [ - element._annotations[ - "parententity" - ].__clause_element__() - ] - ) + element._from_objects + if "parententity" not in element._annotations + else [ + element._annotations["parententity"].__clause_element__() + ] for element in statement._raw_columns ) ) @@ -1589,11 +1501,9 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # the original expressions outside of the label references # in order to have them render. unwrapped_order_by = [ - ( - elem.element - if isinstance(elem, sql.elements._label_reference) - else elem - ) + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem for elem in self.order_by ] @@ -1635,10 +1545,10 @@ class ORMSelectCompileState(ORMCompileState, SelectState): ) statement._label_style = self.label_style - # Oracle Database however does not allow FOR UPDATE on the subquery, - # and the Oracle Database dialects ignore it, plus for PostgreSQL, - # MySQL we expect that all elements of the row are locked, so also put - # it on the outside (except in the case of PG when OF is used) + # Oracle however does not allow FOR UPDATE on the subquery, + # and the Oracle dialect ignores it, plus for PostgreSQL, MySQL + # we expect that all elements of the row are locked, so also put it + # on the outside (except in the case of PG when OF is used) if ( self._for_update_arg is not None and self._for_update_arg.of is None @@ -1864,6 +1774,8 @@ class ORMSelectCompileState(ORMCompileState, SelectState): "selectable/table as join target" ) + of_type = None + if isinstance(onclause, interfaces.PropComparator): # descriptor/property given (or determined); this tells us # explicitly what the expected "left" side of the join is. @@ -2510,12 +2422,9 @@ def _column_descriptions( "type": ent.type, "aliased": getattr(insp_ent, "is_aliased_class", False), "expr": ent.expr, - "entity": ( - getattr(insp_ent, "entity", None) - if ent.entity_zero is not None - and not insp_ent.is_clause_element - else None - ), + "entity": getattr(insp_ent, "entity", None) + if ent.entity_zero is not None and not insp_ent.is_clause_element + else None, } for ent, insp_ent in [ (_ent, _ent.entity_zero) for _ent in ctx._entities @@ -2525,7 +2434,7 @@ def _column_descriptions( def _legacy_filter_by_entity_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]], + query_or_augmented_select: Union[Query[Any], Select[Any]] ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if self._setup_joins: @@ -2540,7 +2449,7 @@ def _legacy_filter_by_entity_zero( def _entity_from_pre_ent_zero( - query_or_augmented_select: Union[Query[Any], Select[Any]], + query_or_augmented_select: Union[Query[Any], Select[Any]] ) -> Optional[_InternalEntityType[Any]]: self = query_or_augmented_select if not self._raw_columns: @@ -2598,7 +2507,7 @@ class _QueryEntity: def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: Optional[_DMLReturningColFilter], + adapter: DMLReturningColFilter, ) -> None: raise NotImplementedError() @@ -2800,7 +2709,7 @@ class _MapperEntity(_QueryEntity): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: Optional[_DMLReturningColFilter], + adapter: DMLReturningColFilter, ) -> None: loading._setup_entity_query( compile_state, @@ -2956,13 +2865,6 @@ class _BundleEntity(_QueryEntity): for ent in self._entities: ent.setup_compile_state(compile_state) - def setup_dml_returning_compile_state( - self, - compile_state: ORMCompileState, - adapter: Optional[_DMLReturningColFilter], - ) -> None: - return self.setup_compile_state(compile_state) - def row_processor(self, context, result): procs, labels, extra = zip( *[ent.row_processor(context, result) for ent in self._entities] @@ -3126,10 +3028,7 @@ class _RawColumnEntity(_ColumnEntity): if not is_current_entities or column._is_text_clause: self._label_name = None else: - if parent_bundle: - self._label_name = column._proxy_key - else: - self._label_name = compile_state._label_convention(column) + self._label_name = compile_state._label_convention(column) if parent_bundle: parent_bundle._entities.append(self) @@ -3149,7 +3048,7 @@ class _RawColumnEntity(_ColumnEntity): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: Optional[_DMLReturningColFilter], + adapter: DMLReturningColFilter, ) -> None: return self.setup_compile_state(compile_state) @@ -3223,12 +3122,9 @@ class _ORMColumnEntity(_ColumnEntity): self.raw_column_index = raw_column_index if is_current_entities: - if parent_bundle: - self._label_name = orm_key if orm_key else column._proxy_key - else: - self._label_name = compile_state._label_convention( - column, col_name=orm_key - ) + self._label_name = compile_state._label_convention( + column, col_name=orm_key + ) else: self._label_name = None @@ -3266,13 +3162,10 @@ class _ORMColumnEntity(_ColumnEntity): def setup_dml_returning_compile_state( self, compile_state: ORMCompileState, - adapter: Optional[_DMLReturningColFilter], + adapter: DMLReturningColFilter, ) -> None: - - self._fetch_column = column = self.column - if adapter: - column = adapter(column, False) - + self._fetch_column = self.column + column = adapter(self.column, False) if column is not None: compile_state.dedupe_columns.add(column) compile_state.primary_columns.append(column) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_api.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_api.py index 8d2e90f..80c85f1 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_api.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_api.py @@ -1,5 +1,5 @@ -# orm/decl_api.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# orm/declarative/api.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -14,6 +14,7 @@ import re import typing from typing import Any from typing import Callable +from typing import cast from typing import ClassVar from typing import Dict from typing import FrozenSet @@ -71,16 +72,12 @@ from ..sql.selectable import FromClause from ..util import hybridmethod from ..util import hybridproperty from ..util import typing as compat_typing -from ..util import warn_deprecated from ..util.typing import CallableReference -from ..util.typing import de_optionalize_union_types from ..util.typing import flatten_newtype from ..util.typing import is_generic from ..util.typing import is_literal from ..util.typing import is_newtype -from ..util.typing import is_pep695 from ..util.typing import Literal -from ..util.typing import LITERAL_TYPES from ..util.typing import Self if TYPE_CHECKING: @@ -209,7 +206,7 @@ def synonym_for( :paramref:`.orm.synonym.descriptor` parameter:: class MyClass(Base): - __tablename__ = "my_table" + __tablename__ = 'my_table' id = Column(Integer, primary_key=True) _job_status = Column("job_status", String(50)) @@ -315,13 +312,17 @@ class _declared_directive(_declared_attr_common, Generic[_T]): self, fn: Callable[..., _T], cascading: bool = False, - ): ... + ): + ... - def __get__(self, instance: Optional[object], owner: Any) -> _T: ... + def __get__(self, instance: Optional[object], owner: Any) -> _T: + ... - def __set__(self, instance: Any, value: Any) -> None: ... + def __set__(self, instance: Any, value: Any) -> None: + ... - def __delete__(self, instance: Any) -> None: ... + def __delete__(self, instance: Any) -> None: + ... def __call__(self, fn: Callable[..., _TT]) -> _declared_directive[_TT]: # extensive fooling of mypy underway... @@ -375,21 +376,20 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): for subclasses:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id: Mapped[int] = mapped_column(primary_key=True) type: Mapped[str] = mapped_column(String(50)) @declared_attr.directive def __mapper_args__(cls) -> Dict[str, Any]: - if cls.__name__ == "Employee": + if cls.__name__ == 'Employee': return { - "polymorphic_on": cls.type, - "polymorphic_identity": "Employee", + "polymorphic_on":cls.type, + "polymorphic_identity":"Employee" } else: - return {"polymorphic_identity": cls.__name__} - + return {"polymorphic_identity":cls.__name__} class Engineer(Employee): pass @@ -427,11 +427,14 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): self, fn: _DeclaredAttrDecorated[_T], cascading: bool = False, - ): ... + ): + ... - def __set__(self, instance: Any, value: Any) -> None: ... + def __set__(self, instance: Any, value: Any) -> None: + ... - def __delete__(self, instance: Any) -> None: ... + def __delete__(self, instance: Any) -> None: + ... # this is the Mapped[] API where at class descriptor get time we want # the type checker to see InstrumentedAttribute[_T]. However the @@ -440,14 +443,17 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: ... + ) -> InstrumentedAttribute[_T]: + ... @overload - def __get__(self, instance: object, owner: Any) -> _T: ... + def __get__(self, instance: object, owner: Any) -> _T: + ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: ... + ) -> Union[InstrumentedAttribute[_T], _T]: + ... @hybridmethod def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: @@ -488,7 +494,6 @@ def declarative_mixin(cls: Type[_T]) -> Type[_T]: from sqlalchemy.orm import declared_attr from sqlalchemy.orm import declarative_mixin - @declarative_mixin class MyMixin: @@ -496,11 +501,10 @@ def declarative_mixin(cls: Type[_T]) -> Type[_T]: def __tablename__(cls): return cls.__name__.lower() - __table_args__ = {"mysql_engine": "InnoDB"} - __mapper_args__ = {"always_refresh": True} - - id = Column(Integer, primary_key=True) + __table_args__ = {'mysql_engine': 'InnoDB'} + __mapper_args__= {'always_refresh': True} + id = Column(Integer, primary_key=True) class MyModel(MyMixin, Base): name = Column(String(1000)) @@ -512,9 +516,6 @@ def declarative_mixin(cls: Type[_T]) -> Type[_T]: .. versionadded:: 1.4.6 - .. legacy:: This api is considered legacy and will be deprecated in the next - SQLAlchemy version. - .. seealso:: :ref:`orm_mixins_toplevel` @@ -593,7 +594,6 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): dataclass_callable: Union[ _NoArg, Callable[..., Type[Any]] ] = _NoArg.NO_ARG, - **kw: Any, ) -> None: apply_dc_transforms: _DataclassArguments = { "init": init, @@ -618,11 +618,11 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): for k, v in apply_dc_transforms.items() } else: - cls._sa_apply_dc_transforms = current_transforms = ( - apply_dc_transforms - ) + cls._sa_apply_dc_transforms = ( + current_transforms + ) = apply_dc_transforms - super().__init_subclass__(**kw) + super().__init_subclass__() if not _is_mapped_class(cls): new_anno = ( @@ -646,10 +646,10 @@ class DeclarativeBase( from sqlalchemy.orm import DeclarativeBase - class Base(DeclarativeBase): pass + The above ``Base`` class is now usable as the base for new declarative mappings. The superclass makes use of the ``__init_subclass__()`` method to set up new classes and metaclasses aren't used. @@ -672,12 +672,11 @@ class DeclarativeBase( bigint = Annotated[int, "bigint"] my_metadata = MetaData() - class Base(DeclarativeBase): metadata = my_metadata type_annotation_map = { str: String().with_variant(String(255), "mysql", "mariadb"), - bigint: BigInteger(), + bigint: BigInteger() } Class-level attributes which may be specified include: @@ -752,9 +751,11 @@ class DeclarativeBase( if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: ... + def _sa_inspect_type(self) -> Mapper[Self]: + ... - def _sa_inspect_instance(self) -> InstanceState[Self]: ... + def _sa_inspect_instance(self) -> InstanceState[Self]: + ... _sa_registry: ClassVar[_RegistryType] @@ -835,15 +836,16 @@ class DeclarativeBase( """ - def __init__(self, **kw: Any): ... + def __init__(self, **kw: Any): + ... - def __init_subclass__(cls, **kw: Any) -> None: + def __init_subclass__(cls) -> None: if DeclarativeBase in cls.__bases__: _check_not_declarative(cls, DeclarativeBase) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) - super().__init_subclass__(**kw) + super().__init_subclass__() def _check_not_declarative(cls: Type[Any], base: Type[Any]) -> None: @@ -920,9 +922,11 @@ class DeclarativeBaseNoMeta( if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: ... + def _sa_inspect_type(self) -> Mapper[Self]: + ... - def _sa_inspect_instance(self) -> InstanceState[Self]: ... + def _sa_inspect_instance(self) -> InstanceState[Self]: + ... __tablename__: Any """String name to assign to the generated @@ -957,15 +961,15 @@ class DeclarativeBaseNoMeta( """ - def __init__(self, **kw: Any): ... + def __init__(self, **kw: Any): + ... - def __init_subclass__(cls, **kw: Any) -> None: + def __init_subclass__(cls) -> None: if DeclarativeBaseNoMeta in cls.__bases__: _check_not_declarative(cls, DeclarativeBaseNoMeta) _setup_declarative_base(cls) else: _as_declarative(cls._sa_registry, cls, cls.__dict__) - super().__init_subclass__(**kw) def add_mapped_attribute( @@ -1230,34 +1234,38 @@ class registry: self.type_annotation_map.update( { - de_optionalize_union_types(typ): sqltype + sub_type: sqltype for typ, sqltype in type_annotation_map.items() + for sub_type in compat_typing.expand_unions( + typ, include_union=True, discard_none=True + ) } ) def _resolve_type( - self, python_type: _MatchedOnType, _do_fallbacks: bool = True + self, python_type: _MatchedOnType ) -> Optional[sqltypes.TypeEngine[Any]]: - python_type_type: Type[Any] search: Iterable[Tuple[_MatchedOnType, Type[Any]]] + python_type_type: Type[Any] if is_generic(python_type): if is_literal(python_type): - python_type_type = python_type # type: ignore[assignment] + python_type_type = cast("Type[Any]", python_type) - search = ( + search = ( # type: ignore[assignment] (python_type, python_type_type), - *((lt, python_type_type) for lt in LITERAL_TYPES), + (Literal, python_type_type), ) else: python_type_type = python_type.__origin__ search = ((python_type, python_type_type),) - elif isinstance(python_type, type): - python_type_type = python_type - search = ((pt, pt) for pt in python_type_type.__mro__) - else: - python_type_type = python_type # type: ignore[assignment] + elif is_newtype(python_type): + python_type_type = flatten_newtype(python_type) search = ((python_type, python_type_type),) + else: + python_type_type = cast("Type[Any]", python_type) + flattened = None + search = ((pt, pt) for pt in python_type_type.__mro__) for pt, flattened in search: # we search through full __mro__ for types. however... @@ -1281,39 +1289,6 @@ class registry: if resolved_sql_type is not None: return resolved_sql_type - # 2.0 fallbacks - if _do_fallbacks: - python_type_to_check: Any = None - kind = None - if is_pep695(python_type): - # NOTE: assume there aren't type alias types of new types. - python_type_to_check = python_type - while is_pep695(python_type_to_check): - python_type_to_check = python_type_to_check.__value__ - python_type_to_check = de_optionalize_union_types( - python_type_to_check - ) - kind = "TypeAliasType" - if is_newtype(python_type): - python_type_to_check = flatten_newtype(python_type) - kind = "NewType" - - if python_type_to_check is not None: - res_after_fallback = self._resolve_type( - python_type_to_check, False - ) - if res_after_fallback is not None: - assert kind is not None - warn_deprecated( - f"Matching the provided {kind} '{python_type}' on " - "its resolved value without matching it in the " - "type_annotation_map is deprecated; add this type to " - "the type_annotation_map to allow it to match " - "explicitly.", - "2.0", - ) - return res_after_fallback - return None @property @@ -1506,7 +1481,6 @@ class registry: Base = mapper_registry.generate_base() - class MyClass(Base): __tablename__ = "my_table" id = Column(Integer, primary_key=True) @@ -1519,7 +1493,6 @@ class registry: mapper_registry = registry() - class Base(metaclass=DeclarativeMeta): __abstract__ = True registry = mapper_registry @@ -1605,7 +1578,8 @@ class registry: ), ) @overload - def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: ... + def mapped_as_dataclass(self, __cls: Type[_O]) -> Type[_O]: + ... @overload def mapped_as_dataclass( @@ -1620,7 +1594,8 @@ class registry: match_args: Union[_NoArg, bool] = ..., kw_only: Union[_NoArg, bool] = ..., dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., - ) -> Callable[[Type[_O]], Type[_O]]: ... + ) -> Callable[[Type[_O]], Type[_O]]: + ... def mapped_as_dataclass( self, @@ -1685,10 +1660,9 @@ class registry: mapper_registry = registry() - @mapper_registry.mapped class Foo: - __tablename__ = "some_table" + __tablename__ = 'some_table' id = Column(Integer, primary_key=True) name = Column(String) @@ -1728,17 +1702,15 @@ class registry: mapper_registry = registry() - @mapper_registry.as_declarative_base() class Base: @declared_attr def __tablename__(cls): return cls.__name__.lower() - id = Column(Integer, primary_key=True) - - class MyMappedClass(Base): ... + class MyMappedClass(Base): + # ... All keyword arguments passed to :meth:`_orm.registry.as_declarative_base` are passed @@ -1768,14 +1740,12 @@ class registry: mapper_registry = registry() - class Foo: - __tablename__ = "some_table" + __tablename__ = 'some_table' id = Column(Integer, primary_key=True) name = Column(String) - mapper = mapper_registry.map_declaratively(Foo) This function is more conveniently invoked indirectly via either the @@ -1828,14 +1798,12 @@ class registry: my_table = Table( "my_table", mapper_registry.metadata, - Column("id", Integer, primary_key=True), + Column('id', Integer, primary_key=True) ) - class MyClass: pass - mapper_registry.map_imperatively(MyClass, my_table) See the section :ref:`orm_imperative_mapping` for complete background @@ -1882,17 +1850,15 @@ def as_declarative(**kw: Any) -> Callable[[Type[_T]], Type[_T]]: from sqlalchemy.orm import as_declarative - @as_declarative() class Base: @declared_attr def __tablename__(cls): return cls.__name__.lower() - id = Column(Integer, primary_key=True) - - class MyMappedClass(Base): ... + class MyMappedClass(Base): + # ... .. seealso:: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_base.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_base.py index 418e312..d5ef3db 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/decl_base.py @@ -1,5 +1,5 @@ -# orm/decl_base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# ext/declarative/base.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -65,11 +65,11 @@ from ..sql.schema import Column from ..sql.schema import Table from ..util import topological from ..util.typing import _AnnotationScanType -from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_literal from ..util.typing import Protocol from ..util.typing import TypedDict +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _ClassDict @@ -98,12 +98,12 @@ class MappedClassProtocol(Protocol[_O]): __mapper__: Mapper[_O] __table__: FromClause - def __call__(self, **kw: Any) -> _O: ... + def __call__(self, **kw: Any) -> _O: + ... class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): "Internal more detailed version of ``MappedClassProtocol``." - metadata: MetaData __tablename__: str __mapper_args__: _MapperKwArgs @@ -111,9 +111,11 @@ class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): _sa_apply_dc_transforms: Optional[_DataclassArguments] - def __declare_first__(self) -> None: ... + def __declare_first__(self) -> None: + ... - def __declare_last__(self) -> None: ... + def __declare_last__(self) -> None: + ... class _DataclassArguments(TypedDict): @@ -432,7 +434,7 @@ class _ImperativeMapperConfig(_MapperConfig): class _CollectedAnnotation(NamedTuple): raw_annotation: _AnnotationScanType mapped_container: Optional[Type[Mapped[Any]]] - extracted_mapped_annotation: Union[_AnnotationScanType, str] + extracted_mapped_annotation: Union[Type[Any], str] is_dataclass: bool attr_value: Any originating_module: str @@ -454,7 +456,6 @@ class _ClassScanMapperConfig(_MapperConfig): "tablename", "mapper_args", "mapper_args_fn", - "table_fn", "inherits", "single", "allow_dataclass_fields", @@ -761,7 +762,7 @@ class _ClassScanMapperConfig(_MapperConfig): _include_dunders = self._include_dunders mapper_args_fn = None table_args = inherited_table_args = None - table_fn = None + tablename = None fixed_table = "__table__" in clsdict_view @@ -842,22 +843,6 @@ class _ClassScanMapperConfig(_MapperConfig): ) if not tablename and (not class_mapped or check_decl): tablename = cls_as_Decl.__tablename__ - elif name == "__table__": - check_decl = _check_declared_props_nocascade( - obj, name, cls - ) - # if a @declared_attr using "__table__" is detected, - # wrap up a callable to look for "__table__" from - # the final concrete class when we set up a table. - # this was fixed by - # #11509, regression in 2.0 from version 1.4. - if check_decl and not table_fn: - # don't even invoke __table__ until we're ready - def _table_fn() -> FromClause: - return cls_as_Decl.__table__ - - table_fn = _table_fn - elif name == "__table_args__": check_decl = _check_declared_props_nocascade( obj, name, cls @@ -874,10 +859,9 @@ class _ClassScanMapperConfig(_MapperConfig): if base is not cls: inherited_table_args = True else: - # any other dunder names; should not be here - # as we have tested for all four names in - # _include_dunders - assert False + # skip all other dunder names, which at the moment + # should only be __table__ + continue elif class_mapped: if _is_declarative_props(obj) and not obj._quiet: util.warn( @@ -924,9 +908,9 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - collected_attributes[name] = column_copies[obj] = ( - ret - ) = obj.__get__(obj, cls) + collected_attributes[name] = column_copies[ + obj + ] = ret = obj.__get__(obj, cls) setattr(cls, name, ret) else: if is_dataclass_field: @@ -963,9 +947,9 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - collected_attributes[name] = column_copies[obj] = ( - ret - ) + collected_attributes[name] = column_copies[ + obj + ] = ret if ( isinstance(ret, (Column, MapperProperty)) @@ -1050,7 +1034,6 @@ class _ClassScanMapperConfig(_MapperConfig): self.table_args = table_args self.tablename = tablename self.mapper_args_fn = mapper_args_fn - self.table_fn = table_fn def _setup_dataclasses_transforms(self) -> None: dataclass_setup_arguments = self.dataclass_setup_arguments @@ -1068,16 +1051,6 @@ class _ClassScanMapperConfig(_MapperConfig): "'@registry.mapped_as_dataclass'" ) - # can't create a dataclass if __table__ is already there. This would - # fail an assertion when calling _get_arguments_for_make_dataclass: - # assert False, "Mapped[] received without a mapping declaration" - if "__table__" in self.cls.__dict__: - raise exc.InvalidRequestError( - f"Class {self.cls} already defines a '__table__'. " - "ORM Annotated Dataclasses do not support a pre-existing " - "'__table__' element" - ) - warn_for_non_dc_attrs = collections.defaultdict(list) def _allow_dataclass_field( @@ -1157,9 +1130,9 @@ class _ClassScanMapperConfig(_MapperConfig): defaults = {} for item in field_list: if len(item) == 2: - name, tp = item + name, tp = item # type: ignore elif len(item) == 3: - name, tp, spec = item + name, tp, spec = item # type: ignore defaults[name] = spec else: assert False @@ -1224,9 +1197,9 @@ class _ClassScanMapperConfig(_MapperConfig): restored = None try: - dataclass_callable( # type: ignore[call-overload] + dataclass_callable( klass, - **{ # type: ignore[call-overload,unused-ignore] + **{ k: v for k, v in dataclass_setup_arguments.items() if v is not _NoArg.NO_ARG and k != "dataclass_callable" @@ -1297,6 +1270,8 @@ class _ClassScanMapperConfig(_MapperConfig): or isinstance(attr_value, _MappedAttribute) ) ) + else: + is_dataclass_field = False is_dataclass_field = False extracted = _extract_mapped_subtype( @@ -1307,8 +1282,10 @@ class _ClassScanMapperConfig(_MapperConfig): type(attr_value), required=False, is_dataclass_field=is_dataclass_field, - expect_mapped=expect_mapped and not is_dataclass, + expect_mapped=expect_mapped + and not is_dataclass, # self.allow_dataclass_fields, ) + if extracted is None: # ClassVar can come out here return None @@ -1316,9 +1293,9 @@ class _ClassScanMapperConfig(_MapperConfig): extracted_mapped_annotation, mapped_container = extracted if attr_value is None and not is_literal(extracted_mapped_annotation): - for elem in get_args(extracted_mapped_annotation): - if is_fwd_ref( - elem, check_generic=True, check_for_plain_string=True + for elem in typing_get_args(extracted_mapped_annotation): + if isinstance(elem, str) or is_fwd_ref( + elem, check_generic=True ): elem = de_stringify_annotation( self.cls, @@ -1576,7 +1553,7 @@ class _ClassScanMapperConfig(_MapperConfig): is_dataclass, ) except NameError as ne: - raise orm_exc.MappedAnnotationError( + raise exc.ArgumentError( f"Could not resolve all types within mapped " f'annotation: "{annotation}". Ensure all ' f"types are written correctly and are " @@ -1600,15 +1577,9 @@ class _ClassScanMapperConfig(_MapperConfig): "default_factory", "repr", "default", - "dataclass_metadata", ] else: - argnames = [ - "init", - "default_factory", - "repr", - "dataclass_metadata", - ] + argnames = ["init", "default_factory", "repr"] args = { a @@ -1719,11 +1690,7 @@ class _ClassScanMapperConfig(_MapperConfig): manager = attributes.manager_of_class(cls) - if ( - self.table_fn is None - and "__table__" not in clsdict_view - and table is None - ): + if "__table__" not in clsdict_view and table is None: if hasattr(cls, "__table_cls__"): table_cls = cast( Type[Table], @@ -1769,12 +1736,7 @@ class _ClassScanMapperConfig(_MapperConfig): ) else: if table is None: - if self.table_fn: - table = self.set_cls_attribute( - "__table__", self.table_fn() - ) - else: - table = cls_as_Decl.__table__ + table = cls_as_Decl.__table__ if declared_columns: for c in declared_columns: if not table.c.contains_column(c): @@ -2023,7 +1985,8 @@ class _DeferredMapperConfig(_ClassScanMapperConfig): def _early_mapping(self, mapper_kw: _MapperKwArgs) -> None: pass - @property + # mypy disallows plain property override of variable + @property # type: ignore def cls(self) -> Type[Any]: return self._cls() # type: ignore diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/dependency.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/dependency.py index a8cafdd..e941dbc 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/dependency.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/dependency.py @@ -1,5 +1,5 @@ # orm/dependency.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -7,7 +7,9 @@ # mypy: ignore-errors -"""Relationship dependencies.""" +"""Relationship dependencies. + +""" from __future__ import annotations @@ -165,11 +167,9 @@ class DependencyProcessor: sum_ = state.manager[self.key].impl.get_all_pending( state, state.dict, - ( - self._passive_delete_flag - if isdelete - else attributes.PASSIVE_NO_INITIALIZE - ), + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE, ) if not sum_: @@ -1052,7 +1052,7 @@ class ManyToManyDP(DependencyProcessor): # so that prop_has_changes() returns True for state in states: if self._pks_changed(uowcommit, state): - uowcommit.get_attribute_history( + history = uowcommit.get_attribute_history( state, self.key, attributes.PASSIVE_OFF ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/descriptor_props.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/descriptor_props.py index 43c4aa3..c1fe9de 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/descriptor_props.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/descriptor_props.py @@ -1,5 +1,5 @@ # orm/descriptor_props.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -53,10 +53,9 @@ from .. import util from ..sql import expression from ..sql import operators from ..sql.elements import BindParameter -from ..util.typing import get_args from ..util.typing import is_fwd_ref from ..util.typing import is_pep593 - +from ..util.typing import typing_get_args if typing.TYPE_CHECKING: from ._typing import _InstanceDict @@ -99,11 +98,6 @@ class DescriptorProperty(MapperProperty[_T]): descriptor: DescriptorReference[Any] - def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: - raise NotImplementedError( - "This MapperProperty does not implement column loader strategies" - ) - def get_history( self, state: InstanceState[Any], @@ -370,7 +364,7 @@ class CompositeProperty( argument = extracted_mapped_annotation if is_pep593(argument): - argument = get_args(argument)[0] + argument = typing_get_args(argument)[0] if argument and self.composite_class is None: if isinstance(argument, str) or is_fwd_ref( @@ -393,9 +387,7 @@ class CompositeProperty( self.composite_class = argument if is_dataclass(self.composite_class): - self._setup_for_dataclass( - decl_scan, registry, cls, originating_module, key - ) + self._setup_for_dataclass(registry, cls, originating_module, key) else: for attr in self.attrs: if ( @@ -427,19 +419,18 @@ class CompositeProperty( and self.composite_class not in _composite_getters ): if self._generated_composite_accessor is not None: - _composite_getters[self.composite_class] = ( - self._generated_composite_accessor - ) + _composite_getters[ + self.composite_class + ] = self._generated_composite_accessor elif hasattr(self.composite_class, "__composite_values__"): - _composite_getters[self.composite_class] = ( - lambda obj: obj.__composite_values__() - ) + _composite_getters[ + self.composite_class + ] = lambda obj: obj.__composite_values__() @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") def _setup_for_dataclass( self, - decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], @@ -467,7 +458,6 @@ class CompositeProperty( if isinstance(attr, MappedColumn): attr.declarative_scan_for_composite( - decl_scan, registry, cls, originating_module, @@ -509,9 +499,6 @@ class CompositeProperty( props.append(prop) return props - def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: - return self._comparable_elements - @util.non_memoized_property @util.preload_module("orm.properties") def columns(self) -> Sequence[Column[Any]]: @@ -794,9 +781,7 @@ class CompositeProperty( elif isinstance(self.prop.composite_class, type) and isinstance( value, self.prop.composite_class ): - values = self.prop._composite_values_from_instance( - value # type: ignore[arg-type] - ) + values = self.prop._composite_values_from_instance(value) else: raise sa_exc.ArgumentError( "Can't UPDATE composite attribute %s to %r" @@ -1011,9 +996,6 @@ class SynonymProperty(DescriptorProperty[_T]): ) return attr.property - def _column_strategy_attrs(self) -> Sequence[QueryableAttribute[Any]]: - return (getattr(self.parent.class_, self.name),) - def _comparator_factory(self, mapper: Mapper[Any]) -> SQLORMOperations[_T]: prop = self._proxied_object diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/dynamic.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/dynamic.py index 3c81c39..1d0c036 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/dynamic.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/dynamic.py @@ -1,5 +1,5 @@ # orm/dynamic.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -161,12 +161,10 @@ class AppenderMixin(AbstractCollectionWriter[_T]): return result.IteratorResult( result.SimpleResultMetaData([self.attr.class_.__name__]), - iter( - self.attr._get_collection_history( - attributes.instance_state(self.instance), - PassiveFlag.PASSIVE_NO_INITIALIZE, - ).added_items - ), + self.attr._get_collection_history( # type: ignore[arg-type] + attributes.instance_state(self.instance), + PassiveFlag.PASSIVE_NO_INITIALIZE, + ).added_items, _source_supports_scalars=True, ).scalars() else: @@ -174,7 +172,8 @@ class AppenderMixin(AbstractCollectionWriter[_T]): if TYPE_CHECKING: - def __iter__(self) -> Iterator[_T]: ... + def __iter__(self) -> Iterator[_T]: + ... def __getitem__(self, index: Any) -> Union[_T, List[_T]]: sess = self.session diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/evaluator.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/evaluator.py index 57aae5a..f3796f0 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/evaluator.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/evaluator.py @@ -1,5 +1,5 @@ # orm/evaluator.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,7 +28,6 @@ from .. import exc from .. import inspect from ..sql import and_ from ..sql import operators -from ..sql.sqltypes import Concatenable from ..sql.sqltypes import Integer from ..sql.sqltypes import Numeric from ..util import warn_deprecated @@ -312,16 +311,6 @@ class _EvaluatorCompiler: def visit_concat_op_binary_op( self, operator, eval_left, eval_right, clause ): - - if not issubclass( - clause.left.type._type_affinity, Concatenable - ) or not issubclass(clause.right.type._type_affinity, Concatenable): - raise UnevaluatableError( - f"Cannot evaluate concatenate operator " - f'"{operator.__name__}" for ' - f"datatypes {clause.left.type}, {clause.right.type}" - ) - return self._straight_evaluate( lambda a, b: a + b, eval_left, eval_right, clause ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/events.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/events.py index 5af78fc..e7e3e32 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/events.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/events.py @@ -1,11 +1,13 @@ # orm/events.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""ORM event interfaces.""" +"""ORM event interfaces. + +""" from __future__ import annotations from typing import Any @@ -205,12 +207,10 @@ class InstanceEvents(event.Events[ClassManager[Any]]): from sqlalchemy import event - def my_load_listener(target, context): print("on load!") - - event.listen(SomeClass, "load", my_load_listener) + event.listen(SomeClass, 'load', my_load_listener) Available targets include: @@ -466,7 +466,8 @@ class InstanceEvents(event.Events[ClassManager[Any]]): the existing loading context is maintained for the object after the event is called:: - @event.listens_for(SomeClass, "load", restore_load_context=True) + @event.listens_for( + SomeClass, "load", restore_load_context=True) def on_load(instance, context): instance.some_unloaded_attribute @@ -493,15 +494,15 @@ class InstanceEvents(event.Events[ClassManager[Any]]): .. seealso:: - :ref:`mapped_class_load_events` - :meth:`.InstanceEvents.init` :meth:`.InstanceEvents.refresh` :meth:`.SessionEvents.loaded_as_persistent` - """ # noqa: E501 + :ref:`mapping_constructors` + + """ def refresh( self, target: _O, context: QueryContext, attrs: Optional[Iterable[str]] @@ -533,8 +534,6 @@ class InstanceEvents(event.Events[ClassManager[Any]]): .. seealso:: - :ref:`mapped_class_load_events` - :meth:`.InstanceEvents.load` """ @@ -578,8 +577,6 @@ class InstanceEvents(event.Events[ClassManager[Any]]): .. seealso:: - :ref:`mapped_class_load_events` - :ref:`orm_server_defaults` :ref:`metadata_defaults_toplevel` @@ -728,9 +725,9 @@ class _EventsHold(event.RefCollection[_ET]): class _InstanceEventsHold(_EventsHold[_ET]): - all_holds: weakref.WeakKeyDictionary[Any, Any] = ( - weakref.WeakKeyDictionary() - ) + all_holds: weakref.WeakKeyDictionary[ + Any, Any + ] = weakref.WeakKeyDictionary() def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) @@ -748,7 +745,6 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy import event - def my_before_insert_listener(mapper, connection, target): # execute a stored procedure upon INSERT, # apply the value to the row to be inserted @@ -756,10 +752,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): text("select my_special_function(%d)" % target.special_number) ).scalar() - # associate the listener function with SomeClass, # to execute during the "before_insert" hook - event.listen(SomeClass, "before_insert", my_before_insert_listener) + event.listen( + SomeClass, 'before_insert', my_before_insert_listener) Available targets include: @@ -925,10 +921,9 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): Base = declarative_base() - @event.listens_for(Base, "instrument_class", propagate=True) def on_new_class(mapper, cls_): - "..." + " ... " :param mapper: the :class:`_orm.Mapper` which is the target of this event. @@ -984,7 +979,7 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): symbol which indicates to the :func:`.configure_mappers` call that this particular mapper (or hierarchy of mappers, if ``propagate=True`` is used) should be skipped in the current configuration run. When one or - more mappers are skipped, the "new mappers" flag will remain set, + more mappers are skipped, the he "new mappers" flag will remain set, meaning the :func:`.configure_mappers` function will continue to be called when mappers are used, to continue to try to configure all available mappers. @@ -993,7 +988,7 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): :meth:`.MapperEvents.before_configured`, :meth:`.MapperEvents.after_configured`, and :meth:`.MapperEvents.mapper_configured`, the - :meth:`.MapperEvents.before_mapper_configured` event provides for a + :meth;`.MapperEvents.before_mapper_configured` event provides for a meaningful return value when it is registered with the ``retval=True`` parameter. @@ -1007,16 +1002,13 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): DontConfigureBase = declarative_base() - @event.listens_for( DontConfigureBase, - "before_mapper_configured", - retval=True, - propagate=True, - ) + "before_mapper_configured", retval=True, propagate=True) def dont_configure(mapper, cls): return EXT_SKIP + .. seealso:: :meth:`.MapperEvents.before_configured` @@ -1098,9 +1090,9 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy.orm import Mapper - @event.listens_for(Mapper, "before_configured") - def go(): ... + def go(): + ... Contrast this event to :meth:`.MapperEvents.after_configured`, which is invoked after the series of mappers has been configured, @@ -1118,9 +1110,10 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy.orm import mapper - @event.listens_for(mapper, "before_configured", once=True) - def go(): ... + def go(): + ... + .. seealso:: @@ -1157,9 +1150,9 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy.orm import Mapper - @event.listens_for(Mapper, "after_configured") - def go(): ... + def go(): + # ... Theoretically this event is called once per application, but is actually called any time new mappers @@ -1171,9 +1164,9 @@ class MapperEvents(event.Events[mapperlib.Mapper[Any]]): from sqlalchemy.orm import mapper - @event.listens_for(mapper, "after_configured", once=True) - def go(): ... + def go(): + # ... .. seealso:: @@ -1560,11 +1553,9 @@ class SessionEvents(event.Events[Session]): from sqlalchemy import event from sqlalchemy.orm import sessionmaker - def my_before_commit(session): print("before commit!") - Session = sessionmaker() event.listen(Session, "before_commit", my_before_commit) @@ -1600,7 +1591,7 @@ class SessionEvents(event.Events[Session]): _dispatch_target = Session def _lifecycle_event( # type: ignore [misc] - fn: Callable[[SessionEvents, Session, Any], None], + fn: Callable[[SessionEvents, Session, Any], None] ) -> Callable[[SessionEvents, Session, Any], None]: _sessionevents_lifecycle_event_names.add(fn.__name__) return fn @@ -1784,7 +1775,7 @@ class SessionEvents(event.Events[Session]): @event.listens_for(session, "after_transaction_create") def after_transaction_create(session, transaction): if transaction.parent is None: - ... # work with top-level transaction + # work with top-level transaction To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the :attr:`.SessionTransaction.nested` attribute:: @@ -1792,7 +1783,8 @@ class SessionEvents(event.Events[Session]): @event.listens_for(session, "after_transaction_create") def after_transaction_create(session, transaction): if transaction.nested: - ... # work with SAVEPOINT transaction + # work with SAVEPOINT transaction + .. seealso:: @@ -1824,7 +1816,7 @@ class SessionEvents(event.Events[Session]): @event.listens_for(session, "after_transaction_create") def after_transaction_end(session, transaction): if transaction.parent is None: - ... # work with top-level transaction + # work with top-level transaction To detect if the :class:`.SessionTransaction` is a SAVEPOINT, use the :attr:`.SessionTransaction.nested` attribute:: @@ -1832,7 +1824,8 @@ class SessionEvents(event.Events[Session]): @event.listens_for(session, "after_transaction_create") def after_transaction_end(session, transaction): if transaction.nested: - ... # work with SAVEPOINT transaction + # work with SAVEPOINT transaction + .. seealso:: @@ -1942,7 +1935,7 @@ class SessionEvents(event.Events[Session]): @event.listens_for(Session, "after_soft_rollback") def do_something(session, previous_transaction): if session.is_active: - session.execute(text("select * from some_table")) + session.execute("select * from some_table") :param session: The target :class:`.Session`. :param previous_transaction: The :class:`.SessionTransaction` @@ -2042,14 +2035,7 @@ class SessionEvents(event.Events[Session]): transaction: SessionTransaction, connection: Connection, ) -> None: - """Execute after a transaction is begun on a connection. - - .. note:: This event is called within the process of the - :class:`_orm.Session` modifying its own internal state. - To invoke SQL operations within this hook, use the - :class:`_engine.Connection` provided to the event; - do not run SQL operations using the :class:`_orm.Session` - directly. + """Execute after a transaction is begun on a connection :param session: The target :class:`.Session`. :param transaction: The :class:`.SessionTransaction`. @@ -2458,11 +2444,11 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): from sqlalchemy import event - - @event.listens_for(MyClass.collection, "append", propagate=True) + @event.listens_for(MyClass.collection, 'append', propagate=True) def my_append_listener(target, value, initiator): print("received append event for target: %s" % target) + Listeners have the option to return a possibly modified version of the value, when the :paramref:`.AttributeEvents.retval` flag is passed to :func:`.event.listen` or :func:`.event.listens_for`, such as below, @@ -2471,12 +2457,11 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): def validate_phone(target, value, oldvalue, initiator): "Strip non-numeric characters from a phone number" - return re.sub(r"\D", "", value) - + return re.sub(r'\D', '', value) # setup listener on UserContact.phone attribute, instructing # it to use the return value - listen(UserContact.phone, "set", validate_phone, retval=True) + listen(UserContact.phone, 'set', validate_phone, retval=True) A validation function like the above can also raise an exception such as :exc:`ValueError` to halt the operation. @@ -2486,7 +2471,7 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): as when using mapper inheritance patterns:: - @event.listens_for(MySuperClass.attr, "set", propagate=True) + @event.listens_for(MySuperClass.attr, 'set', propagate=True) def receive_set(target, value, initiator): print("value set: %s" % target) @@ -2719,12 +2704,10 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): from sqlalchemy.orm.attributes import OP_BULK_REPLACE - @event.listens_for(SomeObject.collection, "bulk_replace") def process_collection(target, values, initiator): values[:] = [_make_value(value) for value in values] - @event.listens_for(SomeObject.collection, "append", retval=True) def process_collection(target, value, initiator): # make sure bulk_replace didn't already do it @@ -2872,18 +2855,16 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): SOME_CONSTANT = 3.1415926 - class MyClass(Base): # ... some_attribute = Column(Numeric, default=SOME_CONSTANT) - @event.listens_for( - MyClass.some_attribute, "init_scalar", retval=True, propagate=True - ) + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) def _init_some_attribute(target, dict_, value): - dict_["some_attribute"] = SOME_CONSTANT + dict_['some_attribute'] = SOME_CONSTANT return SOME_CONSTANT Above, we initialize the attribute ``MyClass.some_attribute`` to the @@ -2919,10 +2900,9 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): SOME_CONSTANT = 3.1415926 - @event.listens_for( - MyClass.some_attribute, "init_scalar", retval=True, propagate=True - ) + MyClass.some_attribute, "init_scalar", + retval=True, propagate=True) def _init_some_attribute(target, dict_, value): # will also fire off attribute set events target.some_attribute = SOME_CONSTANT @@ -2959,7 +2939,7 @@ class AttributeEvents(event.Events[QueryableAttribute[Any]]): :ref:`examples_instrumentation` - see the ``active_column_defaults.py`` example. - """ # noqa: E501 + """ def init_collection( self, @@ -3097,8 +3077,8 @@ class QueryEvents(event.Events[Query[Any]]): @event.listens_for(Query, "before_compile", retval=True) def no_deleted(query): for desc in query.column_descriptions: - if desc["type"] is User: - entity = desc["entity"] + if desc['type'] is User: + entity = desc['entity'] query = query.filter(entity.deleted == False) return query @@ -3114,11 +3094,12 @@ class QueryEvents(event.Events[Query[Any]]): re-establish the query being cached, apply the event adding the ``bake_ok`` flag:: - @event.listens_for(Query, "before_compile", retval=True, bake_ok=True) + @event.listens_for( + Query, "before_compile", retval=True, bake_ok=True) def my_event(query): for desc in query.column_descriptions: - if desc["type"] is User: - entity = desc["entity"] + if desc['type'] is User: + entity = desc['entity'] query = query.filter(entity.deleted == False) return query @@ -3139,7 +3120,7 @@ class QueryEvents(event.Events[Query[Any]]): :ref:`baked_with_before_compile` - """ # noqa: E501 + """ def before_compile_update( self, query: Query[Any], update_context: BulkUpdate @@ -3159,13 +3140,11 @@ class QueryEvents(event.Events[Query[Any]]): @event.listens_for(Query, "before_compile_update", retval=True) def no_deleted(query, update_context): for desc in query.column_descriptions: - if desc["type"] is User: - entity = desc["entity"] + if desc['type'] is User: + entity = desc['entity'] query = query.filter(entity.deleted == False) - update_context.values["timestamp"] = datetime.datetime.now( - datetime.UTC - ) + update_context.values['timestamp'] = datetime.utcnow() return query The ``.values`` dictionary of the "update context" object can also @@ -3193,7 +3172,7 @@ class QueryEvents(event.Events[Query[Any]]): :meth:`.QueryEvents.before_compile_delete` - """ # noqa: E501 + """ def before_compile_delete( self, query: Query[Any], delete_context: BulkDelete @@ -3212,8 +3191,8 @@ class QueryEvents(event.Events[Query[Any]]): @event.listens_for(Query, "before_compile_delete", retval=True) def no_deleted(query, delete_context): for desc in query.column_descriptions: - if desc["type"] is User: - entity = desc["entity"] + if desc['type'] is User: + entity = desc['entity'] query = query.filter(entity.deleted == False) return query diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/exc.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/exc.py index a2f7c9f..f30e503 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/exc.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/exc.py @@ -1,5 +1,5 @@ # orm/exc.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -16,7 +16,6 @@ from typing import Type from typing import TYPE_CHECKING from typing import TypeVar -from .util import _mapper_property_as_plain_name from .. import exc as sa_exc from .. import util from ..exc import MultipleResultsFound # noqa @@ -65,15 +64,6 @@ class FlushError(sa_exc.SQLAlchemyError): """A invalid condition was detected during flush().""" -class MappedAnnotationError(sa_exc.ArgumentError): - """Raised when ORM annotated declarative cannot interpret the - expression present inside of the :class:`.Mapped` construct. - - .. versionadded:: 2.0.40 - - """ - - class UnmappedError(sa_exc.InvalidRequestError): """Base for exceptions that involve expected mappings not present.""" @@ -201,8 +191,8 @@ class LoaderStrategyException(sa_exc.InvalidRequestError): % ( util.clsname_as_plain_name(actual_strategy_type), requesting_property, - _mapper_property_as_plain_name(applied_to_property_type), - _mapper_property_as_plain_name(applies_to), + util.clsname_as_plain_name(applied_to_property_type), + util.clsname_as_plain_name(applies_to), ), ) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/identity.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/identity.py index 1808b2d..81140a9 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/identity.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/identity.py @@ -1,5 +1,5 @@ # orm/identity.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/instrumentation.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/instrumentation.py index f87023f..b12d80a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/instrumentation.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/instrumentation.py @@ -1,5 +1,5 @@ # orm/instrumentation.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -85,11 +85,13 @@ class _ExpiredAttributeLoaderProto(Protocol): state: state.InstanceState[Any], toload: Set[str], passive: base.PassiveFlag, - ) -> None: ... + ) -> None: + ... class _ManagerFactory(Protocol): - def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ... + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: + ... class ClassManager( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/interfaces.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/interfaces.py index dfcd130..a118b2a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/interfaces.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/interfaces.py @@ -1,5 +1,5 @@ # orm/interfaces.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,7 +29,6 @@ from typing import Dict from typing import Generic from typing import Iterator from typing import List -from typing import Mapping from typing import NamedTuple from typing import NoReturn from typing import Optional @@ -116,7 +115,7 @@ _TLS = TypeVar("_TLS", bound="Type[LoaderStrategy]") class ORMStatementRole(roles.StatementRole): __slots__ = () _role_name = ( - "Executable SQL or text() construct, including ORM aware objects" + "Executable SQL or text() construct, including ORM " "aware objects" ) @@ -150,17 +149,13 @@ class ORMColumnDescription(TypedDict): class _IntrospectsAnnotations: __slots__ = () - @classmethod - def _mapper_property_name(cls) -> str: - return cls.__name__ - def found_in_pep593_annotated(self) -> Any: """return a copy of this object to use in declarative when the object is found inside of an Annotated object.""" raise NotImplementedError( - f"Use of the {self._mapper_property_name()!r} " - "construct inside of an Annotated object is not yet supported." + f"Use of the {self.__class__} construct inside of an " + f"Annotated object is not yet supported." ) def declarative_scan( @@ -186,8 +181,7 @@ class _IntrospectsAnnotations: raise sa_exc.ArgumentError( f"Python typing annotation is required for attribute " f'"{cls.__name__}.{key}" when primary argument(s) for ' - f'"{self._mapper_property_name()}" ' - "construct are None or not present" + f'"{self.__class__.__name__}" construct are None or not present' ) @@ -207,8 +201,6 @@ class _AttributeOptions(NamedTuple): dataclasses_default_factory: Union[_NoArg, Callable[[], Any]] dataclasses_compare: Union[_NoArg, bool] dataclasses_kw_only: Union[_NoArg, bool] - dataclasses_hash: Union[_NoArg, bool, None] - dataclasses_dataclass_metadata: Union[_NoArg, Mapping[Any, Any], None] def _as_dataclass_field(self, key: str) -> Any: """Return a ``dataclasses.Field`` object given these arguments.""" @@ -226,10 +218,6 @@ class _AttributeOptions(NamedTuple): kw["compare"] = self.dataclasses_compare if self.dataclasses_kw_only is not _NoArg.NO_ARG: kw["kw_only"] = self.dataclasses_kw_only - if self.dataclasses_hash is not _NoArg.NO_ARG: - kw["hash"] = self.dataclasses_hash - if self.dataclasses_dataclass_metadata is not _NoArg.NO_ARG: - kw["metadata"] = self.dataclasses_dataclass_metadata if "default" in kw and callable(kw["default"]): # callable defaults are ambiguous. deprecate them in favour of @@ -267,7 +255,7 @@ class _AttributeOptions(NamedTuple): key: str, annotation: _AnnotationScanType, mapped_container: Optional[Any], - elem: Any, + elem: _T, ) -> Union[ Tuple[str, _AnnotationScanType], Tuple[str, _AnnotationScanType, dataclasses.Field[Any]], @@ -309,8 +297,6 @@ _DEFAULT_ATTRIBUTE_OPTIONS = _AttributeOptions( _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, - _NoArg.NO_ARG, - _NoArg.NO_ARG, ) _DEFAULT_READONLY_ATTRIBUTE_OPTIONS = _AttributeOptions( @@ -320,8 +306,6 @@ _DEFAULT_READONLY_ATTRIBUTE_OPTIONS = _AttributeOptions( _NoArg.NO_ARG, _NoArg.NO_ARG, _NoArg.NO_ARG, - _NoArg.NO_ARG, - _NoArg.NO_ARG, ) @@ -691,37 +675,27 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): # definition of custom PropComparator subclasses - from sqlalchemy.orm.properties import ( - ColumnProperty, - Composite, - Relationship, - ) - + from sqlalchemy.orm.properties import \ + ColumnProperty,\ + Composite,\ + Relationship class MyColumnComparator(ColumnProperty.Comparator): def __eq__(self, other): return self.__clause_element__() == other - class MyRelationshipComparator(Relationship.Comparator): def any(self, expression): "define the 'any' operation" # ... - class MyCompositeComparator(Composite.Comparator): def __gt__(self, other): "redefine the 'greater than' operation" - return sql.and_( - *[ - a > b - for a, b in zip( - self.__clause_element__().clauses, - other.__composite_values__(), - ) - ] - ) + return sql.and_(*[a>b for a, b in + zip(self.__clause_element__().clauses, + other.__composite_values__())]) # application of custom PropComparator subclasses @@ -729,22 +703,17 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): from sqlalchemy.orm import column_property, relationship, composite from sqlalchemy import Column, String - class SomeMappedClass(Base): - some_column = column_property( - Column("some_column", String), - comparator_factory=MyColumnComparator, - ) + some_column = column_property(Column("some_column", String), + comparator_factory=MyColumnComparator) - some_relationship = relationship( - SomeOtherClass, comparator_factory=MyRelationshipComparator - ) + some_relationship = relationship(SomeOtherClass, + comparator_factory=MyRelationshipComparator) some_composite = composite( - Column("a", String), - Column("b", String), - comparator_factory=MyCompositeComparator, - ) + Column("a", String), Column("b", String), + comparator_factory=MyCompositeComparator + ) Note that for column-level operator redefinition, it's usually simpler to define the operators at the Core level, using the @@ -766,7 +735,6 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): :attr:`.TypeEngine.comparator_factory` """ - __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" @@ -786,7 +754,7 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): self._adapt_to_entity = adapt_to_entity @util.non_memoized_property - def property(self) -> MapperProperty[_T_co]: + def property(self) -> MapperProperty[_T]: """Return the :class:`.MapperProperty` associated with this :class:`.PropComparator`. @@ -816,7 +784,7 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): def adapt_to_entity( self, adapt_to_entity: AliasedInsp[Any] - ) -> PropComparator[_T_co]: + ) -> PropComparator[_T]: """Return a copy of this PropComparator which will use the given :class:`.AliasedInsp` to produce corresponding expressions. """ @@ -870,13 +838,15 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[Any]: ... + ) -> ColumnElement[Any]: + ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement[Any]: ... + ) -> ColumnElement[Any]: + ... - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: + def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T]: r"""Redefine this object in terms of a polymorphic subclass, :func:`_orm.with_polymorphic` construct, or :func:`_orm.aliased` construct. @@ -886,9 +856,8 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): e.g.:: - query.join(Company.employees.of_type(Engineer)).filter( - Engineer.name == "foo" - ) + query.join(Company.employees.of_type(Engineer)).\ + filter(Engineer.name=='foo') :param \class_: a class or mapper indicating that criterion will be against this specific subclass. @@ -914,11 +883,11 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): stmt = select(User).join( - User.addresses.and_(Address.email_address != "foo") + User.addresses.and_(Address.email_address != 'foo') ) stmt = select(User).options( - joinedload(User.addresses.and_(Address.email_address != "foo")) + joinedload(User.addresses.and_(Address.email_address != 'foo')) ) .. versionadded:: 1.4 diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/loading.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/loading.py index 41fa18f..cae6f0b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/loading.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/loading.py @@ -1,5 +1,5 @@ # orm/loading.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -39,7 +39,6 @@ from .base import PassiveFlag from .context import FromStatement from .context import ORMCompileState from .context import QueryContext -from .strategies import SelectInLoader from .util import _none_set from .util import state_str from .. import exc as sa_exc @@ -150,11 +149,9 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: raise sa_exc.InvalidRequestError( "Can't apply uniqueness to row tuple containing value of " - f"""type {datatype!r}; { - 'the values returned appear to be' - if uncertain - else 'this datatype produces' - } non-hashable values""" + f"""type {datatype!r}; {'the values returned appear to be' + if uncertain else 'this datatype produces'} """ + "non-hashable values" ) return go @@ -182,22 +179,20 @@ def instances(cursor: CursorResult[Any], context: QueryContext) -> Result[Any]: return go unique_filters = [ - ( - _no_unique - if context.yield_per - else ( - _not_hashable( - ent.column.type, # type: ignore - legacy=context.load_options._legacy_uniquing, - uncertain=ent._null_column_type, - ) - if ( - not ent.use_id_for_hash - and (ent._non_hashable_value or ent._null_column_type) - ) - else id if ent.use_id_for_hash else None - ) + _no_unique + if context.yield_per + else _not_hashable( + ent.column.type, # type: ignore + legacy=context.load_options._legacy_uniquing, + uncertain=ent._null_column_type, ) + if ( + not ent.use_id_for_hash + and (ent._non_hashable_value or ent._null_column_type) + ) + else id + if ent.use_id_for_hash + else None for ent in context.compile_state._entities ] @@ -1011,38 +1006,21 @@ def _instance_processor( # loading does not apply assert only_load_props is None - if selectin_load_via.is_mapper: - _load_supers = [] - _endmost_mapper = selectin_load_via - while ( - _endmost_mapper - and _endmost_mapper is not _polymorphic_from - ): - _load_supers.append(_endmost_mapper) - _endmost_mapper = _endmost_mapper.inherits - else: - _load_supers = [selectin_load_via] - - for _selectinload_entity in _load_supers: - if PostLoad.path_exists( - context, load_path, _selectinload_entity - ): - continue - callable_ = _load_subclass_via_in( - context, - path, - _selectinload_entity, - _polymorphic_from, - option_entities, - ) - PostLoad.callable_for_path( - context, - load_path, - _selectinload_entity.mapper, - _selectinload_entity, - callable_, - _selectinload_entity, - ) + callable_ = _load_subclass_via_in( + context, + path, + selectin_load_via, + _polymorphic_from, + option_entities, + ) + PostLoad.callable_for_path( + context, + load_path, + selectin_load_via.mapper, + selectin_load_via, + callable_, + selectin_load_via, + ) post_load = PostLoad.for_context(context, load_path, only_load_props) @@ -1310,18 +1288,15 @@ def _load_subclass_via_in( if context.populate_existing: q2 = q2.execution_options(populate_existing=True) - while states: - chunk = states[0 : SelectInLoader._chunksize] - states = states[SelectInLoader._chunksize :] - context.session.execute( - q2, - dict( - primary_keys=[ - state.key[1][0] if zero_idx else state.key[1] - for state, load_attrs in chunk - ] - ), - ).unique().scalars().all() + context.session.execute( + q2, + dict( + primary_keys=[ + state.key[1][0] if zero_idx else state.key[1] + for state, load_attrs in states + ] + ), + ).unique().scalars().all() return do_load diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapped_collection.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapped_collection.py index ca085c4..9e479d0 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapped_collection.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapped_collection.py @@ -1,5 +1,5 @@ -# orm/mapped_collection.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# orm/collections.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,8 +29,6 @@ from .. import util from ..sql import coercions from ..sql import expression from ..sql import roles -from ..util.langhelpers import Missing -from ..util.langhelpers import MissingOr from ..util.typing import Literal if TYPE_CHECKING: @@ -42,6 +40,8 @@ if TYPE_CHECKING: _KT = TypeVar("_KT", bound=Any) _VT = TypeVar("_VT", bound=Any) +_F = TypeVar("_F", bound=Callable[[Any], Any]) + class _PlainColumnGetter(Generic[_KT]): """Plain column getter, stores collection of Column objects @@ -70,7 +70,7 @@ class _PlainColumnGetter(Generic[_KT]): def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]: return self.cols - def __call__(self, value: _KT) -> MissingOr[Union[_KT, Tuple[_KT, ...]]]: + def __call__(self, value: _KT) -> Union[_KT, Tuple[_KT, ...]]: state = base.instance_state(value) m = base._state_mapper(state) @@ -83,7 +83,7 @@ class _PlainColumnGetter(Generic[_KT]): else: obj = key[0] if obj is None: - return Missing + return _UNMAPPED_AMBIGUOUS_NONE else: return obj @@ -117,7 +117,9 @@ class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]): return self.__class__, (self.colkeys,) @classmethod - def _reduce_from_cols(cls, cols: Sequence[ColumnElement[_KT]]) -> Tuple[ + def _reduce_from_cols( + cls, cols: Sequence[ColumnElement[_KT]] + ) -> Tuple[ Type[_SerializableColumnGetterV2[_KT]], Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], ]: @@ -198,6 +200,9 @@ def column_keyed_dict( ) +_UNMAPPED_AMBIGUOUS_NONE = object() + + class _AttrGetter: __slots__ = ("attr_name", "getter") @@ -214,9 +219,9 @@ class _AttrGetter: dict_ = state.dict obj = dict_.get(self.attr_name, base.NO_VALUE) if obj is None: - return Missing + return _UNMAPPED_AMBIGUOUS_NONE else: - return Missing + return _UNMAPPED_AMBIGUOUS_NONE return obj @@ -226,7 +231,7 @@ class _AttrGetter: def attribute_keyed_dict( attr_name: str, *, ignore_unpopulated_attribute: bool = False -) -> Type[KeyFuncDict[Any, Any]]: +) -> Type[KeyFuncDict[_KT, _KT]]: """A dictionary-based collection type with attribute-based keying. .. versionchanged:: 2.0 Renamed :data:`.attribute_mapped_collection` to @@ -274,7 +279,7 @@ def attribute_keyed_dict( def keyfunc_mapping( - keyfunc: Callable[[Any], Any], + keyfunc: _F, *, ignore_unpopulated_attribute: bool = False, ) -> Type[KeyFuncDict[_KT, Any]]: @@ -350,7 +355,7 @@ class KeyFuncDict(Dict[_KT, _VT]): def __init__( self, - keyfunc: Callable[[Any], Any], + keyfunc: _F, *dict_args: Any, ignore_unpopulated_attribute: bool = False, ) -> None: @@ -374,7 +379,7 @@ class KeyFuncDict(Dict[_KT, _VT]): @classmethod def _unreduce( cls, - keyfunc: Callable[[Any], Any], + keyfunc: _F, values: Dict[_KT, _KT], adapter: Optional[CollectionAdapter] = None, ) -> "KeyFuncDict[_KT, _KT]": @@ -461,7 +466,7 @@ class KeyFuncDict(Dict[_KT, _VT]): ) else: return - elif key is Missing: + elif key is _UNMAPPED_AMBIGUOUS_NONE: if not self.ignore_unpopulated_attribute: self._raise_for_unpopulated( value, _sa_initiator, warn_only=True @@ -489,7 +494,7 @@ class KeyFuncDict(Dict[_KT, _VT]): value, _sa_initiator, warn_only=False ) return - elif key is Missing: + elif key is _UNMAPPED_AMBIGUOUS_NONE: if not self.ignore_unpopulated_attribute: self._raise_for_unpopulated( value, _sa_initiator, warn_only=True @@ -511,7 +516,7 @@ class KeyFuncDict(Dict[_KT, _VT]): def _mapped_collection_cls( - keyfunc: Callable[[Any], Any], ignore_unpopulated_attribute: bool + keyfunc: _F, ignore_unpopulated_attribute: bool ) -> Type[KeyFuncDict[_KT, _KT]]: class _MKeyfuncMapped(KeyFuncDict[_KT, _KT]): def __init__(self, *dict_args: Any) -> None: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapper.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapper.py index ae7f8f2..c66d876 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapper.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/mapper.py @@ -1,5 +1,5 @@ # orm/mapper.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -132,9 +132,9 @@ _WithPolymorphicArg = Union[ ] -_mapper_registries: weakref.WeakKeyDictionary[_RegistryType, bool] = ( - weakref.WeakKeyDictionary() -) +_mapper_registries: weakref.WeakKeyDictionary[ + _RegistryType, bool +] = weakref.WeakKeyDictionary() def _all_registries() -> Set[registry]: @@ -296,17 +296,6 @@ class Mapper( particular primary key value. A "partial primary key" can occur if one has mapped to an OUTER JOIN, for example. - The :paramref:`.orm.Mapper.allow_partial_pks` parameter also - indicates to the ORM relationship lazy loader, when loading a - many-to-one related object, if a composite primary key that has - partial NULL values should result in an attempt to load from the - database, or if a load attempt is not necessary. - - .. versionadded:: 2.0.36 :paramref:`.orm.Mapper.allow_partial_pks` - is consulted by the relationship lazy loader strategy, such that - when set to False, a SELECT for a composite primary key that - has partial NULL values will not be emitted. - :param batch: Defaults to ``True``, indicating that save operations of multiple entities can be batched together for efficiency. Setting to False indicates @@ -329,7 +318,7 @@ class Mapper( class User(Base): __table__ = user_table - __mapper_args__ = {"column_prefix": "_"} + __mapper_args__ = {'column_prefix':'_'} The above mapping will assign the ``user_id``, ``user_name``, and ``password`` columns to attributes named ``_user_id``, @@ -453,7 +442,7 @@ class Mapper( mapping of the class to an alternate selectable, for loading only. - .. seealso:: + .. seealso:: :ref:`relationship_aliased_class` - the new pattern that removes the need for the :paramref:`_orm.Mapper.non_primary` flag. @@ -545,14 +534,14 @@ class Mapper( base-most mapped :class:`.Table`:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] = mapped_column(String(50)) __mapper_args__ = { - "polymorphic_on": discriminator, - "polymorphic_identity": "employee", + "polymorphic_on":discriminator, + "polymorphic_identity":"employee" } It may also be specified @@ -561,18 +550,17 @@ class Mapper( approach:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] = mapped_column(String(50)) __mapper_args__ = { - "polymorphic_on": case( + "polymorphic_on":case( (discriminator == "EN", "engineer"), (discriminator == "MA", "manager"), - else_="employee", - ), - "polymorphic_identity": "employee", + else_="employee"), + "polymorphic_identity":"employee" } It may also refer to any attribute using its string name, @@ -580,14 +568,14 @@ class Mapper( configurations:: class Employee(Base): - __tablename__ = "employee" + __tablename__ = 'employee' id: Mapped[int] = mapped_column(primary_key=True) discriminator: Mapped[str] __mapper_args__ = { "polymorphic_on": "discriminator", - "polymorphic_identity": "employee", + "polymorphic_identity": "employee" } When setting ``polymorphic_on`` to reference an @@ -604,7 +592,6 @@ class Mapper( from sqlalchemy import event from sqlalchemy.orm import object_mapper - @event.listens_for(Employee, "init", propagate=True) def set_identity(instance, *arg, **kw): mapper = object_mapper(instance) @@ -1056,7 +1043,7 @@ class Mapper( """ - primary_key: Tuple[ColumnElement[Any], ...] + primary_key: Tuple[Column[Any], ...] """An iterable containing the collection of :class:`_schema.Column` objects which comprise the 'primary key' of the mapped table, from the @@ -1619,11 +1606,9 @@ class Mapper( if self._primary_key_argument: coerced_pk_arg = [ - ( - self._str_arg_to_mapped_col("primary_key", c) - if isinstance(c, str) - else c - ) + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c for c in ( coercions.expect( roles.DDLConstraintColumnRole, @@ -2480,11 +2465,9 @@ class Mapper( return "Mapper[%s%s(%s)]" % ( self.class_.__name__, self.non_primary and " (non-primary)" or "", - ( - self.local_table.description - if self.local_table is not None - else self.persist_selectable.description - ), + self.local_table.description + if self.local_table is not None + else self.persist_selectable.description, ) def _is_orphan(self, state: InstanceState[_O]) -> bool: @@ -2554,7 +2537,7 @@ class Mapper( if spec == "*": mappers = list(self.self_and_descendants) elif spec: - mapper_set: Set[Mapper[Any]] = set() + mapper_set = set() for m in util.to_list(spec): m = _class_to_mapper(m) if not m.isa(self): @@ -3261,9 +3244,14 @@ class Mapper( The resulting structure is a dictionary of columns mapped to lists of equivalent columns, e.g.:: - {tablea.col1: {tableb.col1, tablec.col1}, tablea.col2: {tabled.col2}} + { + tablea.col1: + {tableb.col1, tablec.col1}, + tablea.col2: + {tabled.col2} + } - """ # noqa: E501 + """ result: _EquivalentColumnMap = {} def visit_binary(binary): @@ -3428,11 +3416,9 @@ class Mapper( return self.class_manager.mapper.base_mapper def _result_has_identity_key(self, result, adapter=None): - pk_cols: Sequence[ColumnElement[Any]] - if adapter is not None: - pk_cols = [adapter.columns[c] for c in self.primary_key] - else: - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] rk = result.keys() for col in pk_cols: if col not in rk: @@ -3442,7 +3428,7 @@ class Mapper( def identity_key_from_row( self, - row: Union[Row[Any], RowMapping], + row: Optional[Union[Row[Any], RowMapping]], identity_token: Optional[Any] = None, adapter: Optional[ORMAdapter] = None, ) -> _IdentityKeyType[_O]: @@ -3457,21 +3443,18 @@ class Mapper( for the "row" argument """ - pk_cols: Sequence[ColumnElement[Any]] - if adapter is not None: - pk_cols = [adapter.columns[c] for c in self.primary_key] - else: - pk_cols = self.primary_key + pk_cols: Sequence[ColumnClause[Any]] = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] - mapping: RowMapping if hasattr(row, "_mapping"): - mapping = row._mapping + mapping = row._mapping # type: ignore else: - mapping = row # type: ignore[assignment] + mapping = cast("Mapping[Any, Any]", row) return ( self._identity_class, - tuple(mapping[column] for column in pk_cols), + tuple(mapping[column] for column in pk_cols), # type: ignore identity_token, ) @@ -3741,15 +3724,14 @@ class Mapper( given:: - class A: ... - + class A: + ... class B(A): __mapper_args__ = {"polymorphic_load": "selectin"} - - class C(B): ... - + class C(B): + ... class D(B): __mapper_args__ = {"polymorphic_load": "selectin"} @@ -3819,7 +3801,6 @@ class Mapper( this subclass as a SELECT with IN. """ - strategy_options = util.preloaded.orm_strategy_options assert self.inherits @@ -3843,7 +3824,7 @@ class Mapper( classes_to_include.add(m) m = m.inherits - for prop in self.column_attrs + self.relationships: + for prop in self.attrs: # skip prop keys that are not instrumented on the mapped class. # this is primarily the "_sa_polymorphic_on" property that gets # created for an ad-hoc polymorphic_on SQL expression, issue #8704 @@ -4308,7 +4289,7 @@ def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None: reg._new_mappers = False -def reconstructor(fn: _Fn) -> _Fn: +def reconstructor(fn): """Decorate a method as the 'reconstructor' hook. Designates a single method as the "reconstructor", an ``__init__``-like @@ -4334,7 +4315,7 @@ def reconstructor(fn: _Fn) -> _Fn: :meth:`.InstanceEvents.load` """ - fn.__sa_reconstructor__ = True # type: ignore[attr-defined] + fn.__sa_reconstructor__ = True return fn diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/path_registry.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/path_registry.py index bb03e53..354552a 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/path_registry.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/path_registry.py @@ -1,10 +1,12 @@ # orm/path_registry.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Path tracking utilities, representing mapper graph traversals.""" +"""Path tracking utilities, representing mapper graph traversals. + +""" from __future__ import annotations @@ -33,7 +35,7 @@ from ..sql.cache_key import HasCacheKey if TYPE_CHECKING: from ._typing import _InternalEntityType - from .interfaces import StrategizedProperty + from .interfaces import MapperProperty from .mapper import Mapper from .relationships import RelationshipProperty from .util import AliasedInsp @@ -43,9 +45,11 @@ if TYPE_CHECKING: from ..util.typing import _LiteralStar from ..util.typing import TypeGuard - def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: ... + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: + ... - def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: ... + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: + ... else: is_root = operator.attrgetter("is_root") @@ -55,13 +59,13 @@ else: _SerializedPath = List[Any] _StrPathToken = str _PathElementType = Union[ - _StrPathToken, "_InternalEntityType[Any]", "StrategizedProperty[Any]" + _StrPathToken, "_InternalEntityType[Any]", "MapperProperty[Any]" ] # the representation is in fact # a tuple with alternating: -# [_InternalEntityType[Any], Union[str, StrategizedProperty[Any]], -# _InternalEntityType[Any], Union[str, StrategizedProperty[Any]], ...] +# [_InternalEntityType[Any], Union[str, MapperProperty[Any]], +# _InternalEntityType[Any], Union[str, MapperProperty[Any]], ...] # this might someday be a tuple of 2-tuples instead, but paths can be # chopped at odd intervals as well so this is less flexible _PathRepresentation = Tuple[_PathElementType, ...] @@ -69,7 +73,7 @@ _PathRepresentation = Tuple[_PathElementType, ...] # NOTE: these names are weird since the array is 0-indexed, # the "_Odd" entries are at 0, 2, 4, etc _OddPathRepresentation = Sequence["_InternalEntityType[Any]"] -_EvenPathRepresentation = Sequence[Union["StrategizedProperty[Any]", str]] +_EvenPathRepresentation = Sequence[Union["MapperProperty[Any]", str]] log = logging.getLogger(__name__) @@ -181,23 +185,26 @@ class PathRegistry(HasCacheKey): return id(self) @overload - def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: ... + def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: + ... @overload - def __getitem__(self, entity: int) -> _PathElementType: ... + def __getitem__(self, entity: int) -> _PathElementType: + ... @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: ... + def __getitem__(self, entity: slice) -> _PathRepresentation: + ... @overload def __getitem__( self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: ... + ) -> AbstractEntityRegistry: + ... @overload - def __getitem__( - self, entity: StrategizedProperty[Any] - ) -> PropRegistry: ... + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: + ... def __getitem__( self, @@ -206,7 +213,7 @@ class PathRegistry(HasCacheKey): int, slice, _InternalEntityType[Any], - StrategizedProperty[Any], + MapperProperty[Any], ], ) -> Union[ TokenRegistry, @@ -225,7 +232,7 @@ class PathRegistry(HasCacheKey): def pairs( self, ) -> Iterator[ - Tuple[_InternalEntityType[Any], Union[str, StrategizedProperty[Any]]] + Tuple[_InternalEntityType[Any], Union[str, MapperProperty[Any]]] ]: odd_path = cast(_OddPathRepresentation, self.path) even_path = cast(_EvenPathRepresentation, odd_path) @@ -313,11 +320,13 @@ class PathRegistry(HasCacheKey): @overload @classmethod - def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: ... + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: + ... @overload @classmethod - def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: ... + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: + ... @classmethod def per_mapper( @@ -531,16 +540,15 @@ class PropRegistry(PathRegistry): inherit_cache = True is_property = True - prop: StrategizedProperty[Any] + prop: MapperProperty[Any] mapper: Optional[Mapper[Any]] entity: Optional[_InternalEntityType[Any]] def __init__( - self, parent: AbstractEntityRegistry, prop: StrategizedProperty[Any] + self, parent: AbstractEntityRegistry, prop: MapperProperty[Any] ): - # restate this path in terms of the - # given StrategizedProperty's parent. + # given MapperProperty's parent. insp = cast("_InternalEntityType[Any]", parent[-1]) natural_parent: AbstractEntityRegistry = parent @@ -564,7 +572,7 @@ class PropRegistry(PathRegistry): # entities are used. # # here we are trying to distinguish between a path that starts - # on a with_polymorphic entity vs. one that starts on a + # on a the with_polymorhpic entity vs. one that starts on a # normal entity that introduces a with_polymorphic() in the # middle using of_type(): # @@ -800,9 +808,11 @@ if TYPE_CHECKING: def path_is_entity( path: PathRegistry, - ) -> TypeGuard[AbstractEntityRegistry]: ... + ) -> TypeGuard[AbstractEntityRegistry]: + ... - def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: ... + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: + ... else: path_is_entity = operator.attrgetter("is_entity") diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/persistence.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/persistence.py index 0b48d8e..6729b47 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/persistence.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/persistence.py @@ -1,5 +1,5 @@ # orm/persistence.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -140,13 +140,11 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): state_dict, sub_mapper, connection, - ( - mapper._get_committed_state_attr_by_column( - state, state_dict, mapper.version_id_col - ) - if mapper.version_id_col is not None - else None - ), + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) + if mapper.version_id_col is not None + else None, ) for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table @@ -561,8 +559,7 @@ def _collect_update_commands( f"No primary key value supplied for column(s) " f"""{ ', '.join( - str(c) for c in pks if pk_params[c._label] is None - ) + str(c) for c in pks if pk_params[c._label] is None) }; """ "per-row ORM Bulk UPDATE by Primary Key requires that " "records contain primary key values", @@ -705,10 +702,10 @@ def _collect_delete_commands( params = {} for col in mapper._pks_by_table[table]: - params[col.key] = value = ( - mapper._get_committed_state_attr_by_column( - state, state_dict, col - ) + params[ + col.key + ] = value = mapper._get_committed_state_attr_by_column( + state, state_dict, col ) if value is None: raise orm_exc.FlushError( @@ -936,11 +933,9 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, - ( - c.returned_defaults - if not c.context.executemany - else None - ), + c.returned_defaults + if not c.context.executemany + else None, ) if check_rowcount: @@ -1073,11 +1068,9 @@ def _emit_insert_statements( last_inserted_params, value_params, False, - ( - result.returned_defaults - if not result.context.executemany - else None - ), + result.returned_defaults + if not result.context.executemany + else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1267,11 +1260,9 @@ def _emit_insert_statements( result.context.compiled_parameters[0], value_params, False, - ( - result.returned_defaults - if not result.context.executemany - else None - ), + result.returned_defaults + if not result.context.executemany + else None, ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1374,13 +1365,7 @@ def _emit_post_update_statements( ) rows += c.rowcount - for i, ( - state, - state_dict, - mapper_rec, - connection, - params, - ) in enumerate(records): + for state, state_dict, mapper_rec, connection, params in records: _postfetch_post_update( mapper_rec, uowtransaction, @@ -1388,7 +1373,7 @@ def _emit_post_update_statements( state, state_dict, c, - c.context.compiled_parameters[i], + c.context.compiled_parameters[0], ) if check_rowcount: @@ -1584,25 +1569,16 @@ def _finalize_insert_update_commands(base_mapper, uowtransaction, states): def _postfetch_post_update( mapper, uowtransaction, table, state, dict_, result, params ): - needs_version_id = ( - mapper.version_id_col is not None - and mapper.version_id_col in mapper._cols_by_table[table] - ) - - if not uowtransaction.is_deleted(state): - # post updating after a regular INSERT or UPDATE, do a full postfetch - prefetch_cols = result.context.compiled.prefetch - postfetch_cols = result.context.compiled.postfetch - elif needs_version_id: - # post updating before a DELETE with a version_id_col, need to - # postfetch just version_id_col - prefetch_cols = postfetch_cols = () - else: - # post updating before a DELETE without a version_id_col, - # don't need to postfetch + if uowtransaction.is_deleted(state): return - if needs_version_id: + prefetch_cols = result.context.compiled.prefetch + postfetch_cols = result.context.compiled.postfetch + + if ( + mapper.version_id_col is not None + and mapper.version_id_col in mapper._cols_by_table[table] + ): prefetch_cols = list(prefetch_cols) + [mapper.version_id_col] refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush) @@ -1682,18 +1658,9 @@ def _postfetch( for c in prefetch_cols: if c.key in params and c in mapper._columntoproperty: - pkey = mapper._columntoproperty[c].key - - # set prefetched value in dict and also pop from committed_state, - # since this is new database state that replaces whatever might - # have previously been fetched (see #10800). this is essentially a - # shorthand version of set_committed_value(), which could also be - # used here directly (with more overhead) - dict_[pkey] = params[c.key] - state.committed_state.pop(pkey, None) - + dict_[mapper._columntoproperty[c].key] = params[c.key] if refresh_flush: - load_evt_attrs.append(pkey) + load_evt_attrs.append(mapper._columntoproperty[c].key) if refresh_flush and load_evt_attrs: mapper.class_manager.dispatch.refresh_flush( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/properties.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/properties.py index 88540be..4bb396e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/properties.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/properties.py @@ -1,5 +1,5 @@ # orm/properties.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -28,7 +28,6 @@ from typing import TypeVar from typing import Union from . import attributes -from . import exc as orm_exc from . import strategy_options from .base import _DeclarativeMapped from .base import class_mapper @@ -44,6 +43,7 @@ from .interfaces import PropComparator from .interfaces import StrategizedProperty from .relationships import RelationshipProperty from .util import de_stringify_annotation +from .util import de_stringify_union_elements from .. import exc as sa_exc from .. import ForeignKey from .. import log @@ -55,13 +55,12 @@ from ..sql.schema import Column from ..sql.schema import SchemaConst from ..sql.type_api import TypeEngine from ..util.typing import de_optionalize_union_types -from ..util.typing import get_args -from ..util.typing import includes_none -from ..util.typing import is_a_type from ..util.typing import is_fwd_ref +from ..util.typing import is_optional_union from ..util.typing import is_pep593 -from ..util.typing import is_pep695 +from ..util.typing import is_union from ..util.typing import Self +from ..util.typing import typing_get_args if TYPE_CHECKING: from ._typing import _IdentityKeyType @@ -234,7 +233,7 @@ class ColumnProperty( return self.strategy._have_default_expression # type: ignore return ("deferred", True) not in self.strategy_key or ( - self not in self.parent._readonly_props + self not in self.parent._readonly_props # type: ignore ) @util.preload_module("sqlalchemy.orm.state", "sqlalchemy.orm.strategies") @@ -280,8 +279,8 @@ class ColumnProperty( name = Column(String(64)) extension = Column(String(8)) - filename = column_property(name + "." + extension) - path = column_property("C:/" + filename.expression) + filename = column_property(name + '.' + extension) + path = column_property('C:/' + filename.expression) .. seealso:: @@ -430,7 +429,8 @@ class ColumnProperty( if TYPE_CHECKING: - def __clause_element__(self) -> NamedColumn[_PT]: ... + def __clause_element__(self) -> NamedColumn[_PT]: + ... def _memoized_method___clause_element__( self, @@ -636,11 +636,9 @@ class MappedColumn( return [ ( self.column, - ( - self._sort_order - if self._sort_order is not _NoArg.NO_ARG - else 0 - ), + self._sort_order + if self._sort_order is not _NoArg.NO_ARG + else 0, ) ] @@ -663,31 +661,6 @@ class MappedColumn( # Column will be merged into it in _init_column_for_annotation(). return MappedColumn() - def _adjust_for_existing_column( - self, - decl_scan: _ClassScanMapperConfig, - key: str, - given_column: Column[_T], - ) -> Column[_T]: - if ( - self._use_existing_column - and decl_scan.inherits - and decl_scan.single - ): - if decl_scan.is_deferred: - raise sa_exc.ArgumentError( - "Can't use use_existing_column with deferred mappers" - ) - supercls_mapper = class_mapper(decl_scan.inherits, False) - - colname = ( - given_column.name if given_column.name is not None else key - ) - given_column = supercls_mapper.local_table.c.get( # type: ignore[assignment] # noqa: E501 - colname, given_column - ) - return given_column - def declarative_scan( self, decl_scan: _ClassScanMapperConfig, @@ -702,9 +675,21 @@ class MappedColumn( ) -> None: column = self.column - column = self.column = self._adjust_for_existing_column( - decl_scan, key, self.column - ) + if ( + self._use_existing_column + and decl_scan.inherits + and decl_scan.single + ): + if decl_scan.is_deferred: + raise sa_exc.ArgumentError( + "Can't use use_existing_column with deferred mappers" + ) + supercls_mapper = class_mapper(decl_scan.inherits, False) + + colname = column.name if column.name is not None else key + column = self.column = supercls_mapper.local_table.c.get( # type: ignore # noqa: E501 + colname, column + ) if column.key is None: column.key = key @@ -721,8 +706,6 @@ class MappedColumn( self._init_column_for_annotation( cls, - decl_scan, - key, registry, extracted_mapped_annotation, originating_module, @@ -731,7 +714,6 @@ class MappedColumn( @util.preload_module("sqlalchemy.orm.decl_base") def declarative_scan_for_composite( self, - decl_scan: _ClassScanMapperConfig, registry: _RegistryType, cls: Type[Any], originating_module: Optional[str], @@ -742,65 +724,61 @@ class MappedColumn( decl_base = util.preloaded.orm_decl_base decl_base._undefer_column_name(param_name, self.column) self._init_column_for_annotation( - cls, decl_scan, key, registry, param_annotation, originating_module + cls, registry, param_annotation, originating_module ) def _init_column_for_annotation( self, cls: Type[Any], - decl_scan: _ClassScanMapperConfig, - key: str, registry: _RegistryType, argument: _AnnotationScanType, originating_module: Optional[str], ) -> None: sqltype = self.column.type - if is_fwd_ref( - argument, check_generic=True, check_for_plain_string=True + if isinstance(argument, str) or is_fwd_ref( + argument, check_generic=True ): assert originating_module is not None argument = de_stringify_annotation( cls, argument, originating_module, include_generic=True ) - nullable = includes_none(argument) + if is_union(argument): + assert originating_module is not None + argument = de_stringify_union_elements( + cls, argument, originating_module + ) + + nullable = is_optional_union(argument) if not self._has_nullable: self.column.nullable = nullable our_type = de_optionalize_union_types(argument) - find_mapped_in: Tuple[Any, ...] = () - our_type_is_pep593 = False - raw_pep_593_type = None + use_args_from = None if is_pep593(our_type): our_type_is_pep593 = True - pep_593_components = get_args(our_type) + pep_593_components = typing_get_args(our_type) raw_pep_593_type = pep_593_components[0] - if nullable: + if is_optional_union(raw_pep_593_type): raw_pep_593_type = de_optionalize_union_types(raw_pep_593_type) - find_mapped_in = pep_593_components[1:] - elif is_pep695(argument) and is_pep593(argument.__value__): - # do not support nested annotation inside unions ets - find_mapped_in = get_args(argument.__value__)[1:] - use_args_from: Optional[MappedColumn[Any]] - for elem in find_mapped_in: - if isinstance(elem, MappedColumn): - use_args_from = elem - break + nullable = True + if not self._has_nullable: + self.column.nullable = nullable + for elem in pep_593_components[1:]: + if isinstance(elem, MappedColumn): + use_args_from = elem + break else: - use_args_from = None + our_type_is_pep593 = False + raw_pep_593_type = None if use_args_from is not None: - - self.column = use_args_from._adjust_for_existing_column( - decl_scan, key, self.column - ) - if ( not self._has_insert_default and use_args_from.column.default is not None @@ -870,7 +848,8 @@ class MappedColumn( ) if sqltype._isnull and not self.column.foreign_keys: - checks: List[Any] + new_sqltype = None + if our_type_is_pep593: checks = [our_type, raw_pep_593_type] else: @@ -885,23 +864,16 @@ class MappedColumn( isinstance(our_type, type) and issubclass(our_type, TypeEngine) ): - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( f"The type provided inside the {self.column.key!r} " "attribute Mapped annotation is the SQLAlchemy type " f"{our_type}. Expected a Python type instead" ) - elif is_a_type(our_type): - raise orm_exc.MappedAnnotationError( + else: + raise sa_exc.ArgumentError( "Could not locate SQLAlchemy Core type for Python " f"type {our_type} inside the {self.column.key!r} " "attribute Mapped annotation" ) - else: - raise orm_exc.MappedAnnotationError( - f"The object provided inside the {self.column.key!r} " - "attribute Mapped annotation is not a Python type, " - f"it's the object {our_type!r}. Expected a Python " - "type." - ) self.column._set_type(new_sqltype) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/query.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/query.py index 3489c15..5da7ee9 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/query.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/query.py @@ -1,5 +1,5 @@ # orm/query.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -134,7 +134,6 @@ if TYPE_CHECKING: from ..sql._typing import _TypedColumnClauseArgument as _TCCA from ..sql.base import CacheableOptions from ..sql.base import ExecutableOption - from ..sql.dml import UpdateBase from ..sql.elements import ColumnElement from ..sql.elements import Label from ..sql.selectable import _ForUpdateOfArgument @@ -167,6 +166,7 @@ class Query( Executable, Generic[_T], ): + """ORM-level SQL construction object. .. legacy:: The ORM :class:`.Query` object is a legacy construct @@ -205,9 +205,9 @@ class Query( _memoized_select_entities = () - _compile_options: Union[Type[CacheableOptions], CacheableOptions] = ( - ORMCompileState.default_compile_options - ) + _compile_options: Union[ + Type[CacheableOptions], CacheableOptions + ] = ORMCompileState.default_compile_options _with_options: Tuple[ExecutableOption, ...] load_options = QueryContext.default_load_options + { @@ -493,7 +493,7 @@ class Query( return cast("Select[_T]", self.statement) @property - def statement(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: + def statement(self) -> Union[Select[_T], FromStatement[_T]]: """The full SELECT statement represented by this Query. The statement by default will not have disambiguating labels @@ -521,8 +521,6 @@ class Query( # from there, it starts to look much like Query itself won't be # passed into the execute process and won't generate its own cache # key; this will all occur in terms of the ORM-enabled Select. - stmt: Union[Select[_T], FromStatement[_T], UpdateBase] - if not self._compile_options._set_base_alias: # if we don't have legacy top level aliasing features in use # then convert to a future select() directly @@ -675,38 +673,41 @@ class Query( from sqlalchemy.orm import aliased - class Part(Base): - __tablename__ = "part" + __tablename__ = 'part' part = Column(String, primary_key=True) sub_part = Column(String, primary_key=True) quantity = Column(Integer) - - included_parts = ( - session.query(Part.sub_part, Part.part, Part.quantity) - .filter(Part.part == "our part") - .cte(name="included_parts", recursive=True) - ) + included_parts = session.query( + Part.sub_part, + Part.part, + Part.quantity).\ + filter(Part.part=="our part").\ + cte(name="included_parts", recursive=True) incl_alias = aliased(included_parts, name="pr") parts_alias = aliased(Part, name="p") included_parts = included_parts.union_all( session.query( - parts_alias.sub_part, parts_alias.part, parts_alias.quantity - ).filter(parts_alias.part == incl_alias.c.sub_part) - ) + parts_alias.sub_part, + parts_alias.part, + parts_alias.quantity).\ + filter(parts_alias.part==incl_alias.c.sub_part) + ) q = session.query( - included_parts.c.sub_part, - func.sum(included_parts.c.quantity).label("total_quantity"), - ).group_by(included_parts.c.sub_part) + included_parts.c.sub_part, + func.sum(included_parts.c.quantity). + label('total_quantity') + ).\ + group_by(included_parts.c.sub_part) .. seealso:: :meth:`_sql.Select.cte` - v2 equivalent method. - """ # noqa: E501 + """ return ( self.enable_eagerloads(False) ._get_select_statement_only() @@ -731,17 +732,20 @@ class Query( ) @overload - def as_scalar( # type: ignore[overload-overlap] + def as_scalar( self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[_MAYBE_ENTITY]: ... + ) -> ScalarSelect[_MAYBE_ENTITY]: + ... @overload def as_scalar( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: ... + ) -> ScalarSelect[_NOT_ENTITY]: + ... @overload - def as_scalar(self) -> ScalarSelect[Any]: ... + def as_scalar(self) -> ScalarSelect[Any]: + ... @util.deprecated( "1.4", @@ -759,15 +763,18 @@ class Query( @overload def scalar_subquery( self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[Any]: ... + ) -> ScalarSelect[Any]: + ... @overload def scalar_subquery( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: ... + ) -> ScalarSelect[_NOT_ENTITY]: + ... @overload - def scalar_subquery(self) -> ScalarSelect[Any]: ... + def scalar_subquery(self) -> ScalarSelect[Any]: + ... def scalar_subquery(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this @@ -792,7 +799,7 @@ class Query( ) @property - def selectable(self) -> Union[Select[_T], FromStatement[_T], UpdateBase]: + def selectable(self) -> Union[Select[_T], FromStatement[_T]]: """Return the :class:`_expression.Select` object emitted by this :class:`_query.Query`. @@ -803,9 +810,7 @@ class Query( """ return self.__clause_element__() - def __clause_element__( - self, - ) -> Union[Select[_T], FromStatement[_T], UpdateBase]: + def __clause_element__(self) -> Union[Select[_T], FromStatement[_T]]: return ( self._with_compile_options( _enable_eagerloads=False, _render_for_subquery=True @@ -817,12 +822,14 @@ class Query( @overload def only_return_tuples( self: Query[_O], value: Literal[True] - ) -> RowReturningQuery[Tuple[_O]]: ... + ) -> RowReturningQuery[Tuple[_O]]: + ... @overload def only_return_tuples( self: Query[_O], value: Literal[False] - ) -> Query[_O]: ... + ) -> Query[_O]: + ... @_generative def only_return_tuples(self, value: bool) -> Query[Any]: @@ -943,7 +950,9 @@ class Query( :attr:`_query.Query.statement` using :meth:`.Session.execute`:: result = session.execute( - query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL).statement + query + .set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) + .statement ) .. versionadded:: 1.4 @@ -1052,7 +1061,8 @@ class Query( some_object = session.query(VersionedFoo).get((5, 10)) - some_object = session.query(VersionedFoo).get({"id": 5, "version_id": 10}) + some_object = session.query(VersionedFoo).get( + {"id": 5, "version_id": 10}) :meth:`_query.Query.get` is special in that it provides direct access to the identity map of the owning :class:`.Session`. @@ -1118,7 +1128,7 @@ class Query( :return: The object instance, or ``None``. - """ # noqa: E501 + """ self._no_criterion_assertion("get", order_by=False, distinct=False) # we still implement _get_impl() so that baked query can override @@ -1465,13 +1475,15 @@ class Query( return None @overload - def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: + ... @overload def with_entities( self, _colexpr: roles.TypedColumnsClauseRole[_T], - ) -> RowReturningQuery[Tuple[_T]]: ... + ) -> RowReturningQuery[Tuple[_T]]: + ... # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 @@ -1481,12 +1493,14 @@ class Query( @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... @overload def with_entities( @@ -1495,7 +1509,8 @@ class Query( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... @overload def with_entities( @@ -1505,7 +1520,8 @@ class Query( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... @overload def with_entities( @@ -1516,7 +1532,8 @@ class Query( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... @overload def with_entities( @@ -1528,7 +1545,8 @@ class Query( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... @overload def with_entities( @@ -1541,14 +1559,16 @@ class Query( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... # END OVERLOADED FUNCTIONS self.with_entities @overload def with_entities( self, *entities: _ColumnsClauseArgument[Any] - ) -> Query[Any]: ... + ) -> Query[Any]: + ... @_generative def with_entities( @@ -1562,22 +1582,19 @@ class Query( # Users, filtered on some arbitrary criterion # and then ordered by related email address - q = ( - session.query(User) - .join(User.address) - .filter(User.name.like("%ed%")) - .order_by(Address.email) - ) + q = session.query(User).\ + join(User.address).\ + filter(User.name.like('%ed%')).\ + order_by(Address.email) # given *only* User.id==5, Address.email, and 'q', what # would the *next* User in the result be ? - subq = ( - q.with_entities(Address.email) - .order_by(None) - .filter(User.id == 5) - .subquery() - ) - q = q.join((subq, subq.c.email < Address.email)).limit(1) + subq = q.with_entities(Address.email).\ + order_by(None).\ + filter(User.id==5).\ + subquery() + q = q.join((subq, subq.c.email < Address.email)).\ + limit(1) .. seealso:: @@ -1673,11 +1690,9 @@ class Query( def filter_something(criterion): def transform(q): return q.filter(criterion) - return transform - - q = q.with_transformation(filter_something(x == 5)) + q = q.with_transformation(filter_something(x==5)) This allows ad-hoc recipes to be created for :class:`_query.Query` objects. @@ -1714,12 +1729,13 @@ class Query( schema_translate_map: Optional[SchemaTranslateMapType] = ..., populate_existing: bool = False, autoflush: bool = False, - preserve_rowcount: bool = False, **opt: Any, - ) -> Self: ... + ) -> Self: + ... @overload - def execution_options(self, **opt: Any) -> Self: ... + def execution_options(self, **opt: Any) -> Self: + ... @_generative def execution_options(self, **kwargs: Any) -> Self: @@ -1794,15 +1810,9 @@ class Query( E.g.:: - q = ( - sess.query(User) - .populate_existing() - .with_for_update(nowait=True, of=User) - ) + q = sess.query(User).populate_existing().with_for_update(nowait=True, of=User) - The above query on a PostgreSQL backend will render like: - - .. sourcecode:: sql + The above query on a PostgreSQL backend will render like:: SELECT users.id AS users_id FROM users FOR UPDATE OF users NOWAIT @@ -1880,13 +1890,14 @@ class Query( e.g.:: - session.query(MyClass).filter(MyClass.name == "some name") + session.query(MyClass).filter(MyClass.name == 'some name') Multiple criteria may be specified as comma separated; the effect is that they will be joined together using the :func:`.and_` function:: - session.query(MyClass).filter(MyClass.name == "some name", MyClass.id > 5) + session.query(MyClass).\ + filter(MyClass.name == 'some name', MyClass.id > 5) The criterion is any SQL expression object applicable to the WHERE clause of a select. String expressions are coerced @@ -1899,7 +1910,7 @@ class Query( :meth:`_sql.Select.where` - v2 equivalent method. - """ # noqa: E501 + """ for crit in list(criterion): crit = coercions.expect( roles.WhereHavingRole, crit, apply_propagate_attrs=self @@ -1967,13 +1978,14 @@ class Query( e.g.:: - session.query(MyClass).filter_by(name="some name") + session.query(MyClass).filter_by(name = 'some name') Multiple criteria may be specified as comma separated; the effect is that they will be joined together using the :func:`.and_` function:: - session.query(MyClass).filter_by(name="some name", id=5) + session.query(MyClass).\ + filter_by(name = 'some name', id = 5) The keyword expressions are extracted from the primary entity of the query, or the last entity that was the @@ -2100,12 +2112,10 @@ class Query( HAVING criterion makes it possible to use filters on aggregate functions like COUNT, SUM, AVG, MAX, and MIN, eg.:: - q = ( - session.query(User.id) - .join(User.addresses) - .group_by(User.id) - .having(func.count(Address.id) > 2) - ) + q = session.query(User.id).\ + join(User.addresses).\ + group_by(User.id).\ + having(func.count(Address.id) > 2) .. seealso:: @@ -2129,8 +2139,8 @@ class Query( e.g.:: - q1 = sess.query(SomeClass).filter(SomeClass.foo == "bar") - q2 = sess.query(SomeClass).filter(SomeClass.bar == "foo") + q1 = sess.query(SomeClass).filter(SomeClass.foo=='bar') + q2 = sess.query(SomeClass).filter(SomeClass.bar=='foo') q3 = q1.union(q2) @@ -2139,9 +2149,7 @@ class Query( x.union(y).union(z).all() - will nest on each ``union()``, and produces: - - .. sourcecode:: sql + will nest on each ``union()``, and produces:: SELECT * FROM (SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y) UNION SELECT * FROM Z) @@ -2150,9 +2158,7 @@ class Query( x.union(y, z).all() - produces: - - .. sourcecode:: sql + produces:: SELECT * FROM (SELECT * FROM X UNION SELECT * FROM y UNION SELECT * FROM Z) @@ -2264,9 +2270,7 @@ class Query( q = session.query(User).join(User.addresses) Where above, the call to :meth:`_query.Query.join` along - ``User.addresses`` will result in SQL approximately equivalent to: - - .. sourcecode:: sql + ``User.addresses`` will result in SQL approximately equivalent to:: SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id @@ -2279,12 +2283,10 @@ class Query( calls may be used. The relationship-bound attribute implies both the left and right side of the join at once:: - q = ( - session.query(User) - .join(User.orders) - .join(Order.items) - .join(Item.keywords) - ) + q = session.query(User).\ + join(User.orders).\ + join(Order.items).\ + join(Item.keywords) .. note:: as seen in the above example, **the order in which each call to the join() method occurs is important**. Query would not, @@ -2323,7 +2325,7 @@ class Query( as the ON clause to be passed explicitly. A example that includes a SQL expression as the ON clause is as follows:: - q = session.query(User).join(Address, User.id == Address.user_id) + q = session.query(User).join(Address, User.id==Address.user_id) The above form may also use a relationship-bound attribute as the ON clause as well:: @@ -2338,13 +2340,11 @@ class Query( a1 = aliased(Address) a2 = aliased(Address) - q = ( - session.query(User) - .join(a1, User.addresses) - .join(a2, User.addresses) - .filter(a1.email_address == "ed@foo.com") - .filter(a2.email_address == "ed@bar.com") - ) + q = session.query(User).\ + join(a1, User.addresses).\ + join(a2, User.addresses).\ + filter(a1.email_address=='ed@foo.com').\ + filter(a2.email_address=='ed@bar.com') The relationship-bound calling form can also specify a target entity using the :meth:`_orm.PropComparator.of_type` method; a query @@ -2353,13 +2353,11 @@ class Query( a1 = aliased(Address) a2 = aliased(Address) - q = ( - session.query(User) - .join(User.addresses.of_type(a1)) - .join(User.addresses.of_type(a2)) - .filter(a1.email_address == "ed@foo.com") - .filter(a2.email_address == "ed@bar.com") - ) + q = session.query(User).\ + join(User.addresses.of_type(a1)).\ + join(User.addresses.of_type(a2)).\ + filter(a1.email_address == 'ed@foo.com').\ + filter(a2.email_address == 'ed@bar.com') **Augmenting Built-in ON Clauses** @@ -2370,7 +2368,7 @@ class Query( with the default criteria using AND:: q = session.query(User).join( - User.addresses.and_(Address.email_address != "foo@bar.com") + User.addresses.and_(Address.email_address != 'foo@bar.com') ) .. versionadded:: 1.4 @@ -2383,28 +2381,29 @@ class Query( appropriate ``.subquery()`` method in order to make a subquery out of a query:: - subq = ( - session.query(Address) - .filter(Address.email_address == "ed@foo.com") - .subquery() + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() + + + q = session.query(User).join( + subq, User.id == subq.c.user_id ) - - q = session.query(User).join(subq, User.id == subq.c.user_id) - Joining to a subquery in terms of a specific relationship and/or target entity may be achieved by linking the subquery to the entity using :func:`_orm.aliased`:: - subq = ( - session.query(Address) - .filter(Address.email_address == "ed@foo.com") - .subquery() - ) + subq = session.query(Address).\ + filter(Address.email_address == 'ed@foo.com').\ + subquery() address_subq = aliased(Address, subq) - q = session.query(User).join(User.addresses.of_type(address_subq)) + q = session.query(User).join( + User.addresses.of_type(address_subq) + ) + **Controlling what to Join From** @@ -2412,16 +2411,11 @@ class Query( :class:`_query.Query` is not in line with what we want to join from, the :meth:`_query.Query.select_from` method may be used:: - q = ( - session.query(Address) - .select_from(User) - .join(User.addresses) - .filter(User.name == "ed") - ) + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') - Which will produce SQL similar to: - - .. sourcecode:: sql + Which will produce SQL similar to:: SELECT address.* FROM user JOIN address ON user.id=address.user_id @@ -2525,16 +2519,11 @@ class Query( A typical example:: - q = ( - session.query(Address) - .select_from(User) - .join(User.addresses) - .filter(User.name == "ed") - ) + q = session.query(Address).select_from(User).\ + join(User.addresses).\ + filter(User.name == 'ed') - Which produces SQL equivalent to: - - .. sourcecode:: sql + Which produces SQL equivalent to:: SELECT address.* FROM user JOIN address ON user.id=address.user_id @@ -2787,10 +2776,11 @@ class Query( def one(self) -> _T: """Return exactly one result or raise an exception. - Raises :class:`_exc.NoResultFound` if the query selects no rows. - Raises :class:`_exc.MultipleResultsFound` if multiple object identities - are returned, or if multiple rows are returned for a query that returns - only scalar values as opposed to full identity-mapped entities. + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query selects + no rows. Raises ``sqlalchemy.orm.exc.MultipleResultsFound`` + if multiple object identities are returned, or if multiple + rows are returned for a query that returns only scalar values + as opposed to full identity-mapped entities. Calling :meth:`.one` results in an execution of the underlying query. @@ -2810,7 +2800,7 @@ class Query( def scalar(self) -> Any: """Return the first element of the first result or None if no rows present. If multiple rows are returned, - raises :class:`_exc.MultipleResultsFound`. + raises MultipleResultsFound. >>> session.query(Item).scalar() @@ -2896,7 +2886,7 @@ class Query( Format is a list of dictionaries:: - user_alias = aliased(User, name="user2") + user_alias = aliased(User, name='user2') q = sess.query(User, User.id, user_alias) # this expression: @@ -2905,26 +2895,26 @@ class Query( # would return: [ { - "name": "User", - "type": User, - "aliased": False, - "expr": User, - "entity": User, + 'name':'User', + 'type':User, + 'aliased':False, + 'expr':User, + 'entity': User }, { - "name": "id", - "type": Integer(), - "aliased": False, - "expr": User.id, - "entity": User, + 'name':'id', + 'type':Integer(), + 'aliased':False, + 'expr':User.id, + 'entity': User }, { - "name": "user2", - "type": User, - "aliased": True, - "expr": user_alias, - "entity": user_alias, - }, + 'name':'user2', + 'type':User, + 'aliased':True, + 'expr':user_alias, + 'entity': user_alias + } ] .. seealso:: @@ -2969,7 +2959,6 @@ class Query( context = QueryContext( compile_state, compile_state.statement, - compile_state.statement, self._params, self.session, self.load_options, @@ -3033,12 +3022,10 @@ class Query( e.g.:: - q = session.query(User).filter(User.name == "fred") + q = session.query(User).filter(User.name == 'fred') session.query(q.exists()) - Producing SQL similar to: - - .. sourcecode:: sql + Producing SQL similar to:: SELECT EXISTS ( SELECT 1 FROM users WHERE users.name = :name_1 @@ -3087,9 +3074,7 @@ class Query( r"""Return a count of rows this the SQL formed by this :class:`Query` would return. - This generates the SQL for this Query as follows: - - .. sourcecode:: sql + This generates the SQL for this Query as follows:: SELECT count(1) AS count_1 FROM ( SELECT @@ -3129,7 +3114,8 @@ class Query( # return count of user "id" grouped # by "name" - session.query(func.count(User.id)).group_by(User.name) + session.query(func.count(User.id)).\ + group_by(User.name) from sqlalchemy import distinct @@ -3147,9 +3133,7 @@ class Query( ) def delete( - self, - synchronize_session: SynchronizeSessionArgument = "auto", - delete_args: Optional[Dict[Any, Any]] = None, + self, synchronize_session: SynchronizeSessionArgument = "auto" ) -> int: r"""Perform a DELETE with an arbitrary WHERE clause. @@ -3157,11 +3141,11 @@ class Query( E.g.:: - sess.query(User).filter(User.age == 25).delete(synchronize_session=False) + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session=False) - sess.query(User).filter(User.age == 25).delete( - synchronize_session="evaluate" - ) + sess.query(User).filter(User.age == 25).\ + delete(synchronize_session='evaluate') .. warning:: @@ -3174,13 +3158,6 @@ class Query( :ref:`orm_expression_update_delete` for a discussion of these strategies. - :param delete_args: Optional dictionary, if present will be passed - to the underlying :func:`_expression.delete` construct as the ``**kw`` - for the object. May be used to pass dialect-specific arguments such - as ``mysql_limit``. - - .. versionadded:: 2.0.37 - :return: the count of rows matched as returned by the database's "row count" feature. @@ -3188,9 +3165,9 @@ class Query( :ref:`orm_expression_update_delete` - """ # noqa: E501 + """ - bulk_del = BulkDelete(self, delete_args) + bulk_del = BulkDelete(self) if self.dispatch.before_compile_delete: for fn in self.dispatch.before_compile_delete: new_query = fn(bulk_del.query, bulk_del) @@ -3200,10 +3177,6 @@ class Query( self = bulk_del.query delete_ = sql.delete(*self._raw_columns) # type: ignore - - if delete_args: - delete_ = delete_.with_dialect_options(**delete_args) - delete_._where_criteria = self._where_criteria result: CursorResult[Any] = self.session.execute( delete_, @@ -3230,13 +3203,11 @@ class Query( E.g.:: - sess.query(User).filter(User.age == 25).update( - {User.age: User.age - 10}, synchronize_session=False - ) + sess.query(User).filter(User.age == 25).\ + update({User.age: User.age - 10}, synchronize_session=False) - sess.query(User).filter(User.age == 25).update( - {"age": User.age - 10}, synchronize_session="evaluate" - ) + sess.query(User).filter(User.age == 25).\ + update({"age": User.age - 10}, synchronize_session='evaluate') .. warning:: @@ -3259,8 +3230,9 @@ class Query( strategies. :param update_args: Optional dictionary, if present will be passed - to the underlying :func:`_expression.update` construct as the ``**kw`` - for the object. May be used to pass dialect-specific arguments such + to the underlying :func:`_expression.update` + construct as the ``**kw`` for + the object. May be used to pass dialect-specific arguments such as ``mysql_limit``, as well as other special arguments such as :paramref:`~sqlalchemy.sql.expression.update.preserve_parameter_order`. @@ -3339,16 +3311,13 @@ class Query( ORMCompileState._get_plugin_class_for_plugin(stmt, "orm"), ) - return compile_state_cls._create_orm_context( - stmt, toplevel=True, compiler=None - ) + return compile_state_cls.create_for_statement(stmt, None) def _compile_context(self, for_statement: bool = False) -> QueryContext: compile_state = self._compile_state(for_statement=for_statement) context = QueryContext( compile_state, compile_state.statement, - compile_state.statement, self._params, self.session, self.load_options, @@ -3437,14 +3406,6 @@ class BulkUpdate(BulkUD): class BulkDelete(BulkUD): """BulkUD which handles DELETEs.""" - def __init__( - self, - query: Query[Any], - delete_kwargs: Optional[Dict[Any, Any]], - ): - super().__init__(query) - self.delete_kwargs = delete_kwargs - class RowReturningQuery(Query[Row[_TP]]): if TYPE_CHECKING: diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/relationships.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/relationships.py index 15b63d1..7ea30d7 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/relationships.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/relationships.py @@ -1,5 +1,5 @@ # orm/relationships.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -19,7 +19,6 @@ import collections from collections import abc import dataclasses import inspect as _py_inspect -import itertools import re import typing from typing import Any @@ -27,7 +26,6 @@ from typing import Callable from typing import cast from typing import Collection from typing import Dict -from typing import FrozenSet from typing import Generic from typing import Iterable from typing import Iterator @@ -181,10 +179,7 @@ _ORMOrderByArgument = Union[ ORMBackrefArgument = Union[str, Tuple[str, Dict[str, Any]]] _ORMColCollectionElement = Union[ - ColumnClause[Any], - _HasClauseElement[Any], - roles.DMLColumnRole, - "Mapped[Any]", + ColumnClause[Any], _HasClauseElement, roles.DMLColumnRole, "Mapped[Any]" ] _ORMColCollectionArgument = Union[ str, @@ -486,7 +481,8 @@ class RelationshipProperty( else: self._overlaps = () - self.cascade = cascade + # mypy ignoring the @property setter + self.cascade = cascade # type: ignore self.back_populates = back_populates @@ -708,16 +704,12 @@ class RelationshipProperty( def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``==`` operator. - In a many-to-one context, such as: - - .. sourcecode:: text + In a many-to-one context, such as:: MyClass.some_prop == this will typically produce a - clause such as: - - .. sourcecode:: text + clause such as:: mytable.related_id == @@ -880,12 +872,11 @@ class RelationshipProperty( An expression like:: session.query(MyClass).filter( - MyClass.somereference.any(SomeRelated.x == 2) + MyClass.somereference.any(SomeRelated.x==2) ) - Will produce a query like: - .. sourcecode:: sql + Will produce a query like:: SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE related.my_id=my_table.id @@ -899,11 +890,11 @@ class RelationshipProperty( :meth:`~.Relationship.Comparator.any` is particularly useful for testing for empty collections:: - session.query(MyClass).filter(~MyClass.somereference.any()) + session.query(MyClass).filter( + ~MyClass.somereference.any() + ) - will produce: - - .. sourcecode:: sql + will produce:: SELECT * FROM my_table WHERE NOT (EXISTS (SELECT 1 FROM related WHERE @@ -934,12 +925,11 @@ class RelationshipProperty( An expression like:: session.query(MyClass).filter( - MyClass.somereference.has(SomeRelated.x == 2) + MyClass.somereference.has(SomeRelated.x==2) ) - Will produce a query like: - .. sourcecode:: sql + Will produce a query like:: SELECT * FROM my_table WHERE EXISTS (SELECT 1 FROM related WHERE @@ -958,7 +948,7 @@ class RelationshipProperty( """ if self.property.uselist: raise sa_exc.InvalidRequestError( - "'has()' not implemented for collections. Use any()." + "'has()' not implemented for collections. " "Use any()." ) return self._criterion_exists(criterion, **kwargs) @@ -978,9 +968,7 @@ class RelationshipProperty( MyClass.contains(other) - Produces a clause like: - - .. sourcecode:: sql + Produces a clause like:: mytable.id == @@ -1000,9 +988,7 @@ class RelationshipProperty( query(MyClass).filter(MyClass.contains(other)) - Produces a query like: - - .. sourcecode:: sql + Produces a query like:: SELECT * FROM my_table, my_association_table AS my_association_table_1 WHERE @@ -1098,15 +1084,11 @@ class RelationshipProperty( def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 """Implement the ``!=`` operator. - In a many-to-one context, such as: - - .. sourcecode:: text + In a many-to-one context, such as:: MyClass.some_prop != - This will typically produce a clause such as: - - .. sourcecode:: sql + This will typically produce a clause such as:: mytable.related_id != @@ -1322,11 +1304,9 @@ class RelationshipProperty( state, dict_, column, - passive=( - PassiveFlag.PASSIVE_OFF - if state.persistent - else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK - ), + passive=PassiveFlag.PASSIVE_OFF + if state.persistent + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, ) if current_value is LoaderCallableStatus.NEVER_SET: @@ -1757,6 +1737,8 @@ class RelationshipProperty( extracted_mapped_annotation: Optional[_AnnotationScanType], is_dataclass_field: bool, ) -> None: + argument = extracted_mapped_annotation + if extracted_mapped_annotation is None: if self.argument is None: self._raise_for_required(key, cls) @@ -1766,17 +1748,19 @@ class RelationshipProperty( argument = extracted_mapped_annotation assert originating_module is not None - if mapped_container is not None: - is_write_only = issubclass(mapped_container, WriteOnlyMapped) - is_dynamic = issubclass(mapped_container, DynamicMapped) - if is_write_only: - self.lazy = "write_only" - self.strategy_key = (("lazy", self.lazy),) - elif is_dynamic: - self.lazy = "dynamic" - self.strategy_key = (("lazy", self.lazy),) - else: - is_write_only = is_dynamic = False + is_write_only = mapped_container is not None and issubclass( + mapped_container, WriteOnlyMapped + ) + if is_write_only: + self.lazy = "write_only" + self.strategy_key = (("lazy", self.lazy),) + + is_dynamic = mapped_container is not None and issubclass( + mapped_container, DynamicMapped + ) + if is_dynamic: + self.lazy = "dynamic" + self.strategy_key = (("lazy", self.lazy),) argument = de_optionalize_union_types(argument) @@ -1827,12 +1811,15 @@ class RelationshipProperty( argument, originating_module ) - if ( - self.collection_class is None - and not is_write_only - and not is_dynamic - ): - self.uselist = False + # we don't allow the collection class to be a + # __forward_arg__ right now, so if we see a forward arg here, + # we know there was no collection class either + if ( + self.collection_class is None + and not is_write_only + and not is_dynamic + ): + self.uselist = False # ticket #8759 # if a lead argument was given to relationship(), like @@ -2012,11 +1999,9 @@ class RelationshipProperty( "the single_parent=True flag." % { "rel": self, - "direction": ( - "many-to-one" - if self.direction is MANYTOONE - else "many-to-many" - ), + "direction": "many-to-one" + if self.direction is MANYTOONE + else "many-to-many", "clsname": self.parent.class_.__name__, "relatedcls": self.mapper.class_.__name__, }, @@ -2909,6 +2894,9 @@ class JoinCondition: ) -> None: """Check the foreign key columns collected and emit error messages.""" + + can_sync = False + foreign_cols = self._gather_columns_with_annotation( join_condition, "foreign" ) @@ -3064,9 +3052,9 @@ class JoinCondition: def _setup_pairs(self) -> None: sync_pairs: _MutableColumnPairs = [] - lrp: util.OrderedSet[Tuple[ColumnElement[Any], ColumnElement[Any]]] = ( - util.OrderedSet([]) - ) + lrp: util.OrderedSet[ + Tuple[ColumnElement[Any], ColumnElement[Any]] + ] = util.OrderedSet([]) secondary_sync_pairs: _MutableColumnPairs = [] def go( @@ -3143,9 +3131,9 @@ class JoinCondition: # level configuration that benefits from this warning. if to_ not in self._track_overlapping_sync_targets: - self._track_overlapping_sync_targets[to_] = ( - weakref.WeakKeyDictionary({self.prop: from_}) - ) + self._track_overlapping_sync_targets[ + to_ + ] = weakref.WeakKeyDictionary({self.prop: from_}) else: other_props = [] prop_to_from = self._track_overlapping_sync_targets[to_] @@ -3243,15 +3231,6 @@ class JoinCondition: if annotation_set.issubset(col._annotations) } - @util.memoized_property - def _secondary_lineage_set(self) -> FrozenSet[ColumnElement[Any]]: - if self.secondary is not None: - return frozenset( - itertools.chain(*[c.proxy_set for c in self.secondary.c]) - ) - else: - return util.EMPTY_SET - def join_targets( self, source_selectable: Optional[FromClause], @@ -3302,25 +3281,23 @@ class JoinCondition: if extra_criteria: - def mark_exclude_cols( + def mark_unrelated_columns_as_ok_to_adapt( elem: SupportsAnnotations, annotations: _AnnotationDict ) -> SupportsAnnotations: - """note unrelated columns in the "extra criteria" as either - should be adapted or not adapted, even though they are not - part of our "local" or "remote" side. + """note unrelated columns in the "extra criteria" as OK + to adapt, even though they are not part of our "local" + or "remote" side. - see #9779 for this case, as well as #11010 for a follow up + see #9779 for this case """ parentmapper_for_element = elem._annotations.get( "parentmapper", None ) - if ( parentmapper_for_element is not self.prop.parent and parentmapper_for_element is not self.prop.mapper - and elem not in self._secondary_lineage_set ): return _safe_annotate(elem, annotations) else: @@ -3329,8 +3306,8 @@ class JoinCondition: extra_criteria = tuple( _deep_annotate( elem, - {"should_not_adapt": True}, - annotate_callable=mark_exclude_cols, + {"ok_to_adapt_in_join_condition": True}, + annotate_callable=mark_unrelated_columns_as_ok_to_adapt, ) for elem in extra_criteria ) @@ -3344,16 +3321,14 @@ class JoinCondition: if secondary is not None: secondary = secondary._anonymous_fromclause(flat=True) primary_aliasizer = ClauseAdapter( - secondary, - exclude_fn=_local_col_exclude, + secondary, exclude_fn=_ColInAnnotations("local") ) secondary_aliasizer = ClauseAdapter( dest_selectable, equivalents=self.child_equivalents ).chain(primary_aliasizer) if source_selectable is not None: primary_aliasizer = ClauseAdapter( - secondary, - exclude_fn=_local_col_exclude, + secondary, exclude_fn=_ColInAnnotations("local") ).chain( ClauseAdapter( source_selectable, @@ -3365,14 +3340,14 @@ class JoinCondition: else: primary_aliasizer = ClauseAdapter( dest_selectable, - exclude_fn=_local_col_exclude, + exclude_fn=_ColInAnnotations("local"), equivalents=self.child_equivalents, ) if source_selectable is not None: primary_aliasizer.chain( ClauseAdapter( source_selectable, - exclude_fn=_remote_col_exclude, + exclude_fn=_ColInAnnotations("remote"), equivalents=self.parent_equivalents, ) ) @@ -3391,7 +3366,9 @@ class JoinCondition: dest_selectable, ) - def create_lazy_clause(self, reverse_direction: bool = False) -> Tuple[ + def create_lazy_clause( + self, reverse_direction: bool = False + ) -> Tuple[ ColumnElement[bool], Dict[str, ColumnElement[Any]], Dict[ColumnElement[Any], ColumnElement[Any]], @@ -3451,29 +3428,25 @@ class JoinCondition: class _ColInAnnotations: - """Serializable object that tests for names in c._annotations. + """Serializable object that tests for a name in c._annotations.""" - TODO: does this need to be serializable anymore? can we find what the - use case was for that? + __slots__ = ("name",) - """ - - __slots__ = ("names",) - - def __init__(self, *names: str): - self.names = frozenset(names) + def __init__(self, name: str): + self.name = name def __call__(self, c: ClauseElement) -> bool: - return bool(self.names.intersection(c._annotations)) + return ( + self.name in c._annotations + or "ok_to_adapt_in_join_condition" in c._annotations + ) -_local_col_exclude = _ColInAnnotations("local", "should_not_adapt") -_remote_col_exclude = _ColInAnnotations("remote", "should_not_adapt") - - -class Relationship( +class Relationship( # type: ignore RelationshipProperty[_T], _DeclarativeMapped[_T], + WriteOnlyMapped[_T], # not compatible with Mapped[_T] + DynamicMapped[_T], # not compatible with Mapped[_T] ): """Describes an object property that holds a single item or list of items that correspond to a related database table. @@ -3491,18 +3464,3 @@ class Relationship( inherit_cache = True """:meta private:""" - - -class _RelationshipDeclared( # type: ignore[misc] - Relationship[_T], - WriteOnlyMapped[_T], # not compatible with Mapped[_T] - DynamicMapped[_T], # not compatible with Mapped[_T] -): - """Relationship subclass used implicitly for declarative mapping.""" - - inherit_cache = True - """:meta private:""" - - @classmethod - def _mapper_property_name(cls) -> str: - return "Relationship" diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/scoping.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/scoping.py index df5a653..ab632bd 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/scoping.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/scoping.py @@ -1,5 +1,5 @@ # orm/scoping.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -86,7 +86,8 @@ class QueryPropertyDescriptor(Protocol): """ - def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: + ... _O = TypeVar("_O", bound=object) @@ -280,13 +281,11 @@ class scoped_session(Generic[_S]): Session = scoped_session(sessionmaker()) - class MyClass: query: QueryPropertyDescriptor = Session.query_property() - # after mappers are defined - result = MyClass.query.filter(MyClass.name == "foo").all() + result = MyClass.query.filter(MyClass.name=='foo').all() Produces instances of the session's configured query class by default. To override and use a custom implementation, provide @@ -535,12 +534,12 @@ class scoped_session(Generic[_S]): behalf of the :class:`_orm.scoping.scoped_session` class. This method provides for same "reset-only" behavior that the - :meth:`_orm.Session.close` method has provided historically, where the + :meth:_orm.Session.close method has provided historically, where the state of the :class:`_orm.Session` is reset as though the object were brand new, and ready to be used again. - This method may then be useful for :class:`_orm.Session` objects + The method may then be useful for :class:`_orm.Session` objects which set :paramref:`_orm.Session.close_resets_only` to ``False``, - so that "reset only" behavior is still available. + so that "reset only" behavior is still available from this method. .. versionadded:: 2.0.22 @@ -683,7 +682,8 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: ... + ) -> Result[_T]: + ... @overload def execute( @@ -695,7 +695,8 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... @overload def execute( @@ -707,7 +708,8 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> Result[Any]: + ... def execute( self, @@ -732,8 +734,9 @@ class scoped_session(Generic[_S]): E.g.:: from sqlalchemy import select - - result = session.execute(select(User).where(User.id == 5)) + result = session.execute( + select(User).where(User.id == 5) + ) The API contract of :meth:`_orm.Session.execute` is similar to that of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version @@ -963,7 +966,10 @@ class scoped_session(Generic[_S]): some_object = session.get(VersionedFoo, (5, 10)) - some_object = session.get(VersionedFoo, {"id": 5, "version_id": 10}) + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved from the now legacy :meth:`_orm.Query.get` method. @@ -1086,7 +1092,8 @@ class scoped_session(Generic[_S]): Proxied for the :class:`_orm.Session` class on behalf of the :class:`_orm.scoping.scoped_session` class. - Raises :class:`_exc.NoResultFound` if the query selects no rows. + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. For a detailed documentation of the arguments see the method :meth:`.Session.get`. @@ -1225,7 +1232,7 @@ class scoped_session(Generic[_S]): This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously flushed or committed value, if any. + value to its previously committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -1567,12 +1574,14 @@ class scoped_session(Generic[_S]): return self._proxied.merge(instance, load=load, options=options) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: ... + ) -> RowReturningQuery[Tuple[_T]]: + ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -1582,12 +1591,14 @@ class scoped_session(Generic[_S]): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... @overload def query( @@ -1596,7 +1607,8 @@ class scoped_session(Generic[_S]): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... @overload def query( @@ -1606,7 +1618,8 @@ class scoped_session(Generic[_S]): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... @overload def query( @@ -1617,7 +1630,8 @@ class scoped_session(Generic[_S]): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... @overload def query( @@ -1629,7 +1643,8 @@ class scoped_session(Generic[_S]): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... @overload def query( @@ -1642,14 +1657,16 @@ class scoped_session(Generic[_S]): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: ... + ) -> Query[Any]: + ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -1801,7 +1818,8 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload def scalar( @@ -1812,7 +1830,8 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: ... + ) -> Any: + ... def scalar( self, @@ -1854,7 +1873,8 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload def scalars( @@ -1865,7 +1885,8 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... def scalars( self, diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/session.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/session.py index 6a589f3..d861981 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/session.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/session.py @@ -1,5 +1,5 @@ # orm/session.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -146,9 +146,9 @@ __all__ = [ "object_session", ] -_sessions: weakref.WeakValueDictionary[int, Session] = ( - weakref.WeakValueDictionary() -) +_sessions: weakref.WeakValueDictionary[ + int, Session +] = weakref.WeakValueDictionary() """Weak-referencing dictionary of :class:`.Session` objects. """ @@ -188,7 +188,8 @@ class _ConnectionCallableProto(Protocol): mapper: Optional[Mapper[Any]] = None, instance: Optional[object] = None, **kw: Any, - ) -> Connection: ... + ) -> Connection: + ... def _state_session(state: InstanceState[Any]) -> Optional[Session]: @@ -575,67 +576,22 @@ class ORMExecuteState(util.MemoizedSlots): @property def is_select(self) -> bool: - """return True if this is a SELECT operation. - - .. versionchanged:: 2.0.30 - the attribute is also True for a - :meth:`_sql.Select.from_statement` construct that is itself against - a :class:`_sql.Select` construct, such as - ``select(Entity).from_statement(select(..))`` - - """ + """return True if this is a SELECT operation.""" return self.statement.is_select - @property - def is_from_statement(self) -> bool: - """return True if this operation is a - :meth:`_sql.Select.from_statement` operation. - - This is independent from :attr:`_orm.ORMExecuteState.is_select`, as a - ``select().from_statement()`` construct can be used with - INSERT/UPDATE/DELETE RETURNING types of statements as well. - :attr:`_orm.ORMExecuteState.is_select` will only be set if the - :meth:`_sql.Select.from_statement` is itself against a - :class:`_sql.Select` construct. - - .. versionadded:: 2.0.30 - - """ - return self.statement.is_from_statement - @property def is_insert(self) -> bool: - """return True if this is an INSERT operation. - - .. versionchanged:: 2.0.30 - the attribute is also True for a - :meth:`_sql.Select.from_statement` construct that is itself against - a :class:`_sql.Insert` construct, such as - ``select(Entity).from_statement(insert(..))`` - - """ + """return True if this is an INSERT operation.""" return self.statement.is_dml and self.statement.is_insert @property def is_update(self) -> bool: - """return True if this is an UPDATE operation. - - .. versionchanged:: 2.0.30 - the attribute is also True for a - :meth:`_sql.Select.from_statement` construct that is itself against - a :class:`_sql.Update` construct, such as - ``select(Entity).from_statement(update(..))`` - - """ + """return True if this is an UPDATE operation.""" return self.statement.is_dml and self.statement.is_update @property def is_delete(self) -> bool: - """return True if this is a DELETE operation. - - .. versionchanged:: 2.0.30 - the attribute is also True for a - :meth:`_sql.Select.from_statement` construct that is itself against - a :class:`_sql.Delete` construct, such as - ``select(Entity).from_statement(delete(..))`` - - """ + """return True if this is a DELETE operation.""" return self.statement.is_dml and self.statement.is_delete @property @@ -1044,11 +1000,9 @@ class SessionTransaction(_StateChange, TransactionalContext): def _begin(self, nested: bool = False) -> SessionTransaction: return SessionTransaction( self.session, - ( - SessionTransactionOrigin.BEGIN_NESTED - if nested - else SessionTransactionOrigin.SUBTRANSACTION - ), + SessionTransactionOrigin.BEGIN_NESTED + if nested + else SessionTransactionOrigin.SUBTRANSACTION, self, ) @@ -1211,17 +1165,6 @@ class SessionTransaction(_StateChange, TransactionalContext): else: join_transaction_mode = "rollback_only" - if local_connect: - util.warn( - "The engine provided as bind produced a " - "connection that is already in a transaction. " - "This is usually caused by a core event, " - "such as 'engine_connect', that has left a " - "transaction open. The effective join " - "transaction mode used by this session is " - f"{join_transaction_mode!r}. To silence this " - "warning, do not leave transactions open" - ) if join_transaction_mode in ( "control_fully", "rollback_only", @@ -1569,16 +1512,12 @@ class Session(_SessionClassMethods, EventTarget): operation. The complete heuristics for resolution are described at :meth:`.Session.get_bind`. Usage looks like:: - Session = sessionmaker( - binds={ - SomeMappedClass: create_engine("postgresql+psycopg2://engine1"), - SomeDeclarativeBase: create_engine( - "postgresql+psycopg2://engine2" - ), - some_mapper: create_engine("postgresql+psycopg2://engine3"), - some_table: create_engine("postgresql+psycopg2://engine4"), - } - ) + Session = sessionmaker(binds={ + SomeMappedClass: create_engine('postgresql+psycopg2://engine1'), + SomeDeclarativeBase: create_engine('postgresql+psycopg2://engine2'), + some_mapper: create_engine('postgresql+psycopg2://engine3'), + some_table: create_engine('postgresql+psycopg2://engine4'), + }) .. seealso:: @@ -1773,7 +1712,7 @@ class Session(_SessionClassMethods, EventTarget): # the idea is that at some point NO_ARG will warn that in the future # the default will switch to close_resets_only=False. - if close_resets_only in (True, _NoArg.NO_ARG): + if close_resets_only or close_resets_only is _NoArg.NO_ARG: self._close_state = _SessionCloseState.CLOSE_IS_RESET else: self._close_state = _SessionCloseState.ACTIVE @@ -1880,11 +1819,9 @@ class Session(_SessionClassMethods, EventTarget): ) trans = SessionTransaction( self, - ( - SessionTransactionOrigin.BEGIN - if begin - else SessionTransactionOrigin.AUTOBEGIN - ), + SessionTransactionOrigin.BEGIN + if begin + else SessionTransactionOrigin.AUTOBEGIN, ) assert self._transaction is trans return trans @@ -2120,7 +2057,8 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: Literal[True] = ..., - ) -> Any: ... + ) -> Any: + ... @overload def _execute_internal( @@ -2133,7 +2071,8 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result[Any]: ... + ) -> Result[Any]: + ... def _execute_internal( self, @@ -2276,7 +2215,8 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[_T]: ... + ) -> Result[_T]: + ... @overload def execute( @@ -2288,7 +2228,8 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... @overload def execute( @@ -2300,7 +2241,8 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Any]: ... + ) -> Result[Any]: + ... def execute( self, @@ -2320,8 +2262,9 @@ class Session(_SessionClassMethods, EventTarget): E.g.:: from sqlalchemy import select - - result = session.execute(select(User).where(User.id == 5)) + result = session.execute( + select(User).where(User.id == 5) + ) The API contract of :meth:`_orm.Session.execute` is similar to that of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version @@ -2380,7 +2323,8 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: ... + ) -> Optional[_T]: + ... @overload def scalar( @@ -2391,7 +2335,8 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: ... + ) -> Any: + ... def scalar( self, @@ -2428,7 +2373,8 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: ... + ) -> ScalarResult[_T]: + ... @overload def scalars( @@ -2439,7 +2385,8 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: ... + ) -> ScalarResult[Any]: + ... def scalars( self, @@ -2525,12 +2472,12 @@ class Session(_SessionClassMethods, EventTarget): :class:`_orm.Session`, resetting the session to its initial state. This method provides for same "reset-only" behavior that the - :meth:`_orm.Session.close` method has provided historically, where the + :meth:_orm.Session.close method has provided historically, where the state of the :class:`_orm.Session` is reset as though the object were brand new, and ready to be used again. - This method may then be useful for :class:`_orm.Session` objects + The method may then be useful for :class:`_orm.Session` objects which set :paramref:`_orm.Session.close_resets_only` to ``False``, - so that "reset only" behavior is still available. + so that "reset only" behavior is still available from this method. .. versionadded:: 2.0.22 @@ -2848,12 +2795,14 @@ class Session(_SessionClassMethods, EventTarget): ) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: + ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[Tuple[_T]]: ... + ) -> RowReturningQuery[Tuple[_T]]: + ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -2863,12 +2812,14 @@ class Session(_SessionClassMethods, EventTarget): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> RowReturningQuery[Tuple[_T0, _T1]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1]]: + ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2]]: + ... @overload def query( @@ -2877,7 +2828,8 @@ class Session(_SessionClassMethods, EventTarget): __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3]]: + ... @overload def query( @@ -2887,7 +2839,8 @@ class Session(_SessionClassMethods, EventTarget): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... @overload def query( @@ -2898,7 +2851,8 @@ class Session(_SessionClassMethods, EventTarget): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... @overload def query( @@ -2910,7 +2864,8 @@ class Session(_SessionClassMethods, EventTarget): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... @overload def query( @@ -2923,14 +2878,16 @@ class Session(_SessionClassMethods, EventTarget): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... + ) -> RowReturningQuery[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: ... + ) -> Query[Any]: + ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -2973,7 +2930,7 @@ class Session(_SessionClassMethods, EventTarget): e.g.:: - obj = session._identity_lookup(inspect(SomeClass), (1,)) + obj = session._identity_lookup(inspect(SomeClass), (1, )) :param mapper: mapper in use :param primary_key_identity: the primary key we are searching for, as @@ -3044,8 +3001,7 @@ class Session(_SessionClassMethods, EventTarget): @util.langhelpers.tag_method_for_warnings( "This warning originated from the Session 'autoflush' process, " "which was invoked automatically in response to a user-initiated " - "operation. Consider using ``no_autoflush`` context manager if this " - "warning happened while initializing objects.", + "operation.", sa_exc.SAWarning, ) def _autoflush(self) -> None: @@ -3601,7 +3557,10 @@ class Session(_SessionClassMethods, EventTarget): some_object = session.get(VersionedFoo, (5, 10)) - some_object = session.get(VersionedFoo, {"id": 5, "version_id": 10}) + some_object = session.get( + VersionedFoo, + {"id": 5, "version_id": 10} + ) .. versionadded:: 1.4 Added :meth:`_orm.Session.get`, which is moved from the now legacy :meth:`_orm.Query.get` method. @@ -3690,7 +3649,7 @@ class Session(_SessionClassMethods, EventTarget): :return: The object instance, or ``None``. - """ # noqa: E501 + """ return self._get_impl( entity, ident, @@ -3718,7 +3677,8 @@ class Session(_SessionClassMethods, EventTarget): """Return exactly one instance based on the given primary key identifier, or raise an exception if not found. - Raises :class:`_exc.NoResultFound` if the query selects no rows. + Raises ``sqlalchemy.orm.exc.NoResultFound`` if the query + selects no rows. For a detailed documentation of the arguments see the method :meth:`.Session.get`. @@ -3808,9 +3768,9 @@ class Session(_SessionClassMethods, EventTarget): if correct_keys: primary_key_identity = dict(primary_key_identity) for k in correct_keys: - primary_key_identity[pk_synonyms[k]] = ( - primary_key_identity[k] - ) + primary_key_identity[ + pk_synonyms[k] + ] = primary_key_identity[k] try: primary_key_identity = list( @@ -4014,7 +3974,14 @@ class Session(_SessionClassMethods, EventTarget): else: key_is_persistent = True - merged = self.identity_map.get(key) + if key in self.identity_map: + try: + merged = self.identity_map[key] + except KeyError: + # object was GC'ed right as we checked for it + merged = None + else: + merged = None if merged is None: if key_is_persistent and key in _resolve_conflict_map: @@ -4578,11 +4545,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, states, - isupdate=isupdate, - isstates=True, - return_defaults=return_defaults, - update_changed_only=update_changed_only, - render_nulls=False, + isupdate, + True, + return_defaults, + update_changed_only, + False, ) def bulk_insert_mappings( @@ -4661,11 +4628,11 @@ class Session(_SessionClassMethods, EventTarget): self._bulk_save_mappings( mapper, mappings, - isupdate=False, - isstates=False, - return_defaults=return_defaults, - update_changed_only=False, - render_nulls=render_nulls, + False, + False, + return_defaults, + False, + render_nulls, ) def bulk_update_mappings( @@ -4707,20 +4674,13 @@ class Session(_SessionClassMethods, EventTarget): """ self._bulk_save_mappings( - mapper, - mappings, - isupdate=True, - isstates=False, - return_defaults=False, - update_changed_only=False, - render_nulls=False, + mapper, mappings, True, False, False, False, False ) def _bulk_save_mappings( self, mapper: Mapper[_O], mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]], - *, isupdate: bool, isstates: bool, return_defaults: bool, @@ -4737,17 +4697,17 @@ class Session(_SessionClassMethods, EventTarget): mapper, mappings, transaction, - isstates=isstates, - update_changed_only=update_changed_only, + isstates, + update_changed_only, ) else: bulk_persistence._bulk_insert( mapper, mappings, transaction, - isstates=isstates, - return_defaults=return_defaults, - render_nulls=render_nulls, + isstates, + return_defaults, + render_nulls, ) transaction.commit() @@ -4765,7 +4725,7 @@ class Session(_SessionClassMethods, EventTarget): This method retrieves the history for each instrumented attribute on the instance and performs a comparison of the current - value to its previously flushed or committed value, if any. + value to its previously committed value, if any. It is in effect a more expensive and accurate version of checking for the given instance in the @@ -4935,7 +4895,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): # an Engine, which the Session will use for connection # resources - engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/") + engine = create_engine('postgresql+psycopg2://scott:tiger@localhost/') Session = sessionmaker(engine) @@ -4988,7 +4948,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): with engine.connect() as connection: with Session(bind=connection) as session: - ... # work with session + # work with session The class also includes a method :meth:`_orm.sessionmaker.configure`, which can be used to specify additional keyword arguments to the factory, which @@ -5003,7 +4963,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): # ... later, when an engine URL is read from a configuration # file or other events allow the engine to be created - engine = create_engine("sqlite:///foo.db") + engine = create_engine('sqlite:///foo.db') Session.configure(bind=engine) sess = Session() @@ -5028,7 +4988,8 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): ... + ): + ... @overload def __init__( @@ -5039,7 +5000,8 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): ... + ): + ... def __init__( self, @@ -5141,7 +5103,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): Session = sessionmaker() - Session.configure(bind=create_engine("sqlite://")) + Session.configure(bind=create_engine('sqlite://')) """ self.kw.update(new_kw) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/state.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/state.py index d4bbf92..d9e1f85 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/state.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/state.py @@ -1,5 +1,5 @@ # orm/state.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -78,7 +78,8 @@ if not TYPE_CHECKING: class _InstanceDictProto(Protocol): - def __call__(self) -> Optional[IdentityMap]: ... + def __call__(self) -> Optional[IdentityMap]: + ... class _InstallLoaderCallableProto(Protocol[_O]): @@ -93,12 +94,13 @@ class _InstallLoaderCallableProto(Protocol[_O]): def __call__( self, state: InstanceState[_O], dict_: _InstanceDict, row: Row[Any] - ) -> None: ... + ) -> None: + ... @inspection._self_inspects class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): - """Tracks state information at the instance level. + """tracks state information at the instance level. The :class:`.InstanceState` is a key object used by the SQLAlchemy ORM in order to track the state of an object; @@ -148,14 +150,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): committed_state: Dict[str, Any] modified: bool = False - """When ``True`` the object was modified.""" expired: bool = False - """When ``True`` the object is :term:`expired`. - - .. seealso:: - - :ref:`session_expire` - """ _deleted: bool = False _load_pending: bool = False _orphaned_outside_of_session: bool = False @@ -176,12 +171,11 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): expired_attributes: Set[str] """The set of keys which are 'expired' to be loaded by - the manager's deferred scalar loader, assuming no pending - changes. + the manager's deferred scalar loader, assuming no pending + changes. - See also the ``unmodified`` collection which is intersected - against this set when a refresh operation occurs. - """ + see also the ``unmodified`` collection which is intersected + against this set when a refresh operation occurs.""" callables: Dict[str, Callable[[InstanceState[_O], PassiveFlag], Any]] """A namespace where a per-state loader callable can be associated. @@ -236,6 +230,7 @@ class InstanceState(interfaces.InspectionAttrInfo, Generic[_O]): def pending(self) -> bool: """Return ``True`` if the object is :term:`pending`. + .. seealso:: :ref:`session_object_states` diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/state_changes.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/state_changes.py index a79874e..3d74ff2 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/state_changes.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/state_changes.py @@ -1,11 +1,13 @@ # orm/state_changes.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""State tracking utilities used by :class:`_orm.Session`.""" +"""State tracking utilities used by :class:`_orm.Session`. + +""" from __future__ import annotations diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategies.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategies.py index 8ac34e2..1e58f40 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategies.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategies.py @@ -1,5 +1,5 @@ # orm/strategies.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -8,7 +8,7 @@ """sqlalchemy.orm.interfaces.LoaderStrategy -implementations, and related MapperOptions.""" + implementations, and related MapperOptions.""" from __future__ import annotations @@ -16,10 +16,8 @@ import collections import itertools from typing import Any from typing import Dict -from typing import Optional from typing import Tuple from typing import TYPE_CHECKING -from typing import Union from . import attributes from . import exc as orm_exc @@ -47,7 +45,7 @@ from .interfaces import StrategizedProperty from .session import _state_session from .state import InstanceState from .strategy_options import Load -from .util import _none_only_set +from .util import _none_set from .util import AliasedClass from .. import event from .. import exc as sa_exc @@ -59,10 +57,8 @@ from ..sql import util as sql_util from ..sql import visitors from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL from ..sql.selectable import Select -from ..util.typing import Literal if TYPE_CHECKING: - from .mapper import Mapper from .relationships import RelationshipProperty from ..sql.elements import ColumnElement @@ -388,7 +384,7 @@ class DeferredColumnLoader(LoaderStrategy): super().__init__(parent, strategy_key) if hasattr(self.parent_property, "composite_class"): raise NotImplementedError( - "Deferred loading for composite types not implemented yet" + "Deferred loading for composite " "types not implemented yet" ) self.raiseload = self.strategy_opts.get("raiseload", False) self.columns = self.parent_property.columns @@ -762,7 +758,7 @@ class LazyLoader( self._equated_columns[c] = self._equated_columns[col] self.logger.info( - "%s will use Session.get() to optimize instance loads", self + "%s will use Session.get() to " "optimize instance loads", self ) def init_class_attribute(self, mapper): @@ -936,15 +932,8 @@ class LazyLoader( elif LoaderCallableStatus.NEVER_SET in primary_key_identity: return LoaderCallableStatus.NEVER_SET - # test for None alone in primary_key_identity based on - # allow_partial_pks preference. PASSIVE_NO_RESULT and NEVER_SET - # have already been tested above - if not self.mapper.allow_partial_pks: - if _none_only_set.intersection(primary_key_identity): - return None - else: - if _none_only_set.issuperset(primary_key_identity): - return None + if _none_set.issuperset(primary_key_identity): + return None if ( self.key in state.dict @@ -1206,11 +1195,9 @@ class LazyLoader( key, self, loadopt, - ( - loadopt._generate_extra_criteria(context) - if loadopt._extra_criteria - else None - ), + loadopt._generate_extra_criteria(context) + if loadopt._extra_criteria + else None, ), key, ) @@ -1384,16 +1371,12 @@ class ImmediateLoader(PostLoader): adapter, populators, ): - if not context.compile_state.compile_options._enable_eagerloads: - return - ( effective_path, run_loader, execution_options, recursion_depth, ) = self._setup_for_recursion(context, path, loadopt, self.join_depth) - if not run_loader: # this will not emit SQL and will only emit for a many-to-one # "use get" load. the "_RELATED" part means it may return @@ -1435,6 +1418,7 @@ class ImmediateLoader(PostLoader): alternate_effective_path = path._truncate_recursive() extra_options = (new_opt,) else: + new_opt = None alternate_effective_path = path extra_options = () @@ -1688,11 +1672,9 @@ class SubqueryLoader(PostLoader): elif ltj > 2: middle = [ ( - ( - orm_util.AliasedClass(item[0]) - if not inspect(item[0]).is_aliased_class - else item[0].entity - ), + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity, item[1], ) for item in to_join[1:-1] @@ -1971,18 +1953,6 @@ class SubqueryLoader(PostLoader): adapter, populators, ): - if ( - loadopt - and context.compile_state.statement is not None - and context.compile_state.statement.is_dml - ): - util.warn_deprecated( - "The subqueryload loader option is not compatible with DML " - "statements such as INSERT, UPDATE. Only SELECT may be used." - "This warning will become an exception in a future release.", - "2.0", - ) - if context.refresh_state: return self._immediateload_create_row_processor( context, @@ -2148,22 +2118,13 @@ class JoinedLoader(AbstractRelationshipLoader): if not compile_state.compile_options._enable_eagerloads: return - elif ( - loadopt - and compile_state.statement is not None - and compile_state.statement.is_dml - ): - util.warn_deprecated( - "The joinedload loader option is not compatible with DML " - "statements such as INSERT, UPDATE. Only SELECT may be used." - "This warning will become an exception in a future release.", - "2.0", - ) elif self.uselist: compile_state.multi_row_eager_loaders = True path = path[self.parent_property] + with_polymorphic = None + user_defined_adapter = ( self._init_user_defined_eager_proc( loadopt, compile_state, compile_state.attributes @@ -2367,11 +2328,9 @@ class JoinedLoader(AbstractRelationshipLoader): to_adapt = orm_util.AliasedClass( self.mapper, - alias=( - alt_selectable._anonymous_fromclause(flat=True) - if alt_selectable is not None - else None - ), + alias=alt_selectable._anonymous_fromclause(flat=True) + if alt_selectable is not None + else None, flat=True, use_mapper_path=True, ) @@ -2541,13 +2500,13 @@ class JoinedLoader(AbstractRelationshipLoader): or query_entity.entity_zero.represents_outer_join or (chained_from_outerjoin and isinstance(towrap, sql.Join)), _left_memo=self.parent, - _right_memo=path[self.mapper], + _right_memo=self.mapper, _extra_criteria=extra_join_criteria, ) else: # all other cases are innerjoin=='nested' approach eagerjoin = self._splice_nested_inner_join( - path, path[-2], towrap, clauses, onclause, extra_join_criteria + path, towrap, clauses, onclause, extra_join_criteria ) compile_state.eager_joins[query_entity_key] = eagerjoin @@ -2581,177 +2540,93 @@ class JoinedLoader(AbstractRelationshipLoader): ) def _splice_nested_inner_join( - self, - path, - entity_we_want_to_splice_onto, - join_obj, - clauses, - onclause, - extra_criteria, - entity_inside_join_structure: Union[ - Mapper, None, Literal[False] - ] = False, - detected_existing_path: Optional[path_registry.PathRegistry] = None, + self, path, join_obj, clauses, onclause, extra_criteria, splicing=False ): # recursive fn to splice a nested join into an existing one. - # entity_inside_join_structure=False means this is the outermost call, - # and it should return a value. entity_inside_join_structure= - # indicates we've descended into a join and are looking at a FROM - # clause representing this mapper; if this is not - # entity_we_want_to_splice_onto then return None to end the recursive - # branch + # splicing=False means this is the outermost call, and it + # should return a value. splicing= is the recursive + # form, where it can return None to indicate the end of the recursion - assert entity_we_want_to_splice_onto is path[-2] - - if entity_inside_join_structure is False: + if splicing is False: + # first call is always handed a join object + # from the outside assert isinstance(join_obj, orm_util._ORMJoin) - - if isinstance(join_obj, sql.selectable.FromGrouping): - # FromGrouping - continue descending into the structure + elif isinstance(join_obj, sql.selectable.FromGrouping): return self._splice_nested_inner_join( path, - entity_we_want_to_splice_onto, join_obj.element, clauses, onclause, extra_criteria, - entity_inside_join_structure, + splicing, ) - elif isinstance(join_obj, orm_util._ORMJoin): - # _ORMJoin - continue descending into the structure + elif not isinstance(join_obj, orm_util._ORMJoin): + if path[-2].isa(splicing): + return orm_util._ORMJoin( + join_obj, + clauses.aliased_insp, + onclause, + isouter=False, + _left_memo=splicing, + _right_memo=path[-1].mapper, + _extra_criteria=extra_criteria, + ) + else: + return None - join_right_path = join_obj._right_memo - - # see if right side of join is viable + target_join = self._splice_nested_inner_join( + path, + join_obj.right, + clauses, + onclause, + extra_criteria, + join_obj._right_memo, + ) + if target_join is None: + right_splice = False target_join = self._splice_nested_inner_join( path, - entity_we_want_to_splice_onto, - join_obj.right, + join_obj.left, clauses, onclause, extra_criteria, - entity_inside_join_structure=( - join_right_path[-1].mapper - if join_right_path is not None - else None - ), + join_obj._left_memo, + ) + if target_join is None: + # should only return None when recursively called, + # e.g. splicing refers to a from obj + assert ( + splicing is not False + ), "assertion failed attempting to produce joined eager loads" + return None + else: + right_splice = True + + if right_splice: + # for a right splice, attempt to flatten out + # a JOIN b JOIN c JOIN .. to avoid needless + # parenthesis nesting + if not join_obj.isouter and not target_join.isouter: + eagerjoin = join_obj._splice_into_center(target_join) + else: + eagerjoin = orm_util._ORMJoin( + join_obj.left, + target_join, + join_obj.onclause, + isouter=join_obj.isouter, + _left_memo=join_obj._left_memo, + ) + else: + eagerjoin = orm_util._ORMJoin( + target_join, + join_obj.right, + join_obj.onclause, + isouter=join_obj.isouter, + _right_memo=join_obj._right_memo, ) - if target_join is not None: - # for a right splice, attempt to flatten out - # a JOIN b JOIN c JOIN .. to avoid needless - # parenthesis nesting - if not join_obj.isouter and not target_join.isouter: - eagerjoin = join_obj._splice_into_center(target_join) - else: - eagerjoin = orm_util._ORMJoin( - join_obj.left, - target_join, - join_obj.onclause, - isouter=join_obj.isouter, - _left_memo=join_obj._left_memo, - ) - - eagerjoin._target_adapter = target_join._target_adapter - return eagerjoin - - else: - # see if left side of join is viable - target_join = self._splice_nested_inner_join( - path, - entity_we_want_to_splice_onto, - join_obj.left, - clauses, - onclause, - extra_criteria, - entity_inside_join_structure=join_obj._left_memo, - detected_existing_path=join_right_path, - ) - - if target_join is not None: - eagerjoin = orm_util._ORMJoin( - target_join, - join_obj.right, - join_obj.onclause, - isouter=join_obj.isouter, - _right_memo=join_obj._right_memo, - ) - eagerjoin._target_adapter = target_join._target_adapter - return eagerjoin - - # neither side viable, return None, or fail if this was the top - # most call - if entity_inside_join_structure is False: - assert ( - False - ), "assertion failed attempting to produce joined eager loads" - return None - - # reached an endpoint (e.g. a table that's mapped, or an alias of that - # table). determine if we can use this endpoint to splice onto - - # is this the entity we want to splice onto in the first place? - if not entity_we_want_to_splice_onto.isa(entity_inside_join_structure): - return None - - # path check. if we know the path how this join endpoint got here, - # lets look at our path we are satisfying and see if we're in the - # wrong place. This is specifically for when our entity may - # appear more than once in the path, issue #11449 - # updated in issue #11965. - if detected_existing_path and len(detected_existing_path) > 2: - # this assertion is currently based on how this call is made, - # where given a join_obj, the call will have these parameters as - # entity_inside_join_structure=join_obj._left_memo - # and entity_inside_join_structure=join_obj._right_memo.mapper - assert detected_existing_path[-3] is entity_inside_join_structure - - # from that, see if the path we are targeting matches the - # "existing" path of this join all the way up to the midpoint - # of this join object (e.g. the relationship). - # if not, then this is not our target - # - # a test condition where this test is false looks like: - # - # desired splice: Node->kind->Kind - # path of desired splice: NodeGroup->nodes->Node->kind - # path we've located: NodeGroup->nodes->Node->common_node->Node - # - # above, because we want to splice kind->Kind onto - # NodeGroup->nodes->Node, this is not our path because it actually - # goes more steps than we want into self-referential - # ->common_node->Node - # - # a test condition where this test is true looks like: - # - # desired splice: B->c2s->C2 - # path of desired splice: A->bs->B->c2s - # path we've located: A->bs->B->c1s->C1 - # - # above, we want to splice c2s->C2 onto B, and the located path - # shows that the join ends with B->c1s->C1. so we will - # add another join onto that, which would create a "branch" that - # we might represent in a pseudopath as: - # - # B->c1s->C1 - # ->c2s->C2 - # - # i.e. A JOIN B ON JOIN C1 ON - # JOIN C2 ON - # - - if detected_existing_path[0:-2] != path.path[0:-1]: - return None - - return orm_util._ORMJoin( - join_obj, - clauses.aliased_insp, - onclause, - isouter=False, - _left_memo=entity_inside_join_structure, - _right_memo=path[path[-1].mapper], - _extra_criteria=extra_criteria, - ) + eagerjoin._target_adapter = target_join._target_adapter + return eagerjoin def _create_eager_adapter(self, context, result, adapter, path, loadopt): compile_state = context.compile_state @@ -2800,10 +2675,6 @@ class JoinedLoader(AbstractRelationshipLoader): adapter, populators, ): - - if not context.compile_state.compile_options._enable_eagerloads: - return - if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " @@ -3083,9 +2954,6 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): if not run_loader: return - if not context.compile_state.compile_options._enable_eagerloads: - return - if not self.parent.class_manager[self.key].impl.supports_population: raise sa_exc.InvalidRequestError( "'%s' does not support object " @@ -3243,7 +3111,7 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): orig_query = context.compile_state.select_statement # the actual statement that was requested is this one: - # context_query = context.user_passed_query + # context_query = context.query # # that's not the cached one, however. So while it is of the identical # structure, if it has entities like AliasedInsp, which we get from @@ -3267,11 +3135,11 @@ class SelectInLoader(PostLoader, util.MemoizedSlots): effective_path = path[self.parent_property] - if orig_query is context.user_passed_query: + if orig_query is context.query: new_options = orig_query._with_options else: cached_options = orig_query._with_options - uncached_options = context.user_passed_query._with_options + uncached_options = context.query._with_options # propagate compile state options from the original query, # updating their "extra_criteria" as necessary. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategy_options.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategy_options.py index 17bbe35..6c81e8f 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategy_options.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/strategy_options.py @@ -1,12 +1,13 @@ -# orm/strategy_options.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -""" """ +""" + +""" from __future__ import annotations @@ -96,7 +97,6 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): attr: _AttrType, alias: Optional[_FromClauseArgument] = None, _is_chain: bool = False, - _propagate_to_loaders: bool = False, ) -> Self: r"""Indicate that the given attribute should be eagerly loaded from columns stated manually in the query. @@ -107,7 +107,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): The option is used in conjunction with an explicit join that loads the desired rows, i.e.:: - sess.query(Order).join(Order.user).options(contains_eager(Order.user)) + sess.query(Order).\ + join(Order.user).\ + options(contains_eager(Order.user)) The above query would join from the ``Order`` entity to its related ``User`` entity, and the returned ``Order`` objects would have the @@ -118,9 +120,11 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): :ref:`orm_queryguide_populate_existing` execution option assuming the primary collection of parent objects may already have been loaded:: - sess.query(User).join(User.addresses).filter( - Address.email_address.like("%@aol.com") - ).options(contains_eager(User.addresses)).populate_existing() + sess.query(User).\ + join(User.addresses).\ + filter(Address.email_address.like('%@aol.com')).\ + options(contains_eager(User.addresses)).\ + populate_existing() See the section :ref:`contains_eager` for complete usage details. @@ -155,7 +159,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): cloned = self._set_relationship_strategy( attr, {"lazy": "joined"}, - propagate_to_loaders=_propagate_to_loaders, + propagate_to_loaders=False, opts={"eager_from_alias": coerced_alias}, _reconcile_to_other=True if _is_chain else None, ) @@ -186,18 +190,10 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): the lead entity can be specifically referred to using the :class:`_orm.Load` constructor:: - stmt = ( - select(User, Address) - .join(User.addresses) - .options( - Load(User).load_only(User.name, User.fullname), - Load(Address).load_only(Address.email_address), - ) - ) - - When used together with the - :ref:`populate_existing ` - execution option only the attributes listed will be refreshed. + stmt = select(User, Address).join(User.addresses).options( + Load(User).load_only(User.name, User.fullname), + Load(Address).load_only(Address.email_address) + ) :param \*attrs: Attributes to be loaded, all others will be deferred. @@ -222,7 +218,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): """ cloned = self._set_column_strategy( - _expand_column_strategy_attrs(attrs), + attrs, {"deferred": False, "instrument": True}, ) @@ -250,25 +246,28 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): examples:: # joined-load the "orders" collection on "User" - select(User).options(joinedload(User.orders)) + query(User).options(joinedload(User.orders)) # joined-load Order.items and then Item.keywords - select(Order).options(joinedload(Order.items).joinedload(Item.keywords)) + query(Order).options( + joinedload(Order.items).joinedload(Item.keywords)) # lazily load Order.items, but when Items are loaded, # joined-load the keywords collection - select(Order).options(lazyload(Order.items).joinedload(Item.keywords)) + query(Order).options( + lazyload(Order.items).joinedload(Item.keywords)) :param innerjoin: if ``True``, indicates that the joined eager load should use an inner join instead of the default of left outer join:: - select(Order).options(joinedload(Order.user, innerjoin=True)) + query(Order).options(joinedload(Order.user, innerjoin=True)) In order to chain multiple eager joins together where some may be OUTER and others INNER, right-nested joins are used to link them:: - select(A).options( - joinedload(A.bs, innerjoin=False).joinedload(B.cs, innerjoin=True) + query(A).options( + joinedload(A.bs, innerjoin=False). + joinedload(B.cs, innerjoin=True) ) The above query, linking A.bs via "outer" join and B.cs via "inner" @@ -283,7 +282,10 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): will render as LEFT OUTER JOIN. For example, supposing ``A.bs`` is an outerjoin:: - select(A).options(joinedload(A.bs).joinedload(B.cs, innerjoin="unnested")) + query(A).options( + joinedload(A.bs). + joinedload(B.cs, innerjoin="unnested") + ) The above join will render as "a LEFT OUTER JOIN b LEFT OUTER JOIN c", rather than as "a LEFT OUTER JOIN (b JOIN c)". @@ -313,15 +315,13 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): :ref:`joined_eager_loading` - """ # noqa: E501 + """ loader = self._set_relationship_strategy( attr, {"lazy": "joined"}, - opts=( - {"innerjoin": innerjoin} - if innerjoin is not None - else util.EMPTY_DICT - ), + opts={"innerjoin": innerjoin} + if innerjoin is not None + else util.EMPTY_DICT, ) return loader @@ -335,16 +335,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): examples:: # subquery-load the "orders" collection on "User" - select(User).options(subqueryload(User.orders)) + query(User).options(subqueryload(User.orders)) # subquery-load Order.items and then Item.keywords - select(Order).options( - subqueryload(Order.items).subqueryload(Item.keywords) - ) + query(Order).options( + subqueryload(Order.items).subqueryload(Item.keywords)) # lazily load Order.items, but when Items are loaded, # subquery-load the keywords collection - select(Order).options(lazyload(Order.items).subqueryload(Item.keywords)) + query(Order).options( + lazyload(Order.items).subqueryload(Item.keywords)) + .. seealso:: @@ -369,16 +370,16 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): examples:: # selectin-load the "orders" collection on "User" - select(User).options(selectinload(User.orders)) + query(User).options(selectinload(User.orders)) # selectin-load Order.items and then Item.keywords - select(Order).options( - selectinload(Order.items).selectinload(Item.keywords) - ) + query(Order).options( + selectinload(Order.items).selectinload(Item.keywords)) # lazily load Order.items, but when Items are loaded, # selectin-load the keywords collection - select(Order).options(lazyload(Order.items).selectinload(Item.keywords)) + query(Order).options( + lazyload(Order.items).selectinload(Item.keywords)) :param recursion_depth: optional int; when set to a positive integer in conjunction with a self-referential relationship, @@ -489,10 +490,10 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): :func:`_orm.noload` applies to :func:`_orm.relationship` attributes only. - .. legacy:: The :func:`_orm.noload` option is **legacy**. As it - forces collections to be empty, which invariably leads to - non-intuitive and difficult to predict results. There are no - legitimate uses for this option in modern SQLAlchemy. + .. note:: Setting this loading strategy as the default strategy + for a relationship using the :paramref:`.orm.relationship.lazy` + parameter may cause issues with flushes, such if a delete operation + needs to load related objects and instead ``None`` was returned. .. seealso:: @@ -554,20 +555,17 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): element of an element:: session.query(MyClass).options( - defaultload(MyClass.someattribute).joinedload( - MyOtherClass.someotherattribute - ) + defaultload(MyClass.someattribute). + joinedload(MyOtherClass.someotherattribute) ) :func:`.defaultload` is also useful for setting column-level options on a related class, namely that of :func:`.defer` and :func:`.undefer`:: - session.scalars( - select(MyClass).options( - defaultload(MyClass.someattribute) - .defer("some_column") - .undefer("some_other_column") - ) + session.query(MyClass).options( + defaultload(MyClass.someattribute). + defer("some_column"). + undefer("some_other_column") ) .. seealso:: @@ -591,7 +589,8 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): from sqlalchemy.orm import defer session.query(MyClass).options( - defer(MyClass.attribute_one), defer(MyClass.attribute_two) + defer(MyClass.attribute_one), + defer(MyClass.attribute_two) ) To specify a deferred load of an attribute on a related class, @@ -607,11 +606,11 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): at once using :meth:`_orm.Load.options`:: - select(MyClass).options( + session.query(MyClass).options( defaultload(MyClass.someattr).options( defer(RelatedClass.some_column), defer(RelatedClass.some_other_column), - defer(RelatedClass.another_column), + defer(RelatedClass.another_column) ) ) @@ -636,9 +635,7 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): strategy = {"deferred": True, "instrument": True} if raiseload: strategy["raiseload"] = True - return self._set_column_strategy( - _expand_column_strategy_attrs((key,)), strategy - ) + return self._set_column_strategy((key,), strategy) def undefer(self, key: _AttrType) -> Self: r"""Indicate that the given column-oriented attribute should be @@ -659,10 +656,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): ) # undefer all columns specific to a single class using Load + * - session.query(MyClass, MyOtherClass).options(Load(MyClass).undefer("*")) + session.query(MyClass, MyOtherClass).options( + Load(MyClass).undefer("*")) # undefer a column on a related object - select(MyClass).options(defaultload(MyClass.items).undefer(MyClass.text)) + session.query(MyClass).options( + defaultload(MyClass.items).undefer(MyClass.text)) :param key: Attribute to be undeferred. @@ -675,10 +674,9 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): :func:`_orm.undefer_group` - """ # noqa: E501 + """ return self._set_column_strategy( - _expand_column_strategy_attrs((key,)), - {"deferred": False, "instrument": True}, + (key,), {"deferred": False, "instrument": True} ) def undefer_group(self, name: str) -> Self: @@ -696,9 +694,8 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): spelled out using relationship loader options, such as :func:`_orm.defaultload`:: - select(MyClass).options( - defaultload("someattr").undefer_group("large_attrs") - ) + session.query(MyClass).options( + defaultload("someattr").undefer_group("large_attrs")) .. seealso:: @@ -779,10 +776,12 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return self @overload - def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: ... + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: + ... @overload - def _coerce_strat(self, strategy: Literal[None]) -> None: ... + def _coerce_strat(self, strategy: Literal[None]) -> None: + ... def _coerce_strat( self, strategy: Optional[_StrategySpec] @@ -1034,8 +1033,6 @@ class Load(_AbstractLoad): def _adapt_cached_option_to_uncached_option( self, context: QueryContext, uncached_opt: ORMOption ) -> ORMOption: - if uncached_opt is self: - return self return self._adjust_for_extra_criteria(context) def _prepend_path(self, path: PathRegistry) -> Load: @@ -1051,51 +1048,47 @@ class Load(_AbstractLoad): returning a new instance of this ``Load`` object. """ + orig_query = context.compile_state.select_statement + + orig_cache_key: Optional[CacheKey] = None + replacement_cache_key: Optional[CacheKey] = None + found_crit = False + + def process(opt: _LoadElement) -> _LoadElement: + nonlocal orig_cache_key, replacement_cache_key, found_crit + + found_crit = True + + if orig_cache_key is None or replacement_cache_key is None: + orig_cache_key = orig_query._generate_cache_key() + replacement_cache_key = context.query._generate_cache_key() + + assert orig_cache_key is not None + assert replacement_cache_key is not None + + opt._extra_criteria = tuple( + replacement_cache_key._apply_params_to_element( + orig_cache_key, crit + ) + for crit in opt._extra_criteria + ) + + return opt # avoid generating cache keys for the queries if we don't # actually have any extra_criteria options, which is the # common case - for value in self.context: - if value._extra_criteria: - break - else: - return self - - replacement_cache_key = context.user_passed_query._generate_cache_key() - - if replacement_cache_key is None: - return self - - orig_query = context.compile_state.select_statement - orig_cache_key = orig_query._generate_cache_key() - assert orig_cache_key is not None - - def process( - opt: _LoadElement, - replacement_cache_key: CacheKey, - orig_cache_key: CacheKey, - ) -> _LoadElement: - cloned_opt = opt._clone() - - cloned_opt._extra_criteria = tuple( - replacement_cache_key._apply_params_to_element( - orig_cache_key, crit - ) - for crit in cloned_opt._extra_criteria - ) - - return cloned_opt - - cloned = self._clone() - cloned.context = tuple( - ( - process(value, replacement_cache_key, orig_cache_key) - if value._extra_criteria - else value - ) + new_context = tuple( + process(value._clone()) if value._extra_criteria else value for value in self.context ) - return cloned + + if found_crit: + cloned = self._clone() + cloned.context = new_context + return cloned + else: + return self def _reconcile_query_entities_with_us(self, mapper_entities, raiseerr): """called at process time to allow adjustment of the root @@ -1104,6 +1097,7 @@ class Load(_AbstractLoad): """ path = self.path + ezero = None for ent in mapper_entities: ezero = ent.entity_zero if ezero and orm_util._entity_corresponds_to( @@ -1126,20 +1120,7 @@ class Load(_AbstractLoad): mapper_entities, raiseerr ) - # if the context has a current path, this is a lazy load - has_current_path = bool(compile_state.compile_options._current_path) - for loader in self.context: - # issue #11292 - # historically, propagate_to_loaders was only considered at - # object loading time, whether or not to carry along options - # onto an object's loaded state where it would be used by lazyload. - # however, the defaultload() option needs to propagate in case - # its sub-options propagate_to_loaders, but its sub-options - # that dont propagate should not be applied for lazy loaders. - # so we check again - if has_current_path and not loader.propagate_to_loaders: - continue loader.process_compile_state( self, compile_state, @@ -1197,11 +1178,13 @@ class Load(_AbstractLoad): query = session.query(Author) query = query.options( - joinedload(Author.book).options( - load_only(Book.summary, Book.excerpt), - joinedload(Book.citations).options(joinedload(Citation.author)), - ) - ) + joinedload(Author.book).options( + load_only(Book.summary, Book.excerpt), + joinedload(Book.citations).options( + joinedload(Citation.author) + ) + ) + ) :param \*opts: A series of loader option objects (ultimately :class:`_orm.Load` objects) which should be applied to the path @@ -1628,10 +1611,9 @@ class _LoadElement( f"Mapped class {path[0]} does not apply to any of the " f"root entities in this query, e.g. " f"""{ - ", ".join( - str(x.entity_zero) - for x in mapper_entities if x.entity_zero - )}. Please """ + ", ".join(str(x.entity_zero) + for x in mapper_entities if x.entity_zero + )}. Please """ "specify the full path " "from one of the root entities to the target " "attribute. " @@ -1645,17 +1627,13 @@ class _LoadElement( loads, and adjusts the given path to be relative to the current_path. - E.g. given a loader path and current path: - - .. sourcecode:: text + E.g. given a loader path and current path:: lp: User -> orders -> Order -> items -> Item -> keywords -> Keyword cp: User -> orders -> Order -> items - The adjusted path would be: - - .. sourcecode:: text + The adjusted path would be:: Item -> keywords -> Keyword @@ -2101,9 +2079,9 @@ class _AttributeStrategyLoad(_LoadElement): d["_extra_criteria"] = () if self._path_with_polymorphic_path: - d["_path_with_polymorphic_path"] = ( - self._path_with_polymorphic_path.serialize() - ) + d[ + "_path_with_polymorphic_path" + ] = self._path_with_polymorphic_path.serialize() if self._of_type: if self._of_type.is_aliased_class: @@ -2136,11 +2114,11 @@ class _TokenStrategyLoad(_LoadElement): e.g.:: - raiseload("*") - Load(User).lazyload("*") - defer("*") + raiseload('*') + Load(User).lazyload('*') + defer('*') load_only(User.name, User.email) # will create a defer('*') - joinedload(User.addresses).raiseload("*") + joinedload(User.addresses).raiseload('*') """ @@ -2395,23 +2373,6 @@ See :func:`_orm.{fn.__name__}` for usage examples. return fn -def _expand_column_strategy_attrs( - attrs: Tuple[_AttrType, ...], -) -> Tuple[_AttrType, ...]: - return cast( - "Tuple[_AttrType, ...]", - tuple( - a - for attr in attrs - for a in ( - cast("QueryableAttribute[Any]", attr)._column_strategy_attrs() - if hasattr(attr, "_column_strategy_attrs") - else (attr,) - ) - ), - ) - - # standalone functions follow. docstrings are filled in # by the ``@loader_unbound_fn`` decorator. @@ -2425,7 +2386,6 @@ def contains_eager(*keys: _AttrType, **kw: Any) -> _AbstractLoad: def load_only(*attrs: _AttrType, raiseload: bool = False) -> _AbstractLoad: # TODO: attrs against different classes. we likely have to # add some extra state to Load of some kind - attrs = _expand_column_strategy_attrs(attrs) _, lead_element, _ = _parse_attr_argument(attrs[0]) return Load(lead_element).load_only(*attrs, raiseload=raiseload) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/sync.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/sync.py index 8f85a41..036c26d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/sync.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/sync.py @@ -1,5 +1,5 @@ # orm/sync.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -86,9 +86,8 @@ def clear(dest, dest_mapper, synchronize_pairs): not in orm_util._none_set ): raise AssertionError( - f"Dependency rule on column '{l}' " - "tried to blank-out primary key " - f"column '{r}' on instance '{orm_util.state_str(dest)}'" + "Dependency rule tried to blank-out primary key " + "column '%s' on instance '%s'" % (r, orm_util.state_str(dest)) ) try: dest_mapper._set_state_attr_by_column(dest, dest.dict, r, None) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/unitofwork.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/unitofwork.py index 80897f2..20fe022 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/unitofwork.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/unitofwork.py @@ -1,5 +1,5 @@ # orm/unitofwork.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/util.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/util.py index ca607af..ea2f1a1 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/util.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/util.py @@ -1,5 +1,5 @@ # orm/util.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -35,7 +35,6 @@ import weakref from . import attributes # noqa from . import exc -from . import exc as orm_exc from ._typing import _O from ._typing import insp_is_aliased_class from ._typing import insp_is_mapper @@ -43,7 +42,6 @@ from ._typing import prop_is_relationship from .base import _class_to_mapper as _class_to_mapper from .base import _MappedAnnotationBase from .base import _never_set as _never_set # noqa: F401 -from .base import _none_only_set as _none_only_set # noqa: F401 from .base import _none_set as _none_set # noqa: F401 from .base import attribute_str as attribute_str # noqa: F401 from .base import class_mapper as class_mapper @@ -87,12 +85,14 @@ from ..sql.elements import KeyedColumnElement from ..sql.selectable import FromClause from ..util.langhelpers import MemoizedSlots from ..util.typing import de_stringify_annotation as _de_stringify_annotation +from ..util.typing import ( + de_stringify_union_elements as _de_stringify_union_elements, +) from ..util.typing import eval_name_only as _eval_name_only -from ..util.typing import fixup_container_fwd_refs -from ..util.typing import get_origin from ..util.typing import is_origin_of_cls from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import typing_get_origin if typing.TYPE_CHECKING: from ._typing import _EntityType @@ -121,6 +121,7 @@ if typing.TYPE_CHECKING: from ..sql.selectable import Selectable from ..sql.visitors import anon_map from ..util.typing import _AnnotationScanType + from ..util.typing import ArgsTypeProcotol _T = TypeVar("_T", bound=Any) @@ -137,6 +138,7 @@ all_cascades = frozenset( ) ) + _de_stringify_partial = functools.partial( functools.partial, locals_=util.immutabledict( @@ -161,7 +163,8 @@ class _DeStringifyAnnotation(Protocol): *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, - ) -> Type[Any]: ... + ) -> Type[Any]: + ... de_stringify_annotation = cast( @@ -169,8 +172,27 @@ de_stringify_annotation = cast( ) +class _DeStringifyUnionElements(Protocol): + def __call__( + self, + cls: Type[Any], + annotation: ArgsTypeProcotol, + originating_module: str, + *, + str_cleanup_fn: Optional[Callable[[str, str], str]] = None, + ) -> Type[Any]: + ... + + +de_stringify_union_elements = cast( + _DeStringifyUnionElements, + _de_stringify_partial(_de_stringify_union_elements), +) + + class _EvalNameOnly(Protocol): - def __call__(self, name: str, module_name: str) -> Any: ... + def __call__(self, name: str, module_name: str) -> Any: + ... eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) @@ -228,7 +250,7 @@ class CascadeOptions(FrozenSet[str]): values.clear() values.discard("all") - self = super().__new__(cls, values) + self = super().__new__(cls, values) # type: ignore self.save_update = "save-update" in values self.delete = "delete" in values self.refresh_expire = "refresh-expire" in values @@ -237,7 +259,9 @@ class CascadeOptions(FrozenSet[str]): self.delete_orphan = "delete-orphan" in values if self.delete_orphan and not self.delete: - util.warn("The 'delete-orphan' cascade option requires 'delete'.") + util.warn( + "The 'delete-orphan' cascade " "option requires 'delete'." + ) return self def __repr__(self): @@ -454,7 +478,9 @@ def identity_key( E.g.:: - >>> row = engine.execute(text("select * from table where a=1 and b=2")).first() + >>> row = engine.execute(\ + text("select * from table where a=1 and b=2")\ + ).first() >>> identity_key(MyClass, row=row) (, (1, 2), None) @@ -465,7 +491,7 @@ def identity_key( .. versionadded:: 1.2 added identity_token - """ # noqa: E501 + """ if class_ is not None: mapper = class_mapper(class_) if row is None: @@ -643,9 +669,9 @@ class AliasedClass( # find all pairs of users with the same name user_alias = aliased(User) - session.query(User, user_alias).join( - (user_alias, User.id > user_alias.id) - ).filter(User.name == user_alias.name) + session.query(User, user_alias).\ + join((user_alias, User.id > user_alias.id)).\ + filter(User.name == user_alias.name) :class:`.AliasedClass` is also capable of mapping an existing mapped class to an entirely new selectable, provided this selectable is column- @@ -669,7 +695,6 @@ class AliasedClass( using :func:`_sa.inspect`:: from sqlalchemy import inspect - my_alias = aliased(MyClass) insp = inspect(my_alias) @@ -730,16 +755,12 @@ class AliasedClass( insp, alias, name, - ( - with_polymorphic_mappers - if with_polymorphic_mappers - else mapper.with_polymorphic_mappers - ), - ( - with_polymorphic_discriminator - if with_polymorphic_discriminator is not None - else mapper.polymorphic_on - ), + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers, + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on, base_alias, use_mapper_path, adapt_on_names, @@ -950,9 +971,9 @@ class AliasedInsp( self._weak_entity = weakref.ref(entity) self.mapper = mapper - self.selectable = self.persist_selectable = self.local_table = ( - selectable - ) + self.selectable = ( + self.persist_selectable + ) = self.local_table = selectable self.name = name self.polymorphic_on = polymorphic_on self._base_alias = weakref.ref(_base_alias or self) @@ -1047,7 +1068,6 @@ class AliasedInsp( aliased: bool = False, innerjoin: bool = False, adapt_on_names: bool = False, - name: Optional[str] = None, _use_mapper_path: bool = False, ) -> AliasedClass[_O]: primary_mapper = _class_to_mapper(base) @@ -1068,7 +1088,6 @@ class AliasedInsp( return AliasedClass( base, selectable, - name=name, with_polymorphic_mappers=mappers, adapt_on_names=adapt_on_names, with_polymorphic_discriminator=polymorphic_on, @@ -1210,7 +1229,8 @@ class AliasedInsp( self, obj: _CE, key: Optional[str] = None, - ) -> _CE: ... + ) -> _CE: + ... else: _orm_adapt_element = _adapt_element @@ -1360,10 +1380,7 @@ class LoaderCriteriaOption(CriteriaOption): def __init__( self, entity_or_base: _EntityType[Any], - where_criteria: Union[ - _ColumnExpressionArgument[bool], - Callable[[Any], _ColumnExpressionArgument[bool]], - ], + where_criteria: _ColumnExpressionArgument[bool], loader_only: bool = False, include_aliases: bool = False, propagate_to_loaders: bool = True, @@ -1522,7 +1539,7 @@ GenericAlias = type(List[Any]) def _inspect_generic_alias( class_: Type[_O], ) -> Optional[Mapper[_O]]: - origin = cast("Type[_O]", get_origin(class_)) + origin = cast("Type[_O]", typing_get_origin(class_)) return _inspect_mc(origin) @@ -1566,7 +1583,7 @@ class Bundle( _propagate_attrs: _PropagateAttrsType = util.immutabledict() - proxy_set = util.EMPTY_SET + proxy_set = util.EMPTY_SET # type: ignore exprs: List[_ColumnsClauseElement] @@ -1579,7 +1596,8 @@ class Bundle( bn = Bundle("mybundle", MyClass.x, MyClass.y) - for row in session.query(bn).filter(bn.c.x == 5).filter(bn.c.y == 4): + for row in session.query(bn).filter( + bn.c.x == 5).filter(bn.c.y == 4): print(row.mybundle.x, row.mybundle.y) :param name: name of the bundle. @@ -1588,7 +1606,7 @@ class Bundle( can be returned as a "single entity" outside of any enclosing tuple in the same manner as a mapped entity. - """ # noqa: E501 + """ self.name = self._label = name coerced_exprs = [ coercions.expect( @@ -1643,24 +1661,24 @@ class Bundle( Nesting of bundles is also supported:: - b1 = Bundle( - "b1", - Bundle("b2", MyClass.a, MyClass.b), - Bundle("b3", MyClass.x, MyClass.y), - ) + b1 = Bundle("b1", + Bundle('b2', MyClass.a, MyClass.b), + Bundle('b3', MyClass.x, MyClass.y) + ) - q = sess.query(b1).filter(b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) + q = sess.query(b1).filter( + b1.c.b2.c.a == 5).filter(b1.c.b3.c.y == 9) .. seealso:: :attr:`.Bundle.c` - """ # noqa: E501 + """ c: ReadOnlyColumnCollection[str, KeyedColumnElement[Any]] """An alias for :attr:`.Bundle.columns`.""" - def _clone(self, **kw): + def _clone(self): cloned = self.__class__.__new__(self.__class__) cloned.__dict__.update(self.__dict__) return cloned @@ -1721,24 +1739,25 @@ class Bundle( from sqlalchemy.orm import Bundle - class DictBundle(Bundle): def create_row_processor(self, query, procs, labels): - "Override create_row_processor to return values as dictionaries" + 'Override create_row_processor to return values as + dictionaries' def proc(row): - return dict(zip(labels, (proc(row) for proc in procs))) - + return dict( + zip(labels, (proc(row) for proc in procs)) + ) return proc A result from the above :class:`_orm.Bundle` will return dictionary values:: - bn = DictBundle("mybundle", MyClass.data1, MyClass.data2) - for row in session.execute(select(bn)).where(bn.c.data1 == "d1"): - print(row.mybundle["data1"], row.mybundle["data2"]) + bn = DictBundle('mybundle', MyClass.data1, MyClass.data2) + for row in session.execute(select(bn)).where(bn.c.data1 == 'd1'): + print(row.mybundle['data1'], row.mybundle['data2']) - """ # noqa: E501 + """ keyed_tuple = result_tuple(labels, [() for l in labels]) def proc(row: Row[Any]) -> Any: @@ -1921,7 +1940,7 @@ class _ORMJoin(expression.Join): self.onclause, isouter=self.isouter, _left_memo=self._left_memo, - _right_memo=other._left_memo._path_registry, + _right_memo=other._left_memo, ) return _ORMJoin( @@ -1964,6 +1983,7 @@ def with_parent( stmt = select(Address).where(with_parent(some_user, User.addresses)) + The SQL rendered is the same as that rendered when a lazy loader would fire off from the given parent on that attribute, meaning that the appropriate state is taken from the parent object in @@ -1976,7 +1996,9 @@ def with_parent( a1 = aliased(Address) a2 = aliased(Address) - stmt = select(a1, a2).where(with_parent(u1, User.addresses.of_type(a2))) + stmt = select(a1, a2).where( + with_parent(u1, User.addresses.of_type(a2)) + ) The above use is equivalent to using the :func:`_orm.with_parent.from_entity` argument:: @@ -2001,7 +2023,7 @@ def with_parent( .. versionadded:: 1.2 - """ # noqa: E501 + """ prop_t: RelationshipProperty[Any] if isinstance(prop, str): @@ -2095,13 +2117,14 @@ def _entity_corresponds_to_use_path_impl( someoption(A).someoption(C.d) # -> fn(A, C) -> False a1 = aliased(A) - someoption(a1).someoption(A.b) # -> fn(a1, A) -> False - someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True + someoption(a1).someoption(A.b) # -> fn(a1, A) -> False + someoption(a1).someoption(a1.b) # -> fn(a1, a1) -> True wp = with_polymorphic(A, [A1, A2]) someoption(wp).someoption(A1.foo) # -> fn(wp, A1) -> False someoption(wp).someoption(wp.A1.foo) # -> fn(wp, wp.A1) -> True + """ if insp_is_aliased_class(given): return ( @@ -2128,7 +2151,7 @@ def _entity_isa(given: _InternalEntityType[Any], mapper: Mapper[Any]) -> bool: mapper ) elif given.with_polymorphic_mappers: - return mapper in given.with_polymorphic_mappers or given.isa(mapper) + return mapper in given.with_polymorphic_mappers else: return given.isa(mapper) @@ -2210,7 +2233,7 @@ def _cleanup_mapped_str_annotation( inner: Optional[Match[str]] - mm = re.match(r"^([^ \|]+?)\[(.+)\]$", annotation) + mm = re.match(r"^(.+?)\[(.+)\]$", annotation) if not mm: return annotation @@ -2250,7 +2273,7 @@ def _cleanup_mapped_str_annotation( while True: stack.append(real_symbol if mm is inner else inner.group(1)) g2 = inner.group(2) - inner = re.match(r"^([^ \|]+?)\[(.+)\]$", g2) + inner = re.match(r"^(.+?)\[(.+)\]$", g2) if inner is None: stack.append(g2) break @@ -2272,10 +2295,8 @@ def _cleanup_mapped_str_annotation( # ['Mapped', "'Optional[Dict[str, str]]'"] not re.match(r"""^["'].*["']$""", stack[-1]) # avoid further generics like Dict[] such as - # ['Mapped', 'dict[str, str] | None'], - # ['Mapped', 'list[int] | list[str]'], - # ['Mapped', 'Union[list[int], list[str]]'], - and not re.search(r"[\[\]]", stack[-1]) + # ['Mapped', 'dict[str, str] | None'] + and not re.match(r".*\[.*\]", stack[-1]) ): stripchars = "\"' " stack[-1] = ", ".join( @@ -2297,7 +2318,7 @@ def _extract_mapped_subtype( is_dataclass_field: bool, expect_mapped: bool = True, raiseerr: bool = True, -) -> Optional[Tuple[Union[_AnnotationScanType, str], Optional[type]]]: +) -> Optional[Tuple[Union[type, str], Optional[type]]]: """given an annotation, figure out if it's ``Mapped[something]`` and if so, return the ``something`` part. @@ -2307,7 +2328,7 @@ def _extract_mapped_subtype( if raw_annotation is None: if required: - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( f"Python typing annotation is required for attribute " f'"{cls.__name__}.{key}" when primary argument(s) for ' f'"{attr_cls.__name__}" construct are None or not present' @@ -2315,11 +2336,6 @@ def _extract_mapped_subtype( return None try: - # destringify the "outside" of the annotation. note we are not - # adding include_generic so it will *not* dig into generic contents, - # which will remain as ForwardRef or plain str under future annotations - # mode. The full destringify happens later when mapped_column goes - # to do a full lookup in the registry type_annotations_map. annotated = de_stringify_annotation( cls, raw_annotation, @@ -2327,14 +2343,14 @@ def _extract_mapped_subtype( str_cleanup_fn=_cleanup_mapped_str_annotation, ) except _CleanupError as ce: - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( f"Could not interpret annotation {raw_annotation}. " "Check that it uses names that are correctly imported at the " "module level. See chained stack trace for more hints." ) from ce except NameError as ne: if raiseerr and "Mapped[" in raw_annotation: # type: ignore - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( f"Could not interpret annotation {raw_annotation}. " "Check that it uses names that are correctly imported at the " "module level. See chained stack trace for more hints." @@ -2363,7 +2379,7 @@ def _extract_mapped_subtype( ): return None - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( f'Type annotation for "{cls.__name__}.{key}" ' "can't be correctly interpreted for " "Annotated Declarative Table form. ORM annotations " @@ -2384,20 +2400,8 @@ def _extract_mapped_subtype( return annotated, None if len(annotated.__args__) != 1: - raise orm_exc.MappedAnnotationError( + raise sa_exc.ArgumentError( "Expected sub-type for Mapped[] annotation" ) - return ( - # fix dict/list/set args to be ForwardRef, see #11814 - fixup_container_fwd_refs(annotated.__args__[0]), - annotated.__origin__, - ) - - -def _mapper_property_as_plain_name(prop: Type[Any]) -> str: - if hasattr(prop, "_mapper_property_name"): - name = prop._mapper_property_name() - else: - name = None - return util.clsname_as_plain_name(prop, name) + return annotated.__args__[0], annotated.__origin__ diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/orm/writeonly.py b/venv/lib/python3.12/site-packages/sqlalchemy/orm/writeonly.py index fe9c8e9..416a039 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/orm/writeonly.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/orm/writeonly.py @@ -1,5 +1,5 @@ # orm/writeonly.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -196,7 +196,8 @@ class WriteOnlyAttributeImpl( dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -205,7 +206,8 @@ class WriteOnlyAttributeImpl( dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: ... + ) -> CollectionAdapter: + ... @overload def get_collection( @@ -216,7 +218,8 @@ class WriteOnlyAttributeImpl( passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: ... + ]: + ... def get_collection( self, @@ -236,11 +239,15 @@ class WriteOnlyAttributeImpl( return DynamicCollectionAdapter(data) # type: ignore[return-value] @util.memoized_property - def _append_token(self) -> attributes.AttributeEventToken: + def _append_token( # type:ignore[override] + self, + ) -> attributes.AttributeEventToken: return attributes.AttributeEventToken(self, attributes.OP_APPEND) @util.memoized_property - def _remove_token(self) -> attributes.AttributeEventToken: + def _remove_token( # type:ignore[override] + self, + ) -> attributes.AttributeEventToken: return attributes.AttributeEventToken(self, attributes.OP_REMOVE) def fire_append_event( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/pool/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/pool/__init__.py index 51bf0ec..7929b6e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/pool/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/pool/__init__.py @@ -1,5 +1,5 @@ -# pool/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/pool/__init__.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/pool/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/pool/base.py index ed4d7c1..90ed32e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/pool/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/pool/base.py @@ -1,12 +1,14 @@ -# pool/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/pool.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Base constructs for connection pools.""" +"""Base constructs for connection pools. + +""" from __future__ import annotations @@ -145,14 +147,17 @@ class _AsyncConnDialect(_ConnDialect): class _CreatorFnType(Protocol): - def __call__(self) -> DBAPIConnection: ... + def __call__(self) -> DBAPIConnection: + ... class _CreatorWRecFnType(Protocol): - def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: ... + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: + ... class Pool(log.Identified, event.EventTarget): + """Abstract base class for connection pools.""" dispatch: dispatcher[Pool] @@ -466,7 +471,6 @@ class Pool(log.Identified, event.EventTarget): raise NotImplementedError() def status(self) -> str: - """Returns a brief description of the state of this pool.""" raise NotImplementedError() @@ -629,6 +633,7 @@ class ConnectionPoolEntry(ManagesConnection): class _ConnectionRecord(ConnectionPoolEntry): + """Maintains a position in a connection pool which references a pooled connection. @@ -724,13 +729,11 @@ class _ConnectionRecord(ConnectionPoolEntry): rec.fairy_ref = ref = weakref.ref( fairy, - lambda ref: ( - _finalize_fairy( - None, rec, pool, ref, echo, transaction_was_reset=False - ) - if _finalize_fairy is not None - else None - ), + lambda ref: _finalize_fairy( + None, rec, pool, ref, echo, transaction_was_reset=False + ) + if _finalize_fairy is not None + else None, ) _strong_ref_connection_records[ref] = rec if echo: @@ -1071,13 +1074,14 @@ class PoolProxiedConnection(ManagesConnection): if typing.TYPE_CHECKING: - def commit(self) -> None: ... + def commit(self) -> None: + ... - def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: ... + def cursor(self) -> DBAPICursor: + ... - def rollback(self) -> None: ... - - def __getattr__(self, key: str) -> Any: ... + def rollback(self) -> None: + ... @property def is_valid(self) -> bool: @@ -1185,6 +1189,7 @@ class _AdhocProxiedConnection(PoolProxiedConnection): class _ConnectionFairy(PoolProxiedConnection): + """Proxies a DBAPI connection and provides return-on-dereference support. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/pool/events.py b/venv/lib/python3.12/site-packages/sqlalchemy/pool/events.py index 4ceb260..762418b 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/pool/events.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/pool/events.py @@ -1,5 +1,5 @@ -# pool/events.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/pool/events.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -35,12 +35,10 @@ class PoolEvents(event.Events[Pool]): from sqlalchemy import event - def my_on_checkout(dbapi_conn, connection_rec, connection_proxy): "handle an on checkout event" - - event.listen(Pool, "checkout", my_on_checkout) + event.listen(Pool, 'checkout', my_on_checkout) In addition to accepting the :class:`_pool.Pool` class and :class:`_pool.Pool` instances, :class:`_events.PoolEvents` also accepts @@ -51,7 +49,7 @@ class PoolEvents(event.Events[Pool]): engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test") # will associate with engine.pool - event.listen(engine, "checkout", my_on_checkout) + event.listen(engine, 'checkout', my_on_checkout) """ # noqa: E501 @@ -175,7 +173,7 @@ class PoolEvents(event.Events[Pool]): def checkin( self, - dbapi_connection: Optional[DBAPIConnection], + dbapi_connection: DBAPIConnection, connection_record: ConnectionPoolEntry, ) -> None: """Called when a connection returns to the pool. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/pool/impl.py b/venv/lib/python3.12/site-packages/sqlalchemy/pool/impl.py index f3d53dd..af4f788 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/pool/impl.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/pool/impl.py @@ -1,12 +1,14 @@ -# pool/impl.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# sqlalchemy/pool.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Pool implementation classes.""" +"""Pool implementation classes. + +""" from __future__ import annotations import threading @@ -41,30 +43,21 @@ if typing.TYPE_CHECKING: class QueuePool(Pool): + """A :class:`_pool.Pool` that imposes a limit on the number of open connections. :class:`.QueuePool` is the default pooling implementation used for - all :class:`_engine.Engine` objects other than SQLite with a ``:memory:`` - database. - - The :class:`.QueuePool` class **is not compatible** with asyncio and - :func:`_asyncio.create_async_engine`. The - :class:`.AsyncAdaptedQueuePool` class is used automatically when - using :func:`_asyncio.create_async_engine`, if no other kind of pool - is specified. - - .. seealso:: - - :class:`.AsyncAdaptedQueuePool` + all :class:`_engine.Engine` objects, unless the SQLite dialect is + in use with a ``:memory:`` database. """ - _is_asyncio = False + _is_asyncio = False # type: ignore[assignment] - _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( - sqla_queue.Queue - ) + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.Queue _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] @@ -131,7 +124,6 @@ class QueuePool(Pool): :class:`_pool.Pool` constructor. """ - Pool.__init__(self, creator, **kw) self._pool = self._queue_class(pool_size, use_lifo=use_lifo) self._overflow = 0 - pool_size @@ -257,31 +249,20 @@ class QueuePool(Pool): class AsyncAdaptedQueuePool(QueuePool): - """An asyncio-compatible version of :class:`.QueuePool`. - - This pool is used by default when using :class:`.AsyncEngine` engines that - were generated from :func:`_asyncio.create_async_engine`. It uses an - asyncio-compatible queue implementation that does not use - ``threading.Lock``. - - The arguments and operation of :class:`.AsyncAdaptedQueuePool` are - otherwise identical to that of :class:`.QueuePool`. - - """ - - _is_asyncio = True - _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( - sqla_queue.AsyncAdaptedQueue - ) + _is_asyncio = True # type: ignore[assignment] + _queue_class: Type[ + sqla_queue.QueueCommon[ConnectionPoolEntry] + ] = sqla_queue.AsyncAdaptedQueue _dialect = _AsyncConnDialect() class FallbackAsyncAdaptedQueuePool(AsyncAdaptedQueuePool): - _queue_class = sqla_queue.FallbackAsyncAdaptedQueue # type: ignore[assignment] # noqa: E501 + _queue_class = sqla_queue.FallbackAsyncAdaptedQueue class NullPool(Pool): + """A Pool which does not pool connections. Instead it literally opens and closes the underlying DB-API connection @@ -291,9 +272,6 @@ class NullPool(Pool): invalidation are not supported by this Pool implementation, since no connections are held persistently. - The :class:`.NullPool` class **is compatible** with asyncio and - :func:`_asyncio.create_async_engine`. - """ def status(self) -> str: @@ -324,6 +302,7 @@ class NullPool(Pool): class SingletonThreadPool(Pool): + """A Pool that maintains one connection per thread. Maintains one connection per each thread, never moving a connection to a @@ -341,9 +320,6 @@ class SingletonThreadPool(Pool): scenarios using a SQLite ``:memory:`` database and is not recommended for production use. - The :class:`.SingletonThreadPool` class **is not compatible** with asyncio - and :func:`_asyncio.create_async_engine`. - Options are the same as those of :class:`_pool.Pool`, as well as: @@ -356,7 +332,7 @@ class SingletonThreadPool(Pool): """ - _is_asyncio = False + _is_asyncio = False # type: ignore[assignment] def __init__( self, @@ -446,14 +422,13 @@ class SingletonThreadPool(Pool): class StaticPool(Pool): + """A Pool of exactly one connection, used for all requests. Reconnect-related functions such as ``recycle`` and connection invalidation (which is also used to support auto-reconnect) are only partially supported right now and may not yield good results. - The :class:`.StaticPool` class **is compatible** with asyncio and - :func:`_asyncio.create_async_engine`. """ @@ -511,6 +486,7 @@ class StaticPool(Pool): class AssertionPool(Pool): + """A :class:`_pool.Pool` that allows at most one checked out connection at any given time. @@ -518,9 +494,6 @@ class AssertionPool(Pool): at a time. Useful for debugging code that is using more connections than desired. - The :class:`.AssertionPool` class **is compatible** with asyncio and - :func:`_asyncio.create_async_engine`. - """ _conn: Optional[ConnectionPoolEntry] diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/schema.py b/venv/lib/python3.12/site-packages/sqlalchemy/schema.py index 56b90ec..19782bd 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/schema.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/schema.py @@ -1,11 +1,13 @@ # schema.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -"""Compatibility namespace for sqlalchemy.sql.schema and related.""" +"""Compatibility namespace for sqlalchemy.sql.schema and related. + +""" from __future__ import annotations @@ -63,7 +65,6 @@ from .sql.schema import MetaData as MetaData from .sql.schema import PrimaryKeyConstraint as PrimaryKeyConstraint from .sql.schema import SchemaConst as SchemaConst from .sql.schema import SchemaItem as SchemaItem -from .sql.schema import SchemaVisitable as SchemaVisitable from .sql.schema import Sequence as Sequence from .sql.schema import Table as Table from .sql.schema import UniqueConstraint as UniqueConstraint diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/__init__.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/__init__.py index 188f709..a81509f 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/__init__.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/__init__.py @@ -1,5 +1,5 @@ # sql/__init__.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_dml_constructors.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_dml_constructors.py index 0a6f601..5c0cc62 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_dml_constructors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_dml_constructors.py @@ -1,5 +1,5 @@ # sql/_dml_constructors.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -24,7 +24,10 @@ def insert(table: _DMLTableArgument) -> Insert: from sqlalchemy import insert - stmt = insert(user_table).values(name="username", fullname="Full Username") + stmt = ( + insert(user_table). + values(name='username', fullname='Full Username') + ) Similar functionality is available via the :meth:`_expression.TableClause.insert` method on @@ -75,7 +78,7 @@ def insert(table: _DMLTableArgument) -> Insert: :ref:`tutorial_core_insert` - in the :ref:`unified_tutorial` - """ # noqa: E501 + """ return Insert(table) @@ -87,7 +90,9 @@ def update(table: _DMLTableArgument) -> Update: from sqlalchemy import update stmt = ( - update(user_table).where(user_table.c.id == 5).values(name="user #5") + update(user_table). + where(user_table.c.id == 5). + values(name='user #5') ) Similar functionality is available via the @@ -104,7 +109,7 @@ def update(table: _DMLTableArgument) -> Update: :ref:`tutorial_core_update_delete` - in the :ref:`unified_tutorial` - """ # noqa: E501 + """ return Update(table) @@ -115,7 +120,10 @@ def delete(table: _DMLTableArgument) -> Delete: from sqlalchemy import delete - stmt = delete(user_table).where(user_table.c.id == 5) + stmt = ( + delete(user_table). + where(user_table.c.id == 5) + ) Similar functionality is available via the :meth:`_expression.TableClause.delete` method on diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_elements_constructors.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_elements_constructors.py index 3359998..2719737 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_elements_constructors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_elements_constructors.py @@ -1,5 +1,5 @@ # sql/_elements_constructors.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -10,6 +10,7 @@ from __future__ import annotations import typing from typing import Any from typing import Callable +from typing import Iterable from typing import Mapping from typing import Optional from typing import overload @@ -48,7 +49,6 @@ from .functions import FunctionElement from ..util.typing import Literal if typing.TYPE_CHECKING: - from ._typing import _ByArgument from ._typing import _ColumnExpressionArgument from ._typing import _ColumnExpressionOrLiteralArgument from ._typing import _ColumnExpressionOrStrLabelArgument @@ -125,8 +125,11 @@ def and_( # type: ignore[empty-body] from sqlalchemy import and_ stmt = select(users_table).where( - and_(users_table.c.name == "wendy", users_table.c.enrolled == True) - ) + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) The :func:`.and_` conjunction is also available using the Python ``&`` operator (though note that compound expressions @@ -134,8 +137,9 @@ def and_( # type: ignore[empty-body] operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == "wendy") & (users_table.c.enrolled == True) - ) + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) The :func:`.and_` operation is also implicit in some cases; the :meth:`_expression.Select.where` @@ -143,11 +147,9 @@ def and_( # type: ignore[empty-body] times against a statement, which will have the effect of each clause being combined using :func:`.and_`:: - stmt = ( - select(users_table) - .where(users_table.c.name == "wendy") - .where(users_table.c.enrolled == True) - ) + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) The :func:`.and_` construct must be given at least one positional argument in order to be valid; a :func:`.and_` construct with no @@ -157,7 +159,6 @@ def and_( # type: ignore[empty-body] specified:: from sqlalchemy import true - criteria = and_(true(), *expressions) The above expression will compile to SQL as the expression ``true`` @@ -189,8 +190,11 @@ if not TYPE_CHECKING: from sqlalchemy import and_ stmt = select(users_table).where( - and_(users_table.c.name == "wendy", users_table.c.enrolled == True) - ) + and_( + users_table.c.name == 'wendy', + users_table.c.enrolled == True + ) + ) The :func:`.and_` conjunction is also available using the Python ``&`` operator (though note that compound expressions @@ -198,8 +202,9 @@ if not TYPE_CHECKING: operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == "wendy") & (users_table.c.enrolled == True) - ) + (users_table.c.name == 'wendy') & + (users_table.c.enrolled == True) + ) The :func:`.and_` operation is also implicit in some cases; the :meth:`_expression.Select.where` @@ -207,11 +212,9 @@ if not TYPE_CHECKING: times against a statement, which will have the effect of each clause being combined using :func:`.and_`:: - stmt = ( - select(users_table) - .where(users_table.c.name == "wendy") - .where(users_table.c.enrolled == True) - ) + stmt = select(users_table).\ + where(users_table.c.name == 'wendy').\ + where(users_table.c.enrolled == True) The :func:`.and_` construct must be given at least one positional argument in order to be valid; a :func:`.and_` construct with no @@ -221,7 +224,6 @@ if not TYPE_CHECKING: specified:: from sqlalchemy import true - criteria = and_(true(), *expressions) The above expression will compile to SQL as the expression ``true`` @@ -239,7 +241,7 @@ if not TYPE_CHECKING: :func:`.or_` - """ # noqa: E501 + """ return BooleanClauseList.and_(*clauses) @@ -305,12 +307,9 @@ def asc( e.g.:: from sqlalchemy import asc - stmt = select(users_table).order_by(asc(users_table.c.name)) - will produce SQL as: - - .. sourcecode:: sql + will produce SQL as:: SELECT id, name FROM user ORDER BY name ASC @@ -347,11 +346,9 @@ def collate( e.g.:: - collate(mycolumn, "utf8_bin") + collate(mycolumn, 'utf8_bin') - produces: - - .. sourcecode:: sql + produces:: mycolumn COLLATE utf8_bin @@ -376,12 +373,9 @@ def between( E.g.:: from sqlalchemy import between - stmt = select(users_table).where(between(users_table.c.id, 5, 7)) - Would produce SQL resembling: - - .. sourcecode:: sql + Would produce SQL resembling:: SELECT id, name FROM user WHERE id BETWEEN :id_1 AND :id_2 @@ -442,12 +436,16 @@ def outparam( return BindParameter(key, None, type_=type_, unique=False, isoutparam=True) +# mypy insists that BinaryExpression and _HasClauseElement protocol overlap. +# they do not. at all. bug in mypy? @overload -def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: # type: ignore + ... @overload -def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ... +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: + ... def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: @@ -499,13 +497,10 @@ def bindparam( from sqlalchemy import bindparam - stmt = select(users_table).where( - users_table.c.name == bindparam("username") - ) + stmt = select(users_table).\ + where(users_table.c.name == bindparam('username')) - The above statement, when rendered, will produce SQL similar to: - - .. sourcecode:: sql + The above statement, when rendered, will produce SQL similar to:: SELECT id, name FROM user WHERE name = :username @@ -513,25 +508,22 @@ def bindparam( would typically be applied at execution time to a method like :meth:`_engine.Connection.execute`:: - result = connection.execute(stmt, {"username": "wendy"}) + result = connection.execute(stmt, username='wendy') Explicit use of :func:`.bindparam` is also common when producing UPDATE or DELETE statements that are to be invoked multiple times, where the WHERE criterion of the statement is to change on each invocation, such as:: - stmt = ( - users_table.update() - .where(user_table.c.name == bindparam("username")) - .values(fullname=bindparam("fullname")) - ) + stmt = (users_table.update(). + where(user_table.c.name == bindparam('username')). + values(fullname=bindparam('fullname')) + ) connection.execute( - stmt, - [ - {"username": "wendy", "fullname": "Wendy Smith"}, - {"username": "jack", "fullname": "Jack Jones"}, - ], + stmt, [{"username": "wendy", "fullname": "Wendy Smith"}, + {"username": "jack", "fullname": "Jack Jones"}, + ] ) SQLAlchemy's Core expression system makes wide use of @@ -540,7 +532,7 @@ def bindparam( coerced into fixed :func:`.bindparam` constructs. For example, given a comparison operation such as:: - expr = users_table.c.name == "Wendy" + expr = users_table.c.name == 'Wendy' The above expression will produce a :class:`.BinaryExpression` construct, where the left side is the :class:`_schema.Column` object @@ -548,11 +540,9 @@ def bindparam( :class:`.BindParameter` representing the literal value:: print(repr(expr.right)) - BindParameter("%(4327771088 name)s", "Wendy", type_=String()) + BindParameter('%(4327771088 name)s', 'Wendy', type_=String()) - The expression above will render SQL such as: - - .. sourcecode:: sql + The expression above will render SQL such as:: user.name = :name_1 @@ -561,12 +551,10 @@ def bindparam( along where it is later used within statement execution. If we invoke a statement like the following:: - stmt = select(users_table).where(users_table.c.name == "Wendy") + stmt = select(users_table).where(users_table.c.name == 'Wendy') result = connection.execute(stmt) - We would see SQL logging output as: - - .. sourcecode:: sql + We would see SQL logging output as:: SELECT "user".id, "user".name FROM "user" @@ -584,11 +572,9 @@ def bindparam( bound placeholders based on the arguments passed, as in:: stmt = users_table.insert() - result = connection.execute(stmt, {"name": "Wendy"}) + result = connection.execute(stmt, name='Wendy') - The above will produce SQL output as: - - .. sourcecode:: sql + The above will produce SQL output as:: INSERT INTO "user" (name) VALUES (%(name)s) {'name': 'Wendy'} @@ -661,12 +647,12 @@ def bindparam( :param quote: True if this parameter name requires quoting and is not currently known as a SQLAlchemy reserved word; this currently - only applies to the Oracle Database backends, where bound names must + only applies to the Oracle backend, where bound names must sometimes be quoted. :param isoutparam: if True, the parameter should be treated like a stored procedure - "OUT" parameter. This applies to backends such as Oracle Database which + "OUT" parameter. This applies to backends such as Oracle which support OUT parameters. :param expanding: @@ -752,17 +738,16 @@ def case( from sqlalchemy import case - stmt = select(users_table).where( - case( - (users_table.c.name == "wendy", "W"), - (users_table.c.name == "jack", "J"), - else_="E", - ) - ) + stmt = select(users_table).\ + where( + case( + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J'), + else_='E' + ) + ) - The above statement will produce SQL resembling: - - .. sourcecode:: sql + The above statement will produce SQL resembling:: SELECT id, name FROM user WHERE CASE @@ -780,9 +765,14 @@ def case( compared against keyed to result expressions. The statement below is equivalent to the preceding statement:: - stmt = select(users_table).where( - case({"wendy": "W", "jack": "J"}, value=users_table.c.name, else_="E") - ) + stmt = select(users_table).\ + where( + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name, + else_='E' + ) + ) The values which are accepted as result values in :paramref:`.case.whens` as well as with :paramref:`.case.else_` are @@ -797,16 +787,20 @@ def case( from sqlalchemy import case, literal_column case( - (orderline.c.qty > 100, literal_column("'greaterthan100'")), - (orderline.c.qty > 10, literal_column("'greaterthan10'")), - else_=literal_column("'lessthan10'"), + ( + orderline.c.qty > 100, + literal_column("'greaterthan100'") + ), + ( + orderline.c.qty > 10, + literal_column("'greaterthan10'") + ), + else_=literal_column("'lessthan10'") ) The above will render the given constants without using bound parameters for the result values (but still for the comparison - values), as in: - - .. sourcecode:: sql + values), as in:: CASE WHEN (orderline.qty > :qty_1) THEN 'greaterthan100' @@ -827,8 +821,8 @@ def case( resulting value, e.g.:: case( - (users_table.c.name == "wendy", "W"), - (users_table.c.name == "jack", "J"), + (users_table.c.name == 'wendy', 'W'), + (users_table.c.name == 'jack', 'J') ) In the second form, it accepts a Python dictionary of comparison @@ -836,7 +830,10 @@ def case( :paramref:`.case.value` to be present, and values will be compared using the ``==`` operator, e.g.:: - case({"wendy": "W", "jack": "J"}, value=users_table.c.name) + case( + {"wendy": "W", "jack": "J"}, + value=users_table.c.name + ) :param value: An optional SQL expression which will be used as a fixed "comparison point" for candidate values within a dictionary @@ -849,7 +846,7 @@ def case( expressions evaluate to true. - """ # noqa: E501 + """ return Case(*whens, value=value, else_=else_) @@ -867,9 +864,7 @@ def cast( stmt = select(cast(product_table.c.unit_price, Numeric(10, 4))) - The above statement will produce SQL resembling: - - .. sourcecode:: sql + The above statement will produce SQL resembling:: SELECT CAST(unit_price AS NUMERIC(10, 4)) FROM product @@ -938,11 +933,11 @@ def try_cast( from sqlalchemy import select, try_cast, Numeric - stmt = select(try_cast(product_table.c.unit_price, Numeric(10, 4))) + stmt = select( + try_cast(product_table.c.unit_price, Numeric(10, 4)) + ) - The above would render on Microsoft SQL Server as: - - .. sourcecode:: sql + The above would render on Microsoft SQL Server as:: SELECT TRY_CAST (product_table.unit_price AS NUMERIC(10, 4)) FROM product_table @@ -973,9 +968,7 @@ def column( id, name = column("id"), column("name") stmt = select(id, name).select_from("user") - The above statement would produce SQL like: - - .. sourcecode:: sql + The above statement would produce SQL like:: SELECT id, name FROM user @@ -1011,14 +1004,13 @@ def column( from sqlalchemy import table, column, select - user = table( - "user", - column("id"), - column("name"), - column("description"), + user = table("user", + column("id"), + column("name"), + column("description"), ) - stmt = select(user.c.description).where(user.c.name == "wendy") + stmt = select(user.c.description).where(user.c.name == 'wendy') A :func:`_expression.column` / :func:`.table` construct like that illustrated @@ -1065,9 +1057,7 @@ def desc( stmt = select(users_table).order_by(desc(users_table.c.name)) - will produce SQL as: - - .. sourcecode:: sql + will produce SQL as:: SELECT id, name FROM user ORDER BY name DESC @@ -1100,26 +1090,16 @@ def desc( def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: """Produce an column-expression-level unary ``DISTINCT`` clause. - This applies the ``DISTINCT`` keyword to an **individual column - expression** (e.g. not the whole statement), and renders **specifically - in that column position**; this is used for containment within - an aggregate function, as in:: + This applies the ``DISTINCT`` keyword to an individual column + expression, and is typically contained within an aggregate function, + as in:: from sqlalchemy import distinct, func + stmt = select(func.count(distinct(users_table.c.name))) - stmt = select(users_table.c.id, func.count(distinct(users_table.c.name))) + The above would produce an expression resembling:: - The above would produce an statement resembling: - - .. sourcecode:: sql - - SELECT user.id, count(DISTINCT user.name) FROM user - - .. tip:: The :func:`_sql.distinct` function does **not** apply DISTINCT - to the full SELECT statement, instead applying a DISTINCT modifier - to **individual column expressions**. For general ``SELECT DISTINCT`` - support, use the - :meth:`_sql.Select.distinct` method on :class:`_sql.Select`. + SELECT COUNT(DISTINCT name) FROM user The :func:`.distinct` function is also available as a column-level method, e.g. :meth:`_expression.ColumnElement.distinct`, as in:: @@ -1142,7 +1122,7 @@ def distinct(expr: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :data:`.func` - """ # noqa: E501 + """ return UnaryExpression._create_distinct(expr) @@ -1172,9 +1152,6 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: :param field: The field to extract. - .. warning:: This field is used as a literal SQL string. - **DO NOT PASS UNTRUSTED INPUT TO THIS STRING**. - :param expr: A column or Python scalar expression serving as the right side of the ``EXTRACT`` expression. @@ -1183,10 +1160,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: from sqlalchemy import extract from sqlalchemy import table, column - logged_table = table( - "user", - column("id"), - column("date_created"), + logged_table = table("user", + column("id"), + column("date_created"), ) stmt = select(logged_table.c.id).where( @@ -1198,9 +1174,9 @@ def extract(field: str, expr: _ColumnExpressionArgument[Any]) -> Extract: Similarly, one can also select an extracted component:: - stmt = select(extract("YEAR", logged_table.c.date_created)).where( - logged_table.c.id == 1 - ) + stmt = select( + extract("YEAR", logged_table.c.date_created) + ).where(logged_table.c.id == 1) The implementation of ``EXTRACT`` may vary across database backends. Users are reminded to consult their database documentation. @@ -1259,8 +1235,7 @@ def funcfilter( E.g.:: from sqlalchemy import funcfilter - - funcfilter(func.count(1), MyClass.name == "some name") + funcfilter(func.count(1), MyClass.name == 'some name') Would produce "COUNT(1) FILTER (WHERE myclass.name = 'some name')". @@ -1317,11 +1292,10 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: from sqlalchemy import desc, nulls_first - stmt = select(users_table).order_by(nulls_first(desc(users_table.c.name))) + stmt = select(users_table).order_by( + nulls_first(desc(users_table.c.name))) - The SQL expression from the above would resemble: - - .. sourcecode:: sql + The SQL expression from the above would resemble:: SELECT id, name FROM user ORDER BY name DESC NULLS FIRST @@ -1332,8 +1306,7 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: function version, as in:: stmt = select(users_table).order_by( - users_table.c.name.desc().nulls_first() - ) + users_table.c.name.desc().nulls_first()) .. versionchanged:: 1.4 :func:`.nulls_first` is renamed from :func:`.nullsfirst` in previous releases. @@ -1349,7 +1322,7 @@ def nulls_first(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :meth:`_expression.Select.order_by` - """ # noqa: E501 + """ return UnaryExpression._create_nulls_first(column) @@ -1363,11 +1336,10 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: from sqlalchemy import desc, nulls_last - stmt = select(users_table).order_by(nulls_last(desc(users_table.c.name))) + stmt = select(users_table).order_by( + nulls_last(desc(users_table.c.name))) - The SQL expression from the above would resemble: - - .. sourcecode:: sql + The SQL expression from the above would resemble:: SELECT id, name FROM user ORDER BY name DESC NULLS LAST @@ -1377,7 +1349,8 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: rather than as its standalone function version, as in:: - stmt = select(users_table).order_by(users_table.c.name.desc().nulls_last()) + stmt = select(users_table).order_by( + users_table.c.name.desc().nulls_last()) .. versionchanged:: 1.4 :func:`.nulls_last` is renamed from :func:`.nullslast` in previous releases. @@ -1393,7 +1366,7 @@ def nulls_last(column: _ColumnExpressionArgument[_T]) -> UnaryExpression[_T]: :meth:`_expression.Select.order_by` - """ # noqa: E501 + """ return UnaryExpression._create_nulls_last(column) @@ -1408,8 +1381,11 @@ def or_( # type: ignore[empty-body] from sqlalchemy import or_ stmt = select(users_table).where( - or_(users_table.c.name == "wendy", users_table.c.name == "jack") - ) + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) The :func:`.or_` conjunction is also available using the Python ``|`` operator (though note that compound expressions @@ -1417,8 +1393,9 @@ def or_( # type: ignore[empty-body] operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == "wendy") | (users_table.c.name == "jack") - ) + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) The :func:`.or_` construct must be given at least one positional argument in order to be valid; a :func:`.or_` construct with no @@ -1428,7 +1405,6 @@ def or_( # type: ignore[empty-body] specified:: from sqlalchemy import false - or_criteria = or_(false(), *expressions) The above expression will compile to SQL as the expression ``false`` @@ -1460,8 +1436,11 @@ if not TYPE_CHECKING: from sqlalchemy import or_ stmt = select(users_table).where( - or_(users_table.c.name == "wendy", users_table.c.name == "jack") - ) + or_( + users_table.c.name == 'wendy', + users_table.c.name == 'jack' + ) + ) The :func:`.or_` conjunction is also available using the Python ``|`` operator (though note that compound expressions @@ -1469,8 +1448,9 @@ if not TYPE_CHECKING: operator precedence behavior):: stmt = select(users_table).where( - (users_table.c.name == "wendy") | (users_table.c.name == "jack") - ) + (users_table.c.name == 'wendy') | + (users_table.c.name == 'jack') + ) The :func:`.or_` construct must be given at least one positional argument in order to be valid; a :func:`.or_` construct with no @@ -1480,7 +1460,6 @@ if not TYPE_CHECKING: specified:: from sqlalchemy import false - or_criteria = or_(false(), *expressions) The above expression will compile to SQL as the expression ``false`` @@ -1498,17 +1477,26 @@ if not TYPE_CHECKING: :func:`.and_` - """ # noqa: E501 + """ return BooleanClauseList.or_(*clauses) def over( element: FunctionElement[_T], - partition_by: Optional[_ByArgument] = None, - order_by: Optional[_ByArgument] = None, + partition_by: Optional[ + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] + ] = None, + order_by: Optional[ + Union[ + Iterable[_ColumnExpressionArgument[Any]], + _ColumnExpressionArgument[Any], + ] + ] = None, range_: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, rows: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, - groups: Optional[typing_Tuple[Optional[int], Optional[int]]] = None, ) -> Over[_T]: r"""Produce an :class:`.Over` object against a function. @@ -1520,23 +1508,19 @@ def over( func.row_number().over(order_by=mytable.c.some_column) - Would produce: - - .. sourcecode:: sql + Would produce:: ROW_NUMBER() OVER(ORDER BY some_column) - Ranges are also possible using the :paramref:`.expression.over.range_`, - :paramref:`.expression.over.rows`, and :paramref:`.expression.over.groups` - parameters. These + Ranges are also possible using the :paramref:`.expression.over.range_` + and :paramref:`.expression.over.rows` parameters. These mutually-exclusive parameters each accept a 2-tuple, which contains a combination of integers and None:: - func.row_number().over(order_by=my_table.c.some_column, range_=(None, 0)) + func.row_number().over( + order_by=my_table.c.some_column, range_=(None, 0)) - The above would produce: - - .. sourcecode:: sql + The above would produce:: ROW_NUMBER() OVER(ORDER BY some_column RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) @@ -1547,23 +1531,19 @@ def over( * RANGE BETWEEN 5 PRECEDING AND 10 FOLLOWING:: - func.row_number().over(order_by="x", range_=(-5, 10)) + func.row_number().over(order_by='x', range_=(-5, 10)) * ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW:: - func.row_number().over(order_by="x", rows=(None, 0)) + func.row_number().over(order_by='x', rows=(None, 0)) * RANGE BETWEEN 2 PRECEDING AND UNBOUNDED FOLLOWING:: - func.row_number().over(order_by="x", range_=(-2, None)) + func.row_number().over(order_by='x', range_=(-2, None)) * RANGE BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: - func.row_number().over(order_by="x", range_=(1, 3)) - - * GROUPS BETWEEN 1 FOLLOWING AND 3 FOLLOWING:: - - func.row_number().over(order_by="x", groups=(1, 3)) + func.row_number().over(order_by='x', range_=(1, 3)) :param element: a :class:`.FunctionElement`, :class:`.WithinGroup`, or other compatible construct. @@ -1576,14 +1556,10 @@ def over( :param range\_: optional range clause for the window. This is a tuple value which can contain integer values or ``None``, and will render a RANGE BETWEEN PRECEDING / FOLLOWING clause. + :param rows: optional rows clause for the window. This is a tuple value which can contain integer values or None, and will render a ROWS BETWEEN PRECEDING / FOLLOWING clause. - :param groups: optional groups clause for the window. This is a - tuple value which can contain integer values or ``None``, - and will render a GROUPS BETWEEN PRECEDING / FOLLOWING clause. - - .. versionadded:: 2.0.40 This function is also available from the :data:`~.expression.func` construct itself via the :meth:`.FunctionElement.over` method. @@ -1596,8 +1572,8 @@ def over( :func:`_expression.within_group` - """ # noqa: E501 - return Over(element, partition_by, order_by, range_, rows, groups) + """ + return Over(element, partition_by, order_by, range_, rows) @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`") @@ -1627,7 +1603,7 @@ def text(text: str) -> TextClause: E.g.:: t = text("SELECT * FROM users WHERE id=:user_id") - result = connection.execute(t, {"user_id": 12}) + result = connection.execute(t, user_id=12) For SQL statements where a colon is required verbatim, as within an inline string, use a backslash to escape:: @@ -1645,11 +1621,9 @@ def text(text: str) -> TextClause: method allows specification of return columns including names and types:: - t = ( - text("SELECT * FROM users WHERE id=:user_id") - .bindparams(user_id=7) - .columns(id=Integer, name=String) - ) + t = text("SELECT * FROM users WHERE id=:user_id").\ + bindparams(user_id=7).\ + columns(id=Integer, name=String) for id, name in connection.execute(t): print(id, name) @@ -1659,7 +1633,7 @@ def text(text: str) -> TextClause: such as for the WHERE clause of a SELECT statement:: s = select(users.c.id, users.c.name).where(text("id=:user_id")) - result = connection.execute(s, {"user_id": 12}) + result = connection.execute(s, user_id=12) :func:`_expression.text` is also used for the construction of a full, standalone statement using plain text. @@ -1731,7 +1705,9 @@ def tuple_( from sqlalchemy import tuple_ - tuple_(table.c.col1, table.c.col2).in_([(1, 2), (5, 12), (10, 19)]) + tuple_(table.c.col1, table.c.col2).in_( + [(1, 2), (5, 12), (10, 19)] + ) .. versionchanged:: 1.3.6 Added support for SQLite IN tuples. @@ -1781,9 +1757,10 @@ def type_coerce( :meth:`_expression.ColumnElement.label`:: stmt = select( - type_coerce(log_table.date_string, StringDateTime()).label("date") + type_coerce(log_table.date_string, StringDateTime()).label('date') ) + A type that features bound-value handling will also have that behavior take effect when literal values or :func:`.bindparam` constructs are passed to :func:`.type_coerce` as targets. @@ -1844,10 +1821,11 @@ def within_group( the :meth:`.FunctionElement.within_group` method, e.g.:: from sqlalchemy import within_group - stmt = select( department.c.id, - func.percentile_cont(0.5).within_group(department.c.salary.desc()), + func.percentile_cont(0.5).within_group( + department.c.salary.desc() + ) ) The above statement would produce SQL similar to diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_orm_types.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_orm_types.py index c37d805..90986ec 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_orm_types.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_orm_types.py @@ -1,5 +1,5 @@ # sql/_orm_types.py -# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2022 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_py_util.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_py_util.py index 9e1a084..edff0d6 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_py_util.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_py_util.py @@ -1,5 +1,5 @@ # sql/_py_util.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_selectable_constructors.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_selectable_constructors.py index dfb5ad0..41e8b6e 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_selectable_constructors.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_selectable_constructors.py @@ -1,5 +1,5 @@ # sql/_selectable_constructors.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -12,6 +12,7 @@ from typing import Optional from typing import overload from typing import Tuple from typing import TYPE_CHECKING +from typing import TypeVar from typing import Union from . import coercions @@ -46,7 +47,6 @@ if TYPE_CHECKING: from ._typing import _T7 from ._typing import _T8 from ._typing import _T9 - from ._typing import _TP from ._typing import _TypedColumnClauseArgument as _TCCA from .functions import Function from .selectable import CTE @@ -55,6 +55,9 @@ if TYPE_CHECKING: from .selectable import SelectBase +_T = TypeVar("_T", bound=Any) + + def alias( selectable: FromClause, name: Optional[str] = None, flat: bool = False ) -> NamedFromClause: @@ -103,28 +106,9 @@ def cte( ) -# TODO: mypy requires the _TypedSelectable overloads in all compound select -# constructors since _SelectStatementForCompoundArgument includes -# untyped args that make it return CompoundSelect[Unpack[tuple[Never, ...]]] -# pyright does not have this issue -_TypedSelectable = Union["Select[_TP]", "CompoundSelect[_TP]"] - - -@overload def except_( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def except_( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def except_( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT`` of multiple selectables. The returned object is an instance of @@ -137,21 +121,9 @@ def except_( return CompoundSelect._create_except(*selects) -@overload def except_all( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def except_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def except_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``EXCEPT ALL`` of multiple selectables. The returned object is an instance of @@ -183,16 +155,16 @@ def exists( :meth:`_sql.SelectBase.exists` method:: exists_criteria = ( - select(table2.c.col2).where(table1.c.col1 == table2.c.col2).exists() + select(table2.c.col2). + where(table1.c.col1 == table2.c.col2). + exists() ) The EXISTS criteria is then used inside of an enclosing SELECT:: stmt = select(table1.c.col1).where(exists_criteria) - The above statement will then be of the form: - - .. sourcecode:: sql + The above statement will then be of the form:: SELECT col1 FROM table1 WHERE EXISTS (SELECT table2.col2 FROM table2 WHERE table2.col2 = table1.col1) @@ -209,21 +181,9 @@ def exists( return Exists(__argument) -@overload def intersect( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def intersect( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def intersect( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT`` of multiple selectables. The returned object is an instance of @@ -236,21 +196,9 @@ def intersect( return CompoundSelect._create_intersect(*selects) -@overload def intersect_all( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def intersect_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def intersect_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return an ``INTERSECT ALL`` of multiple selectables. The returned object is an instance of @@ -277,14 +225,11 @@ def join( E.g.:: - j = join( - user_table, address_table, user_table.c.id == address_table.c.user_id - ) + j = join(user_table, address_table, + user_table.c.id == address_table.c.user_id) stmt = select(user_table).select_from(j) - would emit SQL along the lines of: - - .. sourcecode:: sql + would emit SQL along the lines of:: SELECT user.id, user.name FROM user JOIN address ON user.id = address.user_id @@ -318,7 +263,7 @@ def join( :class:`_expression.Join` - the type of object produced. - """ # noqa: E501 + """ return Join(left, right, onclause, isouter, full) @@ -385,19 +330,20 @@ def outerjoin( @overload -def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: ... +def select(__ent0: _TCCA[_T0]) -> Select[Tuple[_T0]]: + ... @overload -def select( - __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] -) -> Select[Tuple[_T0, _T1]]: ... +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1]) -> Select[Tuple[_T0, _T1]]: + ... @overload def select( __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] -) -> Select[Tuple[_T0, _T1, _T2]]: ... +) -> Select[Tuple[_T0, _T1, _T2]]: + ... @overload @@ -406,7 +352,8 @@ def select( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], -) -> Select[Tuple[_T0, _T1, _T2, _T3]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3]]: + ... @overload @@ -416,7 +363,8 @@ def select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4]]: + ... @overload @@ -427,7 +375,8 @@ def select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5]]: + ... @overload @@ -439,7 +388,8 @@ def select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6]]: + ... @overload @@ -452,7 +402,8 @@ def select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]]: + ... @overload @@ -466,7 +417,8 @@ def select( __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]]: + ... @overload @@ -481,16 +433,16 @@ def select( __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], __ent9: _TCCA[_T9], -) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: ... +) -> Select[Tuple[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9]]: + ... # END OVERLOADED FUNCTIONS select @overload -def select( - *entities: _ColumnsClauseArgument[Any], **__kw: Any -) -> Select[Any]: ... +def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: + ... def select(*entities: _ColumnsClauseArgument[Any], **__kw: Any) -> Select[Any]: @@ -584,14 +536,13 @@ def tablesample( from sqlalchemy import func selectable = people.tablesample( - func.bernoulli(1), name="alias", seed=func.random() - ) + func.bernoulli(1), + name='alias', + seed=func.random()) stmt = select(selectable.c.people_id) Assuming ``people`` with a column ``people_id``, the above - statement would render as: - - .. sourcecode:: sql + statement would render as:: SELECT alias.people_id FROM people AS alias TABLESAMPLE bernoulli(:bernoulli_1) @@ -609,21 +560,9 @@ def tablesample( return TableSample._factory(selectable, sampling, name=name, seed=seed) -@overload def union( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def union( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def union( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION`` of multiple selectables. The returned object is an instance of @@ -643,21 +582,9 @@ def union( return CompoundSelect._create_union(*selects) -@overload def union_all( - *selects: _TypedSelectable[_TP], -) -> CompoundSelect[_TP]: ... - - -@overload -def union_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: ... - - -def union_all( - *selects: _SelectStatementForCompoundArgument[_TP], -) -> CompoundSelect[_TP]: + *selects: _SelectStatementForCompoundArgument, +) -> CompoundSelect: r"""Return a ``UNION ALL`` of multiple selectables. The returned object is an instance of @@ -678,75 +605,28 @@ def values( name: Optional[str] = None, literal_binds: bool = False, ) -> Values: - r"""Construct a :class:`_expression.Values` construct representing the - SQL ``VALUES`` clause. + r"""Construct a :class:`_expression.Values` construct. - - The column expressions and the actual data for :class:`_expression.Values` - are given in two separate steps. The constructor receives the column - expressions typically as :func:`_expression.column` constructs, and the - data is then passed via the :meth:`_expression.Values.data` method as a - list, which can be called multiple times to add more data, e.g.:: + The column expressions and the actual data for + :class:`_expression.Values` are given in two separate steps. The + constructor receives the column expressions typically as + :func:`_expression.column` constructs, + and the data is then passed via the + :meth:`_expression.Values.data` method as a list, + which can be called multiple + times to add more data, e.g.:: from sqlalchemy import column from sqlalchemy import values - from sqlalchemy import Integer - from sqlalchemy import String - - value_expr = ( - values( - column("id", Integer), - column("name", String), - ) - .data([(1, "name1"), (2, "name2")]) - .data([(3, "name3")]) - ) - - Would represent a SQL fragment like:: - - VALUES(1, "name1"), (2, "name2"), (3, "name3") - - The :class:`_sql.values` construct has an optional - :paramref:`_sql.values.name` field; when using this field, the - PostgreSQL-specific "named VALUES" clause may be generated:: value_expr = values( - column("id", Integer), column("name", String), name="somename" - ).data([(1, "name1"), (2, "name2"), (3, "name3")]) - - When selecting from the above construct, the name and column names will - be listed out using a PostgreSQL-specific syntax:: - - >>> print(value_expr.select()) - SELECT somename.id, somename.name - FROM (VALUES (:param_1, :param_2), (:param_3, :param_4), - (:param_5, :param_6)) AS somename (id, name) - - For a more database-agnostic means of SELECTing named columns from a - VALUES expression, the :meth:`.Values.cte` method may be used, which - produces a named CTE with explicit column names against the VALUES - construct within; this syntax works on PostgreSQL, SQLite, and MariaDB:: - - value_expr = ( - values( - column("id", Integer), - column("name", String), - ) - .data([(1, "name1"), (2, "name2"), (3, "name3")]) - .cte() + column('id', Integer), + column('name', String), + name="my_values" + ).data( + [(1, 'name1'), (2, 'name2'), (3, 'name3')] ) - Rendering as:: - - >>> print(value_expr.select()) - WITH anon_1(id, name) AS - (VALUES (:param_1, :param_2), (:param_3, :param_4), (:param_5, :param_6)) - SELECT anon_1.id, anon_1.name - FROM anon_1 - - .. versionadded:: 2.0.42 Added the :meth:`.Values.cte` method to - :class:`.Values` - :param \*columns: column expressions, typically composed using :func:`_expression.column` objects. @@ -758,6 +638,5 @@ def values( the data values inline in the SQL output, rather than using bound parameters. - """ # noqa: E501 - + """ return Values(*columns, literal_binds=literal_binds, name=name) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_typing.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_typing.py index 8e3c66e..c9e1830 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/_typing.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/_typing.py @@ -1,5 +1,5 @@ # sql/_typing.py -# Copyright (C) 2022-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2022 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,8 +11,6 @@ import operator from typing import Any from typing import Callable from typing import Dict -from typing import Generic -from typing import Iterable from typing import Mapping from typing import NoReturn from typing import Optional @@ -53,10 +51,10 @@ if TYPE_CHECKING: from .elements import SQLCoreOperations from .elements import TextClause from .lambdas import LambdaElement + from .roles import ColumnsClauseRole from .roles import FromClauseRole from .schema import Column from .selectable import Alias - from .selectable import CompoundSelect from .selectable import CTE from .selectable import FromClause from .selectable import Join @@ -70,14 +68,9 @@ if TYPE_CHECKING: from .sqltypes import TableValueType from .sqltypes import TupleType from .type_api import TypeEngine - from ..engine import Connection - from ..engine import Dialect - from ..engine import Engine - from ..engine.mock import MockConnection from ..util.typing import TypeGuard _T = TypeVar("_T", bound=Any) -_T_co = TypeVar("_T_co", bound=Any, covariant=True) _CE = TypeVar("_CE", bound="ColumnElement[Any]") @@ -85,25 +78,18 @@ _CE = TypeVar("_CE", bound="ColumnElement[Any]") _CLE = TypeVar("_CLE", bound="ClauseElement") -class _HasClauseElement(Protocol, Generic[_T_co]): +class _HasClauseElement(Protocol): """indicates a class that has a __clause_element__() method""" - def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: ... + def __clause_element__(self) -> ColumnsClauseRole: + ... class _CoreAdapterProto(Protocol): """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" - def __call__(self, obj: _CE) -> _CE: ... - - -class _HasDialect(Protocol): - """protocol for Engine/Connection-like objects that have dialect - attribute. - """ - - @property - def dialect(self) -> Dialect: ... + def __call__(self, obj: _CE) -> _CE: + ... # match column types that are not ORM entities @@ -111,7 +97,6 @@ _NOT_ENTITY = TypeVar( "_NOT_ENTITY", int, str, - bool, "datetime", "date", "time", @@ -121,15 +106,13 @@ _NOT_ENTITY = TypeVar( "Decimal", ) -_StarOrOne = Literal["*", 1] - _MAYBE_ENTITY = TypeVar( "_MAYBE_ENTITY", roles.ColumnsClauseRole, - _StarOrOne, + Literal["*", 1], Type[Any], - Inspectable[_HasClauseElement[Any]], - _HasClauseElement[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, ) @@ -143,7 +126,7 @@ _TextCoercedExpressionArgument = Union[ str, "TextClause", "ColumnElement[_T]", - _HasClauseElement[_T], + _HasClauseElement, roles.ExpressionElementRole[_T], ] @@ -151,10 +134,10 @@ _ColumnsClauseArgument = Union[ roles.TypedColumnsClauseRole[_T], roles.ColumnsClauseRole, "SQLCoreOperations[_T]", - _StarOrOne, + Literal["*", 1], Type[_T], - Inspectable[_HasClauseElement[_T]], - _HasClauseElement[_T], + Inspectable[_HasClauseElement], + _HasClauseElement, ] """open-ended SELECT columns clause argument. @@ -188,10 +171,9 @@ _T9 = TypeVar("_T9", bound=Any) _ColumnExpressionArgument = Union[ "ColumnElement[_T]", - _HasClauseElement[_T], + _HasClauseElement, "SQLCoreOperations[_T]", roles.ExpressionElementRole[_T], - roles.TypedColumnsClauseRole[_T], Callable[[], "ColumnElement[_T]"], "LambdaElement", ] @@ -216,12 +198,6 @@ _ColumnExpressionOrLiteralArgument = Union[Any, _ColumnExpressionArgument[_T]] _ColumnExpressionOrStrLabelArgument = Union[str, _ColumnExpressionArgument[_T]] -_ByArgument = Union[ - Iterable[_ColumnExpressionOrStrLabelArgument[Any]], - _ColumnExpressionOrStrLabelArgument[Any], -] -"""Used for keyword-based ``order_by`` and ``partition_by`` parameters.""" - _InfoType = Dict[Any, Any] """the .info dictionary accepted and used throughout Core /ORM""" @@ -229,8 +205,8 @@ _InfoType = Dict[Any, Any] _FromClauseArgument = Union[ roles.FromClauseRole, Type[Any], - Inspectable[_HasClauseElement[Any]], - _HasClauseElement[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, ] """A FROM clause, like we would send to select().select_from(). @@ -251,15 +227,13 @@ come from the ORM. """ _SelectStatementForCompoundArgument = Union[ - "Select[_TP]", - "CompoundSelect[_TP]", - roles.CompoundElementRole, + "SelectBase", roles.CompoundElementRole ] """SELECT statement acceptable by ``union()`` and other SQL set operations""" _DMLColumnArgument = Union[ str, - _HasClauseElement[Any], + _HasClauseElement, roles.DMLColumnRole, "SQLCoreOperations[Any]", ] @@ -290,8 +264,8 @@ _DMLTableArgument = Union[ "Alias", "CTE", Type[Any], - Inspectable[_HasClauseElement[Any]], - _HasClauseElement[Any], + Inspectable[_HasClauseElement], + _HasClauseElement, ] _PropagateAttrsType = util.immutabledict[str, Any] @@ -304,51 +278,58 @@ _LimitOffsetType = Union[int, _ColumnExpressionArgument[int], None] _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] -_CreateDropBind = Union["Engine", "Connection", "MockConnection"] - if TYPE_CHECKING: - def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: + ... - def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: ... + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: + ... - def is_named_from_clause( - t: FromClauseRole, - ) -> TypeGuard[NamedFromClause]: ... + def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: + ... - def is_column_element( - c: ClauseElement, - ) -> TypeGuard[ColumnElement[Any]]: ... + def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: + ... def is_keyed_column_element( c: ClauseElement, - ) -> TypeGuard[KeyedColumnElement[Any]]: ... + ) -> TypeGuard[KeyedColumnElement[Any]]: + ... - def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: + ... - def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: ... + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: + ... - def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: ... + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: + ... - def is_table_value_type( - t: TypeEngine[Any], - ) -> TypeGuard[TableValueType]: ... + def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: + ... - def is_selectable(t: Any) -> TypeGuard[Selectable]: ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: + ... def is_select_base( - t: Union[Executable, ReturnsRows], - ) -> TypeGuard[SelectBase]: ... + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[SelectBase]: + ... def is_select_statement( - t: Union[Executable, ReturnsRows], - ) -> TypeGuard[Select[Any]]: ... + t: Union[Executable, ReturnsRows] + ) -> TypeGuard[Select[Any]]: + ... - def is_table(t: FromClause) -> TypeGuard[TableClause]: ... + def is_table(t: FromClause) -> TypeGuard[TableClause]: + ... - def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: + ... - def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: ... + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: + ... else: is_sql_compiler = operator.attrgetter("is_sql") @@ -376,7 +357,7 @@ def is_quoted_name(s: str) -> TypeGuard[quoted_name]: return hasattr(s, "quote") -def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement[Any]]: +def is_has_clause_element(s: object) -> TypeGuard[_HasClauseElement]: return hasattr(s, "__clause_element__") @@ -399,17 +380,20 @@ def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn: @overload def Nullable( val: "SQLCoreOperations[_T]", -) -> "SQLCoreOperations[Optional[_T]]": ... +) -> "SQLCoreOperations[Optional[_T]]": + ... @overload def Nullable( val: roles.ExpressionElementRole[_T], -) -> roles.ExpressionElementRole[Optional[_T]]: ... +) -> roles.ExpressionElementRole[Optional[_T]]: + ... @overload -def Nullable(val: Type[_T]) -> Type[Optional[_T]]: ... +def Nullable(val: Type[_T]) -> Type[Optional[_T]]: + ... def Nullable( @@ -433,21 +417,25 @@ def Nullable( @overload def NotNullable( val: "SQLCoreOperations[Optional[_T]]", -) -> "SQLCoreOperations[_T]": ... +) -> "SQLCoreOperations[_T]": + ... @overload def NotNullable( val: roles.ExpressionElementRole[Optional[_T]], -) -> roles.ExpressionElementRole[_T]: ... +) -> roles.ExpressionElementRole[_T]: + ... @overload -def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: ... +def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: + ... @overload -def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: ... +def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: + ... def NotNullable( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/annotation.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/annotation.py index bf445ff..08ff47d 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/annotation.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/annotation.py @@ -1,5 +1,5 @@ # sql/annotation.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -67,14 +67,16 @@ class SupportsAnnotations(ExternallyTraversible): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: ... + ) -> Self: + ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: ... + ) -> SupportsAnnotations: + ... def _deannotate( self, @@ -97,11 +99,9 @@ class SupportsAnnotations(ExternallyTraversible): tuple( ( key, - ( - value._gen_cache_key(anon_map, []) - if isinstance(value, HasCacheKey) - else value - ), + value._gen_cache_key(anon_map, []) + if isinstance(value, HasCacheKey) + else value, ) for key, value in [ (key, self._annotations[key]) @@ -119,7 +119,8 @@ class SupportsWrappingAnnotations(SupportsAnnotations): if TYPE_CHECKING: @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: ... + def entity_namespace(self) -> _EntityNamespace: + ... def _annotate(self, values: _AnnotationDict) -> Self: """return a copy of this ClauseElement with annotations @@ -140,14 +141,16 @@ class SupportsWrappingAnnotations(SupportsAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: ... + ) -> Self: + ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: ... + ) -> SupportsAnnotations: + ... def _deannotate( self, @@ -211,14 +214,16 @@ class SupportsCloneAnnotations(SupportsWrappingAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: ... + ) -> Self: + ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: ... + ) -> SupportsAnnotations: + ... def _deannotate( self, @@ -311,14 +316,16 @@ class Annotated(SupportsAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: ... + ) -> Self: + ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> Annotated: ... + ) -> Annotated: + ... def _deannotate( self, @@ -388,9 +395,9 @@ class Annotated(SupportsAnnotations): # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes: Dict[Type[SupportsWrappingAnnotations], Type[Annotated]] = ( - {} -) +annotated_classes: Dict[ + Type[SupportsWrappingAnnotations], Type[Annotated] +] = {} _SA = TypeVar("_SA", bound="SupportsAnnotations") @@ -480,13 +487,15 @@ def _deep_annotate( @overload def _deep_deannotate( element: Literal[None], values: Optional[Sequence[str]] = None -) -> Literal[None]: ... +) -> Literal[None]: + ... @overload def _deep_deannotate( element: _SA, values: Optional[Sequence[str]] = None -) -> _SA: ... +) -> _SA: + ... def _deep_deannotate( diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/base.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/base.py index 21c2201..104c595 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/base.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/base.py @@ -1,12 +1,14 @@ # sql/base.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: allow-untyped-defs, allow-untyped-calls -"""Foundational utilities common to many sql modules.""" +"""Foundational utilities common to many sql modules. + +""" from __future__ import annotations @@ -22,7 +24,6 @@ from typing import Callable from typing import cast from typing import Dict from typing import FrozenSet -from typing import Generator from typing import Generic from typing import Iterable from typing import Iterator @@ -56,7 +57,6 @@ from .. import util from ..util import HasMemoized as HasMemoized from ..util import hybridmethod from ..util import typing as compat_typing -from ..util.typing import Final from ..util.typing import Protocol from ..util.typing import Self from ..util.typing import TypeGuard @@ -68,12 +68,11 @@ if TYPE_CHECKING: from ._orm_types import DMLStrategyArgument from ._orm_types import SynchronizeSessionArgument from ._typing import _CLE - from .cache_key import CacheKey - from .compiler import SQLCompiler from .elements import BindParameter from .elements import ClauseList from .elements import ColumnClause # noqa from .elements import ColumnElement + from .elements import KeyedColumnElement from .elements import NamedColumn from .elements import SQLCoreOperations from .elements import TextClause @@ -82,7 +81,6 @@ if TYPE_CHECKING: from .selectable import _JoinTargetElement from .selectable import _SelectIterable from .selectable import FromClause - from .visitors import anon_map from ..engine import Connection from ..engine import CursorResult from ..engine.interfaces import _CoreMultiExecuteParams @@ -110,7 +108,7 @@ class _NoArg(Enum): return f"_NoArg.{self.name}" -NO_ARG: Final = _NoArg.NO_ARG +NO_ARG = _NoArg.NO_ARG class _NoneName(Enum): @@ -118,7 +116,7 @@ class _NoneName(Enum): """indicate a 'deferred' name that was ultimately the value None.""" -_NONE_NAME: Final = _NoneName.NONE_NAME +_NONE_NAME = _NoneName.NONE_NAME _T = TypeVar("_T", bound=Any) @@ -153,18 +151,18 @@ class _DefaultDescriptionTuple(NamedTuple): ) -_never_select_column: operator.attrgetter[Any] = operator.attrgetter( - "_omit_from_statements" -) +_never_select_column = operator.attrgetter("_omit_from_statements") class _EntityNamespace(Protocol): - def __getattr__(self, key: str) -> SQLCoreOperations[Any]: ... + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: + ... class _HasEntityNamespace(Protocol): @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: ... + def entity_namespace(self) -> _EntityNamespace: + ... def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: @@ -190,12 +188,12 @@ class Immutable: __slots__ = () - _is_immutable: bool = True + _is_immutable = True - def unique_params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: + def unique_params(self, *optionaldict, **kwargs): raise NotImplementedError("Immutable objects do not support copying") - def params(self, *optionaldict: Any, **kwargs: Any) -> NoReturn: + def params(self, *optionaldict, **kwargs): raise NotImplementedError("Immutable objects do not support copying") def _clone(self: _Self, **kw: Any) -> _Self: @@ -210,7 +208,7 @@ class Immutable: class SingletonConstant(Immutable): """Represent SQL constants like NULL, TRUE, FALSE""" - _is_singleton_constant: bool = True + _is_singleton_constant = True _singleton: SingletonConstant @@ -222,7 +220,7 @@ class SingletonConstant(Immutable): raise NotImplementedError() @classmethod - def _create_singleton(cls) -> None: + def _create_singleton(cls): obj = object.__new__(cls) obj.__init__() # type: ignore @@ -263,7 +261,8 @@ _SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") class _GenerativeType(compat_typing.Protocol): - def _generate(self) -> Self: ... + def _generate(self) -> Self: + ... def _generative(fn: _Fn) -> _Fn: @@ -291,17 +290,17 @@ def _generative(fn: _Fn) -> _Fn: def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: - msgs: Dict[str, str] = kw.pop("msgs", {}) + msgs = kw.pop("msgs", {}) - defaults: Dict[str, str] = kw.pop("defaults", {}) + defaults = kw.pop("defaults", {}) - getters: List[Tuple[str, operator.attrgetter[Any], Optional[str]]] = [ + getters = [ (name, operator.attrgetter(name), defaults.get(name, None)) for name in names ] @util.decorator - def check(fn: _Fn, *args: Any, **kw: Any) -> Any: + def check(fn, *args, **kw): # make pylance happy by not including "self" in the argument # list self = args[0] @@ -350,16 +349,12 @@ def _cloned_intersection(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: The returned set is in terms of the entities present within 'a'. """ - all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection( - _expand_cloned(b) - ) + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) return {elem for elem in a if all_overlap.intersection(elem._cloned_set)} def _cloned_difference(a: Iterable[_CLE], b: Iterable[_CLE]) -> Set[_CLE]: - all_overlap: Set[_CLE] = set(_expand_cloned(a)).intersection( - _expand_cloned(b) - ) + all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b)) return { elem for elem in a if not all_overlap.intersection(elem._cloned_set) } @@ -371,12 +366,10 @@ class _DialectArgView(MutableMapping[str, Any]): """ - __slots__ = ("obj",) - - def __init__(self, obj: DialectKWArgs) -> None: + def __init__(self, obj): self.obj = obj - def _key(self, key: str) -> Tuple[str, str]: + def _key(self, key): try: dialect, value_key = key.split("_", 1) except ValueError as err: @@ -384,7 +377,7 @@ class _DialectArgView(MutableMapping[str, Any]): else: return dialect, value_key - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key): dialect, value_key = self._key(key) try: @@ -394,7 +387,7 @@ class _DialectArgView(MutableMapping[str, Any]): else: return opt[value_key] - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key, value): try: dialect, value_key = self._key(key) except KeyError as err: @@ -404,17 +397,17 @@ class _DialectArgView(MutableMapping[str, Any]): else: self.obj.dialect_options[dialect][value_key] = value - def __delitem__(self, key: str) -> None: + def __delitem__(self, key): dialect, value_key = self._key(key) del self.obj.dialect_options[dialect][value_key] - def __len__(self) -> int: + def __len__(self): return sum( len(args._non_defaults) for args in self.obj.dialect_options.values() ) - def __iter__(self) -> Generator[str, None, None]: + def __iter__(self): return ( "%s_%s" % (dialect_name, value_name) for dialect_name in self.obj.dialect_options @@ -433,31 +426,31 @@ class _DialectArgDict(MutableMapping[str, Any]): """ - def __init__(self) -> None: - self._non_defaults: Dict[str, Any] = {} - self._defaults: Dict[str, Any] = {} + def __init__(self): + self._non_defaults = {} + self._defaults = {} - def __len__(self) -> int: + def __len__(self): return len(set(self._non_defaults).union(self._defaults)) - def __iter__(self) -> Iterator[str]: + def __iter__(self): return iter(set(self._non_defaults).union(self._defaults)) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key): if key in self._non_defaults: return self._non_defaults[key] else: return self._defaults[key] - def __setitem__(self, key: str, value: Any) -> None: + def __setitem__(self, key, value): self._non_defaults[key] = value - def __delitem__(self, key: str) -> None: + def __delitem__(self, key): del self._non_defaults[key] @util.preload_module("sqlalchemy.dialects") -def _kw_reg_for_dialect(dialect_name: str) -> Optional[Dict[Any, Any]]: +def _kw_reg_for_dialect(dialect_name): dialect_cls = util.preloaded.dialects.registry.load(dialect_name) if dialect_cls.construct_arguments is None: return None @@ -479,21 +472,19 @@ class DialectKWArgs: __slots__ = () - _dialect_kwargs_traverse_internals: List[Tuple[str, Any]] = [ + _dialect_kwargs_traverse_internals = [ ("dialect_options", InternalTraversal.dp_dialect_options) ] @classmethod - def argument_for( - cls, dialect_name: str, argument_name: str, default: Any - ) -> None: + def argument_for(cls, dialect_name, argument_name, default): """Add a new kind of dialect-specific keyword argument for this class. E.g.:: Index.argument_for("mydialect", "length", None) - some_index = Index("a", "b", mydialect_length=5) + some_index = Index('a', 'b', mydialect_length=5) The :meth:`.DialectKWArgs.argument_for` method is a per-argument way adding extra arguments to the @@ -523,9 +514,7 @@ class DialectKWArgs: """ - construct_arg_dictionary: Optional[Dict[Any, Any]] = ( - DialectKWArgs._kw_registry[dialect_name] - ) + construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] if construct_arg_dictionary is None: raise exc.ArgumentError( "Dialect '%s' does have keyword-argument " @@ -535,8 +524,8 @@ class DialectKWArgs: construct_arg_dictionary[cls] = {} construct_arg_dictionary[cls][argument_name] = default - @property - def dialect_kwargs(self) -> _DialectArgView: + @util.memoized_property + def dialect_kwargs(self): """A collection of keyword arguments specified as dialect-specific options to this construct. @@ -557,29 +546,26 @@ class DialectKWArgs: return _DialectArgView(self) @property - def kwargs(self) -> _DialectArgView: + def kwargs(self): """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" return self.dialect_kwargs - _kw_registry: util.PopulateDict[str, Optional[Dict[Any, Any]]] = ( - util.PopulateDict(_kw_reg_for_dialect) - ) + _kw_registry = util.PopulateDict(_kw_reg_for_dialect) - @classmethod - def _kw_reg_for_dialect_cls(cls, dialect_name: str) -> _DialectArgDict: + def _kw_reg_for_dialect_cls(self, dialect_name): construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] d = _DialectArgDict() if construct_arg_dictionary is None: d._defaults.update({"*": None}) else: - for cls in reversed(cls.__mro__): + for cls in reversed(self.__class__.__mro__): if cls in construct_arg_dictionary: d._defaults.update(construct_arg_dictionary[cls]) return d @util.memoized_property - def dialect_options(self) -> util.PopulateDict[str, _DialectArgDict]: + def dialect_options(self): """A collection of keyword arguments specified as dialect-specific options to this construct. @@ -587,7 +573,7 @@ class DialectKWArgs: and ````. For example, the ``postgresql_where`` argument would be locatable as:: - arg = my_object.dialect_options["postgresql"]["where"] + arg = my_object.dialect_options['postgresql']['where'] .. versionadded:: 0.9.2 @@ -597,7 +583,9 @@ class DialectKWArgs: """ - return util.PopulateDict(self._kw_reg_for_dialect_cls) + return util.PopulateDict( + util.portable_instancemethod(self._kw_reg_for_dialect_cls) + ) def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: # validate remaining kwargs that they all specify DB prefixes @@ -673,9 +661,7 @@ class CompileState: _ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] @classmethod - def create_for_statement( - cls, statement: Executable, compiler: SQLCompiler, **kw: Any - ) -> CompileState: + def create_for_statement(cls, statement, compiler, **kw): # factory construction. if statement._propagate_attrs: @@ -815,11 +801,14 @@ class _MetaOptions(type): if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: ... + def __getattr__(self, key: str) -> Any: + ... - def __setattr__(self, key: str, value: Any) -> None: ... + def __setattr__(self, key: str, value: Any) -> None: + ... - def __delattr__(self, key: str) -> None: ... + def __delattr__(self, key: str) -> None: + ... class Options(metaclass=_MetaOptions): @@ -841,7 +830,7 @@ class Options(metaclass=_MetaOptions): ) super().__init_subclass__() - def __init__(self, **kw: Any) -> None: + def __init__(self, **kw): self.__dict__.update(kw) def __add__(self, other): @@ -866,7 +855,7 @@ class Options(metaclass=_MetaOptions): return False return True - def __repr__(self) -> str: + def __repr__(self): # TODO: fairly inefficient, used only in debugging right now. return "%s(%s)" % ( @@ -883,7 +872,7 @@ class Options(metaclass=_MetaOptions): return issubclass(cls, klass) @hybridmethod - def add_to_element(self, name: str, value: str) -> Any: + def add_to_element(self, name, value): return self + {name: getattr(self, name) + value} @hybridmethod @@ -897,7 +886,7 @@ class Options(metaclass=_MetaOptions): return cls._state_dict_const @classmethod - def safe_merge(cls, other: "Options") -> Any: + def safe_merge(cls, other): d = other._state_dict() # only support a merge with another object of our class @@ -923,12 +912,8 @@ class Options(metaclass=_MetaOptions): @classmethod def from_execution_options( - cls, - key: str, - attrs: set[str], - exec_options: Mapping[str, Any], - statement_exec_options: Mapping[str, Any], - ) -> Tuple["Options", Mapping[str, Any]]: + cls, key, attrs, exec_options, statement_exec_options + ): """process Options argument in terms of execution options. @@ -939,7 +924,11 @@ class Options(metaclass=_MetaOptions): execution_options, ) = QueryContext.default_load_options.from_execution_options( "_sa_orm_load_options", - {"populate_existing", "autoflush", "yield_per"}, + { + "populate_existing", + "autoflush", + "yield_per" + }, execution_options, statement._execution_options, ) @@ -967,8 +956,8 @@ class Options(metaclass=_MetaOptions): result[local] = statement_exec_options[argname] new_options = existing_options + result - exec_options = util.immutabledict(exec_options).merge_with( - {key: new_options} + exec_options = util.immutabledict().merge_with( + exec_options, {key: new_options} ) return new_options, exec_options @@ -977,43 +966,42 @@ class Options(metaclass=_MetaOptions): if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: ... + def __getattr__(self, key: str) -> Any: + ... - def __setattr__(self, key: str, value: Any) -> None: ... + def __setattr__(self, key: str, value: Any) -> None: + ... - def __delattr__(self, key: str) -> None: ... + def __delattr__(self, key: str) -> None: + ... class CacheableOptions(Options, HasCacheKey): __slots__ = () @hybridmethod - def _gen_cache_key_inst( - self, anon_map: Any, bindparams: List[BindParameter[Any]] - ) -> Optional[Tuple[Any]]: + def _gen_cache_key_inst(self, anon_map, bindparams): return HasCacheKey._gen_cache_key(self, anon_map, bindparams) @_gen_cache_key_inst.classlevel - def _gen_cache_key( - cls, anon_map: "anon_map", bindparams: List[BindParameter[Any]] - ) -> Tuple[CacheableOptions, Any]: + def _gen_cache_key(cls, anon_map, bindparams): return (cls, ()) @hybridmethod - def _generate_cache_key(self) -> Optional[CacheKey]: + def _generate_cache_key(self): return HasCacheKey._generate_cache_key_for_object(self) class ExecutableOption(HasCopyInternals): __slots__ = () - _annotations: _ImmutableExecuteOptions = util.EMPTY_DICT + _annotations = util.EMPTY_DICT - __visit_name__: str = "executable_option" + __visit_name__ = "executable_option" - _is_has_cache_key: bool = False + _is_has_cache_key = False - _is_core: bool = True + _is_core = True def _clone(self, **kw): """Create a shallow copy of this ExecutableOption.""" @@ -1033,7 +1021,7 @@ class Executable(roles.StatementRole): supports_execution: bool = True _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT - _is_default_generator: bool = False + _is_default_generator = False _with_options: Tuple[ExecutableOption, ...] = () _with_context_options: Tuple[ Tuple[Callable[[CompileState], None], Any], ... @@ -1049,13 +1037,12 @@ class Executable(roles.StatementRole): ("_propagate_attrs", ExtendedInternalTraversal.dp_propagate_attrs), ] - is_select: bool = False - is_from_statement: bool = False - is_update: bool = False - is_insert: bool = False - is_text: bool = False - is_delete: bool = False - is_dml: bool = False + is_select = False + is_update = False + is_insert = False + is_text = False + is_delete = False + is_dml = False if TYPE_CHECKING: __visit_name__: str @@ -1071,24 +1058,27 @@ class Executable(roles.StatementRole): **kw: Any, ) -> Tuple[ Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats - ]: ... + ]: + ... def _execute_on_connection( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: ... + ) -> CursorResult[Any]: + ... def _execute_on_scalar( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> Any: ... + ) -> Any: + ... @util.ro_non_memoized_property - def _all_selected_columns(self) -> _SelectIterable: + def _all_selected_columns(self): raise NotImplementedError() @property @@ -1189,12 +1179,13 @@ class Executable(roles.StatementRole): render_nulls: bool = ..., is_delete_using: bool = ..., is_update_from: bool = ..., - preserve_rowcount: bool = False, **opt: Any, - ) -> Self: ... + ) -> Self: + ... @overload - def execution_options(self, **opt: Any) -> Self: ... + def execution_options(self, **opt: Any) -> Self: + ... @_generative def execution_options(self, **kw: Any) -> Self: @@ -1246,7 +1237,6 @@ class Executable(roles.StatementRole): from sqlalchemy import event - @event.listens_for(some_engine, "before_execute") def _process_opt(conn, statement, multiparams, params, execution_options): "run a SQL function before invoking a statement" @@ -1348,21 +1338,10 @@ class SchemaEventTarget(event.EventTarget): self.dispatch.after_parent_attach(self, parent) -class SchemaVisitable(SchemaEventTarget, visitors.Visitable): - """Base class for elements that are targets of a :class:`.SchemaVisitor`. - - .. versionadded:: 2.0.41 - - """ - - class SchemaVisitor(ClauseVisitor): - """Define the visiting for ``SchemaItem`` and more - generally ``SchemaVisitable`` objects. + """Define the visiting for ``SchemaItem`` objects.""" - """ - - __traverse_options__: Dict[str, Any] = {"schema_visitor": True} + __traverse_options__ = {"schema_visitor": True} class _SentinelDefaultCharacterization(Enum): @@ -1387,7 +1366,7 @@ class _SentinelColumnCharacterization(NamedTuple): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) _COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) -_COL = TypeVar("_COL", bound="ColumnElement[Any]") +_COL = TypeVar("_COL", bound="KeyedColumnElement[Any]") class _ColumnMetrics(Generic[_COL_co]): @@ -1397,7 +1376,7 @@ class _ColumnMetrics(Generic[_COL_co]): def __init__( self, collection: ColumnCollection[Any, _COL_co], col: _COL_co - ) -> None: + ): self.column = col # proxy_index being non-empty means it was initialized. @@ -1407,10 +1386,10 @@ class _ColumnMetrics(Generic[_COL_co]): for eps_col in col._expanded_proxy_set: pi[eps_col].add(self) - def get_expanded_proxy_set(self) -> FrozenSet[ColumnElement[Any]]: + def get_expanded_proxy_set(self): return self.column._expanded_proxy_set - def dispose(self, collection: ColumnCollection[_COLKEY, _COL_co]) -> None: + def dispose(self, collection): pi = collection._proxy_index if not pi: return @@ -1509,14 +1488,14 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): mean either two columns with the same key, in which case the column returned by key access is **arbitrary**:: - >>> x1, x2 = Column("x", Integer), Column("x", Integer) + >>> x1, x2 = Column('x', Integer), Column('x', Integer) >>> cc = ColumnCollection(columns=[(x1.name, x1), (x2.name, x2)]) >>> list(cc) [Column('x', Integer(), table=None), Column('x', Integer(), table=None)] - >>> cc["x"] is x1 + >>> cc['x'] is x1 False - >>> cc["x"] is x2 + >>> cc['x'] is x2 True Or it can also mean the same column multiple times. These cases are @@ -1543,7 +1522,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): """ - __slots__ = ("_collection", "_index", "_colset", "_proxy_index") + __slots__ = "_collection", "_index", "_colset", "_proxy_index" _collection: List[Tuple[_COLKEY, _COL_co, _ColumnMetrics[_COL_co]]] _index: Dict[Union[None, str, int], Tuple[_COLKEY, _COL_co]] @@ -1612,17 +1591,20 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): return iter([col for _, col, _ in self._collection]) @overload - def __getitem__(self, key: Union[str, int]) -> _COL_co: ... + def __getitem__(self, key: Union[str, int]) -> _COL_co: + ... @overload def __getitem__( self, key: Tuple[Union[str, int], ...] - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + ... @overload def __getitem__( self, key: slice - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: + ... def __getitem__( self, key: Union[str, int, slice, Tuple[Union[str, int], ...]] @@ -1662,7 +1644,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): else: return True - def compare(self, other: ColumnCollection[_COLKEY, _COL_co]) -> bool: + def compare(self, other: ColumnCollection[Any, Any]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1675,15 +1657,9 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): def __eq__(self, other: Any) -> bool: return self.compare(other) - @overload - def get(self, key: str, default: None = None) -> Optional[_COL_co]: ... - - @overload - def get(self, key: str, default: _COL) -> Union[_COL_co, _COL]: ... - def get( - self, key: str, default: Optional[_COL] = None - ) -> Optional[Union[_COL_co, _COL]]: + self, key: str, default: Optional[_COL_co] = None + ) -> Optional[_COL_co]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1713,7 +1689,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): :class:`_sql.ColumnCollection`.""" raise NotImplementedError() - def remove(self, column: Any) -> NoReturn: + def remove(self, column: Any) -> None: raise NotImplementedError() def update(self, iter_: Any) -> NoReturn: @@ -1722,7 +1698,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): raise NotImplementedError() # https://github.com/python/mypy/issues/4266 - __hash__: Optional[int] = None # type: ignore + __hash__ = None # type: ignore def _populate_separate_keys( self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] @@ -1815,7 +1791,7 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): return ReadOnlyColumnCollection(self) - def _init_proxy_index(self) -> None: + def _init_proxy_index(self): """populate the "proxy index", if empty. proxy index is added in 2.0 to provide more efficient operation @@ -1964,15 +1940,16 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - def add( # type: ignore[override] - self, column: _NAMEDCOL, key: Optional[str] = None + def add( + self, column: ColumnElement[Any], key: Optional[str] = None ) -> None: - if key is not None and column.key != key: + named_column = cast(_NAMEDCOL, column) + if key is not None and named_column.key != key: raise exc.ArgumentError( "DedupeColumnCollection requires columns be under " "the same key as their .key" ) - key = column.key + key = named_column.key if key is None: raise exc.ArgumentError( @@ -1982,17 +1959,17 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): if key in self._index: existing = self._index[key][1] - if existing is column: + if existing is named_column: return - self.replace(column) + self.replace(named_column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(column, "proxy_set") + util.memoized_property.reset(named_column, "proxy_set") else: - self._append_new_column(key, column) + self._append_new_column(key, named_column) def _append_new_column(self, key: str, named_column: _NAMEDCOL) -> None: l = len(self._collection) @@ -2034,7 +2011,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: self._populate_separate_keys((col.key, col) for col in iter_) - def remove(self, column: _NAMEDCOL) -> None: # type: ignore[override] + def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -2067,8 +2044,8 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): e.g.:: - t = Table("sometable", metadata, Column("col1", Integer)) - t.columns.replace(Column("col1", Integer, key="columnone")) + t = Table('sometable', metadata, Column('col1', Integer)) + t.columns.replace(Column('col1', Integer, key='columnone')) will remove the original 'col1' from the collection, and add the new column under the name 'columnname'. @@ -2131,17 +2108,17 @@ class ReadOnlyColumnCollection( ): __slots__ = ("_parent",) - def __init__(self, collection: ColumnCollection[_COLKEY, _COL_co]): + def __init__(self, collection): object.__setattr__(self, "_parent", collection) object.__setattr__(self, "_colset", collection._colset) object.__setattr__(self, "_index", collection._index) object.__setattr__(self, "_collection", collection._collection) object.__setattr__(self, "_proxy_index", collection._proxy_index) - def __getstate__(self) -> Dict[str, _COL_co]: + def __getstate__(self): return {"_parent": self._parent} - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state): parent = state["_parent"] self.__init__(parent) # type: ignore @@ -2156,10 +2133,10 @@ class ReadOnlyColumnCollection( class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): - def contains_column(self, col: ColumnClause[Any]) -> bool: + def contains_column(self, col): return col in self - def extend(self, cols: Iterable[Any]) -> None: + def extend(self, cols): for col in cols: self.add(col) @@ -2171,12 +2148,12 @@ class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): l.append(c == local) return elements.and_(*l) - def __hash__(self) -> int: # type: ignore[override] + def __hash__(self): return hash(tuple(x for x in self)) def _entity_namespace( - entity: Union[_HasEntityNamespace, ExternallyTraversible], + entity: Union[_HasEntityNamespace, ExternallyTraversible] ) -> _EntityNamespace: """Return the nearest .entity_namespace for the given entity. diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/cache_key.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/cache_key.py index cec0450..500e3e4 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/cache_key.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/cache_key.py @@ -1,5 +1,5 @@ # sql/cache_key.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -11,7 +11,6 @@ import enum from itertools import zip_longest import typing from typing import Any -from typing import Callable from typing import Dict from typing import Iterable from typing import Iterator @@ -37,7 +36,6 @@ from ..util.typing import Protocol if typing.TYPE_CHECKING: from .elements import BindParameter from .elements import ClauseElement - from .elements import ColumnElement from .visitors import _TraverseInternalsType from ..engine.interfaces import _CoreSingleExecuteParams @@ -45,7 +43,8 @@ if typing.TYPE_CHECKING: class _CacheKeyTraversalDispatchType(Protocol): def __call__( s, self: HasCacheKey, visitor: _CacheKeyTraversal - ) -> _CacheKeyTraversalDispatchTypeReturn: ... + ) -> CacheKey: + ... class CacheConst(enum.Enum): @@ -76,18 +75,6 @@ class CacheTraverseTarget(enum.Enum): ANON_NAME, ) = tuple(CacheTraverseTarget) -_CacheKeyTraversalDispatchTypeReturn = Sequence[ - Tuple[ - str, - Any, - Union[ - Callable[..., Tuple[Any, ...]], - CacheTraverseTarget, - InternalTraversal, - ], - ] -] - class HasCacheKey: """Mixin for objects which can produce a cache key. @@ -303,13 +290,11 @@ class HasCacheKey: result += ( attrname, obj["compile_state_plugin"], - ( - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None - ), + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None, ) elif meth is InternalTraversal.dp_annotations_key: # obj is here is the _annotations dict. Table uses @@ -339,7 +324,7 @@ class HasCacheKey: ), ) else: - result += meth( # type: ignore + result += meth( attrname, obj, self, anon_map, bindparams ) return result @@ -516,7 +501,7 @@ class CacheKey(NamedTuple): e2, ) else: - stack.pop(-1) + pickup_index = stack.pop(-1) break def _diff(self, other: CacheKey) -> str: @@ -558,17 +543,18 @@ class CacheKey(NamedTuple): _anon_map = prefix_anon_map() return {b.key % _anon_map: b.effective_value for b in self.bindparams} - @util.preload_module("sqlalchemy.sql.elements") def _apply_params_to_element( - self, original_cache_key: CacheKey, target_element: ColumnElement[Any] - ) -> ColumnElement[Any]: - if target_element._is_immutable or original_cache_key is self: + self, original_cache_key: CacheKey, target_element: ClauseElement + ) -> ClauseElement: + if target_element._is_immutable: return target_element - elements = util.preloaded.sql_elements - return elements._OverrideBinds( - target_element, self.bindparams, original_cache_key.bindparams - ) + translate = { + k.key: v.value + for k, v in zip(original_cache_key.bindparams, self.bindparams) + } + + return target_element.params(translate) def _ad_hoc_cache_key_from_args( @@ -620,9 +606,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): InternalTraversal.dp_memoized_select_entities ) - visit_string = visit_boolean = visit_operator = visit_plain_obj = ( - CACHE_IN_PLACE - ) + visit_string = ( + visit_boolean + ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY visit_anon_name = ANON_NAME @@ -669,11 +655,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): ) -> Tuple[Any, ...]: return ( attrname, - ( - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj - ), + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj, ) def visit_multi_list( @@ -687,11 +671,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): return ( attrname, tuple( - ( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem - ) + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem for elem in obj ), ) @@ -852,16 +834,12 @@ class _CacheKeyTraversal(HasTraversalDispatch): return tuple( ( target._gen_cache_key(anon_map, bindparams), - ( - onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None - ), - ( - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None - ), + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None, + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None, tuple([(key, flags[key]) for key in sorted(flags)]), ) for (target, onclause, from_, flags) in obj @@ -955,11 +933,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): tuple( ( key, - ( - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value - ), + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value, ) for key, value in [(key, obj[key]) for key in sorted(obj)] ), @@ -1005,11 +981,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): attrname, tuple( ( - ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key - ), + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key, value._gen_cache_key(anon_map, bindparams), ) for key, value in obj @@ -1030,11 +1004,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): attrname, tuple( ( - ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k - ), + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, obj[k]._gen_cache_key(anon_map, bindparams), ) for k in obj diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/coercions.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/coercions.py index ac0393a..c4d3407 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/coercions.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/coercions.py @@ -1,5 +1,5 @@ # sql/coercions.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,6 +29,7 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from . import operators from . import roles from . import visitors from ._typing import is_from_clause @@ -57,9 +58,9 @@ if typing.TYPE_CHECKING: from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement + from .elements import DQLDMLClauseElement from .elements import NamedColumn from .elements import SQLCoreOperations - from .elements import TextClause from .schema import Column from .selectable import _ColumnsClauseElement from .selectable import _JoinTargetProtocol @@ -75,7 +76,7 @@ _StringOnlyR = TypeVar("_StringOnlyR", bound=roles.StringRole) _T = TypeVar("_T", bound=Any) -def _is_literal(element: Any) -> bool: +def _is_literal(element): """Return whether or not the element is a "literal" in the context of a SQL expression construct. @@ -164,7 +165,8 @@ def expect( role: Type[roles.TruncatedLabelRole], element: Any, **kw: Any, -) -> str: ... +) -> str: + ... @overload @@ -174,7 +176,8 @@ def expect( *, as_key: Literal[True] = ..., **kw: Any, -) -> str: ... +) -> str: + ... @overload @@ -182,7 +185,8 @@ def expect( role: Type[roles.LiteralValueRole], element: Any, **kw: Any, -) -> BindParameter[Any]: ... +) -> BindParameter[Any]: + ... @overload @@ -190,7 +194,8 @@ def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, -) -> Union[Column[Any], str]: ... +) -> Column[Any]: + ... @overload @@ -198,7 +203,8 @@ def expect( role: Type[roles.DDLConstraintColumnRole], element: Any, **kw: Any, -) -> Union[Column[Any], str]: ... +) -> Union[Column[Any], str]: + ... @overload @@ -206,7 +212,8 @@ def expect( role: Type[roles.StatementOptionRole], element: Any, **kw: Any, -) -> Union[ColumnElement[Any], TextClause]: ... +) -> DQLDMLClauseElement: + ... @overload @@ -214,7 +221,8 @@ def expect( role: Type[roles.LabeledColumnExprRole[Any]], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> NamedColumn[_T]: ... +) -> NamedColumn[_T]: + ... @overload @@ -226,7 +234,8 @@ def expect( ], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> ColumnElement[_T]: ... +) -> ColumnElement[_T]: + ... @overload @@ -240,7 +249,8 @@ def expect( ], element: Any, **kw: Any, -) -> ColumnElement[Any]: ... +) -> ColumnElement[Any]: + ... @overload @@ -248,7 +258,8 @@ def expect( role: Type[roles.DMLTableRole], element: _DMLTableArgument, **kw: Any, -) -> _DMLTableElement: ... +) -> _DMLTableElement: + ... @overload @@ -256,7 +267,8 @@ def expect( role: Type[roles.HasCTERole], element: HasCTE, **kw: Any, -) -> HasCTE: ... +) -> HasCTE: + ... @overload @@ -264,7 +276,8 @@ def expect( role: Type[roles.SelectStatementRole], element: SelectBase, **kw: Any, -) -> SelectBase: ... +) -> SelectBase: + ... @overload @@ -272,7 +285,8 @@ def expect( role: Type[roles.FromClauseRole], element: _FromClauseArgument, **kw: Any, -) -> FromClause: ... +) -> FromClause: + ... @overload @@ -282,7 +296,8 @@ def expect( *, explicit_subquery: Literal[True] = ..., **kw: Any, -) -> Subquery: ... +) -> Subquery: + ... @overload @@ -290,7 +305,8 @@ def expect( role: Type[roles.ColumnsClauseRole], element: _ColumnsClauseArgument[Any], **kw: Any, -) -> _ColumnsClauseElement: ... +) -> _ColumnsClauseElement: + ... @overload @@ -298,7 +314,8 @@ def expect( role: Type[roles.JoinTargetRole], element: _JoinTargetProtocol, **kw: Any, -) -> _JoinTargetProtocol: ... +) -> _JoinTargetProtocol: + ... # catchall for not-yet-implemented overloads @@ -307,7 +324,8 @@ def expect( role: Type[_SR], element: Any, **kw: Any, -) -> Any: ... +) -> Any: + ... def expect( @@ -492,7 +510,6 @@ class RoleImpl: element: Any, argname: Optional[str] = None, resolved: Optional[Any] = None, - *, advice: Optional[str] = None, code: Optional[str] = None, err: Optional[Exception] = None, @@ -595,7 +612,7 @@ def _no_text_coercion( class _NoTextCoercion(RoleImpl): __slots__ = () - def _literal_coercion(self, element, *, argname=None, **kw): + def _literal_coercion(self, element, argname=None, **kw): if isinstance(element, str) and issubclass( elements.TextClause, self._role_class ): @@ -613,7 +630,7 @@ class _CoerceLiterals(RoleImpl): def _text_coercion(self, element, argname=None): return _no_text_coercion(element, argname) - def _literal_coercion(self, element, *, argname=None, **kw): + def _literal_coercion(self, element, argname=None, **kw): if isinstance(element, str): if self._coerce_star and element == "*": return elements.ColumnClause("*", is_literal=True) @@ -641,8 +658,7 @@ class LiteralValueImpl(RoleImpl): self, element, resolved, - argname=None, - *, + argname, type_=None, literal_execute=False, **kw, @@ -660,7 +676,7 @@ class LiteralValueImpl(RoleImpl): literal_execute=literal_execute, ) - def _literal_coercion(self, element, **kw): + def _literal_coercion(self, element, argname=None, type_=None, **kw): return element @@ -672,7 +688,6 @@ class _SelectIsNotFrom(RoleImpl): element: Any, argname: Optional[str] = None, resolved: Optional[Any] = None, - *, advice: Optional[str] = None, code: Optional[str] = None, err: Optional[Exception] = None, @@ -747,7 +762,7 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): __slots__ = () def _literal_coercion( - self, element, *, name=None, type_=None, is_crud=False, **kw + self, element, name=None, type_=None, argname=None, is_crud=False, **kw ): if ( element is None @@ -789,22 +804,15 @@ class ExpressionElementImpl(_ColumnCoercions, RoleImpl): class BinaryElementImpl(ExpressionElementImpl, RoleImpl): __slots__ = () - def _literal_coercion( # type: ignore[override] - self, - element, - *, - expr, - operator, - bindparam_type=None, - argname=None, - **kw, + def _literal_coercion( + self, element, expr, operator, bindparam_type=None, argname=None, **kw ): try: return expr._bind_param(operator, element, type_=bindparam_type) except exc.ArgumentError as err: self._raise_for_expected(element, err=err) - def _post_coercion(self, resolved, *, expr, bindparam_type=None, **kw): + def _post_coercion(self, resolved, expr, bindparam_type=None, **kw): if resolved.type._isnull and not expr.type._isnull: resolved = resolved._with_binary_element_type( bindparam_type if bindparam_type is not None else expr.type @@ -842,32 +850,31 @@ class InElementImpl(RoleImpl): % (elem.__class__.__name__) ) - @util.preload_module("sqlalchemy.sql.elements") - def _literal_coercion(self, element, *, expr, operator, **kw): # type: ignore[override] # noqa: E501 - if util.is_non_string_iterable(element): + def _literal_coercion(self, element, expr, operator, **kw): + if isinstance(element, collections_abc.Iterable) and not isinstance( + element, str + ): non_literal_expressions: Dict[ - Optional[_ColumnExpressionArgument[Any]], - _ColumnExpressionArgument[Any], + Optional[operators.ColumnOperators], + operators.ColumnOperators, ] = {} element = list(element) for o in element: if not _is_literal(o): - if not isinstance( - o, util.preloaded.sql_elements.ColumnElement - ) and not hasattr(o, "__clause_element__"): + if not isinstance(o, operators.ColumnOperators): self._raise_for_expected(element, **kw) else: non_literal_expressions[o] = o + elif o is None: + non_literal_expressions[o] = elements.Null() if non_literal_expressions: return elements.ClauseList( *[ - ( - non_literal_expressions[o] - if o in non_literal_expressions - else expr._bind_param(operator, o) - ) + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) for o in element ] ) @@ -877,7 +884,7 @@ class InElementImpl(RoleImpl): else: self._raise_for_expected(element, **kw) - def _post_coercion(self, element, *, expr, operator, **kw): + def _post_coercion(self, element, expr, operator, **kw): if element._is_select_base: # for IN, we are doing scalar_subquery() coercion without # a warning @@ -903,10 +910,12 @@ class OnClauseImpl(_ColumnCoercions, RoleImpl): _coerce_consts = True - def _literal_coercion(self, element, **kw): + def _literal_coercion( + self, element, name=None, type_=None, argname=None, is_crud=False, **kw + ): self._raise_for_expected(element) - def _post_coercion(self, resolved, *, original_element=None, **kw): + def _post_coercion(self, resolved, original_element=None, **kw): # this is a hack right now as we want to use coercion on an # ORM InstrumentedAttribute, but we want to return the object # itself if it is one, not its clause element. @@ -991,7 +1000,7 @@ class GroupByImpl(ByOfImpl, RoleImpl): class DMLColumnImpl(_ReturnsStringKey, RoleImpl): __slots__ = () - def _post_coercion(self, element, *, as_key=False, **kw): + def _post_coercion(self, element, as_key=False, **kw): if as_key: return element.key else: @@ -1001,7 +1010,7 @@ class DMLColumnImpl(_ReturnsStringKey, RoleImpl): class ConstExprImpl(RoleImpl): __slots__ = () - def _literal_coercion(self, element, *, argname=None, **kw): + def _literal_coercion(self, element, argname=None, **kw): if element is None: return elements.Null() elif element is False: @@ -1027,7 +1036,7 @@ class TruncatedLabelImpl(_StringOnly, RoleImpl): else: self._raise_for_expected(element, argname, resolved) - def _literal_coercion(self, element, **kw): + def _literal_coercion(self, element, argname=None, **kw): """coerce the given value to :class:`._truncated_label`. Existing :class:`._truncated_label` and @@ -1077,9 +1086,7 @@ class LimitOffsetImpl(RoleImpl): else: self._raise_for_expected(element, argname, resolved) - def _literal_coercion( # type: ignore[override] - self, element, *, name, type_, **kw - ): + def _literal_coercion(self, element, name, type_, **kw): if element is None: return None else: @@ -1121,7 +1128,7 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): _guess_straight_column = re.compile(r"^\w\S*$", re.I) def _raise_for_expected( - self, element, argname=None, resolved=None, *, advice=None, **kw + self, element, argname=None, resolved=None, advice=None, **kw ): if not advice and isinstance(element, list): advice = ( @@ -1145,9 +1152,9 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): % { "column": util.ellipses_string(element), "argname": "for argument %s" % (argname,) if argname else "", - "literal_column": ( - "literal_column" if guess_is_literal else "column" - ), + "literal_column": "literal_column" + if guess_is_literal + else "column", } ) @@ -1159,9 +1166,7 @@ class ReturnsRowsImpl(RoleImpl): class StatementImpl(_CoerceLiterals, RoleImpl): __slots__ = () - def _post_coercion( - self, resolved, *, original_element, argname=None, **kw - ): + def _post_coercion(self, resolved, original_element, argname=None, **kw): if resolved is not original_element and not isinstance( original_element, str ): @@ -1227,7 +1232,7 @@ class JoinTargetImpl(RoleImpl): _skip_clauseelement_for_target_match = True - def _literal_coercion(self, element, *, argname=None, **kw): + def _literal_coercion(self, element, argname=None, **kw): self._raise_for_expected(element, argname) def _implicit_coercions( @@ -1235,7 +1240,6 @@ class JoinTargetImpl(RoleImpl): element: Any, resolved: Any, argname: Optional[str] = None, - *, legacy: bool = False, **kw: Any, ) -> Any: @@ -1269,7 +1273,6 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): element: Any, resolved: Any, argname: Optional[str] = None, - *, explicit_subquery: bool = False, allow_select: bool = True, **kw: Any, @@ -1291,7 +1294,7 @@ class FromClauseImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): else: self._raise_for_expected(element, argname, resolved) - def _post_coercion(self, element, *, deannotate=False, **kw): + def _post_coercion(self, element, deannotate=False, **kw): if deannotate: return element._deannotate() else: @@ -1306,7 +1309,7 @@ class StrictFromClauseImpl(FromClauseImpl): element: Any, resolved: Any, argname: Optional[str] = None, - *, + explicit_subquery: bool = False, allow_select: bool = False, **kw: Any, ) -> Any: @@ -1326,7 +1329,7 @@ class StrictFromClauseImpl(FromClauseImpl): class AnonymizedFromClauseImpl(StrictFromClauseImpl): __slots__ = () - def _post_coercion(self, element, *, flat=False, name=None, **kw): + def _post_coercion(self, element, flat=False, name=None, **kw): assert name is None return element._anonymous_fromclause(flat=flat) diff --git a/venv/lib/python3.12/site-packages/sqlalchemy/sql/compiler.py b/venv/lib/python3.12/site-packages/sqlalchemy/sql/compiler.py index 3f20c93..cb6899c 100644 --- a/venv/lib/python3.12/site-packages/sqlalchemy/sql/compiler.py +++ b/venv/lib/python3.12/site-packages/sqlalchemy/sql/compiler.py @@ -1,5 +1,5 @@ # sql/compiler.py -# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under @@ -29,7 +29,6 @@ import collections import collections.abc as collections_abc import contextlib from enum import IntEnum -import functools import itertools import operator import re @@ -74,49 +73,38 @@ from .base import _de_clone from .base import _from_objects from .base import _NONE_NAME from .base import _SentinelDefaultCharacterization +from .base import Executable from .base import NO_ARG +from .elements import ClauseElement from .elements import quoted_name +from .schema import Column from .sqltypes import TupleType +from .type_api import TypeEngine from .visitors import prefix_anon_map +from .visitors import Visitable from .. import exc from .. import util from ..util import FastIntFlag from ..util.typing import Literal from ..util.typing import Protocol -from ..util.typing import Self from ..util.typing import TypedDict if typing.TYPE_CHECKING: from .annotation import _AnnotationDict from .base import _AmbiguousTableNameMap from .base import CompileState - from .base import Executable from .cache_key import CacheKey from .ddl import ExecutableDDLElement from .dml import Insert - from .dml import Update from .dml import UpdateBase - from .dml import UpdateDMLState from .dml import ValuesBase from .elements import _truncated_label - from .elements import BinaryExpression from .elements import BindParameter - from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement - from .elements import False_ from .elements import Label - from .elements import Null - from .elements import True_ from .functions import Function - from .schema import Column - from .schema import Constraint - from .schema import ForeignKeyConstraint - from .schema import Index - from .schema import PrimaryKeyConstraint from .schema import Table - from .schema import UniqueConstraint - from .selectable import _ColumnsClauseElement from .selectable import AliasedReturnsRows from .selectable import CompoundSelectState from .selectable import CTE @@ -126,10 +114,7 @@ if typing.TYPE_CHECKING: from .selectable import Select from .selectable import SelectState from .type_api import _BindProcessorType - from .type_api import TypeDecorator - from .type_api import TypeEngine - from .type_api import UserDefinedType - from .visitors import Visitable + from .type_api import _SentinelProcessorType from ..engine.cursor import CursorResultMetaData from ..engine.interfaces import _CoreSingleExecuteParams from ..engine.interfaces import _DBAPIAnyExecuteParams @@ -141,7 +126,6 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect from ..engine.interfaces import SchemaTranslateMapType - _FromHintsType = Dict["FromClause", str] RESERVED_WORDS = { @@ -398,7 +382,8 @@ class _ResultMapAppender(Protocol): name: str, objects: Sequence[Any], type_: TypeEngine[Any], - ) -> None: ... + ) -> None: + ... # integer indexes into ResultColumnsEntry used by cursor.py. @@ -561,8 +546,8 @@ class _InsertManyValues(NamedTuple): """ - sentinel_param_keys: Optional[Sequence[str]] = None - """parameter str keys in each param dictionary / tuple + sentinel_param_keys: Optional[Sequence[Union[str, int]]] = None + """parameter str keys / int indexes in each param dictionary / tuple that would link to the client side "sentinel" values for that row, which we can use to match up parameter sets to result rows. @@ -572,10 +557,6 @@ class _InsertManyValues(NamedTuple): .. versionadded:: 2.0.10 - .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys - only, used against the "compiled parameteters" collection before - the parameters were converted by bound parameter processors - """ implicit_sentinel: bool = False @@ -620,8 +601,7 @@ class _InsertManyValuesBatch(NamedTuple): replaced_parameters: _DBAPIAnyExecuteParams processed_setinputsizes: Optional[_GenericSetInputSizesType] batch: Sequence[_DBAPISingleExecuteParams] - sentinel_values: Sequence[Tuple[Any, ...]] - current_batch_size: int + batch_size: int batchnum: int total_batches: int rows_sorted: bool @@ -757,6 +737,7 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): class Compiled: + """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce @@ -886,7 +867,6 @@ class Compiled: self.string = self.process(self.statement, **compile_kwargs) if render_schema_translate: - assert schema_translate_map is not None self.string = self.preparer._render_schema_translates( self.string, schema_translate_map ) @@ -919,7 +899,7 @@ class Compiled: raise exc.UnsupportedCompilationError(self, type(element)) from err @property - def sql_compiler(self) -> SQLCompiler: + def sql_compiler(self): """Return a Compiled that is capable of processing SQL expressions. If this compiler is one, it would likely just return 'self'. @@ -987,6 +967,7 @@ class TypeCompiler(util.EnsureKWArg): class _CompileLabel( roles.BinaryElementRole[Any], elements.CompilerColumnElement ): + """lightweight label object which acts as an expression.Label.""" __visit_name__ = "label" @@ -1056,19 +1037,19 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP - bindname_escape_characters: ClassVar[Mapping[str, str]] = ( - util.immutabledict( - { - "%": "P", - "(": "A", - ")": "Z", - ":": "C", - ".": "_", - "[": "_", - "]": "_", - " ": "_", - } - ) + bindname_escape_characters: ClassVar[ + Mapping[str, str] + ] = util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } ) """A mapping (e.g. dict or similar) containing a lookup of characters keyed to replacement characters which will be applied to all @@ -1362,7 +1343,6 @@ class SQLCompiler(Compiled): column_keys: Optional[Sequence[str]] = None, for_executemany: bool = False, linting: Linting = NO_LINTING, - _supporting_against: Optional[SQLCompiler] = None, **kwargs: Any, ): """Construct a new :class:`.SQLCompiler` object. @@ -1465,24 +1445,6 @@ class SQLCompiler(Compiled): self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle] - if _supporting_against: - self.__dict__.update( - { - k: v - for k, v in _supporting_against.__dict__.items() - if k - not in { - "state", - "dialect", - "preparer", - "positional", - "_numeric_binds", - "compilation_bindtemplate", - "bindtemplate", - } - } - ) - if self.state is CompilerState.STRING_APPLIED: if self.positional: if self._numeric_binds: @@ -1697,9 +1659,19 @@ class SQLCompiler(Compiled): for v in self._insertmanyvalues.insert_crud_params ] + sentinel_param_int_idxs = ( + [ + self.positiontup.index(cast(str, _param_key)) + for _param_key in self._insertmanyvalues.sentinel_param_keys # noqa: E501 + ] + if self._insertmanyvalues.sentinel_param_keys is not None + else None + ) + self._insertmanyvalues = self._insertmanyvalues._replace( single_values_expr=single_values_expr, insert_crud_params=insert_crud_params, + sentinel_param_keys=sentinel_param_int_idxs, ) def _process_numeric(self): @@ -1768,11 +1740,21 @@ class SQLCompiler(Compiled): for v in self._insertmanyvalues.insert_crud_params ] + sentinel_param_int_idxs = ( + [ + self.positiontup.index(cast(str, _param_key)) + for _param_key in self._insertmanyvalues.sentinel_param_keys # noqa: E501 + ] + if self._insertmanyvalues.sentinel_param_keys is not None + else None + ) + self._insertmanyvalues = self._insertmanyvalues._replace( # This has the numbers (:1, :2) single_values_expr=single_values_expr, # The single binds are instead %s so they can be formatted insert_crud_params=insert_crud_params, + sentinel_param_keys=sentinel_param_int_idxs, ) @util.memoized_property @@ -1788,15 +1770,11 @@ class SQLCompiler(Compiled): for key, value in ( ( self.bind_names[bindparam], - ( - bindparam.type._cached_bind_processor(self.dialect) - if not bindparam.type._is_tuple_type - else tuple( - elem_type._cached_bind_processor(self.dialect) - for elem_type in cast( - TupleType, bindparam.type - ).types - ) + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in cast(TupleType, bindparam.type).types ), ) for bindparam in self.bind_names @@ -1804,11 +1782,28 @@ class SQLCompiler(Compiled): if value is not None } + @util.memoized_property + def _imv_sentinel_value_resolvers( + self, + ) -> Optional[Sequence[Optional[_SentinelProcessorType[Any]]]]: + imv = self._insertmanyvalues + if imv is None or imv.sentinel_columns is None: + return None + + sentinel_value_resolvers = [ + _scol.type._cached_sentinel_value_processor(self.dialect) + for _scol in imv.sentinel_columns + ] + if util.NONE_SET.issuperset(sentinel_value_resolvers): + return None + else: + return sentinel_value_resolvers + def is_subquery(self): return len(self.stack) > 1 @property - def sql_compiler(self) -> Self: + def sql_compiler(self): return self def construct_expanded_state( @@ -2085,11 +2080,11 @@ class SQLCompiler(Compiled): if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: - replacement_expressions[escaped_name] = ( - self.render_literal_bindparam( - parameter, - render_literal_value=parameters.pop(escaped_name), - ) + replacement_expressions[ + escaped_name + ] = self.render_literal_bindparam( + parameter, + render_literal_value=parameters.pop(escaped_name), ) continue @@ -2298,14 +2293,12 @@ class SQLCompiler(Compiled): else: return row_fn( ( - ( - autoinc_getter(lastrowid, parameters) - if autoinc_getter is not None - else lastrowid - ) - if col is autoinc_col - else getter(parameters) + autoinc_getter(lastrowid, parameters) + if autoinc_getter is not None + else lastrowid ) + if col is autoinc_col + else getter(parameters) for getter, col in getters ) @@ -2314,7 +2307,10 @@ class SQLCompiler(Compiled): @util.memoized_property @util.preload_module("sqlalchemy.engine.result") def _inserted_primary_key_from_returning_getter(self): - result = util.preloaded.engine_result + if typing.TYPE_CHECKING: + from ..engine import result + else: + result = util.preloaded.engine_result assert self.compile_state is not None statement = self.compile_state.statement @@ -2332,15 +2328,11 @@ class SQLCompiler(Compiled): getters = cast( "List[Tuple[Callable[[Any], Any], bool]]", [ - ( - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller( - "get", param_key_getter(col), None - ), - False, - ) + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller("get", param_key_getter(col), None), + False, ) for col in table.primary_key ], @@ -2356,80 +2348,15 @@ class SQLCompiler(Compiled): return get - def default_from(self) -> str: + def default_from(self): """Called when a SELECT statement has no froms, and no FROM clause is to be appended. - Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string - output. + Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output. """ return "" - def visit_override_binds(self, override_binds, **kw): - """SQL compile the nested element of an _OverrideBinds with - bindparams swapped out. - - The _OverrideBinds is not normally expected to be compiled; it - is meant to be used when an already cached statement is to be used, - the compilation was already performed, and only the bound params should - be swapped in at execution time. - - However, there are test cases that exericise this object, and - additionally the ORM subquery loader is known to feed in expressions - which include this construct into new queries (discovered in #11173), - so it has to do the right thing at compile time as well. - - """ - - # get SQL text first - sqltext = override_binds.element._compiler_dispatch(self, **kw) - - # for a test compile that is not for caching, change binds after the - # fact. note that we don't try to - # swap the bindparam as we compile, because our element may be - # elsewhere in the statement already (e.g. a subquery or perhaps a - # CTE) and was already visited / compiled. See - # test_relationship_criteria.py -> - # test_selectinload_local_criteria_subquery - for k in override_binds.translate: - if k not in self.binds: - continue - bp = self.binds[k] - - # so this would work, just change the value of bp in place. - # but we dont want to mutate things outside. - # bp.value = override_binds.translate[bp.key] - # continue - - # instead, need to replace bp with new_bp or otherwise accommodate - # in all internal collections - new_bp = bp._with_value( - override_binds.translate[bp.key], - maintain_key=True, - required=False, - ) - - name = self.bind_names[bp] - self.binds[k] = self.binds[name] = new_bp - self.bind_names[new_bp] = name - self.bind_names.pop(bp, None) - - if bp in self.post_compile_params: - self.post_compile_params |= {new_bp} - if bp in self.literal_execute_params: - self.literal_execute_params |= {new_bp} - - ckbm_tuple = self._cache_key_bind_match - if ckbm_tuple: - ckbm, cksm = ckbm_tuple - for bp in bp._cloned_set: - if bp.key in cksm: - cb = cksm[bp.key] - ckbm[cb].append(new_bp) - - return sqltext - def visit_grouping(self, grouping, asfrom=False, **kwargs): return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")" @@ -2474,9 +2401,9 @@ class SQLCompiler(Compiled): resolve_dict[order_by_elem.name] ) ): - kwargs["render_label_as_label"] = ( - element.element._order_by_label_element - ) + kwargs[ + "render_label_as_label" + ] = element.element._order_by_label_element return self.process( element.element, within_columns_clause=within_columns_clause, @@ -2579,7 +2506,7 @@ class SQLCompiler(Compiled): def _fallback_column_name(self, column): raise exc.CompileError( - "Cannot compile Column object until its 'name' is assigned." + "Cannot compile Column object until " "its 'name' is assigned." ) def visit_lambda_element(self, element, **kw): @@ -2722,9 +2649,9 @@ class SQLCompiler(Compiled): ) if populate_result_map: - self._ordered_columns = self._textual_ordered_columns = ( - taf.positional - ) + self._ordered_columns = ( + self._textual_ordered_columns + ) = taf.positional # enable looser result column matching when the SQL text links to # Column objects by name only @@ -2748,16 +2675,16 @@ class SQLCompiler(Compiled): return text - def visit_null(self, expr: Null, **kw: Any) -> str: + def visit_null(self, expr, **kw): return "NULL" - def visit_true(self, expr: True_, **kw: Any) -> str: + def visit_true(self, expr, **kw): if self.dialect.supports_native_boolean: return "true" else: return "1" - def visit_false(self, expr: False_, **kw: Any) -> str: + def visit_false(self, expr, **kw): if self.dialect.supports_native_boolean: return "false" else: @@ -2851,60 +2778,36 @@ class SQLCompiler(Compiled): def _format_frame_clause(self, range_, **kw): return "%s AND %s" % ( - ( - "UNBOUNDED PRECEDING" - if range_[0] is elements.RANGE_UNBOUNDED - else ( - "CURRENT ROW" - if range_[0] is elements.RANGE_CURRENT - else ( - "%s PRECEDING" - % ( - self.process( - elements.literal(abs(range_[0])), **kw - ), - ) - if range_[0] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[0]), **kw),) - ) - ) - ), - ( - "UNBOUNDED FOLLOWING" - if range_[1] is elements.RANGE_UNBOUNDED - else ( - "CURRENT ROW" - if range_[1] is elements.RANGE_CURRENT - else ( - "%s PRECEDING" - % ( - self.process( - elements.literal(abs(range_[1])), **kw - ), - ) - if range_[1] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[1]), **kw),) - ) - ) - ), + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[0])), **kw),) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),), + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else "%s PRECEDING" + % (self.process(elements.literal(abs(range_[1])), **kw),) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),), ) def visit_over(self, over, **kwargs): text = over.element._compiler_dispatch(self, **kwargs) - if over.range_ is not None: + if over.range_: range_ = "RANGE BETWEEN %s" % self._format_frame_clause( over.range_, **kwargs ) - elif over.rows is not None: + elif over.rows: range_ = "ROWS BETWEEN %s" % self._format_frame_clause( over.rows, **kwargs ) - elif over.groups is not None: - range_ = "GROUPS BETWEEN %s" % self._format_frame_clause( - over.groups, **kwargs - ) else: range_ = None @@ -2955,7 +2858,7 @@ class SQLCompiler(Compiled): **kwargs: Any, ) -> str: if add_to_result_map is not None: - add_to_result_map(func.name, func.name, (func.name,), func.type) + add_to_result_map(func.name, func.name, (), func.type) disp = getattr(self, "visit_%s_func" % func.name.lower(), None) @@ -3003,7 +2906,7 @@ class SQLCompiler(Compiled): % self.dialect.name ) - def function_argspec(self, func: Function[Any], **kwargs: Any) -> str: + def function_argspec(self, func, **kwargs): return func.clause_expr._compiler_dispatch(self, **kwargs) def visit_compound_select( @@ -3133,12 +3036,9 @@ class SQLCompiler(Compiled): + self.process( elements.Cast( binary.right, - ( - binary.right.type - if binary.right.type._type_affinity - is sqltypes.Numeric - else sqltypes.Numeric() - ), + binary.right.type + if binary.right.type._type_affinity is sqltypes.Numeric + else sqltypes.Numeric(), ), **kw, ) @@ -3467,12 +3367,8 @@ class SQLCompiler(Compiled): ) def _generate_generic_binary( - self, - binary: BinaryExpression[Any], - opstring: str, - eager_grouping: bool = False, - **kw: Any, - ) -> str: + self, binary, opstring, eager_grouping=False, **kw + ): _in_operator_expression = kw.get("_in_operator_expression", False) kw["_in_operator_expression"] = True @@ -3641,25 +3537,19 @@ class SQLCompiler(Compiled): **kw, ) - def visit_regexp_match_op_binary( - self, binary: BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_regexp_match_op_binary(self, binary, operator, **kw): raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_not_regexp_match_op_binary( - self, binary: BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_not_regexp_match_op_binary(self, binary, operator, **kw): raise exc.CompileError( "%s dialect does not support regular expressions" % self.dialect.name ) - def visit_regexp_replace_op_binary( - self, binary: BinaryExpression[Any], operator: Any, **kw: Any - ) -> str: + def visit_regexp_replace_op_binary(self, binary, operator, **kw): raise exc.CompileError( "%s dialect does not support regular expression replacements" % self.dialect.name @@ -3675,7 +3565,6 @@ class SQLCompiler(Compiled): render_postcompile=False, **kwargs, ): - if not skip_bind_expression: impl = bindparam.type.dialect_impl(self.dialect) if impl._has_bind_expression: @@ -3866,9 +3755,7 @@ class SQLCompiler(Compiled): else: return self.render_literal_value(value, bindparam.type) - def render_literal_value( - self, value: Any, type_: sqltypes.TypeEngine[Any] - ) -> str: + def render_literal_value(self, value, type_): """Render the value of a bind parameter as a quoted literal. This is used for statement sections that do not accept bind parameters @@ -4104,28 +3991,15 @@ class SQLCompiler(Compiled): del self.level_name_by_cte[existing_cte_reference_cte] else: + # if the two CTEs are deep-copy identical, consider them + # the same, **if** they are clones, that is, they came from + # the ORM or other visit method if ( - # if the two CTEs have the same hash, which we expect - # here means that one/both is an annotated of the other - (hash(cte) == hash(existing_cte)) - # or... - or ( - ( - # if they are clones, i.e. they came from the ORM - # or some other visit method - cte._is_clone_of is not None - or existing_cte._is_clone_of is not None - ) - # and are deep-copy identical - and cte.compare(existing_cte) - ) - ): - # then consider these two CTEs the same + cte._is_clone_of is not None + or existing_cte._is_clone_of is not None + ) and cte.compare(existing_cte): is_new_cte = False else: - # otherwise these are two CTEs that either will render - # differently, or were indicated separately by the user, - # with the same name raise exc.CompileError( "Multiple, unrelated CTEs found with " "the same name: %r" % cte_name @@ -4158,7 +4032,7 @@ class SQLCompiler(Compiled): if cte.recursive: self.ctes_recursive = True text = self.preparer.format_alias(cte, cte_name) - if cte.recursive or cte.element.name_cte_columns: + if cte.recursive: col_source = cte.element # TODO: can we get at the .columns_plus_names collection @@ -4227,7 +4101,7 @@ class SQLCompiler(Compiled): if self.preparer._requires_quotes(cte_name): cte_name = self.preparer.quote(cte_name) text += self.get_render_as_alias_suffix(cte_name) - return text # type: ignore[no-any-return] + return text else: return self.preparer.format_alias(cte, cte_name) @@ -4289,7 +4163,7 @@ class SQLCompiler(Compiled): inner = "(%s)" % (inner,) return inner else: - kwargs["enclosing_alias"] = alias + enclosing_alias = kwargs["enclosing_alias"] = alias if asfrom or ashint: if isinstance(alias.name, elements._truncated_label): @@ -4319,14 +4193,12 @@ class SQLCompiler(Compiled): "%s%s" % ( self.preparer.quote(col.name), - ( - " %s" - % self.dialect.type_compiler_instance.process( - col.type, **kwargs - ) - if alias._render_derived_w_types - else "" - ), + " %s" + % self.dialect.type_compiler_instance.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "", ) for col in alias.c ) @@ -4379,13 +4251,7 @@ class SQLCompiler(Compiled): ) return f"VALUES {tuples}" - def visit_values( - self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw - ): - - if element._independent_ctes: - self._dispatch_independent_ctes(element, kw) - + def visit_values(self, element, asfrom=False, from_linter=None, **kw): v = self._render_values(element, **kw) if element._unnamed: @@ -4406,12 +4272,7 @@ class SQLCompiler(Compiled): name if name is not None else "(unnamed VALUES element)" ) - if visiting_cte is not None and visiting_cte.element is element: - if element._is_lateral: - raise exc.CompileError( - "Can't use a LATERAL VALUES expression inside of a CTE" - ) - elif name: + if name: kw["include_table"] = False v = "%s(%s)%s (%s)" % ( lateral, @@ -4441,11 +4302,6 @@ class SQLCompiler(Compiled): objects: Tuple[Any, ...], type_: TypeEngine[Any], ) -> None: - - # note objects must be non-empty for cursor.py to handle the - # collection properly - assert objects - if keyname is None or keyname == "*": self._ordered_columns = False self._ad_hoc_textual = True @@ -4519,7 +4375,7 @@ class SQLCompiler(Compiled): _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - _add_to_result_map(keyname, name, (keyname,), type_) + _add_to_result_map(keyname, name, (), type_) # if we redefined col_expr for type expressions, wrap the # callable with one that adds the original column to the targets @@ -4595,52 +4451,7 @@ class SQLCompiler(Compiled): elif isinstance(column, elements.TextClause): render_with_label = False elif isinstance(column, elements.UnaryExpression): - # unary expression. notes added as of #12681 - # - # By convention, the visit_unary() method - # itself does not add an entry to the result map, and relies - # upon either the inner expression creating a result map - # entry, or if not, by creating a label here that produces - # the result map entry. Where that happens is based on whether - # or not the element immediately inside the unary is a - # NamedColumn subclass or not. - # - # Now, this also impacts how the SELECT is written; if - # we decide to generate a label here, we get the usual - # "~(x+y) AS anon_1" thing in the columns clause. If we - # don't, we don't get an AS at all, we get like - # "~table.column". - # - # But here is the important thing as of modernish (like 1.4) - # versions of SQLAlchemy - **whether or not the AS