from datetime import datetime, timezone from decimal import Decimal from sqlalchemy import case, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.experiments.registry import EXPERIMENT_REGISTRY from app.models.experiment import Experiment, ExperimentEvent, ExperimentVariant from app.schemas.experiments import ( ExperimentConversionCreate, ExperimentDefinitionResponse, ExperimentDefinitionUpdate, ExperimentEventCreate, ExperimentImpressionCreate, ExperimentResult, ExperimentVariantResult, ) def experiment_exists(experiment_key: str, variant_key: str) -> bool: definition = EXPERIMENT_REGISTRY.get(experiment_key) if not definition: return False return any(variant["variant_key"] == variant_key for variant in definition["variants"]) async def sync_experiment_registry(db: AsyncSession) -> None: result = await db.execute( select(Experiment).options(selectinload(Experiment.variants)) ) existing = {experiment.experiment_key: experiment for experiment in result.scalars().all()} for definition in EXPERIMENT_REGISTRY.values(): experiment = existing.get(definition["experiment_key"]) existing_variants: dict[str, ExperimentVariant] = {} if experiment is None: experiment = Experiment( experiment_key=definition["experiment_key"], cookie_name=definition["cookie_name"], name=definition["name"], description=definition.get("description"), enabled=definition["enabled"], eligible_routes=definition["eligible_routes"], ) db.add(experiment) await db.flush() else: existing_variants = {variant.variant_key: variant for variant in experiment.variants} for variant_definition in definition["variants"]: variant = existing_variants.get(variant_definition["variant_key"]) if variant is None: db.add( ExperimentVariant( experiment_id=experiment.id, variant_key=variant_definition["variant_key"], label=variant_definition["label"], allocation=variant_definition["allocation"], is_control=variant_definition["is_control"], ) ) continue variant.label = variant_definition["label"] variant.allocation = variant_definition["allocation"] variant.is_control = variant_definition["is_control"] await db.flush() async def list_experiment_definitions(db: AsyncSession) -> list[ExperimentDefinitionResponse]: result = await db.execute( select(Experiment).options(selectinload(Experiment.variants)).order_by(Experiment.experiment_key) ) experiments = result.scalars().all() return [ ExperimentDefinitionResponse( experiment_key=experiment.experiment_key, cookie_name=experiment.cookie_name, name=experiment.name, description=experiment.description, enabled=experiment.enabled, eligible_routes=experiment.eligible_routes, variants=[ { "variant_key": variant.variant_key, "label": variant.label, "allocation": variant.allocation, "is_control": variant.is_control, } for variant in experiment.variants ], ) for experiment in experiments ] async def record_experiment_event( db: AsyncSession, payload: ExperimentImpressionCreate | ExperimentEventCreate | ExperimentConversionCreate, ) -> ExperimentEvent: conversion_value = getattr(payload, "conversion_value", None) timestamp = payload.timestamp if timestamp.tzinfo is None: timestamp = timestamp.replace(tzinfo=timezone.utc) event = ExperimentEvent( experiment_key=payload.experiment_key, variant_key=payload.variant_key, session_id=payload.session_id, user_id=payload.user_id, path=payload.path, event_type=payload.event_name, conversion_value=conversion_value, metadata_=payload.metadata, created_at=timestamp.astimezone(timezone.utc).replace(tzinfo=None), ) db.add(event) await db.flush() await db.refresh(event) return event async def get_experiment_results(db: AsyncSession, experiment_key: str | None = None) -> list[ExperimentResult]: stmt = select( ExperimentEvent.experiment_key, ExperimentEvent.variant_key, func.sum(case((ExperimentEvent.event_type == "impression", 1), else_=0)).label("impressions"), func.sum(case((ExperimentEvent.event_type == "cta_click", 1), else_=0)).label("cta_clicks"), func.sum(case((ExperimentEvent.event_type == "form_start", 1), else_=0)).label("form_starts"), func.sum(case((ExperimentEvent.event_type == "form_submit", 1), else_=0)).label("form_submits"), func.sum(case((ExperimentEvent.event_type == "conversion", 1), else_=0)).label("conversions"), func.count(func.distinct(ExperimentEvent.session_id)).label("unique_sessions"), func.coalesce(func.sum(ExperimentEvent.conversion_value), Decimal("0")).label("conversion_value_total"), ).group_by(ExperimentEvent.experiment_key, ExperimentEvent.variant_key).order_by( ExperimentEvent.experiment_key, ExperimentEvent.variant_key, ) if experiment_key: stmt = stmt.where(ExperimentEvent.experiment_key == experiment_key) result = await db.execute(stmt) rows = result.all() grouped: dict[str, list[ExperimentVariantResult]] = {} for row in rows: impressions = int(row.impressions or 0) conversions = int(row.conversions or 0) conversion_rate = conversions / impressions if impressions else 0.0 grouped.setdefault(row.experiment_key, []).append( ExperimentVariantResult( variant_key=row.variant_key, impressions=impressions, cta_clicks=int(row.cta_clicks or 0), form_starts=int(row.form_starts or 0), form_submits=int(row.form_submits or 0), conversions=conversions, unique_sessions=int(row.unique_sessions or 0), conversion_rate=round(conversion_rate, 4), conversion_value_total=float(row.conversion_value_total or 0), ) ) generated_at = datetime.now(timezone.utc) return [ ExperimentResult( experiment_key=key, generated_at=generated_at, variants=variants, ) for key, variants in grouped.items() ] async def get_experiment_definition(db: AsyncSession, experiment_key: str) -> Experiment | None: result = await db.execute( select(Experiment) .options(selectinload(Experiment.variants)) .where(Experiment.experiment_key == experiment_key) ) return result.scalars().first() async def upsert_experiment_definition( db: AsyncSession, experiment_key: str, payload: ExperimentDefinitionUpdate, ) -> Experiment: experiment = await get_experiment_definition(db, experiment_key) duplicate_cookie = await db.execute( select(Experiment).where( Experiment.cookie_name == payload.cookie_name, Experiment.experiment_key != experiment_key, ) ) if duplicate_cookie.scalars().first(): raise ValueError("cookie_name is already used by another experiment") if experiment is None: experiment = Experiment( experiment_key=experiment_key, cookie_name=payload.cookie_name, name=payload.name, description=payload.description, enabled=payload.enabled, eligible_routes=payload.eligible_routes, ) db.add(experiment) await db.flush() existing_variants: dict[str, ExperimentVariant] = {} else: experiment.cookie_name = payload.cookie_name experiment.name = payload.name experiment.description = payload.description experiment.enabled = payload.enabled experiment.eligible_routes = payload.eligible_routes existing_variants = {variant.variant_key: variant for variant in experiment.variants} incoming_keys = {variant.variant_key for variant in payload.variants} for variant in list(existing_variants.values()): if variant.variant_key not in incoming_keys: await db.delete(variant) for variant_payload in payload.variants: variant = existing_variants.get(variant_payload.variant_key) if variant is None: db.add( ExperimentVariant( experiment_id=experiment.id, variant_key=variant_payload.variant_key, label=variant_payload.label, allocation=variant_payload.allocation, is_control=variant_payload.is_control, ) ) continue variant.label = variant_payload.label variant.allocation = variant_payload.allocation variant.is_control = variant_payload.is_control await db.flush() refreshed = await get_experiment_definition(db, experiment_key) assert refreshed is not None return refreshed