# Author: Carsten Sachse 11-Aug-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from spring.csinfrastr.csdatabase import SpringDataBase, base
from spring.csinfrastr.csfeatures import Features
from spring.csinfrastr.csproductivity import OpenMpi
from spring.micprgs.micctfdetermine import MicCtfDetermine, MicCtfDeterminePar
from tabulate import tabulate
import os
import shutil
[docs]class ScanMpi(object):
[docs]    def startup_scan_mpi_programs(self):
        self.comm, self.rank, self.size, self.log, self.tempdir = OpenMpi().setup_mpi_and_simultaneous_logging(self.log,
        self.feature_set.logfile, self.temppath)
        
        self.outfiles = Features().rename_series_of_output_files(self.micrograph_files, self.outfile)
        
        self.log.fcttolog()
        if self.rank == 0:
            if len(self.micrograph_files) == 0:
                msg = 'No micrographs found in {0}.'.format(self.infile)
                raise ValueError(msg)
            elif len(self.micrograph_files) < self.size:
                msg = 'You requested a larger number of CPUs than microcraphs. To optimally make use of the ' + \
                
'resources. Please reduce number of requested CPUs to ' + \
                
'{0} for {0} micrographs.'.format(len(self.micrograph_files))
                raise ValueError(msg)
            self.micrograph_files = OpenMpi().split_sequence_evenly(self.micrograph_files, self.size)
            self.outfiles = OpenMpi().split_sequence_evenly(self.outfiles, self.size)
            
            table_data = [ (each_rank_id, ', '.join(each_fileset)) 
                          for each_rank_id, each_fileset in enumerate(self.micrograph_files)]
            msg = tabulate(table_data, ['node_id', 'micrographs'])
            log_str = 'The following nodes will handle the following micrographs:\n{0}'.format(msg)
            self.log.ilog(log_str)
            
        else:
            self.micrograph_files = None
        self.micrograph_files = self.comm.scatter(self.micrograph_files, root=0)
        self.outfiles = self.comm.scatter(self.outfiles, root=0) 
[docs]    def end_scan_mpi_programs(self):
        self.comm.barrier()
        if self.rank == 0:
            self.log.endlog(self.feature_set)
        os.rmdir(self.tempdir)  
[docs]class MicCtfDetermineMpi(MicCtfDetermine, ScanMpi):
[docs]    def gather_ctf_and_enter_in_database(self, ctffind_params_list, ctftilt_params_list):
        ctffinds = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(ctffind_params_list)
        ctftilts = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(ctftilt_params_list)
        ctffinds = self.comm.gather(ctffinds, root=0)
        ctftilts = self.comm.gather(ctftilts, root=0)
        micrograph_files = self.comm.gather(self.micrograph_files, root=0)
        if self.rank == 0:
            ctffinds = OpenMpi().merge_sequence_of_sequences(ctffinds)
            ctftilts = OpenMpi().merge_sequence_of_sequences(ctftilts)
            micrograph_files = OpenMpi().merge_sequence_of_sequences(micrograph_files)
            ctffind_tuple = self.make_ctffind_parameters_named_tuple()
            
            ctffind_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctffinds,
            ctffind_tuple)
            
            ctftilt_tuple = self.make_ctftilt_parameters_named_tuple()
            
            ctftilt_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctftilts,
            ctftilt_tuple)
            
            self.enter_ctffind_and_ctftilt_values_in_database(micrograph_files, ctffind_params_combined_list,
            ctftilt_params_combined_list) 
[docs]    def fill_micrographs_list_with_dummy(self, micrograph_files, max_micrograph_count):
        """
        >>> from spring.micprgs.micctfdetermine_mpi import MicCtfDetermineMpi
        >>> MicCtfDetermineMpi().fill_micrographs_list_with_dummy(['dim', 'dum'], 4) 
        ['dim', 'dum', 'place_holder', 'place_holder']
        >>> MicCtfDetermineMpi().fill_micrographs_list_with_dummy(['dim', 'dum'], 2) 
        ['dim', 'dum']
        """
        while len(micrograph_files) < max_micrograph_count:
            micrograph_files += ['place_holder']
        
        return micrograph_files 
