#!/usr/bin/env python
"""
This script optimizes Miller indices and wavelengths jointly.
"""
import logging
import sys
import time
from itertools import repeat
from multiprocessing import Pool
import libtbx.phil
import numpy as np
from dials.array_family import flex
from dials.array_family.flex import reflection_table
from dials.util import show_mail_handle_errors
from dials.util.options import (ArgumentParser,
                                reflections_and_experiments_from_files)
from dxtbx.model import ExperimentList
from laue_dials.utils.matching import split_stills_by_image
from laue_dials.utils.version import laue_version
logger = logging.getLogger("laue-dials.command_line.optimize_indexing")
help_message = """
This script optimizes Miller indices and wavelengths jointly.
This program takes initial monochromatic estimates of the geometry in a
DIALS experiment list and reflection table, and optimizes the indexed
solution by allowing the wavelength to vary. Only Miller indices, s1
vectors, and wavelengths for scattered reflections are overwritten by
this program - all other geometry variables remain constant here.
The outputs are a pair of files (optimized.expt, optimized.refl), which
mimic the input files but with updated Miller indices, s1 vectors, and
wavelengths in the reflection table. Note that the experiment list is
unchanged entirely.
Examples:
    laue.optimize_indexing [options] monochromatic.expt monochromatic.refl
"""
# Set the phil scope
phil_scope = libtbx.phil.parse(
    """
output {
  experiments = 'optimized.expt'
    .type = str
    .help = "The output experiment list filename."
  reflections = 'optimized.refl'
    .type = str
    .help = "The output reflection table filename."
  log = 'laue.optimize_indexing.log'
    .type = str
    .help = "The log filename."
  }
geometry {
  unit_cell = None
    .type = floats(size=6)
    .help = "Target unit cell for indexing."
}
filter_spectrum = True
  .type = bool
  .help = Whether to remove reflections outside of the provided wavelength limits.
keep_unindexed = False
  .type = bool
  .help = Whether to keep unindexed reflections.
nproc = 1
  .type = int
  .help = Number of parallel processes to run.
n_macrocycles = 3
  .type = int(value_min=0)
  .help = "Number of macrocycles of index optimization to perform. A value of 0 performs one round of assignment without outlier rejection or crystal rotation."
wavelengths {
  lam_min = None
    .type = float(value_min=0.1)
    .help = "Minimum wavelength for beam spectrum."
  lam_max = None
    .type = float(value_min=0.2)
    .help = "Maximum wavelength for beam spectrum."
  }
reciprocal_grid {
  d_min = None
    .type = float(value_min=0.1)
    .help = "Minimum d-spacing for reflecting planes."
  }
""",
    process_includes=True,
)
working_phil = phil_scope.fetch(sources=[phil_scope])
[docs]
def index_image(params, refls, expts):
    """
    Reindex the given image reflections and update the experiment geometry.
    Args:
        params (libtbx.phil.scope_extract): Program parameters.
        refls (dials.array_family.flex.reflection_table): Reflection table for a single image.
        expts (dxtbx.model.ExperimentList): List of experiment objects.
    Returns:
        expts (dxtbx.model.experiment_list.ExperimentList): The optimized experiment list with updated crystal rotations.
        refls (dials.array_family.flex.reflection_table): The optimized reflection table with updated wavelengths.
    """
    import gemmi
    from cctbx.sgtbx import space_group
    from cctbx.uctbx import unit_cell
    from dials.array_family import flex
    from dxtbx.model import ExperimentList
    from laue_dials.algorithms.laue import LaueAssigner
    if type(expts) != ExperimentList:
        expts = ExperimentList([expts])
    s0 = np.array(expts[0].beam.get_s0())
    # Write to reflection file
    refls["wavelength"] = flex.double(len(refls))
    refls["miller_index"] = flex.miller_index(len(refls))
    refls["harmonics"] = flex.bool([False] * len(refls))
    for i in range(len(expts)):
        # Get experiment data from experiment objects
        experiment = expts[i]
        cryst = experiment.crystal
        spacegroup = gemmi.SpaceGroup(
            cryst.get_space_group().type().universal_hermann_mauguin_symbol()
        )
        # Get unit cell params
        if params.geometry.unit_cell is not None:
            cell_params = params.geometry.unit_cell
            cell = gemmi.UnitCell(*cell_params)
            # Check compatibility with spacegroup
            if not cell.is_compatible_with_spacegroup(spacegroup):
                logger.warning(
                    f"WARNING: User-provided unit cell is incompatible with crystal space group on image {i}. Using crystal unit cell instead."
                )
                cell_params = cryst.get_unit_cell().parameters()
                cell = gemmi.UnitCell(*cell_params)
        else:
            cell_params = cryst.get_unit_cell().parameters()
            cell = gemmi.UnitCell(*cell_params)
        # Generate s vectors
        s1 = refls["s1"].as_numpy_array()
        # Get U matrix
        U = np.asarray(cryst.get_U()).reshape(3, 3)
        # Generate assigner object
        img_ids = np.unique(refls["id"])
        img_id = img_ids[img_ids != -1][0]
        logger.info(f"Reindexing experiment {img_id}.")
        la = LaueAssigner(
            s0,
            s1,
            cell,
            U,
            params.wavelengths.lam_min,
            params.wavelengths.lam_max,
            experiment.beam.get_wavelength(),
            params.reciprocal_grid.d_min,
            spacegroup,
        )
        # Optimize Miller indices
        la.assign()
        for j in range(params.n_macrocycles):
            la.reset_inliers()
            la.update_rotation()
            la.assign()
            la.reject_outliers()
            la.update_rotation()
            la.assign()
        # Update s1 based on new wavelengths
        s1[la._inliers] = la.s1 / la.wav[:, None]
        # Reset crystal parameters based on new geometry
        cryst.set_U(la.R.flatten())
        cryst.set_A(la.RB.flatten())
        cryst.set_B(la.B.flatten())
        cryst.set_space_group(space_group(la.spacegroup.hall))
        cryst.set_unit_cell(unit_cell(la.cell.parameters))
        # Get wavelengths
        spot_wavelengths = np.asarray(la._wav.tolist())
        # Write data to reflections
        refls["s1"] = flex.vec3_double(s1)
        refls["miller_index"] = flex.miller_index(la._H.astype("int").tolist())
        refls["harmonics"] = flex.bool(la._harmonics.tolist())
        refls["wavelength"] = flex.double(spot_wavelengths)
    # Get image ID before filtering reflections
    exp_ids = np.unique(refls["id"].as_numpy_array())
    image_id = exp_ids[exp_ids != -1][0]  # Should only be a single identifier
    # Remove unindexed reflections
    if params.filter_spectrum:
        all_wavelengths = refls["wavelength"].as_numpy_array()
        keep = np.logical_and(
            all_wavelengths >= params.wavelengths.lam_min,
            all_wavelengths <= params.wavelengths.lam_max,
        )
        refls = refls.select(flex.bool(keep))
    # Remove unindexed reflections
    if not params.keep_unindexed:
        all_wavelengths = refls["wavelength"].as_numpy_array()
        keep = all_wavelengths > 0  # Unindexed reflections assigned wavelength of 0
        refls = refls.select(flex.bool(keep))
        refls["id"] = flex.int(np.full(len(refls), image_id))
    else:
        all_wavelengths = refls["wavelength"].as_numpy_array()
        indexed = all_wavelengths > 0
        refls["id"].set_selected(indexed, flex.int(np.full(len(refls), image_id)))
    # Return reindexed expts, refls
    return expts, refls 
