Source code for laue_dials.command_line.compute_rmsds
#!/usr/bin/env python
"""
This script computes and plots RMSDs for a pair of DIALS experiment/reflection files.
"""
import logging
import sys
import gemmi
import libtbx.phil
import numpy as np
import pandas as pd
import reciprocalspaceship as rs
from cctbx import sgtbx
from dials.util import show_mail_handle_errors
from dials.util.options import (ArgumentParser,
reflections_and_experiments_from_files)
from matplotlib import pyplot as plt
from laue_dials.utils.version import laue_version
# Print laue-dials + DIALS versions
laue_version()
logger = logging.getLogger("laue-dials.command_line.compute_rmsds")
help_message = """
This program computes and plots the RMSDs between observed and predicted centroids in a reflection table.
Examples::
laue.compute_rmsds [options] filename.expt filename.refl
"""
# Set the phil scope
phil_scope = libtbx.phil.parse(
"""
show = True
.type = bool
.help = "Show the plot of centroid RMSDs per image."
save = False
.type = bool
.help = "Save the plot of centroid RMSDs per image to a PNG file."
csv = None
.type = str
.help = "Save a CSV of the RMSDs per image with this filename."
output = "residuals.png"
.type = str
.help = "The filename for the generated plot."
refined_only = False
.type = bool
.help = "Only compute refined spot RMSDs."
log = 'laue.compute_rmsds.log'
.type = str
.help = "The log filename."
ymax = None
.type = int
.help = "Desired ymax for plot. Defaults to maximum of data."
dotsize = None
.type = int
.help = "Desired dot size for plot in points**2. Defaults to 16."
""",
process_includes=True,
)
working_phil = phil_scope.fetch(sources=[phil_scope])
[docs]
@show_mail_handle_errors()
def run(args=None, *, phil=working_phil):
"""
Compute and plot RMSDs for a pair of DIALS experiment/reflection files.
Args:
args (list): Command-line arguments.
phil: The phil scope for the program.
Returns:
None
"""
# Parse arguments
usage = "laue.compute_rmsds [options] filename.expt filename.refl"
parser = ArgumentParser(
usage=usage,
phil=working_phil,
read_reflections=True,
read_experiments=True,
check_format=False,
epilog=help_message,
)
params, options = parser.parse_args(args=args, show_diff_phil=False)
# Configure logging
console = logging.StreamHandler(sys.stdout)
fh = logging.FileHandler(params.log, mode="w", encoding="utf-8")
loglevel = logging.INFO
logger.addHandler(fh)
logger.addHandler(console)
logging.captureWarnings(True)
warning_logger = logging.getLogger("py.warnings")
warning_logger.addHandler(fh)
warning_logger.addHandler(console)
dials_logger = logging.getLogger("dials")
dials_logger.addHandler(fh)
dials_logger.addHandler(console)
dxtbx_logger = logging.getLogger("dxtbx")
dxtbx_logger.addHandler(fh)
dxtbx_logger.addHandler(console)
xfel_logger = logging.getLogger("xfel")
xfel_logger.addHandler(fh)
xfel_logger.addHandler(console)
logger.setLevel(loglevel)
dials_logger.setLevel(loglevel)
dxtbx_logger.setLevel(loglevel)
xfel_logger.setLevel(loglevel)
fh.setLevel(loglevel)
# Print help if no input
if not params.input.experiments or not params.input.reflections:
parser.print_help()
exit()
# Log diff phil
diff_phil = parser.diff_phil.as_str()
if diff_phil != "":
logger.info("The following parameters have been modified:\n")
logger.info(diff_phil)
# Load data
refls, expts = reflections_and_experiments_from_files(
params.input.reflections, params.input.experiments
)
refls = refls[0]
if params.refined_only:
refls = refls.select(refls.get_flags(refls.flags.used_in_refinement))
if len(refls) == 0:
logger.info("No reflections in table after filtering.")
return
# Get data from reflection table
hkl = refls["miller_index"].as_vec3_double()
cell = np.zeros(6)
for crystal in expts.crystals():
cell += np.array(crystal.get_unit_cell().parameters()) / len(expts.crystals())
cell = gemmi.UnitCell(*cell)
sginfo = expts.crystals()[0].get_space_group().info()
symbol = sgtbx.space_group_symbols(sginfo.symbol_and_number().split("(")[0])
spacegroup = gemmi.SpaceGroup(symbol.universal_hermann_mauguin())
# Generate rs.DataSet
data = rs.DataSet(
{
"H": hkl.as_numpy_array()[:, 0].astype(np.int32),
"K": hkl.as_numpy_array()[:, 1].astype(np.int32),
"L": hkl.as_numpy_array()[:, 2].astype(np.int32),
"image": refls["id"].as_numpy_array() + 1,
"xobs": refls["xyzobs.px.value"].as_numpy_array()[:, 0],
"yobs": refls["xyzobs.px.value"].as_numpy_array()[:, 1],
"xcal": refls["xyzcal.px"].as_numpy_array()[:, 0],
"ycal": refls["xyzcal.px"].as_numpy_array()[:, 1],
},
cell=cell,
spacegroup=spacegroup,
).infer_mtz_dtypes()
logger.info(f"Total Number of Spots: {len(data)}.")
# Calculate image residuals
images = np.unique(data["image"])
x_resids = data["xcal"] - data["xobs"]
y_resids = data["ycal"] - data["yobs"]
sqr_resids = x_resids**2 + y_resids**2
mean_resids = np.zeros(len(images))
for i, img_num in enumerate(images):
sel = data["image"] == img_num
mean_resids[i] = np.mean(sqr_resids[sel])
rmsds = np.sqrt(mean_resids)
resid_data = pd.DataFrame({"Image": images, "RMSD (px)": rmsds})
pd.set_option("display.max_rows", None)
logger.info(f"RMSDs per image: \n{resid_data}")
if params.csv is not None:
resid_data.to_csv(params.csv, index=False)
# Get pixel size (assume square)
# Not sure if this will be needed but I never remember
# this incantation so leaving it here
expts.detectors()[0].to_dict()["panels"][0]["pixel_size"][0]
# Plot residuals
fig = plt.figure()
plt.scatter(images, rmsds, color="k", s=params.dotsize)
if params.ymax == None:
plt.ylim(bottom=0)
else:
plt.ylim(bottom=0, top=params.ymax)
plt.title("Image RMSDs")
plt.xlabel("Image #")
plt.ylabel("RMSD (px)")
if params.save:
fig.savefig(params.output, format="png")
if params.show:
plt.show()
if __name__ == "__main__":
run()