[docs]    def insure_that_every_node_has_the_same_number_of_micrographs(self, micrograph_files):
        micrograph_count = self.comm.gather(len(micrograph_files), root=0)
        if self.rank == 0:
            max_micrograph_count = max(micrograph_count)
        else:
            max_micrograph_count = None
            
        max_micrograph_count = self.comm.bcast(max_micrograph_count, root=0)
        
        micrograph_files = self.fill_micrographs_list_with_dummy(micrograph_files, max_micrograph_count)
        
        return micrograph_files 
        
        
[docs]    def run_ctffind_and_ctftilt_for_given_micrographs(self, micrograph_files, outfiles):
        if self.rank == 0:
            if self.spring_db_option:
                shutil.copy(self.spring_path, 'spring.db')
            else:
                SpringDataBase().setup_sqlite_db(base)
                
        self.comm.barrier()
        session, ctf_parameters, micrograph_files = self.setup_database_and_ctfinfo(micrograph_files)
        
        micrograph_files = self.insure_that_every_node_has_the_same_number_of_micrographs(micrograph_files)
        
        self.log.plog(10)
        for each_micrograph_index, each_micrograph_file in enumerate(micrograph_files):
            if each_micrograph_file != 'place_holder':
                ctffind_parameters, ctftilt_parameters = \
                
self.run_ctffind_and_ctftilt_for_each_micrograph(micrograph_files, outfiles, each_micrograph_index,
                each_micrograph_file)
            else:
                np_ctffind = self.make_ctffind_parameters_named_tuple()
                ctffind_parameters = np_ctffind._make(5 * [None])
                ctftilt_parameters = self.make_empty_ctftilt_parameters()
            
            self.comm.barrier()
            
            ctffinds = OpenMpi().convert_list_of_namedtuples_to_list_of_lists([ctffind_parameters])
            ctftilts = OpenMpi().convert_list_of_namedtuples_to_list_of_lists([ctftilt_parameters])
            
            those_micrographs = self.comm.gather(each_micrograph_file, root=0)
            ctffinds = self.comm.gather(ctffinds, root=0)
            ctftilts = self.comm.gather(ctftilts, root=0)
            if self.rank == 0:
                ctffinds = OpenMpi().merge_sequence_of_sequences(ctffinds)
                ctftilts = OpenMpi().merge_sequence_of_sequences(ctftilts)
                ctffind_tuple = self.make_ctffind_parameters_named_tuple()
                
                ctffind_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctffinds,
                ctffind_tuple)
                
                ctftilt_tuple = self.make_ctftilt_parameters_named_tuple()
                
                ctftilt_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctftilts,
                ctftilt_tuple)
            
                zipped_info = zip(those_micrographs, ctffind_params_combined_list, ctftilt_params_combined_list)
                for that_micrograph_file, each_ctffind, each_ctftilt in zipped_info:
                    if each_ctffind.defocus1 is not None:
                        session = self.enter_ctffind_values_in_database(session, that_micrograph_file,
                        self.ori_pixelsize, ctf_parameters, each_ctffind)
            
                    if each_ctftilt.defocus1 is not None:
                        session = self.enter_ctftilt_values_in_database(session, that_micrograph_file,
                        ctf_parameters.pixelsize, each_ctftilt)
            
                session.commit()
        
        self.ctftilt_parameters = ctftilt_parameters 
        
[docs]    def determine_ctf(self):
        self.startup_scan_mpi_programs()
        
        if self.micrograph_files != []:
            self.run_ctffind_and_ctftilt_for_given_micrographs(self.micrograph_files, self.outfiles)
        
        self.comm.barrier()
                
        self.end_scan_mpi_programs()  
[docs]def main():
    parset = MicCtfDeterminePar()
    reduced_parset = OpenMpi().start_main_mpi(parset)
    
    ####### Program
    micrograph = MicCtfDetermineMpi(reduced_parset)
    micrograph.determine_ctf() 
if __name__ == '__main__':
    main()