from __future__ import annotations from dataclasses import dataclass, field from sqlalchemy import MetaData, inspect, text from sqlalchemy.engine import Engine TENANT_TABLES = { "client_users": None, "client_feature_access": None, "raw_materials": None, "raw_material_price_versions": None, "mixes": None, "mix_ingredients": None, "products": None, "scenarios": None, "costing_results": None, "process_cost_rules": None, "packaging_cost_rules": None, "freight_cost_rules": None, } @dataclass(frozen=True) class MigrationReport: created_tables: tuple[str, ...] = () added_columns: tuple[str, ...] = () synced_tenant_rows: dict[str, int] = field(default_factory=dict) def has_changes(self) -> bool: return bool(self.created_tables or self.added_columns or self.synced_tenant_rows) def summary(self) -> str: parts: list[str] = [] if self.created_tables: parts.append(f"created tables: {', '.join(self.created_tables)}") if self.added_columns: parts.append(f"patched columns: {', '.join(self.added_columns)}") if self.synced_tenant_rows: counts = ", ".join(f"{table}={count}" for table, count in sorted(self.synced_tenant_rows.items())) parts.append(f"synced tenant rows: {counts}") return "; ".join(parts) if parts else "schema already up to date" def _has_column(engine: Engine, table_name: str, column_name: str) -> bool: inspector = inspect(engine) try: columns = inspector.get_columns(table_name) except Exception: return False return any(column["name"] == column_name for column in columns) def _add_tenant_column(engine: Engine, table_name: str) -> None: if _has_column(engine, table_name, "tenant_id"): return with engine.begin() as connection: connection.execute(text(f"ALTER TABLE {table_name} ADD COLUMN tenant_id VARCHAR(64)")) def _table_exists(engine: Engine, table_name: str) -> bool: return inspect(engine).has_table(table_name) def ensure_metadata_tables(engine: Engine, metadata: MetaData) -> tuple[str, ...]: missing_tables = tuple(table.name for table in metadata.sorted_tables if not _table_exists(engine, table.name)) if missing_tables: metadata.create_all(bind=engine) return missing_tables def ensure_tenant_columns(engine: Engine) -> tuple[str, ...]: added_columns: list[str] = [] for table_name in TENANT_TABLES: if _table_exists(engine, table_name): if not _has_column(engine, table_name, "tenant_id"): _add_tenant_column(engine, table_name) added_columns.append(f"{table_name}.tenant_id") return tuple(added_columns) def sync_tenant_ids(engine: Engine) -> dict[str, int]: existing_tables = set(inspect(engine).get_table_names()) if "client_accounts" not in existing_tables: return {} synced_rows: dict[str, int] = {} with engine.begin() as connection: default_tenant = connection.execute( text("SELECT tenant_id FROM client_accounts ORDER BY id LIMIT 1") ).scalar_one_or_none() if not default_tenant: return {} statements = [ ( "client_users", text( """ UPDATE client_users SET tenant_id = ( SELECT client_accounts.tenant_id FROM client_accounts WHERE client_accounts.id = client_users.client_account_id ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "client_feature_access", text( """ UPDATE client_feature_access SET tenant_id = ( SELECT client_accounts.tenant_id FROM client_accounts WHERE client_accounts.id = client_feature_access.client_account_id ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "raw_materials", text( """ UPDATE raw_materials SET tenant_id = :default_tenant WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "raw_material_price_versions", text( """ UPDATE raw_material_price_versions SET tenant_id = ( SELECT raw_materials.tenant_id FROM raw_materials WHERE raw_materials.id = raw_material_price_versions.raw_material_id ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "mixes", text( """ UPDATE mixes SET tenant_id = COALESCE( ( SELECT client_accounts.tenant_id FROM client_accounts WHERE client_accounts.name = mixes.client_name ), :default_tenant ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "mix_ingredients", text( """ UPDATE mix_ingredients SET tenant_id = ( SELECT mixes.tenant_id FROM mixes WHERE mixes.id = mix_ingredients.mix_id ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "products", text( """ UPDATE products SET tenant_id = COALESCE( ( SELECT client_accounts.tenant_id FROM client_accounts WHERE client_accounts.name = products.client_name ), ( SELECT mixes.tenant_id FROM mixes WHERE mixes.id = products.mix_id ), :default_tenant ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "scenarios", text( """ UPDATE scenarios SET tenant_id = :default_tenant WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "costing_results", text( """ UPDATE costing_results SET tenant_id = COALESCE( ( SELECT products.tenant_id FROM products WHERE products.id = costing_results.product_id ), ( SELECT scenarios.tenant_id FROM scenarios WHERE scenarios.id = costing_results.scenario_id ), :default_tenant ) WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "process_cost_rules", text( """ UPDATE process_cost_rules SET tenant_id = :default_tenant WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "packaging_cost_rules", text( """ UPDATE packaging_cost_rules SET tenant_id = :default_tenant WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ( "freight_cost_rules", text( """ UPDATE freight_cost_rules SET tenant_id = :default_tenant WHERE tenant_id IS NULL OR tenant_id = '' OR tenant_id = 'default' """ ), ), ] for table_name, statement in statements: if table_name not in existing_tables: continue result = connection.execute(statement, {"default_tenant": default_tenant}) if result.rowcount and result.rowcount > 0: synced_rows[table_name] = result.rowcount return synced_rows def bootstrap_schema(engine: Engine, metadata: MetaData) -> MigrationReport: created_tables = ensure_metadata_tables(engine, metadata) added_columns = ensure_tenant_columns(engine) return MigrationReport(created_tables=created_tables, added_columns=added_columns)