{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Merging with a Robust Error Model \n", "\n", "In the previous example, we computed the common merging statistics $CC_{1/2}$ and $CC_{anom}$ to explore a dataset with significant anomalous signal from native sulfur atoms. One assumption that we made while merging is that the scaled reflection observations are normally distributed about the mean. This assumption is consistent with the merging strategy used by [AIMLESS](https://doi.org/10.1107/S0907444913000061), which was used to scale the data in the first place. In this example, we will explore whether the scaled reflection observations are normally distributed, and whether we can improve the anomalous signal ($CC_{anom}$) by using a different error model.\n", "\n", "_Note:_ See [pytorch.org](https://pytorch.org/) for customizable PyTorch installation instructions." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n", "sns.set_context(\"notebook\", font_scale=1.3)\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in /Users/jgreisman/miniconda3/envs/test/lib/python3.9/site-packages (1.10.1)\r\n", "Requirement already satisfied: typing-extensions in /Users/jgreisman/miniconda3/envs/test/lib/python3.9/site-packages (from torch) (4.0.1)\r\n" ] } ], "source": [ "# Install PyTorch to running kernel\n", "import sys\n", "!{sys.executable} -m pip install torch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import reciprocalspaceship as rs" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.9.18\n" ] } ], "source": [ "print(rs.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### Normal Error Model\n", "\n", "It is common in merging to assume that scaled intensities are normally distributed about the true mean. We can assess the validity of this assumption by looking at the residuals between scaled intensities and the maximum likelihood esimate of their true intensity." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "ds = rs.read_mtz(\"data/HEWL_unmerged.mtz\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "merged_normal = rs.algorithms.merge(ds)\n", "merged_normal = merged_normal.stack_anomalous()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "ds.hkl_to_asu(anomalous=True, inplace=True)\n", "ds[\"IML\"] = merged_normal.loc[ds.index, \"I\"]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Expected residuals for normally distributed data\n", "x = np.linspace(-6., 6., 50)\n", "bin_width = x[1] - x[0]\n", "normal = bin_width*ds.shape[0]*(1/np.sqrt(2*np.pi))*np.exp(-0.5*x**2)\n", "\n", "# Histogram Residuals\n", "fig = plt.figure(figsize=(8, 5))\n", "sns.histplot((ds.I - ds.IML)/ds.SIGI, bins=x, color=\"k\", label=\"Residuals\")\n", "plt.plot(x, normal, 'r-', lw=2, label=\"Expected\")\n", "plt.ylabel(\"Log Count\")\n", "plt.yscale(\"log\")\n", "plt.ylim(1e0, 3e5)\n", "plt.xlabel(r\"$\\frac{I_{h,i} - I^{ML}_{h}}{\\sigma_{I_{h,i}}}$\", fontsize=24)\n", "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These residuals are not symmetric about $0$, which would have been expected for truly normally distributed data. In addition, the tails are much \"heavier\" than for a normal distribution. This suggests that it may be possible to do a better job merging these scaled observations by using a different error model that is more robust to outliers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### Generalized Merging with Flexible Error Model\n", "\n", "The inverse-variance weighting scheme implemented in `rs.algorithms.merge()` is the maximum likelihood estimator for the true mean if we assume the observations are normally distributed. \n", "However, we can write a more general form for the maximum likelihood estimator for the mean of each intensity distribution, $\\mu \\in\\mathbb{R}^{|\\mathbf{H}|}$, without assuming a specific distribution for the error model. \n", "Therefore, we will maximize the probability of the data given the model\n", "\\begin{align*}\n", "P(data | model) &= \\prod_{h,i}P(I_{h,i}|\\mu_h, \\sigma_{I_{h,i}}) \n", "\\end{align*}\n", "by minimizing the negative log likelihood\n", "\\begin{align*}\n", "\\mathcal{L} &\\triangleq -\\log P(data | model) \\\\\n", "&= -\\sum_{h,i}\\log P(I_{h,i}|\\mu_h, \\sigma_{I_{h,i}}) \\\\\n", "I^{ML} &= \\underset{\\mu}{\\mathrm{argmin}}\\ -\\sum_{h,i} \\log P(I_{h,i} | \\mu_h, \\sigma_{I_{h,i}}) \n", "\\end{align*}\n", "with respect to $\\mu$. With this formulation, it is possible to supply any parametric form for the error model belonging to the [location-scale family](https://en.wikipedia.org/wiki/Location%E2%80%93scale_family) of distributions. This maximum likelihood estimator is implemented in the function below." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def merge_mle(ds, distribution, *args, lr=.5, progress_bar=True, return_loss=False, **kwargs):\n", " \"\"\"\n", " Merge observations using the provided distribution as an error model.\n", " Additional arguments or keyword arguments will be passed to the PyTorch \n", " distribution constructor.\n", " \n", " Parameters\n", " ----------\n", " ds : rs.DataSet\n", " Scaled, unmerged observations\n", " distribution : torch.distributions.Distribution\n", " PyTorch distribution to use as error model\n", " lr : float\n", " Learning rate for Adam optimizer\n", " progress_bar : bool\n", " Whether to display a progress bar for optimization\n", " return_loss : bool\n", " Whether to return the loss function values during optimization\n", " \n", " Returns\n", " -------\n", " rs.DataSet or (rs.DataSet, losses)\n", " Merged DataSet with or without list of loss function values\n", " \"\"\"\n", " # Compute MLE with normal error model\n", " ds = ds.copy()\n", " mle_norm = rs.algorithms.merge(ds).stack_anomalous()\n", " ds[\"IML\"] = mle_norm.loc[ds.index, \"I\"]\n", " \n", " # Observed intensities and error estimates\n", " groupby = ds.groupby(['H', 'K', 'L'])\n", " idx = groupby.ngroup().to_numpy()\n", " I = torch.as_tensor(ds.I.to_numpy())\n", " SigI = torch.as_tensor(ds.SIGI.to_numpy())\n", " \n", " # Initialize optimization at MLE with normal error model\n", " mle = groupby.first()[\"IML\"].to_numpy()\n", " mean = torch.tensor(mle, requires_grad=True)\n", " \n", " # Define loss function\n", " def _evaluate_loss():\n", " return -torch.sum(distribution(*args, loc=mean[idx], scale=SigI, **kwargs).log_prob(I))\n", "\n", " # Setup and fit model\n", " losses = []\n", " opt = torch.optim.Adam([mean], lr=lr) \n", " for _ in tqdm(range(300), disable=not progress_bar):\n", " opt.zero_grad()\n", " loss = _evaluate_loss()\n", " losses.append(loss.detach())\n", " loss.backward()\n", " opt.step()\n", " grad = torch.autograd.grad(_evaluate_loss(), mean, create_graph=True)[0]\n", " hess = torch.autograd.grad(grad.sum(), mean, create_graph=True)[0]\n", " \n", " # Package results\n", " results = rs.DataSet({'I': mean.detach().numpy(), \n", " 'SIGI': np.sqrt(1./hess.detach().numpy())},\n", " index=groupby.first().index, \n", " spacegroup=ds.spacegroup,\n", " cell=ds.cell, \n", " merged=True)\n", " results.infer_mtz_dtypes(inplace=True)\n", "\n", " if return_loss:\n", " return results.unstack_anomalous(), losses\n", " return results.unstack_anomalous()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function must be passed a `torch.distribution.Distribution` to use as the error model. It then uses the [Adam optimizer](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam) to minimize the negative log-likelihood and fit the merged intensities, $I^{ML}$. For stability, the model is initialized using the mean intensity values from `rs.algorithms.merge()`. The following cell fits the model using a [Student's _t_-distributed error model](https://pytorch.org/docs/stable/distributions.html#studentt) with `df=4.0`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6cd5820834fa44d5b66d870ad7bd24e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/300 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "result1, loss1 = merge_mle(ds, torch.distributions.StudentT, 4.0, \n", " progress_bar=True, return_loss=True)\n", "\n", "# Plot loss function\n", "plt.plot(loss1)\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Loss\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
I(+)SIGI(+)I(-)SIGI(-)
HKL
1814913.3830950.5903978312.2444470.5472484
25241481.6251516.697405487.1359616.214811
3913762.359973.602170761.640683.1317017
29295235.866648.579554235.866648.579554
30145174.98365.340581189.37324.8809543
\n", "
" ], "text/plain": [ " I(+) SIGI(+) I(-) SIGI(-)\n", "H K L \n", "18 14 9 13.383095 0.59039783 12.244447 0.5472484\n", "25 24 1 481.62515 16.697405 487.13596 16.214811\n", "39 13 7 62.35997 3.6021707 61.64068 3.1317017\n", "29 29 5 235.86664 8.579554 235.86664 8.579554\n", "30 14 5 174.9836 5.340581 189.3732 4.8809543" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "result1.sample(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As a sanity check we can test this implementation by passing the normal distribution as the error model. In this case, the initial values should match the maximum likelihood estimate and optimization should not change the estimates. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "df1d4a7dda514a0d84aef6c962136263", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/300 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "result2, loss2 = merge_mle(ds, torch.distributions.Normal, lr=1e-3,\n", " progress_bar=True, return_loss=True)\n", "result2 = result2[[\"I(+)\", \"SIGI(+)\", \"I(-)\", \"SIGI(-)\"]].stack_anomalous()\n", "\n", "# Plot loss function\n", "fig, ax = plt.subplots(ncols=2, figsize=(12, 5.5))\n", "ax[0].plot(loss2)\n", "ax[0].set_ylim(4.15e6, max(loss1))\n", "ax[0].set_xlabel(\"Iteration\")\n", "ax[0].set_ylabel(\"Loss\")\n", "ax[0].set_title(\"Loss Function\")\n", "ax[1].loglog(merged_normal[\"I\"].to_numpy(), result2.loc[merged_normal.index, \"I\"].to_numpy(), \"k.\", alpha=0.5)\n", "ax[1].set_xlabel(r\"$I^{ML}$ (AIMLESS)\")\n", "ax[1].set_ylabel(r\"$I^{ML}$ (PyTorch)\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see here that the loss function does not change during the optimization, and that the final maximum likelihood estimates for the intensities are equivalent to the input. This validates that the merging function is working as expected." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### Assess Student's _t_-Distributed Error Model\n", "\n", "A Student's _t_-distribution is useful in modeling data that contain outliers. This distribution places more density in the tails than a normal which makes maximum likelihood estimates robust to outlying measurements. It is parameterized by a degree of freedom, $\\nu$, which can take a value between $[0, \\infty)$ and controls the heaviness of the tails. As $\\nu\\to\\infty$, the probability density function approaches a normal distribution.\n", "\n", "Below, we will set up a few helper functions for merging our data within randomly partitioned half-datasets. We will do this for the Student's t-distribution scanning several degrees of freedom and for the normal distribution so that we can compare the results. Based on the previous example, we will compute $CC_{anom}$ using a Spearman correlation coefficient and repeated 2-fold cross-validation to compare the different error models." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def sample_halfdatasets(data):\n", " \"\"\"Randomly split DataSet into two equal halves by BATCH\"\"\"\n", " batch = data.BATCH.unique().to_numpy(dtype=int)\n", " np.random.shuffle(batch)\n", " halfbatch1, halfbatch2 = np.array_split(batch, 2)\n", " half1 = data.loc[data.BATCH.isin(halfbatch1)]\n", " half2 = data.loc[data.BATCH.isin(halfbatch2)]\n", " return half1, half2" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def merge_dataset(dataset, nsamples, distribution, *args, **kwargs):\n", " \"\"\"\n", " Merge dataset with repeated 2-fold cross-validation using `distribution`\n", " as error model.\n", " \"\"\"\n", " dataset = dataset.copy()\n", " samples = []\n", " for n in tqdm(range(nsamples)):\n", " half1, half2 = sample_halfdatasets(dataset)\n", " mergedhalf1 = merge_mle(half1, distribution, *args, progress_bar=False, **kwargs)\n", " mergedhalf2 = merge_mle(half2, distribution, *args, progress_bar=False, **kwargs)\n", " result = mergedhalf1.merge(mergedhalf2, on=[\"H\", \"K\", \"L\"], suffixes=(1, 2))\n", " result[\"sample\"] = n\n", " samples.append(result)\n", " return rs.concat(samples).sort_index()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def merge_dataset_normal(dataset, nsamples):\n", " \"\"\"\n", " Merge dataset with repeated 2-fold cross-validation using normal distribution\n", " as error model.\n", " \"\"\"\n", " dataset = dataset.copy()\n", " samples = []\n", " for n in tqdm(range(nsamples)):\n", " half1, half2 = sample_halfdatasets(dataset)\n", " mergedhalf1 = rs.algorithms.merge(half1)\n", " mergedhalf2 = rs.algorithms.merge(half2)\n", " result = mergedhalf1.merge(mergedhalf2, on=[\"H\", \"K\", \"L\"], suffixes=(1, 2))\n", " result[\"sample\"] = n\n", " samples.append(result)\n", " return rs.concat(samples).sort_index()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Note:* Using these settings, the following will take ~15 min to complete. For the pre-print, this was run with `nsamples=15`. It is possible to run into occasional numerical instabilities when computing the Hessian for `df=4.0`, but these will only impact $\\sigma_I^{ML}$ estimates which are not used here. " ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e90b8157e9544b92993a1bf9fdbd16de", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def plot(results, label, ax, color=None):\n", " ax.errorbar(results.index, results[\"mean\"], yerr=results[\"std\"], color=color, label=label)\n", " return\n", "\n", "fig = plt.figure(figsize=(10, 6))\n", "ax = fig.gca()\n", "with sns.color_palette(\"viridis\", 5) as palette:\n", " plot(results4, r\"Student-T ($d.f.=4$)\", ax, color=palette[4])\n", " plot(results8, r\"Student-T ($d.f.=8$)\", ax, color=palette[3])\n", " plot(results16, r\"Student-T ($d.f.=16$)\", ax, color=palette[2])\n", " plot(results32, r\"Student-T ($d.f.=32$)\", ax, color=palette[1])\n", " plot(results64, r\"Student-T ($d.f.=64$)\", ax, color=palette[0])\n", "plot(resultsinf,\"Normal\", ax, color=\"k\")\n", "plt.ylabel(r\"$CC_{anom}$ (Spearman)\")\n", "plt.xlabel(r\"Resolution Bin ($\\AA$)\")\n", "plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), title=\"Error Model\")\n", "plt.ylim(0, 0.72)\n", "\n", "plt.xticks(resultsinf.index, labels, rotation=45, ha='right', rotation_mode='anchor')\n", "plt.grid(axis=\"y\", linestyle='--')\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "### Summary \n", "\n", "This dataset was scaled and merged in AIMLESS, which involves rounds of outlier rejection and implicitly assumes that the intensities are normally distributed about the true mean. In the [first section](3_mergingerrormodel.ipynb#Normal-Error-Model), we observed that the residuals from merging are not normally distributed, suggesting that the normal error model may have been suboptimal. We implemented a more general maximum likelihood-based approach for merging data using different probability distributions and used it to evaluate the performance of Student's _t_-distributed error models. This seemed like a reasonable starting point because it is often more robust to outliers than a normal distribution. \n", "\n", "Using repeated 2-fold cross-validation, we saw that a Student's _t_-distribution with a low degree of freedom ($\\nu=4$) outperforms the normally distributed error model when assessed using $CC_{anom}$. Furthermore, the performance seems to approach that of the normal distribution when the degree of freedom is increased, which is expected since the _t_-distribution approaches a normal disitrbution as $\\nu\\to\\infty$.\n", "\n", "Although this dataset is quite high quality and was used to phase and refine a model from the [native sulfur SAD signal](http://doi.org/10.2210/pdb7L84/pdb), it still shows that there can be an incremental improvement in $CC_{anom}$ from revisiting assumptions regarding error models during merging. This model is implemented in ~40 lines of code using `PyTorch`, and can be quickly applied to any dataset of interest. By lowering the barrier to implementing such models, `reciprocalspaceship` makes it easy to try new analyses and to revisit some of the assumptions made in crystallographic data reduction." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 4 }