#!/usr/bin/env python3 """ Comprehensive unit tests for PyGuardian authentication system. """ import unittest import tempfile import os import sys import sqlite3 import jwt import hashlib import hmac from datetime import datetime, timedelta from unittest.mock import Mock, patch, MagicMock # Add src directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src')) from auth import AgentAuthentication from storage import Storage class TestAgentAuthentication(unittest.TestCase): """Test cases for agent authentication system.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') self.auth = AgentAuthentication() # Create test database self.db = Database(self.db_path) self.db.create_tables() def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.db_path): os.remove(self.db_path) os.rmdir(self.temp_dir) def test_generate_agent_id(self): """Test agent ID generation.""" agent_id = self.auth.generate_agent_id() # Check format self.assertTrue(agent_id.startswith('agent_')) self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID # Test uniqueness agent_id2 = self.auth.generate_agent_id() self.assertNotEqual(agent_id, agent_id2) def test_create_agent_credentials(self): """Test agent credentials creation.""" agent_id = self.auth.generate_agent_id() credentials = self.auth.create_agent_credentials(agent_id) # Check required fields required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash'] for field in required_fields: self.assertIn(field, credentials) # Check agent ID matches self.assertEqual(credentials['agent_id'], agent_id) # Check secret key length self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded # Check key hash expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest() self.assertEqual(credentials['key_hash'], expected_hash) def test_generate_jwt_token(self): """Test JWT token generation.""" agent_id = self.auth.generate_agent_id() secret_key = self.auth._generate_secret_key() token = self.auth.generate_jwt_token(agent_id, secret_key) # Verify token structure self.assertIsInstance(token, str) self.assertTrue(len(token) > 100) # JWT tokens are typically long # Decode and verify payload decoded = jwt.decode(token, secret_key, algorithms=['HS256']) self.assertEqual(decoded['agent_id'], agent_id) self.assertIn('iat', decoded) self.assertIn('exp', decoded) self.assertIn('jti', decoded) def test_verify_jwt_token_valid(self): """Test JWT token verification with valid token.""" agent_id = self.auth.generate_agent_id() secret_key = self.auth._generate_secret_key() token = self.auth.generate_jwt_token(agent_id, secret_key) is_valid = self.auth.verify_jwt_token(token, secret_key) self.assertTrue(is_valid) def test_verify_jwt_token_invalid(self): """Test JWT token verification with invalid token.""" secret_key = self.auth._generate_secret_key() # Test with invalid token is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key) self.assertFalse(is_valid) # Test with wrong secret key agent_id = self.auth.generate_agent_id() token = self.auth.generate_jwt_token(agent_id, secret_key) wrong_key = self.auth._generate_secret_key() is_valid = self.auth.verify_jwt_token(token, wrong_key) self.assertFalse(is_valid) def test_verify_jwt_token_expired(self): """Test JWT token verification with expired token.""" agent_id = self.auth.generate_agent_id() secret_key = self.auth._generate_secret_key() # Create expired token payload = { 'agent_id': agent_id, 'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago 'iat': datetime.utcnow() - timedelta(hours=2), 'jti': self.auth._generate_jti() } expired_token = jwt.encode(payload, secret_key, algorithm='HS256') is_valid = self.auth.verify_jwt_token(expired_token, secret_key) self.assertFalse(is_valid) def test_create_hmac_signature(self): """Test HMAC signature creation.""" data = "test message" secret_key = self.auth._generate_secret_key() signature = self.auth.create_hmac_signature(data, secret_key) # Verify signature format self.assertEqual(len(signature), 64) # SHA256 hex digest # Verify signature is correct expected = hmac.new( secret_key.encode(), data.encode(), hashlib.sha256 ).hexdigest() self.assertEqual(signature, expected) def test_verify_hmac_signature_valid(self): """Test HMAC signature verification with valid signature.""" data = "test message" secret_key = self.auth._generate_secret_key() signature = self.auth.create_hmac_signature(data, secret_key) is_valid = self.auth.verify_hmac_signature(data, signature, secret_key) self.assertTrue(is_valid) def test_verify_hmac_signature_invalid(self): """Test HMAC signature verification with invalid signature.""" data = "test message" secret_key = self.auth._generate_secret_key() # Test with wrong signature wrong_signature = "0" * 64 is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key) self.assertFalse(is_valid) # Test with wrong key signature = self.auth.create_hmac_signature(data, secret_key) wrong_key = self.auth._generate_secret_key() is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key) self.assertFalse(is_valid) def test_encrypt_decrypt_secret_key(self): """Test secret key encryption and decryption.""" secret_key = self.auth._generate_secret_key() password = "test_password" encrypted = self.auth.encrypt_secret_key(secret_key, password) decrypted = self.auth.decrypt_secret_key(encrypted, password) self.assertEqual(secret_key, decrypted) def test_encrypt_decrypt_wrong_password(self): """Test secret key decryption with wrong password.""" secret_key = self.auth._generate_secret_key() password = "test_password" wrong_password = "wrong_password" encrypted = self.auth.encrypt_secret_key(secret_key, password) with self.assertRaises(Exception): self.auth.decrypt_secret_key(encrypted, wrong_password) @patch('src.auth.Database') def test_authenticate_agent_success(self, mock_db_class): """Test successful agent authentication.""" # Mock database mock_db = Mock() mock_db_class.return_value = mock_db agent_id = self.auth.generate_agent_id() secret_key = self.auth._generate_secret_key() key_hash = hashlib.sha256(secret_key.encode()).hexdigest() # Mock database response mock_db.get_agent_credentials.return_value = { 'agent_id': agent_id, 'key_hash': key_hash, 'is_active': True, 'created_at': datetime.now().isoformat() } result = self.auth.authenticate_agent(agent_id, secret_key) self.assertTrue(result) @patch('src.auth.Database') def test_authenticate_agent_failure(self, mock_db_class): """Test failed agent authentication.""" # Mock database mock_db = Mock() mock_db_class.return_value = mock_db agent_id = self.auth.generate_agent_id() secret_key = self.auth._generate_secret_key() # Mock database response - no credentials found mock_db.get_agent_credentials.return_value = None result = self.auth.authenticate_agent(agent_id, secret_key) self.assertFalse(result) class TestDatabase(unittest.TestCase): """Test cases for database operations.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') self.db = Database(self.db_path) self.db.create_tables() def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.db_path): os.remove(self.db_path) os.rmdir(self.temp_dir) def test_create_agent_auth(self): """Test agent authentication record creation.""" agent_id = "agent_test123" secret_key_hash = "test_hash" encrypted_key = "encrypted_test_key" success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) self.assertTrue(success) # Verify record exists credentials = self.db.get_agent_credentials(agent_id) self.assertIsNotNone(credentials) self.assertEqual(credentials['agent_id'], agent_id) self.assertEqual(credentials['key_hash'], secret_key_hash) def test_get_agent_credentials_exists(self): """Test retrieving existing agent credentials.""" agent_id = "agent_test123" secret_key_hash = "test_hash" encrypted_key = "encrypted_test_key" # Create record self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key) # Retrieve record credentials = self.db.get_agent_credentials(agent_id) self.assertIsNotNone(credentials) self.assertEqual(credentials['agent_id'], agent_id) self.assertEqual(credentials['key_hash'], secret_key_hash) self.assertTrue(credentials['is_active']) def test_get_agent_credentials_not_exists(self): """Test retrieving non-existent agent credentials.""" credentials = self.db.get_agent_credentials("non_existent_agent") self.assertIsNone(credentials) def test_store_agent_token(self): """Test storing agent JWT token.""" agent_id = "agent_test123" token = "test_jwt_token" expires_at = (datetime.now() + timedelta(hours=1)).isoformat() success = self.db.store_agent_token(agent_id, token, expires_at) self.assertTrue(success) # Verify token exists stored_token = self.db.get_agent_token(agent_id) self.assertIsNotNone(stored_token) self.assertEqual(stored_token['token'], token) def test_cleanup_expired_tokens(self): """Test cleanup of expired tokens.""" agent_id = "agent_test123" # Create expired token expired_token = "expired_token" expired_time = (datetime.now() - timedelta(hours=1)).isoformat() self.db.store_agent_token(agent_id, expired_token, expired_time) # Create valid token valid_token = "valid_token" valid_time = (datetime.now() + timedelta(hours=1)).isoformat() self.db.store_agent_token("agent_valid", valid_token, valid_time) # Cleanup expired tokens cleaned = self.db.cleanup_expired_tokens() self.assertGreaterEqual(cleaned, 1) # Verify expired token is gone token = self.db.get_agent_token(agent_id) self.assertIsNone(token) # Verify valid token remains token = self.db.get_agent_token("agent_valid") self.assertIsNotNone(token) class TestIntegration(unittest.TestCase): """Integration tests for the complete authentication flow.""" def setUp(self): """Set up test fixtures.""" self.temp_dir = tempfile.mkdtemp() self.db_path = os.path.join(self.temp_dir, 'test_guardian.db') self.auth = AgentAuthentication() # Use test database self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None def tearDown(self): """Clean up test fixtures.""" if os.path.exists(self.db_path): os.remove(self.db_path) os.rmdir(self.temp_dir) def test_complete_authentication_flow(self): """Test complete agent authentication workflow.""" # Step 1: Generate agent ID agent_id = self.auth.generate_agent_id() self.assertIsNotNone(agent_id) # Step 2: Create credentials credentials = self.auth.create_agent_credentials(agent_id) self.assertIsNotNone(credentials) # Step 3: Generate JWT token token = self.auth.generate_jwt_token( credentials['agent_id'], credentials['secret_key'] ) self.assertIsNotNone(token) # Step 4: Verify token is_valid = self.auth.verify_jwt_token(token, credentials['secret_key']) self.assertTrue(is_valid) # Step 5: Create HMAC signature test_data = "test API request" signature = self.auth.create_hmac_signature(test_data, credentials['secret_key']) self.assertIsNotNone(signature) # Step 6: Verify HMAC signature is_signature_valid = self.auth.verify_hmac_signature( test_data, signature, credentials['secret_key'] ) self.assertTrue(is_signature_valid) def run_tests(): """Run all tests.""" print("๐Ÿงช Running PyGuardian Authentication Tests...") print("=" * 50) # Create test suite test_suite = unittest.TestSuite() # Add test classes test_classes = [ TestAgentAuthentication, TestDatabase, TestIntegration ] for test_class in test_classes: tests = unittest.TestLoader().loadTestsFromTestCase(test_class) test_suite.addTests(tests) # Run tests runner = unittest.TextTestRunner(verbosity=2) result = runner.run(test_suite) # Print summary print("\n" + "=" * 50) print(f"๐Ÿ Tests completed:") print(f" โœ… Passed: {result.testsRun - len(result.failures) - len(result.errors)}") print(f" โŒ Failed: {len(result.failures)}") print(f" ๐Ÿ’ฅ Errors: {len(result.errors)}") # Return exit code return 0 if result.wasSuccessful() else 1 if __name__ == '__main__': sys.exit(run_tests())