!==============================================================================!
!                                                                              |
!              ||===\\                                                         | 
!              ||    \\                                                        |
!              ||     ||   //==\\   ||  ||   //==||  ||/==\\                   |
!              ||     ||  ||    ||  ||  ||  ||   ||  ||    ||                  |
!              ||    //   ||    ||  ||  ||  ||   ||  ||                        |
!              ||===//     \\==//    \\==\\  \\==\\  ||                        |
!                                                                              |
!==============================================================================!
!==============================================================================!
!                                                                              |
!              WSMP_SETUP      Oct. 2008                                       |
!              WSMP_CLEANUP    Oct. 2008                                       |
!                                                                              |
!==============================================================================!
!==============================================================================!

subroutine wsmp_setup (ndof,vo,wsmp,osolve,params,threadinfo,istep,iter)
use definitions
use threads
implicit none

integer ndof
type (void) vo
type (topology),dimension(:),allocatable::tpl
type (wsmp_data) wsmp
type (octreesolve) osolve
type (parameters) params
type (thread) threadinfo
integer istep,iter

integer icol(100000),kk,nn,n0,jproc,idg_locp,nheightp,iatemp
integer ie,k1,k2,j,i2,i1,idof1,idof2,k,err,i3,k3,i,iloc
integer iproc,nproc,ierr
character*72  :: shift

INCLUDE 'mpif.h'

call mpi_comm_size (mpi_comm_world,nproc,ierr)
call mpi_comm_rank (mpi_comm_world,iproc,ierr)

shift=' '

!==============================================================================!
!==============================================================================!

! wsmp%n is the size of the system/matrix
! wsmp%nz is the total number of non-zero terms in the global matrix
! wsmp%n_iproc is the number of columns on given thread iproc
! wsmp%n_iproc_st is the first column on given thread iproc, between 1 and n
! wsmp%n_iproc_end is the last column on given thread iproc, between 1 and n
! wsmp%nz_loc is the total number of non-zero terms on given thread iproc
! wsmp%ia is an integer array of size n_iproc+1 (see wsmp manual)
! wsmp%ja is an interer array of size nz_loc
! wsmp%avals is a double precision array of size nz_loc containing the values
! wsmp%b is a double precision array of size n_iproc containing the rhs


!==============================================================================!
!=====[find matrix size]=======================================================!
!==============================================================================!

wsmp%n = vo%nnode * ndof

!==============================================================================!
!=====[allocate memory for topology]===========================================!
!==============================================================================!

allocate (tpl(wsmp%n),stat=err) 
tpl%nheight=0

if (ndof==3) then
   tpl%nheightmax=17*ndof
elseif(ndof==1) then
   tpl%nheightmax=27*ndof
else
   stop 'pb in wsmp_Setup'
end if

do i=1,wsmp%n
   allocate (tpl(i)%icol(tpl(i)%nheightmax),stat=err) 
   tpl(i)%icol(:)=0
enddo

call heap (threadinfo,'tpl(:)%icol(:)','wsmp_setup',wsmp%n*tpl%nheightmax,'int',+1)
 
!==============================================================================!
!=====[find_connectivity_dimension]============================================!
!==============================================================================!
! the list of non-zero terms is computed for every column
! i1 is a column index
! i2 is a row index
! for symmetric systems, we only build and store the lower triangular part 
! of the matrix, i.e. i2.ge.i1

do ie=1,osolve%nleaves
   if (vo%leaf(ie).eq.0) then
      do k1=1,params%mpe
         do idof1=1,ndof
            i1=(vo%ftr(osolve%icon(k1,ie))-1)*ndof+idof1
            do k2=1,params%mpe
               do idof2=1,ndof
                  i2=(vo%ftr(osolve%icon(k2,ie))-1)*ndof+idof2
                  if (i1.gt.i2.and.ndof.gt.1) goto 2
                  do j=1,tpl(i1)%nheight
                     if (i2.eq.tpl(i1)%icol(j)) goto 2
                  enddo
                  if (tpl(i1)%nheight.eq.tpl(i1)%nheightmax) then
                     do k=1,tpl(i1)%nheight
                        icol(k)=tpl(i1)%icol(k)
                     enddo
                     if (tpl(i1)%nheight.gt.0) deallocate (tpl(i1)%icol)
                     tpl(i1)%nheightmax=tpl(i1)%nheightmax+ndof
                     allocate (tpl(i1)%icol(tpl(i1)%nheightmax),stat=err)
                     tpl(i1)%icol(:)=0
                     if (err.ne.0) call stop_run ('Error alloc tpl(i1)%icol$')
                     do k=1,tpl(i1)%nheight
                        tpl(i1)%icol(k)=icol(k)
                     enddo
                  endif
                  tpl(i1)%nheight=tpl(i1)%nheight+1
                  tpl(i1)%icol(tpl(i1)%nheight)=i2
