from __future__ import annotations from datetime import date from sqlalchemy import func, select from sqlalchemy.orm import Session, joinedload, selectinload from app.api.deps import AuthSession from app.models.mix import Mix, MixIngredient from app.models.mix_calculator import MixCalculatorSession, MixCalculatorSessionLine from app.models.product import Product from app.schemas.mix_calculator import MixCalculatorSessionCreate, MixCalculatorSessionUpdate from app.services.costing_engine import extract_unit_quantity_kg def can_view_all_mix_calculator_sessions(session: AuthSession) -> bool: return session.client_role in {"superadmin", "admin"} def _build_session_access_query(session: AuthSession): query = select(MixCalculatorSession).where(MixCalculatorSession.tenant_id == session.tenant_id) if can_view_all_mix_calculator_sessions(session): return query return query.where(MixCalculatorSession.prepared_by_user_id == session.user_id) def _load_product_for_calculation(db: Session, tenant_id: str, product_id: int) -> Product | None: return db.scalar( select(Product) .where(Product.id == product_id, Product.tenant_id == tenant_id) .options(selectinload(Product.mix).selectinload(Mix.ingredients).selectinload(MixIngredient.raw_material)) ) def _fractional_bag_warning(batch_size_kg: float, total_bags: float, unit_of_measure: str) -> str | None: rounded_bags = round(total_bags) if abs(total_bags - rounded_bags) < 1e-9: return None return ( f"Batch size {batch_size_kg:g}kg produces {total_bags:.2f} bags for {unit_of_measure}. " "This is not a whole-bag quantity." ) def calculate_mix_calculator_preview( db: Session, *, tenant_id: str, payload: MixCalculatorSessionCreate | MixCalculatorSessionUpdate | dict, ): values = payload if isinstance(payload, dict) else payload.model_dump(exclude_unset=False) product = _load_product_for_calculation(db, tenant_id, int(values["product_id"])) if product is None: raise ValueError("Product not found") if product.client_name != values["client_name"]: raise ValueError("Selected product does not belong to the chosen client") if product.mix is None: raise ValueError("Product mix is not configured") source_total_kg = round(sum(ingredient.quantity_kg for ingredient in product.mix.ingredients), 4) if source_total_kg <= 0: raise ValueError("Product mix has no source kilograms to scale") batch_size_kg = float(values["batch_size_kg"]) scale_factor = batch_size_kg / source_total_kg unit_size_kg = extract_unit_quantity_kg(product.unit_of_measure) total_bags = round(batch_size_kg / unit_size_kg, 4) if unit_size_kg > 0 else 0.0 warnings: list[str] = [] bag_warning = _fractional_bag_warning(batch_size_kg, total_bags, product.unit_of_measure) if bag_warning: warnings.append(bag_warning) lines = [] for index, ingredient in enumerate(product.mix.ingredients, start=1): mix_percentage = round((ingredient.quantity_kg / source_total_kg) * 100, 4) required_kg = round(ingredient.quantity_kg * scale_factor, 4) raw_material = ingredient.raw_material lines.append( { "raw_material_id": raw_material.id if raw_material is not None else ingredient.raw_material_id, "raw_material_name": raw_material.name if raw_material is not None else f"Raw material {ingredient.raw_material_id}", "required_kg": required_kg, "mix_percentage": mix_percentage, "unit": raw_material.unit_of_measure if raw_material is not None else "kg", "sort_order": index, } ) return { "client_name": product.client_name, "product_id": product.id, "product_name": product.name, "mix_id": product.mix_id, "mix_name": product.mix.name, "mix_date": values["mix_date"], "batch_size_kg": round(batch_size_kg, 4), "total_bags": total_bags, "total_kg": round(batch_size_kg, 4), "product_unit_of_measure": product.unit_of_measure, "product_unit_size_kg": round(unit_size_kg, 4), "prepared_by_name": values["prepared_by_name"], "status": values.get("status") or "saved", "notes": values.get("notes"), "warnings": warnings, "lines": lines, } def build_mix_calculator_options(db: Session, *, tenant_id: str) -> dict: # Aggregate mix totals in a single query instead of loading every # ingredient row for every product. The previous implementation was the # main slow path on first Mix Calculator open — it streamed the entire # tenant's recipe table just to compute one sum per product. mix_totals_rows = db.execute( select(MixIngredient.mix_id, func.coalesce(func.sum(MixIngredient.quantity_kg), 0.0)) .join(Mix, Mix.id == MixIngredient.mix_id) .where(Mix.tenant_id == tenant_id) .group_by(MixIngredient.mix_id) ).all() mix_totals: dict[int, float] = {mix_id: round(total or 0.0, 4) for mix_id, total in mix_totals_rows} products = db.scalars( select(Product) .where(Product.tenant_id == tenant_id) .options(joinedload(Product.mix)) .order_by(Product.client_name, Product.name) ).all() clients = sorted({product.client_name for product in products}) product_rows = [ { "product_id": product.id, "client_name": product.client_name, "product_name": product.name, "mix_id": product.mix_id, "mix_name": product.mix.name if product.mix else "", "unit_of_measure": product.unit_of_measure, "unit_size_kg": round(extract_unit_quantity_kg(product.unit_of_measure), 4), "mix_total_kg": mix_totals.get(product.mix_id, 0.0), } for product in products ] return {"clients": clients, "products": product_rows} def serialize_mix_calculator_session(session_record: MixCalculatorSession, auth_session: AuthSession) -> dict: total_bags = round(session_record.total_bags, 4) warnings: list[str] = [] bag_warning = _fractional_bag_warning(session_record.batch_size_kg, total_bags, session_record.product_unit_of_measure) if bag_warning: warnings.append(bag_warning) return { "id": session_record.id, "tenant_id": session_record.tenant_id, "session_number": session_record.session_number, "client_name": session_record.client_name, "product_id": session_record.product_id, "product_name": session_record.product_name, "mix_id": session_record.mix_id, "mix_name": session_record.mix_name, "mix_date": session_record.mix_date, "batch_size_kg": round(session_record.batch_size_kg, 4), "total_bags": total_bags, "total_kg": round(session_record.total_kg, 4), "product_unit_of_measure": session_record.product_unit_of_measure, "product_unit_size_kg": round(session_record.product_unit_size_kg, 4), "prepared_by_user_id": session_record.prepared_by_user_id, "prepared_by_name": session_record.prepared_by_name, "created_by": session_record.created_by, "status": session_record.status, "notes": session_record.notes, "created_at": session_record.created_at, "updated_at": session_record.updated_at, "warnings": warnings, "is_owner": session_record.prepared_by_user_id == auth_session.user_id, "lines": [ { "id": line.id, "raw_material_id": line.raw_material_id, "raw_material_name": line.raw_material_name, "required_kg": round(line.required_kg, 4), "mix_percentage": round(line.mix_percentage, 4), "unit": line.unit, "sort_order": line.sort_order, } for line in session_record.lines ], } def list_mix_calculator_sessions(db: Session, *, auth_session: AuthSession) -> list[dict]: sessions = db.scalars( _build_session_access_query(auth_session) .options(selectinload(MixCalculatorSession.lines)) .order_by(MixCalculatorSession.created_at.desc(), MixCalculatorSession.id.desc()) ).all() return [serialize_mix_calculator_session(session_record, auth_session) for session_record in sessions] def get_mix_calculator_session(db: Session, *, auth_session: AuthSession, session_id: int) -> MixCalculatorSession | None: return db.scalar( _build_session_access_query(auth_session) .where(MixCalculatorSession.id == session_id) .options(selectinload(MixCalculatorSession.lines)) ) def _next_session_number(db: Session, *, tenant_id: str, mix_date: date) -> str: prefix = f"HPP-{mix_date.strftime('%Y%m%d')}-" existing = db.scalars( select(MixCalculatorSession.session_number) .where( MixCalculatorSession.tenant_id == tenant_id, MixCalculatorSession.mix_date == mix_date, MixCalculatorSession.session_number.like(f"{prefix}%"), ) ).all() sequence = 1 if existing: sequence = max(int(value.rsplit("-", 1)[-1]) for value in existing) + 1 return f"{prefix}{sequence:04d}" def create_mix_calculator_session(db: Session, *, auth_session: AuthSession, payload: MixCalculatorSessionCreate) -> dict: preview = calculate_mix_calculator_preview(db, tenant_id=auth_session.tenant_id or "", payload=payload) session_record = MixCalculatorSession( tenant_id=auth_session.tenant_id or "default", session_number=_next_session_number(db, tenant_id=auth_session.tenant_id or "default", mix_date=payload.mix_date), client_name=preview["client_name"], product_id=preview["product_id"], product_name=preview["product_name"], mix_id=preview["mix_id"], mix_name=preview["mix_name"], mix_date=preview["mix_date"], batch_size_kg=preview["batch_size_kg"], total_bags=preview["total_bags"], total_kg=preview["total_kg"], product_unit_of_measure=preview["product_unit_of_measure"], product_unit_size_kg=preview["product_unit_size_kg"], prepared_by_user_id=auth_session.user_id, prepared_by_name=preview["prepared_by_name"], created_by=auth_session.email, status=preview["status"], notes=preview["notes"], ) session_record.lines = [ MixCalculatorSessionLine( tenant_id=auth_session.tenant_id or "default", raw_material_id=line["raw_material_id"], raw_material_name=line["raw_material_name"], required_kg=line["required_kg"], mix_percentage=line["mix_percentage"], unit=line["unit"], sort_order=line["sort_order"], ) for line in preview["lines"] ] db.add(session_record) db.commit() db.refresh(session_record) db.refresh(session_record, attribute_names=["lines"]) return serialize_mix_calculator_session(session_record, auth_session) def update_mix_calculator_session( db: Session, *, auth_session: AuthSession, session_record: MixCalculatorSession, payload: MixCalculatorSessionUpdate, ) -> dict: merged_values = { "mix_date": session_record.mix_date, "client_name": session_record.client_name, "product_id": session_record.product_id, "batch_size_kg": session_record.batch_size_kg, "prepared_by_name": session_record.prepared_by_name, "status": session_record.status, "notes": session_record.notes, } merged_values.update(payload.model_dump(exclude_unset=True)) preview = calculate_mix_calculator_preview(db, tenant_id=auth_session.tenant_id or "", payload=merged_values) session_record.client_name = preview["client_name"] session_record.product_id = preview["product_id"] session_record.product_name = preview["product_name"] session_record.mix_id = preview["mix_id"] session_record.mix_name = preview["mix_name"] session_record.mix_date = preview["mix_date"] session_record.batch_size_kg = preview["batch_size_kg"] session_record.total_bags = preview["total_bags"] session_record.total_kg = preview["total_kg"] session_record.product_unit_of_measure = preview["product_unit_of_measure"] session_record.product_unit_size_kg = preview["product_unit_size_kg"] session_record.prepared_by_name = preview["prepared_by_name"] session_record.status = preview["status"] session_record.notes = preview["notes"] session_record.lines.clear() session_record.lines.extend( [ MixCalculatorSessionLine( tenant_id=auth_session.tenant_id or "default", raw_material_id=line["raw_material_id"], raw_material_name=line["raw_material_name"], required_kg=line["required_kg"], mix_percentage=line["mix_percentage"], unit=line["unit"], sort_order=line["sort_order"], ) for line in preview["lines"] ] ) db.commit() db.refresh(session_record) db.refresh(session_record, attribute_names=["lines"]) return serialize_mix_calculator_session(session_record, auth_session)