151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
from __future__ import annotations
|
||
import time
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any, Dict, Optional, List
|
||
from threading import Lock
|
||
|
||
from app.bots.editor.messages import MessageType
|
||
from app.models.post import PostType
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DEFAULT_TTL = 60 * 60 # 1 час
|
||
|
||
# Тип сообщения, используемый в сессии
|
||
SessionType = MessageType | PostType
|
||
|
||
|
||
@dataclass
|
||
class UserSession:
|
||
"""Сессия пользователя при создании поста."""
|
||
|
||
# Основные данные поста
|
||
channel_id: Optional[int] = None
|
||
type: Optional[SessionType] = None
|
||
parse_mode: Optional[str] = None # HTML/MarkdownV2
|
||
text: Optional[str] = None
|
||
media_file_id: Optional[str] = None
|
||
keyboard: Optional[dict] = None # {"rows": [[{"text","url"}], ...]}
|
||
|
||
# Данные шаблона
|
||
template_name: Optional[str] = None
|
||
template_id: Optional[str] = None
|
||
template_vars: Dict[str, str] = field(default_factory=dict)
|
||
missing_vars: List[str] = field(default_factory=list)
|
||
|
||
# Метаданные отправки
|
||
schedule_time: Optional[datetime] = None
|
||
|
||
def update(self, data: Dict[str, Any]) -> None:
|
||
"""Обновляет поля сессии из словаря."""
|
||
for key, value in data.items():
|
||
if hasattr(self, key):
|
||
setattr(self, key, value)
|
||
|
||
# Метаданные
|
||
last_activity: float = field(default_factory=time.time)
|
||
state: Optional[int] = None
|
||
|
||
def touch(self) -> None:
|
||
"""Обновляет время последней активности."""
|
||
self.last_activity = time.time()
|
||
|
||
def clear(self) -> None:
|
||
"""Очищает все данные сессии."""
|
||
self.channel_id = None
|
||
self.type = None
|
||
self.parse_mode = None
|
||
self.text = None
|
||
self.media_file_id = None
|
||
self.keyboard = None
|
||
self.template_name = None
|
||
self.template_vars.clear()
|
||
self.missing_vars.clear()
|
||
self.state = None
|
||
self.touch()
|
||
|
||
def is_complete(self) -> bool:
|
||
"""Проверяет, заполнены ли все необходимые поля."""
|
||
if not self.channel_id or not self.type:
|
||
return False
|
||
|
||
if self.type == MessageType.TEXT:
|
||
return bool(self.text)
|
||
else:
|
||
return bool(self.text and self.media_file_id)
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""Конвертирует сессию в словарь для отправки."""
|
||
return {
|
||
"type": self.type.value if self.type else None,
|
||
"text": self.text,
|
||
"media_file_id": self.media_file_id,
|
||
"parse_mode": self.parse_mode or "HTML",
|
||
"keyboard": self.keyboard,
|
||
"template_id": self.template_id,
|
||
"template_name": self.template_name,
|
||
"template_vars": self.template_vars
|
||
}
|
||
|
||
as_dict = to_dict
|
||
|
||
|
||
class SessionStore:
|
||
"""Thread-safe хранилище сессий с автоочисткой."""
|
||
|
||
_instance: Optional["SessionStore"] = None
|
||
|
||
def __init__(self, ttl: int = DEFAULT_TTL) -> None:
|
||
self._data: Dict[int, UserSession] = {}
|
||
self._ttl = ttl
|
||
self._lock = Lock()
|
||
|
||
@classmethod
|
||
def get_instance(cls) -> "SessionStore":
|
||
"""Возвращает глобальный экземпляр."""
|
||
if cls._instance is None:
|
||
cls._instance = cls()
|
||
return cls._instance
|
||
|
||
def get(self, uid: int) -> UserSession:
|
||
"""Получает или создает сессию пользователя."""
|
||
with self._lock:
|
||
s = self._data.get(uid)
|
||
if not s:
|
||
s = UserSession()
|
||
self._data[uid] = s
|
||
s.touch()
|
||
self._cleanup()
|
||
return s
|
||
|
||
def drop(self, uid: int) -> None:
|
||
"""Удаляет сессию пользователя."""
|
||
with self._lock:
|
||
if uid in self._data:
|
||
logger.info(f"Dropping session for user {uid}")
|
||
del self._data[uid]
|
||
|
||
def _cleanup(self) -> None:
|
||
"""Удаляет истекшие сессии."""
|
||
now = time.time()
|
||
expired = []
|
||
|
||
for uid, session in self._data.items():
|
||
if now - session.last_activity > self._ttl:
|
||
expired.append(uid)
|
||
|
||
for uid in expired:
|
||
logger.info(f"Session expired for user {uid}")
|
||
del self._data[uid]
|
||
|
||
def get_active_count(self) -> int:
|
||
"""Возвращает количество активных сессий."""
|
||
return len(self._data)
|
||
|
||
|
||
def get_session_store() -> SessionStore:
|
||
"""Возвращает глобальный экземпляр хранилища сессий."""
|
||
return SessionStore.get_instance()
|