2                 continue
               enddo
            enddo
         enddo
      enddo
   endif
enddo

!print *,iproc,'nface=',osolve%nface

do ie=1,osolve%nface
   if (vo%face(ie).eq.0) then
      do k1=1,9
         do idof1=1,ndof
            i1=(vo%ftr(osolve%iface(k1,ie))-1)*ndof+idof1
            do k2=1,9
               do idof2=1,ndof
                  i2=(vo%ftr(osolve%iface(k2,ie))-1)*ndof+idof2
                  if (i1.gt.i2.and.ndof.gt.1) goto 3             !!!!!!NEW
                  do j=1,tpl(i1)%nheight
                     if (i2.eq.tpl(i1)%icol(j)) goto 3
                  enddo
                  if (tpl(i1)%nheight.eq.tpl(i1)%nheightmax) then
                     do k=1,tpl(i1)%nheight
                        icol(k)=tpl(i1)%icol(k)
                     enddo
                     if (tpl(i1)%nheight.gt.0) deallocate (tpl(i1)%icol)
                     tpl(i1)%nheightmax=tpl(i1)%nheightmax+ndof
                     allocate (tpl(i1)%icol(tpl(i1)%nheightmax),stat=err)
                     if (err.ne.0) call stop_run ('Error alloc tpl(i1)%icol$')
                     do k=1,tpl(i1)%nheight
                        tpl(i1)%icol(k)=icol(k)
                     enddo
                  endif
                  tpl(i1)%nheight=tpl(i1)%nheight+1
                  tpl(i1)%icol(tpl(i1)%nheight)=i2
3                 continue
               enddo
            enddo
         enddo
      enddo
   endif
enddo

!print *,iproc,'min/max tpl%nheight',minval(tpl(:)%nheight),maxval(tpl(:)%nheight)
!print *,iproc,'avrg tpl%nheight',sum(tpl(:)%nheight)/dble(wsmp%n)

!==============================================================================!
!=====[computing total number of non-zero terms]===============================!
!==============================================================================!
! the total number of non-zero terms is the sum of the height of each column

wsmp%nz=sum(tpl(1:vo%nnode*ndof)%nheight)

if (iproc.eq.0) write(*,'(a,i9)') shift//'n =',wsmp%n
if (iproc.eq.0) write(*,'(a,i9)') shift//'nz=',wsmp%nz

