from __future__ import annotations from dataclasses import asdict, dataclass import re from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from app.models.assumption import FreightCostRule, PackagingCostRule, ProcessCostRule from app.models.mix import Mix, MixIngredient from app.models.product import Product from app.models.raw_material import RawMaterial, RawMaterialPriceVersion @dataclass class PriceComputation: loss_cost: float cost_per_unit: float cost_per_kg: float def calculate_raw_material_cost(raw_material: RawMaterial, price: RawMaterialPriceVersion) -> PriceComputation: loss_cost = price.market_value * price.waste_percentage cost_per_unit = price.market_value + loss_cost cost_per_kg = cost_per_unit / raw_material.kg_per_unit return PriceComputation(loss_cost=round(loss_cost, 4), cost_per_unit=round(cost_per_unit, 4), cost_per_kg=round(cost_per_kg, 4)) def get_active_price(raw_material: RawMaterial) -> RawMaterialPriceVersion | None: active_prices = [price for price in raw_material.price_versions if price.status == "active"] if not active_prices: return None active_prices.sort(key=lambda item: item.effective_date, reverse=True) return active_prices[0] def calculate_mix_cost(db: Session, mix_id: int, overrides: dict | None = None) -> dict: overrides = overrides or {} mix = db.scalar( select(Mix) .where(Mix.id == mix_id) .options(selectinload(Mix.ingredients).selectinload(MixIngredient.raw_material).selectinload(RawMaterial.price_versions)) ) if mix is None: raise ValueError(f"Mix {mix_id} not found") total_mix_kg = 0.0 total_mix_cost = 0.0 warnings: list[str] = [] lines: list[dict] = [] for ingredient in mix.ingredients: raw_material = ingredient.raw_material active_price = get_active_price(raw_material) if active_price is None: warnings.append(f"{raw_material.name} has no active price") lines.append( { "id": ingredient.id, "raw_material_id": raw_material.id, "raw_material_name": raw_material.name, "quantity_kg": ingredient.quantity_kg, "cost_per_kg": None, "line_cost": None, "notes": ingredient.notes, } ) total_mix_kg += ingredient.quantity_kg continue market_value = overrides.get("raw_material_market_values", {}).get(str(raw_material.id), active_price.market_value) waste_percentage = overrides.get("raw_material_waste_percentages", {}).get(str(raw_material.id), active_price.waste_percentage) price_stub = RawMaterialPriceVersion( raw_material_id=raw_material.id, market_value=market_value, waste_percentage=waste_percentage, effective_date=active_price.effective_date, status=active_price.status, ) price_comp = calculate_raw_material_cost(raw_material, price_stub) line_cost = round(ingredient.quantity_kg * price_comp.cost_per_kg, 4) total_mix_kg += ingredient.quantity_kg total_mix_cost += line_cost lines.append( { "id": ingredient.id, "raw_material_id": raw_material.id, "raw_material_name": raw_material.name, "quantity_kg": ingredient.quantity_kg, "cost_per_kg": price_comp.cost_per_kg, "line_cost": line_cost, "notes": ingredient.notes, } ) if total_mix_kg == 0: warnings.append("Mix total kg is zero") mix_cost_per_kg = None else: mix_cost_per_kg = round(total_mix_cost / total_mix_kg, 4) if not mix.ingredients: warnings.append("Mix has no ingredients") return { "id": mix.id, "tenant_id": mix.tenant_id, "client_name": mix.client_name, "name": mix.name, "status": mix.status, "version": mix.version, "notes": mix.notes, "created_at": mix.created_at, "ingredients": lines, "total_mix_kg": round(total_mix_kg, 4), "total_mix_cost": round(total_mix_cost, 4), "mix_cost_per_kg": mix_cost_per_kg, "warnings": warnings, } def _get_process_costs(db: Session, process_name: str | None, overrides: dict) -> tuple[float, float, float, list[str]]: if not process_name: return 0.0, 0.0, 0.0, ["Missing bagging process"] tenant_id = overrides.get("tenant_id") query = select(ProcessCostRule).where(ProcessCostRule.process_name == process_name) if tenant_id: query = query.where(ProcessCostRule.tenant_id == tenant_id) rule = db.scalar(query) if rule is None: return 0.0, 0.0, 0.0, [f"Process rule not found for {process_name}"] override_costs = overrides.get("process_costs", {}).get(process_name, {}) grading = override_costs.get("grading_cost", rule.grading_cost) bagging = override_costs.get("bagging_cost", rule.bagging_cost) cracking = override_costs.get("cracking_cost", rule.cracking_cost) return grading, bagging, cracking, [] def _get_packaging_cost(db: Session, product: Product, overrides: dict) -> tuple[float, list[str]]: if product.own_bag: return 0.0, [] query = select(PackagingCostRule).where( PackagingCostRule.sale_type == product.sale_type, PackagingCostRule.unit_of_measure == product.unit_of_measure, PackagingCostRule.own_bag == product.own_bag, ) if product.tenant_id: query = query.where(PackagingCostRule.tenant_id == product.tenant_id) rule = db.scalar(query) if rule is None: return 0.0, ["Packaging rule not found"] return overrides.get("packaging_costs", {}).get(str(rule.id), rule.bag_cost), [] def _get_freight_cost(db: Session, product: Product, overrides: dict) -> tuple[float, list[str]]: query = select(FreightCostRule).where( FreightCostRule.sale_type == product.sale_type, FreightCostRule.unit_of_measure == product.unit_of_measure, ) if product.tenant_id: query = query.where(FreightCostRule.tenant_id == product.tenant_id) rule = db.scalar(query) if rule is None: return 0.0, ["Freight rule not found"] return overrides.get("freight_costs", {}).get(str(rule.id), rule.cost_per_unit), [] def _apply_margin(cost: float, margin: float | None) -> float | None: if margin is None: return None if margin >= 1: raise ValueError("Margin must be lower than 1") return round(cost / (1 - margin), 4) def _extract_unit_quantity_kg(unit_of_measure: str) -> float: normalized = unit_of_measure.strip().lower() if normalized == "tonne": return 1000.0 if normalized == "kg": return 1.0 match = re.search(r"(\d+(?:\.\d+)?)\s*kg", normalized) if match: return float(match.group(1)) return 1.0 def calculate_product_cost(db: Session, product_id: int, overrides: dict | None = None) -> dict: overrides = overrides or {} overrides = {**overrides, "tenant_id": overrides.get("tenant_id")} product = db.scalar(select(Product).where(Product.id == product_id).options(selectinload(Product.mix))) if product is None: raise ValueError(f"Product {product_id} not found") overrides["tenant_id"] = product.tenant_id mix_result = calculate_mix_cost(db, product.mix_id, overrides=overrides) warnings = list(mix_result["warnings"]) sale_unit_kg = _extract_unit_quantity_kg(product.unit_of_measure) mix_cost_per_kg = mix_result["mix_cost_per_kg"] or 0.0 cleaned_product_cost = round(mix_cost_per_kg * sale_unit_kg, 4) grading_cost, bagging_cost, cracking_cost, process_warnings = _get_process_costs(db, product.bagging_process, overrides) warnings.extend(process_warnings) grading_cost = round(grading_cost * sale_unit_kg, 4) bagging_cost = round(bagging_cost * sale_unit_kg, 4) cracking_cost = round(cracking_cost * sale_unit_kg, 4) bag_cost, packaging_warnings = _get_packaging_cost(db, product, overrides) warnings.extend(packaging_warnings) freight_cost, freight_warnings = _get_freight_cost(db, product, overrides) warnings.extend(freight_warnings) finished_product_delivered = round( cleaned_product_cost + grading_cost + bagging_cost + cracking_cost + bag_cost + freight_cost, 4, ) distributor_margin = overrides.get("product_margins", {}).get(str(product.id), {}).get("distributor_margin", product.distributor_margin) wholesale_margin = overrides.get("product_margins", {}).get(str(product.id), {}).get("wholesale_margin", product.wholesale_margin) return { "product_id": product.id, "product_name": product.name, "cleaned_product_cost": round(cleaned_product_cost, 4), "grading_cost": round(grading_cost, 4), "bagging_cost": round(bagging_cost, 4), "cracking_cost": round(cracking_cost, 4), "bag_cost": round(bag_cost, 4), "freight_cost": round(freight_cost, 4), "finished_product_delivered": finished_product_delivered, "distributor_price": _apply_margin(finished_product_delivered, distributor_margin), "wholesale_price": _apply_margin(finished_product_delivered, wholesale_margin), "warnings": warnings, "inputs": { "mix": { "mix_id": product.mix_id, "mix_name": product.mix.name, "total_mix_kg": mix_result["total_mix_kg"], "total_mix_cost": mix_result["total_mix_cost"], "mix_cost_per_kg": mix_result["mix_cost_per_kg"], "sale_unit_kg": sale_unit_kg, }, "product": { "sale_type": product.sale_type, "own_bag": product.own_bag, "unit_of_measure": product.unit_of_measure, "items_per_pallet": product.items_per_pallet, "bagging_process": product.bagging_process, "distributor_margin": distributor_margin, "wholesale_margin": wholesale_margin, }, }, } def serialize_raw_material(raw_material: RawMaterial) -> dict: active_price = get_active_price(raw_material) current_price = None if active_price is not None: price_comp = calculate_raw_material_cost(raw_material, active_price) current_price = { "id": active_price.id, "market_value": active_price.market_value, "waste_percentage": active_price.waste_percentage, "effective_date": active_price.effective_date, "status": active_price.status, "notes": active_price.notes, "created_at": active_price.created_at, **asdict(price_comp), } return { "id": raw_material.id, "tenant_id": raw_material.tenant_id, "name": raw_material.name, "supplier": raw_material.supplier, "unit_of_measure": raw_material.unit_of_measure, "kg_per_unit": raw_material.kg_per_unit, "status": raw_material.status, "notes": raw_material.notes, "created_at": raw_material.created_at, "current_price": current_price, }