This commit is contained in:
2026-05-10 09:46:07 +12:00
parent cfc193b713
commit 2f2466ecac
81 changed files with 2571 additions and 413 deletions
+10 -8
View File
@@ -16,18 +16,16 @@ from __future__ import annotations
from typing import Iterable
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi import Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session, selectinload
from app.core.security import verify_token
from app.core.http import CLIENT_AUTH_COOKIE, get_bearer_or_cookie_token
from app.core.security_logging import log_security_event
from app.db.session import get_db
from app.models.access import Permission, Role, User
bearer_scheme = HTTPBearer(auto_error=False)
# Subject claim used by tokens issued for internal Hunter Stock Feeds users.
# Distinct from the existing client-portal/admin tokens so the two systems
# cannot impersonate each other.
@@ -103,7 +101,7 @@ def _load_user(db: Session, user_id: int) -> User | None:
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
request: Request,
db: Session = Depends(get_db),
) -> User:
"""Resolve the current internal user from the bearer token.
@@ -111,10 +109,11 @@ def get_current_user(
Raises 401 for missing/invalid tokens or unknown users, 403 for inactive
users.
"""
if credentials is None:
token = get_bearer_or_cookie_token(request, cookie_name=CLIENT_AUTH_COOKIE.name)
if token is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
payload = verify_token(credentials.credentials)
payload = verify_token(token)
if payload.get("sub") != INTERNAL_USER_SUBJECT:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication token")
@@ -136,6 +135,7 @@ def require_permission(permission_key: str):
def dependency(user: User = Depends(get_current_user)) -> User:
if not user_has_permission(user, permission_key):
log_security_event("authz.denied", role=user.role.name if user.role else None, permission=permission_key)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission: {permission_key}",
@@ -152,6 +152,7 @@ def require_any_permission(permission_keys: Iterable[str]):
def dependency(user: User = Depends(get_current_user)) -> User:
granted = get_user_permissions(user)
if not any(key in granted for key in keys):
log_security_event("authz.denied", role=user.role.name if user.role else None, permissions=list(keys))
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Requires any of: {list(keys)}",
@@ -169,6 +170,7 @@ def require_all_permissions(permission_keys: Iterable[str]):
granted = get_user_permissions(user)
missing = [key for key in keys if key not in granted]
if missing:
log_security_event("authz.denied", role=user.role.name if user.role else None, permissions=missing)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permissions: {missing}",
+69 -1
View File
@@ -16,9 +16,21 @@ def _parse_csv_env(value: str) -> tuple[str, ...]:
return tuple(part.strip() for part in value.split(",") if part.strip())
def _env_flag(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
@dataclass(frozen=True)
class Settings:
app_name: str
app_env: str
host: str
port: int
log_level: str
log_verbose: bool
database_url: str
client_name: str
client_email: str
@@ -30,11 +42,27 @@ class Settings:
auth_secret: str
cors_allow_origins: tuple[str, ...]
cors_allow_origin_regex: str
session_ttl_seconds: int
session_cookie_name: str
admin_session_cookie_name: str
session_cookie_secure: bool
session_cookie_samesite: str
session_cookie_domain: str | None
request_body_max_bytes: int
login_rate_limit_attempts: int
login_rate_limit_window_seconds: int
trusted_hosts: tuple[str, ...]
docs_enabled: bool
@classmethod
def from_env(cls) -> "Settings":
return cls(
settings = cls(
app_name=os.getenv("APP_NAME", "Data Entry App API"),
app_env=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "development")),
host=os.getenv("HOST", "0.0.0.0"),
port=int(os.getenv("PORT", "8000")),
log_level=os.getenv("LOG_LEVEL", "DEBUG" if os.getenv("LOG_VERBOSE") in {"1", "true", "TRUE", "yes", "on"} else "INFO"),
log_verbose=_env_flag("LOG_VERBOSE"),
database_url=os.getenv("DATABASE_URL", "sqlite:///./data_entry_app.db"),
client_name=os.getenv("CLIENT_NAME", "Hunter Premium Produce"),
client_email=os.getenv("CLIENT_EMAIL", "operator@example.com"),
@@ -51,7 +79,47 @@ class Settings:
)
),
cors_allow_origin_regex=os.getenv("CORS_ALLOW_ORIGIN_REGEX", DEFAULT_CORS_ALLOW_ORIGIN_REGEX),
session_ttl_seconds=int(os.getenv("SESSION_TTL_SECONDS", str(60 * 60 * 12))),
session_cookie_name=os.getenv("SESSION_COOKIE_NAME", "client_session"),
admin_session_cookie_name=os.getenv("ADMIN_SESSION_COOKIE_NAME", "admin_session"),
session_cookie_secure=_env_flag("SESSION_COOKIE_SECURE"),
session_cookie_samesite=os.getenv("SESSION_COOKIE_SAMESITE", "lax").lower(),
session_cookie_domain=os.getenv("SESSION_COOKIE_DOMAIN", "").strip() or None,
request_body_max_bytes=int(os.getenv("REQUEST_BODY_MAX_BYTES", str(1024 * 1024))),
login_rate_limit_attempts=int(os.getenv("LOGIN_RATE_LIMIT_ATTEMPTS", "8")),
login_rate_limit_window_seconds=int(os.getenv("LOGIN_RATE_LIMIT_WINDOW_SECONDS", "300")),
trusted_hosts=_parse_csv_env(os.getenv("TRUSTED_HOSTS", "localhost,127.0.0.1,testserver")),
docs_enabled=_env_flag("DOCS_ENABLED", default=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "development")).lower() != "production"),
)
settings._validate()
return settings
def _validate(self) -> None:
if self.session_cookie_samesite not in {"lax", "strict", "none"}:
raise ValueError("SESSION_COOKIE_SAMESITE must be one of: lax, strict, none")
is_production = self.app_env.lower() == "production"
if not is_production:
return
if self.client_password in {"changeme", "", "replace-with-strong-password"}:
raise ValueError("CLIENT_PASSWORD must be set to a non-default value in production")
if self.admin_password in {"lean101-admin", "", "replace-with-strong-password"}:
raise ValueError("ADMIN_PASSWORD must be set to a non-default value in production")
if self.auth_secret in {"lean-101-local-dev-secret", "change-me-in-production", "", "replace-with-a-long-random-secret"}:
raise ValueError("AUTH_SECRET must be set to a strong production secret")
if len(self.auth_secret) < 32:
raise ValueError("AUTH_SECRET must be at least 32 characters in production")
if not self.session_cookie_secure:
raise ValueError("SESSION_COOKIE_SECURE must be enabled in production")
if not self.cors_allow_origins:
raise ValueError("CORS_ALLOW_ORIGINS must explicitly list production origins")
if "localhost" in ",".join(self.cors_allow_origins).lower():
raise ValueError("CORS_ALLOW_ORIGINS cannot include localhost in production")
if self.cors_allow_origin_regex == DEFAULT_CORS_ALLOW_ORIGIN_REGEX:
raise ValueError("CORS_ALLOW_ORIGIN_REGEX must be overridden or blank in production")
if self.docs_enabled:
raise ValueError("DOCS_ENABLED must be false in production")
settings = Settings.from_env()
+51
View File
@@ -0,0 +1,51 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Final
from fastapi import Request, Response
from app.core.config import settings
COOKIE_PATH: Final[str] = "/"
@dataclass(frozen=True)
class AuthCookie:
name: str
def apply(self, response: Response, token: str) -> None:
response.set_cookie(
key=self.name,
value=token,
httponly=True,
secure=settings.session_cookie_secure,
samesite=settings.session_cookie_samesite,
domain=settings.session_cookie_domain,
path=COOKIE_PATH,
max_age=settings.session_ttl_seconds,
)
def clear(self, response: Response) -> None:
response.delete_cookie(
key=self.name,
domain=settings.session_cookie_domain,
path=COOKIE_PATH,
)
CLIENT_AUTH_COOKIE = AuthCookie(settings.session_cookie_name)
ADMIN_AUTH_COOKIE = AuthCookie(settings.admin_session_cookie_name)
def get_bearer_or_cookie_token(request: Request, *, cookie_name: str) -> str | None:
authorization = request.headers.get("authorization", "").strip()
if authorization.lower().startswith("bearer "):
token = authorization[7:].strip()
if token:
return token
cookie_value = request.cookies.get(cookie_name)
if cookie_value:
return cookie_value
return None
+372
View File
@@ -0,0 +1,372 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
import logging
import os
import sys
import time
from typing import Iterable
try:
from rich.console import Console
from rich.logging import RichHandler
from rich.table import Table
from rich.text import Text
except ImportError: # pragma: no cover - exercised only before dependency install
Console = None
RichHandler = None
Table = None
Text = None
@dataclass(frozen=True)
class LoggingSettings:
app_name: str
app_env: str
host: str
port: int
log_level: str
log_verbose: bool
database_url: str
version: str
@dataclass(frozen=True)
class StartupStatus:
app_name: str
version: str
environment: str
host: str
port: int
database: str
mode: str
started_at: str
local_url: str
network_url: str
class PlainFormatter(logging.Formatter):
default_time_format = "%Y-%m-%d %H:%M:%S"
def format(self, record: logging.LogRecord) -> str:
if not hasattr(record, "component"):
record.component = record.name.rsplit(".", 1)[-1]
return super().format(record)
def _allow_color() -> bool:
if RichHandler is None or Console is None:
return False
if os.getenv("NO_COLOR"):
return False
if os.getenv("TERM") == "dumb":
return False
return hasattr(sys.stderr, "isatty") and sys.stderr.isatty()
def _allow_unicode() -> bool:
encoding = (getattr(sys.stdout, "encoding", None) or "").lower()
if not encoding:
return False
return "utf" in encoding
def _console() -> Console:
return Console(
stderr=True,
soft_wrap=False,
highlight=False,
force_terminal=_allow_color(),
no_color=not _allow_color(),
emoji=False,
)
def _rich_handler(level: str) -> RichHandler:
return RichHandler(
level=level,
console=_console(),
show_time=True,
show_level=True,
show_path=False,
omit_repeated_times=False,
markup=True,
rich_tracebacks=True,
tracebacks_show_locals=False,
log_time_format="%H:%M:%S",
)
def _plain_handler(level: str) -> logging.StreamHandler:
handler = logging.StreamHandler()
handler.setLevel(level)
handler.setFormatter(
PlainFormatter("%(asctime)s | %(levelname)-7s | %(component)-10s | %(message)s")
)
return handler
def _handler(level: str) -> logging.Handler:
return _rich_handler(level) if _allow_color() else _plain_handler(level)
def configure_logging(settings: LoggingSettings) -> None:
level = settings.log_level.upper()
root = logging.getLogger()
root.handlers.clear()
root.setLevel(level)
root.addHandler(_handler(level))
for name in ("uvicorn", "uvicorn.error", "fastapi"):
logger = logging.getLogger(name)
logger.handlers.clear()
logger.setLevel(level)
logger.propagate = True
access_logger = logging.getLogger("uvicorn.access")
access_logger.handlers.clear()
access_logger.propagate = False
access_logger.disabled = True
def get_logger(name: str) -> logging.LoggerAdapter[logging.Logger]:
component = name.rsplit(".", 1)[-1]
return logging.LoggerAdapter(logging.getLogger(name), {"component": component})
def _icon(name: str) -> str:
ascii_icons = {
"app": "#",
"info": "i",
"success": "+",
"warning": "!",
"error": "x",
"debug": ".",
"section": "=",
"url": ">",
"shutdown": "-",
}
unicode_icons = {
"app": "",
"info": "",
"success": "",
"warning": "",
"error": "",
"debug": "",
"section": "",
"url": "",
"shutdown": "",
}
icons = unicode_icons if _allow_unicode() else ascii_icons
return icons[name]
def _style(name: str) -> str:
return {
"info": "bold cyan",
"success": "bold green",
"warning": "bold yellow",
"error": "bold red",
"debug": "dim",
"section": "bold bright_blue",
"muted": "grey62",
}[name]
def section_heading(title: str) -> None:
logger = get_logger("data_entry_app.section")
if _allow_color():
_console().rule(Text(f" {title.upper()} ", style=_style("section")))
return
logger.info("%s %s %s", _icon("section") * 10, title.upper(), _icon("section") * 10)
def startup_banner(status: StartupStatus) -> None:
logger = get_logger("data_entry_app.startup")
if _allow_color():
console = _console()
table = Table.grid(expand=False)
table.add_column(style="bold white", justify="left")
table.add_column(style="white", justify="left")
table.add_row("Environment", status.environment)
table.add_row("Version", status.version)
table.add_row("Host", status.host)
table.add_row("Port", str(status.port))
table.add_row("Database", status.database)
table.add_row("Mode", status.mode)
table.add_row("Started", status.started_at)
console.rule(Text(f" {status.app_name} ", style="bold white"))
console.print(Text("Clean startup. Clear status. Ready.", style="italic cyan"))
console.print(table)
console.print()
console.print(Text("App is running at:", style="bold white"))
console.print(Text(f" Local: {status.local_url}", style="cyan"))
console.print(Text(f" Network: {status.network_url}", style="cyan"))
console.print()
return
logger.info("%s %s", _icon("app"), "Startup banner")
logger.info("App : %s", status.app_name)
logger.info("Environment : %s", status.environment)
logger.info("Version : %s", status.version)
logger.info("Host : %s", status.host)
logger.info("Port : %s", status.port)
logger.info("Database : %s", status.database)
logger.info("Mode : %s", status.mode)
logger.info("Started : %s", status.started_at)
logger.info("Local : %s", status.local_url)
logger.info("Network : %s", status.network_url)
def status_message(level: str, message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
palette = {
"debug": logging.DEBUG,
"info": logging.INFO,
"success": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
}
labels = {
"debug": f"[{_icon('debug')}]",
"info": f"[{_icon('info')}]",
"success": f"[{_icon('success')}]",
"warning": f"[{_icon('warning')}]",
"error": f"[{_icon('error')}]",
}
styles = {
"debug": _style("debug"),
"info": _style("info"),
"success": _style("success"),
"warning": _style("warning"),
"error": _style("error"),
}
logger = get_logger(logger_name)
rendered = message % args if args else message
if _allow_color():
logger.log(palette[level], f"[{styles[level]}]{labels[level]}[/] {rendered}")
else:
logger.log(palette[level], "%s %s", labels[level], rendered)
def success(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
status_message("success", message, *args, logger_name=logger_name)
def warning(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
status_message("warning", message, *args, logger_name=logger_name)
def info(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
status_message("info", message, *args, logger_name=logger_name)
def debug(message: str, *args: object, logger_name: str = "data_entry_app.status") -> None:
status_message("debug", message, *args, logger_name=logger_name)
def fatal(message: str, *args: object, exc_info: bool = False, logger_name: str = "data_entry_app.status") -> None:
logger = get_logger(logger_name)
rendered = message % args if args else message
if _allow_color():
logger.error(f"[{_style('error')}][{_icon('error')}][/] {rendered}", exc_info=exc_info)
else:
logger.error("[%s] %s", _icon("error"), rendered, exc_info=exc_info)
def shutdown_summary(*, uptime_seconds: float, requests_served: int, host: str, port: int) -> None:
section_heading("Shutdown")
logger = get_logger("data_entry_app.shutdown")
summary = f"Uptime {uptime_seconds:.1f}s | Requests {requests_served} | Endpoint http://{host}:{port}"
if _allow_color():
logger.info(f"[{_style('debug')}]{_icon('shutdown')}[/] {summary}")
else:
logger.info("%s %s", _icon("shutdown"), summary)
def describe_database(url: str) -> str:
if url.startswith("sqlite"):
return "sqlite"
if "postgresql" in url:
return "postgresql"
if "mysql" in url:
return "mysql"
return url.split(":", 1)[0]
def sanitize_database_target(url: str) -> str:
if url.startswith("sqlite:///"):
return url.removeprefix("sqlite:///")
if "@" in url:
return url.split("@", 1)[1]
return url
def startup_status(settings: LoggingSettings) -> StartupStatus:
host = settings.host
local_host = "localhost" if host in {"0.0.0.0", "::"} else host
timestamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S %Z")
return StartupStatus(
app_name=settings.app_name,
version=settings.version,
environment=settings.app_env,
host=settings.host,
port=settings.port,
database=f"{describe_database(settings.database_url)} ({sanitize_database_target(settings.database_url)})",
mode="verbose" if settings.log_verbose else "normal",
started_at=timestamp,
local_url=f"http://{local_host}:{settings.port}",
network_url=f"http://{host}:{settings.port}",
)
def route_summary(routes: Iterable[object]) -> tuple[int, list[str]]:
lines: list[str] = []
count = 0
for route in routes:
path = getattr(route, "path", None)
methods = getattr(route, "methods", None)
if not path or not methods:
continue
filtered_methods = sorted(method for method in methods if method not in {"HEAD", "OPTIONS"})
if not filtered_methods:
continue
count += 1
lines.append(f"{','.join(filtered_methods):<7} {path}")
return count, lines
def log_request(
*,
method: str,
path: str,
status_code: int,
duration_ms: float,
client: str,
content_length: str | None,
) -> None:
level = "info"
if status_code >= 500:
level = "error"
elif status_code >= 400:
level = "warning"
elif path == "/health":
level = "debug"
message = (
f"{method:<6} {status_code:>3} {duration_ms:>7.1f}ms "
f"{path:<36} client={client}"
)
if content_length:
message += f" bytes={content_length}"
status_message(level, message, logger_name="data_entry_app.http")
class RequestTimer:
def __init__(self) -> None:
self.started = time.perf_counter()
@property
def elapsed_ms(self) -> float:
return (time.perf_counter() - self.started) * 1000
+39
View File
@@ -0,0 +1,39 @@
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from threading import Lock
from fastapi import HTTPException, Request, status
@dataclass
class SlidingWindowRateLimiter:
limit: int
window_seconds: int
def __post_init__(self) -> None:
self._events: dict[str, deque[float]] = {}
self._lock = Lock()
def hit(self, key: str) -> None:
now = time.time()
floor = now - self.window_seconds
with self._lock:
bucket = self._events.setdefault(key, deque())
while bucket and bucket[0] <= floor:
bucket.popleft()
if len(bucket) >= self.limit:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Too many requests. Please try again later.",
)
bucket.append(now)
def request_client_key(request: Request, *, suffix: str = "") -> str:
forwarded_for = request.headers.get("x-forwarded-for", "")
client_ip = forwarded_for.split(",", 1)[0].strip() if forwarded_for else (request.client.host if request.client else "unknown")
return f"{client_ip}:{suffix}" if suffix else client_ip
+15
View File
@@ -0,0 +1,15 @@
from __future__ import annotations
import logging
logger = logging.getLogger("data_entry_app.security")
def log_security_event(event: str, **fields: object) -> None:
safe_fields = {
key: value
for key, value in fields.items()
if key not in {"password", "token", "cookie", "authorization"}
}
logger.info("%s | %s", event, safe_fields)