Source code for laue_dials.command_line.refine

#!/usr/bin/env python
"""
This script handles polychromatic geometry refinement.
"""

import logging
import sys
import time
from itertools import repeat
from multiprocessing import Pool

import libtbx.phil
import numpy as np
from dials.array_family import flex
from dials.array_family.flex import reflection_table
from dials.command_line.refine import run_dials_refine
from dials.util import show_mail_handle_errors
from dials.util.options import ArgumentParser
from dxtbx.model import ExperimentList
from dxtbx.model.experiment_list import ExperimentListFactory

from laue_dials.algorithms.laue import (gen_beam_models, remove_beam_models,
                                        store_wavelengths)
from laue_dials.utils.version import laue_version

# Print laue-dials + DIALS versions
laue_version()

logger = logging.getLogger("laue-dials.command_line.refine")

help_message = """
This script handles polychromatic geometry refinement.

This program takes an indexed DIALS experiment list and reflection table
(with wavelengths) and refines the experiment geometry. The outputs are a pair of
files (poly_refined.expt, poly_refined.refl) that contain an experiment geometry
suitable for prediction and integration.

Examples:

    laue.refine [options] optimized.expt optimized.refl
"""

# Set the phil scope
main_phil = libtbx.phil.parse(
    """
include scope dials.command_line.refine.working_phil

nproc = 1
  .type = int
  .help = Number of parallel processes to run
""",
    process_includes=True,
)

refiner_phil = libtbx.phil.parse(
    """

refinement {
  refinery {
    engine = SparseLevMar
  }

  reflections {
    weighting_strategy {
      override = stills
    }

    outlier {
      nproc = 1

      minimum_number_of_reflections = 1

      algorithm = mcd

      separate_images = True
    }
  }

  parameterisation {
    beam {
      fix = *in_spindle_plane *out_spindle_plane
    }

    crystal {
      unit_cell {
        fix_list = real_space_a
      }
    }

    detector {
      fix = distance
    }

    auto_reduction {
      action = fix

      min_nref_per_parameter = 1
    }

    spherical_relp_model = True
  }
}

output {
  experiments = poly_refined.expt

  reflections = poly_refined.refl

  log = laue.poly_refined.log
}
"""
)

working_phil = main_phil.fetch(sources=[refiner_phil])


