# Author: Carsten Sachse 08-Jun-2011
# Copyright: EMBL (2010 - 2018), Forschungszentrum Juelich (2019 - 2021)
# License: see license.txt for details
from EMAN2 import EMData, Util, Transform, Vec2f, Reconstructors
from spring.micprgs.scansplit import Micrograph
from spring.segment2d.segment import Segment
from spring.segment3d.segclassreconstruct_prep import SegClassReconstructCylinderMask
from tabulate import tabulate
from utilities import compose_transform2, model_circle, get_sym
import numpy as np
[docs]class SegClassReconstructAssist(SegClassReconstructCylinderMask):                
[docs]    def compute_helical_inplane_rotation_from_Euler_angles(self, phi, theta, psi):
        """
        >>> from spring.segment3d.segclassreconstruct import SegClassReconstruct
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(180, 90, 270)
        0.0
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(0, 90, 270)
        0.0
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(33, 90, 334)
        63.999997560564
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(180, 90, 240)
        329.999999554708
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(150.0, 94.0, 315)
        44.999997495521825
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(150.0, 86.0, 315)
        224.99999749552182
        >>> SegClassReconstruct().compute_helical_inplane_rotation_from_Euler_angles(170.0, 90.0, 315)
        44.999997495521825
        """
        RABTeuler = self.compute_transform_with_aligned_helix_segment(phi, 90.0, psi)
        RABTphi = RABTeuler['phi']
        RABTpsi = RABTeuler['psi']
        
#        delta_inplane_rotation = (RABTpsi + RABTphi) % 360
        if theta >= 90:
            delta_inplane_rotation = (RABTpsi + RABTphi) % 360
        else:
            delta_inplane_rotation = (RABTpsi + RABTphi + 180) % 360
        
        return delta_inplane_rotation 
    
[docs]    def compute_point_where_normal_and_parallel_line_intersect(self, shift_x, shift_y, inplane_angle):
        """
        >>> from spring.segment3d.segclassreconstruct import SegClassReconstruct
        >>> s = SegClassReconstruct()
        >>> s.compute_point_where_normal_and_parallel_line_intersect(3, 1, 0)
        (3.0, -1.8369701987210297e-16)
        >>> s.compute_point_where_normal_and_parallel_line_intersect(3, 1, 45)
        (2.0000000000000004, 2.0)
        >>> s.compute_point_where_normal_and_parallel_line_intersect(3, 1, -45)
        (0.9999999999999998, -0.9999999999999999)
        >>> s.compute_point_where_normal_and_parallel_line_intersect(3, 1, 90)
        (1.7449489701899453e-12, 1.000000000005235)
        """
        if inplane_angle == 90.0 or inplane_angle == 270.0:
            inplane_angle -= 1e-10
            
        if 90.0 + inplane_angle == 0:
            slope = np.tan(np.deg2rad(360.0))
        else:
            slope = np.tan(np.deg2rad(90.0 + inplane_angle))
            
        t = shift_y - slope * shift_x 
        
        point_x = -t / (1/slope + slope)
        point_y = -point_x / slope 
        
        return point_x, point_y 
    
        
[docs]    def compute_distances_to_helical_axis(self, shift_x, shift_y, inplane_angle):
        """
        >>> from spring.segment3d.segclassreconstruct import SegClassReconstruct
        >>> SegClassReconstruct().compute_distances_to_helical_axis(22, 2, 0)
        (-22.0, 2.0000000000000013)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(22, 2, 180)
        (22.0, -2.000000000000004)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(22, 2, 90)
        (-2.000000000038389, -21.99999999999651)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(3, 1, -45)
        (-1.414213562373095, 2.82842712474619)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(3, 1, 45)
        (-2.8284271247461903, -1.4142135623730947)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(22, 2, 20)
        (-21.357277943941323, -5.645057911592895)
        >>> s = SegClassReconstruct()
        >>> hx, hy = s.compute_distances_to_helical_axis(22, 2, 180)
        >>> SegClassReconstruct().compute_distances_to_helical_axis(hx, hy, 180)
        (22.0, 2.0)
        """
        
        point_x, point_y = self.compute_point_where_normal_and_parallel_line_intersect(shift_x, shift_y, inplane_angle)
        dist_x, dst_y = Segment().rotate_coordinates_by_angle(-point_x, -point_y, inplane_angle)
        
        centered_x = shift_x - point_x
        centered_y = shift_y - point_y
        dst_x, dist_y = Segment().rotate_coordinates_by_angle(centered_x, centered_y, inplane_angle)
        
        return dist_x, dist_y 
    
    
[docs]    def prepare_symmetry_computation(self, alignment_parameters, each_symmetry_pair, pixelsize, symmetry_views_count):
        helical_rise = each_symmetry_pair[0] / pixelsize
        helical_rotation = each_symmetry_pair[1]
        symmetry_views = np.arange(-symmetry_views_count / 2, symmetry_views_count / 2, dtype='float64')
        
        x_norm, y_parallel = self.compute_distances_to_helical_axis(alignment_parameters.x_shift,
        alignment_parameters.y_shift, alignment_parameters.inplane_angle)
        
        if helical_rise != 0:
            multiple_off_center = np.round(y_parallel / helical_rise)
        else:
            multiple_off_center = np.float(0)
            
        symmetry_views += multiple_off_center
        out_of_plane_correction = np.sin(np.deg2rad(alignment_parameters.theta))
        
        return symmetry_views, helical_rotation, out_of_plane_correction, helical_rise 
        
[docs]    def compute_straight_phi_and_yshift(self, each_symmetry_view, phi, helical_rotation, out_of_plane_correction,
    helical_rise):
        symmetrized_phi = (phi + each_symmetry_view * helical_rotation) % 360
        straight_yshift = out_of_plane_correction * each_symmetry_view * helical_rise
        
        return straight_yshift, symmetrized_phi 
[docs]    def compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(self, helix_shift_x,
    helix_shift_y, inplane_angle):
        """
        >>> from spring.segment3d.segclassreconstruct import SegClassReconstruct
        >>> hx, hy = SegClassReconstruct().compute_distances_to_helical_axis(3, 1, -45)
        >>> s = SegClassReconstruct()
        >>> s.compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(hx, hy, -45)
        (2.9999999999999996, 1.0000000000000002)
        >>> s.compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(3, 0, -45)
        (-2.121320343559643, 2.1213203435596424)
        >>> hx, hy = SegClassReconstruct().compute_distances_to_helical_axis(22, 2, 20)
        >>> s.compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(hx, hy, 20)
        (22.0, 2.0)
        """
        point_x, point_y = Segment().rotate_coordinates_by_angle(helix_shift_x, 0, -inplane_angle)
        helix_y_from_point = helix_shift_y - point_y
        sx, sy = Segment().rotate_coordinates_by_angle(-point_x, helix_y_from_point, -inplane_angle, -point_x, -point_y)
        
        return sx, sy 
        
    
[docs]    def correct_straight_symmetry_parameters_for_in_plane_rotation_and_shifts_old(self, inplane_angle, x_shift, y_shift,
    straight_yshift):
#        x_contribution, y_contribution = self.compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(0.0, straight_yshift, inplane_angle)
        new_rot, x_contribution, y_contribution, scale = compose_transform2(0.0, 0, straight_yshift, 1.0, inplane_angle,
        0.0, 0.0, 1.0)
        
        total_symmetrized_xshift = x_contribution + x_shift
        total_symmetrized_yshift = y_contribution + y_shift
        
        return total_symmetrized_xshift, total_symmetrized_yshift 
    
    
[docs]    def correct_straight_symmetry_parameters_for_in_plane_rotation_and_shifts(self, inplane_angle, x_shift, y_shift,
    straight_yshift):
        
        x_norm, y_parallel = self.compute_distances_to_helical_axis(x_shift, y_shift, inplane_angle)
        
        total_symmetrized_xshift, total_symmetrized_yshift = \
        
self.compute_sx_sy_from_shifts_normal_and_parallel_to_helix_axis_with_inplane_angle(x_norm, y_parallel + \
        
straight_yshift, inplane_angle)
        
        return -total_symmetrized_xshift, -total_symmetrized_yshift 
    
    
    
[docs]    def setup_reconstructor(self, imgsize):
        fftvol = EMData()
        weight = EMData()
        #ds.humanize_bytes(s.compute_byte_size_of_image_stack(4 * 200,4 * 200,4 * 200)) = 1.91 GB
        if 0 < imgsize <= 200:
            npad = 4
        # ds.humanize_bytes(s.compute_byte_size_of_image_stack(3 * 267,3 * 267,3 * 267)) = 1.91GB
        elif 200 < imgsize <= 267:
            npad = 3
        #ds.humanize_bytes(s.compute_byte_size_of_image_stack(2 * 400,2 * 400,2 * 400)) = 1.91 GB
        elif 267 < imgsize <= 400:
            npad = 2
        elif imgsize > 400: 
            npad = 1.5
        params = {'size':imgsize, 'npad':npad, 'symmetry':'c1', 'fftvol':fftvol, 'weight':weight}
        r = Reconstructors.get('nn4', params)
        r.setup()
        
        return r, fftvol, weight 
    
    
    
    
        
[docs]    def convert_alignment_parameters_to_symmetry_alignment_parameters(self, each_param):
        sym_align_parameters = np.vstack([[each_param.stack_id, each_param.phi, each_param.theta, each_param.psi,
        each_param.x_shift, each_param.y_shift, each_param.mirror, each_param.seg_ref_id]])
        
        return sym_align_parameters, each_param.inplane_angle 
    
    
    
    
[docs]    def project_from_volume_and_backproject(self, alignment_parameters, trimmed_image_size, xvol,
    pxvol):
        for each_image_parameters in alignment_parameters:
            for (each_stack_id, each_phi, each_theta, each_psi, each_x, each_y, each_ref_id) in each_image_parameters:
                
                transform_projection = \
                
self.setup_transform_with_reduced_alignment_parameters(trimmed_image_size, each_phi,
                each_theta, each_psi, each_x, each_y)
                
                myparams = {'transform':transform_projection, 'anglelist':[each_phi, each_theta, each_psi],
                'radius':(trimmed_image_size/2 - 2)}
                
                data = xvol.project('chao', myparams) 
                pxvol += data.backproject('chao', myparams)
        
        return pxvol