Files
gw/backend/app/services/experiments.py
T

252 lines
9.4 KiB
Python
Raw Normal View History

2026-04-18 07:23:55 +12:00
from datetime import datetime, timezone
from decimal import Decimal
from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.experiments.registry import EXPERIMENT_REGISTRY
from app.models.experiment import Experiment, ExperimentEvent, ExperimentVariant
from app.schemas.experiments import (
ExperimentConversionCreate,
ExperimentDefinitionResponse,
ExperimentDefinitionUpdate,
ExperimentEventCreate,
ExperimentImpressionCreate,
ExperimentResult,
ExperimentVariantResult,
)
def experiment_exists(experiment_key: str, variant_key: str) -> bool:
definition = EXPERIMENT_REGISTRY.get(experiment_key)
if not definition:
return False
return any(variant["variant_key"] == variant_key for variant in definition["variants"])
async def sync_experiment_registry(db: AsyncSession) -> None:
result = await db.execute(
select(Experiment).options(selectinload(Experiment.variants))
)
existing = {experiment.experiment_key: experiment for experiment in result.scalars().all()}
for definition in EXPERIMENT_REGISTRY.values():
experiment = existing.get(definition["experiment_key"])
existing_variants: dict[str, ExperimentVariant] = {}
if experiment is None:
experiment = Experiment(
experiment_key=definition["experiment_key"],
cookie_name=definition["cookie_name"],
name=definition["name"],
description=definition.get("description"),
enabled=definition["enabled"],
eligible_routes=definition["eligible_routes"],
)
db.add(experiment)
await db.flush()
else:
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
for variant_definition in definition["variants"]:
variant = existing_variants.get(variant_definition["variant_key"])
if variant is None:
db.add(
ExperimentVariant(
experiment_id=experiment.id,
variant_key=variant_definition["variant_key"],
label=variant_definition["label"],
allocation=variant_definition["allocation"],
is_control=variant_definition["is_control"],
)
)
continue
variant.label = variant_definition["label"]
variant.allocation = variant_definition["allocation"]
variant.is_control = variant_definition["is_control"]
await db.flush()
async def list_experiment_definitions(db: AsyncSession) -> list[ExperimentDefinitionResponse]:
result = await db.execute(
select(Experiment).options(selectinload(Experiment.variants)).order_by(Experiment.experiment_key)
)
experiments = result.scalars().all()
return [
ExperimentDefinitionResponse(
experiment_key=experiment.experiment_key,
cookie_name=experiment.cookie_name,
name=experiment.name,
description=experiment.description,
enabled=experiment.enabled,
eligible_routes=experiment.eligible_routes,
variants=[
{
"variant_key": variant.variant_key,
"label": variant.label,
"allocation": variant.allocation,
"is_control": variant.is_control,
}
for variant in experiment.variants
],
)
for experiment in experiments
]
async def record_experiment_event(
db: AsyncSession,
payload: ExperimentImpressionCreate | ExperimentEventCreate | ExperimentConversionCreate,
) -> ExperimentEvent:
conversion_value = getattr(payload, "conversion_value", None)
timestamp = payload.timestamp
if timestamp.tzinfo is None:
timestamp = timestamp.replace(tzinfo=timezone.utc)
event = ExperimentEvent(
experiment_key=payload.experiment_key,
variant_key=payload.variant_key,
session_id=payload.session_id,
user_id=payload.user_id,
path=payload.path,
event_type=payload.event_name,
conversion_value=conversion_value,
metadata_=payload.metadata,
created_at=timestamp.astimezone(timezone.utc).replace(tzinfo=None),
)
db.add(event)
await db.flush()
await db.refresh(event)
return event
async def get_experiment_results(db: AsyncSession, experiment_key: str | None = None) -> list[ExperimentResult]:
stmt = select(
ExperimentEvent.experiment_key,
ExperimentEvent.variant_key,
func.sum(case((ExperimentEvent.event_type == "impression", 1), else_=0)).label("impressions"),
func.sum(case((ExperimentEvent.event_type == "cta_click", 1), else_=0)).label("cta_clicks"),
func.sum(case((ExperimentEvent.event_type == "form_start", 1), else_=0)).label("form_starts"),
func.sum(case((ExperimentEvent.event_type == "form_submit", 1), else_=0)).label("form_submits"),
func.sum(case((ExperimentEvent.event_type == "conversion", 1), else_=0)).label("conversions"),
func.count(func.distinct(ExperimentEvent.session_id)).label("unique_sessions"),
func.coalesce(func.sum(ExperimentEvent.conversion_value), Decimal("0")).label("conversion_value_total"),
).group_by(ExperimentEvent.experiment_key, ExperimentEvent.variant_key).order_by(
ExperimentEvent.experiment_key,
ExperimentEvent.variant_key,
)
if experiment_key:
stmt = stmt.where(ExperimentEvent.experiment_key == experiment_key)
result = await db.execute(stmt)
rows = result.all()
grouped: dict[str, list[ExperimentVariantResult]] = {}
for row in rows:
impressions = int(row.impressions or 0)
conversions = int(row.conversions or 0)
conversion_rate = conversions / impressions if impressions else 0.0
grouped.setdefault(row.experiment_key, []).append(
ExperimentVariantResult(
variant_key=row.variant_key,
impressions=impressions,
cta_clicks=int(row.cta_clicks or 0),
form_starts=int(row.form_starts or 0),
form_submits=int(row.form_submits or 0),
conversions=conversions,
unique_sessions=int(row.unique_sessions or 0),
conversion_rate=round(conversion_rate, 4),
conversion_value_total=float(row.conversion_value_total or 0),
)
)
generated_at = datetime.now(timezone.utc)
return [
ExperimentResult(
experiment_key=key,
generated_at=generated_at,
variants=variants,
)
for key, variants in grouped.items()
]
async def get_experiment_definition(db: AsyncSession, experiment_key: str) -> Experiment | None:
result = await db.execute(
select(Experiment)
.options(selectinload(Experiment.variants))
.where(Experiment.experiment_key == experiment_key)
)
return result.scalars().first()
async def upsert_experiment_definition(
db: AsyncSession,
experiment_key: str,
payload: ExperimentDefinitionUpdate,
) -> Experiment:
experiment = await get_experiment_definition(db, experiment_key)
duplicate_cookie = await db.execute(
select(Experiment).where(
Experiment.cookie_name == payload.cookie_name,
Experiment.experiment_key != experiment_key,
)
)
if duplicate_cookie.scalars().first():
raise ValueError("cookie_name is already used by another experiment")
if experiment is None:
experiment = Experiment(
experiment_key=experiment_key,
cookie_name=payload.cookie_name,
name=payload.name,
description=payload.description,
enabled=payload.enabled,
eligible_routes=payload.eligible_routes,
)
db.add(experiment)
await db.flush()
existing_variants: dict[str, ExperimentVariant] = {}
else:
experiment.cookie_name = payload.cookie_name
experiment.name = payload.name
experiment.description = payload.description
experiment.enabled = payload.enabled
experiment.eligible_routes = payload.eligible_routes
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
incoming_keys = {variant.variant_key for variant in payload.variants}
for variant in list(existing_variants.values()):
if variant.variant_key not in incoming_keys:
await db.delete(variant)
for variant_payload in payload.variants:
variant = existing_variants.get(variant_payload.variant_key)
if variant is None:
db.add(
ExperimentVariant(
experiment_id=experiment.id,
variant_key=variant_payload.variant_key,
label=variant_payload.label,
allocation=variant_payload.allocation,
is_control=variant_payload.is_control,
)
)
continue
variant.label = variant_payload.label
variant.allocation = variant_payload.allocation
variant.is_control = variant_payload.is_control
await db.flush()
refreshed = await get_experiment_definition(db, experiment_key)
assert refreshed is not None
return refreshed