/*  Difference of Gaussians - Difference of Gaussians Method for nucleus center detection in 3D data 
 *  Written in 2015 by BioEmergences CNRS USR bioemergences@inaf.cnrs-gif.fr
 *  Barbara Rizzi brizzi08@gmail.com
 *  Paul Bourgine paul.bourgine@polytechnique.edu
 *  
 *  To the extent possible under law, the author(s) have dedicated all copyright and related and neighboring rights to this software to the public domain worldwide. This software is distributed without any warranty.
 *  You should have received a copy of the CC BY-NC-SA 4.0 Dedication along with this software. If not, see <https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode>.
*/


#include "centers_detection.h"
#include <assert.h>
#include "timer.h"

#include <itksys/SystemTools.hxx>

#define IND (i+(xSize+2)*(j+k*(ySize+2)))

/*****************************************************************************************/

template < class TInputImage >
typename TInputImage::Pointer ReadingPreprocessingData< TInputImage >::ReadingPreprocessing( std::string nameNuclei, std::string nameMembranes, float weight) 
{	
    typedef itk::ImageRegionIterator< TInputImage > IteratorType;
    typedef itk::ImageRegionConstIterator< TInputImage > ConstIteratorType;	
    typedef typename TInputImage::PixelType DataType;

    //Nuclei
    typename  TInputImage::Pointer imageOutNuclei = TInputImage::New();
    imageOutNuclei = ReadingPreprocessingData::ReadingRescaling( nameNuclei );

    typename TInputImage::Pointer image = TInputImage::New();

    //Membranes
    if( !nameMembranes.empty() ) 
    {	
	cout<<"Combining nuclei and membranes signals"<<endl;
	typename TInputImage::Pointer imageOutMembranes = TInputImage::New();
	imageOutMembranes = ReadingPreprocessingData::ReadingRescaling( nameMembranes );
	      
	image->SetRegions( imageOutNuclei->GetRequestedRegion() );
	image->CopyInformation( imageOutNuclei );
	image->Allocate();
	    
	ConstIteratorType ItN(imageOutNuclei, imageOutNuclei->GetRequestedRegion());	
	ConstIteratorType ItM(imageOutMembranes, imageOutMembranes->GetRequestedRegion());
	IteratorType outIt(image, image->GetRequestedRegion());
	
	for( ItN.GoToBegin(), ItM.GoToBegin(), outIt.GoToBegin(); !ItN.IsAtEnd() && !ItM.IsAtEnd() && !outIt.IsAtEnd(); ++ItN, ++ItM, ++outIt )
	{	
	    DataType value=ItN.Get()*weight-ItM.Get();
	    if ( value > 0 ) outIt.Set( ItN.Get() );
	    else outIt.Set( 0 );
	}
	
	assert(outIt.IsAtEnd() && ItM.IsAtEnd() && ItN.IsAtEnd());
	    
    }
    else { image = imageOutNuclei; cout<<"Only nuclei"<<endl; }
    
    return image;
}


template < class TInputImage >
typename TInputImage::Pointer ReadingPreprocessingData< TInputImage >::ReadingRescaling( std::string name )
{	
    typedef itk::ImageFileReader< TInputImage > ReaderType;
    typename ReaderType::Pointer reader = ReaderType::New();
    
    reader->SetFileName( name );
    try 
    {
	reader->Update();
    }
    catch( itk::ExceptionObject & excp )
    {
	std::cerr << "Reading DATA failure!" << std::endl;
	std::cerr << "\nError : " << excp;
	
	exit(EXIT_FAILURE);
    }	
    typedef itk::RescaleIntensityImageFilter< TInputImage, TInputImage > RescaleFilterType;
    typename RescaleFilterType::Pointer rescaler = RescaleFilterType::New();
    rescaler->SetInput( reader->GetOutput() );
    rescaler->SetOutputMinimum( 0 );
    rescaler->SetOutputMaximum( 1 );
    rescaler->Update();
    
    return rescaler->GetOutput();
}

/*****************************************************************************************/

