Files
PyGuardian/.history/tests/unit/test_authentication_20251125212848.py
Andrey K. Choi 9f2cc216d5
Some checks reported errors
continuous-integration/drone/push Build was killed
fix: Resolve import issues and test compatibility
- Fix Storage class reference in authentication tests
- Add secret_key parameter to AgentAuthentication initialization
- Fix timedelta import in sessions.py
- Basic authentication functionality verified
2025-11-25 21:33:17 +09:00

421 lines
15 KiB
Python

#!/usr/bin/env python3
"""
Comprehensive unit tests for PyGuardian authentication system.
"""
import unittest
import tempfile
import os
import sys
import sqlite3
import jwt
import hashlib
import hmac
from datetime import datetime, timedelta
from unittest.mock import Mock, patch, MagicMock
# Add src directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../src'))
from auth import AgentAuthentication
from storage import Storage
class TestAgentAuthentication(unittest.TestCase):
"""Test cases for agent authentication system."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Create test database
self.db = Database(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_generate_agent_id(self):
"""Test agent ID generation."""
agent_id = self.auth.generate_agent_id()
# Check format
self.assertTrue(agent_id.startswith('agent_'))
self.assertEqual(len(agent_id), 42) # 'agent_' + 36 char UUID
# Test uniqueness
agent_id2 = self.auth.generate_agent_id()
self.assertNotEqual(agent_id, agent_id2)
def test_create_agent_credentials(self):
"""Test agent credentials creation."""
agent_id = self.auth.generate_agent_id()
credentials = self.auth.create_agent_credentials(agent_id)
# Check required fields
required_fields = ['agent_id', 'secret_key', 'encrypted_key', 'key_hash']
for field in required_fields:
self.assertIn(field, credentials)
# Check agent ID matches
self.assertEqual(credentials['agent_id'], agent_id)
# Check secret key length
self.assertEqual(len(credentials['secret_key']), 64) # 32 bytes hex encoded
# Check key hash
expected_hash = hashlib.sha256(credentials['secret_key'].encode()).hexdigest()
self.assertEqual(credentials['key_hash'], expected_hash)
def test_generate_jwt_token(self):
"""Test JWT token generation."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
# Verify token structure
self.assertIsInstance(token, str)
self.assertTrue(len(token) > 100) # JWT tokens are typically long
# Decode and verify payload
decoded = jwt.decode(token, secret_key, algorithms=['HS256'])
self.assertEqual(decoded['agent_id'], agent_id)
self.assertIn('iat', decoded)
self.assertIn('exp', decoded)
self.assertIn('jti', decoded)
def test_verify_jwt_token_valid(self):
"""Test JWT token verification with valid token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
token = self.auth.generate_jwt_token(agent_id, secret_key)
is_valid = self.auth.verify_jwt_token(token, secret_key)
self.assertTrue(is_valid)
def test_verify_jwt_token_invalid(self):
"""Test JWT token verification with invalid token."""
secret_key = self.auth._generate_secret_key()
# Test with invalid token
is_valid = self.auth.verify_jwt_token("invalid.jwt.token", secret_key)
self.assertFalse(is_valid)
# Test with wrong secret key
agent_id = self.auth.generate_agent_id()
token = self.auth.generate_jwt_token(agent_id, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_jwt_token(token, wrong_key)
self.assertFalse(is_valid)
def test_verify_jwt_token_expired(self):
"""Test JWT token verification with expired token."""
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Create expired token
payload = {
'agent_id': agent_id,
'exp': datetime.utcnow() - timedelta(hours=1), # Expired 1 hour ago
'iat': datetime.utcnow() - timedelta(hours=2),
'jti': self.auth._generate_jti()
}
expired_token = jwt.encode(payload, secret_key, algorithm='HS256')
is_valid = self.auth.verify_jwt_token(expired_token, secret_key)
self.assertFalse(is_valid)
def test_create_hmac_signature(self):
"""Test HMAC signature creation."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
# Verify signature format
self.assertEqual(len(signature), 64) # SHA256 hex digest
# Verify signature is correct
expected = hmac.new(
secret_key.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
self.assertEqual(signature, expected)
def test_verify_hmac_signature_valid(self):
"""Test HMAC signature verification with valid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
signature = self.auth.create_hmac_signature(data, secret_key)
is_valid = self.auth.verify_hmac_signature(data, signature, secret_key)
self.assertTrue(is_valid)
def test_verify_hmac_signature_invalid(self):
"""Test HMAC signature verification with invalid signature."""
data = "test message"
secret_key = self.auth._generate_secret_key()
# Test with wrong signature
wrong_signature = "0" * 64
is_valid = self.auth.verify_hmac_signature(data, wrong_signature, secret_key)
self.assertFalse(is_valid)
# Test with wrong key
signature = self.auth.create_hmac_signature(data, secret_key)
wrong_key = self.auth._generate_secret_key()
is_valid = self.auth.verify_hmac_signature(data, signature, wrong_key)
self.assertFalse(is_valid)
def test_encrypt_decrypt_secret_key(self):
"""Test secret key encryption and decryption."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
decrypted = self.auth.decrypt_secret_key(encrypted, password)
self.assertEqual(secret_key, decrypted)
def test_encrypt_decrypt_wrong_password(self):
"""Test secret key decryption with wrong password."""
secret_key = self.auth._generate_secret_key()
password = "test_password"
wrong_password = "wrong_password"
encrypted = self.auth.encrypt_secret_key(secret_key, password)
with self.assertRaises(Exception):
self.auth.decrypt_secret_key(encrypted, wrong_password)
@patch('src.auth.Database')
def test_authenticate_agent_success(self, mock_db_class):
"""Test successful agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
key_hash = hashlib.sha256(secret_key.encode()).hexdigest()
# Mock database response
mock_db.get_agent_credentials.return_value = {
'agent_id': agent_id,
'key_hash': key_hash,
'is_active': True,
'created_at': datetime.now().isoformat()
}
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertTrue(result)
@patch('src.auth.Database')
def test_authenticate_agent_failure(self, mock_db_class):
"""Test failed agent authentication."""
# Mock database
mock_db = Mock()
mock_db_class.return_value = mock_db
agent_id = self.auth.generate_agent_id()
secret_key = self.auth._generate_secret_key()
# Mock database response - no credentials found
mock_db.get_agent_credentials.return_value = None
result = self.auth.authenticate_agent(agent_id, secret_key)
self.assertFalse(result)
class TestDatabase(unittest.TestCase):
"""Test cases for database operations."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.db = Database(self.db_path)
self.db.create_tables()
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_create_agent_auth(self):
"""Test agent authentication record creation."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
success = self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
self.assertTrue(success)
# Verify record exists
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
def test_get_agent_credentials_exists(self):
"""Test retrieving existing agent credentials."""
agent_id = "agent_test123"
secret_key_hash = "test_hash"
encrypted_key = "encrypted_test_key"
# Create record
self.db.create_agent_auth(agent_id, secret_key_hash, encrypted_key)
# Retrieve record
credentials = self.db.get_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
self.assertEqual(credentials['agent_id'], agent_id)
self.assertEqual(credentials['key_hash'], secret_key_hash)
self.assertTrue(credentials['is_active'])
def test_get_agent_credentials_not_exists(self):
"""Test retrieving non-existent agent credentials."""
credentials = self.db.get_agent_credentials("non_existent_agent")
self.assertIsNone(credentials)
def test_store_agent_token(self):
"""Test storing agent JWT token."""
agent_id = "agent_test123"
token = "test_jwt_token"
expires_at = (datetime.now() + timedelta(hours=1)).isoformat()
success = self.db.store_agent_token(agent_id, token, expires_at)
self.assertTrue(success)
# Verify token exists
stored_token = self.db.get_agent_token(agent_id)
self.assertIsNotNone(stored_token)
self.assertEqual(stored_token['token'], token)
def test_cleanup_expired_tokens(self):
"""Test cleanup of expired tokens."""
agent_id = "agent_test123"
# Create expired token
expired_token = "expired_token"
expired_time = (datetime.now() - timedelta(hours=1)).isoformat()
self.db.store_agent_token(agent_id, expired_token, expired_time)
# Create valid token
valid_token = "valid_token"
valid_time = (datetime.now() + timedelta(hours=1)).isoformat()
self.db.store_agent_token("agent_valid", valid_token, valid_time)
# Cleanup expired tokens
cleaned = self.db.cleanup_expired_tokens()
self.assertGreaterEqual(cleaned, 1)
# Verify expired token is gone
token = self.db.get_agent_token(agent_id)
self.assertIsNone(token)
# Verify valid token remains
token = self.db.get_agent_token("agent_valid")
self.assertIsNotNone(token)
class TestIntegration(unittest.TestCase):
"""Integration tests for the complete authentication flow."""
def setUp(self):
"""Set up test fixtures."""
self.temp_dir = tempfile.mkdtemp()
self.db_path = os.path.join(self.temp_dir, 'test_guardian.db')
self.auth = AgentAuthentication()
# Use test database
self.original_db_path = self.auth.db_path if hasattr(self.auth, 'db_path') else None
def tearDown(self):
"""Clean up test fixtures."""
if os.path.exists(self.db_path):
os.remove(self.db_path)
os.rmdir(self.temp_dir)
def test_complete_authentication_flow(self):
"""Test complete agent authentication workflow."""
# Step 1: Generate agent ID
agent_id = self.auth.generate_agent_id()
self.assertIsNotNone(agent_id)
# Step 2: Create credentials
credentials = self.auth.create_agent_credentials(agent_id)
self.assertIsNotNone(credentials)
# Step 3: Generate JWT token
token = self.auth.generate_jwt_token(
credentials['agent_id'],
credentials['secret_key']
)
self.assertIsNotNone(token)
# Step 4: Verify token
is_valid = self.auth.verify_jwt_token(token, credentials['secret_key'])
self.assertTrue(is_valid)
# Step 5: Create HMAC signature
test_data = "test API request"
signature = self.auth.create_hmac_signature(test_data, credentials['secret_key'])
self.assertIsNotNone(signature)
# Step 6: Verify HMAC signature
is_signature_valid = self.auth.verify_hmac_signature(
test_data, signature, credentials['secret_key']
)
self.assertTrue(is_signature_valid)
def run_tests():
"""Run all tests."""
print("🧪 Running PyGuardian Authentication Tests...")
print("=" * 50)
# Create test suite
test_suite = unittest.TestSuite()
# Add test classes
test_classes = [
TestAgentAuthentication,
TestDatabase,
TestIntegration
]
for test_class in test_classes:
tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
test_suite.addTests(tests)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(test_suite)
# Print summary
print("\n" + "=" * 50)
print(f"🏁 Tests completed:")
print(f" ✅ Passed: {result.testsRun - len(result.failures) - len(result.errors)}")
print(f" ❌ Failed: {len(result.failures)}")
print(f" 💥 Errors: {len(result.errors)}")
# Return exit code
return 0 if result.wasSuccessful() else 1
if __name__ == '__main__':
sys.exit(run_tests())