#!/usr/bin/env python
"""
This script predicts reflections for integration
"""
import logging
import sys
import time
from itertools import repeat
from multiprocessing import Pool
import gemmi
import libtbx.phil
import numpy as np
from dials.algorithms.spot_prediction import ray_intersection
from dials.array_family import flex
from dials.array_family.flex import reflection_table
from dials.util import show_mail_handle_errors
from dials.util.options import (ArgumentParser,
reflections_and_experiments_from_files)
from dxtbx.model import ExperimentList
from laue_dials.algorithms.outliers import gen_kde
from laue_dials.utils.version import laue_version
# Print laue-dials + DIALS versions
laue_version()
logger = logging.getLogger("laue-dials.command_line.predict")
help_message = """
This script predicts reflections for integration using a refined geometry experiment and reflection file.
This program takes a refined geometry experiment and reflection file, builds a
DIALS experiment list and reflection table, and predicts the feasible set of
reflections on the detector for integration using the refined geometry.
The output is a predicted reflection table (predicted.refl), which contains
the necessary information for the integrator to locate predicted spots
and integrate them.
Examples::
laue.predict [options] poly_refined.expt poly_refined.refl
"""
# Set the phil scope
phil_scope = libtbx.phil.parse(
"""
output {
reflections = 'predicted.refl'
.type = str
.help = "The output reflection table filename."
log = 'laue.predict.log'
.type = str
.help = "The log filename."
}
nproc = 1
.type = int
.help = Number of parallel processes to run
wavelengths {
lam_min = None
.type = float(value_min=0.1)
.help = "Minimum wavelength for beam spectrum"
lam_max = None
.type = float(value_min=0.2)
.help = "Maximum wavelength for beam spectrum"
}
reciprocal_grid {
d_min = None
.type = float(value_min=0.1)
.help = "Minimum d-spacing for reflecting planes"
}
cutoff_log_probability = 0.
.type = float
.help = "The cutoff threshold for removing unlikely reflections"
""",
process_includes=True,
)
working_phil = phil_scope.fetch(sources=[phil_scope])
[docs]
def predict_spots(lam_min, lam_max, d_min, refls, expts):
"""
Predict spots given a geometry.
Args:
lam_min (float): Minimum wavelength for the beam spectrum.
lam_max (float): Maximum wavelength for the beam spectrum.
d_min (float): Minimum d-spacing for reflecting planes.
refls (dials.array_family.flex.reflection_table): The reflection table.
expts (dxtbx.model.experiment_list.ExperimentList): The experiment list.
Returns:
final_preds (dials.array_family.flex.reflection_table): Predicted reflection table.
"""
from laue_dials.algorithms.laue import LauePredictor
img_num = refls["id"][0]
# Remove outliers
refls = refls.select(refls.get_flags(refls.flags.used_in_refinement))
# Set up reflection table to store valid predictions
final_preds = reflection_table()
try:
# Get experiment data from experiment objects
experiment = expts[0]
cryst = experiment.crystal
spacegroup = gemmi.SpaceGroup(
cryst.get_space_group().type().universal_hermann_mauguin_symbol()
)
# Get mask
mask = experiment.imageset.get_mask(0)[0]
# Get beam vector
s0 = np.array(experiment.beam.get_s0())
# Get unit cell params
cell_params = cryst.get_unit_cell().parameters()
cell = gemmi.UnitCell(*cell_params)
# Get U matrix
U = np.asarray(cryst.get_U()).reshape(3, 3)
# Get observed centroids
sub_refls = refls.select(refls["id"] == img_num)
# Generate predictor object
logger.info(f"Predicting spots on image {img_num}.")
la = LauePredictor(
s0,
cell,
U,
lam_min,
lam_max,
d_min,
spacegroup=spacegroup,
)
# Predict spots
s1, new_lams, q_vecs, millers = la.predict_s1()
# Build new reflection table for predictions
preds = reflection_table.empty_standard(len(s1))
del preds["intensity.prf.value"]
del preds["intensity.prf.variance"]
del preds["lp"]
del preds["profile_correlation"]
# Populate needed columns
preds["id"] = flex.int([int(experiment.identifier)] * len(preds))
preds["imageset_id"] = flex.int([sub_refls[0]["imageset_id"]] * len(preds))
preds["s1"] = flex.vec3_double(s1)
preds["phi"] = flex.double(np.zeros(len(s1))) # Data are stills
preds["wavelength"] = flex.double(new_lams)
preds["rlp"] = flex.vec3_double(q_vecs)
preds["miller_index"] = flex.miller_index(millers.astype("int").tolist())
# Get which reflections intersect detector
intersects = ray_intersection(experiment.detector, preds)
preds = preds.select(intersects)
new_lams = new_lams[intersects]
# Get predicted centroids
x, y, _ = preds["xyzcal.mm"].parts()
# Convert to pixel units
px_size = experiment.detector.to_dict()["panels"][0]["pixel_size"]
x = x / px_size[0]
y = y / px_size[1]
# Convert centroids to integer pixels
x = np.asarray(flex.floor(x).iround())
y = np.asarray(flex.floor(y).iround())
# Remove predictions in masked areas
img_row_size = experiment.detector.to_dict()["panels"][0]["image_size"][1]
sel = np.full(len(x), True)
for i in range(len(preds)):
if not mask[x[i] + img_row_size * y[i]]:
sel[i] = False
preds = preds.select(flex.bool(sel))
new_lams = new_lams[sel]
except:
logger.warning(
f"WARNING: Could not predict reflections for experiment {img_num}. Image skipped."
)
return reflection_table() # Return empty on failure
# Append image predictions to dataset
final_preds.extend(preds)
# Return predicted refls
return final_preds
[docs]
@show_mail_handle_errors()
def run(args=None, *, phil=working_phil):
"""
Run the prediction script.
Args:
args (list): Command-line arguments.
phil: Working phil scope.
"""
# Parse arguments
usage = "laue.predict [options] poly_refined.expt poly_refined.refl"
parser = ArgumentParser(
usage=usage,
phil=phil,
read_reflections=True,
read_experiments=True,
check_format=True,
epilog=help_message,
)
params, options = parser.parse_args(args=args, show_diff_phil=False)
# Configure logging
console = logging.StreamHandler(sys.stdout)
fh = logging.FileHandler(params.output.log, mode="w", encoding="utf-8")
loglevel = logging.INFO
logger.addHandler(fh)
logger.addHandler(console)
logging.captureWarnings(True)
warning_logger = logging.getLogger("py.warnings")
warning_logger.addHandler(fh)
warning_logger.addHandler(console)
dials_logger = logging.getLogger("dials")
dials_logger.addHandler(fh)
dials_logger.addHandler(console)
dxtbx_logger = logging.getLogger("dxtbx")
dxtbx_logger.addHandler(fh)
dxtbx_logger.addHandler(console)
xfel_logger = logging.getLogger("xfel")
xfel_logger.addHandler(fh)
xfel_logger.addHandler(console)
logger.setLevel(loglevel)
dials_logger.setLevel(loglevel)
dxtbx_logger.setLevel(loglevel)
xfel_logger.setLevel(loglevel)
fh.setLevel(loglevel)
# Log diff phil
diff_phil = parser.diff_phil.as_str()
if diff_phil != "":
logger.info("The following parameters have been modified:\n")
logger.info(diff_phil)
# Print help if no input
if not params.input.experiments or not params.input.reflections:
parser.print_help()
return
# Check for valid parameter values
if params.reciprocal_grid.d_min == None:
logger.info("Please provide a d_min.")
return
elif params.wavelengths.lam_min == None or params.wavelengths.lam_max == None:
logger.info(
"Please provide upper and lower boundaries for the wavelength spectrum."
)
return
elif params.wavelengths.lam_min > params.wavelengths.lam_max:
logger.info("Minimum wavelength cannot be greater than maximum wavelength.")
return
# Get initial time for process
start_time = time.time()
# Load data
reflections, experiments = reflections_and_experiments_from_files(
params.input.reflections, params.input.experiments
)
reflections = reflections[0] # Get table out of list
reflections = reflections.select(
reflections.get_flags(reflections.flags.used_in_refinement)
)
# Sanity checks
if len(experiments) == 0:
parser.print_help()
return
# Prepare parallel input
ids = list(np.unique(reflections["id"]).astype(np.int32))
expts_arr = []
refls_arr = []
for i in range(len(ids)): # Split DIALS objects into lists
expts_arr.append(ExperimentList([experiments[i]]))
refls_arr.append(reflections.select(reflections["id"] == ids[i]))
inputs = list(
zip(
repeat(params.wavelengths.lam_min),
repeat(params.wavelengths.lam_max),
repeat(params.reciprocal_grid.d_min),
refls_arr,
expts_arr,
)
)
# Predict reflections
logger.info(f"Predicting reflections")
num_processes = params.nproc
with Pool(processes=num_processes) as pool:
output = pool.starmap(predict_spots, inputs, chunksize=1)
logger.info(f"Finished predicting feasible spots.")
# Convert output to single reflection table
predicted_reflections = reflection_table()
for table in output:
predicted_reflections.extend(table)
# Generate a KDE
logger.info("Training KDE for resolution-dependent bandwidth.")
_, _, kde = gen_kde(experiments, reflections)
# Get probability densities for predictions:
logger.info(f"Calculating prediction probabilities.")
rlps = predicted_reflections["rlp"].as_numpy_array()
norms = (np.linalg.norm(rlps, axis=1)) ** 2
lams = predicted_reflections["wavelength"].as_numpy_array()
pred_data = np.vstack([lams, norms])
# Split array into chunks
inputs = np.array_split(pred_data, num_processes, axis=1)
# Multiprocess PDF estimation
with Pool(processes=num_processes) as pool:
prob_list = pool.map(kde.pdf, inputs)
probs = np.concatenate(prob_list)
# Cut off using log probabilities
logger.info(f"Removing improbable reflections.")
cutoff_log = params.cutoff_log_probability
sel = np.log(probs) >= cutoff_log
final_predictions = predicted_reflections.select(flex.bool(sel))
# Mark strong spots
logger.info("Marking strong predictions")
idpred, idstrong = final_predictions.match_by_hkle(reflections)
strongs = np.zeros(len(final_predictions), dtype=int)
strongs[idpred] = 1
final_predictions["strong"] = flex.int(strongs)
logger.info(f"Assigning intensities")
for i in range(len(idstrong)):
final_predictions["intensity.sum.value"][idpred[i]] = reflections[
"intensity.sum.value"
][idstrong[i]]
final_predictions["intensity.sum.variance"][idpred[i]] = reflections[
"intensity.sum.variance"
][idstrong[i]]
final_predictions["xyzobs.mm.value"][idpred[i]] = reflections[
"xyzobs.mm.value"
][idstrong[i]]
final_predictions["xyzobs.mm.variance"][idpred[i]] = reflections[
"xyzobs.mm.variance"
][idstrong[i]]
final_predictions["xyzobs.px.value"][idpred[i]] = reflections[
"xyzobs.px.value"
][idstrong[i]]
final_predictions["xyzobs.px.variance"][idpred[i]] = reflections[
"xyzobs.px.variance"
][idstrong[i]]
# Populate 'px' variety of predicted centroids
# Based on flat rectangular detector
x, y, z = final_predictions["xyzcal.mm"].parts()
expt = experiments[0] # assuming shared detector models
x = x / expt.detector.to_dict()["panels"][0]["pixel_size"][0]
y = y / expt.detector.to_dict()["panels"][0]["pixel_size"][1]
final_predictions["xyzcal.px"] = flex.vec3_double(x, y, z)
# Save reflections
logger.info("Saving predicted reflections to %s", params.output.reflections)
final_predictions.as_file(filename=params.output.reflections)
# Final logs
logger.info("")
logger.info(
"Time Taken for Total Processing = %f seconds", time.time() - start_time
)
if __name__ == "__main__":
run()