from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Query from typing import List, Dict import asyncio import logging import json from sqlalchemy.orm import Session from app.database import SessionLocal from app.api.auth import get_current_user_ws from app.schemas.auth import CurrentUser from app.services.proxmox_service import proxmox_service from app.models.vm import VMAccess from app.schemas.vm import VMInfo router = APIRouter() logger = logging.getLogger(__name__) class ConnectionManager: def __init__(self): # Active connections: List of (WebSocket, User) tuples self.active_connections: List[tuple[WebSocket, CurrentUser]] = [] async def connect(self, websocket: WebSocket, user: CurrentUser): await websocket.accept() self.active_connections.append((websocket, user)) logger.info(f"WebSocket connected: {user.username}") def disconnect(self, websocket: WebSocket, user: CurrentUser): if (websocket, user) in self.active_connections: self.active_connections.remove((websocket, user)) logger.info(f"WebSocket disconnected: {user.username}") async def broadcast(self, all_resources: List[Dict]): """ 모든 연결된 클라이언트에게 권한에 맞는 VM 상태를 전송 """ # DB 세션을 매번 새로 생성하는 것은 비효율적일 수 있으나, # Background Task에서 실행되므로 안전하게 처리 try: db = SessionLocal() # 모든 활성 VM Access 정보 미리 로딩 all_accesses = db.query(VMAccess).filter(VMAccess.is_active == True).all() # User ID별 Access Map 생성 user_access_map = {} for access in all_accesses: if access.user_id not in user_access_map: user_access_map[access.user_id] = {} user_access_map[access.user_id][access.vm_id] = access for connection, user in self.active_connections: try: # 해당 유저의 권한 맵 가져오기 access_map = user_access_map.get(user.id, {}) user_vm_list = [] for res in all_resources: vm_id = res.get("vmid") # VMID가 없거나 권한 정보가 없으면 패스 if not vm_id or vm_id not in access_map: continue access = access_map[vm_id] # VMInfo 스키마에 맞춰 데이터 구성 vm_info = { "vm_id": vm_id, "node": res.get("node"), "type": res.get("type", "qemu"), "name": res.get("name", "Unknown"), "status": res.get("status", "unknown"), "ip_address": access.static_ip, "cpus": res.get("maxcpu", 0), "memory": res.get("maxmem", 0) // (1024 * 1024), "memory_usage": res.get("mem", 0) // (1024 * 1024) if res.get("mem") else 0, "cpu_usage": res.get("cpu", 0), # 권한 정보 "can_start": access.can_start, "can_stop": access.can_stop, "can_reboot": access.can_reboot, "can_connect": access.can_connect, # RDP 정보는 보안상 웹소켓 브로드캐스트에서는 제외하거나 필요시 포함 # (여기서는 제외하고 REST API 상세조회 사용 권장하지만, 목록 뷰를 위해 포함) "rdp_username": access.rdp_username, "rdp_port": access.rdp_port or 3389 } user_vm_list.append(vm_info) # 전송 await connection.send_json({ "type": "update", "data": user_vm_list }) except Exception as e: logger.error(f"Error sending update to {user.username}: {e}") # 에러 발생 시 연결 끊기 고려? db.close() except Exception as e: logger.error(f"Broadcast error: {e}") manager = ConnectionManager() @router.websocket("/status") async def websocket_endpoint( websocket: WebSocket, user: CurrentUser = Depends(get_current_user_ws) ): try: await manager.connect(websocket, user) try: while True: # 클라이언트로부터 메시지를 받을 일이 딱히 없지만 연결 유지를 위해 대기 await websocket.receive_text() except WebSocketDisconnect: manager.disconnect(websocket, user) except Exception: await websocket.close(code=4001) async def start_monitoring_task(): """백그라운드 모니터링 태스크""" logger.info("Starting background monitoring task...") try: while True: # 활성 연결이 있을 때만 조회 (리소스 절약) if manager.active_connections: resources = await proxmox_service.get_all_vms() await manager.broadcast(resources) await asyncio.sleep(2) # 2초마다 갱신 except asyncio.CancelledError: logger.info("Monitoring task cancelled")