[docs]
@show_mail_handle_errors()
def run(args=None, *, phil=working_phil):
    """
    Run the script to optimize Miller indices and wavelengths jointly.
    Args:
        args (list): Command-line arguments.
        phil: The phil scope for the program.
    Returns:
        None
    """
    # Parse arguments
    usage = "laue.optimize_indexing [options] monochromatic.expt monochromatic.refl"
    parser = ArgumentParser(
        usage=usage,
        phil=phil,
        read_reflections=True,
        read_experiments=True,
        check_format=False,
        epilog=help_message,
    )
    params, options = parser.parse_args(args=args, show_diff_phil=False)
    # Configure logging
    console = logging.StreamHandler(sys.stdout)
    fh = logging.FileHandler(params.output.log, mode="w", encoding="utf-8")
    loglevel = logging.INFO
    logger.addHandler(fh)
    logger.addHandler(console)
    logging.captureWarnings(True)
    warning_logger = logging.getLogger("py.warnings")
    warning_logger.addHandler(fh)
    warning_logger.addHandler(console)
    dials_logger = logging.getLogger("dials")
    dials_logger.addHandler(fh)
    dials_logger.addHandler(console)
    dxtbx_logger = logging.getLogger("dxtbx")
    dxtbx_logger.addHandler(fh)
    dxtbx_logger.addHandler(console)
    xfel_logger = logging.getLogger("xfel")
    xfel_logger.addHandler(fh)
    xfel_logger.addHandler(console)
    logger.setLevel(loglevel)
    dials_logger.setLevel(loglevel)
    dxtbx_logger.setLevel(loglevel)
    xfel_logger.setLevel(loglevel)
    fh.setLevel(loglevel)
    # Print version information
    logger.info(laue_version())
    # Log diff phil
    diff_phil = parser.diff_phil.as_str()
    if diff_phil != "":
        logger.info("The following parameters have been modified:\n")
        logger.info(diff_phil)
    # Print help if no input
    if not params.input.experiments or not params.input.reflections:
        parser.print_help()
        return
    reflections, experiments = reflections_and_experiments_from_files(
        params.input.reflections, params.input.experiments
    )
    reflections = reflections[0]  # Get reflection table out of list
    # Sanity checks
    if len(experiments) == 0:
        parser.print_help()
        return
    # Check for valid parameter values
    if params.reciprocal_grid.d_min == None:
        logger.info("Please provide a d_min.")
        return
    elif params.wavelengths.lam_min == None or params.wavelengths.lam_max == None:
        logger.info(
            "Please provide upper and lower boundaries for the wavelength spectrum."
        )
        return
    elif params.wavelengths.lam_min > params.wavelengths.lam_max:
        logger.info("Minimum wavelength cannot be greater than maximum wavelength.")
        return
    # Remove duplicate expt + refl data
    params.input.experiments = None
    params.input.reflections = None
    # Get initial time for process
    start_time = time.time()
    # This will populate ['s1'] & ['rlp'] columns
    reflections.centroid_px_to_mm(experiments)
    reflections.map_centroids_to_reciprocal_space(experiments)
    # Prepare parallel input
    ids = list(np.unique(reflections["id"]).astype(np.int32))
    if -1 in ids:
        ids.remove(-1)
    expts_arr, refls_arr = split_stills_by_image(experiments, reflections)
    inputs = list(zip(repeat(params), refls_arr, expts_arr))
    # Reindex data
    num_processes = params.nproc
    logger.info("Reindexing images.")
    if num_processes == 1:
        output = [index_image(*i) for i in inputs]
    else:
        with Pool(processes=num_processes) as pool:
            output = pool.starmap(index_image, inputs)
    logger.info(f"All images reindexed.")
    # Convert reindexed data to DIALS objects
    total_experiments = ExperimentList()
    total_reflections = reflection_table()
    for i in ids:
        total_experiments.extend(output[i][0])
        total_reflections.extend(output[i][1])
    # Give all unindexed experiments wavelength 0
    hkls = np.asarray(total_reflections["miller_index"])
    sel = np.all(hkls == [0, 0, 0], axis=1)
    lams = np.asarray(total_reflections["wavelength"])
    lams[sel] = 0.0
    total_reflections["wavelength"] = flex.double(lams)
    # Save experiments
    logger.info("Saving optimized experiments to %s", params.output.experiments)
    total_experiments.as_file(params.output.experiments)
    # Save reflections
    logger.info("Saving optimized reflections to %s", params.output.reflections)
    total_reflections.as_file(filename=params.output.reflections)
    # Final logs
    logger.info("")
    logger.info(
        "Time Taken for Total Processing = %f seconds", time.time() - start_time
    ) 
if __name__ == "__main__":
    run()