template < class TInputImage >
typename TInputImage::Pointer DifferenceOfGaussian< TInputImage >::PerformDifferenceOfGaussian( typename TInputImage::Pointer Image, float stdSmall, float stdBig, float threshold, bool fileExists1, bool fileExists2, std::string nameStdSmall, std::string nameStdBig )
{	
    typename TInputImage::Pointer convImageSmall = TInputImage::New();
    typedef itk::ImageFileReader< TInputImage > ReaderType;
    Timer timer;

    // First convolution - stdSmall
    if(fileExists1) // Reading the file if it exists already
    {
	typename ReaderType::Pointer readerStdSmall = ReaderType::New();
	readerStdSmall->SetFileName( nameStdSmall );
	readerStdSmall->Update();
	convImageSmall=readerStdSmall->GetOutput();
	cout<<"file 1 exists already"<<endl;
    }
    else // otherwise, performing the convolution and saving the convolved image
    {
	convImageSmall = DifferenceOfGaussian::GaussianConvolution(Image, stdSmall);
	typedef itk::ImageFileWriter< TInputImage > WriterType;
	typename WriterType::Pointer writerStdSmall = WriterType::New();
	writerStdSmall->SetInput(convImageSmall);
	writerStdSmall->SetFileName( nameStdSmall );
	writerStdSmall->Update();
	cout<<"Ok first convolution"<<endl;	
    }    	
    printf("\n First convolution with std %e time secs : %e\n",stdSmall, timer.timePassed());
		
    typename TInputImage::Pointer convImageBig = TInputImage::New();
    timer.reset();

    // Second convolution - stdBig
    if(fileExists2) // Reading the file if it exists already
    {
    	typename ReaderType::Pointer readerStdBig = ReaderType::New();
    	readerStdBig->SetFileName( nameStdBig );
    	readerStdBig->Update();
	convImageBig=readerStdBig->GetOutput();
	cout<<"file 2 exixts already"<<endl;
    }
    else // otherwise, performing the convolution and saving the convolved image 
    {
    	convImageBig = DifferenceOfGaussian::GaussianConvolution(Image, stdBig);
    	typedef itk::ImageFileWriter< TInputImage > WriterType;
	typename WriterType::Pointer writerStdBig = WriterType::New();
	writerStdBig->SetFileName( nameStdBig );
    	writerStdBig->SetInput(convImageBig);
    	writerStdBig->Update();
    	cout<<"Ok second convolution"<<endl;	
    }
    printf("\n Second convolution with std %e time secs : %e\n",stdBig, timer.timePassed());
    timer.reset();
    
    // Difference and thresholding
    typedef itk::ImageRegionIterator< TInputImage > IteratorType;
    typedef itk::ImageRegionConstIterator< TInputImage > ConstIteratorType;	
		
    ConstIteratorType ItSmall(convImageSmall, convImageSmall->GetRequestedRegion());	
    ConstIteratorType ItBig(convImageBig, convImageBig->GetRequestedRegion());
    IteratorType outIt(Image, Image->GetRequestedRegion());	
	
    float value, value1, value2;
    for( ItSmall.GoToBegin(), ItBig.GoToBegin(), outIt.GoToBegin(); 
		 !ItSmall.IsAtEnd() && !ItBig.IsAtEnd() && !outIt.IsAtEnd(); ++ItSmall, ++ItBig, ++outIt )
    {
	value1=ItSmall.Get();
	value2=ItBig.Get();
	value = value1-value2;
	if ( value>=threshold ) outIt.Set( value );
	else outIt.Set( 0 );
    }
    
    assert(ItSmall.IsAtEnd() && ItBig.IsAtEnd() && outIt.IsAtEnd());
    cout<<"Ok thresholding of difference"<<endl;
	
    typedef itk::RescaleIntensityImageFilter< TInputImage, TInputImage > RescaleFilterType;
    typename RescaleFilterType::Pointer rescaler = RescaleFilterType::New();
    rescaler->SetInput( Image );
    rescaler->SetOutputMinimum( 0 );
    rescaler->SetOutputMaximum( 1 );
    rescaler->Update();
    
    printf("\n thresholding of difference and rescaling time secs : %e\n", timer.timePassed());
    return rescaler->GetOutput();
}

template < class TInputImage >
typename TInputImage::Pointer DifferenceOfGaussian< TInputImage >::GaussianConvolution( typename TInputImage::Pointer Image, float std )
{	
    // Convolving the image 'Image' with a Gaussian with std 'std'
    typedef itk::RecursiveGaussianImageFilter< TInputImage, TInputImage >  GaussianFilterType;
    typename GaussianFilterType::Pointer filterX = GaussianFilterType::New();
    typename GaussianFilterType::Pointer filterY = GaussianFilterType::New();
    typename GaussianFilterType::Pointer filterZ = GaussianFilterType::New();

    filterX->SetDirection( 0 );   // 0 --> X direction
    filterY->SetDirection( 1 );   // 1 --> Y direction
    filterZ->SetDirection( 2 );   // 2 --> Z direction

    filterX->SetOrder( GaussianFilterType::ZeroOrder );
    filterY->SetOrder( GaussianFilterType::ZeroOrder );
    filterZ->SetOrder( GaussianFilterType::ZeroOrder );

    filterX->SetNormalizeAcrossScale( false );
    filterY->SetNormalizeAcrossScale( false );
    filterZ->SetNormalizeAcrossScale( false );

    filterX->SetInput( Image );
    filterY->SetInput( filterX->GetOutput() );
    filterZ->SetInput( filterY->GetOutput() );

    filterX->SetSigma( std );
    filterY->SetSigma( std );
    filterZ->SetSigma( std );
    
    filterZ->Update();
    
    Image=filterZ->GetOutput();
    
    return Image;
}

