v1.4 - Login fixes, etc
This commit is contained in:
@@ -2,6 +2,20 @@ import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
DEFAULT_CORS_ALLOW_ORIGIN_REGEX = (
|
||||
r"^https?://("
|
||||
r"localhost|127\.0\.0\.1|"
|
||||
r"10\.\d{1,3}\.\d{1,3}\.\d{1,3}|"
|
||||
r"192\.168\.\d{1,3}\.\d{1,3}|"
|
||||
r"172\.(1[6-9]|2\d|3[0-1])\.\d{1,3}\.\d{1,3}"
|
||||
r")(:\d+)?$"
|
||||
)
|
||||
|
||||
|
||||
def _parse_csv_env(value: str) -> tuple[str, ...]:
|
||||
return tuple(part.strip() for part in value.split(",") if part.strip())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
app_name: str
|
||||
@@ -14,6 +28,8 @@ class Settings:
|
||||
admin_email: str
|
||||
admin_password: str
|
||||
auth_secret: str
|
||||
cors_allow_origins: tuple[str, ...]
|
||||
cors_allow_origin_regex: str
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "Settings":
|
||||
@@ -28,6 +44,13 @@ class Settings:
|
||||
admin_email=os.getenv("ADMIN_EMAIL", "admin@lean101.local"),
|
||||
admin_password=os.getenv("ADMIN_PASSWORD", "lean101-admin"),
|
||||
auth_secret=os.getenv("AUTH_SECRET", "lean-101-local-dev-secret"),
|
||||
cors_allow_origins=_parse_csv_env(
|
||||
os.getenv(
|
||||
"CORS_ALLOW_ORIGINS",
|
||||
"http://localhost:5173,http://localhost:5174,http://127.0.0.1:5173,http://127.0.0.1:5174",
|
||||
)
|
||||
),
|
||||
cors_allow_origin_regex=os.getenv("CORS_ALLOW_ORIGIN_REGEX", DEFAULT_CORS_ALLOW_ORIGIN_REGEX),
|
||||
)
|
||||
|
||||
|
||||
|
||||
+106
-21
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import inspect, text
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy import MetaData, inspect, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
@@ -20,6 +22,27 @@ TENANT_TABLES = {
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationReport:
|
||||
created_tables: tuple[str, ...] = ()
|
||||
added_columns: tuple[str, ...] = ()
|
||||
synced_tenant_rows: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def has_changes(self) -> bool:
|
||||
return bool(self.created_tables or self.added_columns or self.synced_tenant_rows)
|
||||
|
||||
def summary(self) -> str:
|
||||
parts: list[str] = []
|
||||
if self.created_tables:
|
||||
parts.append(f"created tables: {', '.join(self.created_tables)}")
|
||||
if self.added_columns:
|
||||
parts.append(f"patched columns: {', '.join(self.added_columns)}")
|
||||
if self.synced_tenant_rows:
|
||||
counts = ", ".join(f"{table}={count}" for table, count in sorted(self.synced_tenant_rows.items()))
|
||||
parts.append(f"synced tenant rows: {counts}")
|
||||
return "; ".join(parts) if parts else "schema already up to date"
|
||||
|
||||
|
||||
def _has_column(engine: Engine, table_name: str, column_name: str) -> bool:
|
||||
inspector = inspect(engine)
|
||||
try:
|
||||
@@ -40,25 +63,41 @@ def _table_exists(engine: Engine, table_name: str) -> bool:
|
||||
return inspect(engine).has_table(table_name)
|
||||
|
||||
|
||||
def ensure_tenant_columns(engine: Engine) -> None:
|
||||
def ensure_metadata_tables(engine: Engine, metadata: MetaData) -> tuple[str, ...]:
|
||||
missing_tables = tuple(table.name for table in metadata.sorted_tables if not _table_exists(engine, table.name))
|
||||
if missing_tables:
|
||||
metadata.create_all(bind=engine)
|
||||
return missing_tables
|
||||
|
||||
|
||||
def ensure_tenant_columns(engine: Engine) -> tuple[str, ...]:
|
||||
added_columns: list[str] = []
|
||||
for table_name in TENANT_TABLES:
|
||||
if _table_exists(engine, table_name):
|
||||
_add_tenant_column(engine, table_name)
|
||||
if not _has_column(engine, table_name, "tenant_id"):
|
||||
_add_tenant_column(engine, table_name)
|
||||
added_columns.append(f"{table_name}.tenant_id")
|
||||
return tuple(added_columns)
|
||||
|
||||
|
||||
def sync_tenant_ids(engine: Engine) -> None:
|
||||
if not _table_exists(engine, "client_accounts"):
|
||||
return
|
||||
def sync_tenant_ids(engine: Engine) -> dict[str, int]:
|
||||
existing_tables = set(inspect(engine).get_table_names())
|
||||
|
||||
if "client_accounts" not in existing_tables:
|
||||
return {}
|
||||
|
||||
synced_rows: dict[str, int] = {}
|
||||
with engine.begin() as connection:
|
||||
default_tenant = connection.execute(
|
||||
text("SELECT tenant_id FROM client_accounts ORDER BY id LIMIT 1")
|
||||
).scalar_one_or_none()
|
||||
if not default_tenant:
|
||||
return
|
||||
return {}
|
||||
|
||||
statements = [
|
||||
text(
|
||||
(
|
||||
"client_users",
|
||||
text(
|
||||
"""
|
||||
UPDATE client_users
|
||||
SET tenant_id = (
|
||||
@@ -68,8 +107,11 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"client_feature_access",
|
||||
text(
|
||||
"""
|
||||
UPDATE client_feature_access
|
||||
SET tenant_id = (
|
||||
@@ -79,15 +121,21 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"raw_materials",
|
||||
text(
|
||||
"""
|
||||
UPDATE raw_materials
|
||||
SET tenant_id = :default_tenant
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"raw_material_price_versions",
|
||||
text(
|
||||
"""
|
||||
UPDATE raw_material_price_versions
|
||||
SET tenant_id = (
|
||||
@@ -97,8 +145,11 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"mixes",
|
||||
text(
|
||||
"""
|
||||
UPDATE mixes
|
||||
SET tenant_id = COALESCE(
|
||||
@@ -111,8 +162,11 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"mix_ingredients",
|
||||
text(
|
||||
"""
|
||||
UPDATE mix_ingredients
|
||||
SET tenant_id = (
|
||||
@@ -122,8 +176,11 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"products",
|
||||
text(
|
||||
"""
|
||||
UPDATE products
|
||||
SET tenant_id = COALESCE(
|
||||
@@ -141,15 +198,21 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"scenarios",
|
||||
text(
|
||||
"""
|
||||
UPDATE scenarios
|
||||
SET tenant_id = :default_tenant
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"costing_results",
|
||||
text(
|
||||
"""
|
||||
UPDATE costing_results
|
||||
SET tenant_id = COALESCE(
|
||||
@@ -167,29 +230,51 @@ def sync_tenant_ids(engine: Engine) -> None:
|
||||
)
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"process_cost_rules",
|
||||
text(
|
||||
"""
|
||||
UPDATE process_cost_rules
|
||||
SET tenant_id = :default_tenant
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"packaging_cost_rules",
|
||||
text(
|
||||
"""
|
||||
UPDATE packaging_cost_rules
|
||||
SET tenant_id = :default_tenant
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
text(
|
||||
(
|
||||
"freight_cost_rules",
|
||||
text(
|
||||
"""
|
||||
UPDATE freight_cost_rules
|
||||
SET tenant_id = :default_tenant
|
||||
WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default'
|
||||
"""
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for statement in statements:
|
||||
connection.execute(statement, {"default_tenant": default_tenant})
|
||||
for table_name, statement in statements:
|
||||
if table_name not in existing_tables:
|
||||
continue
|
||||
result = connection.execute(statement, {"default_tenant": default_tenant})
|
||||
if result.rowcount and result.rowcount > 0:
|
||||
synced_rows[table_name] = result.rowcount
|
||||
|
||||
return synced_rows
|
||||
|
||||
|
||||
def bootstrap_schema(engine: Engine, metadata: MetaData) -> MigrationReport:
|
||||
created_tables = ensure_metadata_tables(engine, metadata)
|
||||
added_columns = ensure_tenant_columns(engine)
|
||||
return MigrationReport(created_tables=created_tables, added_columns=added_columns)
|
||||
|
||||
+36
-6
@@ -1,7 +1,9 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
if __package__ in {None, ""}:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
@@ -19,16 +21,41 @@ from app.api.raw_materials import router as raw_materials_router
|
||||
from app.api.scenarios import router as scenarios_router
|
||||
from app.core.config import settings
|
||||
from app.db.session import Base, engine
|
||||
from app.db.migrations import ensure_tenant_columns, sync_tenant_ids
|
||||
from app.db.migrations import MigrationReport, bootstrap_schema, sync_tenant_ids
|
||||
from app.seed import seed_if_empty
|
||||
|
||||
logger = logging.getLogger("data_entry_app.startup")
|
||||
_database_ready = False
|
||||
_database_ready_lock = Lock()
|
||||
|
||||
|
||||
def ensure_database_ready() -> MigrationReport:
|
||||
global _database_ready
|
||||
|
||||
if _database_ready:
|
||||
return MigrationReport()
|
||||
|
||||
with _database_ready_lock:
|
||||
if _database_ready:
|
||||
return MigrationReport()
|
||||
|
||||
schema_report = bootstrap_schema(engine, Base.metadata)
|
||||
seed_if_empty()
|
||||
tenant_sync_report = sync_tenant_ids(engine)
|
||||
|
||||
report = MigrationReport(
|
||||
created_tables=schema_report.created_tables,
|
||||
added_columns=schema_report.added_columns,
|
||||
synced_tenant_rows=tenant_sync_report,
|
||||
)
|
||||
logger.info("Database startup checks complete: %s", report.summary())
|
||||
_database_ready = True
|
||||
return report
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
Base.metadata.create_all(bind=engine)
|
||||
ensure_tenant_columns(engine)
|
||||
seed_if_empty()
|
||||
sync_tenant_ids(engine)
|
||||
ensure_database_ready()
|
||||
yield
|
||||
|
||||
|
||||
@@ -36,7 +63,8 @@ app = FastAPI(title=settings.app_name, lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173", "http://localhost:5174"],
|
||||
allow_origins=list(settings.cors_allow_origins),
|
||||
allow_origin_regex=settings.cors_allow_origin_regex,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -81,6 +109,8 @@ def healthcheck():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
report = ensure_database_ready()
|
||||
print(f"Database startup checks complete: {report.summary()}")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=os.getenv("HOST", "0.0.0.0"),
|
||||
|
||||
@@ -56,6 +56,8 @@ class ProductRead(BaseModel):
|
||||
class ProductCostBreakdown(BaseModel):
|
||||
product_id: int
|
||||
product_name: str
|
||||
client_name: str
|
||||
mix_name: str
|
||||
cleaned_product_cost: float
|
||||
grading_cost: float
|
||||
bagging_cost: float
|
||||
@@ -67,4 +69,3 @@ class ProductCostBreakdown(BaseModel):
|
||||
wholesale_price: float | None
|
||||
warnings: list[str]
|
||||
inputs: dict[str, object]
|
||||
|
||||
|
||||
@@ -224,6 +224,8 @@ def calculate_product_cost(db: Session, product_id: int, overrides: dict | None
|
||||
return {
|
||||
"product_id": product.id,
|
||||
"product_name": product.name,
|
||||
"client_name": product.client_name,
|
||||
"mix_name": product.mix.name if product.mix else "",
|
||||
"cleaned_product_cost": round(cleaned_product_cost, 4),
|
||||
"grading_cost": round(grading_cost, 4),
|
||||
"bagging_cost": round(bagging_cost, 4),
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from datetime import date
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.migrations import bootstrap_schema, sync_tenant_ids
|
||||
from app.db.session import Base
|
||||
from app.main import app
|
||||
from app.models.assumption import FreightCostRule, PackagingCostRule, ProcessCostRule
|
||||
@@ -77,6 +79,8 @@ def test_mix_and_product_cost_breakdown():
|
||||
|
||||
assert mix_result["total_mix_kg"] == 280
|
||||
assert mix_result["mix_cost_per_kg"] == 0.5114
|
||||
assert product_result["client_name"] == "Specialty Feeds"
|
||||
assert product_result["mix_name"] == "Pigeon Mix"
|
||||
assert product_result["finished_product_delivered"] == 14.208
|
||||
assert product_result["distributor_price"] == 18.3329
|
||||
assert product_result["wholesale_price"] == 17.3268
|
||||
@@ -181,3 +185,74 @@ def test_client_access_endpoints():
|
||||
export_response = client.get("/api/powerbi/client-access", headers=headers)
|
||||
assert export_response.status_code == 200
|
||||
assert "client_rows" in export_response.json()
|
||||
|
||||
|
||||
def test_bootstrap_schema_creates_missing_tables_and_patches_legacy_tenant_columns():
|
||||
engine = create_engine(
|
||||
"sqlite://",
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE client_accounts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
tenant_id VARCHAR(64),
|
||||
name VARCHAR(255),
|
||||
client_code VARCHAR(64),
|
||||
status VARCHAR(32),
|
||||
powerbi_workspace VARCHAR(128),
|
||||
notes TEXT,
|
||||
created_at DATETIME
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE raw_materials (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name VARCHAR(255),
|
||||
supplier VARCHAR(255),
|
||||
unit_of_measure VARCHAR(64),
|
||||
kg_per_unit FLOAT,
|
||||
status VARCHAR(32),
|
||||
notes TEXT,
|
||||
created_at DATETIME
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO client_accounts (id, tenant_id, name, client_code, status)
|
||||
VALUES (1, 'specialty-feeds', 'Specialty Feeds', 'SPEC', 'active')
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO raw_materials (id, name, supplier, unit_of_measure, kg_per_unit, status)
|
||||
VALUES (1, 'Maize', 'Example Supplier', 'tonne', 1000, 'active')
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
report = bootstrap_schema(engine, Base.metadata)
|
||||
synced_rows = sync_tenant_ids(engine)
|
||||
|
||||
assert "products" in report.created_tables
|
||||
assert "raw_materials.tenant_id" in report.added_columns
|
||||
assert "tenant_id" in {column["name"] for column in inspect(engine).get_columns("raw_materials")}
|
||||
assert synced_rows["raw_materials"] == 1
|
||||
|
||||
with engine.begin() as connection:
|
||||
tenant_id = connection.execute(text("SELECT tenant_id FROM raw_materials WHERE id = 1")).scalar_one()
|
||||
|
||||
assert tenant_id == "specialty-feeds"
|
||||
|
||||
Reference in New Issue
Block a user