fix: Resolve import issues and test compatibility
Some checks reported errors
continuous-integration/drone/push Build was killed
Some checks reported errors
continuous-integration/drone/push Build was killed
- Fix Storage class reference in authentication tests - Add secret_key parameter to AgentAuthentication initialization - Fix timedelta import in sessions.py - Basic authentication functionality verified
This commit is contained in:
488
.history/src/sessions_20251125213303.py
Normal file
488
.history/src/sessions_20251125213303.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Sessions module для PyGuardian
|
||||
Управление SSH сессиями и процессами пользователей
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
import psutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Менеджер SSH сессий и пользовательских процессов"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def get_active_sessions(self) -> List[Dict]:
|
||||
"""Получение всех активных SSH сессий"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
# Метод 1: через who
|
||||
who_sessions = await self._get_sessions_via_who()
|
||||
sessions.extend(who_sessions)
|
||||
|
||||
# Метод 2: через ps (для SSH процессов)
|
||||
ssh_sessions = await self._get_sessions_via_ps()
|
||||
sessions.extend(ssh_sessions)
|
||||
|
||||
# Убираем дубликаты и объединяем информацию
|
||||
unique_sessions = self._merge_session_info(sessions)
|
||||
|
||||
return unique_sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения активных сессий: {e}")
|
||||
return []
|
||||
|
||||
async def _get_sessions_via_who(self) -> List[Dict]:
|
||||
"""Получение сессий через команду who"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'who', '-u',
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
lines = stdout.decode().strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
# Парсим вывод who
|
||||
# Формат: user tty date time (idle) pid (comment)
|
||||
match = re.match(
|
||||
r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)',
|
||||
line
|
||||
)
|
||||
if match:
|
||||
username, tty, date, time, pid = match.groups()
|
||||
sessions.append({
|
||||
'username': username,
|
||||
'tty': tty,
|
||||
'login_date': date,
|
||||
'login_time': time,
|
||||
'pid': int(pid),
|
||||
'type': 'who',
|
||||
'status': 'active'
|
||||
})
|
||||
else:
|
||||
# Альтернативный парсинг для разных форматов who
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
username = parts[0]
|
||||
tty = parts[1]
|
||||
|
||||
# Ищем PID в скобках
|
||||
pid_match = re.search(r'\((\d+)\)', line)
|
||||
pid = int(pid_match.group(1)) if pid_match else None
|
||||
|
||||
sessions.append({
|
||||
'username': username,
|
||||
'tty': tty,
|
||||
'pid': pid,
|
||||
'type': 'who',
|
||||
'status': 'active',
|
||||
'raw_line': line
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения сессий через who: {e}")
|
||||
return []
|
||||
|
||||
async def _get_sessions_via_ps(self) -> List[Dict]:
|
||||
"""Получение SSH сессий через ps"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
# Ищем SSH процессы
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'ps', 'aux',
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
lines = stdout.decode().strip().split('\n')
|
||||
for line in lines[1:]: # Пропускаем заголовок
|
||||
if 'sshd:' in line and '@pts' in line:
|
||||
# Парсим SSH сессии
|
||||
parts = line.split()
|
||||
if len(parts) >= 11:
|
||||
username = parts[0]
|
||||
pid = int(parts[1])
|
||||
|
||||
# Извлекаем информацию из команды
|
||||
cmd_parts = ' '.join(parts[10:])
|
||||
|
||||
# Ищем пользователя и tty в команде sshd
|
||||
match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts)
|
||||
if match:
|
||||
ssh_user, tty = match.groups()
|
||||
|
||||
sessions.append({
|
||||
'username': ssh_user,
|
||||
'tty': tty,
|
||||
'pid': pid,
|
||||
'ppid': int(parts[2]),
|
||||
'cpu': parts[2],
|
||||
'mem': parts[3],
|
||||
'start_time': parts[8],
|
||||
'type': 'sshd',
|
||||
'status': 'active',
|
||||
'command': cmd_parts
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения SSH сессий через ps: {e}")
|
||||
return []
|
||||
|
||||
def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Объединение информации о сессиях и удаление дубликатов"""
|
||||
try:
|
||||
merged = {}
|
||||
|
||||
for session in sessions:
|
||||
key = f"{session['username']}:{session.get('tty', 'unknown')}"
|
||||
|
||||
if key in merged:
|
||||
# Обновляем существующую запись дополнительной информацией
|
||||
merged[key].update({k: v for k, v in session.items() if v is not None})
|
||||
else:
|
||||
merged[key] = session.copy()
|
||||
|
||||
# Добавляем дополнительную информацию о процессах
|
||||
for session in merged.values():
|
||||
if session.get('pid'):
|
||||
try:
|
||||
# Получаем дополнительную информацию о процессе через psutil
|
||||
if psutil.pid_exists(session['pid']):
|
||||
proc = psutil.Process(session['pid'])
|
||||
session.update({
|
||||
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_info': proc.memory_info()._asdict(),
|
||||
'connections': len(proc.connections())
|
||||
})
|
||||
except Exception:
|
||||
pass # Игнорируем ошибки получения доп. информации
|
||||
|
||||
return list(merged.values())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка объединения информации о сессиях: {e}")
|
||||
return sessions
|
||||
|
||||
async def get_user_sessions(self, username: str) -> List[Dict]:
|
||||
"""Получение сессий конкретного пользователя"""
|
||||
try:
|
||||
all_sessions = await self.get_active_sessions()
|
||||
return [s for s in all_sessions if s['username'] == username]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения сессий пользователя {username}: {e}")
|
||||
return []
|
||||
|
||||
async def terminate_session(self, pid: int) -> bool:
|
||||
"""Завершение сессии по PID"""
|
||||
try:
|
||||
# Сначала пробуем TERM
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'kill', '-TERM', str(pid),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
# Ждем немного и проверяем
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not psutil.pid_exists(pid):
|
||||
logger.info(f"✅ Сессия PID {pid} завершена через TERM")
|
||||
return True
|
||||
else:
|
||||
# Если не помогло - используем KILL
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'kill', '-KILL', str(pid),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL")
|
||||
return True
|
||||
|
||||
logger.error(f"❌ Не удалось завершить сессию PID {pid}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка завершения сессии PID {pid}: {e}")
|
||||
return False
|
||||
|
||||
async def terminate_user_sessions(self, username: str) -> int:
|
||||
"""Завершение всех сессий пользователя"""
|
||||
try:
|
||||
user_sessions = await self.get_user_sessions(username)
|
||||
terminated = 0
|
||||
|
||||
for session in user_sessions:
|
||||
pid = session.get('pid')
|
||||
if pid:
|
||||
success = await self.terminate_session(pid)
|
||||
if success:
|
||||
terminated += 1
|
||||
|
||||
logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}")
|
||||
return terminated
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка завершения сессий пользователя {username}: {e}")
|
||||
return 0
|
||||
|
||||
async def get_session_details(self, pid: int) -> Optional[Dict]:
|
||||
"""Получение детальной информации о сессии"""
|
||||
try:
|
||||
if not psutil.pid_exists(pid):
|
||||
return None
|
||||
|
||||
proc = psutil.Process(pid)
|
||||
|
||||
# Базовая информация о процессе
|
||||
details = {
|
||||
'pid': pid,
|
||||
'ppid': proc.ppid(),
|
||||
'username': proc.username(),
|
||||
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_info': proc.memory_info()._asdict(),
|
||||
'status': proc.status(),
|
||||
'cmdline': proc.cmdline(),
|
||||
'cwd': proc.cwd(),
|
||||
'exe': proc.exe()
|
||||
}
|
||||
|
||||
# Сетевые соединения
|
||||
try:
|
||||
connections = []
|
||||
for conn in proc.connections():
|
||||
connections.append({
|
||||
'fd': conn.fd,
|
||||
'family': str(conn.family),
|
||||
'type': str(conn.type),
|
||||
'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None,
|
||||
'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None,
|
||||
'status': str(conn.status)
|
||||
})
|
||||
details['connections'] = connections
|
||||
except Exception:
|
||||
details['connections'] = []
|
||||
|
||||
# Открытые файлы
|
||||
try:
|
||||
open_files = []
|
||||
for file in proc.open_files()[:10]: # Ограничиваем 10 файлами
|
||||
open_files.append({
|
||||
'path': file.path,
|
||||
'fd': file.fd,
|
||||
'mode': file.mode
|
||||
})
|
||||
details['open_files'] = open_files
|
||||
except Exception:
|
||||
details['open_files'] = []
|
||||
|
||||
# Переменные окружения (выборочно)
|
||||
try:
|
||||
env = proc.environ()
|
||||
safe_env = {}
|
||||
safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM']
|
||||
for key in safe_keys:
|
||||
if key in env:
|
||||
safe_env[key] = env[key]
|
||||
details['environment'] = safe_env
|
||||
except Exception:
|
||||
details['environment'] = {}
|
||||
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}")
|
||||
return None
|
||||
|
||||
async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]:
|
||||
"""Мониторинг активности сессии в течение времени"""
|
||||
try:
|
||||
if not psutil.pid_exists(pid):
|
||||
return []
|
||||
|
||||
activity_log = []
|
||||
proc = psutil.Process(pid)
|
||||
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=duration)
|
||||
|
||||
while datetime.now() < end_time:
|
||||
try:
|
||||
# Снимок состояния процесса
|
||||
snapshot = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_percent': proc.memory_percent(),
|
||||
'num_threads': proc.num_threads(),
|
||||
'num_fds': proc.num_fds(),
|
||||
'status': proc.status()
|
||||
}
|
||||
|
||||
# Проверяем новые соединения
|
||||
try:
|
||||
connections = len(proc.connections())
|
||||
snapshot['connections_count'] = connections
|
||||
except Exception:
|
||||
snapshot['connections_count'] = 0
|
||||
|
||||
activity_log.append(snapshot)
|
||||
|
||||
await asyncio.sleep(5) # Снимок каждые 5 секунд
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
# Процесс завершился
|
||||
activity_log.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'event': 'process_terminated'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
activity_log.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'event': 'monitoring_error',
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return activity_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}")
|
||||
return []
|
||||
|
||||
async def get_session_statistics(self) -> Dict:
|
||||
"""Получение общей статистики по сессиям"""
|
||||
try:
|
||||
sessions = await self.get_active_sessions()
|
||||
|
||||
stats = {
|
||||
'total_sessions': len(sessions),
|
||||
'users': {},
|
||||
'tty_types': {},
|
||||
'session_ages': [],
|
||||
'total_connections': 0
|
||||
}
|
||||
|
||||
for session in sessions:
|
||||
# Статистика по пользователям
|
||||
user = session['username']
|
||||
if user not in stats['users']:
|
||||
stats['users'][user] = 0
|
||||
stats['users'][user] += 1
|
||||
|
||||
# Статистика по типам TTY
|
||||
tty = session.get('tty', 'unknown')
|
||||
tty_type = 'console' if tty.startswith('tty') else 'ssh'
|
||||
if tty_type not in stats['tty_types']:
|
||||
stats['tty_types'][tty_type] = 0
|
||||
stats['tty_types'][tty_type] += 1
|
||||
|
||||
# Возраст сессии
|
||||
if 'create_time' in session:
|
||||
try:
|
||||
create_time = datetime.fromisoformat(session['create_time'])
|
||||
age_seconds = (datetime.now() - create_time).total_seconds()
|
||||
stats['session_ages'].append(age_seconds)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Количество соединений
|
||||
connections = session.get('connections', 0)
|
||||
if isinstance(connections, int):
|
||||
stats['total_connections'] += connections
|
||||
|
||||
# Средний возраст сессий
|
||||
if stats['session_ages']:
|
||||
stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages'])
|
||||
else:
|
||||
stats['average_session_age'] = 0
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения статистики сессий: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
async def find_suspicious_sessions(self) -> List[Dict]:
|
||||
"""Поиск подозрительных сессий"""
|
||||
try:
|
||||
sessions = await self.get_active_sessions()
|
||||
suspicious = []
|
||||
|
||||
for session in sessions:
|
||||
suspicion_score = 0
|
||||
reasons = []
|
||||
|
||||
# Проверка 1: Много открытых соединений
|
||||
connections = session.get('connections', 0)
|
||||
if isinstance(connections, int) and connections > 10:
|
||||
suspicion_score += 2
|
||||
reasons.append(f"Много соединений: {connections}")
|
||||
|
||||
# Проверка 2: Высокое потребление CPU
|
||||
cpu = session.get('cpu_percent', 0)
|
||||
if isinstance(cpu, (int, float)) and cpu > 50:
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Высокая нагрузка CPU: {cpu}%")
|
||||
|
||||
# Проверка 3: Долго активная сессия
|
||||
if 'create_time' in session:
|
||||
try:
|
||||
create_time = datetime.fromisoformat(session['create_time'])
|
||||
age_hours = (datetime.now() - create_time).total_seconds() / 3600
|
||||
if age_hours > 24: # Больше суток
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Долгая сессия: {age_hours:.1f} часов")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Проверка 4: Подозрительные команды в cmdline
|
||||
cmdline = session.get('cmdline', [])
|
||||
if isinstance(cmdline, list):
|
||||
suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash']
|
||||
for cmd in cmdline:
|
||||
if any(susp in cmd.lower() for susp in suspicious_commands):
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Подозрительная команда: {cmd}")
|
||||
break
|
||||
|
||||
# Если набрали достаточно очков подозрительности
|
||||
if suspicion_score >= 2:
|
||||
session['suspicion_score'] = suspicion_score
|
||||
session['suspicion_reasons'] = reasons
|
||||
suspicious.append(session)
|
||||
|
||||
return suspicious
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка поиска подозрительных сессий: {e}")
|
||||
return []
|
||||
488
.history/src/sessions_20251125213308.py
Normal file
488
.history/src/sessions_20251125213308.py
Normal file
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Sessions module для PyGuardian
|
||||
Управление SSH сессиями и процессами пользователей
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
import psutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Менеджер SSH сессий и пользовательских процессов"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def get_active_sessions(self) -> List[Dict]:
|
||||
"""Получение всех активных SSH сессий"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
# Метод 1: через who
|
||||
who_sessions = await self._get_sessions_via_who()
|
||||
sessions.extend(who_sessions)
|
||||
|
||||
# Метод 2: через ps (для SSH процессов)
|
||||
ssh_sessions = await self._get_sessions_via_ps()
|
||||
sessions.extend(ssh_sessions)
|
||||
|
||||
# Убираем дубликаты и объединяем информацию
|
||||
unique_sessions = self._merge_session_info(sessions)
|
||||
|
||||
return unique_sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения активных сессий: {e}")
|
||||
return []
|
||||
|
||||
async def _get_sessions_via_who(self) -> List[Dict]:
|
||||
"""Получение сессий через команду who"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'who', '-u',
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
lines = stdout.decode().strip().split('\n')
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
# Парсим вывод who
|
||||
# Формат: user tty date time (idle) pid (comment)
|
||||
match = re.match(
|
||||
r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)',
|
||||
line
|
||||
)
|
||||
if match:
|
||||
username, tty, date, time, pid = match.groups()
|
||||
sessions.append({
|
||||
'username': username,
|
||||
'tty': tty,
|
||||
'login_date': date,
|
||||
'login_time': time,
|
||||
'pid': int(pid),
|
||||
'type': 'who',
|
||||
'status': 'active'
|
||||
})
|
||||
else:
|
||||
# Альтернативный парсинг для разных форматов who
|
||||
parts = line.split()
|
||||
if len(parts) >= 2:
|
||||
username = parts[0]
|
||||
tty = parts[1]
|
||||
|
||||
# Ищем PID в скобках
|
||||
pid_match = re.search(r'\((\d+)\)', line)
|
||||
pid = int(pid_match.group(1)) if pid_match else None
|
||||
|
||||
sessions.append({
|
||||
'username': username,
|
||||
'tty': tty,
|
||||
'pid': pid,
|
||||
'type': 'who',
|
||||
'status': 'active',
|
||||
'raw_line': line
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения сессий через who: {e}")
|
||||
return []
|
||||
|
||||
async def _get_sessions_via_ps(self) -> List[Dict]:
|
||||
"""Получение SSH сессий через ps"""
|
||||
try:
|
||||
sessions = []
|
||||
|
||||
# Ищем SSH процессы
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'ps', 'aux',
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
lines = stdout.decode().strip().split('\n')
|
||||
for line in lines[1:]: # Пропускаем заголовок
|
||||
if 'sshd:' in line and '@pts' in line:
|
||||
# Парсим SSH сессии
|
||||
parts = line.split()
|
||||
if len(parts) >= 11:
|
||||
username = parts[0]
|
||||
pid = int(parts[1])
|
||||
|
||||
# Извлекаем информацию из команды
|
||||
cmd_parts = ' '.join(parts[10:])
|
||||
|
||||
# Ищем пользователя и tty в команде sshd
|
||||
match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts)
|
||||
if match:
|
||||
ssh_user, tty = match.groups()
|
||||
|
||||
sessions.append({
|
||||
'username': ssh_user,
|
||||
'tty': tty,
|
||||
'pid': pid,
|
||||
'ppid': int(parts[2]),
|
||||
'cpu': parts[2],
|
||||
'mem': parts[3],
|
||||
'start_time': parts[8],
|
||||
'type': 'sshd',
|
||||
'status': 'active',
|
||||
'command': cmd_parts
|
||||
})
|
||||
|
||||
return sessions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения SSH сессий через ps: {e}")
|
||||
return []
|
||||
|
||||
def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]:
|
||||
"""Объединение информации о сессиях и удаление дубликатов"""
|
||||
try:
|
||||
merged = {}
|
||||
|
||||
for session in sessions:
|
||||
key = f"{session['username']}:{session.get('tty', 'unknown')}"
|
||||
|
||||
if key in merged:
|
||||
# Обновляем существующую запись дополнительной информацией
|
||||
merged[key].update({k: v for k, v in session.items() if v is not None})
|
||||
else:
|
||||
merged[key] = session.copy()
|
||||
|
||||
# Добавляем дополнительную информацию о процессах
|
||||
for session in merged.values():
|
||||
if session.get('pid'):
|
||||
try:
|
||||
# Получаем дополнительную информацию о процессе через psutil
|
||||
if psutil.pid_exists(session['pid']):
|
||||
proc = psutil.Process(session['pid'])
|
||||
session.update({
|
||||
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_info': proc.memory_info()._asdict(),
|
||||
'connections': len(proc.connections())
|
||||
})
|
||||
except Exception:
|
||||
pass # Игнорируем ошибки получения доп. информации
|
||||
|
||||
return list(merged.values())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка объединения информации о сессиях: {e}")
|
||||
return sessions
|
||||
|
||||
async def get_user_sessions(self, username: str) -> List[Dict]:
|
||||
"""Получение сессий конкретного пользователя"""
|
||||
try:
|
||||
all_sessions = await self.get_active_sessions()
|
||||
return [s for s in all_sessions if s['username'] == username]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения сессий пользователя {username}: {e}")
|
||||
return []
|
||||
|
||||
async def terminate_session(self, pid: int) -> bool:
|
||||
"""Завершение сессии по PID"""
|
||||
try:
|
||||
# Сначала пробуем TERM
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'kill', '-TERM', str(pid),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
# Ждем немного и проверяем
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not psutil.pid_exists(pid):
|
||||
logger.info(f"✅ Сессия PID {pid} завершена через TERM")
|
||||
return True
|
||||
else:
|
||||
# Если не помогло - используем KILL
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
'kill', '-KILL', str(pid),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL")
|
||||
return True
|
||||
|
||||
logger.error(f"❌ Не удалось завершить сессию PID {pid}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка завершения сессии PID {pid}: {e}")
|
||||
return False
|
||||
|
||||
async def terminate_user_sessions(self, username: str) -> int:
|
||||
"""Завершение всех сессий пользователя"""
|
||||
try:
|
||||
user_sessions = await self.get_user_sessions(username)
|
||||
terminated = 0
|
||||
|
||||
for session in user_sessions:
|
||||
pid = session.get('pid')
|
||||
if pid:
|
||||
success = await self.terminate_session(pid)
|
||||
if success:
|
||||
terminated += 1
|
||||
|
||||
logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}")
|
||||
return terminated
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка завершения сессий пользователя {username}: {e}")
|
||||
return 0
|
||||
|
||||
async def get_session_details(self, pid: int) -> Optional[Dict]:
|
||||
"""Получение детальной информации о сессии"""
|
||||
try:
|
||||
if not psutil.pid_exists(pid):
|
||||
return None
|
||||
|
||||
proc = psutil.Process(pid)
|
||||
|
||||
# Базовая информация о процессе
|
||||
details = {
|
||||
'pid': pid,
|
||||
'ppid': proc.ppid(),
|
||||
'username': proc.username(),
|
||||
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_info': proc.memory_info()._asdict(),
|
||||
'status': proc.status(),
|
||||
'cmdline': proc.cmdline(),
|
||||
'cwd': proc.cwd(),
|
||||
'exe': proc.exe()
|
||||
}
|
||||
|
||||
# Сетевые соединения
|
||||
try:
|
||||
connections = []
|
||||
for conn in proc.connections():
|
||||
connections.append({
|
||||
'fd': conn.fd,
|
||||
'family': str(conn.family),
|
||||
'type': str(conn.type),
|
||||
'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None,
|
||||
'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None,
|
||||
'status': str(conn.status)
|
||||
})
|
||||
details['connections'] = connections
|
||||
except Exception:
|
||||
details['connections'] = []
|
||||
|
||||
# Открытые файлы
|
||||
try:
|
||||
open_files = []
|
||||
for file in proc.open_files()[:10]: # Ограничиваем 10 файлами
|
||||
open_files.append({
|
||||
'path': file.path,
|
||||
'fd': file.fd,
|
||||
'mode': file.mode
|
||||
})
|
||||
details['open_files'] = open_files
|
||||
except Exception:
|
||||
details['open_files'] = []
|
||||
|
||||
# Переменные окружения (выборочно)
|
||||
try:
|
||||
env = proc.environ()
|
||||
safe_env = {}
|
||||
safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM']
|
||||
for key in safe_keys:
|
||||
if key in env:
|
||||
safe_env[key] = env[key]
|
||||
details['environment'] = safe_env
|
||||
except Exception:
|
||||
details['environment'] = {}
|
||||
|
||||
return details
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}")
|
||||
return None
|
||||
|
||||
async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]:
|
||||
"""Мониторинг активности сессии в течение времени"""
|
||||
try:
|
||||
if not psutil.pid_exists(pid):
|
||||
return []
|
||||
|
||||
activity_log = []
|
||||
proc = psutil.Process(pid)
|
||||
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=duration)
|
||||
|
||||
while datetime.now() < end_time:
|
||||
try:
|
||||
# Снимок состояния процесса
|
||||
snapshot = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'cpu_percent': proc.cpu_percent(),
|
||||
'memory_percent': proc.memory_percent(),
|
||||
'num_threads': proc.num_threads(),
|
||||
'num_fds': proc.num_fds(),
|
||||
'status': proc.status()
|
||||
}
|
||||
|
||||
# Проверяем новые соединения
|
||||
try:
|
||||
connections = len(proc.connections())
|
||||
snapshot['connections_count'] = connections
|
||||
except Exception:
|
||||
snapshot['connections_count'] = 0
|
||||
|
||||
activity_log.append(snapshot)
|
||||
|
||||
await asyncio.sleep(5) # Снимок каждые 5 секунд
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
# Процесс завершился
|
||||
activity_log.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'event': 'process_terminated'
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
activity_log.append({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'event': 'monitoring_error',
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
return activity_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}")
|
||||
return []
|
||||
|
||||
async def get_session_statistics(self) -> Dict:
|
||||
"""Получение общей статистики по сессиям"""
|
||||
try:
|
||||
sessions = await self.get_active_sessions()
|
||||
|
||||
stats = {
|
||||
'total_sessions': len(sessions),
|
||||
'users': {},
|
||||
'tty_types': {},
|
||||
'session_ages': [],
|
||||
'total_connections': 0
|
||||
}
|
||||
|
||||
for session in sessions:
|
||||
# Статистика по пользователям
|
||||
user = session['username']
|
||||
if user not in stats['users']:
|
||||
stats['users'][user] = 0
|
||||
stats['users'][user] += 1
|
||||
|
||||
# Статистика по типам TTY
|
||||
tty = session.get('tty', 'unknown')
|
||||
tty_type = 'console' if tty.startswith('tty') else 'ssh'
|
||||
if tty_type not in stats['tty_types']:
|
||||
stats['tty_types'][tty_type] = 0
|
||||
stats['tty_types'][tty_type] += 1
|
||||
|
||||
# Возраст сессии
|
||||
if 'create_time' in session:
|
||||
try:
|
||||
create_time = datetime.fromisoformat(session['create_time'])
|
||||
age_seconds = (datetime.now() - create_time).total_seconds()
|
||||
stats['session_ages'].append(age_seconds)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Количество соединений
|
||||
connections = session.get('connections', 0)
|
||||
if isinstance(connections, int):
|
||||
stats['total_connections'] += connections
|
||||
|
||||
# Средний возраст сессий
|
||||
if stats['session_ages']:
|
||||
stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages'])
|
||||
else:
|
||||
stats['average_session_age'] = 0
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка получения статистики сессий: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
async def find_suspicious_sessions(self) -> List[Dict]:
|
||||
"""Поиск подозрительных сессий"""
|
||||
try:
|
||||
sessions = await self.get_active_sessions()
|
||||
suspicious = []
|
||||
|
||||
for session in sessions:
|
||||
suspicion_score = 0
|
||||
reasons = []
|
||||
|
||||
# Проверка 1: Много открытых соединений
|
||||
connections = session.get('connections', 0)
|
||||
if isinstance(connections, int) and connections > 10:
|
||||
suspicion_score += 2
|
||||
reasons.append(f"Много соединений: {connections}")
|
||||
|
||||
# Проверка 2: Высокое потребление CPU
|
||||
cpu = session.get('cpu_percent', 0)
|
||||
if isinstance(cpu, (int, float)) and cpu > 50:
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Высокая нагрузка CPU: {cpu}%")
|
||||
|
||||
# Проверка 3: Долго активная сессия
|
||||
if 'create_time' in session:
|
||||
try:
|
||||
create_time = datetime.fromisoformat(session['create_time'])
|
||||
age_hours = (datetime.now() - create_time).total_seconds() / 3600
|
||||
if age_hours > 24: # Больше суток
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Долгая сессия: {age_hours:.1f} часов")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Проверка 4: Подозрительные команды в cmdline
|
||||
cmdline = session.get('cmdline', [])
|
||||
if isinstance(cmdline, list):
|
||||
suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash']
|
||||
for cmd in cmdline:
|
||||
if any(susp in cmd.lower() for susp in suspicious_commands):
|
||||
suspicion_score += 1
|
||||
reasons.append(f"Подозрительная команда: {cmd}")
|
||||
break
|
||||
|
||||
# Если набрали достаточно очков подозрительности
|
||||
if suspicion_score >= 2:
|
||||
session['suspicion_score'] = suspicion_score
|
||||
session['suspicion_reasons'] = reasons
|
||||
suspicious.append(session)
|
||||
|
||||
return suspicious
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Ошибка поиска подозрительных сессий: {e}")
|
||||
return []
|
||||
421
.history/tests/unit/test_authentication_20251125212848.py
Normal file
421
.history/tests/unit/test_authentication_20251125212848.py
Normal file
@@ -0,0 +1,421 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Create test database
|
||||
self.db = Database(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Database(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
421
.history/tests/unit/test_authentication_20251125212856.py
Normal file
421
.history/tests/unit/test_authentication_20251125212856.py
Normal file
@@ -0,0 +1,421 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Create test database
|
||||
self.db = Database(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
421
.history/tests/unit/test_authentication_20251125212909.py
Normal file
421
.history/tests/unit/test_authentication_20251125212909.py
Normal file
@@ -0,0 +1,421 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Create test database
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
421
.history/tests/unit/test_authentication_20251125213152.py
Normal file
421
.history/tests/unit/test_authentication_20251125213152.py
Normal file
@@ -0,0 +1,421 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Create test database
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
422
.history/tests/unit/test_authentication_20251125213226.py
Normal file
422
.history/tests/unit/test_authentication_20251125213226.py
Normal file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.test_secret = 'test_secret_key_123'
|
||||
self.auth = AgentAuthentication(self.test_secret)
|
||||
|
||||
# Create test database
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
422
.history/tests/unit/test_authentication_20251125213251.py
Normal file
422
.history/tests/unit/test_authentication_20251125213251.py
Normal file
@@ -0,0 +1,422 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive unit tests for PyGuardian authentication system.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import sys
|
||||
import sqlite3
|
||||
import jwt
|
||||
import hashlib
|
||||
import hmac
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
|
||||
# Add src directory to path for imports
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Test cases for agent authentication system."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.test_secret = 'test_secret_key_123'
|
||||
self.auth = AgentAuthentication(self.test_secret)
|
||||
|
||||
# Create test database
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_generate_agent_id(self):
|
||||
"""Test agent ID generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
|
||||
# Check format
|
||||
self.assertTrue(agent_id.startswith('agent_'))
|
||||
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
|
||||
|
||||
# Test uniqueness
|
||||
agent_id2 = self.auth.generate_agent_id()
|
||||
self.assertNotEqual(agent_id, agent_id2)
|
||||
|
||||
def test_create_agent_credentials(self):
|
||||
"""Test agent credentials creation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
|
||||
# Check required fields
|
||||
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
|
||||
for field in required_fields:
|
||||
self.assertIn(field, credentials)
|
||||
|
||||
# Check agent ID matches
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
|
||||
# Check secret key length
|
||||
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
|
||||
|
||||
# Check key hash
|
||||
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
|
||||
self.assertEqual(credentials['key_hash'], expected_hash)
|
||||
|
||||
def test_generate_jwt_token(self):
|
||||
"""Test JWT token generation."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
# Verify token structure
|
||||
self.assertIsInstance(token, str)
|
||||
self.assertTrue(len(token) > 100) # JWT tokens are typically long
|
||||
|
||||
# Decode and verify payload
|
||||
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
|
||||
self.assertEqual(decoded['agent_id'], agent_id)
|
||||
self.assertIn('iat', decoded)
|
||||
self.assertIn('exp', decoded)
|
||||
self.assertIn('jti', decoded)
|
||||
|
||||
def test_verify_jwt_token_valid(self):
|
||||
"""Test JWT token verification with valid token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, secret_key)
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_jwt_token_invalid(self):
|
||||
"""Test JWT token verification with invalid token."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with invalid token
|
||||
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong secret key
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
token = self.auth.generate_jwt_token(agent_id, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(token, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_verify_jwt_token_expired(self):
|
||||
"""Test JWT token verification with expired token."""
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Create expired token
|
||||
payload = {
|
||||
'agent_id': agent_id,
|
||||
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
|
||||
'iat': datetime.utcnow() - timedelta(hours=2),
|
||||
'jti': self.auth._generate_jti()
|
||||
}
|
||||
|
||||
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
|
||||
|
||||
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_create_hmac_signature(self):
|
||||
"""Test HMAC signature creation."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
|
||||
# Verify signature format
|
||||
self.assertEqual(len(signature), 64) # SHA256 hex digest
|
||||
|
||||
# Verify signature is correct
|
||||
expected = hmac.new(
|
||||
secret_key.encode(),
|
||||
data.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
self.assertEqual(signature, expected)
|
||||
|
||||
def test_verify_hmac_signature_valid(self):
|
||||
"""Test HMAC signature verification with valid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
def test_verify_hmac_signature_invalid(self):
|
||||
"""Test HMAC signature verification with invalid signature."""
|
||||
data = "test message"
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Test with wrong signature
|
||||
wrong_signature = "0" * 64
|
||||
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
# Test with wrong key
|
||||
signature = self.auth.create_hmac_signature(data, secret_key)
|
||||
wrong_key = self.auth._generate_secret_key()
|
||||
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
|
||||
self.assertFalse(is_valid)
|
||||
|
||||
def test_encrypt_decrypt_secret_key(self):
|
||||
"""Test secret key encryption and decryption."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
decrypted = self.auth.decrypt_secret_key(encrypted, password)
|
||||
|
||||
self.assertEqual(secret_key, decrypted)
|
||||
|
||||
def test_encrypt_decrypt_wrong_password(self):
|
||||
"""Test secret key decryption with wrong password."""
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
password = "test_password"
|
||||
wrong_password = "wrong_password"
|
||||
|
||||
encrypted = self.auth.encrypt_secret_key(secret_key, password)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
self.auth.decrypt_secret_key(encrypted, wrong_password)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_success(self, mock_db_class):
|
||||
"""Test successful agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
|
||||
|
||||
# Mock database response
|
||||
mock_db.get_agent_credentials.return_value = {
|
||||
'agent_id': agent_id,
|
||||
'key_hash': key_hash,
|
||||
'is_active': True,
|
||||
'created_at': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('src.auth.Database')
|
||||
def test_authenticate_agent_failure(self, mock_db_class):
|
||||
"""Test failed agent authentication."""
|
||||
# Mock database
|
||||
mock_db = Mock()
|
||||
mock_db_class.return_value = mock_db
|
||||
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
secret_key = self.auth._generate_secret_key()
|
||||
|
||||
# Mock database response - no credentials found
|
||||
mock_db.get_agent_credentials.return_value = None
|
||||
|
||||
result = self.auth.authenticate_agent(agent_id, secret_key)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
"""Test cases for database operations."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_create_agent_auth(self):
|
||||
"""Test agent authentication record creation."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify record exists
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
|
||||
def test_get_agent_credentials_exists(self):
|
||||
"""Test retrieving existing agent credentials."""
|
||||
agent_id = "agent_test123"
|
||||
secret_key_hash = "test_hash"
|
||||
encrypted_key = "encrypted_test_key"
|
||||
|
||||
# Create record
|
||||
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
|
||||
|
||||
# Retrieve record
|
||||
credentials = self.db.get_agent_credentials(agent_id)
|
||||
|
||||
self.assertIsNotNone(credentials)
|
||||
self.assertEqual(credentials['agent_id'], agent_id)
|
||||
self.assertEqual(credentials['key_hash'], secret_key_hash)
|
||||
self.assertTrue(credentials['is_active'])
|
||||
|
||||
def test_get_agent_credentials_not_exists(self):
|
||||
"""Test retrieving non-existent agent credentials."""
|
||||
credentials = self.db.get_agent_credentials("non_existent_agent")
|
||||
self.assertIsNone(credentials)
|
||||
|
||||
def test_store_agent_token(self):
|
||||
"""Test storing agent JWT token."""
|
||||
agent_id = "agent_test123"
|
||||
token = "test_jwt_token"
|
||||
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
|
||||
success = self.db.store_agent_token(agent_id, token, expires_at)
|
||||
self.assertTrue(success)
|
||||
|
||||
# Verify token exists
|
||||
stored_token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNotNone(stored_token)
|
||||
self.assertEqual(stored_token['token'], token)
|
||||
|
||||
def test_cleanup_expired_tokens(self):
|
||||
"""Test cleanup of expired tokens."""
|
||||
agent_id = "agent_test123"
|
||||
|
||||
# Create expired token
|
||||
expired_token = "expired_token"
|
||||
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token(agent_id, expired_token, expired_time)
|
||||
|
||||
# Create valid token
|
||||
valid_token = "valid_token"
|
||||
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
self.db.store_agent_token("agent_valid", valid_token, valid_time)
|
||||
|
||||
# Cleanup expired tokens
|
||||
cleaned = self.db.cleanup_expired_tokens()
|
||||
self.assertGreaterEqual(cleaned, 1)
|
||||
|
||||
# Verify expired token is gone
|
||||
token = self.db.get_agent_token(agent_id)
|
||||
self.assertIsNone(token)
|
||||
|
||||
# Verify valid token remains
|
||||
token = self.db.get_agent_token("agent_valid")
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
|
||||
class TestIntegration(unittest.TestCase):
|
||||
"""Integration tests for the complete authentication flow."""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
|
||||
# Use test database
|
||||
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up test fixtures."""
|
||||
if os.path.exists(self.db_path):
|
||||
os.remove(self.db_path)
|
||||
os.rmdir(self.temp_dir)
|
||||
|
||||
def test_complete_authentication_flow(self):
|
||||
"""Test complete agent authentication workflow."""
|
||||
# Step 1: Generate agent ID
|
||||
agent_id = self.auth.generate_agent_id()
|
||||
self.assertIsNotNone(agent_id)
|
||||
|
||||
# Step 2: Create credentials
|
||||
credentials = self.auth.create_agent_credentials(agent_id)
|
||||
self.assertIsNotNone(credentials)
|
||||
|
||||
# Step 3: Generate JWT token
|
||||
token = self.auth.generate_jwt_token(
|
||||
credentials['agent_id'],
|
||||
credentials['secret_key']
|
||||
)
|
||||
self.assertIsNotNone(token)
|
||||
|
||||
# Step 4: Verify token
|
||||
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
|
||||
self.assertTrue(is_valid)
|
||||
|
||||
# Step 5: Create HMAC signature
|
||||
test_data = "test API request"
|
||||
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
|
||||
self.assertIsNotNone(signature)
|
||||
|
||||
# Step 6: Verify HMAC signature
|
||||
is_signature_valid = self.auth.verify_hmac_signature(
|
||||
test_data, signature, credentials['secret_key']
|
||||
)
|
||||
self.assertTrue(is_signature_valid)
|
||||
|
||||
|
||||
def run_tests():
|
||||
"""Run all tests."""
|
||||
print("🧪 Running PyGuardian Authentication Tests...")
|
||||
print("=" * 50)
|
||||
|
||||
# Create test suite
|
||||
test_suite = unittest.TestSuite()
|
||||
|
||||
# Add test classes
|
||||
test_classes = [
|
||||
TestAgentAuthentication,
|
||||
TestDatabase,
|
||||
TestIntegration
|
||||
]
|
||||
|
||||
for test_class in test_classes:
|
||||
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
|
||||
test_suite.addTests(tests)
|
||||
|
||||
# Run tests
|
||||
runner = unittest.TextTestRunner(verbosity=2)
|
||||
result = runner.run(test_suite)
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print(f"🏁 Tests completed:")
|
||||
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
|
||||
print(f" ❌ Failed: {len(result.failures)}")
|
||||
print(f" 💥 Errors: {len(result.errors)}")
|
||||
|
||||
# Return exit code
|
||||
return 0 if result.wasSuccessful() else 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(run_tests())
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional
|
||||
import psutil
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from unittest.mock import Mock, patch, MagicMock
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
|
||||
|
||||
from auth import AgentAuthentication
|
||||
from storage import Database
|
||||
from storage import Storage
|
||||
|
||||
|
||||
class TestAgentAuthentication(unittest.TestCase):
|
||||
@@ -28,10 +28,11 @@ class TestAgentAuthentication(unittest.TestCase):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.auth = AgentAuthentication()
|
||||
self.test_secret = 'test_secret_key_123'
|
||||
self.auth = AgentAuthentication(self.test_secret)
|
||||
|
||||
# Create test database
|
||||
self.db = Database(self.db_path)
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
@@ -245,7 +246,7 @@ class TestDatabase(unittest.TestCase):
|
||||
"""Set up test fixtures."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
|
||||
self.db = Database(self.db_path)
|
||||
self.db = Storage(self.db_path)
|
||||
self.db.create_tables()
|
||||
|
||||
def tearDown(self):
|
||||
|
||||
Reference in New Issue
Block a user