Updates
This commit is contained in:
@@ -16,18 +16,16 @@ from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.security import verify_token
|
||||
from app.core.http import CLIENT_AUTH_COOKIE, get_bearer_or_cookie_token
|
||||
from app.core.security_logging import log_security_event
|
||||
from app.db.session import get_db
|
||||
from app.models.access import Permission, Role, User
|
||||
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
# Subject claim used by tokens issued for internal Hunter Stock Feeds users.
|
||||
# Distinct from the existing client-portal/admin tokens so the two systems
|
||||
# cannot impersonate each other.
|
||||
@@ -103,7 +101,7 @@ def _load_user(db: Session, user_id: int) -> User | None:
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
"""Resolve the current internal user from the bearer token.
|
||||
@@ -111,10 +109,11 @@ def get_current_user(
|
||||
Raises 401 for missing/invalid tokens or unknown users, 403 for inactive
|
||||
users.
|
||||
"""
|
||||
if credentials is None:
|
||||
token = get_bearer_or_cookie_token(request, cookie_name=CLIENT_AUTH_COOKIE.name)
|
||||
if token is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
payload = verify_token(credentials.credentials)
|
||||
payload = verify_token(token)
|
||||
if payload.get("sub") != INTERNAL_USER_SUBJECT:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token")
|
||||
|
||||
@@ -136,6 +135,7 @@ def require_permission(permission_key: str):
|
||||
|
||||
def dependency(user: User = Depends(get_current_user)) -> User:
|
||||
if not user_has_permission(user, permission_key):
|
||||
log_security_event("authz.denied", role=user.role.name if user.role else None, permission=permission_key)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permission: {permission_key}",
|
||||
@@ -152,6 +152,7 @@ def require_any_permission(permission_keys: Iterable[str]):
|
||||
def dependency(user: User = Depends(get_current_user)) -> User:
|
||||
granted = get_user_permissions(user)
|
||||
if not any(key in granted for key in keys):
|
||||
log_security_event("authz.denied", role=user.role.name if user.role else None, permissions=list(keys))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Requires any of: {list(keys)}",
|
||||
@@ -169,6 +170,7 @@ def require_all_permissions(permission_keys: Iterable[str]):
|
||||
granted = get_user_permissions(user)
|
||||
missing = [key for key in keys if key not in granted]
|
||||
if missing:
|
||||
log_security_event("authz.denied", role=user.role.name if user.role else None, permissions=missing)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing required permissions: {missing}",
|
||||
|
||||
@@ -16,9 +16,21 @@ def _parse_csv_env(value: str) -> tuple[str, ...]:
|
||||
return tuple(part.strip() for part in value.split(",") if part.strip())
|
||||
|
||||
|
||||
def _env_flag(name: str, default: bool = False) -> bool:
|
||||
value = os.getenv(name)
|
||||
if value is None:
|
||||
return default
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
app_name: str
|
||||
app_env: str
|
||||
host: str
|
||||
port: int
|
||||
log_level: str
|
||||
log_verbose: bool
|
||||
database_url: str
|
||||
client_name: str
|
||||
client_email: str
|
||||
@@ -30,11 +42,27 @@ class Settings:
|
||||
auth_secret: str
|
||||
cors_allow_origins: tuple[str, ...]
|
||||
cors_allow_origin_regex: str
|
||||
session_ttl_seconds: int
|
||||
session_cookie_name: str
|
||||
admin_session_cookie_name: str
|
||||
session_cookie_secure: bool
|
||||
session_cookie_samesite: str
|
||||
session_cookie_domain: str | None
|
||||
request_body_max_bytes: int
|
||||
login_rate_limit_attempts: int
|
||||
login_rate_limit_window_seconds: int
|
||||
trusted_hosts: tuple[str, ...]
|
||||
docs_enabled: bool
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Settings":
|
||||
return cls(
|
||||
settings = cls(
|
||||
app_name=os.getenv("APP_NAME", "Data Entry App API"),
|
||||
app_env=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "development")),
|
||||
host=os.getenv("HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("PORT", "8000")),
|
||||
log_level=os.getenv("LOG_LEVEL", "DEBUG" if os.getenv("LOG_VERBOSE") in {"1", "true", "TRUE", "yes", "on"} else "INFO"),
|
||||
log_verbose=_env_flag("LOG_VERBOSE"),
|
||||
database_url=os.getenv("DATABASE_URL", "sqlite:///./data_entry_app.db"),
|
||||
client_name=os.getenv("CLIENT_NAME", "Hunter Premium Produce"),
|
||||
client_email=os.getenv("CLIENT_EMAIL", "operator@example.com"),
|
||||
@@ -51,7 +79,47 @@ class Settings:
|
||||
)
|
||||
),
|
||||
cors_allow_origin_regex=os.getenv("CORS_ALLOW_ORIGIN_REGEX", DEFAULT_CORS_ALLOW_ORIGIN_REGEX),
|
||||
session_ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 12))),
|
||||
session_cookie_name=os.getenv("SESSION_COOKIE_NAME", "client_session"),
|
||||
admin_session_cookie_name=os.getenv("ADMIN_SESSION_COOKIE_NAME", "admin_session"),
|
||||
session_cookie_secure=_env_flag("SESSION_COOKIE_SECURE"),
|
||||
session_cookie_samesite=os.getenv("SESSION_COOKIE_SAMESITE", "lax").lower(),
|
||||
session_cookie_domain=os.getenv("SESSION_COOKIE_DOMAIN", "").strip() or None,
|
||||
request_body_max_bytes=int(os.getenv("REQUEST_BODY_MAX_BYTES", str(1024 * 1024))),
|
||||
login_rate_limit_attempts=int(os.getenv("LOGIN_RATE_LIMIT_ATTEMPTS", "8")),
|
||||
login_rate_limit_window_seconds=int(os.getenv("LOGIN_RATE_LIMIT_WINDOW_SECONDS", "300")),
|
||||
trusted_hosts=_parse_csv_env(os.getenv("TRUSTED_HOSTS", "localhost,127.0.0.1,testserver")),
|
||||
docs_enabled=_env_flag("DOCS_ENABLED", default=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "development")).lower() != "production"),
|
||||
)
|
||||
settings._validate()
|
||||
return settings
|
||||
|
||||
def _validate(self) -> None:
|
||||
if self.session_cookie_samesite not in {"lax", "strict", "none"}:
|
||||
raise ValueError("SESSION_COOKIE_SAMESITE must be one of: lax, strict, none")
|
||||
|
||||
is_production = self.app_env.lower() == "production"
|
||||
if not is_production:
|
||||
return
|
||||
|
||||
if self.client_password in {"changeme", "", "replace-with-strong-password"}:
|
||||
raise ValueError("CLIENT_PASSWORD must be set to a non-default value in production")
|
||||
if self.admin_password in {"lean101-admin", "", "replace-with-strong-password"}:
|
||||
raise ValueError("ADMIN_PASSWORD must be set to a non-default value in production")
|
||||
if self.auth_secret in {"lean-101-local-dev-secret", "change-me-in-production", "", "replace-with-a-long-random-secret"}:
|
||||
raise ValueError("AUTH_SECRET must be set to a strong production secret")
|
||||
if len(self.auth_secret) < 32:
|
||||
raise ValueError("AUTH_SECRET must be at least 32 characters in production")
|
||||
if not self.session_cookie_secure:
|
||||
raise ValueError("SESSION_COOKIE_SECURE must be enabled in production")
|
||||
if not self.cors_allow_origins:
|
||||
raise ValueError("CORS_ALLOW_ORIGINS must explicitly list production origins")
|
||||
if "localhost" in ",".join(self.cors_allow_origins).lower():
|
||||
raise ValueError("CORS_ALLOW_ORIGINS cannot include localhost in production")
|
||||
if self.cors_allow_origin_regex == DEFAULT_CORS_ALLOW_ORIGIN_REGEX:
|
||||
raise ValueError("CORS_ALLOW_ORIGIN_REGEX must be overridden or blank in production")
|
||||
if self.docs_enabled:
|
||||
raise ValueError("DOCS_ENABLED must be false in production")
|
||||
|
||||
|
||||
settings = Settings.from_env()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
from fastapi import Request, Response
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
COOKIE_PATH: Final[str] = "/"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthCookie:
|
||||
name: str
|
||||
|
||||
def apply(self, response: Response, token: str) -> None:
|
||||
response.set_cookie(
|
||||
key=self.name,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=settings.session_cookie_secure,
|
||||
samesite=settings.session_cookie_samesite,
|
||||
domain=settings.session_cookie_domain,
|
||||
path=COOKIE_PATH,
|
||||
max_age=settings.session_ttl_seconds,
|
||||
)
|
||||
|
||||
def clear(self, response: Response) -> None:
|
||||
response.delete_cookie(
|
||||
key=self.name,
|
||||
domain=settings.session_cookie_domain,
|
||||
path=COOKIE_PATH,
|
||||
)
|
||||
|
||||
|
||||
CLIENT_AUTH_COOKIE = AuthCookie(settings.session_cookie_name)
|
||||
ADMIN_AUTH_COOKIE = AuthCookie(settings.admin_session_cookie_name)
|
||||
|
||||
|
||||
def get_bearer_or_cookie_token(request: Request, *, cookie_name: str) -> str | None:
|
||||
authorization = request.headers.get("authorization", "").strip()
|
||||
if authorization.lower().startswith("bearer "):
|
||||
token = authorization[7:].strip()
|
||||
if token:
|
||||
return token
|
||||
cookie_value = request.cookies.get(cookie_name)
|
||||
if cookie_value:
|
||||
return cookie_value
|
||||
return None
|
||||
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
try:
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
except ImportError: # pragma: no cover - exercised only before dependency install
|
||||
Console = None
|
||||
RichHandler = None
|
||||
Table = None
|
||||
Text = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LoggingSettings:
|
||||
app_name: str
|
||||
app_env: str
|
||||
host: str
|
||||
port: int
|
||||
log_level: str
|
||||
log_verbose: bool
|
||||
database_url: str
|
||||
version: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StartupStatus:
|
||||
app_name: str
|
||||
version: str
|
||||
environment: str
|
||||
host: str
|
||||
port: int
|
||||
database: str
|
||||
mode: str
|
||||
started_at: str
|
||||
local_url: str
|
||||
network_url: str
|
||||
|
||||
|
||||
class PlainFormatter(logging.Formatter):
|
||||
default_time_format = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
if not hasattr(record, "component"):
|
||||
record.component = record.name.rsplit(".", 1)[-1]
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def _allow_color() -> bool:
|
||||
if RichHandler is None or Console is None:
|
||||
return False
|
||||
if os.getenv("NO_COLOR"):
|
||||
return False
|
||||
if os.getenv("TERM") == "dumb":
|
||||
return False
|
||||
return hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
|
||||
|
||||
|
||||
def _allow_unicode() -> bool:
|
||||
encoding = (getattr(sys.stdout, "encoding", None) or "").lower()
|
||||
if not encoding:
|
||||
return False
|
||||
return "utf" in encoding
|
||||
|
||||
|
||||
def _console() -> Console:
|
||||
return Console(
|
||||
stderr=True,
|
||||
soft_wrap=False,
|
||||
highlight=False,
|
||||
force_terminal=_allow_color(),
|
||||
no_color=not _allow_color(),
|
||||
emoji=False,
|
||||
)
|
||||
|
||||
|
||||
def _rich_handler(level: str) -> RichHandler:
|
||||
return RichHandler(
|
||||
level=level,
|
||||
console=_console(),
|
||||
show_time=True,
|
||||
show_level=True,
|
||||
show_path=False,
|
||||
omit_repeated_times=False,
|
||||
markup=True,
|
||||
rich_tracebacks=True,
|
||||
tracebacks_show_locals=False,
|
||||
log_time_format="%H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def _plain_handler(level: str) -> logging.StreamHandler:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(
|
||||
PlainFormatter("%(asctime)s | %(levelname)-7s | %(component)-10s | %(message)s")
|
||||
)
|
||||
return handler
|
||||
|
||||
|
||||
def _handler(level: str) -> logging.Handler:
|
||||
return _rich_handler(level) if _allow_color() else _plain_handler(level)
|
||||
|
||||
|
||||
def configure_logging(settings: LoggingSettings) -> None:
|
||||
level = settings.log_level.upper()
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.setLevel(level)
|
||||
root.addHandler(_handler(level))
|
||||
|
||||
for name in ("uvicorn", "uvicorn.error", "fastapi"):
|
||||
logger = logging.getLogger(name)
|
||||
logger.handlers.clear()
|
||||
logger.setLevel(level)
|
||||
logger.propagate = True
|
||||
|
||||
access_logger = logging.getLogger("uvicorn.access")
|
||||
access_logger.handlers.clear()
|
||||
access_logger.propagate = False
|
||||
access_logger.disabled = True
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.LoggerAdapter[logging.Logger]:
|
||||
component = name.rsplit(".", 1)[-1]
|
||||
return logging.LoggerAdapter(logging.getLogger(name), {"component": component})
|
||||
|
||||
|
||||
def _icon(name: str) -> str:
|
||||
ascii_icons = {
|
||||
"app": "#",
|
||||
"info": "i",
|
||||
"success": "+",
|
||||
"warning": "!",
|
||||
"error": "x",
|
||||
"debug": ".",
|
||||
"section": "=",
|
||||
"url": ">",
|
||||
"shutdown": "-",
|
||||
}
|
||||
unicode_icons = {
|
||||
"app": "◆",
|
||||
"info": "ℹ",
|
||||
"success": "✓",
|
||||
"warning": "▲",
|
||||
"error": "✖",
|
||||
"debug": "•",
|
||||
"section": "─",
|
||||
"url": "↳",
|
||||
"shutdown": "◦",
|
||||
}
|
||||
icons = unicode_icons if _allow_unicode() else ascii_icons
|
||||
return icons[name]
|
||||
|
||||
|
||||
def _style(name: str) -> str:
|
||||
return {
|
||||
"info": "bold cyan",
|
||||
"success": "bold green",
|
||||
"warning": "bold yellow",
|
||||
"error": "bold red",
|
||||
"debug": "dim",
|
||||
"section": "bold bright_blue",
|
||||
"muted": "grey62",
|
||||
}[name]
|
||||
|
||||
|
||||
def section_heading(title: str) -> None:
|
||||
logger = get_logger("data_entry_app.section")
|
||||
if _allow_color():
|
||||
_console().rule(Text(f" {title.upper()} ", style=_style("section")))
|
||||
return
|
||||
logger.info("%s %s %s", _icon("section") * 10, title.upper(), _icon("section") * 10)
|
||||
|
||||
|
||||
def startup_banner(status: StartupStatus) -> None:
|
||||
logger = get_logger("data_entry_app.startup")
|
||||
if _allow_color():
|
||||
console = _console()
|
||||
table = Table.grid(expand=False)
|
||||
table.add_column(style="bold white", justify="left")
|
||||
table.add_column(style="white", justify="left")
|
||||
table.add_row("Environment", status.environment)
|
||||
table.add_row("Version", status.version)
|
||||
table.add_row("Host", status.host)
|
||||
table.add_row("Port", str(status.port))
|
||||
table.add_row("Database", status.database)
|
||||
table.add_row("Mode", status.mode)
|
||||
table.add_row("Started", status.started_at)
|
||||
|
||||
console.rule(Text(f" {status.app_name} ", style="bold white"))
|
||||
console.print(Text("Clean startup. Clear status. Ready.", style="italic cyan"))
|
||||
console.print(table)
|
||||
console.print()
|
||||
console.print(Text("App is running at:", style="bold white"))
|
||||
console.print(Text(f" Local: {status.local_url}", style="cyan"))
|
||||
console.print(Text(f" Network: {status.network_url}", style="cyan"))
|
||||
console.print()
|
||||
return
|
||||
|
||||
logger.info("%s %s", _icon("app"), "Startup banner")
|
||||
logger.info("App : %s", status.app_name)
|
||||
logger.info("Environment : %s", status.environment)
|
||||
logger.info("Version : %s", status.version)
|
||||
logger.info("Host : %s", status.host)
|
||||
logger.info("Port : %s", status.port)
|
||||
logger.info("Database : %s", status.database)
|
||||
logger.info("Mode : %s", status.mode)
|
||||
logger.info("Started : %s", status.started_at)
|
||||
logger.info("Local : %s", status.local_url)
|
||||
logger.info("Network : %s", status.network_url)
|
||||
|
||||
|
||||
def status_message(level: str, message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
|
||||
palette = {
|
||||
"debug": logging.DEBUG,
|
||||
"info": logging.INFO,
|
||||
"success": logging.INFO,
|
||||
"warning": logging.WARNING,
|
||||
"error": logging.ERROR,
|
||||
}
|
||||
labels = {
|
||||
"debug": f"[{_icon('debug')}]",
|
||||
"info": f"[{_icon('info')}]",
|
||||
"success": f"[{_icon('success')}]",
|
||||
"warning": f"[{_icon('warning')}]",
|
||||
"error": f"[{_icon('error')}]",
|
||||
}
|
||||
styles = {
|
||||
"debug": _style("debug"),
|
||||
"info": _style("info"),
|
||||
"success": _style("success"),
|
||||
"warning": _style("warning"),
|
||||
"error": _style("error"),
|
||||
}
|
||||
logger = get_logger(logger_name)
|
||||
rendered = message % args if args else message
|
||||
if _allow_color():
|
||||
logger.log(palette[level], f"[{styles[level]}]{labels[level]}[/] {rendered}")
|
||||
else:
|
||||
logger.log(palette[level], "%s %s", labels[level], rendered)
|
||||
|
||||
|
||||
def success(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
|
||||
status_message("success", message, *args, logger_name=logger_name)
|
||||
|
||||
|
||||
def warning(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
|
||||
status_message("warning", message, *args, logger_name=logger_name)
|
||||
|
||||
|
||||
def info(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
|
||||
status_message("info", message, *args, logger_name=logger_name)
|
||||
|
||||
|
||||
def debug(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
|
||||
status_message("debug", message, *args, logger_name=logger_name)
|
||||
|
||||
|
||||
def fatal(message: str, *args: object, exc_info: bool = False, logger_name: str = "data_entry_app.status") -> None:
|
||||
logger = get_logger(logger_name)
|
||||
rendered = message % args if args else message
|
||||
if _allow_color():
|
||||
logger.error(f"[{_style('error')}][{_icon('error')}][/] {rendered}", exc_info=exc_info)
|
||||
else:
|
||||
logger.error("[%s] %s", _icon("error"), rendered, exc_info=exc_info)
|
||||
|
||||
|
||||
def shutdown_summary(*, uptime_seconds: float, requests_served: int, host: str, port: int) -> None:
|
||||
section_heading("Shutdown")
|
||||
logger = get_logger("data_entry_app.shutdown")
|
||||
summary = f"Uptime {uptime_seconds:.1f}s | Requests {requests_served} | Endpoint http://{host}:{port}"
|
||||
if _allow_color():
|
||||
logger.info(f"[{_style('debug')}]{_icon('shutdown')}[/] {summary}")
|
||||
else:
|
||||
logger.info("%s %s", _icon("shutdown"), summary)
|
||||
|
||||
|
||||
def describe_database(url: str) -> str:
|
||||
if url.startswith("sqlite"):
|
||||
return "sqlite"
|
||||
if "postgresql" in url:
|
||||
return "postgresql"
|
||||
if "mysql" in url:
|
||||
return "mysql"
|
||||
return url.split(":", 1)[0]
|
||||
|
||||
|
||||
def sanitize_database_target(url: str) -> str:
|
||||
if url.startswith("sqlite:///"):
|
||||
return url.removeprefix("sqlite:///")
|
||||
if "@" in url:
|
||||
return url.split("@", 1)[1]
|
||||
return url
|
||||
|
||||
|
||||
def startup_status(settings: LoggingSettings) -> StartupStatus:
|
||||
host = settings.host
|
||||
local_host = "localhost" if host in {"0.0.0.0", "::"} else host
|
||||
timestamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
|
||||
return StartupStatus(
|
||||
app_name=settings.app_name,
|
||||
version=settings.version,
|
||||
environment=settings.app_env,
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
database=f"{describe_database(settings.database_url)} ({sanitize_database_target(settings.database_url)})",
|
||||
mode="verbose" if settings.log_verbose else "normal",
|
||||
started_at=timestamp,
|
||||
local_url=f"http://{local_host}:{settings.port}",
|
||||
network_url=f"http://{host}:{settings.port}",
|
||||
)
|
||||
|
||||
|
||||
def route_summary(routes: Iterable[object]) -> tuple[int, list[str]]:
|
||||
lines: list[str] = []
|
||||
count = 0
|
||||
for route in routes:
|
||||
path = getattr(route, "path", None)
|
||||
methods = getattr(route, "methods", None)
|
||||
if not path or not methods:
|
||||
continue
|
||||
filtered_methods = sorted(method for method in methods if method not in {"HEAD", "OPTIONS"})
|
||||
if not filtered_methods:
|
||||
continue
|
||||
count += 1
|
||||
lines.append(f"{','.join(filtered_methods):<7} {path}")
|
||||
return count, lines
|
||||
|
||||
|
||||
def log_request(
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
status_code: int,
|
||||
duration_ms: float,
|
||||
client: str,
|
||||
content_length: str | None,
|
||||
) -> None:
|
||||
level = "info"
|
||||
if status_code >= 500:
|
||||
level = "error"
|
||||
elif status_code >= 400:
|
||||
level = "warning"
|
||||
elif path == "/health":
|
||||
level = "debug"
|
||||
|
||||
message = (
|
||||
f"{method:<6} {status_code:>3} {duration_ms:>7.1f}ms "
|
||||
f"{path:<36} client={client}"
|
||||
)
|
||||
if content_length:
|
||||
message += f" bytes={content_length}"
|
||||
status_message(level, message, logger_name="data_entry_app.http")
|
||||
|
||||
|
||||
class RequestTimer:
|
||||
def __init__(self) -> None:
|
||||
self.started = time.perf_counter()
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> float:
|
||||
return (time.perf_counter() - self.started) * 1000
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from threading import Lock
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowRateLimiter:
|
||||
limit: int
|
||||
window_seconds: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._events: dict[str, deque[float]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def hit(self, key: str) -> None:
|
||||
now = time.time()
|
||||
floor = now - self.window_seconds
|
||||
|
||||
with self._lock:
|
||||
bucket = self._events.setdefault(key, deque())
|
||||
while bucket and bucket[0] <= floor:
|
||||
bucket.popleft()
|
||||
if len(bucket) >= self.limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail="Too many requests. Please try again later.",
|
||||
)
|
||||
bucket.append(now)
|
||||
|
||||
|
||||
def request_client_key(request: Request, *, suffix: str = "") -> str:
|
||||
forwarded_for = request.headers.get("x-forwarded-for", "")
|
||||
client_ip = forwarded_for.split(",", 1)[0].strip() if forwarded_for else (request.client.host if request.client else "unknown")
|
||||
return f"{client_ip}:{suffix}" if suffix else client_ip
|
||||
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("data_entry_app.security")
|
||||
|
||||
|
||||
def log_security_event(event: str, **fields: object) -> None:
|
||||
safe_fields = {
|
||||
key: value
|
||||
for key, value in fields.items()
|
||||
if key not in {"password", "token", "cookie", "authorization"}
|
||||
}
|
||||
logger.info("%s | %s", event, safe_fields)
|
||||
Reference in New Issue
Block a user