v1.4 - Login fixes, etc

This commit is contained in:
2026-04-27 21:53:36 +12:00
parent 8cf9bfb441
commit c9580ac2eb
33 changed files with 2283 additions and 202 deletions
+23
View File
@@ -2,6 +2,20 @@ import os
from dataclasses import dataclass
DEFAULT_CORS_ALLOW_ORIGIN_REGEX = (
r"^https?://("
r"localhost|127\.0\.0\.1|"
r"10\.\d{1,3}\.\d{1,3}\.\d{1,3}|"
r"192\.168\.\d{1,3}\.\d{1,3}|"
r"172\.(1[6-9]|2\d|3[0-1])\.\d{1,3}\.\d{1,3}"
r")(:\d+)?$"
)
def _parse_csv_env(value: str) -> tuple[str, ...]:
return tuple(part.strip() for part in value.split(",") if part.strip())
@dataclass(frozen=True)
class Settings:
app_name: str
@@ -14,6 +28,8 @@ class Settings:
admin_email: str
admin_password: str
auth_secret: str
cors_allow_origins: tuple[str, ...]
cors_allow_origin_regex: str
@classmethod
def from_env(cls) -> "Settings":
@@ -28,6 +44,13 @@ class Settings:
admin_email=os.getenv("ADMIN_EMAIL", "admin@lean101.local"),
admin_password=os.getenv("ADMIN_PASSWORD", "lean101-admin"),
auth_secret=os.getenv("AUTH_SECRET", "lean-101-local-dev-secret"),
cors_allow_origins=_parse_csv_env(
os.getenv(
"CORS_ALLOW_ORIGINS",
"http://localhost:5173,http://localhost:5174,http://127.0.0.1:5173,http://127.0.0.1:5174",
)
),
cors_allow_origin_regex=os.getenv("CORS_ALLOW_ORIGIN_REGEX", DEFAULT_CORS_ALLOW_ORIGIN_REGEX),
)
+106 -21
View File
@@ -1,6 +1,8 @@
from __future__ import annotations
from sqlalchemy import inspect, text
from dataclasses import dataclass, field
from sqlalchemy import MetaData, inspect, text
from sqlalchemy.engine import Engine
@@ -20,6 +22,27 @@ TENANT_TABLES = {
}
@dataclass(frozen=True)
class MigrationReport:
created_tables: tuple[str, ...] = ()
added_columns: tuple[str, ...] = ()
synced_tenant_rows: dict[str, int] = field(default_factory=dict)
def has_changes(self) -> bool:
return bool(self.created_tables or self.added_columns or self.synced_tenant_rows)
def summary(self) -> str:
parts: list[str] = []
if self.created_tables:
parts.append(f"created tables: {', '.join(self.created_tables)}")
if self.added_columns:
parts.append(f"patched columns: {', '.join(self.added_columns)}")
if self.synced_tenant_rows:
counts = ", ".join(f"{table}={count}" for table, count in sorted(self.synced_tenant_rows.items()))
parts.append(f"synced tenant rows: {counts}")
return "; ".join(parts) if parts else "schema already up to date"
def _has_column(engine: Engine, table_name: str, column_name: str) -> bool:
inspector = inspect(engine)
try:
@@ -40,25 +63,41 @@ def _table_exists(engine: Engine, table_name: str) -> bool:
return inspect(engine).has_table(table_name)
def ensure_tenant_columns(engine: Engine) -> None:
def ensure_metadata_tables(engine: Engine, metadata: MetaData) -> tuple[str, ...]:
missing_tables = tuple(table.name for table in metadata.sorted_tables if not _table_exists(engine, table.name))
if missing_tables:
metadata.create_all(bind=engine)
return missing_tables
def ensure_tenant_columns(engine: Engine) -> tuple[str, ...]:
added_columns: list[str] = []
for table_name in TENANT_TABLES:
if _table_exists(engine, table_name):
_add_tenant_column(engine, table_name)
if not _has_column(engine, table_name, "tenant_id"):
_add_tenant_column(engine, table_name)
added_columns.append(f"{table_name}.tenant_id")
return tuple(added_columns)
def sync_tenant_ids(engine: Engine) -> None:
if not _table_exists(engine, "client_accounts"):
return
def sync_tenant_ids(engine: Engine) -> dict[str, int]:
existing_tables = set(inspect(engine).get_table_names())
if "client_accounts" not in existing_tables:
return {}
synced_rows: dict[str, int] = {}
with engine.begin() as connection:
default_tenant = connection.execute(
text("SELECT tenant_id FROM client_accounts ORDER BY id LIMIT 1")
).scalar_one_or_none()
if not default_tenant:
return
return {}
statements = [
text(
(
"client_users",
text(
"""
UPDATE client_users
SET tenant_id = (
@@ -68,8 +107,11 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"client_feature_access",
text(
"""
UPDATE client_feature_access
SET tenant_id = (
@@ -79,15 +121,21 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"raw_materials",
text(
"""
UPDATE raw_materials
SET tenant_id = :default_tenant
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"raw_material_price_versions",
text(
"""
UPDATE raw_material_price_versions
SET tenant_id = (
@@ -97,8 +145,11 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"mixes",
text(
"""
UPDATE mixes
SET tenant_id = COALESCE(
@@ -111,8 +162,11 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"mix_ingredients",
text(
"""
UPDATE mix_ingredients
SET tenant_id = (
@@ -122,8 +176,11 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"products",
text(
"""
UPDATE products
SET tenant_id = COALESCE(
@@ -141,15 +198,21 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"scenarios",
text(
"""
UPDATE scenarios
SET tenant_id = :default_tenant
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"costing_results",
text(
"""
UPDATE costing_results
SET tenant_id = COALESCE(
@@ -167,29 +230,51 @@ def sync_tenant_ids(engine: Engine) -> None:
)
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"process_cost_rules",
text(
"""
UPDATE process_cost_rules
SET tenant_id = :default_tenant
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"packaging_cost_rules",
text(
"""
UPDATE packaging_cost_rules
SET tenant_id = :default_tenant
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
text(
(
"freight_cost_rules",
text(
"""
UPDATE freight_cost_rules
SET tenant_id = :default_tenant
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
"""
),
),
]
for statement in statements:
connection.execute(statement, {"default_tenant": default_tenant})
for table_name, statement in statements:
if table_name not in existing_tables:
continue
result = connection.execute(statement, {"default_tenant": default_tenant})
if result.rowcount and result.rowcount > 0:
synced_rows[table_name] = result.rowcount
return synced_rows
def bootstrap_schema(engine: Engine, metadata: MetaData) -> MigrationReport:
created_tables = ensure_metadata_tables(engine, metadata)
added_columns = ensure_tenant_columns(engine)
return MigrationReport(created_tables=created_tables, added_columns=added_columns)
+36 -6
View File
@@ -1,7 +1,9 @@
import logging
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from threading import Lock
if __package__ in {None, ""}:
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
@@ -19,16 +21,41 @@ from app.api.raw_materials import router as raw_materials_router
from app.api.scenarios import router as scenarios_router
from app.core.config import settings
from app.db.session import Base, engine
from app.db.migrations import ensure_tenant_columns, sync_tenant_ids
from app.db.migrations import MigrationReport, bootstrap_schema, sync_tenant_ids
from app.seed import seed_if_empty
logger = logging.getLogger("data_entry_app.startup")
_database_ready = False
_database_ready_lock = Lock()
def ensure_database_ready() -> MigrationReport:
global _database_ready
if _database_ready:
return MigrationReport()
with _database_ready_lock:
if _database_ready:
return MigrationReport()
schema_report = bootstrap_schema(engine, Base.metadata)
seed_if_empty()
tenant_sync_report = sync_tenant_ids(engine)
report = MigrationReport(
created_tables=schema_report.created_tables,
added_columns=schema_report.added_columns,
synced_tenant_rows=tenant_sync_report,
)
logger.info("Database startup checks complete: %s", report.summary())
_database_ready = True
return report
@asynccontextmanager
async def lifespan(_: FastAPI):
Base.metadata.create_all(bind=engine)
ensure_tenant_columns(engine)
seed_if_empty()
sync_tenant_ids(engine)
ensure_database_ready()
yield
@@ -36,7 +63,8 @@ app = FastAPI(title=settings.app_name, lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:5173", "http://localhost:5174"],
allow_origins=list(settings.cors_allow_origins),
allow_origin_regex=settings.cors_allow_origin_regex,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@@ -81,6 +109,8 @@ def healthcheck():
if __name__ == "__main__":
report = ensure_database_ready()
print(f"Database startup checks complete: {report.summary()}")
uvicorn.run(
app,
host=os.getenv("HOST", "0.0.0.0"),
+2 -1
View File
@@ -56,6 +56,8 @@ class ProductRead(BaseModel):
class ProductCostBreakdown(BaseModel):
product_id: int
product_name: str
client_name: str
mix_name: str
cleaned_product_cost: float
grading_cost: float
bagging_cost: float
@@ -67,4 +69,3 @@ class ProductCostBreakdown(BaseModel):
wholesale_price: float | None
warnings: list[str]
inputs: dict[str, object]
+2
View File
@@ -224,6 +224,8 @@ def calculate_product_cost(db: Session, product_id: int, overrides: dict | None
return {
"product_id": product.id,
"product_name": product.name,
"client_name": product.client_name,
"mix_name": product.mix.name if product.mix else "",
"cleaned_product_cost": round(cleaned_product_cost, 4),
"grading_cost": round(grading_cost, 4),
"bagging_cost": round(bagging_cost, 4),
+76 -1
View File
@@ -1,10 +1,12 @@
from datetime import date
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy import create_engine, inspect, text
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import StaticPool
from app.core.config import settings
from app.db.migrations import bootstrap_schema, sync_tenant_ids
from app.db.session import Base
from app.main import app
from app.models.assumption import FreightCostRule, PackagingCostRule, ProcessCostRule
@@ -77,6 +79,8 @@ def test_mix_and_product_cost_breakdown():
assert mix_result["total_mix_kg"] == 280
assert mix_result["mix_cost_per_kg"] == 0.5114
assert product_result["client_name"] == "Specialty Feeds"
assert product_result["mix_name"] == "Pigeon Mix"
assert product_result["finished_product_delivered"] == 14.208
assert product_result["distributor_price"] == 18.3329
assert product_result["wholesale_price"] == 17.3268
@@ -181,3 +185,74 @@ def test_client_access_endpoints():
export_response = client.get("/api/powerbi/client-access", headers=headers)
assert export_response.status_code == 200
assert "client_rows" in export_response.json()
def test_bootstrap_schema_creates_missing_tables_and_patches_legacy_tenant_columns():
engine = create_engine(
"sqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
with engine.begin() as connection:
connection.execute(
text(
"""
CREATE TABLE client_accounts (
id INTEGER PRIMARY KEY,
tenant_id VARCHAR(64),
name VARCHAR(255),
client_code VARCHAR(64),
status VARCHAR(32),
powerbi_workspace VARCHAR(128),
notes TEXT,
created_at DATETIME
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE raw_materials (
id INTEGER PRIMARY KEY,
name VARCHAR(255),
supplier VARCHAR(255),
unit_of_measure VARCHAR(64),
kg_per_unit FLOAT,
status VARCHAR(32),
notes TEXT,
created_at DATETIME
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO client_accounts (id, tenant_id, name, client_code, status)
VALUES (1, 'specialty-feeds', 'Specialty Feeds', 'SPEC', 'active')
"""
)
)
connection.execute(
text(
"""
INSERT INTO raw_materials (id, name, supplier, unit_of_measure, kg_per_unit, status)
VALUES (1, 'Maize', 'Example Supplier', 'tonne', 1000, 'active')
"""
)
)
report = bootstrap_schema(engine, Base.metadata)
synced_rows = sync_tenant_ids(engine)
assert "products" in report.created_tables
assert "raw_materials.tenant_id" in report.added_columns
assert "tenant_id" in {column["name"] for column in inspect(engine).get_columns("raw_materials")}
assert synced_rows["raw_materials"] == 1
with engine.begin() as connection:
tenant_id = connection.execute(text("SELECT tenant_id FROM raw_materials WHERE id = 1")).scalar_one()
assert tenant_id == "specialty-feeds"