Sanad/dashboard/routes/_safe_io.py

82 lines
2.5 KiB
Python

"""Shared filesystem safety helpers for dashboard routes.
Provides:
- safe_filename: validate + reject traversal/special chars
- safe_path_under: ensure resolved path stays inside a base dir
- atomic_write_bytes: write-to-temp + os.replace
- atomic_write_text
- atomic_write_json
"""
from __future__ import annotations
import json
import os
import tempfile
from pathlib import Path
from typing import Any
from fastapi import HTTPException
from Project.Sanad.core.config_loader import section as _cfg_section
# Maximum upload size in bytes — SINGLE SOURCE in dashboard.api_input
MAX_UPLOAD_BYTES = _cfg_section("dashboard", "api_input").get(
"max_upload_bytes", 8 * 1024 * 1024)
def safe_filename(name: str | None) -> str:
"""Strip directory components and reject obviously unsafe names."""
if not name:
raise HTTPException(400, "Filename required.")
cleaned = os.path.basename(name).strip()
if not cleaned or cleaned in {".", ".."}:
raise HTTPException(400, "Invalid filename.")
if any(c in cleaned for c in ("\x00", "\n", "\r")):
raise HTTPException(400, "Invalid characters in filename.")
return cleaned
def safe_path_under(base: Path, name: str) -> Path:
"""Resolve `base/name` and verify it stays inside `base`."""
cleaned = safe_filename(name)
base_resolved = base.resolve()
candidate = (base / cleaned).resolve()
try:
candidate.relative_to(base_resolved)
except ValueError:
raise HTTPException(400, "Path traversal denied.")
return candidate
def check_upload_size(content: bytes, max_bytes: int = MAX_UPLOAD_BYTES) -> None:
if len(content) > max_bytes:
raise HTTPException(
413,
f"Upload too large: {len(content)} bytes (max {max_bytes}).",
)
def atomic_write_bytes(path: Path, data: bytes) -> None:
"""Write bytes atomically via tempfile + os.replace."""
path.parent.mkdir(parents=True, exist_ok=True)
fd, tmp = tempfile.mkstemp(prefix=f".{path.name}.", suffix=".tmp", dir=str(path.parent))
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
os.replace(tmp, path)
except Exception:
try:
os.unlink(tmp)
except OSError:
pass
raise
def atomic_write_text(path: Path, text: str, encoding: str = "utf-8") -> None:
atomic_write_bytes(path, text.encode(encoding))
def atomic_write_json(path: Path, payload: Any, indent: int = 2) -> None:
atomic_write_text(path, json.dumps(payload, ensure_ascii=False, indent=indent))