fix: Resolve import issues and test compatibility
Some checks reported errors
continuous-integration/drone/push Build was killed

- Fix Storage class reference in authentication tests
- Add secret_key parameter to AgentAuthentication initialization
- Fix timedelta import in sessions.py
- Basic authentication functionality verified
This commit is contained in:
2025-11-25 21:33:17 +09:00
parent d00fc9fd61
commit 9f2cc216d5
10 changed files with 3510 additions and 5 deletions

View File

@@ -0,0 +1,488 @@
"""
Sessions module для PyGuardian
Управление SSH сессиями и процессами пользователей
"""
import asyncio
import logging
import re
import os
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import psutil
logger = logging.getLogger(__name__)
class SessionManager:
"""Менеджер SSH сессий и пользовательских процессов"""
def __init__(self):
pass
async def get_active_sessions(self) -> List[Dict]:
"""Получение всех активных SSH сессий"""
try:
sessions = []
# Метод 1: через who
who_sessions = await self._get_sessions_via_who()
sessions.extend(who_sessions)
# Метод 2: через ps (для SSH процессов)
ssh_sessions = await self._get_sessions_via_ps()
sessions.extend(ssh_sessions)
# Убираем дубликаты и объединяем информацию
unique_sessions = self._merge_session_info(sessions)
return unique_sessions
except Exception as e:
logger.error(f"Ошибка получения активных сессий: {e}")
return []
async def _get_sessions_via_who(self) -> List[Dict]:
"""Получение сессий через команду who"""
try:
sessions = []
process = await asyncio.create_subprocess_exec(
'who', '-u',
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
lines = stdout.decode().strip().split('\n')
for line in lines:
if line.strip():
# Парсим вывод who
# Формат: user tty date time (idle) pid (comment)
match = re.match(
r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)',
line
)
if match:
username, tty, date, time, pid = match.groups()
sessions.append({
'username': username,
'tty': tty,
'login_date': date,
'login_time': time,
'pid': int(pid),
'type': 'who',
'status': 'active'
})
else:
# Альтернативный парсинг для разных форматов who
parts = line.split()
if len(parts) >= 2:
username = parts[0]
tty = parts[1]
# Ищем PID в скобках
pid_match = re.search(r'\((\d+)\)', line)
pid = int(pid_match.group(1)) if pid_match else None
sessions.append({
'username': username,
'tty': tty,
'pid': pid,
'type': 'who',
'status': 'active',
'raw_line': line
})
return sessions
except Exception as e:
logger.error(f"Ошибка получения сессий через who: {e}")
return []
async def _get_sessions_via_ps(self) -> List[Dict]:
"""Получение SSH сессий через ps"""
try:
sessions = []
# Ищем SSH процессы
process = await asyncio.create_subprocess_exec(
'ps', 'aux',
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
lines = stdout.decode().strip().split('\n')
for line in lines[1:]: # Пропускаем заголовок
if 'sshd:' in line and '@pts' in line:
# Парсим SSH сессии
parts = line.split()
if len(parts) >= 11:
username = parts[0]
pid = int(parts[1])
# Извлекаем информацию из команды
cmd_parts = ' '.join(parts[10:])
# Ищем пользователя и tty в команде sshd
match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts)
if match:
ssh_user, tty = match.groups()
sessions.append({
'username': ssh_user,
'tty': tty,
'pid': pid,
'ppid': int(parts[2]),
'cpu': parts[2],
'mem': parts[3],
'start_time': parts[8],
'type': 'sshd',
'status': 'active',
'command': cmd_parts
})
return sessions
except Exception as e:
logger.error(f"Ошибка получения SSH сессий через ps: {e}")
return []
def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]:
"""Объединение информации о сессиях и удаление дубликатов"""
try:
merged = {}
for session in sessions:
key = f"{session['username']}:{session.get('tty', 'unknown')}"
if key in merged:
# Обновляем существующую запись дополнительной информацией
merged[key].update({k: v for k, v in session.items() if v is not None})
else:
merged[key] = session.copy()
# Добавляем дополнительную информацию о процессах
for session in merged.values():
if session.get('pid'):
try:
# Получаем дополнительную информацию о процессе через psutil
if psutil.pid_exists(session['pid']):
proc = psutil.Process(session['pid'])
session.update({
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_info': proc.memory_info()._asdict(),
'connections': len(proc.connections())
})
except Exception:
pass # Игнорируем ошибки получения доп. информации
return list(merged.values())
except Exception as e:
logger.error(f"Ошибка объединения информации о сессиях: {e}")
return sessions
async def get_user_sessions(self, username: str) -> List[Dict]:
"""Получение сессий конкретного пользователя"""
try:
all_sessions = await self.get_active_sessions()
return [s for s in all_sessions if s['username'] == username]
except Exception as e:
logger.error(f"Ошибка получения сессий пользователя {username}: {e}")
return []
async def terminate_session(self, pid: int) -> bool:
"""Завершение сессии по PID"""
try:
# Сначала пробуем TERM
process = await asyncio.create_subprocess_exec(
'kill', '-TERM', str(pid),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
await process.communicate()
if process.returncode == 0:
# Ждем немного и проверяем
await asyncio.sleep(2)
if not psutil.pid_exists(pid):
logger.info(f"✅ Сессия PID {pid} завершена через TERM")
return True
else:
# Если не помогло - используем KILL
process = await asyncio.create_subprocess_exec(
'kill', '-KILL', str(pid),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
await process.communicate()
if process.returncode == 0:
logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL")
return True
logger.error(f"Не удалось завершить сессию PID {pid}")
return False
except Exception as e:
logger.error(f"Ошибка завершения сессии PID {pid}: {e}")
return False
async def terminate_user_sessions(self, username: str) -> int:
"""Завершение всех сессий пользователя"""
try:
user_sessions = await self.get_user_sessions(username)
terminated = 0
for session in user_sessions:
pid = session.get('pid')
if pid:
success = await self.terminate_session(pid)
if success:
terminated += 1
logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}")
return terminated
except Exception as e:
logger.error(f"Ошибка завершения сессий пользователя {username}: {e}")
return 0
async def get_session_details(self, pid: int) -> Optional[Dict]:
"""Получение детальной информации о сессии"""
try:
if not psutil.pid_exists(pid):
return None
proc = psutil.Process(pid)
# Базовая информация о процессе
details = {
'pid': pid,
'ppid': proc.ppid(),
'username': proc.username(),
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_info': proc.memory_info()._asdict(),
'status': proc.status(),
'cmdline': proc.cmdline(),
'cwd': proc.cwd(),
'exe': proc.exe()
}
# Сетевые соединения
try:
connections = []
for conn in proc.connections():
connections.append({
'fd': conn.fd,
'family': str(conn.family),
'type': str(conn.type),
'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None,
'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None,
'status': str(conn.status)
})
details['connections'] = connections
except Exception:
details['connections'] = []
# Открытые файлы
try:
open_files = []
for file in proc.open_files()[:10]: # Ограничиваем 10 файлами
open_files.append({
'path': file.path,
'fd': file.fd,
'mode': file.mode
})
details['open_files'] = open_files
except Exception:
details['open_files'] = []
# Переменные окружения (выборочно)
try:
env = proc.environ()
safe_env = {}
safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM']
for key in safe_keys:
if key in env:
safe_env[key] = env[key]
details['environment'] = safe_env
except Exception:
details['environment'] = {}
return details
except Exception as e:
logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}")
return None
async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]:
"""Мониторинг активности сессии в течение времени"""
try:
if not psutil.pid_exists(pid):
return []
activity_log = []
proc = psutil.Process(pid)
start_time = datetime.now()
end_time = start_time + timedelta(seconds=duration)
while datetime.now() < end_time:
try:
# Снимок состояния процесса
snapshot = {
'timestamp': datetime.now().isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_percent': proc.memory_percent(),
'num_threads': proc.num_threads(),
'num_fds': proc.num_fds(),
'status': proc.status()
}
# Проверяем новые соединения
try:
connections = len(proc.connections())
snapshot['connections_count'] = connections
except Exception:
snapshot['connections_count'] = 0
activity_log.append(snapshot)
await asyncio.sleep(5) # Снимок каждые 5 секунд
except psutil.NoSuchProcess:
# Процесс завершился
activity_log.append({
'timestamp': datetime.now().isoformat(),
'event': 'process_terminated'
})
break
except Exception as e:
activity_log.append({
'timestamp': datetime.now().isoformat(),
'event': 'monitoring_error',
'error': str(e)
})
return activity_log
except Exception as e:
logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}")
return []
async def get_session_statistics(self) -> Dict:
"""Получение общей статистики по сессиям"""
try:
sessions = await self.get_active_sessions()
stats = {
'total_sessions': len(sessions),
'users': {},
'tty_types': {},
'session_ages': [],
'total_connections': 0
}
for session in sessions:
# Статистика по пользователям
user = session['username']
if user not in stats['users']:
stats['users'][user] = 0
stats['users'][user] += 1
# Статистика по типам TTY
tty = session.get('tty', 'unknown')
tty_type = 'console' if tty.startswith('tty') else 'ssh'
if tty_type not in stats['tty_types']:
stats['tty_types'][tty_type] = 0
stats['tty_types'][tty_type] += 1
# Возраст сессии
if 'create_time' in session:
try:
create_time = datetime.fromisoformat(session['create_time'])
age_seconds = (datetime.now() - create_time).total_seconds()
stats['session_ages'].append(age_seconds)
except Exception:
pass
# Количество соединений
connections = session.get('connections', 0)
if isinstance(connections, int):
stats['total_connections'] += connections
# Средний возраст сессий
if stats['session_ages']:
stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages'])
else:
stats['average_session_age'] = 0
return stats
except Exception as e:
logger.error(f"Ошибка получения статистики сессий: {e}")
return {'error': str(e)}
async def find_suspicious_sessions(self) -> List[Dict]:
"""Поиск подозрительных сессий"""
try:
sessions = await self.get_active_sessions()
suspicious = []
for session in sessions:
suspicion_score = 0
reasons = []
# Проверка 1: Много открытых соединений
connections = session.get('connections', 0)
if isinstance(connections, int) and connections > 10:
suspicion_score += 2
reasons.append(f"Много соединений: {connections}")
# Проверка 2: Высокое потребление CPU
cpu = session.get('cpu_percent', 0)
if isinstance(cpu, (int, float)) and cpu > 50:
suspicion_score += 1
reasons.append(f"Высокая нагрузка CPU: {cpu}%")
# Проверка 3: Долго активная сессия
if 'create_time' in session:
try:
create_time = datetime.fromisoformat(session['create_time'])
age_hours = (datetime.now() - create_time).total_seconds() / 3600
if age_hours > 24: # Больше суток
suspicion_score += 1
reasons.append(f"Долгая сессия: {age_hours:.1f} часов")
except Exception:
pass
# Проверка 4: Подозрительные команды в cmdline
cmdline = session.get('cmdline', [])
if isinstance(cmdline, list):
suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash']
for cmd in cmdline:
if any(susp in cmd.lower() for susp in suspicious_commands):
suspicion_score += 1
reasons.append(f"Подозрительная команда: {cmd}")
break
# Если набрали достаточно очков подозрительности
if suspicion_score >= 2:
session['suspicion_score'] = suspicion_score
session['suspicion_reasons'] = reasons
suspicious.append(session)
return suspicious
except Exception as e:
logger.error(f"Ошибка поиска подозрительных сессий: {e}")
return []