/*****************************************************************************************/

template < class TInputImage >
Centers< TInputImage  >::Centers( typename TInputImage::Pointer image )
{   	
    xSize = image->GetLargestPossibleRegion().GetSize()[0];  
    ySize = image->GetLargestPossibleRegion().GetSize()[1];
    zSize = image->GetLargestPossibleRegion().GetSize()[2];

    xSpacing = image->GetSpacing()[0];  
    ySpacing = image->GetSpacing()[1];
    zSpacing = image->GetSpacing()[2];
}


template < class TInputImage > 
vtkSmartPointer< vtkPolyData > Centers< TInputImage >::GetCenters( typename  TInputImage::Pointer image  )
{	
    image = FindLocalMaxima( image );  
    vtkSmartPointer< vtkPolyData > nodes = RefineCenters( image );
    
    return nodes;
}

template < class TInputImage >
typename TInputImage::Pointer Centers< TInputImage  >::FindLocalMaxima(typename TInputImage::Pointer image )
{	
    DataType *I;
    DataIndexType index;

    int xSize1 = xSize+2;
    int ySize1 = ySize+2;
    int zSize1 = zSize+2;	
    int DIM  = (xSize1)*(ySize1)*(zSize1);
    int KK = (xSize1)*(ySize1);

    DataType Ix_a, Ix_i, Iy_a, Iy_i, Iz_a, Iz_i, crossx, crossy, crossz;
    DataType Ixx, Iyy, Izz, Ixy, Ixz, Iyz, value, Ix, Iy, Iz;
    int IND1, INDW, INDE, INDS, INDN, INDB, INDT;

    // Data copy in vector I
    I = (DataType *)calloc(DIM,sizeof(DataType));
    for(int k=1; k<zSize1-1; k++)
      for(int j=1; j<ySize1-1; j++)
	for(int i=1; i<xSize1-1; i++)
	{
	  index[0]=i-1;
	  index[1]=j-1;
	  index[2]=k-1;	
	  value = image->GetPixel( index );
	  if (value == 0) continue; 	
	  I[IND] = value;
	  image->SetPixel(index, 0);
    } 	

    // Looking for local maxima
    cout<<"Searching for maxima....."<<endl;
    for(int k=1; k<zSize1-1; k++)
      for(int j=1; j<ySize1-1; j++)
	for(int i=1; i<xSize1-1; i++)
	{    
	  IND1=IND;
	  INDW=IND1-1;
	  INDE=IND1+1;
	  INDS=IND1-xSize1;
	  INDN=IND1+xSize1;
	  INDB=IND1-KK;
	  INDT=IND1+KK;
		      
	  value = I[IND1];
	  if(value==0) continue;

	  Ix = (I[INDE] - I[INDW]);
	  Ix_a  = (I[INDE] - I[IND1]);
	  Ix_i  = (I[IND1] - I[INDW]);

	  Iy  = (I[INDN] - I[INDS]);
	  Iy_a  = (I[INDN] - I[IND1]);	
	  Iy_i  = (I[IND1] - I[INDS]);

	  Iz  = (I[INDT] - I[INDB]);
	  Iz_a  = (I[INDT] - I[IND1]);
	  Iz_i  = (I[IND1] - I[INDB]);

	  crossx=Ix_a*Ix_i;
	  crossy=Iy_a*Iy_i;
	  crossz=Iz_a*Iz_i;
			    
	  if( (crossx<0 || Ix==0) && (crossy<0 || Iy==0) && (crossz<0 || Iz==0) ) 	
	  {	
	    Ixx = (I[INDE] - 2*I[IND1] + I[INDW]);
	    Iyy = (I[INDN] - 2*I[IND1] + I[INDS]);
	    Izz = (I[INDT] - 2*I[IND1] + I[INDB]);

	    Ixy = ((I[INDN+1] - I[INDN-1]) - (I[INDS+1] - I[INDS-1]))*0.25;
	    Ixz = ((I[INDT+1] - I[INDT-1]) - (I[INDB+1] - I[INDB-1]))*0.25;
	    Iyz = ((I[INDT+xSize1] - I[INDT-xSize1]) - (I[INDB+xSize1] - I[INDB-xSize1]))*0.25;

	    DataType Sdata[10] = {	
	    Ixx,   Ixy,   Ixz, 
	    Ixy,   Iyy,   Iyz,
	    Ixz,   Iyz,   Izz,
	    };
	    
	    DataType lambda1, lambda2, lambda3;
	    vnl_matrix<DataType> S(Sdata, 3,3);

	    vnl_symmetric_eigensystem<DataType> eig(S);
	    vnl_matrix<DataType> res = eig.recompose() - S;

	    lambda1 = eig.get_eigenvalue(2);
	    lambda2 = eig.get_eigenvalue(1);
	    lambda3 = eig.get_eigenvalue(0);

	    if( lambda1<0 && lambda2<0 && lambda3<0 )
	    {	index[0]=i-1;
		    index[1]=j-1;
		    index[2]=k-1;
		    image->SetPixel(index, -255);
	    }
	  }		    
    }
    free(I);
    cout<<"\nFound centers"<<endl;
    return image;
}

