Files
gw/backend/app/services/experiments.py
T
ponzischeme89 6d44e05de4 v1
2026-04-18 07:23:55 +12:00

252 lines
9.4 KiB
Python

from datetime import datetime, timezone
from decimal import Decimal
from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.experiments.registry import EXPERIMENT_REGISTRY
from app.models.experiment import Experiment, ExperimentEvent, ExperimentVariant
from app.schemas.experiments import (
ExperimentConversionCreate,
ExperimentDefinitionResponse,
ExperimentDefinitionUpdate,
ExperimentEventCreate,
ExperimentImpressionCreate,
ExperimentResult,
ExperimentVariantResult,
)
def experiment_exists(experiment_key: str, variant_key: str) -> bool:
definition = EXPERIMENT_REGISTRY.get(experiment_key)
if not definition:
return False
return any(variant["variant_key"] == variant_key for variant in definition["variants"])
async def sync_experiment_registry(db: AsyncSession) -> None:
result = await db.execute(
select(Experiment).options(selectinload(Experiment.variants))
)
existing = {experiment.experiment_key: experiment for experiment in result.scalars().all()}
for definition in EXPERIMENT_REGISTRY.values():
experiment = existing.get(definition["experiment_key"])
existing_variants: dict[str, ExperimentVariant] = {}
if experiment is None:
experiment = Experiment(
experiment_key=definition["experiment_key"],
cookie_name=definition["cookie_name"],
name=definition["name"],
description=definition.get("description"),
enabled=definition["enabled"],
eligible_routes=definition["eligible_routes"],
)
db.add(experiment)
await db.flush()
else:
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
for variant_definition in definition["variants"]:
variant = existing_variants.get(variant_definition["variant_key"])
if variant is None:
db.add(
ExperimentVariant(
experiment_id=experiment.id,
variant_key=variant_definition["variant_key"],
label=variant_definition["label"],
allocation=variant_definition["allocation"],
is_control=variant_definition["is_control"],
)
)
continue
variant.label = variant_definition["label"]
variant.allocation = variant_definition["allocation"]
variant.is_control = variant_definition["is_control"]
await db.flush()
async def list_experiment_definitions(db: AsyncSession) -> list[ExperimentDefinitionResponse]:
result = await db.execute(
select(Experiment).options(selectinload(Experiment.variants)).order_by(Experiment.experiment_key)
)
experiments = result.scalars().all()
return [
ExperimentDefinitionResponse(
experiment_key=experiment.experiment_key,
cookie_name=experiment.cookie_name,
name=experiment.name,
description=experiment.description,
enabled=experiment.enabled,
eligible_routes=experiment.eligible_routes,
variants=[
{
"variant_key": variant.variant_key,
"label": variant.label,
"allocation": variant.allocation,
"is_control": variant.is_control,
}
for variant in experiment.variants
],
)
for experiment in experiments
]
async def record_experiment_event(
db: AsyncSession,
payload: ExperimentImpressionCreate | ExperimentEventCreate | ExperimentConversionCreate,
) -> ExperimentEvent:
conversion_value = getattr(payload, "conversion_value", None)
timestamp = payload.timestamp
if timestamp.tzinfo is None:
timestamp = timestamp.replace(tzinfo=timezone.utc)
event = ExperimentEvent(
experiment_key=payload.experiment_key,
variant_key=payload.variant_key,
session_id=payload.session_id,
user_id=payload.user_id,
path=payload.path,
event_type=payload.event_name,
conversion_value=conversion_value,
metadata_=payload.metadata,
created_at=timestamp.astimezone(timezone.utc).replace(tzinfo=None),
)
db.add(event)
await db.flush()
await db.refresh(event)
return event
async def get_experiment_results(db: AsyncSession, experiment_key: str | None = None) -> list[ExperimentResult]:
stmt = select(
ExperimentEvent.experiment_key,
ExperimentEvent.variant_key,
func.sum(case((ExperimentEvent.event_type == "impression", 1), else_=0)).label("impressions"),
func.sum(case((ExperimentEvent.event_type == "cta_click", 1), else_=0)).label("cta_clicks"),
func.sum(case((ExperimentEvent.event_type == "form_start", 1), else_=0)).label("form_starts"),
func.sum(case((ExperimentEvent.event_type == "form_submit", 1), else_=0)).label("form_submits"),
func.sum(case((ExperimentEvent.event_type == "conversion", 1), else_=0)).label("conversions"),
func.count(func.distinct(ExperimentEvent.session_id)).label("unique_sessions"),
func.coalesce(func.sum(ExperimentEvent.conversion_value), Decimal("0")).label("conversion_value_total"),
).group_by(ExperimentEvent.experiment_key, ExperimentEvent.variant_key).order_by(
ExperimentEvent.experiment_key,
ExperimentEvent.variant_key,
)
if experiment_key:
stmt = stmt.where(ExperimentEvent.experiment_key == experiment_key)
result = await db.execute(stmt)
rows = result.all()
grouped: dict[str, list[ExperimentVariantResult]] = {}
for row in rows:
impressions = int(row.impressions or 0)
conversions = int(row.conversions or 0)
conversion_rate = conversions / impressions if impressions else 0.0
grouped.setdefault(row.experiment_key, []).append(
ExperimentVariantResult(
variant_key=row.variant_key,
impressions=impressions,
cta_clicks=int(row.cta_clicks or 0),
form_starts=int(row.form_starts or 0),
form_submits=int(row.form_submits or 0),
conversions=conversions,
unique_sessions=int(row.unique_sessions or 0),
conversion_rate=round(conversion_rate, 4),
conversion_value_total=float(row.conversion_value_total or 0),
)
)
generated_at = datetime.now(timezone.utc)
return [
ExperimentResult(
experiment_key=key,
generated_at=generated_at,
variants=variants,
)
for key, variants in grouped.items()
]
async def get_experiment_definition(db: AsyncSession, experiment_key: str) -> Experiment | None:
result = await db.execute(
select(Experiment)
.options(selectinload(Experiment.variants))
.where(Experiment.experiment_key == experiment_key)
)
return result.scalars().first()
async def upsert_experiment_definition(
db: AsyncSession,
experiment_key: str,
payload: ExperimentDefinitionUpdate,
) -> Experiment:
experiment = await get_experiment_definition(db, experiment_key)
duplicate_cookie = await db.execute(
select(Experiment).where(
Experiment.cookie_name == payload.cookie_name,
Experiment.experiment_key != experiment_key,
)
)
if duplicate_cookie.scalars().first():
raise ValueError("cookie_name is already used by another experiment")
if experiment is None:
experiment = Experiment(
experiment_key=experiment_key,
cookie_name=payload.cookie_name,
name=payload.name,
description=payload.description,
enabled=payload.enabled,
eligible_routes=payload.eligible_routes,
)
db.add(experiment)
await db.flush()
existing_variants: dict[str, ExperimentVariant] = {}
else:
experiment.cookie_name = payload.cookie_name
experiment.name = payload.name
experiment.description = payload.description
experiment.enabled = payload.enabled
experiment.eligible_routes = payload.eligible_routes
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
incoming_keys = {variant.variant_key for variant in payload.variants}
for variant in list(existing_variants.values()):
if variant.variant_key not in incoming_keys:
await db.delete(variant)
for variant_payload in payload.variants:
variant = existing_variants.get(variant_payload.variant_key)
if variant is None:
db.add(
ExperimentVariant(
experiment_id=experiment.id,
variant_key=variant_payload.variant_key,
label=variant_payload.label,
allocation=variant_payload.allocation,
is_control=variant_payload.is_control,
)
)
continue
variant.label = variant_payload.label
variant.allocation = variant_payload.allocation
variant.is_control = variant_payload.is_control
await db.flush()
refreshed = await get_experiment_definition(db, experiment_key)
assert refreshed is not None
return refreshed