{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 2D Gaussians\n", "\n", "## Two gaussians" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from fastabx import Dataset, Score, Task\n", "\n", "n = 100\n", "diagonal_shift = 4\n", "mean = np.zeros(2)\n", "cov = np.array([[4, -2], [-2, 3]])\n", "\n", "rng = np.random.default_rng(seed=0)\n", "first = rng.multivariate_normal(mean, cov, n)\n", "second = rng.multivariate_normal(mean + np.ones(2) * diagonal_shift, cov, n)\n", "\n", "dataset = Dataset.from_numpy(np.vstack([first, second]), {\"label\": [0] * n + [1] * n})\n", "task = Task(dataset, on=\"label\")\n", "score = Score(task, \"euclidean\")\n", "\n", "plt.scatter(*first.T, alpha=0.5)\n", "plt.scatter(*second.T, alpha=0.5)\n", "plt.axis(\"equal\")\n", "plt.grid()\n", "plt.title(f\"ABX: {score.collapse():.3%}\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two gaussians with increasing shift" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from fastabx import Dataset, Score, Task\n", "\n", "n = 100\n", "shift = np.ones(1)\n", "mean = np.zeros(2)\n", "cov = np.array([[4, -2], [-2, 3]])\n", "\n", "rng = np.random.default_rng(seed=0)\n", "first = rng.multivariate_normal(mean, cov, n)\n", "second = rng.multivariate_normal(mean, cov, n)\n", "\n", "fig, axes = plt.subplots(figsize=(10, 8), nrows=3, ncols=3, sharex=True, sharey=True)\n", "for ax in axes.flatten():\n", " dataset = Dataset.from_numpy(np.vstack([first, second]), {\"label\": [0] * n + [1] * n})\n", " task = Task(dataset, on=\"label\")\n", " score = Score(task, \"euclidean\")\n", "\n", " ax.scatter(*first.T, s=10, alpha=0.5)\n", " ax.scatter(*second.T, s=10, alpha=0.5)\n", " ax.grid()\n", " ax.set_title(f\"ABX: {score.collapse():.3%}\")\n", " second += shift\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }