# Author: Carsten Sachse 11-Aug-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from spring.csinfrastr.csdatabase import SpringDataBase, base
from spring.csinfrastr.csfeatures import Features
from spring.csinfrastr.csproductivity import OpenMpi
from spring.micprgs.micctfdetermine import MicCtfDetermine, MicCtfDeterminePar
from tabulate import tabulate
import os
import shutil
[docs]class ScanMpi(object):
[docs] def startup_scan_mpi_programs(self):
self.comm, self.rank, self.size, self.log, self.tempdir = OpenMpi().setup_mpi_and_simultaneous_logging(self.log,
self.feature_set.logfile, self.temppath)
self.outfiles = Features().rename_series_of_output_files(self.micrograph_files, self.outfile)
self.log.fcttolog()
if self.rank == 0:
if len(self.micrograph_files) == 0:
msg = 'No micrographs found in {0}.'.format(self.infile)
raise ValueError(msg)
elif len(self.micrograph_files) < self.size:
msg = 'You requested a larger number of CPUs than microcraphs. To optimally make use of the ' + \
'resources. Please reduce number of requested CPUs to ' + \
'{0} for {0} micrographs.'.format(len(self.micrograph_files))
raise ValueError(msg)
self.micrograph_files = OpenMpi().split_sequence_evenly(self.micrograph_files, self.size)
self.outfiles = OpenMpi().split_sequence_evenly(self.outfiles, self.size)
table_data = [ (each_rank_id, ', '.join(each_fileset))
for each_rank_id, each_fileset in enumerate(self.micrograph_files)]
msg = tabulate(table_data, ['node_id', 'micrographs'])
log_str = 'The following nodes will handle the following micrographs:\n{0}'.format(msg)
self.log.ilog(log_str)
else:
self.micrograph_files = None
self.micrograph_files = self.comm.scatter(self.micrograph_files, root=0)
self.outfiles = self.comm.scatter(self.outfiles, root=0)
[docs] def end_scan_mpi_programs(self):
self.comm.barrier()
if self.rank == 0:
self.log.endlog(self.feature_set)
os.rmdir(self.tempdir)
[docs]class MicCtfDetermineMpi(MicCtfDetermine, ScanMpi):
[docs] def gather_ctf_and_enter_in_database(self, ctffind_params_list, ctftilt_params_list):
ctffinds = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(ctffind_params_list)
ctftilts = OpenMpi().convert_list_of_namedtuples_to_list_of_lists(ctftilt_params_list)
ctffinds = self.comm.gather(ctffinds, root=0)
ctftilts = self.comm.gather(ctftilts, root=0)
micrograph_files = self.comm.gather(self.micrograph_files, root=0)
if self.rank == 0:
ctffinds = OpenMpi().merge_sequence_of_sequences(ctffinds)
ctftilts = OpenMpi().merge_sequence_of_sequences(ctftilts)
micrograph_files = OpenMpi().merge_sequence_of_sequences(micrograph_files)
ctffind_tuple = self.make_ctffind_parameters_named_tuple()
ctffind_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctffinds,
ctffind_tuple)
ctftilt_tuple = self.make_ctftilt_parameters_named_tuple()
ctftilt_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctftilts,
ctftilt_tuple)
self.enter_ctffind_and_ctftilt_values_in_database(micrograph_files, ctffind_params_combined_list,
ctftilt_params_combined_list)
[docs] def fill_micrographs_list_with_dummy(self, micrograph_files, max_micrograph_count):
"""
>>> from spring.micprgs.micctfdetermine_mpi import MicCtfDetermineMpi
>>> MicCtfDetermineMpi().fill_micrographs_list_with_dummy(['dim', 'dum'], 4)
['dim', 'dum', 'place_holder', 'place_holder']
>>> MicCtfDetermineMpi().fill_micrographs_list_with_dummy(['dim', 'dum'], 2)
['dim', 'dum']
"""
while len(micrograph_files) < max_micrograph_count:
micrograph_files += ['place_holder']
return micrograph_files
[docs] def insure_that_every_node_has_the_same_number_of_micrographs(self, micrograph_files):
micrograph_count = self.comm.gather(len(micrograph_files), root=0)
if self.rank == 0:
max_micrograph_count = max(micrograph_count)
else:
max_micrograph_count = None
max_micrograph_count = self.comm.bcast(max_micrograph_count, root=0)
micrograph_files = self.fill_micrographs_list_with_dummy(micrograph_files, max_micrograph_count)
return micrograph_files
[docs] def run_ctffind_and_ctftilt_for_given_micrographs(self, micrograph_files, outfiles):
if self.rank == 0:
if self.spring_db_option:
shutil.copy(self.spring_path, 'spring.db')
else:
SpringDataBase().setup_sqlite_db(base)
self.comm.barrier()
session, ctf_parameters, micrograph_files = self.setup_database_and_ctfinfo(micrograph_files)
micrograph_files = self.insure_that_every_node_has_the_same_number_of_micrographs(micrograph_files)
self.log.plog(10)
for each_micrograph_index, each_micrograph_file in enumerate(micrograph_files):
if each_micrograph_file != 'place_holder':
ctffind_parameters, ctftilt_parameters = \
self.run_ctffind_and_ctftilt_for_each_micrograph(micrograph_files, outfiles, each_micrograph_index,
each_micrograph_file)
else:
np_ctffind = self.make_ctffind_parameters_named_tuple()
ctffind_parameters = np_ctffind._make(5 * [None])
ctftilt_parameters = self.make_empty_ctftilt_parameters()
self.comm.barrier()
ctffinds = OpenMpi().convert_list_of_namedtuples_to_list_of_lists([ctffind_parameters])
ctftilts = OpenMpi().convert_list_of_namedtuples_to_list_of_lists([ctftilt_parameters])
those_micrographs = self.comm.gather(each_micrograph_file, root=0)
ctffinds = self.comm.gather(ctffinds, root=0)
ctftilts = self.comm.gather(ctftilts, root=0)
if self.rank == 0:
ctffinds = OpenMpi().merge_sequence_of_sequences(ctffinds)
ctftilts = OpenMpi().merge_sequence_of_sequences(ctftilts)
ctffind_tuple = self.make_ctffind_parameters_named_tuple()
ctffind_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctffinds,
ctffind_tuple)
ctftilt_tuple = self.make_ctftilt_parameters_named_tuple()
ctftilt_params_combined_list = OpenMpi().convert_list_of_lists_to_list_of_provided_namedtuple(ctftilts,
ctftilt_tuple)
zipped_info = zip(those_micrographs, ctffind_params_combined_list, ctftilt_params_combined_list)
for that_micrograph_file, each_ctffind, each_ctftilt in zipped_info:
if each_ctffind.defocus1 is not None:
session = self.enter_ctffind_values_in_database(session, that_micrograph_file,
self.ori_pixelsize, ctf_parameters, each_ctffind)
if each_ctftilt.defocus1 is not None:
session = self.enter_ctftilt_values_in_database(session, that_micrograph_file,
ctf_parameters.pixelsize, each_ctftilt)
session.commit()
self.ctftilt_parameters = ctftilt_parameters
[docs] def determine_ctf(self):
self.startup_scan_mpi_programs()
if self.micrograph_files != []:
self.run_ctffind_and_ctftilt_for_given_micrographs(self.micrograph_files, self.outfiles)
self.comm.barrier()
self.end_scan_mpi_programs()
[docs]def main():
parset = MicCtfDeterminePar()
reduced_parset = OpenMpi().start_main_mpi(parset)
####### Program
micrograph = MicCtfDetermineMpi(reduced_parset)
micrograph.determine_ctf()
if __name__ == '__main__':
main()