template < class TInputImage >
vtkSmartPointer< vtkPolyData > Centers< TInputImage >::RefineCenters(typename TInputImage::Pointer image)
{   
    int count = 0;
    int shiftx_e, shiftx_w;
    int shifty_n, shifty_s;
    int shiftz_t, shiftz_b;
    DataType value, value1;	
    DataIndexType index, index1;
    
    vtkSmartPointer< vtkDoubleArray > pcoords = vtkSmartPointer< vtkDoubleArray >::New();
    pcoords->SetNumberOfComponents(3);
    
    vtkSmartPointer< vtkPoints > points = vtkSmartPointer< vtkPoints >::New();
    points->SetData(pcoords);

    vtkSmartPointer< vtkPolyData > polyOutput = vtkSmartPointer< vtkPolyData >::New();
    polyOutput->SetPoints( points );
		
    for(int k=0; k<zSize; k++)
    {	
      for(int j=0; j<ySize; j++)	
	for(int i=0; i<xSize; i++)
	{	
	  index[0]=i;
	  index[1]=j;
	  index[2]=k;
	  
	  value=image->GetPixel(index);
	  if( value != 0 )
	  {	
	    shiftz_b = shiftz_t = shifty_s = shifty_n = shiftx_w = shiftx_e = 1;
			      
	    if( k==0 ) shiftz_b=0;
	    else if( k==zSize-1 ) shiftz_t=0;
				    
	    if( j==0 ) shifty_s=0;
	    else if( j==ySize-1 ) shifty_n=0;
						    
	    if( i==0 ) shiftx_w=0;
	    else if( i==xSize-1 ) shiftx_e=0;

	    count +=1;
	    bool found = false;
	    for(int k2=k-shiftz_b; k2<=k+shiftz_t; k2++)
	      for(int j2=j-shifty_s; j2<=j+shifty_n; j2++)
		for(int i2=i-shiftx_w; i2<=i+shiftx_e; i2++)
		{	
		  index1[0]=i2;
		  index1[1]=j2;
		  index1[2]=k2;
		  
		  value1 = image->GetPixel(index1);
		  if( value1 > 0 )
		  {	
		    image->SetPixel(index, value1);
		    found = true;
		    image->Modified();
		  }
		}
		if( !found)	
		{	
		  image->SetPixel(index, count);
		  image->Modified();	
		}
		else count -=1;
	  }
	}
    }
    image->Modified();

    cout<<"Found "<<count<<" nuclei"<<endl;
    
    // Refining centers
    for( int i=0; i<=count; i++)
    {	
	Neighb tmp;	
	regions.push_back(tmp);
    }

    Position pos;
    for(int k=0; k<zSize; k++)
    {	
      for(int  j=0; j<ySize; j++)
	for(int i=0; i<xSize; i++)
	{	
	  index[0]=i;
	  index[1]=j;
	  index[2]=k;
	  value=image->GetPixel(index);
	  if( value > 0)
	  {	
	    pos.x=i;
	    pos.y=j;
	    pos.z=k;
	    regions[value].connected.push_back(pos);
	  }
	}
    }

    for( int r=1; r<=count; r++)
    {	
      double center_temp[3];
      if( regions[r].connected.size()==1 )
      {	
	center_temp[0] = (regions[r].connected[0].x)*xSpacing;
	center_temp[1] = (regions[r].connected[0].y)*ySpacing;
	center_temp[2] = (regions[r].connected[0].z)*zSpacing;	
	pcoords->InsertNextTuple( center_temp );
      }
      else
      {	
	int dim = regions[r].connected.size();
	DataType mean_x=0;
	DataType mean_y=0;
	DataType mean_z=0;
	for( int iter=0; iter<dim; iter++ )
	{
	  mean_x+=regions[r].connected[iter].x;
	  mean_y+=regions[r].connected[iter].y;
	  mean_z+=regions[r].connected[iter].z;
	}
	  mean_x=mean_x/dim;
	  mean_y=mean_y/dim;
	  mean_z=mean_z/dim;
	  center_temp[0]=mean_x*xSpacing;
	  center_temp[1]=mean_y*ySpacing;
	  center_temp[2]=mean_z*zSpacing;
	  pcoords->InsertNextTuple( center_temp );
      }
    }
    regions.erase(regions.begin());
	
    return polyOutput;
}