View File

@@ -0,0 +1,488 @@
"""
Sessions module для PyGuardian
Управление SSH сессиями и процессами пользователей
"""
import asyncio
import logging
import re
import os
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import psutil
logger = logging.getLogger(__name__)
class SessionManager:
"""Менеджер SSH сессий и пользовательских процессов"""
def __init__(self):
pass
async def get_active_sessions(self) -> List[Dict]:
"""Получение всех активных SSH сессий"""
try:
sessions = []
# Метод 1: через who
who_sessions = await self._get_sessions_via_who()
sessions.extend(who_sessions)
# Метод 2: через ps (для SSH процессов)
ssh_sessions = await self._get_sessions_via_ps()
sessions.extend(ssh_sessions)
# Убираем дубликаты и объединяем информацию
unique_sessions = self._merge_session_info(sessions)
return unique_sessions
except Exception as e:
logger.error(f"Ошибка получения активных сессий: {e}")
return []
async def _get_sessions_via_who(self) -> List[Dict]:
"""Получение сессий через команду who"""
try:
sessions = []
process = await asyncio.create_subprocess_exec(
'who', '-u',
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
lines = stdout.decode().strip().split('\n')
for line in lines:
if line.strip():
# Парсим вывод who
# Формат: user tty date time (idle) pid (comment)
match = re.match(
r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)',
line
)
if match:
username, tty, date, time, pid = match.groups()
sessions.append({
'username': username,
'tty': tty,
'login_date': date,
'login_time': time,
'pid': int(pid),
'type': 'who',
'status': 'active'
})
else:
# Альтернативный парсинг для разных форматов who
parts = line.split()
if len(parts) >= 2:
username = parts[0]
tty = parts[1]
# Ищем PID в скобках
pid_match = re.search(r'\((\d+)\)', line)
pid = int(pid_match.group(1)) if pid_match else None
sessions.append({
'username': username,
'tty': tty,
'pid': pid,
'type': 'who',
'status': 'active',
'raw_line': line
})
return sessions
except Exception as e:
logger.error(f"Ошибка получения сессий через who: {e}")
return []
async def _get_sessions_via_ps(self) -> List[Dict]:
"""Получение SSH сессий через ps"""
try:
sessions = []
# Ищем SSH процессы
process = await asyncio.create_subprocess_exec(
'ps', 'aux',
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
lines = stdout.decode().strip().split('\n')
for line in lines[1:]: # Пропускаем заголовок
if 'sshd:' in line and '@pts' in line:
# Парсим SSH сессии
parts = line.split()
if len(parts) >= 11:
username = parts[0]
pid = int(parts[1])
# Извлекаем информацию из команды
cmd_parts = ' '.join(parts[10:])
# Ищем пользователя и tty в команде sshd
match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts)
if match:
ssh_user, tty = match.groups()
sessions.append({
'username': ssh_user,
'tty': tty,
'pid': pid,
'ppid': int(parts[2]),
'cpu': parts[2],
'mem': parts[3],
'start_time': parts[8],
'type': 'sshd',
'status': 'active',
'command': cmd_parts
})
return sessions
except Exception as e:
logger.error(f"Ошибка получения SSH сессий через ps: {e}")
return []
def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]:
"""Объединение информации о сессиях и удаление дубликатов"""
try:
merged = {}
for session in sessions:
key = f"{session['username']}:{session.get('tty', 'unknown')}"
if key in merged:
# Обновляем существующую запись дополнительной информацией
merged[key].update({k: v for k, v in session.items() if v is not None})
else:
merged[key] = session.copy()
# Добавляем дополнительную информацию о процессах
for session in merged.values():
if session.get('pid'):
try:
# Получаем дополнительную информацию о процессе через psutil
if psutil.pid_exists(session['pid']):
proc = psutil.Process(session['pid'])
session.update({
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_info': proc.memory_info()._asdict(),
'connections': len(proc.connections())
})
except Exception:
pass # Игнорируем ошибки получения доп. информации
return list(merged.values())
except Exception as e:
logger.error(f"Ошибка объединения информации о сессиях: {e}")
return sessions
async def get_user_sessions(self, username: str) -> List[Dict]:
"""Получение сессий конкретного пользователя"""
try:
all_sessions = await self.get_active_sessions()
return [s for s in all_sessions if s['username'] == username]
except Exception as e:
logger.error(f"Ошибка получения сессий пользователя {username}: {e}")
return []
async def terminate_session(self, pid: int) -> bool:
"""Завершение сессии по PID"""
try:
# Сначала пробуем TERM
process = await asyncio.create_subprocess_exec(
'kill', '-TERM', str(pid),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
await process.communicate()
if process.returncode == 0:
# Ждем немного и проверяем
await asyncio.sleep(2)
if not psutil.pid_exists(pid):
logger.info(f"✅ Сессия PID {pid} завершена через TERM")
return True
else:
# Если не помогло - используем KILL
process = await asyncio.create_subprocess_exec(
'kill', '-KILL', str(pid),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
await process.communicate()
if process.returncode == 0:
logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL")
return True
logger.error(f"Не удалось завершить сессию PID {pid}")
return False
except Exception as e:
logger.error(f"Ошибка завершения сессии PID {pid}: {e}")
return False
async def terminate_user_sessions(self, username: str) -> int:
"""Завершение всех сессий пользователя"""
try:
user_sessions = await self.get_user_sessions(username)
terminated = 0
for session in user_sessions:
pid = session.get('pid')
if pid:
success = await self.terminate_session(pid)
if success:
terminated += 1
logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}")
return terminated
except Exception as e:
logger.error(f"Ошибка завершения сессий пользователя {username}: {e}")
return 0
async def get_session_details(self, pid: int) -> Optional[Dict]:
"""Получение детальной информации о сессии"""
try:
if not psutil.pid_exists(pid):
return None
proc = psutil.Process(pid)
# Базовая информация о процессе
details = {
'pid': pid,
'ppid': proc.ppid(),
'username': proc.username(),
'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_info': proc.memory_info()._asdict(),
'status': proc.status(),
'cmdline': proc.cmdline(),
'cwd': proc.cwd(),
'exe': proc.exe()
}
# Сетевые соединения
try:
connections = []
for conn in proc.connections():
connections.append({
'fd': conn.fd,
'family': str(conn.family),
'type': str(conn.type),
'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None,
'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None,
'status': str(conn.status)
})
details['connections'] = connections
except Exception:
details['connections'] = []
# Открытые файлы
try:
open_files = []
for file in proc.open_files()[:10]: # Ограничиваем 10 файлами
open_files.append({
'path': file.path,
'fd': file.fd,
'mode': file.mode
})
details['open_files'] = open_files
except Exception:
details['open_files'] = []
# Переменные окружения (выборочно)
try:
env = proc.environ()
safe_env = {}
safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM']
for key in safe_keys:
if key in env:
safe_env[key] = env[key]
details['environment'] = safe_env
except Exception:
details['environment'] = {}
return details
except Exception as e:
logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}")
return None
async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]:
"""Мониторинг активности сессии в течение времени"""
try:
if not psutil.pid_exists(pid):
return []
activity_log = []
proc = psutil.Process(pid)
start_time = datetime.now()
end_time = start_time + timedelta(seconds=duration)
while datetime.now() < end_time:
try:
# Снимок состояния процесса
snapshot = {
'timestamp': datetime.now().isoformat(),
'cpu_percent': proc.cpu_percent(),
'memory_percent': proc.memory_percent(),
'num_threads': proc.num_threads(),
'num_fds': proc.num_fds(),
'status': proc.status()
}
# Проверяем новые соединения
try:
connections = len(proc.connections())
snapshot['connections_count'] = connections
except Exception:
snapshot['connections_count'] = 0
activity_log.append(snapshot)
await asyncio.sleep(5) # Снимок каждые 5 секунд
except psutil.NoSuchProcess:
# Процесс завершился
activity_log.append({
'timestamp': datetime.now().isoformat(),
'event': 'process_terminated'
})
break
except Exception as e:
activity_log.append({
'timestamp': datetime.now().isoformat(),
'event': 'monitoring_error',
'error': str(e)
})
return activity_log
except Exception as e:
logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}")
return []
async def get_session_statistics(self) -> Dict:
"""Получение общей статистики по сессиям"""
try:
sessions = await self.get_active_sessions()
stats = {
'total_sessions': len(sessions),
'users': {},
'tty_types': {},
'session_ages': [],
'total_connections': 0
}
for session in sessions:
# Статистика по пользователям
user = session['username']
if user not in stats['users']:
stats['users'][user] = 0
stats['users'][user] += 1
# Статистика по типам TTY
tty = session.get('tty', 'unknown')
tty_type = 'console' if tty.startswith('tty') else 'ssh'
if tty_type not in stats['tty_types']:
stats['tty_types'][tty_type] = 0
stats['tty_types'][tty_type] += 1
# Возраст сессии
if 'create_time' in session:
try:
create_time = datetime.fromisoformat(session['create_time'])
age_seconds = (datetime.now() - create_time).total_seconds()
stats['session_ages'].append(age_seconds)
except Exception:
pass
# Количество соединений
connections = session.get('connections', 0)
if isinstance(connections, int):
stats['total_connections'] += connections
# Средний возраст сессий
if stats['session_ages']:
stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages'])
else:
stats['average_session_age'] = 0
return stats
except Exception as e:
logger.error(f"Ошибка получения статистики сессий: {e}")
return {'error': str(e)}
async def find_suspicious_sessions(self) -> List[Dict]:
"""Поиск подозрительных сессий"""
try:
sessions = await self.get_active_sessions()
suspicious = []
for session in sessions:
suspicion_score = 0
reasons = []
# Проверка 1: Много открытых соединений
connections = session.get('connections', 0)
if isinstance(connections, int) and connections > 10:
suspicion_score += 2
reasons.append(f"Много соединений: {connections}")
# Проверка 2: Высокое потребление CPU
cpu = session.get('cpu_percent', 0)
if isinstance(cpu, (int, float)) and cpu > 50:
suspicion_score += 1
reasons.append(f"Высокая нагрузка CPU: {cpu}%")
# Проверка 3: Долго активная сессия
if 'create_time' in session:
try:
create_time = datetime.fromisoformat(session['create_time'])
age_hours = (datetime.now() - create_time).total_seconds() / 3600
if age_hours > 24: # Больше суток
suspicion_score += 1
reasons.append(f"Долгая сессия: {age_hours:.1f} часов")
except Exception:
pass
# Проверка 4: Подозрительные команды в cmdline
cmdline = session.get('cmdline', [])
if isinstance(cmdline, list):
suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash']
for cmd in cmdline:
if any(susp in cmd.lower() for susp in suspicious_commands):
suspicion_score += 1
reasons.append(f"Подозрительная команда: {cmd}")
break
# Если набрали достаточно очков подозрительности
if suspicion_score >= 2:
session['suspicion_score'] = suspicion_score
session['suspicion_reasons'] = reasons
suspicious.append(session)
return suspicious
except Exception as e:
logger.error(f"Ошибка поиска подозрительных сессий: {e}")
return []

View File

@@ -0,0 +1,421 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Create test database
self.db = Database(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Database(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -0,0 +1,421 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Create test database
self.db = Database(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -0,0 +1,421 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Create test database
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -0,0 +1,421 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Create test database
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.test_secret = 'test_secret_key_123'
self.auth = AgentAuthentication(self.test_secret)
# Create test database
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -0,0 +1,422 @@
#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.test_secret = 'test_secret_key_123'
self.auth = AgentAuthentication(self.test_secret)
# Create test database
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())

View File

@@ -7,7 +7,7 @@ import asyncio
import logging
import re
import os
from datetime import datetime
from datetime import datetime, timedelta
from typing import Dict, List, Optional
import psutil

View File

@@ -18,7 +18,7 @@ from unittest.mock import Mock, patch, MagicMock
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Database
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
@@ -28,10 +28,11 @@ class TestAgentAuthentication(unittest.TestCase):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
self.test_secret = 'test_secret_key_123'
self.auth = AgentAuthentication(self.test_secret)
# Create test database
self.db = Database(self.db_path)
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):
@@ -245,7 +246,7 @@ class TestDatabase(unittest.TestCase):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Database(self.db_path)
self.db = Storage(self.db_path)
self.db.create_tables()
def tearDown(self):