[docs] def correct_identifiers(expts, refls): """ Correct identifiers in case of skipped images. Args: expts (dxtbx.model.ExperimentList): Experiment list. refls (flex.reflection_table): Reflection table. Returns: corrected_expts (dxtbx.model.experiment_list.ExperimentList): The corrected experiment list with corrected identifiers. corrected_refls (dials.array_family.flex.reflection_table): The corrected reflection table with updated identifiers. """ # Initialize arrays corrected_expts = ExperimentList() corrected_refls = reflection_table() # Fix identifiers skipped_expts = 0 for i, expt in enumerate(expts): if i != int(expt.identifier) - skipped_expts: # Found skipped image skipped_expts = skipped_expts + 1 img_refls = refls.select(refls["id"] == i + skipped_expts) # Correct ids expt.identifier = str(i) ids = np.full(len(img_refls), i) img_refls["id"] = flex.int(ids) # Add data to array corrected_expts.append(expt) corrected_refls.extend(img_refls) return corrected_expts, corrected_refls
[docs] def refine_image(params, expts, refls): """ Refine image given parameters, experiments, and reflections. Args: params (libtbx.phil.scope_extract): Refinement parameters. expts (dxtbx.model.ExperimentList): Experiment list. refls (flex.reflection_table): Reflection table. Returns: refined_expts (dxtbx.model.experiment_list.ExperimentList): The refined experiment list with updated geometry. refined_refls (dials.array_family.flex.reflection_table): The refined reflection table with updated wavelength and centroid data. """ img_num = refls["id"][0] original_ids = refls["id"] refls["id"] = flex.int([0] * len(refls)) refls["imageset_id"] = flex.int([0] * len(refls)) # Generate beam models multi_expts, multi_refls = gen_beam_models(expts, refls) # Perform refinement try: refined_expts, refined_refls, _, _ = run_dials_refine( multi_expts, multi_refls, params ) except: logger.warning( f"WARNING: Experiment {img_num} could not be refined. Skipping image." ) return ExperimentList(), reflection_table() # Return empty # Write wavelengths and centroid data refined_refls = store_wavelengths(refined_expts, refined_refls) refined_refls.map_centroids_to_reciprocal_space(refined_expts) # Strip beam objects and reset reflection IDs refined_expts = remove_beam_models(refined_expts, original_ids[0]) refined_refls["id"] = original_ids # Return refined data return refined_expts, refined_refls
[docs] @show_mail_handle_errors() def run(args=None, *, phil=working_phil): """ Run the refinement script. Args: args (list): Command-line arguments. phil: Working phil scope. """ # Parse arguments usage = "laue.refine [options] optimized.expt optimized.refl" parser = ArgumentParser( usage=usage, phil=phil, read_reflections=True, read_experiments=True, check_format=False, epilog=help_message, ) params, options = parser.parse_args(args=args, show_diff_phil=False) # Configure logging console = logging.StreamHandler(sys.stdout) fh = logging.FileHandler(params.output.log, mode="w", encoding="utf-8") loglevel = logging.INFO logger.addHandler(fh) logger.addHandler(console) logging.captureWarnings(True) warning_logger = logging.getLogger("py.warnings") warning_logger.addHandler(fh) warning_logger.addHandler(console) dials_logger = logging.getLogger("dials") dials_logger.addHandler(fh) dials_logger.addHandler(console) dxtbx_logger = logging.getLogger("dxtbx") dxtbx_logger.addHandler(fh) dxtbx_logger.addHandler(console) xfel_logger = logging.getLogger("xfel") xfel_logger.addHandler(fh) xfel_logger.addHandler(console) logger.setLevel(loglevel) dials_logger.setLevel(loglevel) dxtbx_logger.setLevel(loglevel) xfel_logger.setLevel(loglevel) fh.setLevel(loglevel) # Log diff phil diff_phil = parser.diff_phil.as_str() if diff_phil != "": logger.info("The following parameters have been modified:\n") logger.info(diff_phil) # Print help if no input if not params.input.experiments or not params.input.reflections: parser.print_help() return # Load files input_expts = ExperimentListFactory.from_json_file( params.input.experiments[0].filename, check_format=False ) input_refls = reflection_table.from_file(params.input.reflections[0].filename) input_refls = input_refls.select( input_refls["wavelength"] != 0 ) # Remove unindexed reflections # Remove duplicate expt + refl data params.input.experiments = None params.input.reflections = None # Get initial time for process start_time = time.time() # Prepare parallel input ids = list(np.unique(input_refls["id"]).astype(np.int32)) expts_arr = [] refls_arr = [] for i in ids: expts_arr.append(ExperimentList([input_expts[i]])) refls_arr.append(input_refls.select(flex.bool(input_refls["id"] == i))) inputs = list(zip(repeat(params), expts_arr, refls_arr)) # Refine data num_processes = params.nproc with Pool(processes=num_processes) as pool: output = pool.starmap(refine_image, inputs) # Initialize arrays for final results total_refined_expts = ExperimentList() total_refined_refls = reflection_table() # Convert refined data to DIALS objects for i in ids: total_refined_expts.extend(output[i][0]) total_refined_refls.extend(output[i][1]) # Correct any mismatching identifiers final_expts, final_refls = correct_identifiers( total_refined_expts, total_refined_refls ) # Save data logger.info("Saving refined experiments to %s", params.output.experiments) final_expts.as_file(params.output.experiments) logger.info("Saving refined reflections to %s", params.output.reflections) final_refls.as_file(filename=params.output.reflections) # Final logs logger.info("") logger.info("Time Taken Refining = %f seconds", time.time() - start_time)
if __name__ == "__main__": run()