{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Gaussians\n\nThis example illustrates ABX discriminability on the simplest possible classes: samples drawn from\nGaussians. We start in 2D for visual intuition, then move to 1D where the ABX score admits a closed\nform and we can check ``fastabx`` against the theoretical value.\n\nThroughout, we use the Euclidean distance.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import math\n\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nfrom fastabx import Dataset, Score, Task"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Two 2D Gaussians\n\nWe draw two clusters from $\\mathcal{N}(\\mu_A, \\Sigma)$ and $\\mathcal{N}(\\mu_B, \\Sigma)$\nwith a shared (correlated) covariance and a fixed diagonal shift between the means. The reported\nABX score is the probability that a probe drawn from class $A$ ends up closer to another\nclass-$A$ sample than to a class-$B$ sample.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = 100\ndiagonal_shift = 4\nmean = np.zeros(2)\ncov = np.array([[4, -2], [-2, 3]])\n\nrng = np.random.default_rng(seed=0)\nfirst = rng.multivariate_normal(mean, cov, n)\nsecond = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)\n\ndataset = Dataset.from_numpy(np.vstack([first, second]), {\"label\": [0] * n + [1] * n})\ntask = Task(dataset, on=\"label\")\nscore = Score(task, \"euclidean\")\n\nplt.scatter(*first.T, alpha=0.5)\nplt.scatter(*second.T, alpha=0.5)\nplt.axis(\"equal\")\nplt.grid()\nplt.title(f\"ABX: {1 - score.collapse():.3%}\")\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Two 2D Gaussians with increasing shift\n\nNow we keep the same covariance for both classes and sweep the displacement between their means\nalong the diagonal. The score climbs from chance level (fully overlapping clouds) up toward\n$1$ as the clusters separate.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = 100\nshift = np.ones(1)\nmean = np.zeros(2)\ncov = np.array([[4, -2], [-2, 3]])\n\nrng = np.random.default_rng(seed=0)\nfirst = rng.multivariate_normal(mean, cov, n)\nsecond = rng.multivariate_normal(mean, cov, n)\n\nfig, axes = plt.subplots(figsize=(10, 8), nrows=3, ncols=3, sharex=True, sharey=True)\nfor ax in axes.flatten():\n    dataset = Dataset.from_numpy(np.vstack([first, second]), {\"label\": [0] * n + [1] * n})\n    task = Task(dataset, on=\"label\")\n    score = Score(task, \"euclidean\")\n\n    ax.scatter(*first.T, s=10, alpha=0.5)\n    ax.scatter(*second.T, s=10, alpha=0.5)\n    ax.grid()\n    ax.set_title(f\"ABX: {1 - score.collapse():.3%}\")\n    second += shift\n\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Closed-form ABX for two 1D Gaussians\n\nIn 1D with a shared variance, the ABX score can be written in closed form, which makes it a good\nsanity check for the implementation. Let $A = \\mathcal{N}(\\mu_a, \\sigma^2)$ and\n$B = \\mathcal{N}(\\mu_b, \\sigma^2)$, and write the normalized separation\n$t = (\\mu_a - \\mu_b) / \\sigma$. Then\n\n\\begin{align}\\mathrm{ABX}(A, B) \\;=\\; \\mathbb{P}\\bigl(|x-a| < |x-b|\\bigr) \\;=\\; \\frac{1}{2}\n    + \\frac{1}{2}\\,\\operatorname{erf}\\!\\left(\\frac{t}{2}\\right)\\operatorname{erf}\\!\\left(\\frac{t}{2\\sqrt{3}}\\right),\\end{align}\n\nwhere $a \\sim A$, $x \\sim A$, $b \\sim B$ are mutually independent. The result\ndepends only on $t$: it equals $\\tfrac{1}{2}$ at $t = 0$, tends to $1$ as\n$|t| \\to \\infty$, and is symmetric under $\\mu_a \\leftrightarrow \\mu_b$.\n\n.. dropdown:: Derivation\n\n    **Step 1: reduce the event to a product of two Gaussians.** Both distances are nonnegative, so\n    squaring preserves the inequality:\n\n    .. math::\n\n        |x-a| < |x-b| \\iff (x-a)^2 < (x-b)^2.\n\n    Expanding, cancelling $x^2$, and factoring gives\n\n    .. math::\n\n        (a-b)(a+b-2x) < 0.\n\n    Introducing $U = a-b$ and $V = a+b-2x$,\n\n    .. math::\n\n        \\mathbb{P}\\bigl(|x-a|<|x-b|\\bigr) = \\mathbb{P}(UV < 0).\n\n    **Step 2: joint distribution of** $U$ **and** $V$ **.** Both are linear\n    combinations of independent Gaussians, hence jointly Gaussian. With $m = \\mu_a - \\mu_b$,\n\n    .. math::\n\n        \\mathbb{E}[U] = m, \\qquad \\mathbb{E}[V] = -m,\n\n    .. math::\n\n        \\operatorname{Var}(U) = 2\\sigma^2, \\qquad \\operatorname{Var}(V) = 6\\sigma^2,\n\n    .. math::\n\n        \\operatorname{Cov}(U,V) = \\operatorname{Cov}(a,a) - \\operatorname{Cov}(b,b) = 0.\n\n    Zero covariance for jointly Gaussian variables implies independence:\n\n    .. math::\n\n        U \\sim \\mathcal{N}(m,\\,2\\sigma^2), \\qquad V \\sim \\mathcal{N}(-m,\\,6\\sigma^2), \\qquad U \\perp V.\n\n    **Step 3: factor the probability.** For independent $U, V$,\n\n    .. math::\n\n        \\mathbb{P}(UV < 0) = \\mathbb{P}(U>0)\\mathbb{P}(V<0) + \\mathbb{P}(U<0)\\mathbb{P}(V>0).\n\n    With $p = \\mathbb{P}(U>0)$ and $q = \\mathbb{P}(V>0)$, this rearranges to\n\n    .. math::\n\n        \\mathbb{P}(UV<0) = p + q - 2pq = \\tfrac{1}{2} - 2\\bigl(p-\\tfrac{1}{2}\\bigr)\\bigl(q-\\tfrac{1}{2}\\bigr).\n\n    **Step 4: evaluate.** For $W \\sim \\mathcal{N}(\\mu_W, \\sigma_W^2)$,\n\n    .. math::\n\n        \\mathbb{P}(W>0) - \\tfrac{1}{2} = \\tfrac{1}{2}\\,\\operatorname{erf}\\!\\left(\\frac{\\mu_W}{\\sqrt{2}\\,\\sigma_W}\\right).\n\n    Applied to $U$ ($\\sigma_U = \\sqrt{2}\\,\\sigma$) and $V$\n    ($\\sigma_V = \\sqrt{6}\\,\\sigma$, $\\mu_V = -m$),\n\n    .. math::\n\n        p - \\tfrac{1}{2} = \\tfrac{1}{2}\\,\\operatorname{erf}\\!\\left(\\frac{m}{2\\sigma}\\right), \\qquad\n        q - \\tfrac{1}{2} = -\\tfrac{1}{2}\\,\\operatorname{erf}\\!\\left(\\frac{m}{2\\sqrt{3}\\,\\sigma}\\right),\n\n    using $\\sqrt{2}\\cdot\\sqrt{6} = 2\\sqrt{3}$ and\n    $\\operatorname{erf}(-z) = -\\operatorname{erf}(z)$. Substituting back,\n\n    .. math::\n\n        \\mathbb{P}(UV<0) = \\tfrac{1}{2} + \\tfrac{1}{2}\\,\\operatorname{erf}\\!\\left(\\frac{m}{2\\sigma}\\right)\\operatorname{erf}\\!\\left(\\frac{m}{2\\sqrt{3}\\,\\sigma}\\right).\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Empirical vs theoretical ABX in 1D\n\nWe can now check the formula above against ``fastabx``. The helpers below sample $n = 500$\npoints from each class, compute the empirical ABX with ``Score(Task(...))``, and compare it to\n``theoretical_abx``. Each panel overlays the true densities and the sample histograms, with the\ntwo scores shown in the title. We then sweep one parameter at a time: $\\mu_b$ at fixed\n$\\sigma$, then $\\sigma$ at fixed $\\mu_b$.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def theoretical_abx(mu_a: float, mu_b: float, sigma: float) -> float:\n    \"\"\"Closed-form ABX score for two 1D Gaussians with shared variance.\"\"\"\n    t = (mu_a - mu_b) / sigma\n    return 0.5 + 0.5 * math.erf(t / 2) * math.erf(t / (2 * math.sqrt(3)))\n\n\ndef empirical_abx(a: np.ndarray, b: np.ndarray) -> float:\n    \"\"\"Empirical ABX score on two 1D samples computed with ``fastabx``.\"\"\"\n    features = np.concatenate([a, b]).reshape(-1, 1)\n    labels = {\"label\": [0] * len(a) + [1] * len(b)}\n    dataset = Dataset.from_numpy(features, labels)\n    return 1.0 - Score(Task(dataset, on=\"label\"), \"euclidean\").collapse()\n\n\ndef gaussian_pdf(x: np.ndarray, mu: float, sigma: float) -> np.ndarray:\n    \"\"\"Density of the normal distribution with mean ``mu`` and standard deviation ``sigma``.\"\"\"\n    return np.exp(-0.5 * ((x - mu) / sigma) ** 2) / (sigma * math.sqrt(2 * math.pi))\n\n\ndef plot_panel(\n    ax: plt.Axes,\n    mu_a: float,\n    mu_b: float,\n    sigma: float,\n    x_range: tuple[float, float] | None,\n    n: int,\n    seed: int,\n) -> None:\n    \"\"\"Draw one panel comparing the theoretical and empirical ABX scores.\"\"\"\n    rng = np.random.default_rng(seed)\n    a = rng.normal(mu_a, sigma, n)\n    b = rng.normal(mu_b, sigma, n)\n    if x_range is None:\n        pad = 3.5 * sigma\n        lo, hi = min(mu_a, mu_b) - pad, max(mu_a, mu_b) + pad\n    else:\n        lo, hi = x_range\n    grid = np.linspace(lo, hi, 400)\n    bins = np.linspace(lo, hi, 40).tolist()\n\n    ax.hist(a, bins=bins, density=True, alpha=0.35, color=\"C0\")\n    ax.hist(b, bins=bins, density=True, alpha=0.35, color=\"C1\")\n    ax.plot(grid, gaussian_pdf(grid, mu_a, sigma), color=\"C0\", lw=2)\n    ax.plot(grid, gaussian_pdf(grid, mu_b, sigma), color=\"C1\", lw=2)\n    ax.set_xlim(lo, hi)\n    peak = 1.0 / (sigma * math.sqrt(2 * math.pi))\n    ax.set_ylim(0, 1.3 * peak)\n    ax.grid(alpha=0.3)\n\n    theory = theoretical_abx(mu_a, mu_b, sigma)\n    empirical = empirical_abx(a, b)\n    ax.set_title(\n        rf\"$\\mu_b={mu_b:g},\\ \\sigma={sigma:g}$\" + \"\\n\" + f\"theory: {theory:.3f}   fastabx: {empirical:.3f}\",\n        fontsize=10,\n    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Varying the mean separation (fixed $\\sigma = 1$, $\\mu_a = 0$)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = 500\nseed = 0\nmu_a = 0.0\nsigma = 1.0\nmu_bs = [0.25, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0]\nx_range = (-4.0, 8.0)\n\nfig, axes = plt.subplots(figsize=(13, 7), nrows=2, ncols=4, sharex=True, sharey=True)\nfor ax, mu_b in zip(axes.flatten(), mu_bs, strict=True):\n    plot_panel(ax, mu_a, mu_b, sigma, x_range, n, seed)\nfig.suptitle(rf\"Varying $\\mu_b$ at $\\sigma={sigma:g}$\")\nfig.tight_layout()\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Varying the standard deviation (fixed $\\mu_a = 0$, $\\mu_b = 2$)\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "n = 500\nseed = 0\nmu_a = 0.0\nmu_b = 2.0\nsigmas = [0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0, 4.0]\n\nfig, axes = plt.subplots(figsize=(13, 7), nrows=2, ncols=4)\nfor ax, sigma in zip(axes.flatten(), sigmas, strict=True):\n    plot_panel(ax, mu_a, mu_b, sigma, None, n, seed)\nfig.suptitle(rf\"Varying $\\sigma$ at $\\mu_b={mu_b:g}$\")\nfig.tight_layout()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}