!==============================================================================!
!=====[sorting icol's]=========================================================!
!==============================================================================!
! wsmp requires the indices of the row contained in each column to be sorted

do i=1,wsmp%n
   call iqsort_s (tpl(i)%icol,tpl(i)%nheight)
end do

!==============================================================================!
!=====[allocating memory]======================================================!
!==============================================================================!

allocate(wsmp%irn(wsmp%nz))      ; call heap (threadinfo,'wsmp%irn','wsmp_setup',size(wsmp%irn),'int',+1)
allocate(wsmp%jcn(wsmp%nz))      ; call heap (threadinfo,'wsmp%jcn','wsmp_setup',size(wsmp%jcn),'int',+1)
allocate(wsmp%idg(wsmp%n+1))     ; call heap (threadinfo,'wsmp%idg','wsmp_setup',size(wsmp%idg),'int',+1)
allocate(wsmp%iproc_col(wsmp%n)) ; call heap (threadinfo,'wsmp%iproc_col','wsmp_setup',size(wsmp%iproc_col),'bool',+1)

wsmp%irn=0
wsmp%jcn=0
wsmp%idg=0

!==============================================================================!
!=====[find_connectivity]======================================================!
!==============================================================================!
! the coordinates (irn,jcn) of the non zero terms in the global matrix are 
! computed, as well as the indices of the diagonal terms in the list of dof's.

!if (ndof==1) then
!print *,iproc,'irn size',size(wsmp%irn)
!print *,iproc,'jcn size',size(wsmp%jcn)
!print *,iproc,'nz ',wsmp%nz
!end if

kk=0
do i=1,wsmp%n
   do j=1,tpl(i)%nheight
      kk=kk+1
      if (kk.gt.wsmp%nz) call stop_run ('error in find_connectivity$')
      wsmp%irn(kk)=tpl(i)%icol(j)
      wsmp%jcn(kk)=i
   enddo
enddo

!wsmp%idg(1)=1
!do i=2,wsmp%n+1
!   wsmp%idg(i)=wsmp%idg(i-1)+tpl(i-1)%nheight
!enddo


!==============================================================================!
!=====[find_processors]========================================================!
!==============================================================================!

nn  =(wsmp%n-1)/nproc+1
wsmp%n_iproc_st     = nn*iproc+1
wsmp%n_iproc_end    = min(wsmp%n_iproc_st+nn-1,wsmp%n)
wsmp%n_iproc        = wsmp%n_iproc_end - wsmp%n_iproc_st +1

wsmp%iproc_col=.false.
wsmp%iproc_col(wsmp%n_iproc_st:wsmp%n_iproc_end)=.true.

wsmp%nz_loc=sum(tpl(:)%nheight,wsmp%iproc_col)

!==============================================================================!
!=====[allocating memory]======================================================!
!==============================================================================!

allocate(wsmp%ja(wsmp%nz_loc))    ; call heap (threadinfo,'wsmp%ja','wsmp_setup',size(wsmp%ja),'int',+1)
allocate(wsmp%ia(wsmp%n_iproc+1)) ; call heap (threadinfo,'wsmp%ia','wsmp_setup',size(wsmp%ia),'int',+1)

!allocate(wsmp%irn_loc(wsmp%nz_loc))
!allocate(wsmp%jcn_loc(wsmp%nz_loc))
!allocate(wsmp%idg_loc(wsmp%n+1))

!wsmp%irn_loc=0
!wsmp%jcn_loc=0
!wsmp%idg_loc=0

!==============================================================================!
!=====[find_connectivity_local]================================================!
!==============================================================================!

!idg_locp=0

!do i=1,wsmp%n
!   if (wsmp%iproc_col(i)) then
!      if (idg_locp.eq.0) then
!         wsmp%idg_loc(i)=1
!      else
!         wsmp%idg_loc(i)=idg_locp+nheightp
!      endif
!      idg_locp=wsmp%idg_loc(i)
!      nheightp=tpl(i)%nheight
!   endif
!enddo

!wsmp%idg_loc(wsmp%n+1)=idg_locp+nheightp

iloc=1

do i=1,wsmp%n
   if (wsmp%iproc_col(i)) then
      do k=1,tpl(i)%nheight
!         wsmp%irn_loc(iloc+k-1)=wsmp%irn(wsmp%idg(i)+k-1) 
!         wsmp%irn_loc(iloc+k-1)=tpl(i)%icol(k)
         wsmp%ja(iloc+k-1)=tpl(i)%icol(k)
!         wsmp%jcn_loc(iloc+k-1)=wsmp%jcn(wsmp%idg(i)+k-1)
!         wsmp%jcn_loc(iloc+k-1)=i
      enddo
      iloc=iloc+tpl(i)%nheight
   endif
enddo


!==============================================================================!
!=====[outputs visual reprenstation of matrix]=================================!
!==============================================================================!

call visualise_matrix (params%visualise_matrix,wsmp%nz,wsmp%irn,wsmp%jcn,wsmp%n,istep,iter,ndof)

!==============================================================================!
!=====[allocate memory for matrix values]======================================!
!==============================================================================!

allocate(wsmp%avals(wsmp%nz_loc))  ; call heap (threadinfo,'wsmp%avals','wsmp_setup',size(wsmp%avals),'dp',+1)
wsmp%avals=0.d0

wsmp%nrhs = 1
wsmp%ldb  = wsmp%n_iproc
wsmp%naux = 0

wsmp%mrp=0

allocate(wsmp%b    (wsmp%ldb,wsmp%nrhs)) ; call heap (threadinfo,'wsmp%b','wsmp_setup',size(wsmp%b),'dp',+1)
wsmp%b=0.d0

if (ndof==1) then
   allocate(wsmp%rmisc(wsmp%ldb,wsmp%nrhs)) ; call heap (threadinfo,'wsmp%rmisc','wsmp_setup',size(wsmp%rmisc),'dp',+1)
   wsmp%rmisc=0.d0
end if

if (ndof==3) then
allocate(wsmp%perm (wsmp%n))       ; call heap (threadinfo,'wsmp%perm','wsmp_setup',size(wsmp%perm),'int',+1)
allocate(wsmp%invp (wsmp%n))       ; call heap (threadinfo,'wsmp%invp','wsmp_setup',size(wsmp%invp),'int',+1)
wsmp%perm=0
wsmp%invp=0
end if

!==============================================================================!
!=====[build ia,ja]============================================================!
!==============================================================================!

iatemp=0
k=0
do i=wsmp%n_iproc_st,wsmp%n_iproc_end
   k=k+1
   if (iatemp==0) then
      wsmp%ia(k)=1
   else
      wsmp%ia(k)=iatemp+nheightp
   end if
   iatemp=wsmp%ia(k)
   nheightp=tpl(i)%nheight
end do 
k=k+1
wsmp%ia(k)=iatemp+nheightp


!if (ndof==1) then
!write(*,*) iproc,'n , nz',wsmp%n,wsmp%nz
!write(*,*) iproc,'col st/end ',wsmp%n_iproc_st,wsmp%n_iproc_end
!write(*,*) iproc,'n_iproc',wsmp%n_iproc
!write(*,*) iproc,'nz_loc',wsmp%nz_loc
!write(*,*) iproc,'min/max ia',minval(wsmp%ia),maxval(wsmp%ia)
!write(*,*) iproc,'min/max ja',minval(wsmp%ja),maxval(wsmp%ja)
!end if


!==============================================================================!
!=====[deallocate memory]======================================================!
!==============================================================================!

call heap (threadinfo,'wsmp%irn','wsmp_setup',size(wsmp%irn),'int',-1) ; deallocate (wsmp%irn)
call heap (threadinfo,'wsmp%jcn','wsmp_setup',size(wsmp%jcn),'int',-1) ; deallocate (wsmp%jcn)
call heap (threadinfo,'wsmp%idg','wsmp_setup',size(wsmp%idg),'int',-1) ; deallocate (wsmp%idg)
call heap (threadinfo,'tpl(:)%icol(:)','wsmp_setup',wsmp%n*tpl%nheightmax,'int',-1)

do i=1,wsmp%n
   deallocate (tpl(i)%icol)
enddo
deallocate (tpl)

end subroutine wsmp_setup


!==============================================================================!
!==============================================================================!
!                                                                              |
!              WSMP_CLEANUP    Oct. 2008                                       |
!                                                                              |
!==============================================================================!
!==============================================================================!

subroutine wsmp_cleanup (wsmp,threadinfo)
use definitions
use threads
implicit none

type(wsmp_data) wsmp
type (thread) threadinfo

call heap (threadinfo,'wsmp%ia','wsmp_cleanup',size(wsmp%ia),'int',-1)                ; deallocate(wsmp%ia)
call heap (threadinfo,'wsmp%ja','wsmp_cleanup',size(wsmp%ja),'int',-1)                ; deallocate(wsmp%ja)
call heap (threadinfo,'wsmp%iproc_col','wsmp_cleanup',size(wsmp%iproc_col),'bool',-1) ; deallocate(wsmp%iproc_col)
call heap (threadinfo,'wsmp%avals','wsmp_cleanup',size(wsmp%avals),'dp',-1)           ; deallocate(wsmp%avals) 
call heap (threadinfo,'wsmp%b','wsmp_cleanup',size(wsmp%b),'dp',-1)                   ; deallocate(wsmp%b)

if (allocated(wsmp%perm)) then
   call heap (threadinfo,'wsmp%perm','wsmp_cleanup',size(wsmp%perm),'int',-1)
   deallocate(wsmp%perm)
end if

if (allocated(wsmp%invp)) then
   call heap (threadinfo,'wsmp%invp','wsmp_cleanup',size(wsmp%invp),'int',-1)
   deallocate(wsmp%invp)
end if

if (allocated(wsmp%rmisc)) then
   call heap (threadinfo,'wsmp%rmisc','wsmp_cleanup',size(wsmp%rmisc),'dp',-1)
   deallocate(wsmp%rmisc) 
end if

!call pwsmp_clear()

end subroutine wsmp_cleanup

!==============================================================================!
!==============================================================================!