init commit
This commit is contained in:
3
app/__init__.py
Normal file
3
app/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""Finance Bot Application Package"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
1
app/api/__init__.py
Normal file
1
app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routes"""
|
||||
279
app/api/auth.py
Normal file
279
app/api/auth.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
Authentication API Endpoints - Login, Token Management, Telegram Binding
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.database import get_db
|
||||
from app.services.auth_service import AuthService
|
||||
from app.security.jwt_manager import jwt_manager
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["authentication"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
user_id: int
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TelegramBindingStartRequest(BaseModel):
|
||||
chat_id: int
|
||||
|
||||
|
||||
class TelegramBindingStartResponse(BaseModel):
|
||||
code: str
|
||||
expires_in: int # seconds
|
||||
|
||||
|
||||
class TelegramBindingConfirmRequest(BaseModel):
|
||||
code: str
|
||||
chat_id: int
|
||||
username: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
|
||||
class TelegramBindingConfirmResponse(BaseModel):
|
||||
success: bool
|
||||
user_id: int
|
||||
jwt_token: str
|
||||
expires_at: str
|
||||
|
||||
|
||||
class TokenRefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenRefreshResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
@router.post(
|
||||
"/login",
|
||||
response_model=LoginResponse,
|
||||
summary="User login with email & password",
|
||||
)
|
||||
async def login(
|
||||
request: LoginRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> LoginResponse:
|
||||
"""
|
||||
Authenticate user and create session.
|
||||
|
||||
**Returns:**
|
||||
- access_token: Short-lived JWT (15 min)
|
||||
- refresh_token: Long-lived refresh token (30 days)
|
||||
|
||||
**Usage:**
|
||||
```
|
||||
Authorization: Bearer <access_token>
|
||||
X-Device-Id: device_uuid # For tracking
|
||||
```
|
||||
"""
|
||||
|
||||
# TODO: Verify email + password
|
||||
# For MVP: Assume credentials are valid
|
||||
|
||||
from app.db.models import User
|
||||
|
||||
user = db.query(User).filter(User.email == request.email).first()
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
service = AuthService(db)
|
||||
access_token, refresh_token = await service.create_session(
|
||||
user_id=user.id,
|
||||
device_id=request.__dict__.get("device_id"),
|
||||
)
|
||||
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
user_id=user.id,
|
||||
expires_in=15 * 60, # 15 minutes
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/refresh",
|
||||
response_model=TokenRefreshResponse,
|
||||
summary="Refresh access token",
|
||||
)
|
||||
async def refresh_token(
|
||||
request: TokenRefreshRequest,
|
||||
db: Session = Depends(get_db),
|
||||
) -> TokenRefreshResponse:
|
||||
"""
|
||||
Issue new access token using refresh token.
|
||||
|
||||
**Flow:**
|
||||
1. Access token expires
|
||||
2. Send refresh_token to this endpoint
|
||||
3. Receive new access_token (without creating new session)
|
||||
"""
|
||||
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(request.refresh_token)
|
||||
if token_payload.type != "refresh":
|
||||
raise ValueError("Not a refresh token")
|
||||
|
||||
service = AuthService(db)
|
||||
new_access_token = await service.refresh_access_token(
|
||||
refresh_token=request.refresh_token,
|
||||
user_id=token_payload.sub,
|
||||
)
|
||||
|
||||
return TokenRefreshResponse(
|
||||
access_token=new_access_token,
|
||||
expires_in=15 * 60,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/telegram/start",
|
||||
response_model=TelegramBindingStartResponse,
|
||||
summary="Start Telegram binding flow",
|
||||
)
|
||||
async def telegram_binding_start(
|
||||
request: TelegramBindingStartRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Generate binding code for Telegram user.
|
||||
|
||||
**Bot Flow:**
|
||||
1. User sends /start
|
||||
2. Bot calls this endpoint: POST /auth/telegram/start
|
||||
3. Bot receives code and generates link
|
||||
4. Bot sends message with link to user
|
||||
5. User clicks link (goes to confirm endpoint)
|
||||
"""
|
||||
|
||||
service = AuthService(db)
|
||||
code = await service.create_telegram_binding_code(chat_id=request.chat_id)
|
||||
|
||||
return TelegramBindingStartResponse(
|
||||
code=code,
|
||||
expires_in=600, # 10 minutes
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/telegram/confirm",
|
||||
response_model=TelegramBindingConfirmResponse,
|
||||
summary="Confirm Telegram binding",
|
||||
)
|
||||
async def telegram_binding_confirm(
|
||||
request: TelegramBindingConfirmRequest,
|
||||
current_request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Confirm Telegram binding and issue JWT.
|
||||
|
||||
**Flow:**
|
||||
1. User logs in or creates account
|
||||
2. User clicks binding link with code
|
||||
3. Frontend calls this endpoint with code + user context
|
||||
4. Backend creates TelegramIdentity record
|
||||
5. Backend returns JWT for bot to use
|
||||
|
||||
**Bot Usage:**
|
||||
```python
|
||||
# Bot stores JWT for user
|
||||
redis.setex(f"chat_id:{chat_id}:jwt", 86400*30, jwt_token)
|
||||
|
||||
# Bot makes API calls
|
||||
api_request.headers['Authorization'] = f'Bearer {jwt_token}'
|
||||
```
|
||||
"""
|
||||
|
||||
# Get authenticated user from JWT
|
||||
user_id = getattr(current_request.state, "user_id", None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User must be authenticated")
|
||||
|
||||
service = AuthService(db)
|
||||
result = await service.confirm_telegram_binding(
|
||||
user_id=user_id,
|
||||
chat_id=request.chat_id,
|
||||
code=request.code,
|
||||
username=request.username,
|
||||
first_name=request.first_name,
|
||||
last_name=request.last_name,
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=400, detail="Binding failed")
|
||||
|
||||
return TelegramBindingConfirmResponse(**result)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/telegram/authenticate",
|
||||
response_model=dict,
|
||||
summary="Authenticate by Telegram chat_id",
|
||||
)
|
||||
async def telegram_authenticate(
|
||||
chat_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get JWT token for Telegram user.
|
||||
|
||||
**Usage in Bot:**
|
||||
```python
|
||||
# After user binding is confirmed
|
||||
response = api.post("/auth/telegram/authenticate?chat_id=12345")
|
||||
jwt_token = response["jwt_token"]
|
||||
```
|
||||
"""
|
||||
|
||||
service = AuthService(db)
|
||||
result = await service.authenticate_telegram_user(chat_id=chat_id)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Telegram identity not found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post(
|
||||
"/logout",
|
||||
summary="Logout user",
|
||||
)
|
||||
async def logout(
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Revoke session and blacklist tokens.
|
||||
|
||||
**TODO:** Implement token blacklisting in Redis
|
||||
"""
|
||||
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
# TODO: Add token to Redis blacklist
|
||||
# redis.setex(f"blacklist:{token}", token_expiry_time, "1")
|
||||
|
||||
return {"message": "Logged out successfully"}
|
||||
41
app/api/main.py
Normal file
41
app/api/main.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""FastAPI application"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title="Finance Bot API",
|
||||
description="REST API for family finance management",
|
||||
version="0.1.0"
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {
|
||||
"status": "ok",
|
||||
"environment": settings.app_env
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint"""
|
||||
return {
|
||||
"message": "Finance Bot API",
|
||||
"docs": "/docs",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
275
app/api/transactions.py
Normal file
275
app/api/transactions.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Transaction API Endpoints - CRUD + Approval Workflow
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from decimal import Decimal
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.database import get_db
|
||||
from app.services.transaction_service import TransactionService
|
||||
from app.security.rbac import UserContext, RBACEngine, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/v1/transactions", tags=["transactions"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class TransactionCreateRequest(BaseModel):
|
||||
family_id: int
|
||||
from_wallet_id: Optional[int] = None
|
||||
to_wallet_id: Optional[int] = None
|
||||
category_id: Optional[int] = None
|
||||
amount: Decimal
|
||||
description: str
|
||||
notes: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"family_id": 1,
|
||||
"from_wallet_id": 10,
|
||||
"to_wallet_id": 11,
|
||||
"category_id": 5,
|
||||
"amount": 50.00,
|
||||
"description": "Rent payment",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TransactionResponse(BaseModel):
|
||||
id: int
|
||||
status: str # draft, pending_approval, executed, reversed
|
||||
amount: Decimal
|
||||
description: str
|
||||
confirmation_required: bool
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TransactionConfirmRequest(BaseModel):
|
||||
confirmation_token: Optional[str] = None
|
||||
|
||||
|
||||
class TransactionReverseRequest(BaseModel):
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
# Dependency to extract user context
|
||||
async def get_user_context(request: Request) -> UserContext:
|
||||
"""Extract user context from JWT"""
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_id = getattr(request.state, "family_id", None)
|
||||
|
||||
if not user_id or not family_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||
|
||||
# Load user role from DB (simplified for MVP)
|
||||
# In production: Load from users->family_members join
|
||||
role = MemberRole.OWNER # TODO: Load from DB
|
||||
permissions = RBACEngine.get_permissions(role)
|
||||
|
||||
return UserContext(
|
||||
user_id=user_id,
|
||||
family_id=family_id,
|
||||
role=role,
|
||||
permissions=permissions,
|
||||
family_ids=[family_id],
|
||||
device_id=getattr(request.state, "device_id", None),
|
||||
client_id=getattr(request.state, "client_id", None),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
response_model=TransactionResponse,
|
||||
status_code=201,
|
||||
summary="Create new transaction",
|
||||
)
|
||||
async def create_transaction(
|
||||
request: TransactionCreateRequest,
|
||||
user_context: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
) -> TransactionResponse:
|
||||
"""
|
||||
Create a new financial transaction.
|
||||
|
||||
**Request Headers Required:**
|
||||
- Authorization: Bearer <jwt_token>
|
||||
- X-Client-Id: telegram_bot | web_frontend | ios_app
|
||||
- X-Signature: HMAC_SHA256(...)
|
||||
- X-Timestamp: unix timestamp
|
||||
|
||||
**Response:**
|
||||
- If amount ≤ threshold: status="executed" immediately
|
||||
- If amount > threshold: status="pending_approval", requires confirmation
|
||||
|
||||
**Events Emitted:**
|
||||
- transaction.created
|
||||
"""
|
||||
|
||||
try:
|
||||
service = TransactionService(db)
|
||||
result = await service.create_transaction(
|
||||
user_context=user_context,
|
||||
family_id=request.family_id,
|
||||
from_wallet_id=request.from_wallet_id,
|
||||
to_wallet_id=request.to_wallet_id,
|
||||
amount=request.amount,
|
||||
category_id=request.category_id,
|
||||
description=request.description,
|
||||
)
|
||||
|
||||
return TransactionResponse(**result)
|
||||
|
||||
except PermissionError as e:
|
||||
logger.warning(f"Permission denied: {e} (user: {user_context.user_id})")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Validation error: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating transaction: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{transaction_id}/confirm",
|
||||
response_model=TransactionResponse,
|
||||
summary="Confirm pending transaction",
|
||||
)
|
||||
async def confirm_transaction(
|
||||
transaction_id: int,
|
||||
request: TransactionConfirmRequest,
|
||||
user_context: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Approve a pending transaction for execution.
|
||||
|
||||
Only owner or designated approver can confirm.
|
||||
|
||||
**Events Emitted:**
|
||||
- transaction.confirmed
|
||||
- transaction.executed
|
||||
"""
|
||||
|
||||
try:
|
||||
service = TransactionService(db)
|
||||
result = await service.confirm_transaction(
|
||||
user_context=user_context,
|
||||
transaction_id=transaction_id,
|
||||
confirmation_token=request.confirmation_token,
|
||||
)
|
||||
|
||||
return TransactionResponse(**result)
|
||||
|
||||
except (PermissionError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{transaction_id}",
|
||||
response_model=dict,
|
||||
summary="Reverse (cancel) transaction",
|
||||
)
|
||||
async def reverse_transaction(
|
||||
transaction_id: int,
|
||||
request: TransactionReverseRequest,
|
||||
user_context: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Reverse (cancel) executed transaction.
|
||||
|
||||
Creates a compensation (reverse) transaction instead of deletion.
|
||||
Original transaction status changes to "reversed".
|
||||
|
||||
**Events Emitted:**
|
||||
- transaction.reversed
|
||||
- transaction.created (compensation)
|
||||
"""
|
||||
|
||||
try:
|
||||
service = TransactionService(db)
|
||||
result = await service.reverse_transaction(
|
||||
user_context=user_context,
|
||||
transaction_id=transaction_id,
|
||||
reason=request.reason,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except (PermissionError, ValueError) as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
response_model=List[TransactionResponse],
|
||||
summary="List transactions",
|
||||
)
|
||||
async def list_transactions(
|
||||
family_id: int,
|
||||
skip: int = 0,
|
||||
limit: int = 20,
|
||||
user_context: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all transactions for family.
|
||||
|
||||
**Filtering:**
|
||||
- ?family_id=1
|
||||
- ?wallet_id=10
|
||||
- ?category_id=5
|
||||
- ?status=executed
|
||||
- ?from_date=2023-12-01&to_date=2023-12-31
|
||||
|
||||
**Pagination:**
|
||||
- ?skip=0&limit=20
|
||||
"""
|
||||
|
||||
# Verify family access
|
||||
RBACEngine.check_family_access(user_context, family_id)
|
||||
|
||||
from app.db.models import Transaction
|
||||
|
||||
transactions = db.query(Transaction).filter(
|
||||
Transaction.family_id == family_id,
|
||||
).offset(skip).limit(limit).all()
|
||||
|
||||
return [TransactionResponse.from_orm(t) for t in transactions]
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{transaction_id}",
|
||||
response_model=TransactionResponse,
|
||||
summary="Get transaction details",
|
||||
)
|
||||
async def get_transaction(
|
||||
transaction_id: int,
|
||||
user_context: UserContext = Depends(get_user_context),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get detailed transaction information"""
|
||||
|
||||
from app.db.models import Transaction
|
||||
|
||||
transaction = db.query(Transaction).filter(
|
||||
Transaction.id == transaction_id,
|
||||
Transaction.family_id == user_context.family_id,
|
||||
).first()
|
||||
|
||||
if not transaction:
|
||||
raise HTTPException(status_code=404, detail="Transaction not found")
|
||||
|
||||
return TransactionResponse.from_orm(transaction)
|
||||
6
app/bot/__init__.py
Normal file
6
app/bot/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Bot module"""
|
||||
|
||||
from app.bot.handlers import register_handlers
|
||||
from app.bot.keyboards import *
|
||||
|
||||
__all__ = ["register_handlers"]
|
||||
328
app/bot/client.py
Normal file
328
app/bot/client.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Telegram Bot - API-First Client
|
||||
All database operations go through API endpoints, not direct SQLAlchemy.
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from decimal import Decimal
|
||||
import aiohttp
|
||||
import time
|
||||
from aiogram import Bot, Dispatcher, types, F
|
||||
from aiogram.filters import Command
|
||||
from aiogram.types import Message
|
||||
import redis
|
||||
import json
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TelegramBotClient:
|
||||
"""
|
||||
Telegram Bot that communicates exclusively via API calls.
|
||||
|
||||
Features:
|
||||
- User authentication via JWT tokens stored in Redis
|
||||
- All operations through API (no direct DB access)
|
||||
- Async HTTP requests with aiohttp
|
||||
- Event listening via Redis Streams
|
||||
"""
|
||||
|
||||
def __init__(self, bot_token: str, api_base_url: str, redis_client: redis.Redis):
|
||||
self.bot = Bot(token=bot_token)
|
||||
self.dp = Dispatcher()
|
||||
self.api_base_url = api_base_url
|
||||
self.redis_client = redis_client
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
# Register handlers
|
||||
self._setup_handlers()
|
||||
|
||||
def _setup_handlers(self):
|
||||
"""Register message handlers"""
|
||||
self.dp.message.register(self.cmd_start, Command("start"))
|
||||
self.dp.message.register(self.cmd_help, Command("help"))
|
||||
self.dp.message.register(self.cmd_balance, Command("balance"))
|
||||
self.dp.message.register(self.cmd_add_transaction, Command("add"))
|
||||
|
||||
async def start(self):
|
||||
"""Start bot polling"""
|
||||
self.session = aiohttp.ClientSession()
|
||||
logger.info("Telegram bot started")
|
||||
|
||||
# Start polling
|
||||
try:
|
||||
await self.dp.start_polling(self.bot)
|
||||
finally:
|
||||
await self.session.close()
|
||||
|
||||
# ========== Handler: /start (Binding) ==========
|
||||
async def cmd_start(self, message: Message):
|
||||
"""
|
||||
/start - Begin Telegram binding process.
|
||||
|
||||
Flow:
|
||||
1. Check if user already bound
|
||||
2. If not: Generate binding code
|
||||
3. Send link to user
|
||||
"""
|
||||
chat_id = message.chat.id
|
||||
|
||||
# Check if already bound
|
||||
jwt_key = f"chat_id:{chat_id}:jwt"
|
||||
existing_token = self.redis_client.get(jwt_key)
|
||||
|
||||
if existing_token:
|
||||
await message.answer("✅ You're already connected!\n\nUse /help for commands.")
|
||||
return
|
||||
|
||||
# Generate binding code
|
||||
try:
|
||||
code = await self._api_call(
|
||||
method="POST",
|
||||
endpoint="/api/v1/auth/telegram/start",
|
||||
data={"chat_id": chat_id},
|
||||
use_jwt=False,
|
||||
)
|
||||
|
||||
binding_code = code.get("code")
|
||||
|
||||
# Send binding link to user
|
||||
binding_url = f"https://your-app.com/auth/telegram?code={binding_code}&chat_id={chat_id}"
|
||||
|
||||
await message.answer(
|
||||
f"🔗 Click to bind your account:\n\n"
|
||||
f"[Open Account Binding]({binding_url})\n\n"
|
||||
f"Code expires in 10 minutes.",
|
||||
parse_mode="Markdown"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Binding start error: {e}")
|
||||
await message.answer("❌ Binding failed. Try again later.")
|
||||
|
||||
# ========== Handler: /balance ==========
|
||||
async def cmd_balance(self, message: Message):
|
||||
"""
|
||||
/balance - Show wallet balances.
|
||||
|
||||
Requires:
|
||||
- User must be bound (JWT token in Redis)
|
||||
- API call with JWT auth
|
||||
"""
|
||||
chat_id = message.chat.id
|
||||
|
||||
# Get JWT token
|
||||
jwt_token = self._get_user_jwt(chat_id)
|
||||
if not jwt_token:
|
||||
await message.answer("❌ Not connected. Use /start to bind your account.")
|
||||
return
|
||||
|
||||
try:
|
||||
# Call API: GET /api/v1/wallets/summary?family_id=1
|
||||
wallets = await self._api_call(
|
||||
method="GET",
|
||||
endpoint="/api/v1/wallets/summary",
|
||||
jwt_token=jwt_token,
|
||||
params={"family_id": 1}, # TODO: Get from context
|
||||
)
|
||||
|
||||
# Format response
|
||||
response = "💰 **Your Wallets:**\n\n"
|
||||
for wallet in wallets:
|
||||
response += f"📊 {wallet['name']}: ${wallet['balance']}\n"
|
||||
|
||||
await message.answer(response, parse_mode="Markdown")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Balance fetch error: {e}")
|
||||
await message.answer("❌ Could not fetch balance. Try again later.")
|
||||
|
||||
# ========== Handler: /add (Create Transaction) ==========
|
||||
async def cmd_add_transaction(self, message: Message):
|
||||
"""
|
||||
/add - Create new transaction (interactive).
|
||||
|
||||
Flow:
|
||||
1. Ask for amount
|
||||
2. Ask for category
|
||||
3. Ask for wallet (from/to)
|
||||
4. Create transaction via API
|
||||
"""
|
||||
|
||||
chat_id = message.chat.id
|
||||
jwt_token = self._get_user_jwt(chat_id)
|
||||
|
||||
if not jwt_token:
|
||||
await message.answer("❌ Not connected. Use /start first.")
|
||||
return
|
||||
|
||||
# Store conversation state in Redis
|
||||
state_key = f"chat_id:{chat_id}:state"
|
||||
self.redis_client.setex(state_key, 300, json.dumps({
|
||||
"action": "add_transaction",
|
||||
"step": 1, # Waiting for amount
|
||||
}))
|
||||
|
||||
await message.answer("💵 How much?\n\nEnter amount (e.g., 50.00)")
|
||||
|
||||
async def handle_transaction_input(self, message: Message, state: Dict[str, Any]):
|
||||
"""Handle transaction creation in steps"""
|
||||
chat_id = message.chat.id
|
||||
jwt_token = self._get_user_jwt(chat_id)
|
||||
|
||||
step = state.get("step", 1)
|
||||
|
||||
if step == 1:
|
||||
# Amount entered
|
||||
try:
|
||||
amount = Decimal(message.text)
|
||||
except:
|
||||
await message.answer("❌ Invalid amount. Try again.")
|
||||
return
|
||||
|
||||
state["amount"] = float(amount)
|
||||
state["step"] = 2
|
||||
self.redis_client.setex(f"chat_id:{chat_id}:state", 300, json.dumps(state))
|
||||
|
||||
await message.answer("📂 Which category?\n\n/food /transport /other")
|
||||
|
||||
elif step == 2:
|
||||
# Category selected
|
||||
state["category"] = message.text
|
||||
state["step"] = 3
|
||||
self.redis_client.setex(f"chat_id:{chat_id}:state", 300, json.dumps(state))
|
||||
|
||||
await message.answer("💬 Any notes?\n\n(or /skip)")
|
||||
|
||||
elif step == 3:
|
||||
# Notes entered (or skipped)
|
||||
state["notes"] = message.text if message.text != "/skip" else ""
|
||||
|
||||
# Create transaction via API
|
||||
try:
|
||||
result = await self._api_call(
|
||||
method="POST",
|
||||
endpoint="/api/v1/transactions",
|
||||
jwt_token=jwt_token,
|
||||
data={
|
||||
"family_id": 1,
|
||||
"from_wallet_id": 10,
|
||||
"amount": state["amount"],
|
||||
"category_id": 5, # TODO: Map category
|
||||
"description": state["category"],
|
||||
"notes": state["notes"],
|
||||
}
|
||||
)
|
||||
|
||||
tx_id = result.get("id")
|
||||
await message.answer(f"✅ Transaction #{tx_id} created!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transaction creation error: {e}")
|
||||
await message.answer("❌ Creation failed. Try again.")
|
||||
|
||||
finally:
|
||||
# Clean up state
|
||||
self.redis_client.delete(f"chat_id:{chat_id}:state")
|
||||
|
||||
# ========== Handler: /help ==========
|
||||
async def cmd_help(self, message: Message):
|
||||
"""Show available commands"""
|
||||
help_text = """
|
||||
🤖 **Finance Bot Commands:**
|
||||
|
||||
/start - Bind your Telegram account
|
||||
/balance - Show wallet balances
|
||||
/add - Add new transaction
|
||||
/reports - View reports (daily/weekly/monthly)
|
||||
/help - This message
|
||||
"""
|
||||
await message.answer(help_text, parse_mode="Markdown")
|
||||
|
||||
# ========== API Communication Methods ==========
|
||||
async def _api_call(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
data: Dict = None,
|
||||
params: Dict = None,
|
||||
jwt_token: Optional[str] = None,
|
||||
use_jwt: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Make HTTP request to API with proper auth headers.
|
||||
|
||||
Headers:
|
||||
- Authorization: Bearer <jwt_token>
|
||||
- X-Client-Id: telegram_bot
|
||||
- X-Signature: HMAC_SHA256(...)
|
||||
- X-Timestamp: unix timestamp
|
||||
"""
|
||||
|
||||
if not self.session:
|
||||
raise RuntimeError("Session not initialized")
|
||||
|
||||
# Build headers
|
||||
headers = {
|
||||
"X-Client-Id": "telegram_bot",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add JWT if provided
|
||||
if use_jwt and jwt_token:
|
||||
headers["Authorization"] = f"Bearer {jwt_token}"
|
||||
|
||||
# Add HMAC signature
|
||||
timestamp = int(time.time())
|
||||
headers["X-Timestamp"] = str(timestamp)
|
||||
|
||||
signature = hmac_manager.create_signature(
|
||||
method=method,
|
||||
endpoint=endpoint,
|
||||
timestamp=timestamp,
|
||||
body=data,
|
||||
)
|
||||
headers["X-Signature"] = signature
|
||||
|
||||
# Make request
|
||||
url = f"{self.api_base_url}{endpoint}"
|
||||
|
||||
async with self.session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=data,
|
||||
params=params,
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status >= 400:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"API error {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
def _get_user_jwt(self, chat_id: int) -> Optional[str]:
|
||||
"""Get JWT token for chat_id from Redis"""
|
||||
jwt_key = f"chat_id:{chat_id}:jwt"
|
||||
token = self.redis_client.get(jwt_key)
|
||||
return token.decode() if token else None
|
||||
|
||||
async def send_notification(self, chat_id: int, message: str):
|
||||
"""Send notification to user"""
|
||||
try:
|
||||
await self.bot.send_message(chat_id=chat_id, text=message)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send notification to {chat_id}: {e}")
|
||||
|
||||
|
||||
# Bot factory
|
||||
async def create_telegram_bot(
|
||||
bot_token: str,
|
||||
api_base_url: str,
|
||||
redis_client: redis.Redis,
|
||||
) -> TelegramBotClient:
|
||||
"""Create and start Telegram bot"""
|
||||
bot = TelegramBotClient(bot_token, api_base_url, redis_client)
|
||||
return bot
|
||||
14
app/bot/handlers/__init__.py
Normal file
14
app/bot/handlers/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Bot handlers"""
|
||||
|
||||
from app.bot.handlers.start import register_start_handlers
|
||||
from app.bot.handlers.user import register_user_handlers
|
||||
from app.bot.handlers.family import register_family_handlers
|
||||
from app.bot.handlers.transaction import register_transaction_handlers
|
||||
|
||||
|
||||
def register_handlers(dp):
|
||||
"""Register all bot handlers"""
|
||||
register_start_handlers(dp)
|
||||
register_user_handlers(dp)
|
||||
register_family_handlers(dp)
|
||||
register_transaction_handlers(dp)
|
||||
18
app/bot/handlers/family.py
Normal file
18
app/bot/handlers/family.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Family-related handlers"""
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.types import Message
|
||||
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message()
|
||||
async def family_menu(message: Message):
|
||||
"""Handle family menu interactions"""
|
||||
pass
|
||||
|
||||
|
||||
def register_family_handlers(dp):
|
||||
"""Register family handlers"""
|
||||
dp.include_router(router)
|
||||
60
app/bot/handlers/start.py
Normal file
60
app/bot/handlers/start.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Start and help handlers"""
|
||||
|
||||
from aiogram import Router, F
|
||||
from aiogram.filters import CommandStart
|
||||
from aiogram.types import Message
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.database import SessionLocal
|
||||
from app.db.repositories import UserRepository, FamilyRepository
|
||||
from app.bot.keyboards import main_menu_keyboard
|
||||
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message(CommandStart())
|
||||
async def cmd_start(message: Message):
|
||||
"""Handle /start command"""
|
||||
user_repo = UserRepository(SessionLocal())
|
||||
|
||||
# Create or update user
|
||||
user = user_repo.get_or_create(
|
||||
telegram_id=message.from_user.id,
|
||||
username=message.from_user.username,
|
||||
first_name=message.from_user.first_name,
|
||||
last_name=message.from_user.last_name,
|
||||
)
|
||||
|
||||
welcome_text = (
|
||||
"👋 Добро пожаловать в Finance Bot!\n\n"
|
||||
"Я помогу вам управлять семейными финансами:\n"
|
||||
"💰 Отслеживать доходы и расходы\n"
|
||||
"👨👩👧👦 Управлять семейной группой\n"
|
||||
"📊 Видеть аналитику\n"
|
||||
"🎯 Ставить финансовые цели\n\n"
|
||||
"Выберите действие:"
|
||||
)
|
||||
|
||||
await message.answer(welcome_text, reply_markup=main_menu_keyboard())
|
||||
|
||||
|
||||
@router.message(CommandStart())
|
||||
async def cmd_help(message: Message):
|
||||
"""Handle /help command"""
|
||||
help_text = (
|
||||
"📚 **Справка по командам:**\n\n"
|
||||
"/start - Главное меню\n"
|
||||
"/help - Эта справка\n"
|
||||
"/account - Мои счета\n"
|
||||
"/transaction - Новая операция\n"
|
||||
"/budget - Управление бюджетом\n"
|
||||
"/analytics - Аналитика\n"
|
||||
"/family - Управление семьей\n"
|
||||
"/settings - Параметры\n"
|
||||
)
|
||||
await message.answer(help_text)
|
||||
|
||||
|
||||
def register_start_handlers(dp):
|
||||
"""Register start handlers"""
|
||||
dp.include_router(router)
|
||||
18
app/bot/handlers/transaction.py
Normal file
18
app/bot/handlers/transaction.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Transaction-related handlers"""
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.types import Message
|
||||
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message()
|
||||
async def transaction_menu(message: Message):
|
||||
"""Handle transaction operations"""
|
||||
pass
|
||||
|
||||
|
||||
def register_transaction_handlers(dp):
|
||||
"""Register transaction handlers"""
|
||||
dp.include_router(router)
|
||||
18
app/bot/handlers/user.py
Normal file
18
app/bot/handlers/user.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""User-related handlers"""
|
||||
|
||||
from aiogram import Router
|
||||
from aiogram.types import Message
|
||||
|
||||
|
||||
router = Router()
|
||||
|
||||
|
||||
@router.message()
|
||||
async def user_menu(message: Message):
|
||||
"""Handle user menu interactions"""
|
||||
pass
|
||||
|
||||
|
||||
def register_user_handlers(dp):
|
||||
"""Register user handlers"""
|
||||
dp.include_router(router)
|
||||
56
app/bot/keyboards/__init__.py
Normal file
56
app/bot/keyboards/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Bot keyboards"""
|
||||
|
||||
from aiogram.types import ReplyKeyboardMarkup, KeyboardButton
|
||||
from aiogram.types import InlineKeyboardMarkup, InlineKeyboardButton
|
||||
|
||||
|
||||
def main_menu_keyboard() -> ReplyKeyboardMarkup:
|
||||
"""Main menu keyboard"""
|
||||
return ReplyKeyboardMarkup(
|
||||
keyboard=[
|
||||
[
|
||||
KeyboardButton(text="💰 Новая операция"),
|
||||
KeyboardButton(text="📊 Аналитика"),
|
||||
],
|
||||
[
|
||||
KeyboardButton(text="👨👩👧👦 Семья"),
|
||||
KeyboardButton(text="🎯 Цели"),
|
||||
],
|
||||
[
|
||||
KeyboardButton(text="💳 Счета"),
|
||||
KeyboardButton(text="⚙️ Параметры"),
|
||||
],
|
||||
[
|
||||
KeyboardButton(text="📞 Помощь"),
|
||||
],
|
||||
],
|
||||
resize_keyboard=True,
|
||||
input_field_placeholder="Выберите действие...",
|
||||
)
|
||||
|
||||
|
||||
def transaction_type_keyboard() -> InlineKeyboardMarkup:
|
||||
"""Transaction type selection"""
|
||||
return InlineKeyboardMarkup(
|
||||
inline_keyboard=[
|
||||
[InlineKeyboardButton(text="💸 Расход", callback_data="tx_expense")],
|
||||
[InlineKeyboardButton(text="💵 Доход", callback_data="tx_income")],
|
||||
[InlineKeyboardButton(text="🔄 Перевод", callback_data="tx_transfer")],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def cancel_keyboard() -> InlineKeyboardMarkup:
|
||||
"""Cancel button"""
|
||||
return InlineKeyboardMarkup(
|
||||
inline_keyboard=[
|
||||
[InlineKeyboardButton(text="❌ Отменить", callback_data="cancel")],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"main_menu_keyboard",
|
||||
"transaction_type_keyboard",
|
||||
"cancel_keyboard",
|
||||
]
|
||||
36
app/bot_main.py
Normal file
36
app/bot_main.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Telegram Bot Entry Point
|
||||
Runs the bot polling service
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from app.bot.client import TelegramBotClient
|
||||
from app.core.config import settings
|
||||
import redis
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Start Telegram bot"""
|
||||
try:
|
||||
redis_client = redis.from_url(settings.redis_url)
|
||||
|
||||
bot = TelegramBotClient(
|
||||
bot_token=settings.bot_token,
|
||||
api_base_url="http://web:8000",
|
||||
redis_client=redis_client
|
||||
)
|
||||
|
||||
logger.info("Starting Telegram bot...")
|
||||
await bot.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Bot error: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
5
app/core/__init__.py
Normal file
5
app/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Core module - configuration and utilities"""
|
||||
|
||||
from app.core.config import Settings
|
||||
|
||||
__all__ = ["Settings"]
|
||||
70
app/core/config.py
Normal file
70
app/core/config.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Application configuration using pydantic-settings"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Main application settings"""
|
||||
|
||||
# Bot Configuration
|
||||
bot_token: str
|
||||
bot_admin_id: int
|
||||
|
||||
# Database Configuration
|
||||
database_url: str
|
||||
database_echo: bool = False
|
||||
|
||||
# Database Credentials (for Docker)
|
||||
db_password: Optional[str] = None
|
||||
db_user: Optional[str] = None
|
||||
db_name: Optional[str] = None
|
||||
|
||||
# Redis Configuration
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# Application Configuration
|
||||
app_debug: bool = False
|
||||
app_env: str = "development"
|
||||
log_level: str = "INFO"
|
||||
|
||||
# API Configuration
|
||||
api_host: str = "0.0.0.0"
|
||||
api_port: int = 8000
|
||||
|
||||
# Timezone
|
||||
tz: str = "Europe/Moscow"
|
||||
|
||||
# Security Configuration
|
||||
jwt_secret_key: str = "your-secret-key-change-in-production"
|
||||
hmac_secret_key: str = "your-hmac-secret-change-in-production"
|
||||
require_hmac_verification: bool = False # Disabled by default in MVP
|
||||
access_token_expire_minutes: int = 15
|
||||
refresh_token_expire_days: int = 30
|
||||
|
||||
# CORS Configuration
|
||||
cors_allowed_origins: list[str] = ["http://localhost:3000", "http://localhost:8081"]
|
||||
cors_allow_credentials: bool = True
|
||||
cors_allow_methods: list[str] = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||
cors_allow_headers: list[str] = ["*"]
|
||||
|
||||
# Feature Flags
|
||||
feature_telegram_bot_enabled: bool = True
|
||||
feature_transaction_approval: bool = True
|
||||
feature_event_logging: bool = True
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
case_sensitive = False
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance"""
|
||||
return Settings()
|
||||
|
||||
|
||||
# Global settings instance for direct imports
|
||||
settings = get_settings()
|
||||
5
app/db/__init__.py
Normal file
5
app/db/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Database module - models, repositories, and session management"""
|
||||
|
||||
from app.db.database import SessionLocal, engine, Base
|
||||
|
||||
__all__ = ["SessionLocal", "engine", "Base"]
|
||||
36
app/db/database.py
Normal file
36
app/db/database.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Database connection and session management"""
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.core.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Create database engine
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
pool_pre_ping=True, # Verify connections before using them
|
||||
pool_recycle=3600, # Recycle connections every hour
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
# Create declarative base for models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Dependency for FastAPI to get database session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
28
app/db/models/__init__.py
Normal file
28
app/db/models/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Database models"""
|
||||
|
||||
from app.db.models.user import User
|
||||
from app.db.models.family import Family, FamilyMember, FamilyInvite, FamilyRole
|
||||
from app.db.models.account import Account, AccountType
|
||||
from app.db.models.category import Category, CategoryType
|
||||
from app.db.models.transaction import Transaction, TransactionType
|
||||
from app.db.models.budget import Budget, BudgetPeriod
|
||||
from app.db.models.goal import Goal
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"User",
|
||||
"Family",
|
||||
"FamilyMember",
|
||||
"FamilyInvite",
|
||||
"Account",
|
||||
"Category",
|
||||
"Transaction",
|
||||
"Budget",
|
||||
"Goal",
|
||||
# Enums
|
||||
"FamilyRole",
|
||||
"AccountType",
|
||||
"CategoryType",
|
||||
"TransactionType",
|
||||
"BudgetPeriod",
|
||||
]
|
||||
50
app/db/models/account.py
Normal file
50
app/db/models/account.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Account (wallet) model"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Float, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""Types of accounts"""
|
||||
CARD = "card"
|
||||
CASH = "cash"
|
||||
DEPOSIT = "deposit"
|
||||
GOAL = "goal"
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class Account(Base):
|
||||
"""Account model - represents a user's wallet or account"""
|
||||
|
||||
__tablename__ = "accounts"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
account_type = Column(Enum(AccountType), default=AccountType.CARD)
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Balance
|
||||
balance = Column(Float, default=0.0)
|
||||
initial_balance = Column(Float, default=0.0)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_archived = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="accounts")
|
||||
owner = relationship("User", back_populates="accounts")
|
||||
transactions = relationship("Transaction", back_populates="account")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Account(id={self.id}, name={self.name}, balance={self.balance})>"
|
||||
50
app/db/models/budget.py
Normal file
50
app/db/models/budget.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Budget model for budget tracking"""
|
||||
|
||||
from sqlalchemy import Column, Integer, Float, String, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class BudgetPeriod(str, PyEnum):
|
||||
"""Budget periods"""
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
YEARLY = "yearly"
|
||||
|
||||
|
||||
class Budget(Base):
|
||||
"""Budget model - spending limits"""
|
||||
|
||||
__tablename__ = "budgets"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
category_id = Column(Integer, ForeignKey("categories.id"), nullable=True)
|
||||
|
||||
# Budget details
|
||||
name = Column(String(255), nullable=False)
|
||||
limit_amount = Column(Float, nullable=False)
|
||||
spent_amount = Column(Float, default=0.0)
|
||||
period = Column(Enum(BudgetPeriod), default=BudgetPeriod.MONTHLY)
|
||||
|
||||
# Alert threshold (percentage)
|
||||
alert_threshold = Column(Float, default=80.0)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
start_date = Column(DateTime, nullable=False)
|
||||
end_date = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="budgets")
|
||||
category = relationship("Category", back_populates="budgets")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Budget(id={self.id}, name={self.name}, limit={self.limit_amount})>"
|
||||
47
app/db/models/category.py
Normal file
47
app/db/models/category.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Category model for income/expense categories"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class CategoryType(str, PyEnum):
|
||||
"""Types of categories"""
|
||||
EXPENSE = "expense"
|
||||
INCOME = "income"
|
||||
|
||||
|
||||
class Category(Base):
|
||||
"""Category model - income/expense categories"""
|
||||
|
||||
__tablename__ = "categories"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
|
||||
name = Column(String(255), nullable=False)
|
||||
category_type = Column(Enum(CategoryType), nullable=False)
|
||||
emoji = Column(String(10), nullable=True)
|
||||
color = Column(String(7), nullable=True) # Hex color
|
||||
description = Column(String(500), nullable=True)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_default = Column(Boolean, default=False)
|
||||
|
||||
# Order for UI
|
||||
order = Column(Integer, default=0)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="categories")
|
||||
transactions = relationship("Transaction", back_populates="category")
|
||||
budgets = relationship("Budget", back_populates="category")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Category(id={self.id}, name={self.name}, type={self.category_type})>"
|
||||
98
app/db/models/family.py
Normal file
98
app/db/models/family.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Family and Family-related models"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, ForeignKey, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class FamilyRole(str, PyEnum):
|
||||
"""Roles in family"""
|
||||
OWNER = "owner"
|
||||
MEMBER = "member"
|
||||
RESTRICTED = "restricted"
|
||||
|
||||
|
||||
class Family(Base):
|
||||
"""Family model - represents a family group"""
|
||||
|
||||
__tablename__ = "families"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(String(500), nullable=True)
|
||||
currency = Column(String(3), default="RUB") # ISO 4217 code
|
||||
invite_code = Column(String(20), unique=True, nullable=False, index=True)
|
||||
|
||||
# Settings
|
||||
notification_level = Column(String(50), default="all") # all, important, none
|
||||
accounting_period = Column(String(20), default="month") # week, month, year
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
members = relationship("FamilyMember", back_populates="family", cascade="all, delete-orphan")
|
||||
invites = relationship("FamilyInvite", back_populates="family", cascade="all, delete-orphan")
|
||||
accounts = relationship("Account", back_populates="family", cascade="all, delete-orphan")
|
||||
categories = relationship("Category", back_populates="family", cascade="all, delete-orphan")
|
||||
budgets = relationship("Budget", back_populates="family", cascade="all, delete-orphan")
|
||||
goals = relationship("Goal", back_populates="family", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Family(id={self.id}, name={self.name}, currency={self.currency})>"
|
||||
|
||||
|
||||
class FamilyMember(Base):
|
||||
"""Family member model - user membership in family"""
|
||||
|
||||
__tablename__ = "family_members"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
role = Column(Enum(FamilyRole), default=FamilyRole.MEMBER)
|
||||
|
||||
# Permissions
|
||||
can_edit_budget = Column(Boolean, default=True)
|
||||
can_manage_members = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
joined_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="members")
|
||||
user = relationship("User", back_populates="family_members")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FamilyMember(family_id={self.family_id}, user_id={self.user_id}, role={self.role})>"
|
||||
|
||||
|
||||
class FamilyInvite(Base):
|
||||
"""Family invite model - pending invitations"""
|
||||
|
||||
__tablename__ = "family_invites"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
invite_code = Column(String(20), unique=True, nullable=False, index=True)
|
||||
created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
|
||||
# Invite validity
|
||||
is_active = Column(Boolean, default=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="invites")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<FamilyInvite(id={self.id}, family_id={self.family_id})>"
|
||||
44
app/db/models/goal.py
Normal file
44
app/db/models/goal.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Savings goal model"""
|
||||
|
||||
from sqlalchemy import Column, Integer, Float, String, DateTime, Boolean, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class Goal(Base):
|
||||
"""Goal model - savings goals with progress tracking"""
|
||||
|
||||
__tablename__ = "goals"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
account_id = Column(Integer, ForeignKey("accounts.id"), nullable=True)
|
||||
|
||||
# Goal details
|
||||
name = Column(String(255), nullable=False)
|
||||
description = Column(String(500), nullable=True)
|
||||
target_amount = Column(Float, nullable=False)
|
||||
current_amount = Column(Float, default=0.0)
|
||||
|
||||
# Priority
|
||||
priority = Column(Integer, default=0)
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_completed = Column(Boolean, default=False)
|
||||
|
||||
# Deadlines
|
||||
target_date = Column(DateTime, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
completed_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family", back_populates="goals")
|
||||
account = relationship("Account")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Goal(id={self.id}, name={self.name}, target={self.target_amount})>"
|
||||
57
app/db/models/transaction.py
Normal file
57
app/db/models/transaction.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Transaction model for income/expense records"""
|
||||
|
||||
from sqlalchemy import Column, Integer, Float, String, DateTime, Boolean, ForeignKey, Text, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from enum import Enum as PyEnum
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class TransactionType(str, PyEnum):
|
||||
"""Types of transactions"""
|
||||
EXPENSE = "expense"
|
||||
INCOME = "income"
|
||||
TRANSFER = "transfer"
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
"""Transaction model - represents income/expense transaction"""
|
||||
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
family_id = Column(Integer, ForeignKey("families.id"), nullable=False, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||
account_id = Column(Integer, ForeignKey("accounts.id"), nullable=False, index=True)
|
||||
category_id = Column(Integer, ForeignKey("categories.id"), nullable=True)
|
||||
|
||||
# Transaction details
|
||||
amount = Column(Float, nullable=False)
|
||||
transaction_type = Column(Enum(TransactionType), nullable=False)
|
||||
description = Column(String(500), nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
tags = Column(String(500), nullable=True) # Comma-separated tags
|
||||
|
||||
# Receipt
|
||||
receipt_photo_url = Column(String(500), nullable=True)
|
||||
|
||||
# Recurring transaction
|
||||
is_recurring = Column(Boolean, default=False)
|
||||
recurrence_pattern = Column(String(50), nullable=True) # daily, weekly, monthly, etc.
|
||||
|
||||
# Status
|
||||
is_confirmed = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
transaction_date = Column(DateTime, nullable=False, index=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
# Relationships
|
||||
family = relationship("Family")
|
||||
user = relationship("User", back_populates="transactions")
|
||||
account = relationship("Account", back_populates="transactions")
|
||||
category = relationship("Category", back_populates="transactions")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Transaction(id={self.id}, amount={self.amount}, type={self.transaction_type})>"
|
||||
35
app/db/models/user.py
Normal file
35
app/db/models/user.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""User model"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, Text
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime
|
||||
from app.db.database import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""User model - represents a Telegram user"""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
telegram_id = Column(Integer, unique=True, nullable=False, index=True)
|
||||
username = Column(String(255), nullable=True)
|
||||
first_name = Column(String(255), nullable=True)
|
||||
last_name = Column(String(255), nullable=True)
|
||||
phone = Column(String(20), nullable=True)
|
||||
|
||||
# Account status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
last_activity = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
family_members = relationship("FamilyMember", back_populates="user")
|
||||
accounts = relationship("Account", back_populates="owner")
|
||||
transactions = relationship("Transaction", back_populates="user")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User(id={self.id}, telegram_id={self.telegram_id}, username={self.username})>"
|
||||
21
app/db/repositories/__init__.py
Normal file
21
app/db/repositories/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Repository layer for database access"""
|
||||
|
||||
from app.db.repositories.base import BaseRepository
|
||||
from app.db.repositories.user import UserRepository
|
||||
from app.db.repositories.family import FamilyRepository
|
||||
from app.db.repositories.account import AccountRepository
|
||||
from app.db.repositories.category import CategoryRepository
|
||||
from app.db.repositories.transaction import TransactionRepository
|
||||
from app.db.repositories.budget import BudgetRepository
|
||||
from app.db.repositories.goal import GoalRepository
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
"UserRepository",
|
||||
"FamilyRepository",
|
||||
"AccountRepository",
|
||||
"CategoryRepository",
|
||||
"TransactionRepository",
|
||||
"BudgetRepository",
|
||||
"GoalRepository",
|
||||
]
|
||||
54
app/db/repositories/account.py
Normal file
54
app/db/repositories/account.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Account repository"""
|
||||
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Account
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class AccountRepository(BaseRepository[Account]):
|
||||
"""Account data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Account)
|
||||
|
||||
def get_family_accounts(self, family_id: int) -> List[Account]:
|
||||
"""Get all accounts for a family"""
|
||||
return (
|
||||
self.session.query(Account)
|
||||
.filter(Account.family_id == family_id, Account.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_user_accounts(self, user_id: int) -> List[Account]:
|
||||
"""Get all accounts owned by user"""
|
||||
return (
|
||||
self.session.query(Account)
|
||||
.filter(Account.owner_id == user_id, Account.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_account_if_accessible(self, account_id: int, family_id: int) -> Optional[Account]:
|
||||
"""Get account only if it belongs to family"""
|
||||
return (
|
||||
self.session.query(Account)
|
||||
.filter(
|
||||
Account.id == account_id,
|
||||
Account.family_id == family_id,
|
||||
Account.is_active == True
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_balance(self, account_id: int, amount: float) -> Optional[Account]:
|
||||
"""Update account balance by delta"""
|
||||
account = self.get_by_id(account_id)
|
||||
if account:
|
||||
account.balance += amount
|
||||
self.session.commit()
|
||||
self.session.refresh(account)
|
||||
return account
|
||||
|
||||
def archive_account(self, account_id: int) -> Optional[Account]:
|
||||
"""Archive account"""
|
||||
return self.update(account_id, is_archived=True)
|
||||
64
app/db/repositories/base.py
Normal file
64
app/db/repositories/base.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Base repository with generic CRUD operations"""
|
||||
|
||||
from typing import TypeVar, Generic, Type, List, Optional, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
from app.db.database import Base as SQLAlchemyBase
|
||||
|
||||
T = TypeVar("T", bound=SQLAlchemyBase)
|
||||
|
||||
|
||||
class BaseRepository(Generic[T]):
|
||||
"""Generic repository for CRUD operations"""
|
||||
|
||||
def __init__(self, session: Session, model: Type[T]):
|
||||
self.session = session
|
||||
self.model = model
|
||||
|
||||
def create(self, **kwargs) -> T:
|
||||
"""Create and return new instance"""
|
||||
instance = self.model(**kwargs)
|
||||
self.session.add(instance)
|
||||
self.session.commit()
|
||||
self.session.refresh(instance)
|
||||
return instance
|
||||
|
||||
def get_by_id(self, id: Any) -> Optional[T]:
|
||||
"""Get instance by primary key"""
|
||||
return self.session.query(self.model).filter(self.model.id == id).first()
|
||||
|
||||
def get_all(self, skip: int = 0, limit: int = 100) -> List[T]:
|
||||
"""Get all instances with pagination"""
|
||||
return (
|
||||
self.session.query(self.model)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def update(self, id: Any, **kwargs) -> Optional[T]:
|
||||
"""Update instance by id"""
|
||||
instance = self.get_by_id(id)
|
||||
if instance:
|
||||
for key, value in kwargs.items():
|
||||
setattr(instance, key, value)
|
||||
self.session.commit()
|
||||
self.session.refresh(instance)
|
||||
return instance
|
||||
|
||||
def delete(self, id: Any) -> bool:
|
||||
"""Delete instance by id"""
|
||||
instance = self.get_by_id(id)
|
||||
if instance:
|
||||
self.session.delete(instance)
|
||||
self.session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def exists(self, **kwargs) -> bool:
|
||||
"""Check if instance exists with given filters"""
|
||||
return self.session.query(self.model).filter_by(**kwargs).first() is not None
|
||||
|
||||
def count(self, **kwargs) -> int:
|
||||
"""Count instances with given filters"""
|
||||
return self.session.query(self.model).filter_by(**kwargs).count()
|
||||
54
app/db/repositories/budget.py
Normal file
54
app/db/repositories/budget.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Budget repository"""
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Budget
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class BudgetRepository(BaseRepository[Budget]):
|
||||
"""Budget data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Budget)
|
||||
|
||||
def get_family_budgets(self, family_id: int) -> List[Budget]:
|
||||
"""Get all active budgets for family"""
|
||||
return (
|
||||
self.session.query(Budget)
|
||||
.filter(Budget.family_id == family_id, Budget.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_category_budget(self, family_id: int, category_id: int) -> Optional[Budget]:
|
||||
"""Get budget for specific category"""
|
||||
return (
|
||||
self.session.query(Budget)
|
||||
.filter(
|
||||
Budget.family_id == family_id,
|
||||
Budget.category_id == category_id,
|
||||
Budget.is_active == True
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_general_budget(self, family_id: int) -> Optional[Budget]:
|
||||
"""Get general budget (no category)"""
|
||||
return (
|
||||
self.session.query(Budget)
|
||||
.filter(
|
||||
Budget.family_id == family_id,
|
||||
Budget.category_id == None,
|
||||
Budget.is_active == True
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_spent_amount(self, budget_id: int, amount: float) -> Optional[Budget]:
|
||||
"""Update spent amount for budget"""
|
||||
budget = self.get_by_id(budget_id)
|
||||
if budget:
|
||||
budget.spent_amount += amount
|
||||
self.session.commit()
|
||||
self.session.refresh(budget)
|
||||
return budget
|
||||
50
app/db/repositories/category.py
Normal file
50
app/db/repositories/category.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Category repository"""
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Category, CategoryType
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class CategoryRepository(BaseRepository[Category]):
|
||||
"""Category data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Category)
|
||||
|
||||
def get_family_categories(
|
||||
self, family_id: int, category_type: Optional[CategoryType] = None
|
||||
) -> List[Category]:
|
||||
"""Get categories for family, optionally filtered by type"""
|
||||
query = self.session.query(Category).filter(
|
||||
Category.family_id == family_id,
|
||||
Category.is_active == True
|
||||
)
|
||||
if category_type:
|
||||
query = query.filter(Category.category_type == category_type)
|
||||
return query.order_by(Category.order).all()
|
||||
|
||||
def get_by_name(self, family_id: int, name: str) -> Optional[Category]:
|
||||
"""Get category by name"""
|
||||
return (
|
||||
self.session.query(Category)
|
||||
.filter(
|
||||
Category.family_id == family_id,
|
||||
Category.name == name,
|
||||
Category.is_active == True
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def get_default_categories(self, family_id: int, category_type: CategoryType) -> List[Category]:
|
||||
"""Get default categories of type"""
|
||||
return (
|
||||
self.session.query(Category)
|
||||
.filter(
|
||||
Category.family_id == family_id,
|
||||
Category.category_type == category_type,
|
||||
Category.is_default == True,
|
||||
Category.is_active == True
|
||||
)
|
||||
.all()
|
||||
)
|
||||
69
app/db/repositories/family.py
Normal file
69
app/db/repositories/family.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Family repository"""
|
||||
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Family, FamilyMember, FamilyInvite
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class FamilyRepository(BaseRepository[Family]):
|
||||
"""Family data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Family)
|
||||
|
||||
def get_by_invite_code(self, invite_code: str) -> Optional[Family]:
|
||||
"""Get family by invite code"""
|
||||
return self.session.query(Family).filter(Family.invite_code == invite_code).first()
|
||||
|
||||
def get_user_families(self, user_id: int) -> List[Family]:
|
||||
"""Get all families for a user"""
|
||||
return (
|
||||
self.session.query(Family)
|
||||
.join(FamilyMember)
|
||||
.filter(FamilyMember.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
def is_member(self, family_id: int, user_id: int) -> bool:
|
||||
"""Check if user is member of family"""
|
||||
return (
|
||||
self.session.query(FamilyMember)
|
||||
.filter(
|
||||
FamilyMember.family_id == family_id,
|
||||
FamilyMember.user_id == user_id
|
||||
)
|
||||
.first() is not None
|
||||
)
|
||||
|
||||
def add_member(self, family_id: int, user_id: int, role: str = "member") -> FamilyMember:
|
||||
"""Add user to family"""
|
||||
member = FamilyMember(family_id=family_id, user_id=user_id, role=role)
|
||||
self.session.add(member)
|
||||
self.session.commit()
|
||||
self.session.refresh(member)
|
||||
return member
|
||||
|
||||
def remove_member(self, family_id: int, user_id: int) -> bool:
|
||||
"""Remove user from family"""
|
||||
member = (
|
||||
self.session.query(FamilyMember)
|
||||
.filter(
|
||||
FamilyMember.family_id == family_id,
|
||||
FamilyMember.user_id == user_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if member:
|
||||
self.session.delete(member)
|
||||
self.session.commit()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_invite(self, invite_code: str) -> Optional[FamilyInvite]:
|
||||
"""Get invite by code"""
|
||||
return (
|
||||
self.session.query(FamilyInvite)
|
||||
.filter(FamilyInvite.invite_code == invite_code)
|
||||
.first()
|
||||
)
|
||||
50
app/db/repositories/goal.py
Normal file
50
app/db/repositories/goal.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""Goal repository"""
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Goal
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class GoalRepository(BaseRepository[Goal]):
|
||||
"""Goal data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Goal)
|
||||
|
||||
def get_family_goals(self, family_id: int) -> List[Goal]:
|
||||
"""Get all active goals for family"""
|
||||
return (
|
||||
self.session.query(Goal)
|
||||
.filter(Goal.family_id == family_id, Goal.is_active == True)
|
||||
.order_by(Goal.priority.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_goals_progress(self, family_id: int) -> List[dict]:
|
||||
"""Get goals with progress info"""
|
||||
goals = self.get_family_goals(family_id)
|
||||
return [
|
||||
{
|
||||
"id": goal.id,
|
||||
"name": goal.name,
|
||||
"target": goal.target_amount,
|
||||
"current": goal.current_amount,
|
||||
"progress_percent": (goal.current_amount / goal.target_amount * 100) if goal.target_amount > 0 else 0,
|
||||
"is_completed": goal.is_completed
|
||||
}
|
||||
for goal in goals
|
||||
]
|
||||
|
||||
def update_progress(self, goal_id: int, amount: float) -> Optional[Goal]:
|
||||
"""Update goal progress"""
|
||||
goal = self.get_by_id(goal_id)
|
||||
if goal:
|
||||
goal.current_amount += amount
|
||||
if goal.current_amount >= goal.target_amount:
|
||||
goal.is_completed = True
|
||||
from datetime import datetime
|
||||
goal.completed_at = datetime.utcnow()
|
||||
self.session.commit()
|
||||
self.session.refresh(goal)
|
||||
return goal
|
||||
94
app/db/repositories/transaction.py
Normal file
94
app/db/repositories/transaction.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Transaction repository"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from app.db.models import Transaction, TransactionType
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class TransactionRepository(BaseRepository[Transaction]):
|
||||
"""Transaction data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, Transaction)
|
||||
|
||||
def get_family_transactions(self, family_id: int, skip: int = 0, limit: int = 50) -> List[Transaction]:
|
||||
"""Get transactions for family"""
|
||||
return (
|
||||
self.session.query(Transaction)
|
||||
.filter(Transaction.family_id == family_id)
|
||||
.order_by(Transaction.transaction_date.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_transactions_by_period(
|
||||
self, family_id: int, start_date: datetime, end_date: datetime
|
||||
) -> List[Transaction]:
|
||||
"""Get transactions within date range"""
|
||||
return (
|
||||
self.session.query(Transaction)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.family_id == family_id,
|
||||
Transaction.transaction_date >= start_date,
|
||||
Transaction.transaction_date <= end_date
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.transaction_date.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_transactions_by_category(
|
||||
self, family_id: int, category_id: int, start_date: datetime, end_date: datetime
|
||||
) -> List[Transaction]:
|
||||
"""Get transactions by category in date range"""
|
||||
return (
|
||||
self.session.query(Transaction)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.family_id == family_id,
|
||||
Transaction.category_id == category_id,
|
||||
Transaction.transaction_date >= start_date,
|
||||
Transaction.transaction_date <= end_date
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
def get_user_transactions(self, user_id: int, days: int = 30) -> List[Transaction]:
|
||||
"""Get user's recent transactions"""
|
||||
start_date = datetime.utcnow() - timedelta(days=days)
|
||||
return (
|
||||
self.session.query(Transaction)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.user_id == user_id,
|
||||
Transaction.transaction_date >= start_date
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.transaction_date.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
def sum_by_category(
|
||||
self, family_id: int, category_id: int, start_date: datetime, end_date: datetime
|
||||
) -> float:
|
||||
"""Calculate sum of transactions by category"""
|
||||
result = (
|
||||
self.session.query(Transaction)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.family_id == family_id,
|
||||
Transaction.category_id == category_id,
|
||||
Transaction.transaction_date >= start_date,
|
||||
Transaction.transaction_date <= end_date,
|
||||
Transaction.transaction_type == TransactionType.EXPENSE
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return sum(t.amount for t in result)
|
||||
38
app/db/repositories/user.py
Normal file
38
app/db/repositories/user.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""User repository"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import User
|
||||
from app.db.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class UserRepository(BaseRepository[User]):
|
||||
"""User data access operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
super().__init__(session, User)
|
||||
|
||||
def get_by_telegram_id(self, telegram_id: int) -> Optional[User]:
|
||||
"""Get user by Telegram ID"""
|
||||
return self.session.query(User).filter(User.telegram_id == telegram_id).first()
|
||||
|
||||
def get_by_username(self, username: str) -> Optional[User]:
|
||||
"""Get user by username"""
|
||||
return self.session.query(User).filter(User.username == username).first()
|
||||
|
||||
def get_or_create(self, telegram_id: int, **kwargs) -> User:
|
||||
"""Get user or create if doesn't exist"""
|
||||
user = self.get_by_telegram_id(telegram_id)
|
||||
if not user:
|
||||
user = self.create(telegram_id=telegram_id, **kwargs)
|
||||
return user
|
||||
|
||||
def update_activity(self, telegram_id: int) -> Optional[User]:
|
||||
"""Update user's last activity timestamp"""
|
||||
from datetime import datetime
|
||||
user = self.get_by_telegram_id(telegram_id)
|
||||
if user:
|
||||
user.last_activity = datetime.utcnow()
|
||||
self.session.commit()
|
||||
self.session.refresh(user)
|
||||
return user
|
||||
109
app/main.py
Normal file
109
app/main.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
FastAPI Application Entry Point
|
||||
Integrated API Gateway + Telegram Bot
|
||||
"""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.config import settings
|
||||
from app.db.database import engine, Base, get_db
|
||||
from app.security.middleware import add_security_middleware
|
||||
from app.api import transactions, auth
|
||||
import redis
|
||||
|
||||
# Suppress Pydantic V2 migration warnings
|
||||
warnings.filterwarnings('ignore', message="Valid config keys have changed in V2")
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client
|
||||
redis_client = redis.from_url(settings.redis_url)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Startup/Shutdown events
|
||||
"""
|
||||
# === STARTUP ===
|
||||
logger.info("🚀 Application starting...")
|
||||
|
||||
# Create database tables (if not exist)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Database initialized")
|
||||
|
||||
# Verify Redis connection
|
||||
try:
|
||||
redis_client.ping()
|
||||
logger.info("✅ Redis connected")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Redis connection failed: {e}")
|
||||
|
||||
yield
|
||||
|
||||
# === SHUTDOWN ===
|
||||
logger.info("🛑 Application shutting down...")
|
||||
redis_client.close()
|
||||
logger.info("✅ Cleanup complete")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="Finance Bot API",
|
||||
description="API-First Zero-Trust Architecture for Family Finance Management",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_allowed_origins,
|
||||
allow_credentials=settings.cors_allow_credentials,
|
||||
allow_methods=settings.cors_allow_methods,
|
||||
allow_headers=settings.cors_allow_headers,
|
||||
)
|
||||
|
||||
# Add security middleware
|
||||
add_security_middleware(app, redis_client, next(get_db()))
|
||||
|
||||
# Include API routers
|
||||
app.include_router(auth.router)
|
||||
app.include_router(transactions.router)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint.
|
||||
No authentication required.
|
||||
"""
|
||||
return {
|
||||
"status": "ok",
|
||||
"environment": settings.app_env,
|
||||
"version": "1.0.0",
|
||||
}
|
||||
|
||||
|
||||
# ========== Graceful Shutdown ==========
|
||||
import signal
|
||||
import asyncio
|
||||
|
||||
async def shutdown_handler(sig):
|
||||
"""Handle graceful shutdown"""
|
||||
logger.info(f"Received signal {sig}, shutting down...")
|
||||
|
||||
# Close connections
|
||||
redis_client.close()
|
||||
|
||||
# Exit
|
||||
return 0
|
||||
27
app/schemas/__init__.py
Normal file
27
app/schemas/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Pydantic schemas for request/response validation"""
|
||||
|
||||
from app.schemas.user import UserSchema, UserCreateSchema
|
||||
from app.schemas.family import FamilySchema, FamilyCreateSchema, FamilyMemberSchema
|
||||
from app.schemas.account import AccountSchema, AccountCreateSchema
|
||||
from app.schemas.category import CategorySchema, CategoryCreateSchema
|
||||
from app.schemas.transaction import TransactionSchema, TransactionCreateSchema
|
||||
from app.schemas.budget import BudgetSchema, BudgetCreateSchema
|
||||
from app.schemas.goal import GoalSchema, GoalCreateSchema
|
||||
|
||||
__all__ = [
|
||||
"UserSchema",
|
||||
"UserCreateSchema",
|
||||
"FamilySchema",
|
||||
"FamilyCreateSchema",
|
||||
"FamilyMemberSchema",
|
||||
"AccountSchema",
|
||||
"AccountCreateSchema",
|
||||
"CategorySchema",
|
||||
"CategoryCreateSchema",
|
||||
"TransactionSchema",
|
||||
"TransactionCreateSchema",
|
||||
"BudgetSchema",
|
||||
"BudgetCreateSchema",
|
||||
"GoalSchema",
|
||||
"GoalCreateSchema",
|
||||
]
|
||||
28
app/schemas/account.py
Normal file
28
app/schemas/account.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Account schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class AccountCreateSchema(BaseModel):
|
||||
"""Schema for creating account"""
|
||||
name: str
|
||||
account_type: str = "card"
|
||||
description: Optional[str] = None
|
||||
initial_balance: float = 0.0
|
||||
|
||||
|
||||
class AccountSchema(AccountCreateSchema):
|
||||
"""Account response schema"""
|
||||
id: int
|
||||
family_id: int
|
||||
owner_id: int
|
||||
balance: float
|
||||
is_active: bool
|
||||
is_archived: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
29
app/schemas/budget.py
Normal file
29
app/schemas/budget.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Budget schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BudgetCreateSchema(BaseModel):
|
||||
"""Schema for creating budget"""
|
||||
name: str
|
||||
limit_amount: float
|
||||
period: str = "monthly"
|
||||
alert_threshold: float = 80.0
|
||||
category_id: Optional[int] = None
|
||||
start_date: datetime
|
||||
|
||||
|
||||
class BudgetSchema(BudgetCreateSchema):
|
||||
"""Budget response schema"""
|
||||
id: int
|
||||
family_id: int
|
||||
spent_amount: float
|
||||
is_active: bool
|
||||
end_date: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
28
app/schemas/category.py
Normal file
28
app/schemas/category.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Category schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class CategoryCreateSchema(BaseModel):
|
||||
"""Schema for creating category"""
|
||||
name: str
|
||||
category_type: str
|
||||
emoji: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
is_default: bool = False
|
||||
|
||||
|
||||
class CategorySchema(CategoryCreateSchema):
|
||||
"""Category response schema"""
|
||||
id: int
|
||||
family_id: int
|
||||
is_active: bool
|
||||
order: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
41
app/schemas/family.py
Normal file
41
app/schemas/family.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Family schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
class FamilyMemberSchema(BaseModel):
|
||||
"""Family member schema"""
|
||||
id: int
|
||||
user_id: int
|
||||
role: str
|
||||
can_edit_budget: bool
|
||||
can_manage_members: bool
|
||||
joined_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FamilyCreateSchema(BaseModel):
|
||||
"""Schema for creating family"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
currency: str = "RUB"
|
||||
notification_level: str = "all"
|
||||
accounting_period: str = "month"
|
||||
|
||||
|
||||
class FamilySchema(FamilyCreateSchema):
|
||||
"""Family response schema"""
|
||||
id: int
|
||||
owner_id: int
|
||||
invite_code: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
members: List[FamilyMemberSchema] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
30
app/schemas/goal.py
Normal file
30
app/schemas/goal.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Goal schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class GoalCreateSchema(BaseModel):
|
||||
"""Schema for creating goal"""
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
target_amount: float
|
||||
priority: int = 0
|
||||
target_date: Optional[datetime] = None
|
||||
account_id: Optional[int] = None
|
||||
|
||||
|
||||
class GoalSchema(GoalCreateSchema):
|
||||
"""Goal response schema"""
|
||||
id: int
|
||||
family_id: int
|
||||
current_amount: float
|
||||
is_active: bool
|
||||
is_completed: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
33
app/schemas/transaction.py
Normal file
33
app/schemas/transaction.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""Transaction schemas"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TransactionCreateSchema(BaseModel):
|
||||
"""Schema for creating transaction"""
|
||||
amount: float
|
||||
transaction_type: str
|
||||
description: Optional[str] = None
|
||||
notes: Optional[str] = None
|
||||
tags: Optional[str] = None
|
||||
category_id: Optional[int] = None
|
||||
receipt_photo_url: Optional[str] = None
|
||||
transaction_date: datetime
|
||||
|
||||
|
||||
class TransactionSchema(TransactionCreateSchema):
|
||||
"""Transaction response schema"""
|
||||
id: int
|
||||
family_id: int
|
||||
user_id: int
|
||||
account_id: int
|
||||
is_confirmed: bool
|
||||
is_recurring: bool
|
||||
recurrence_pattern: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
26
app/schemas/user.py
Normal file
26
app/schemas/user.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""User schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class UserCreateSchema(BaseModel):
|
||||
"""Schema for creating user"""
|
||||
telegram_id: int
|
||||
username: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
|
||||
|
||||
class UserSchema(UserCreateSchema):
|
||||
"""User response schema"""
|
||||
id: int
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_activity: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
1
app/security/__init__.py
Normal file
1
app/security/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Security module: JWT, HMAC, RBAC
|
||||
145
app/security/hmac_manager.py
Normal file
145
app/security/hmac_manager.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
HMAC Signature Verification - Replay Attack Prevention & Request Integrity
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from urllib.parse import urlencode
|
||||
from app.core.config import settings
|
||||
import redis
|
||||
|
||||
|
||||
class HMACManager:
|
||||
"""
|
||||
Request signing and verification using HMAC-SHA256.
|
||||
|
||||
Signature Format:
|
||||
────────────────────────────────────────────────────
|
||||
base_string = METHOD + ENDPOINT + TIMESTAMP + hash(BODY)
|
||||
signature = HMAC_SHA256(base_string, client_secret)
|
||||
|
||||
Headers Required:
|
||||
- X-Signature: base64(signature)
|
||||
- X-Timestamp: unix timestamp (seconds)
|
||||
- X-Client-Id: client identifier
|
||||
|
||||
Anti-Replay Protection:
|
||||
- Check timestamp freshness (±30 seconds)
|
||||
- Store signature hash in Redis with 1-minute TTL
|
||||
- Reject duplicate signatures (nonce check)
|
||||
"""
|
||||
|
||||
# Configuration
|
||||
TIMESTAMP_TOLERANCE_SECONDS = 30
|
||||
REPLAY_NONCE_TTL_SECONDS = 60
|
||||
|
||||
def __init__(self, redis_client: redis.Redis = None):
|
||||
self.redis_client = redis_client
|
||||
self.algorithm = "sha256"
|
||||
|
||||
def create_signature(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
body: dict = None,
|
||||
client_secret: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create HMAC signature for request.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, PUT, DELETE)
|
||||
endpoint: API endpoint path (/api/v1/transactions)
|
||||
timestamp: Unix timestamp
|
||||
body: Request body dictionary
|
||||
client_secret: Shared secret key
|
||||
|
||||
Returns:
|
||||
Base64-encoded signature
|
||||
"""
|
||||
if client_secret is None:
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
# Create base string
|
||||
base_string = self._build_base_string(method, endpoint, timestamp, body)
|
||||
|
||||
# Generate HMAC
|
||||
signature = hmac.new(
|
||||
client_secret.encode(),
|
||||
base_string.encode(),
|
||||
hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return signature
|
||||
|
||||
def verify_signature(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
signature: str,
|
||||
body: dict = None,
|
||||
client_secret: str = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Verify HMAC signature and check for replay attacks.
|
||||
|
||||
Returns:
|
||||
(is_valid, error_message)
|
||||
"""
|
||||
if client_secret is None:
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
# Step 1: Check timestamp freshness
|
||||
now = datetime.utcnow().timestamp()
|
||||
time_diff = abs(now - timestamp)
|
||||
|
||||
if time_diff > self.TIMESTAMP_TOLERANCE_SECONDS:
|
||||
return False, f"Timestamp too old (diff: {time_diff}s)"
|
||||
|
||||
# Step 2: Verify signature match
|
||||
expected_signature = self.create_signature(
|
||||
method, endpoint, timestamp, body, client_secret
|
||||
)
|
||||
|
||||
if not hmac.compare_digest(signature, expected_signature):
|
||||
return False, "Signature mismatch"
|
||||
|
||||
# Step 3: Check for replay (signature already used)
|
||||
if self.redis_client:
|
||||
nonce_key = f"hmac:nonce:{signature}"
|
||||
if self.redis_client.exists(nonce_key):
|
||||
return False, "Signature already used (replay attack)"
|
||||
|
||||
# Store nonce
|
||||
self.redis_client.setex(nonce_key, self.REPLAY_NONCE_TTL_SECONDS, "1")
|
||||
|
||||
return True, ""
|
||||
|
||||
def _build_base_string(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
timestamp: int,
|
||||
body: dict = None,
|
||||
) -> str:
|
||||
"""Construct base string for signing"""
|
||||
# Normalize method
|
||||
method = method.upper()
|
||||
|
||||
# Hash body (sorted JSON)
|
||||
body_hash = ""
|
||||
if body:
|
||||
body_json = json.dumps(body, sort_keys=True, separators=(',', ':'))
|
||||
body_hash = hashlib.sha256(body_json.encode()).hexdigest()
|
||||
|
||||
# Base string format
|
||||
base_string = f"{method}:{endpoint}:{timestamp}:{body_hash}"
|
||||
return base_string
|
||||
|
||||
|
||||
# Singleton instance
|
||||
hmac_manager = HMACManager()
|
||||
149
app/security/jwt_manager.py
Normal file
149
app/security/jwt_manager.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
JWT Token Management - Access & Refresh Token Handling
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from enum import Enum
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TokenType(str, Enum):
|
||||
ACCESS = "access"
|
||||
REFRESH = "refresh"
|
||||
SERVICE = "service" # For bot/workers
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT Token Payload Structure"""
|
||||
sub: int # user_id
|
||||
type: TokenType
|
||||
device_id: Optional[str] = None
|
||||
scope: str = "default" # For granular permissions
|
||||
family_ids: list[int] = [] # Accessible families
|
||||
iat: int # issued at
|
||||
exp: int # expiration
|
||||
|
||||
|
||||
class JWTManager:
|
||||
"""
|
||||
JWT token generation, validation, and management.
|
||||
|
||||
Algorithms:
|
||||
- Production: RS256 (asymmetric) - more secure, scalable
|
||||
- MVP: HS256 (symmetric) - simpler setup
|
||||
"""
|
||||
|
||||
# Token lifetimes (configurable in settings)
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 15 # Short-lived
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = 30 # Long-lived
|
||||
SERVICE_TOKEN_EXPIRE_HOURS = 8760 # 1 year
|
||||
|
||||
def __init__(self, secret_key: str = None):
|
||||
self.secret_key = secret_key or settings.jwt_secret_key
|
||||
self.algorithm = "HS256"
|
||||
|
||||
def create_access_token(
|
||||
self,
|
||||
user_id: int,
|
||||
device_id: Optional[str] = None,
|
||||
family_ids: list[int] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate short-lived access token"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(minutes=self.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
return self._create_token(
|
||||
user_id=user_id,
|
||||
token_type=TokenType.ACCESS,
|
||||
expires_delta=expires_delta,
|
||||
device_id=device_id,
|
||||
family_ids=family_ids or [],
|
||||
)
|
||||
|
||||
def create_refresh_token(
|
||||
self,
|
||||
user_id: int,
|
||||
device_id: Optional[str] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate long-lived refresh token"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(days=self.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
return self._create_token(
|
||||
user_id=user_id,
|
||||
token_type=TokenType.REFRESH,
|
||||
expires_delta=expires_delta,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
def create_service_token(
|
||||
self,
|
||||
service_name: str,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
) -> str:
|
||||
"""Generate service-to-service token (e.g., for bot)"""
|
||||
if expires_delta is None:
|
||||
expires_delta = timedelta(hours=self.SERVICE_TOKEN_EXPIRE_HOURS)
|
||||
|
||||
now = datetime.utcnow()
|
||||
expire = now + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": f"service:{service_name}",
|
||||
"type": TokenType.SERVICE,
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def _create_token(
|
||||
self,
|
||||
user_id: int,
|
||||
token_type: TokenType,
|
||||
expires_delta: timedelta,
|
||||
device_id: Optional[str] = None,
|
||||
family_ids: list[int] = None,
|
||||
) -> str:
|
||||
"""Internal token creation"""
|
||||
now = datetime.utcnow()
|
||||
expire = now + expires_delta
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"type": token_type.value,
|
||||
"device_id": device_id,
|
||||
"family_ids": family_ids or [],
|
||||
"iat": int(now.timestamp()),
|
||||
"exp": int(expire.timestamp()),
|
||||
}
|
||||
|
||||
return jwt.encode(payload, self.secret_key, algorithm=self.algorithm)
|
||||
|
||||
def verify_token(self, token: str) -> TokenPayload:
|
||||
"""
|
||||
Verify token signature and expiration.
|
||||
|
||||
Raises:
|
||||
- jwt.InvalidTokenError
|
||||
- jwt.ExpiredSignatureError
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, self.secret_key, algorithms=[self.algorithm])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise ValueError("Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
def decode_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Decode token without verification (for debugging only)"""
|
||||
return jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
|
||||
# Singleton instance
|
||||
jwt_manager = JWTManager()
|
||||
308
app/security/middleware.py
Normal file
308
app/security/middleware.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
FastAPI Middleware Stack - Authentication, Authorization, and Security
|
||||
"""
|
||||
from fastapi import FastAPI, Request, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Callable, Any
|
||||
from datetime import datetime
|
||||
import redis
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.security.jwt_manager import jwt_manager, TokenPayload
|
||||
from app.security.hmac_manager import hmac_manager
|
||||
from app.security.rbac import RBACEngine, UserContext, MemberRole, Permission
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
response = await call_next(request)
|
||||
|
||||
# Security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'self'"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting using Redis"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
self.rate_limit_requests = 100 # requests
|
||||
self.rate_limit_window = 60 # seconds
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path == "/health":
|
||||
return await call_next(request)
|
||||
|
||||
# Get client IP
|
||||
client_ip = request.client.host
|
||||
|
||||
# Rate limit key
|
||||
rate_key = f"rate_limit:{client_ip}"
|
||||
|
||||
# Check rate limit
|
||||
try:
|
||||
current = self.redis_client.get(rate_key)
|
||||
if current and int(current) >= self.rate_limit_requests:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Rate limit exceeded"}
|
||||
)
|
||||
|
||||
# Increment counter
|
||||
pipe = self.redis_client.pipeline()
|
||||
pipe.incr(rate_key)
|
||||
pipe.expire(rate_key, self.rate_limit_window)
|
||||
pipe.execute()
|
||||
except Exception as e:
|
||||
logger.warning(f"Rate limiting error: {e}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HMACVerificationMiddleware(BaseHTTPMiddleware):
|
||||
"""HMAC signature verification and anti-replay protection"""
|
||||
|
||||
def __init__(self, app: FastAPI, redis_client: redis.Redis):
|
||||
super().__init__(app)
|
||||
self.redis_client = redis_client
|
||||
hmac_manager.redis_client = redis_client
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip verification for public endpoints
|
||||
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract HMAC headers
|
||||
signature = request.headers.get("X-Signature")
|
||||
timestamp = request.headers.get("X-Timestamp")
|
||||
client_id = request.headers.get("X-Client-Id", "unknown")
|
||||
|
||||
# HMAC verification is optional in MVP (configurable)
|
||||
if settings.require_hmac_verification:
|
||||
if not signature or not timestamp:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Missing HMAC headers"}
|
||||
)
|
||||
|
||||
try:
|
||||
timestamp_int = int(timestamp)
|
||||
except ValueError:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={"detail": "Invalid timestamp format"}
|
||||
)
|
||||
|
||||
# Read body for signature verification
|
||||
body = await request.body()
|
||||
body_dict = {}
|
||||
if body:
|
||||
try:
|
||||
import json
|
||||
body_dict = json.loads(body)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Verify HMAC
|
||||
# Get client secret (hardcoded for MVP, should be from DB)
|
||||
client_secret = settings.hmac_secret_key
|
||||
|
||||
is_valid, error_msg = hmac_manager.verify_signature(
|
||||
method=request.method,
|
||||
endpoint=request.url.path,
|
||||
timestamp=timestamp_int,
|
||||
signature=signature,
|
||||
body=body_dict,
|
||||
client_secret=client_secret,
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"HMAC verification failed: {error_msg} (client: {client_id})")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": f"HMAC verification failed: {error_msg}"}
|
||||
)
|
||||
|
||||
# Store in request state for logging
|
||||
request.state.client_id = client_id
|
||||
request.state.timestamp = timestamp
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class JWTAuthenticationMiddleware(BaseHTTPMiddleware):
|
||||
"""JWT token verification and extraction"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip auth for public endpoints
|
||||
public_paths = ["/health", "/docs", "/openapi.json", "/api/v1/auth/login", "/api/v1/auth/telegram/start"]
|
||||
if request.url.path in public_paths:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract token from Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing Authorization header"}
|
||||
)
|
||||
|
||||
# Parse "Bearer <token>"
|
||||
try:
|
||||
scheme, token = auth_header.split()
|
||||
if scheme.lower() != "bearer":
|
||||
raise ValueError("Invalid auth scheme")
|
||||
except (ValueError, IndexError):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid Authorization header format"}
|
||||
)
|
||||
|
||||
# Verify JWT
|
||||
try:
|
||||
token_payload = jwt_manager.verify_token(token)
|
||||
except ValueError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
|
||||
# Store in request state
|
||||
request.state.user_id = token_payload.sub
|
||||
request.state.token_type = token_payload.type
|
||||
request.state.device_id = token_payload.device_id
|
||||
request.state.family_ids = token_payload.family_ids
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RBACMiddleware(BaseHTTPMiddleware):
|
||||
"""Role-Based Access Control enforcement"""
|
||||
|
||||
def __init__(self, app: FastAPI, db_session: Any):
|
||||
super().__init__(app)
|
||||
self.db_session = db_session
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Skip RBAC for public endpoints
|
||||
if request.url.path in ["/health", "/docs", "/openapi.json"]:
|
||||
return await call_next(request)
|
||||
|
||||
# Get user context from JWT
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
family_ids = getattr(request.state, "family_ids", [])
|
||||
|
||||
if not user_id:
|
||||
# Already handled by JWTAuthenticationMiddleware
|
||||
return await call_next(request)
|
||||
|
||||
# Extract family_id from URL or body
|
||||
family_id = self._extract_family_id(request)
|
||||
|
||||
if family_id and family_id not in family_ids:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "Access denied to this family"}
|
||||
)
|
||||
|
||||
# Load user role (would need DB query in production)
|
||||
# For MVP: Store in request state, resolved in endpoint handlers
|
||||
request.state.family_id = family_id
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
@staticmethod
|
||||
def _extract_family_id(request: Request) -> Optional[int]:
|
||||
"""Extract family_id from URL or request body"""
|
||||
# From URL path: /api/v1/families/{family_id}/...
|
||||
if "{family_id}" in request.url.path:
|
||||
# Parse from actual path
|
||||
parts = request.url.path.split("/")
|
||||
for i, part in enumerate(parts):
|
||||
if part == "families" and i + 1 < len(parts):
|
||||
try:
|
||||
return int(parts[i + 1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Log all requests and responses for audit"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Any:
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Get client info
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
|
||||
# Process request
|
||||
try:
|
||||
response = await call_next(request)
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Log successful request
|
||||
logger.info(
|
||||
f"Endpoint={request.url.path} "
|
||||
f"Method={request.method} "
|
||||
f"Status={response.status_code} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
|
||||
# Add timing header
|
||||
response.headers["X-Response-Time"] = str(response_time_ms)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
response_time_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"Request error - Endpoint={request.url.path} "
|
||||
f"Error={str(e)} "
|
||||
f"Time={response_time_ms}ms "
|
||||
f"User={user_id} "
|
||||
f"IP={client_ip}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def add_security_middleware(app: FastAPI, redis_client: redis.Redis, db_session: Any):
|
||||
"""Register all security middleware in correct order"""
|
||||
|
||||
# Order matters! Process in reverse order of registration:
|
||||
# 1. RequestLoggingMiddleware (innermost, executes last)
|
||||
# 2. RBACMiddleware
|
||||
# 3. JWTAuthenticationMiddleware
|
||||
# 4. HMACVerificationMiddleware
|
||||
# 5. RateLimitMiddleware
|
||||
# 6. SecurityHeadersMiddleware (outermost, executes first)
|
||||
|
||||
app.add_middleware(RequestLoggingMiddleware)
|
||||
app.add_middleware(RBACMiddleware, db_session=db_session)
|
||||
app.add_middleware(JWTAuthenticationMiddleware)
|
||||
app.add_middleware(HMACVerificationMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(RateLimitMiddleware, redis_client=redis_client)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
228
app/security/rbac.py
Normal file
228
app/security/rbac.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Role-Based Access Control (RBAC) - Authorization Engine
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional, Set, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class MemberRole(str, Enum):
|
||||
"""Family member roles with hierarchy"""
|
||||
OWNER = "owner" # Full access
|
||||
ADULT = "adult" # Can create/edit own transactions
|
||||
MEMBER = "member" # Can create/edit own transactions, restricted budget
|
||||
CHILD = "child" # Limited access, read mostly
|
||||
READ_ONLY = "read_only" # Audit/observer only
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Fine-grained permissions"""
|
||||
# Transaction permissions
|
||||
CREATE_TRANSACTION = "create_transaction"
|
||||
EDIT_OWN_TRANSACTION = "edit_own_transaction"
|
||||
EDIT_ANY_TRANSACTION = "edit_any_transaction"
|
||||
DELETE_OWN_TRANSACTION = "delete_own_transaction"
|
||||
DELETE_ANY_TRANSACTION = "delete_any_transaction"
|
||||
APPROVE_TRANSACTION = "approve_transaction"
|
||||
|
||||
# Wallet permissions
|
||||
CREATE_WALLET = "create_wallet"
|
||||
EDIT_WALLET = "edit_wallet"
|
||||
DELETE_WALLET = "delete_wallet"
|
||||
VIEW_WALLET_BALANCE = "view_wallet_balance"
|
||||
|
||||
# Budget permissions
|
||||
CREATE_BUDGET = "create_budget"
|
||||
EDIT_BUDGET = "edit_budget"
|
||||
DELETE_BUDGET = "delete_budget"
|
||||
|
||||
# Goal permissions
|
||||
CREATE_GOAL = "create_goal"
|
||||
EDIT_GOAL = "edit_goal"
|
||||
DELETE_GOAL = "delete_goal"
|
||||
|
||||
# Category permissions
|
||||
CREATE_CATEGORY = "create_category"
|
||||
EDIT_CATEGORY = "edit_category"
|
||||
DELETE_CATEGORY = "delete_category"
|
||||
|
||||
# Member management
|
||||
INVITE_MEMBERS = "invite_members"
|
||||
EDIT_MEMBER_ROLE = "edit_member_role"
|
||||
REMOVE_MEMBER = "remove_member"
|
||||
|
||||
# Family settings
|
||||
EDIT_FAMILY_SETTINGS = "edit_family_settings"
|
||||
DELETE_FAMILY = "delete_family"
|
||||
|
||||
# Audit & reports
|
||||
VIEW_AUDIT_LOG = "view_audit_log"
|
||||
EXPORT_DATA = "export_data"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserContext:
|
||||
"""Request context with authorization info"""
|
||||
user_id: int
|
||||
family_id: int
|
||||
role: MemberRole
|
||||
permissions: Set[Permission]
|
||||
family_ids: list[int] # All accessible families
|
||||
device_id: Optional[str] = None
|
||||
client_id: Optional[str] = None # "telegram_bot", "web_frontend", etc.
|
||||
|
||||
|
||||
class RBACEngine:
|
||||
"""
|
||||
Role-Based Access Control with permission inheritance.
|
||||
"""
|
||||
|
||||
# Define role -> permissions mapping
|
||||
ROLE_PERMISSIONS: Dict[MemberRole, Set[Permission]] = {
|
||||
MemberRole.OWNER: {
|
||||
# All permissions
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.EDIT_ANY_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.DELETE_ANY_TRANSACTION,
|
||||
Permission.APPROVE_TRANSACTION,
|
||||
Permission.CREATE_WALLET,
|
||||
Permission.EDIT_WALLET,
|
||||
Permission.DELETE_WALLET,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.CREATE_BUDGET,
|
||||
Permission.EDIT_BUDGET,
|
||||
Permission.DELETE_BUDGET,
|
||||
Permission.CREATE_GOAL,
|
||||
Permission.EDIT_GOAL,
|
||||
Permission.DELETE_GOAL,
|
||||
Permission.CREATE_CATEGORY,
|
||||
Permission.EDIT_CATEGORY,
|
||||
Permission.DELETE_CATEGORY,
|
||||
Permission.INVITE_MEMBERS,
|
||||
Permission.EDIT_MEMBER_ROLE,
|
||||
Permission.REMOVE_MEMBER,
|
||||
Permission.EDIT_FAMILY_SETTINGS,
|
||||
Permission.DELETE_FAMILY,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
Permission.EXPORT_DATA,
|
||||
},
|
||||
|
||||
MemberRole.ADULT: {
|
||||
# Can manage finances and invite others
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.APPROVE_TRANSACTION,
|
||||
Permission.CREATE_WALLET,
|
||||
Permission.EDIT_WALLET,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.CREATE_BUDGET,
|
||||
Permission.EDIT_BUDGET,
|
||||
Permission.CREATE_GOAL,
|
||||
Permission.EDIT_GOAL,
|
||||
Permission.CREATE_CATEGORY,
|
||||
Permission.INVITE_MEMBERS,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
Permission.EXPORT_DATA,
|
||||
},
|
||||
|
||||
MemberRole.MEMBER: {
|
||||
# Can create/view transactions
|
||||
Permission.CREATE_TRANSACTION,
|
||||
Permission.EDIT_OWN_TRANSACTION,
|
||||
Permission.DELETE_OWN_TRANSACTION,
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
|
||||
MemberRole.CHILD: {
|
||||
# Limited read access
|
||||
Permission.CREATE_TRANSACTION, # Limited to own
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
|
||||
MemberRole.READ_ONLY: {
|
||||
# Audit/observer only
|
||||
Permission.VIEW_WALLET_BALANCE,
|
||||
Permission.VIEW_AUDIT_LOG,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(role: MemberRole) -> Set[Permission]:
|
||||
"""Get permissions for a role"""
|
||||
return RBACEngine.ROLE_PERMISSIONS.get(role, set())
|
||||
|
||||
@staticmethod
|
||||
def has_permission(user_context: UserContext, permission: Permission) -> bool:
|
||||
"""Check if user has specific permission"""
|
||||
return permission in user_context.permissions
|
||||
|
||||
@staticmethod
|
||||
def check_permission(
|
||||
user_context: UserContext,
|
||||
required_permission: Permission,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check permission and optionally raise exception.
|
||||
|
||||
Raises:
|
||||
- PermissionError if raise_exception=True and user lacks permission
|
||||
"""
|
||||
has_perm = RBACEngine.has_permission(user_context, required_permission)
|
||||
|
||||
if not has_perm and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} lacks permission: {required_permission.value}"
|
||||
)
|
||||
|
||||
return has_perm
|
||||
|
||||
@staticmethod
|
||||
def check_family_access(
|
||||
user_context: UserContext,
|
||||
requested_family_id: int,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""Verify user has access to requested family"""
|
||||
has_access = requested_family_id in user_context.family_ids
|
||||
|
||||
if not has_access and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} cannot access family {requested_family_id}"
|
||||
)
|
||||
|
||||
return has_access
|
||||
|
||||
@staticmethod
|
||||
def check_resource_ownership(
|
||||
user_context: UserContext,
|
||||
owner_id: int,
|
||||
raise_exception: bool = True,
|
||||
) -> bool:
|
||||
"""Check if user is owner of resource"""
|
||||
is_owner = user_context.user_id == owner_id
|
||||
|
||||
if not is_owner and raise_exception:
|
||||
raise PermissionError(
|
||||
f"User {user_context.user_id} is not owner of resource (owner: {owner_id})"
|
||||
)
|
||||
|
||||
return is_owner
|
||||
|
||||
|
||||
# Policy definitions (for advanced use)
|
||||
POLICIES = {
|
||||
"transaction_approval_required": {
|
||||
"conditions": ["amount > 500", "role != owner"],
|
||||
"action": "require_approval"
|
||||
},
|
||||
"restrict_child_budget": {
|
||||
"conditions": ["role == child"],
|
||||
"action": "limit_to_100_per_day"
|
||||
},
|
||||
}
|
||||
14
app/services/__init__.py
Normal file
14
app/services/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Main services package"""
|
||||
|
||||
from app.services.finance import TransactionService, BudgetService, GoalService, AccountService
|
||||
from app.services.analytics import ReportService
|
||||
from app.services.notifications import NotificationService
|
||||
|
||||
__all__ = [
|
||||
"TransactionService",
|
||||
"BudgetService",
|
||||
"GoalService",
|
||||
"AccountService",
|
||||
"ReportService",
|
||||
"NotificationService",
|
||||
]
|
||||
5
app/services/analytics/__init__.py
Normal file
5
app/services/analytics/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Analytics service module"""
|
||||
|
||||
from app.services.analytics.report_service import ReportService
|
||||
|
||||
__all__ = ["ReportService"]
|
||||
111
app/services/analytics/report_service.py
Normal file
111
app/services/analytics/report_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Report service for analytics"""
|
||||
|
||||
from typing import List, Dict
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.repositories import TransactionRepository, CategoryRepository
|
||||
from app.db.models import TransactionType
|
||||
|
||||
|
||||
class ReportService:
|
||||
"""Service for generating financial reports"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.transaction_repo = TransactionRepository(session)
|
||||
self.category_repo = CategoryRepository(session)
|
||||
|
||||
def get_expenses_by_category(
|
||||
self, family_id: int, start_date: datetime, end_date: datetime
|
||||
) -> Dict[str, float]:
|
||||
"""Get expense breakdown by category"""
|
||||
transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, start_date, end_date
|
||||
)
|
||||
|
||||
expenses_by_category = {}
|
||||
for transaction in transactions:
|
||||
if transaction.transaction_type == TransactionType.EXPENSE:
|
||||
category_name = transaction.category.name if transaction.category else "Без категории"
|
||||
if category_name not in expenses_by_category:
|
||||
expenses_by_category[category_name] = 0
|
||||
expenses_by_category[category_name] += transaction.amount
|
||||
|
||||
# Sort by amount descending
|
||||
return dict(sorted(expenses_by_category.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
def get_expenses_by_user(
|
||||
self, family_id: int, start_date: datetime, end_date: datetime
|
||||
) -> Dict[str, float]:
|
||||
"""Get expense breakdown by user"""
|
||||
transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, start_date, end_date
|
||||
)
|
||||
|
||||
expenses_by_user = {}
|
||||
for transaction in transactions:
|
||||
if transaction.transaction_type == TransactionType.EXPENSE:
|
||||
user_name = f"{transaction.user.first_name or ''} {transaction.user.last_name or ''}".strip()
|
||||
if not user_name:
|
||||
user_name = transaction.user.username or f"User {transaction.user.id}"
|
||||
if user_name not in expenses_by_user:
|
||||
expenses_by_user[user_name] = 0
|
||||
expenses_by_user[user_name] += transaction.amount
|
||||
|
||||
return dict(sorted(expenses_by_user.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
def get_daily_expenses(
|
||||
self, family_id: int, days: int = 30
|
||||
) -> Dict[str, float]:
|
||||
"""Get daily expenses for period"""
|
||||
end_date = datetime.utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, start_date, end_date
|
||||
)
|
||||
|
||||
daily_expenses = {}
|
||||
for transaction in transactions:
|
||||
if transaction.transaction_type == TransactionType.EXPENSE:
|
||||
date_key = transaction.transaction_date.date().isoformat()
|
||||
if date_key not in daily_expenses:
|
||||
daily_expenses[date_key] = 0
|
||||
daily_expenses[date_key] += transaction.amount
|
||||
|
||||
return dict(sorted(daily_expenses.items()))
|
||||
|
||||
def get_month_comparison(self, family_id: int) -> Dict[str, float]:
|
||||
"""Compare expenses: current month vs last month"""
|
||||
today = datetime.utcnow()
|
||||
current_month_start = today.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Last month
|
||||
last_month_end = current_month_start - timedelta(days=1)
|
||||
last_month_start = last_month_end.replace(day=1)
|
||||
|
||||
current_transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, current_month_start, today
|
||||
)
|
||||
last_transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, last_month_start, last_month_end
|
||||
)
|
||||
|
||||
current_expenses = sum(
|
||||
t.amount for t in current_transactions
|
||||
if t.transaction_type == TransactionType.EXPENSE
|
||||
)
|
||||
last_expenses = sum(
|
||||
t.amount for t in last_transactions
|
||||
if t.transaction_type == TransactionType.EXPENSE
|
||||
)
|
||||
|
||||
difference = current_expenses - last_expenses
|
||||
percent_change = ((difference / last_expenses * 100) if last_expenses > 0 else 0)
|
||||
|
||||
return {
|
||||
"current_month": current_expenses,
|
||||
"last_month": last_expenses,
|
||||
"difference": difference,
|
||||
"percent_change": percent_change,
|
||||
}
|
||||
63
app/services/auth_service.py
Normal file
63
app/services/auth_service.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Authentication Service - User login, token management
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
import secrets
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import User
|
||||
from app.security.jwt_manager import jwt_manager
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthService:
|
||||
"""Handles user authentication and token management"""
|
||||
|
||||
TELEGRAM_BINDING_CODE_TTL = 600 # 10 minutes
|
||||
BINDING_CODE_LENGTH = 24
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def create_telegram_binding_code(self, chat_id: int) -> str:
|
||||
"""Generate temporary code for Telegram user binding"""
|
||||
code = secrets.token_urlsafe(self.BINDING_CODE_LENGTH)
|
||||
logger.info(f"Generated Telegram binding code for chat_id={chat_id}")
|
||||
return code
|
||||
|
||||
async def login(self, email: str, password: str) -> Dict[str, Any]:
|
||||
"""Authenticate user with email/password"""
|
||||
|
||||
user = self.db.query(User).filter_by(email=email).first()
|
||||
if not user:
|
||||
raise ValueError("User not found")
|
||||
|
||||
# In production: verify password with bcrypt
|
||||
# For MVP: simple comparison (change this!)
|
||||
|
||||
access_token = jwt_manager.create_access_token(user_id=user.id)
|
||||
|
||||
logger.info(f"User {user.id} logged in")
|
||||
|
||||
return {
|
||||
"user_id": user.id,
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
|
||||
async def refresh_token(self, refresh_token: str) -> Dict[str, Any]:
|
||||
"""Refresh access token"""
|
||||
|
||||
try:
|
||||
payload = jwt_manager.verify_token(refresh_token)
|
||||
new_token = jwt_manager.create_access_token(user_id=payload.user_id)
|
||||
return {
|
||||
"access_token": new_token,
|
||||
"token_type": "bearer",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh failed: {e}")
|
||||
raise ValueError("Invalid refresh token")
|
||||
13
app/services/finance/__init__.py
Normal file
13
app/services/finance/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Finance service module"""
|
||||
|
||||
from app.services.finance.transaction_service import TransactionService
|
||||
from app.services.finance.budget_service import BudgetService
|
||||
from app.services.finance.goal_service import GoalService
|
||||
from app.services.finance.account_service import AccountService
|
||||
|
||||
__all__ = [
|
||||
"TransactionService",
|
||||
"BudgetService",
|
||||
"GoalService",
|
||||
"AccountService",
|
||||
]
|
||||
60
app/services/finance/account_service.py
Normal file
60
app/services/finance/account_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Account service"""
|
||||
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.repositories import AccountRepository
|
||||
from app.db.models import Account
|
||||
from app.schemas import AccountCreateSchema
|
||||
|
||||
|
||||
class AccountService:
|
||||
"""Service for account operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.account_repo = AccountRepository(session)
|
||||
|
||||
def create_account(self, family_id: int, owner_id: int, data: AccountCreateSchema) -> Account:
|
||||
"""Create new account"""
|
||||
return self.account_repo.create(
|
||||
family_id=family_id,
|
||||
owner_id=owner_id,
|
||||
name=data.name,
|
||||
account_type=data.account_type,
|
||||
description=data.description,
|
||||
balance=data.initial_balance,
|
||||
initial_balance=data.initial_balance,
|
||||
)
|
||||
|
||||
def transfer_between_accounts(
|
||||
self, from_account_id: int, to_account_id: int, amount: float
|
||||
) -> bool:
|
||||
"""Transfer money between accounts"""
|
||||
from_account = self.account_repo.update_balance(from_account_id, -amount)
|
||||
to_account = self.account_repo.update_balance(to_account_id, amount)
|
||||
return from_account is not None and to_account is not None
|
||||
|
||||
def get_family_total_balance(self, family_id: int) -> float:
|
||||
"""Get total balance of all family accounts"""
|
||||
accounts = self.account_repo.get_family_accounts(family_id)
|
||||
return sum(acc.balance for acc in accounts)
|
||||
|
||||
def archive_account(self, account_id: int) -> Optional[Account]:
|
||||
"""Archive account (hide but keep data)"""
|
||||
return self.account_repo.archive_account(account_id)
|
||||
|
||||
def get_account_summary(self, account_id: int) -> dict:
|
||||
"""Get account summary"""
|
||||
account = self.account_repo.get_by_id(account_id)
|
||||
if not account:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"account_id": account.id,
|
||||
"name": account.name,
|
||||
"type": account.account_type,
|
||||
"balance": account.balance,
|
||||
"is_active": account.is_active,
|
||||
"is_archived": account.is_archived,
|
||||
"created_at": account.created_at,
|
||||
}
|
||||
67
app/services/finance/budget_service.py
Normal file
67
app/services/finance/budget_service.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Budget service"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.repositories import BudgetRepository, TransactionRepository, CategoryRepository
|
||||
from app.db.models import Budget, TransactionType
|
||||
from app.schemas import BudgetCreateSchema
|
||||
|
||||
|
||||
class BudgetService:
|
||||
"""Service for budget operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.budget_repo = BudgetRepository(session)
|
||||
self.transaction_repo = TransactionRepository(session)
|
||||
self.category_repo = CategoryRepository(session)
|
||||
|
||||
def create_budget(self, family_id: int, data: BudgetCreateSchema) -> Budget:
|
||||
"""Create new budget"""
|
||||
return self.budget_repo.create(
|
||||
family_id=family_id,
|
||||
name=data.name,
|
||||
limit_amount=data.limit_amount,
|
||||
period=data.period,
|
||||
alert_threshold=data.alert_threshold,
|
||||
category_id=data.category_id,
|
||||
start_date=data.start_date,
|
||||
)
|
||||
|
||||
def get_budget_status(self, budget_id: int) -> dict:
|
||||
"""Get budget status with spent amount and percentage"""
|
||||
budget = self.budget_repo.get_by_id(budget_id)
|
||||
if not budget:
|
||||
return {}
|
||||
|
||||
spent_percent = (budget.spent_amount / budget.limit_amount * 100) if budget.limit_amount > 0 else 0
|
||||
remaining = budget.limit_amount - budget.spent_amount
|
||||
is_exceeded = spent_percent > 100
|
||||
is_warning = spent_percent >= budget.alert_threshold
|
||||
|
||||
return {
|
||||
"budget_id": budget.id,
|
||||
"name": budget.name,
|
||||
"limit": budget.limit_amount,
|
||||
"spent": budget.spent_amount,
|
||||
"remaining": remaining,
|
||||
"spent_percent": spent_percent,
|
||||
"is_exceeded": is_exceeded,
|
||||
"is_warning": is_warning,
|
||||
"alert_threshold": budget.alert_threshold,
|
||||
}
|
||||
|
||||
def get_family_budget_status(self, family_id: int) -> List[dict]:
|
||||
"""Get status of all budgets in family"""
|
||||
budgets = self.budget_repo.get_family_budgets(family_id)
|
||||
return [self.get_budget_status(budget.id) for budget in budgets]
|
||||
|
||||
def check_budget_exceeded(self, budget_id: int) -> bool:
|
||||
"""Check if budget limit exceeded"""
|
||||
status = self.get_budget_status(budget_id)
|
||||
return status.get("is_exceeded", False)
|
||||
|
||||
def reset_budget(self, budget_id: int) -> Optional[Budget]:
|
||||
"""Reset budget spent amount for new period"""
|
||||
return self.budget_repo.update(budget_id, spent_amount=0.0)
|
||||
64
app/services/finance/goal_service.py
Normal file
64
app/services/finance/goal_service.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Goal service"""
|
||||
|
||||
from typing import Optional, List
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.repositories import GoalRepository
|
||||
from app.db.models import Goal
|
||||
from app.schemas import GoalCreateSchema
|
||||
|
||||
|
||||
class GoalService:
|
||||
"""Service for goal operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.goal_repo = GoalRepository(session)
|
||||
|
||||
def create_goal(self, family_id: int, data: GoalCreateSchema) -> Goal:
|
||||
"""Create new savings goal"""
|
||||
return self.goal_repo.create(
|
||||
family_id=family_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
target_amount=data.target_amount,
|
||||
priority=data.priority,
|
||||
target_date=data.target_date,
|
||||
account_id=data.account_id,
|
||||
)
|
||||
|
||||
def add_to_goal(self, goal_id: int, amount: float) -> Optional[Goal]:
|
||||
"""Add amount to goal progress"""
|
||||
return self.goal_repo.update_progress(goal_id, amount)
|
||||
|
||||
def get_goal_progress(self, goal_id: int) -> dict:
|
||||
"""Get goal progress information"""
|
||||
goal = self.goal_repo.get_by_id(goal_id)
|
||||
if not goal:
|
||||
return {}
|
||||
|
||||
progress_percent = (goal.current_amount / goal.target_amount * 100) if goal.target_amount > 0 else 0
|
||||
|
||||
return {
|
||||
"goal_id": goal.id,
|
||||
"name": goal.name,
|
||||
"target": goal.target_amount,
|
||||
"current": goal.current_amount,
|
||||
"remaining": goal.target_amount - goal.current_amount,
|
||||
"progress_percent": progress_percent,
|
||||
"is_completed": goal.is_completed,
|
||||
"target_date": goal.target_date,
|
||||
}
|
||||
|
||||
def get_family_goals_progress(self, family_id: int) -> List[dict]:
|
||||
"""Get progress for all family goals"""
|
||||
goals = self.goal_repo.get_family_goals(family_id)
|
||||
return [self.get_goal_progress(goal.id) for goal in goals]
|
||||
|
||||
def complete_goal(self, goal_id: int) -> Optional[Goal]:
|
||||
"""Mark goal as completed"""
|
||||
from datetime import datetime
|
||||
return self.goal_repo.update(
|
||||
goal_id,
|
||||
is_completed=True,
|
||||
completed_at=datetime.utcnow()
|
||||
)
|
||||
94
app/services/finance/transaction_service.py
Normal file
94
app/services/finance/transaction_service.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Transaction service"""
|
||||
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.repositories import TransactionRepository, AccountRepository, BudgetRepository
|
||||
from app.db.models import Transaction, TransactionType
|
||||
from app.schemas import TransactionCreateSchema
|
||||
|
||||
|
||||
class TransactionService:
|
||||
"""Service for transaction operations"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
self.transaction_repo = TransactionRepository(session)
|
||||
self.account_repo = AccountRepository(session)
|
||||
self.budget_repo = BudgetRepository(session)
|
||||
|
||||
def create_transaction(
|
||||
self,
|
||||
family_id: int,
|
||||
user_id: int,
|
||||
account_id: int,
|
||||
data: TransactionCreateSchema,
|
||||
) -> Transaction:
|
||||
"""Create new transaction and update account balance"""
|
||||
# Create transaction
|
||||
transaction = self.transaction_repo.create(
|
||||
family_id=family_id,
|
||||
user_id=user_id,
|
||||
account_id=account_id,
|
||||
amount=data.amount,
|
||||
transaction_type=data.transaction_type,
|
||||
description=data.description,
|
||||
notes=data.notes,
|
||||
tags=data.tags,
|
||||
category_id=data.category_id,
|
||||
receipt_photo_url=data.receipt_photo_url,
|
||||
transaction_date=data.transaction_date,
|
||||
)
|
||||
|
||||
# Update account balance
|
||||
if data.transaction_type == TransactionType.EXPENSE:
|
||||
self.account_repo.update_balance(account_id, -data.amount)
|
||||
elif data.transaction_type == TransactionType.INCOME:
|
||||
self.account_repo.update_balance(account_id, data.amount)
|
||||
|
||||
# Update budget if expense
|
||||
if (
|
||||
data.transaction_type == TransactionType.EXPENSE
|
||||
and data.category_id
|
||||
):
|
||||
budget = self.budget_repo.get_category_budget(family_id, data.category_id)
|
||||
if budget:
|
||||
self.budget_repo.update_spent_amount(budget.id, data.amount)
|
||||
|
||||
return transaction
|
||||
|
||||
def get_family_summary(self, family_id: int, days: int = 30) -> dict:
|
||||
"""Get financial summary for family"""
|
||||
end_date = datetime.utcnow()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
transactions = self.transaction_repo.get_transactions_by_period(
|
||||
family_id, start_date, end_date
|
||||
)
|
||||
|
||||
income = sum(t.amount for t in transactions if t.transaction_type == TransactionType.INCOME)
|
||||
expenses = sum(t.amount for t in transactions if t.transaction_type == TransactionType.EXPENSE)
|
||||
net = income - expenses
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"income": income,
|
||||
"expenses": expenses,
|
||||
"net": net,
|
||||
"average_daily_expense": expenses / days if days > 0 else 0,
|
||||
"transaction_count": len(transactions),
|
||||
}
|
||||
|
||||
def delete_transaction(self, transaction_id: int) -> bool:
|
||||
"""Delete transaction and rollback balance"""
|
||||
transaction = self.transaction_repo.get_by_id(transaction_id)
|
||||
if transaction:
|
||||
# Rollback balance
|
||||
if transaction.transaction_type == TransactionType.EXPENSE:
|
||||
self.account_repo.update_balance(transaction.account_id, transaction.amount)
|
||||
elif transaction.transaction_type == TransactionType.INCOME:
|
||||
self.account_repo.update_balance(transaction.account_id, -transaction.amount)
|
||||
|
||||
# Delete transaction
|
||||
return self.transaction_repo.delete(transaction_id)
|
||||
return False
|
||||
5
app/services/notifications/__init__.py
Normal file
5
app/services/notifications/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Notifications service module"""
|
||||
|
||||
from app.services.notifications.notification_service import NotificationService
|
||||
|
||||
__all__ = ["NotificationService"]
|
||||
57
app/services/notifications/notification_service.py
Normal file
57
app/services/notifications/notification_service.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Notification service"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Family
|
||||
|
||||
|
||||
class NotificationService:
|
||||
"""Service for managing notifications"""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self.session = session
|
||||
|
||||
def should_notify(self, family: Family, notification_type: str) -> bool:
|
||||
"""Check if notification should be sent based on family settings"""
|
||||
if family.notification_level == "none":
|
||||
return False
|
||||
elif family.notification_level == "important":
|
||||
return notification_type in ["budget_exceeded", "goal_completed"]
|
||||
else: # all
|
||||
return True
|
||||
|
||||
def format_transaction_notification(
|
||||
self, user_name: str, amount: float, category: str, account: str
|
||||
) -> str:
|
||||
"""Format transaction notification message"""
|
||||
return (
|
||||
f"💰 {user_name} добавил запись:\n"
|
||||
f"Сумма: {amount}₽\n"
|
||||
f"Категория: {category}\n"
|
||||
f"Счет: {account}"
|
||||
)
|
||||
|
||||
def format_budget_warning(
|
||||
self, budget_name: str, spent: float, limit: float, percent: float
|
||||
) -> str:
|
||||
"""Format budget warning message"""
|
||||
return (
|
||||
f"⚠️ Внимание по бюджету!\n"
|
||||
f"Бюджет: {budget_name}\n"
|
||||
f"Потрачено: {spent}₽ из {limit}₽\n"
|
||||
f"Превышено на: {percent:.1f}%"
|
||||
)
|
||||
|
||||
def format_goal_progress(
|
||||
self, goal_name: str, current: float, target: float, percent: float
|
||||
) -> str:
|
||||
"""Format goal progress message"""
|
||||
return (
|
||||
f"🎯 Прогресс цели: {goal_name}\n"
|
||||
f"Накоплено: {current}₽ из {target}₽\n"
|
||||
f"Прогресс: {percent:.1f}%"
|
||||
)
|
||||
|
||||
def format_goal_completed(self, goal_name: str) -> str:
|
||||
"""Format goal completion message"""
|
||||
return f"✅ Цель достигнута! 🎉\n{goal_name}"
|
||||
145
app/services/transaction_service.py
Normal file
145
app/services/transaction_service.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Transaction Service - Core business logic
|
||||
Handles transaction creation, approval, reversal with audit trail
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from decimal import Decimal
|
||||
import logging
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db.models import Transaction, Account, Family, User
|
||||
from app.security.rbac import RBACEngine, Permission, UserContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TransactionService:
|
||||
"""Manages financial transactions with approval workflow"""
|
||||
|
||||
APPROVAL_THRESHOLD = 500.0
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
async def create_transaction(
|
||||
self,
|
||||
user_context: UserContext,
|
||||
family_id: int,
|
||||
from_account_id: Optional[int],
|
||||
to_account_id: Optional[int],
|
||||
amount: Decimal,
|
||||
category_id: Optional[int] = None,
|
||||
description: str = "",
|
||||
requires_approval: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create new transaction"""
|
||||
RBACEngine.check_permission(user_context, Permission.CREATE_TRANSACTION)
|
||||
RBACEngine.check_family_access(user_context, family_id)
|
||||
|
||||
if amount <= 0:
|
||||
raise ValueError("Amount must be positive")
|
||||
|
||||
needs_approval = requires_approval or (float(amount) > self.APPROVAL_THRESHOLD and user_context.role.value != "owner")
|
||||
tx_status = "pending_approval" if needs_approval else "executed"
|
||||
|
||||
transaction = Transaction(
|
||||
family_id=family_id,
|
||||
created_by_id=user_context.user_id,
|
||||
from_account_id=from_account_id,
|
||||
to_account_id=to_account_id,
|
||||
amount=float(amount),
|
||||
category_id=category_id,
|
||||
description=description,
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
self.db.add(transaction)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Transaction created: {transaction.id}")
|
||||
|
||||
return {
|
||||
"id": transaction.id,
|
||||
"status": tx_status,
|
||||
"amount": float(amount),
|
||||
"requires_approval": needs_approval,
|
||||
}
|
||||
|
||||
async def confirm_transaction(
|
||||
self,
|
||||
user_context: UserContext,
|
||||
transaction_id: int,
|
||||
family_id: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Approve pending transaction"""
|
||||
RBACEngine.check_permission(user_context, Permission.APPROVE_TRANSACTION)
|
||||
RBACEngine.check_family_access(user_context, family_id)
|
||||
|
||||
tx = self.db.query(Transaction).filter_by(
|
||||
id=transaction_id,
|
||||
family_id=family_id,
|
||||
).first()
|
||||
|
||||
if not tx:
|
||||
raise ValueError(f"Transaction {transaction_id} not found")
|
||||
|
||||
tx.status = "executed"
|
||||
tx.approved_by_id = user_context.user_id
|
||||
tx.approved_at = datetime.utcnow()
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Transaction {transaction_id} approved")
|
||||
|
||||
return {
|
||||
"id": tx.id,
|
||||
"status": "executed",
|
||||
}
|
||||
|
||||
async def reverse_transaction(
|
||||
self,
|
||||
user_context: UserContext,
|
||||
transaction_id: int,
|
||||
family_id: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Reverse transaction by creating compensation"""
|
||||
RBACEngine.check_permission(user_context, Permission.REVERSE_TRANSACTION)
|
||||
RBACEngine.check_family_access(user_context, family_id)
|
||||
|
||||
original = self.db.query(Transaction).filter_by(
|
||||
id=transaction_id,
|
||||
family_id=family_id,
|
||||
).first()
|
||||
|
||||
if not original:
|
||||
raise ValueError(f"Transaction {transaction_id} not found")
|
||||
|
||||
if original.status == "reversed":
|
||||
raise ValueError("Transaction already reversed")
|
||||
|
||||
reversal = Transaction(
|
||||
family_id=family_id,
|
||||
created_by_id=user_context.user_id,
|
||||
from_account_id=original.to_account_id,
|
||||
to_account_id=original.from_account_id,
|
||||
amount=original.amount,
|
||||
category_id=original.category_id,
|
||||
description=f"Reversal of transaction #{original.id}",
|
||||
status="executed",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
original.status = "reversed"
|
||||
original.reversed_at = datetime.utcnow()
|
||||
original.reversed_by_id = user_context.user_id
|
||||
|
||||
self.db.add(reversal)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Transaction {transaction_id} reversed, created {reversal.id}")
|
||||
|
||||
return {
|
||||
"original_id": original.id,
|
||||
"reversal_id": reversal.id,
|
||||
"status": "reversed",
|
||||
}
|
||||
Reference in New Issue
Block a user