206 lines
8.0 KiB
Python
206 lines
8.0 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.future import select
|
||
from sqlalchemy.orm import selectinload
|
||
from app.models import Group, Message, MessageGroup
|
||
from datetime import datetime, timedelta
|
||
from typing import List, Optional
|
||
|
||
|
||
class GroupRepository:
|
||
"""Репозиторий для работы с группами"""
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
self.session = session
|
||
|
||
async def add_group(self, chat_id: str, title: str, slow_mode_delay: int = 0) -> Group:
|
||
"""Добавить новую группу"""
|
||
group = Group(
|
||
chat_id=chat_id,
|
||
title=title,
|
||
slow_mode_delay=slow_mode_delay
|
||
)
|
||
self.session.add(group)
|
||
await self.session.commit()
|
||
await self.session.refresh(group)
|
||
return group
|
||
|
||
async def get_group_by_chat_id(self, chat_id: str) -> Optional[Group]:
|
||
"""Получить группу по ID чата"""
|
||
result = await self.session.execute(
|
||
select(Group).where(Group.chat_id == chat_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_all_active_groups(self) -> List[Group]:
|
||
"""Получить все активные группы"""
|
||
result = await self.session.execute(
|
||
select(Group).where(Group.is_active == True)
|
||
)
|
||
return result.scalars().all()
|
||
|
||
async def update_group_slow_mode(self, group_id: int, delay: int) -> None:
|
||
"""Обновить slow mode задержку группы"""
|
||
group = await self.session.get(Group, group_id)
|
||
if group:
|
||
group.slow_mode_delay = delay
|
||
group.updated_at = datetime.utcnow()
|
||
await self.session.commit()
|
||
|
||
async def update_last_message_time(self, group_id: int) -> None:
|
||
"""Обновить время последнего сообщения"""
|
||
group = await self.session.get(Group, group_id)
|
||
if group:
|
||
group.last_message_time = datetime.utcnow()
|
||
await self.session.commit()
|
||
|
||
async def deactivate_group(self, group_id: int) -> None:
|
||
"""Деактивировать группу"""
|
||
group = await self.session.get(Group, group_id)
|
||
if group:
|
||
group.is_active = False
|
||
await self.session.commit()
|
||
|
||
async def activate_group(self, group_id: int) -> None:
|
||
"""Активировать группу"""
|
||
group = await self.session.get(Group, group_id)
|
||
if group:
|
||
group.is_active = True
|
||
await self.session.commit()
|
||
|
||
|
||
class MessageRepository:
|
||
"""Репозиторий для работы с сообщениями"""
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
self.session = session
|
||
|
||
async def add_message(self, text: str, title: str, parse_mode: str = 'HTML') -> Message:
|
||
"""Добавить новое сообщение"""
|
||
message = Message(
|
||
text=text,
|
||
title=title,
|
||
parse_mode=parse_mode
|
||
)
|
||
self.session.add(message)
|
||
await self.session.commit()
|
||
await self.session.refresh(message)
|
||
return message
|
||
|
||
async def get_message(self, message_id: int) -> Optional[Message]:
|
||
"""Получить сообщение по ID"""
|
||
result = await self.session.execute(
|
||
select(Message).where(Message.id == message_id)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_all_messages(self, active_only: bool = True) -> List[Message]:
|
||
"""Получить все сообщения"""
|
||
query = select(Message)
|
||
if active_only:
|
||
query = query.where(Message.is_active == True)
|
||
result = await self.session.execute(query)
|
||
return result.scalars().all()
|
||
|
||
async def update_message(self, message_id: int, text: str = None, title: str = None) -> None:
|
||
"""Обновить сообщение"""
|
||
message = await self.session.get(Message, message_id)
|
||
if message:
|
||
if text:
|
||
message.text = text
|
||
if title:
|
||
message.title = title
|
||
message.updated_at = datetime.utcnow()
|
||
await self.session.commit()
|
||
|
||
async def deactivate_message(self, message_id: int) -> None:
|
||
"""Деактивировать сообщение"""
|
||
message = await self.session.get(Message, message_id)
|
||
if message:
|
||
message.is_active = False
|
||
await self.session.commit()
|
||
|
||
async def delete_message(self, message_id: int) -> None:
|
||
"""Удалить сообщение"""
|
||
message = await self.session.get(Message, message_id)
|
||
if message:
|
||
await self.session.delete(message)
|
||
await self.session.commit()
|
||
|
||
|
||
class MessageGroupRepository:
|
||
"""Репозиторий для работы со связями сообщение-группа"""
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
self.session = session
|
||
|
||
async def add_message_to_group(self, message_id: int, group_id: int) -> MessageGroup:
|
||
"""Добавить сообщение в группу"""
|
||
# Проверить, не существует ли уже
|
||
result = await self.session.execute(
|
||
select(MessageGroup).where(
|
||
(MessageGroup.message_id == message_id) &
|
||
(MessageGroup.group_id == group_id)
|
||
)
|
||
)
|
||
existing = result.scalar_one_or_none()
|
||
if existing:
|
||
return existing
|
||
|
||
link = MessageGroup(message_id=message_id, group_id=group_id)
|
||
self.session.add(link)
|
||
await self.session.commit()
|
||
await self.session.refresh(link)
|
||
return link
|
||
|
||
async def get_message_groups_to_send(self, message_id: int) -> List[MessageGroup]:
|
||
"""Получить группы, куда еще не отправлено сообщение"""
|
||
result = await self.session.execute(
|
||
select(MessageGroup)
|
||
.where((MessageGroup.message_id == message_id) & (MessageGroup.is_sent == False))
|
||
.options(selectinload(MessageGroup.group))
|
||
)
|
||
return result.scalars().all()
|
||
|
||
async def get_unsent_messages_for_group(self, group_id: int) -> List[MessageGroup]:
|
||
"""Получить неотправленные сообщения для группы"""
|
||
result = await self.session.execute(
|
||
select(MessageGroup)
|
||
.where((MessageGroup.group_id == group_id) & (MessageGroup.is_sent == False))
|
||
.options(selectinload(MessageGroup.message))
|
||
)
|
||
return result.scalars().all()
|
||
|
||
async def mark_as_sent(self, message_group_id: int, error: str = None) -> None:
|
||
"""Отметить как отправленное"""
|
||
link = await self.session.get(MessageGroup, message_group_id)
|
||
if link:
|
||
link.is_sent = True
|
||
link.sent_at = datetime.utcnow()
|
||
if error:
|
||
link.error = error
|
||
link.is_sent = False
|
||
await self.session.commit()
|
||
|
||
async def get_messages_for_group(self, group_id: int) -> List[MessageGroup]:
|
||
"""Получить все сообщения для группы с их статусом"""
|
||
result = await self.session.execute(
|
||
select(MessageGroup)
|
||
.where(MessageGroup.group_id == group_id)
|
||
.options(selectinload(MessageGroup.message))
|
||
.order_by(MessageGroup.created_at.desc())
|
||
)
|
||
return result.scalars().all()
|
||
|
||
async def remove_message_from_group(self, message_id: int, group_id: int) -> None:
|
||
"""Удалить сообщение из группы"""
|
||
result = await self.session.execute(
|
||
select(MessageGroup).where(
|
||
(MessageGroup.message_id == message_id) &
|
||
(MessageGroup.group_id == group_id)
|
||
)
|
||
)
|
||
link = result.scalar_one_or_none()
|
||
if link:
|
||
await self.session.delete(link)
|
||
await self.session.commit()
|