from __future__ import annotations from dataclasses import dataclass import json import math from sqlalchemy import select from sqlalchemy.orm import Session from app.models.product import Product from app.models.product_costing import ( ProductCostBagInput, ProductCostBaseInput, ProductCostClientInput, ProductCostFreightInput, ProductCostItem, ProductCostProcessInput, ) from app.services.costing_engine import calculate_product_cost UNIT_TYPES = ("Standard", "Bulka", "1.5 kg", "Per Unit") OWN_BAG_VALUES = ("Yes", "No Bag") ZERO_GRADING_CLIENTS = {"PHF Horse Mixes", "Peckish", "Hay & Straw"} PROCESS_NAMES = ("Bagging + Grading", "Standard Bagging", "PHF Horse Mixes", "Peckish", "Hay & Straw") BAG_INPUTS = { "20kg_bag": "20kg bag", "bulka_bag": "Bulka bag", "own_bag_credit": "Own bag credit", "1_5kg_bagging": "1.5kg bagging", "peckish_bag": "Peckish bag", } FREIGHT_INPUTS = { "freight_per_pallet": "Freight per pallet", "peckish_freight_per_pallet": "Peckish freight per pallet", "hay_straw_freight_per_pallet": "Hay & Straw freight per pallet", } @dataclass(frozen=True) class ProductCostInputItem: client_category: str product_name: str mix_product_name: str unit_type: str own_bag: str | None unit_kg: float | None items_per_pallet: int | None bagging_process: str | None manual_distributor_margin: float | None manual_wholesale_margin: float | None @dataclass(frozen=True) class ProductCostAssumptions: grading_per_kg: float cracking_per_kg: float process_costs: dict[str, float] client_margins: dict[str, dict[str, float | None]] bag_costs: dict[str, float] freight_costs: dict[str, float] @dataclass(frozen=True) class ProductCostCalculation: cleaned_product_cost_per_kg: float | None grading_cost_per_kg: float | None bagging_cost_per_kg: float | None cracking_cost_per_kg: float | None bag_cost_per_unit: float | None freight_cost_per_unit: float | None finished_product_delivered_cost: float | None distributor_price: float | None wholesale_price: float | None warnings: list[str] def _round4(value: float | None) -> float | None: return None if value is None else round(value, 4) def _ceil_to(value: float, digits: int) -> float: factor = 10**digits return math.ceil((value * factor) - 1e-9) / factor def _valid_margin(value: float | None, label: str, warnings: list[str]) -> float | None: if value is None: return None if value < 0 or value >= 1: warnings.append(f"Invalid {label} margin") return None return value def calculate_product_cost_item( item: ProductCostInputItem, assumptions: ProductCostAssumptions, cleaned_product_cost_per_kg: float | None, ) -> ProductCostCalculation: warnings: list[str] = [] unit_type = item.unit_type or "Standard" unit_kg = item.unit_kg items_per_pallet = item.items_per_pallet if unit_type not in UNIT_TYPES: warnings.append("Invalid unit type") if cleaned_product_cost_per_kg is None: warnings.append("Missing mix/product cost lookup") if unit_kg is None or unit_kg <= 0: warnings.append("Missing unit kg") if items_per_pallet is None or items_per_pallet <= 0: warnings.append("Missing pallet quantity") grading_cost_per_kg = 0.0 if item.client_category not in ZERO_GRADING_CLIENTS and item.bagging_process: grading_cost_per_kg = assumptions.grading_per_kg bagging_cost_per_kg = assumptions.process_costs.get(item.bagging_process or "", 0.0) if item.bagging_process and item.bagging_process not in assumptions.process_costs: warnings.append("Missing bagging process cost") cracking_cost_per_kg = assumptions.cracking_per_kg if "cracked" in item.product_name.lower() else 0.0 bag_cost_per_unit = 0.0 if item.client_category == "Peckish": bag_cost_per_unit = assumptions.bag_costs.get("peckish_bag", 0.0) elif unit_type == "1.5 kg": bag_cost_per_unit = assumptions.bag_costs.get("1_5kg_bagging", 0.0) elif item.own_bag == "No Bag": bag_cost_per_unit = 0.0 elif unit_type == "Standard": bag_cost_per_unit = assumptions.bag_costs.get("20kg_bag", 0.0) elif unit_type == "Bulka": bag_cost_per_unit = assumptions.bag_costs.get("bulka_bag", 0.0) / unit_kg if unit_kg and unit_kg > 0 else None if bag_cost_per_unit is not None and item.own_bag == "Yes": bag_cost_per_unit -= assumptions.bag_costs.get("own_bag_credit", 0.0) freight_cost_per_unit: float | None if item.client_category == "Peckish": freight_cost_per_unit = assumptions.freight_costs.get("peckish_freight_per_pallet", 0.0) / items_per_pallet if items_per_pallet and items_per_pallet > 0 else None elif item.client_category == "Hay & Straw": freight_cost_per_unit = assumptions.freight_costs.get("hay_straw_freight_per_pallet", 0.0) / items_per_pallet if items_per_pallet and items_per_pallet > 0 else None elif unit_type in {"Standard", "Per Unit"}: freight_cost_per_unit = assumptions.freight_costs.get("freight_per_pallet", 0.0) / items_per_pallet if items_per_pallet and items_per_pallet > 0 else None elif unit_type == "Bulka": freight_cost_per_unit = assumptions.freight_costs.get("freight_per_pallet", 0.0) / unit_kg if unit_kg and unit_kg > 0 else None else: freight_cost_per_unit = assumptions.freight_costs.get("freight_per_pallet", 0.0) / 1000 * unit_kg if unit_kg and unit_kg > 0 else None finished_cost = None components = [cleaned_product_cost_per_kg, grading_cost_per_kg, bagging_cost_per_kg, cracking_cost_per_kg, bag_cost_per_unit, freight_cost_per_unit] if all(value is not None for value in components) and unit_kg and unit_kg > 0: per_kg_cost = cleaned_product_cost_per_kg + grading_cost_per_kg + bagging_cost_per_kg + cracking_cost_per_kg # type: ignore[operator] if unit_type == "Standard": finished_cost = per_kg_cost * unit_kg + bag_cost_per_unit + freight_cost_per_unit # type: ignore[operator] elif unit_type in {"Bulka", "Per Unit"}: finished_cost = per_kg_cost + bag_cost_per_unit + freight_cost_per_unit # type: ignore[operator] else: finished_cost = (per_kg_cost * unit_kg + bag_cost_per_unit + freight_cost_per_unit) * 8 # type: ignore[operator] client_margin = assumptions.client_margins.get(item.client_category, {}) distributor_margin = _valid_margin( item.manual_distributor_margin if item.manual_distributor_margin is not None else client_margin.get("distributor_margin"), "distributor", warnings, ) wholesale_margin = _valid_margin( item.manual_wholesale_margin if item.manual_wholesale_margin is not None else client_margin.get("wholesale_margin"), "wholesale", warnings, ) distributor_price = finished_cost / (1 - distributor_margin) if finished_cost is not None and distributor_margin is not None else None wholesale_price = finished_cost / (1 - wholesale_margin) if finished_cost is not None and wholesale_margin is not None else None if wholesale_price is not None: wholesale_price = _ceil_to(wholesale_price, 2 if item.client_category == "Straight Grain" and unit_type == "Bulka" else 1) return ProductCostCalculation( cleaned_product_cost_per_kg=_round4(cleaned_product_cost_per_kg), grading_cost_per_kg=_round4(grading_cost_per_kg), bagging_cost_per_kg=_round4(bagging_cost_per_kg), cracking_cost_per_kg=_round4(cracking_cost_per_kg), bag_cost_per_unit=_round4(bag_cost_per_unit), freight_cost_per_unit=_round4(freight_cost_per_unit), finished_product_delivered_cost=_round4(finished_cost), distributor_price=_round4(distributor_price), wholesale_price=_round4(wholesale_price), warnings=warnings, ) def _item_input(item: ProductCostItem) -> ProductCostInputItem: return ProductCostInputItem( client_category=item.client_category, product_name=item.product_name, mix_product_name=item.mix_product_name, unit_type=item.unit_type, own_bag=item.own_bag, unit_kg=item.unit_kg, items_per_pallet=item.items_per_pallet, bagging_process=item.bagging_process, manual_distributor_margin=item.manual_distributor_margin, manual_wholesale_margin=item.manual_wholesale_margin, ) def get_product_costing_assumptions(db: Session, tenant_id: str) -> ProductCostAssumptions: base = db.scalar(select(ProductCostBaseInput).where(ProductCostBaseInput.tenant_id == tenant_id)) if base is None: base = ProductCostBaseInput(tenant_id=tenant_id) db.add(base) db.flush() process_costs = { row.process_name: row.cost_per_kg for row in db.scalars(select(ProductCostProcessInput).where(ProductCostProcessInput.tenant_id == tenant_id)).all() } client_margins = { row.client_category: { "distributor_margin": row.distributor_margin, "wholesale_margin": row.wholesale_margin, } for row in db.scalars(select(ProductCostClientInput).where(ProductCostClientInput.tenant_id == tenant_id)).all() } bag_costs = { row.input_key: row.cost for row in db.scalars(select(ProductCostBagInput).where(ProductCostBagInput.tenant_id == tenant_id)).all() } freight_costs = { row.input_key: row.cost for row in db.scalars(select(ProductCostFreightInput).where(ProductCostFreightInput.tenant_id == tenant_id)).all() } return ProductCostAssumptions( grading_per_kg=base.grading_per_kg or ((base.grading_per_tonne or 0.0) / 1000), cracking_per_kg=base.cracking_per_kg or ((base.cracking_per_tonne or 0.0) / 1000), process_costs=process_costs, client_margins=client_margins, bag_costs=bag_costs, freight_costs=freight_costs, ) def lookup_cleaned_product_cost_per_kg(db: Session, item: ProductCostItem) -> float | None: product = db.scalar( select(Product) .where( Product.tenant_id == item.tenant_id, Product.client_name == item.client_category, Product.name == item.mix_product_name, ) .limit(1) ) if product is None: product = db.scalar( select(Product) .where( Product.tenant_id == item.tenant_id, Product.client_name == item.client_category, Product.name == item.product_name, ) .limit(1) ) if product is None: return None try: result = calculate_product_cost(db, product.id) except ValueError: return None mix = (result.get("inputs") or {}).get("mix") or {} return mix.get("mix_cost_per_kg") def apply_calculation(item: ProductCostItem, calculation: ProductCostCalculation) -> ProductCostItem: item.cleaned_product_cost_per_kg = calculation.cleaned_product_cost_per_kg item.grading_cost_per_kg = calculation.grading_cost_per_kg item.bagging_cost_per_kg = calculation.bagging_cost_per_kg item.cracking_cost_per_kg = calculation.cracking_cost_per_kg item.bag_cost_per_unit = calculation.bag_cost_per_unit item.freight_cost_per_unit = calculation.freight_cost_per_unit item.finished_product_delivered_cost = calculation.finished_product_delivered_cost item.distributor_price = calculation.distributor_price item.wholesale_price = calculation.wholesale_price item.warnings = json.dumps(calculation.warnings) return item def recalculate_product_cost_item(db: Session, item: ProductCostItem) -> ProductCostItem: assumptions = get_product_costing_assumptions(db, item.tenant_id) cleaned_cost = lookup_cleaned_product_cost_per_kg(db, item) calculation = calculate_product_cost_item(_item_input(item), assumptions, cleaned_cost) return apply_calculation(item, calculation) def recalculate_all_product_cost_items(db: Session, tenant_id: str) -> int: items = db.scalars(select(ProductCostItem).where(ProductCostItem.tenant_id == tenant_id)).all() for item in items: recalculate_product_cost_item(db, item) return len(items) def serialize_product_cost_item(item: ProductCostItem) -> dict: warnings = [] if item.warnings: try: warnings = json.loads(item.warnings) except json.JSONDecodeError: warnings = [item.warnings] return { "id": item.id, "tenant_id": item.tenant_id, "client_category": item.client_category, "item_id": item.item_id, "product_name": item.product_name, "mix_product_name": item.mix_product_name, "unit_type": item.unit_type, "own_bag": item.own_bag, "unit_kg": item.unit_kg, "items_per_pallet": item.items_per_pallet, "bagging_process": item.bagging_process, "manual_distributor_margin": item.manual_distributor_margin, "manual_wholesale_margin": item.manual_wholesale_margin, "cleaned_product_cost_per_kg": item.cleaned_product_cost_per_kg, "grading_cost_per_kg": item.grading_cost_per_kg, "bagging_cost_per_kg": item.bagging_cost_per_kg, "cracking_cost_per_kg": item.cracking_cost_per_kg, "bag_cost_per_unit": item.bag_cost_per_unit, "freight_cost_per_unit": item.freight_cost_per_unit, "finished_product_delivered_cost": item.finished_product_delivered_cost, "distributor_price": item.distributor_price, "wholesale_price": item.wholesale_price, "warnings": warnings, "created_at": item.created_at, "updated_at": item.updated_at, }