diff --git a/.history/src/sessions_20251125213303.py b/.history/src/sessions_20251125213303.py new file mode 100644 index 0000000..6458cb1 --- /dev/null +++ b/.history/src/sessions_20251125213303.py @@ -0,0 +1,488 @@ +""" +Sessions module для PyGuardian +Управление SSH сессиями и процессами пользователей +""" + +import asyncio +import logging +import re +import os +from datetime import datetime, timedelta +from typing import Dict, List, Optional +import psutil + +logger = logging.getLogger(__name__) + + +class SessionManager: + """Менеджер SSH сессий и пользовательских процессов""" + + def __init__(self): + pass + + async def get_active_sessions(self) -> List[Dict]: + """Получение всех активных SSH сессий""" + try: + sessions = [] + + # Метод 1: через who + who_sessions = await self._get_sessions_via_who() + sessions.extend(who_sessions) + + # Метод 2: через ps (для SSH процессов) + ssh_sessions = await self._get_sessions_via_ps() + sessions.extend(ssh_sessions) + + # Убираем дубликаты и объединяем информацию + unique_sessions = self._merge_session_info(sessions) + + return unique_sessions + + except Exception as e: + logger.error(f"Ошибка получения активных сессий: {e}") + return [] + + async def _get_sessions_via_who(self) -> List[Dict]: + """Получение сессий через команду who""" + try: + sessions = [] + + process = await asyncio.create_subprocess_exec( + 'who', '-u', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + lines = stdout.decode().strip().split('\n') + for line in lines: + if line.strip(): + # Парсим вывод who + # Формат: user tty date time (idle) pid (comment) + match = re.match( + r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)', + line + ) + if match: + username, tty, date, time, pid = match.groups() + sessions.append({ + 'username': username, + 'tty': tty, + 'login_date': date, + 'login_time': time, + 'pid': int(pid), + 'type': 'who', + 'status': 'active' + }) + else: + # Альтернативный парсинг для разных форматов who + parts = line.split() + if len(parts) >= 2: + username = parts[0] + tty = parts[1] + + # Ищем PID в скобках + pid_match = re.search(r'\((\d+)\)', line) + pid = int(pid_match.group(1)) if pid_match else None + + sessions.append({ + 'username': username, + 'tty': tty, + 'pid': pid, + 'type': 'who', + 'status': 'active', + 'raw_line': line + }) + + return sessions + + except Exception as e: + logger.error(f"Ошибка получения сессий через who: {e}") + return [] + + async def _get_sessions_via_ps(self) -> List[Dict]: + """Получение SSH сессий через ps""" + try: + sessions = [] + + # Ищем SSH процессы + process = await asyncio.create_subprocess_exec( + 'ps', 'aux', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + lines = stdout.decode().strip().split('\n') + for line in lines[1:]: # Пропускаем заголовок + if 'sshd:' in line and '@pts' in line: + # Парсим SSH сессии + parts = line.split() + if len(parts) >= 11: + username = parts[0] + pid = int(parts[1]) + + # Извлекаем информацию из команды + cmd_parts = ' '.join(parts[10:]) + + # Ищем пользователя и tty в команде sshd + match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts) + if match: + ssh_user, tty = match.groups() + + sessions.append({ + 'username': ssh_user, + 'tty': tty, + 'pid': pid, + 'ppid': int(parts[2]), + 'cpu': parts[2], + 'mem': parts[3], + 'start_time': parts[8], + 'type': 'sshd', + 'status': 'active', + 'command': cmd_parts + }) + + return sessions + + except Exception as e: + logger.error(f"Ошибка получения SSH сессий через ps: {e}") + return [] + + def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]: + """Объединение информации о сессиях и удаление дубликатов""" + try: + merged = {} + + for session in sessions: + key = f"{session['username']}:{session.get('tty', 'unknown')}" + + if key in merged: + # Обновляем существующую запись дополнительной информацией + merged[key].update({k: v for k, v in session.items() if v is not None}) + else: + merged[key] = session.copy() + + # Добавляем дополнительную информацию о процессах + for session in merged.values(): + if session.get('pid'): + try: + # Получаем дополнительную информацию о процессе через psutil + if psutil.pid_exists(session['pid']): + proc = psutil.Process(session['pid']) + session.update({ + 'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_info': proc.memory_info()._asdict(), + 'connections': len(proc.connections()) + }) + except Exception: + pass # Игнорируем ошибки получения доп. информации + + return list(merged.values()) + + except Exception as e: + logger.error(f"Ошибка объединения информации о сессиях: {e}") + return sessions + + async def get_user_sessions(self, username: str) -> List[Dict]: + """Получение сессий конкретного пользователя""" + try: + all_sessions = await self.get_active_sessions() + return [s for s in all_sessions if s['username'] == username] + + except Exception as e: + logger.error(f"Ошибка получения сессий пользователя {username}: {e}") + return [] + + async def terminate_session(self, pid: int) -> bool: + """Завершение сессии по PID""" + try: + # Сначала пробуем TERM + process = await asyncio.create_subprocess_exec( + 'kill', '-TERM', str(pid), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + await process.communicate() + + if process.returncode == 0: + # Ждем немного и проверяем + await asyncio.sleep(2) + + if not psutil.pid_exists(pid): + logger.info(f"✅ Сессия PID {pid} завершена через TERM") + return True + else: + # Если не помогло - используем KILL + process = await asyncio.create_subprocess_exec( + 'kill', '-KILL', str(pid), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + await process.communicate() + + if process.returncode == 0: + logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL") + return True + + logger.error(f"❌ Не удалось завершить сессию PID {pid}") + return False + + except Exception as e: + logger.error(f"Ошибка завершения сессии PID {pid}: {e}") + return False + + async def terminate_user_sessions(self, username: str) -> int: + """Завершение всех сессий пользователя""" + try: + user_sessions = await self.get_user_sessions(username) + terminated = 0 + + for session in user_sessions: + pid = session.get('pid') + if pid: + success = await self.terminate_session(pid) + if success: + terminated += 1 + + logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}") + return terminated + + except Exception as e: + logger.error(f"Ошибка завершения сессий пользователя {username}: {e}") + return 0 + + async def get_session_details(self, pid: int) -> Optional[Dict]: + """Получение детальной информации о сессии""" + try: + if not psutil.pid_exists(pid): + return None + + proc = psutil.Process(pid) + + # Базовая информация о процессе + details = { + 'pid': pid, + 'ppid': proc.ppid(), + 'username': proc.username(), + 'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_info': proc.memory_info()._asdict(), + 'status': proc.status(), + 'cmdline': proc.cmdline(), + 'cwd': proc.cwd(), + 'exe': proc.exe() + } + + # Сетевые соединения + try: + connections = [] + for conn in proc.connections(): + connections.append({ + 'fd': conn.fd, + 'family': str(conn.family), + 'type': str(conn.type), + 'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None, + 'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None, + 'status': str(conn.status) + }) + details['connections'] = connections + except Exception: + details['connections'] = [] + + # Открытые файлы + try: + open_files = [] + for file in proc.open_files()[:10]: # Ограничиваем 10 файлами + open_files.append({ + 'path': file.path, + 'fd': file.fd, + 'mode': file.mode + }) + details['open_files'] = open_files + except Exception: + details['open_files'] = [] + + # Переменные окружения (выборочно) + try: + env = proc.environ() + safe_env = {} + safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM'] + for key in safe_keys: + if key in env: + safe_env[key] = env[key] + details['environment'] = safe_env + except Exception: + details['environment'] = {} + + return details + + except Exception as e: + logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}") + return None + + async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]: + """Мониторинг активности сессии в течение времени""" + try: + if not psutil.pid_exists(pid): + return [] + + activity_log = [] + proc = psutil.Process(pid) + + start_time = datetime.now() + end_time = start_time + timedelta(seconds=duration) + + while datetime.now() < end_time: + try: + # Снимок состояния процесса + snapshot = { + 'timestamp': datetime.now().isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_percent': proc.memory_percent(), + 'num_threads': proc.num_threads(), + 'num_fds': proc.num_fds(), + 'status': proc.status() + } + + # Проверяем новые соединения + try: + connections = len(proc.connections()) + snapshot['connections_count'] = connections + except Exception: + snapshot['connections_count'] = 0 + + activity_log.append(snapshot) + + await asyncio.sleep(5) # Снимок каждые 5 секунд + + except psutil.NoSuchProcess: + # Процесс завершился + activity_log.append({ + 'timestamp': datetime.now().isoformat(), + 'event': 'process_terminated' + }) + break + except Exception as e: + activity_log.append({ + 'timestamp': datetime.now().isoformat(), + 'event': 'monitoring_error', + 'error': str(e) + }) + + return activity_log + + except Exception as e: + logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}") + return [] + + async def get_session_statistics(self) -> Dict: + """Получение общей статистики по сессиям""" + try: + sessions = await self.get_active_sessions() + + stats = { + 'total_sessions': len(sessions), + 'users': {}, + 'tty_types': {}, + 'session_ages': [], + 'total_connections': 0 + } + + for session in sessions: + # Статистика по пользователям + user = session['username'] + if user not in stats['users']: + stats['users'][user] = 0 + stats['users'][user] += 1 + + # Статистика по типам TTY + tty = session.get('tty', 'unknown') + tty_type = 'console' if tty.startswith('tty') else 'ssh' + if tty_type not in stats['tty_types']: + stats['tty_types'][tty_type] = 0 + stats['tty_types'][tty_type] += 1 + + # Возраст сессии + if 'create_time' in session: + try: + create_time = datetime.fromisoformat(session['create_time']) + age_seconds = (datetime.now() - create_time).total_seconds() + stats['session_ages'].append(age_seconds) + except Exception: + pass + + # Количество соединений + connections = session.get('connections', 0) + if isinstance(connections, int): + stats['total_connections'] += connections + + # Средний возраст сессий + if stats['session_ages']: + stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages']) + else: + stats['average_session_age'] = 0 + + return stats + + except Exception as e: + logger.error(f"Ошибка получения статистики сессий: {e}") + return {'error': str(e)} + + async def find_suspicious_sessions(self) -> List[Dict]: + """Поиск подозрительных сессий""" + try: + sessions = await self.get_active_sessions() + suspicious = [] + + for session in sessions: + suspicion_score = 0 + reasons = [] + + # Проверка 1: Много открытых соединений + connections = session.get('connections', 0) + if isinstance(connections, int) and connections > 10: + suspicion_score += 2 + reasons.append(f"Много соединений: {connections}") + + # Проверка 2: Высокое потребление CPU + cpu = session.get('cpu_percent', 0) + if isinstance(cpu, (int, float)) and cpu > 50: + suspicion_score += 1 + reasons.append(f"Высокая нагрузка CPU: {cpu}%") + + # Проверка 3: Долго активная сессия + if 'create_time' in session: + try: + create_time = datetime.fromisoformat(session['create_time']) + age_hours = (datetime.now() - create_time).total_seconds() / 3600 + if age_hours > 24: # Больше суток + suspicion_score += 1 + reasons.append(f"Долгая сессия: {age_hours:.1f} часов") + except Exception: + pass + + # Проверка 4: Подозрительные команды в cmdline + cmdline = session.get('cmdline', []) + if isinstance(cmdline, list): + suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash'] + for cmd in cmdline: + if any(susp in cmd.lower() for susp in suspicious_commands): + suspicion_score += 1 + reasons.append(f"Подозрительная команда: {cmd}") + break + + # Если набрали достаточно очков подозрительности + if suspicion_score >= 2: + session['suspicion_score'] = suspicion_score + session['suspicion_reasons'] = reasons + suspicious.append(session) + + return suspicious + + except Exception as e: + logger.error(f"Ошибка поиска подозрительных сессий: {e}") + return [] \ No newline at end of file diff --git a/.history/src/sessions_20251125213308.py b/.history/src/sessions_20251125213308.py new file mode 100644 index 0000000..6458cb1 --- /dev/null +++ b/.history/src/sessions_20251125213308.py @@ -0,0 +1,488 @@ +""" +Sessions module для PyGuardian +Управление SSH сессиями и процессами пользователей +""" + +import asyncio +import logging +import re +import os +from datetime import datetime, timedelta +from typing import Dict, List, Optional +import psutil + +logger = logging.getLogger(__name__) + + +class SessionManager: + """Менеджер SSH сессий и пользовательских процессов""" + + def __init__(self): + pass + + async def get_active_sessions(self) -> List[Dict]: + """Получение всех активных SSH сессий""" + try: + sessions = [] + + # Метод 1: через who + who_sessions = await self._get_sessions_via_who() + sessions.extend(who_sessions) + + # Метод 2: через ps (для SSH процессов) + ssh_sessions = await self._get_sessions_via_ps() + sessions.extend(ssh_sessions) + + # Убираем дубликаты и объединяем информацию + unique_sessions = self._merge_session_info(sessions) + + return unique_sessions + + except Exception as e: + logger.error(f"Ошибка получения активных сессий: {e}") + return [] + + async def _get_sessions_via_who(self) -> List[Dict]: + """Получение сессий через команду who""" + try: + sessions = [] + + process = await asyncio.create_subprocess_exec( + 'who', '-u', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + lines = stdout.decode().strip().split('\n') + for line in lines: + if line.strip(): + # Парсим вывод who + # Формат: user tty date time (idle) pid (comment) + match = re.match( + r'(\w+)\s+(\w+)\s+(\d{4}-\d{2}-\d{2})\s+(\d{2}:\d{2})\s+.*?\((\d+)\)', + line + ) + if match: + username, tty, date, time, pid = match.groups() + sessions.append({ + 'username': username, + 'tty': tty, + 'login_date': date, + 'login_time': time, + 'pid': int(pid), + 'type': 'who', + 'status': 'active' + }) + else: + # Альтернативный парсинг для разных форматов who + parts = line.split() + if len(parts) >= 2: + username = parts[0] + tty = parts[1] + + # Ищем PID в скобках + pid_match = re.search(r'\((\d+)\)', line) + pid = int(pid_match.group(1)) if pid_match else None + + sessions.append({ + 'username': username, + 'tty': tty, + 'pid': pid, + 'type': 'who', + 'status': 'active', + 'raw_line': line + }) + + return sessions + + except Exception as e: + logger.error(f"Ошибка получения сессий через who: {e}") + return [] + + async def _get_sessions_via_ps(self) -> List[Dict]: + """Получение SSH сессий через ps""" + try: + sessions = [] + + # Ищем SSH процессы + process = await asyncio.create_subprocess_exec( + 'ps', 'aux', + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + stdout, stderr = await process.communicate() + + if process.returncode == 0: + lines = stdout.decode().strip().split('\n') + for line in lines[1:]: # Пропускаем заголовок + if 'sshd:' in line and '@pts' in line: + # Парсим SSH сессии + parts = line.split() + if len(parts) >= 11: + username = parts[0] + pid = int(parts[1]) + + # Извлекаем информацию из команды + cmd_parts = ' '.join(parts[10:]) + + # Ищем пользователя и tty в команде sshd + match = re.search(r'sshd:\s+(\w+)@(\w+)', cmd_parts) + if match: + ssh_user, tty = match.groups() + + sessions.append({ + 'username': ssh_user, + 'tty': tty, + 'pid': pid, + 'ppid': int(parts[2]), + 'cpu': parts[2], + 'mem': parts[3], + 'start_time': parts[8], + 'type': 'sshd', + 'status': 'active', + 'command': cmd_parts + }) + + return sessions + + except Exception as e: + logger.error(f"Ошибка получения SSH сессий через ps: {e}") + return [] + + def _merge_session_info(self, sessions: List[Dict]) -> List[Dict]: + """Объединение информации о сессиях и удаление дубликатов""" + try: + merged = {} + + for session in sessions: + key = f"{session['username']}:{session.get('tty', 'unknown')}" + + if key in merged: + # Обновляем существующую запись дополнительной информацией + merged[key].update({k: v for k, v in session.items() if v is not None}) + else: + merged[key] = session.copy() + + # Добавляем дополнительную информацию о процессах + for session in merged.values(): + if session.get('pid'): + try: + # Получаем дополнительную информацию о процессе через psutil + if psutil.pid_exists(session['pid']): + proc = psutil.Process(session['pid']) + session.update({ + 'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_info': proc.memory_info()._asdict(), + 'connections': len(proc.connections()) + }) + except Exception: + pass # Игнорируем ошибки получения доп. информации + + return list(merged.values()) + + except Exception as e: + logger.error(f"Ошибка объединения информации о сессиях: {e}") + return sessions + + async def get_user_sessions(self, username: str) -> List[Dict]: + """Получение сессий конкретного пользователя""" + try: + all_sessions = await self.get_active_sessions() + return [s for s in all_sessions if s['username'] == username] + + except Exception as e: + logger.error(f"Ошибка получения сессий пользователя {username}: {e}") + return [] + + async def terminate_session(self, pid: int) -> bool: + """Завершение сессии по PID""" + try: + # Сначала пробуем TERM + process = await asyncio.create_subprocess_exec( + 'kill', '-TERM', str(pid), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + await process.communicate() + + if process.returncode == 0: + # Ждем немного и проверяем + await asyncio.sleep(2) + + if not psutil.pid_exists(pid): + logger.info(f"✅ Сессия PID {pid} завершена через TERM") + return True + else: + # Если не помогло - используем KILL + process = await asyncio.create_subprocess_exec( + 'kill', '-KILL', str(pid), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE + ) + await process.communicate() + + if process.returncode == 0: + logger.info(f"🔪 Сессия PID {pid} принудительно завершена через KILL") + return True + + logger.error(f"❌ Не удалось завершить сессию PID {pid}") + return False + + except Exception as e: + logger.error(f"Ошибка завершения сессии PID {pid}: {e}") + return False + + async def terminate_user_sessions(self, username: str) -> int: + """Завершение всех сессий пользователя""" + try: + user_sessions = await self.get_user_sessions(username) + terminated = 0 + + for session in user_sessions: + pid = session.get('pid') + if pid: + success = await self.terminate_session(pid) + if success: + terminated += 1 + + logger.info(f"Завершено {terminated} из {len(user_sessions)} сессий пользователя {username}") + return terminated + + except Exception as e: + logger.error(f"Ошибка завершения сессий пользователя {username}: {e}") + return 0 + + async def get_session_details(self, pid: int) -> Optional[Dict]: + """Получение детальной информации о сессии""" + try: + if not psutil.pid_exists(pid): + return None + + proc = psutil.Process(pid) + + # Базовая информация о процессе + details = { + 'pid': pid, + 'ppid': proc.ppid(), + 'username': proc.username(), + 'create_time': datetime.fromtimestamp(proc.create_time()).isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_info': proc.memory_info()._asdict(), + 'status': proc.status(), + 'cmdline': proc.cmdline(), + 'cwd': proc.cwd(), + 'exe': proc.exe() + } + + # Сетевые соединения + try: + connections = [] + for conn in proc.connections(): + connections.append({ + 'fd': conn.fd, + 'family': str(conn.family), + 'type': str(conn.type), + 'local_address': f"{conn.laddr.ip}:{conn.laddr.port}" if conn.laddr else None, + 'remote_address': f"{conn.raddr.ip}:{conn.raddr.port}" if conn.raddr else None, + 'status': str(conn.status) + }) + details['connections'] = connections + except Exception: + details['connections'] = [] + + # Открытые файлы + try: + open_files = [] + for file in proc.open_files()[:10]: # Ограничиваем 10 файлами + open_files.append({ + 'path': file.path, + 'fd': file.fd, + 'mode': file.mode + }) + details['open_files'] = open_files + except Exception: + details['open_files'] = [] + + # Переменные окружения (выборочно) + try: + env = proc.environ() + safe_env = {} + safe_keys = ['USER', 'HOME', 'SHELL', 'SSH_CLIENT', 'SSH_CONNECTION', 'TERM'] + for key in safe_keys: + if key in env: + safe_env[key] = env[key] + details['environment'] = safe_env + except Exception: + details['environment'] = {} + + return details + + except Exception as e: + logger.error(f"Ошибка получения деталей сессии PID {pid}: {e}") + return None + + async def monitor_session_activity(self, pid: int, duration: int = 60) -> List[Dict]: + """Мониторинг активности сессии в течение времени""" + try: + if not psutil.pid_exists(pid): + return [] + + activity_log = [] + proc = psutil.Process(pid) + + start_time = datetime.now() + end_time = start_time + timedelta(seconds=duration) + + while datetime.now() < end_time: + try: + # Снимок состояния процесса + snapshot = { + 'timestamp': datetime.now().isoformat(), + 'cpu_percent': proc.cpu_percent(), + 'memory_percent': proc.memory_percent(), + 'num_threads': proc.num_threads(), + 'num_fds': proc.num_fds(), + 'status': proc.status() + } + + # Проверяем новые соединения + try: + connections = len(proc.connections()) + snapshot['connections_count'] = connections + except Exception: + snapshot['connections_count'] = 0 + + activity_log.append(snapshot) + + await asyncio.sleep(5) # Снимок каждые 5 секунд + + except psutil.NoSuchProcess: + # Процесс завершился + activity_log.append({ + 'timestamp': datetime.now().isoformat(), + 'event': 'process_terminated' + }) + break + except Exception as e: + activity_log.append({ + 'timestamp': datetime.now().isoformat(), + 'event': 'monitoring_error', + 'error': str(e) + }) + + return activity_log + + except Exception as e: + logger.error(f"Ошибка мониторинга активности сессии PID {pid}: {e}") + return [] + + async def get_session_statistics(self) -> Dict: + """Получение общей статистики по сессиям""" + try: + sessions = await self.get_active_sessions() + + stats = { + 'total_sessions': len(sessions), + 'users': {}, + 'tty_types': {}, + 'session_ages': [], + 'total_connections': 0 + } + + for session in sessions: + # Статистика по пользователям + user = session['username'] + if user not in stats['users']: + stats['users'][user] = 0 + stats['users'][user] += 1 + + # Статистика по типам TTY + tty = session.get('tty', 'unknown') + tty_type = 'console' if tty.startswith('tty') else 'ssh' + if tty_type not in stats['tty_types']: + stats['tty_types'][tty_type] = 0 + stats['tty_types'][tty_type] += 1 + + # Возраст сессии + if 'create_time' in session: + try: + create_time = datetime.fromisoformat(session['create_time']) + age_seconds = (datetime.now() - create_time).total_seconds() + stats['session_ages'].append(age_seconds) + except Exception: + pass + + # Количество соединений + connections = session.get('connections', 0) + if isinstance(connections, int): + stats['total_connections'] += connections + + # Средний возраст сессий + if stats['session_ages']: + stats['average_session_age'] = sum(stats['session_ages']) / len(stats['session_ages']) + else: + stats['average_session_age'] = 0 + + return stats + + except Exception as e: + logger.error(f"Ошибка получения статистики сессий: {e}") + return {'error': str(e)} + + async def find_suspicious_sessions(self) -> List[Dict]: + """Поиск подозрительных сессий""" + try: + sessions = await self.get_active_sessions() + suspicious = [] + + for session in sessions: + suspicion_score = 0 + reasons = [] + + # Проверка 1: Много открытых соединений + connections = session.get('connections', 0) + if isinstance(connections, int) and connections > 10: + suspicion_score += 2 + reasons.append(f"Много соединений: {connections}") + + # Проверка 2: Высокое потребление CPU + cpu = session.get('cpu_percent', 0) + if isinstance(cpu, (int, float)) and cpu > 50: + suspicion_score += 1 + reasons.append(f"Высокая нагрузка CPU: {cpu}%") + + # Проверка 3: Долго активная сессия + if 'create_time' in session: + try: + create_time = datetime.fromisoformat(session['create_time']) + age_hours = (datetime.now() - create_time).total_seconds() / 3600 + if age_hours > 24: # Больше суток + suspicion_score += 1 + reasons.append(f"Долгая сессия: {age_hours:.1f} часов") + except Exception: + pass + + # Проверка 4: Подозрительные команды в cmdline + cmdline = session.get('cmdline', []) + if isinstance(cmdline, list): + suspicious_commands = ['nc', 'netcat', 'wget', 'curl', 'python', 'perl', 'bash'] + for cmd in cmdline: + if any(susp in cmd.lower() for susp in suspicious_commands): + suspicion_score += 1 + reasons.append(f"Подозрительная команда: {cmd}") + break + + # Если набрали достаточно очков подозрительности + if suspicion_score >= 2: + session['suspicion_score'] = suspicion_score + session['suspicion_reasons'] = reasons + suspicious.append(session) + + return suspicious + + except Exception as e: + logger.error(f"Ошибка поиска подозрительных сессий: {e}") + return [] \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125212848.py b/.history/tests/unit/test_authentication_20251125212848.py new file mode 100644 index 0000000..176dd39 --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125212848.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Create test database + self.db = Database(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Database(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125212856.py b/.history/tests/unit/test_authentication_20251125212856.py new file mode 100644 index 0000000..de39089 --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125212856.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Create test database + self.db = Database(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125212909.py b/.history/tests/unit/test_authentication_20251125212909.py new file mode 100644 index 0000000..6387fbf --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125212909.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Create test database + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125213152.py b/.history/tests/unit/test_authentication_20251125213152.py new file mode 100644 index 0000000..6387fbf --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125213152.py @@ -0,0 +1,421 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Create test database + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125213226.py b/.history/tests/unit/test_authentication_20251125213226.py new file mode 100644 index 0000000..3f5ddfc --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125213226.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.test_secret = 'test_secret_key_123' + self.auth = AgentAuthentication(self.test_secret) + + # Create test database + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/.history/tests/unit/test_authentication_20251125213251.py b/.history/tests/unit/test_authentication_20251125213251.py new file mode 100644 index 0000000..3f5ddfc --- /dev/null +++ b/.history/tests/unit/test_authentication_20251125213251.py @@ -0,0 +1,422 @@ +#!/usr/bin/env python3 +""" +Comprehensive unit tests for PyGuardian authentication system. +""" + +import unittest +import tempfile +import os +import sys +import sqlite3 +import jwt +import hashlib +import hmac +from datetime import datetime, timedelta +from unittest.mock import Mock, patch, MagicMock + +# Add src directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) + +from auth import AgentAuthentication +from storage import Storage + + +class TestAgentAuthentication(unittest.TestCase): + """Test cases for agent authentication system.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.test_secret = 'test_secret_key_123' + self.auth = AgentAuthentication(self.test_secret) + + # Create test database + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_generate_agent_id(self): + """Test agent ID generation.""" + agent_id = self.auth.generate_agent_id() + + # Check format + self.assertTrue(agent_id.startswith('agent_')) + self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID + + # Test uniqueness + agent_id2 = self.auth.generate_agent_id() + self.assertNotEqual(agent_id, agent_id2) + + def test_create_agent_credentials(self): + """Test agent credentials creation.""" + agent_id = self.auth.generate_agent_id() + credentials = self.auth.create_agent_credentials(agent_id) + + # Check required fields + required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] + for field in required_fields: + self.assertIn(field, credentials) + + # Check agent ID matches + self.assertEqual(credentials['agent_id'], agent_id) + + # Check secret key length + self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded + + # Check key hash + expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() + self.assertEqual(credentials['key_hash'], expected_hash) + + def test_generate_jwt_token(self): + """Test JWT token generation.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + token = self.auth.generate_jwt_token(agent_id, secret_key) + + # Verify token structure + self.assertIsInstance(token, str) + self.assertTrue(len(token) > 100) # JWT tokens are typically long + + # Decode and verify payload + decoded = jwt.decode(token, secret_key, algorithms=['HS256']) + self.assertEqual(decoded['agent_id'], agent_id) + self.assertIn('iat', decoded) + self.assertIn('exp', decoded) + self.assertIn('jti', decoded) + + def test_verify_jwt_token_valid(self): + """Test JWT token verification with valid token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + token = self.auth.generate_jwt_token(agent_id, secret_key) + + is_valid = self.auth.verify_jwt_token(token, secret_key) + self.assertTrue(is_valid) + + def test_verify_jwt_token_invalid(self): + """Test JWT token verification with invalid token.""" + secret_key = self.auth._generate_secret_key() + + # Test with invalid token + is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) + self.assertFalse(is_valid) + + # Test with wrong secret key + agent_id = self.auth.generate_agent_id() + token = self.auth.generate_jwt_token(agent_id, secret_key) + wrong_key = self.auth._generate_secret_key() + + is_valid = self.auth.verify_jwt_token(token, wrong_key) + self.assertFalse(is_valid) + + def test_verify_jwt_token_expired(self): + """Test JWT token verification with expired token.""" + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Create expired token + payload = { + 'agent_id': agent_id, + 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago + 'iat': datetime.utcnow() - timedelta(hours=2), + 'jti': self.auth._generate_jti() + } + + expired_token = jwt.encode(payload, secret_key, algorithm='HS256') + + is_valid = self.auth.verify_jwt_token(expired_token, secret_key) + self.assertFalse(is_valid) + + def test_create_hmac_signature(self): + """Test HMAC signature creation.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + + # Verify signature format + self.assertEqual(len(signature), 64) # SHA256 hex digest + + # Verify signature is correct + expected = hmac.new( + secret_key.encode(), + data.encode(), + hashlib.sha256 + ).hexdigest() + + self.assertEqual(signature, expected) + + def test_verify_hmac_signature_valid(self): + """Test HMAC signature verification with valid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + signature = self.auth.create_hmac_signature(data, secret_key) + is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) + + self.assertTrue(is_valid) + + def test_verify_hmac_signature_invalid(self): + """Test HMAC signature verification with invalid signature.""" + data = "test message" + secret_key = self.auth._generate_secret_key() + + # Test with wrong signature + wrong_signature = "0" * 64 + is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) + self.assertFalse(is_valid) + + # Test with wrong key + signature = self.auth.create_hmac_signature(data, secret_key) + wrong_key = self.auth._generate_secret_key() + is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) + self.assertFalse(is_valid) + + def test_encrypt_decrypt_secret_key(self): + """Test secret key encryption and decryption.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + decrypted = self.auth.decrypt_secret_key(encrypted, password) + + self.assertEqual(secret_key, decrypted) + + def test_encrypt_decrypt_wrong_password(self): + """Test secret key decryption with wrong password.""" + secret_key = self.auth._generate_secret_key() + password = "test_password" + wrong_password = "wrong_password" + + encrypted = self.auth.encrypt_secret_key(secret_key, password) + + with self.assertRaises(Exception): + self.auth.decrypt_secret_key(encrypted, wrong_password) + + @patch('src.auth.Database') + def test_authenticate_agent_success(self, mock_db_class): + """Test successful agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + key_hash = hashlib.sha256(secret_key.encode()).hexdigest() + + # Mock database response + mock_db.get_agent_credentials.return_value = { + 'agent_id': agent_id, + 'key_hash': key_hash, + 'is_active': True, + 'created_at': datetime.now().isoformat() + } + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertTrue(result) + + @patch('src.auth.Database') + def test_authenticate_agent_failure(self, mock_db_class): + """Test failed agent authentication.""" + # Mock database + mock_db = Mock() + mock_db_class.return_value = mock_db + + agent_id = self.auth.generate_agent_id() + secret_key = self.auth._generate_secret_key() + + # Mock database response - no credentials found + mock_db.get_agent_credentials.return_value = None + + result = self.auth.authenticate_agent(agent_id, secret_key) + self.assertFalse(result) + + +class TestDatabase(unittest.TestCase): + """Test cases for database operations.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.db = Storage(self.db_path) + self.db.create_tables() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_create_agent_auth(self): + """Test agent authentication record creation.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + self.assertTrue(success) + + # Verify record exists + credentials = self.db.get_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + + def test_get_agent_credentials_exists(self): + """Test retrieving existing agent credentials.""" + agent_id = "agent_test123" + secret_key_hash = "test_hash" + encrypted_key = "encrypted_test_key" + + # Create record + self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) + + # Retrieve record + credentials = self.db.get_agent_credentials(agent_id) + + self.assertIsNotNone(credentials) + self.assertEqual(credentials['agent_id'], agent_id) + self.assertEqual(credentials['key_hash'], secret_key_hash) + self.assertTrue(credentials['is_active']) + + def test_get_agent_credentials_not_exists(self): + """Test retrieving non-existent agent credentials.""" + credentials = self.db.get_agent_credentials("non_existent_agent") + self.assertIsNone(credentials) + + def test_store_agent_token(self): + """Test storing agent JWT token.""" + agent_id = "agent_test123" + token = "test_jwt_token" + expires_at = (datetime.now() + timedelta(hours=1)).isoformat() + + success = self.db.store_agent_token(agent_id, token, expires_at) + self.assertTrue(success) + + # Verify token exists + stored_token = self.db.get_agent_token(agent_id) + self.assertIsNotNone(stored_token) + self.assertEqual(stored_token['token'], token) + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + agent_id = "agent_test123" + + # Create expired token + expired_token = "expired_token" + expired_time = (datetime.now() - timedelta(hours=1)).isoformat() + self.db.store_agent_token(agent_id, expired_token, expired_time) + + # Create valid token + valid_token = "valid_token" + valid_time = (datetime.now() + timedelta(hours=1)).isoformat() + self.db.store_agent_token("agent_valid", valid_token, valid_time) + + # Cleanup expired tokens + cleaned = self.db.cleanup_expired_tokens() + self.assertGreaterEqual(cleaned, 1) + + # Verify expired token is gone + token = self.db.get_agent_token(agent_id) + self.assertIsNone(token) + + # Verify valid token remains + token = self.db.get_agent_token("agent_valid") + self.assertIsNotNone(token) + + +class TestIntegration(unittest.TestCase): + """Integration tests for the complete authentication flow.""" + + def setUp(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') + self.auth = AgentAuthentication() + + # Use test database + self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.remove(self.db_path) + os.rmdir(self.temp_dir) + + def test_complete_authentication_flow(self): + """Test complete agent authentication workflow.""" + # Step 1: Generate agent ID + agent_id = self.auth.generate_agent_id() + self.assertIsNotNone(agent_id) + + # Step 2: Create credentials + credentials = self.auth.create_agent_credentials(agent_id) + self.assertIsNotNone(credentials) + + # Step 3: Generate JWT token + token = self.auth.generate_jwt_token( + credentials['agent_id'], + credentials['secret_key'] + ) + self.assertIsNotNone(token) + + # Step 4: Verify token + is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) + self.assertTrue(is_valid) + + # Step 5: Create HMAC signature + test_data = "test API request" + signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) + self.assertIsNotNone(signature) + + # Step 6: Verify HMAC signature + is_signature_valid = self.auth.verify_hmac_signature( + test_data, signature, credentials['secret_key'] + ) + self.assertTrue(is_signature_valid) + + +def run_tests(): + """Run all tests.""" + print("🧪 Running PyGuardian Authentication Tests...") + print("=" * 50) + + # Create test suite + test_suite = unittest.TestSuite() + + # Add test classes + test_classes = [ + TestAgentAuthentication, + TestDatabase, + TestIntegration + ] + + for test_class in test_classes: + tests = unittest.TestLoader().loadTestsFromTestCase(test_class) + test_suite.addTests(tests) + + # Run tests + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(test_suite) + + # Print summary + print("\n" + "=" * 50) + print(f"🏁 Tests completed:") + print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f" ❌ Failed: {len(result.failures)}") + print(f" 💥 Errors: {len(result.errors)}") + + # Return exit code + return 0 if result.wasSuccessful() else 1 + + +if __name__ == '__main__': + sys.exit(run_tests()) \ No newline at end of file diff --git a/src/sessions.py b/src/sessions.py index e2c6eb0..6458cb1 100644 --- a/src/sessions.py +++ b/src/sessions.py @@ -7,7 +7,7 @@ import asyncio import logging import re import os -from datetime import datetime +from datetime import datetime, timedelta from typing import Dict, List, Optional import psutil diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py index d7e49ba..3f5ddfc 100644 --- a/tests/unit/test_authentication.py +++ b/tests/unit/test_authentication.py @@ -18,7 +18,7 @@ from unittest.mock import Mock, patch, MagicMock sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) from auth import AgentAuthentication -from storage import Database +from storage import Storage class TestAgentAuthentication(unittest.TestCase): @@ -28,10 +28,11 @@ class TestAgentAuthentication(unittest.TestCase): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') - self.auth = AgentAuthentication() + self.test_secret = 'test_secret_key_123' + self.auth = AgentAuthentication(self.test_secret) # Create test database - self.db = Database(self.db_path) + self.db = Storage(self.db_path) self.db.create_tables() def tearDown(self): @@ -245,7 +246,7 @@ class TestDatabase(unittest.TestCase): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') - self.db = Database(self.db_path) + self.db = Storage(self.db_path) self.db.create_tables() def tearDown(self):