wann_main Subroutine

public subroutine wann_main()

Uses

  • proc~~wann_main~~UsesGraph proc~wann_main wann_main module~w90_utility w90_utility proc~wann_main->module~w90_utility module~w90_constants w90_constants proc~wann_main->module~w90_constants module~w90_hamiltonian w90_hamiltonian proc~wann_main->module~w90_hamiltonian module~w90_io w90_io proc~wann_main->module~w90_io module~w90_sitesym w90_sitesym proc~wann_main->module~w90_sitesym module~w90_parameters w90_parameters proc~wann_main->module~w90_parameters module~w90_utility->module~w90_constants module~w90_hamiltonian->module~w90_constants module~w90_comms w90_comms module~w90_hamiltonian->module~w90_comms module~w90_io->module~w90_constants module~w90_sitesym->module~w90_constants module~w90_sitesym->module~w90_io module~w90_parameters->module~w90_constants module~w90_parameters->module~w90_io module~w90_comms->module~w90_constants module~w90_comms->module~w90_io

Calculate the Unitary Rotations to give Maximally Localised Wannier Functions

Arguments

None

Calls

proc~~wann_main~~CallsGraph proc~wann_main wann_main proc~wann_omega wann_omega proc~wann_main->proc~wann_omega proc~hamiltonian_get_hr hamiltonian_get_hr proc~wann_main->proc~hamiltonian_get_hr proc~comms_array_split comms_array_split proc~wann_main->proc~comms_array_split proc~io_error io_error proc~wann_main->proc~io_error proc~wann_spread_copy wann_spread_copy proc~wann_main->proc~wann_spread_copy proc~hamiltonian_setup hamiltonian_setup proc~wann_main->proc~hamiltonian_setup proc~wann_check_unitarity wann_check_unitarity proc~wann_main->proc~wann_check_unitarity proc~io_wallclocktime io_wallclocktime proc~wann_main->proc~io_wallclocktime proc~wann_domega wann_domega proc~wann_main->proc~wann_domega proc~wann_phases wann_phases proc~wann_main->proc~wann_phases proc~io_file_unit io_file_unit proc~wann_main->proc~io_file_unit interface~comms_allreduce comms_allreduce proc~wann_omega->interface~comms_allreduce proc~hamiltonian_wigner_seitz hamiltonian_wigner_seitz proc~hamiltonian_setup->proc~hamiltonian_wigner_seitz proc~wann_check_unitarity->proc~io_error proc~sitesym_symmetrize_gradient sitesym_symmetrize_gradient proc~wann_domega->proc~sitesym_symmetrize_gradient proc~utility_inv3 utility_inv3 proc~wann_phases->proc~utility_inv3 proc~comms_allreduce_real comms_allreduce_real interface~comms_allreduce->proc~comms_allreduce_real proc~comms_allreduce_cmplx comms_allreduce_cmplx interface~comms_allreduce->proc~comms_allreduce_cmplx proc~hamiltonian_wigner_seitz->proc~io_error

Called by

proc~~wann_main~~CalledByGraph proc~wann_main wann_main program~wannier wannier program~wannier->proc~wann_main proc~wannier_run wannier_run proc~wannier_run->proc~wann_main

Contents

Source Code


Source Code

  subroutine wann_main
    !==================================================================!
    !                                                                  !
    !! Calculate the Unitary Rotations to give Maximally Localised Wannier Functions
    !                                                                  !
    !===================================================================
    use w90_constants, only: dp, cmplx_1, cmplx_0, eps2, eps5, eps8
    use w90_io, only: stdout, io_error, io_wallclocktime, io_stopwatch &
      , io_file_unit
    use w90_parameters, only: num_wann, num_cg_steps, num_iter, nnlist, &
      nntot, wbtot, u_matrix, m_matrix, num_kpts, iprint, num_print_cycles, &
      num_dump_cycles, omega_invariant, param_write_chkpt, length_unit, &
      lenconfac, proj_site, real_lattice, write_r2mn, guiding_centres, &
      num_guide_cycles, num_no_guide_iter, timing_level, trial_step, precond, spinors, &
      fixed_step, lfixstep, write_proj, have_disentangled, conv_tol, num_proj, &
      conv_window, conv_noise_amp, conv_noise_num, wannier_centres, write_xyz, &
      wannier_spreads, omega_total, omega_tilde, optimisation, write_vdw_data, &
      write_hr_diag, kpt_latt, bk, ccentres_cart, slwf_num, selective_loc, &
      slwf_constrain, slwf_lambda
    use w90_utility, only: utility_frac_to_cart, utility_zgemm
    use w90_parameters, only: lsitesymmetry                !RS:
    use w90_sitesym, only: sitesym_symmetrize_gradient  !RS:

    !ivo
    use w90_hamiltonian, only: hamiltonian_setup, hamiltonian_get_hr, ham_r, &
      rpt_origin, irvec, nrpts, ndegen

    implicit none

    type(localisation_vars) :: old_spread
    type(localisation_vars) :: wann_spread
    type(localisation_vars) :: trial_spread

    ! guiding centres
    real(kind=dp), allocatable :: rguide(:, :)
    integer :: irguide

    ! local arrays used and passed in subroutines
    complex(kind=dp), allocatable :: csheet(:, :, :)
    complex(kind=dp), allocatable :: cdodq(:, :, :)
    complex(kind=dp), allocatable :: cdodq_r(:, :, :)
    complex(kind=dp), allocatable :: k_to_r(:, :)
    complex(kind=dp), allocatable :: cdodq_precond(:, :, :)
    complex(kind=dp), allocatable :: cdodq_precond_loc(:, :, :)
    real(kind=dp), allocatable :: sheet(:, :, :)
    real(kind=dp), allocatable :: rave(:, :), r2ave(:), rave2(:)
    real(kind=dp), dimension(3) :: rvec_cart

    !local arrays not passed into subroutines
    complex(kind=dp), allocatable  :: cwschur1(:), cwschur2(:)
    complex(kind=dp), allocatable  :: cwschur3(:), cwschur4(:)
    complex(kind=dp), allocatable  :: cdq(:, :, :)!,cdqkeep(:,:,:)
    ! cdqkeep is replaced by cdqkeep_loc
    complex(kind=dp), allocatable  :: cdqkeep_loc(:, :, :)
    complex(kind=dp), allocatable  :: cz(:, :)
    complex(kind=dp), allocatable  :: cmtmp(:, :), tmp_cdq(:, :)
    ! complex(kind=dp), allocatable  :: m0(:,:,:,:),u0(:,:,:)
    ! m0 and u0 are replaced by m0_loc and u0_loc
    complex(kind=dp), allocatable  :: m0_loc(:, :, :, :), u0_loc(:, :, :)
    complex(kind=dp), allocatable  :: cwork(:)
    real(kind=dp), allocatable  :: evals(:)
    real(kind=dp), allocatable  :: rwork(:)

    real(kind=dp) :: doda0
    real(kind=dp) :: falphamin, alphamin
    real(kind=dp) :: gcfac, gcnorm1, gcnorm0
    integer       :: i, n, iter, ind, ierr, iw, ncg, info, nkp, nkp_loc, nn
    logical       :: lprint, ldump, lquad
    real(kind=dp), allocatable :: history(:)
    real(kind=dp)              :: save_spread
    logical                    :: lconverged, lrandom, lfirst
    integer                    :: conv_count, noise_count, page_unit
    complex(kind=dp) :: fac, rdotk
    real(kind=dp) :: alpha_precond
    integer :: irpt, loop_kpt
    logical :: cconverged
    real(kind=dp) :: glpar, cvalue_new
    real(kind=dp), allocatable :: rnr0n2(:)

    if (timing_level > 0 .and. on_root) call io_stopwatch('wann: main', 1)

    first_pass = .true.

    ! Allocate stuff

    allocate (history(conv_window), stat=ierr)
    if (ierr /= 0) call io_error('Error allocating history in wann_main')

    ! module data
!    if(optimisation>0) then
!       allocate(  m0 (num_wann, num_wann, nntot, num_kpts),stat=ierr)
!    end if
!    if (ierr/=0) call io_error('Error in allocating m0 in wann_main')
!    allocate(  u0 (num_wann, num_wann, num_kpts),stat=ierr)
!    if (ierr/=0) call io_error('Error in allocating u0 in wann_main')
    allocate (rnkb(num_wann, nntot, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating rnkb in wann_main')
    allocate (ln_tmp(num_wann, nntot, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating ln_tmp in wann_main')
    if (selective_loc) then
      allocate (rnr0n2(slwf_num), stat=ierr)
      if (ierr /= 0) call io_error('Error in allocating rnr0n2 in wann_main')
    end if

    rnkb = 0.0_dp

    ! sub vars passed into other subs
    allocate (csheet(num_wann, nntot, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating csheet in wann_main')
    allocate (cdodq(num_wann, num_wann, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cdodq in wann_main')
    allocate (sheet(num_wann, nntot, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating sheet in wann_main')
    allocate (rave(3, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating rave in wann_main')
    allocate (r2ave(num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating r2ave in wann_main')
    allocate (rave2(num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating rave2 in wann_main')
    allocate (rguide(3, num_wann))
    if (ierr /= 0) call io_error('Error in allocating rguide in wann_main')

    if (precond) then
      call hamiltonian_setup()
      allocate (cdodq_r(num_wann, num_wann, nrpts), stat=ierr)
      if (ierr /= 0) call io_error('Error in allocating cdodq_r in wann_main')
      allocate (cdodq_precond(num_wann, num_wann, num_kpts), stat=ierr)
      if (ierr /= 0) call io_error('Error in allocating cdodq_precond in wann_main')

      ! this method of computing the preconditioning is much more efficient, but requires more RAM
      if (optimisation >= 3) then
        allocate (k_to_r(num_kpts, nrpts), stat=ierr)
        if (ierr /= 0) call io_error('Error in allocating k_to_r in wann_main')

        do irpt = 1, nrpts
          do loop_kpt = 1, num_kpts
            rdotk = twopi*dot_product(kpt_latt(:, loop_kpt), real(irvec(:, irpt), dp))
            k_to_r(loop_kpt, irpt) = exp(-cmplx_i*rdotk)
          enddo
        enddo
      end if
    end if

    csheet = cmplx_1; cdodq = cmplx_0
    sheet = 0.0_dp; rave = 0.0_dp; r2ave = 0.0_dp; rave2 = 0.0_dp; rguide = 0.0_dp

    ! sub vars not passed into other subs
    allocate (cwschur1(num_wann), cwschur2(10*num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cwshur1 in wann_main')
    allocate (cwschur3(num_wann), cwschur4(num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cwshur3 in wann_main')
    allocate (cdq(num_wann, num_wann, num_kpts), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cdq in wann_main')

    ! for MPI
    if (allocated(counts)) deallocate (counts)
    allocate (counts(0:num_nodes - 1), stat=ierr)
    if (ierr /= 0) then
      call io_error('Error in allocating counts in wann_main')
    end if

    if (allocated(displs)) deallocate (displs)
    allocate (displs(0:num_nodes - 1), stat=ierr)
    if (ierr /= 0) then
      call io_error('Error in allocating displs in wann_main')
    end if
    call comms_array_split(num_kpts, counts, displs)
    allocate (rnkb_loc(num_wann, nntot, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating rnkb_loc in wann_main')
    allocate (ln_tmp_loc(num_wann, nntot, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating ln_tmp_loc in wann_main')
    allocate (u_matrix_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating u_matrix_loc in wann_main')
    allocate (m_matrix_loc(num_wann, num_wann, nntot, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating m_matrix_loc in wann_main')
!    allocate( m_matrix_1b  (num_wann, num_wann, num_kpts),stat=ierr )
!    if (ierr/=0) call io_error('Error in allocating m_matrix_1b in wann_main')
!    allocate( m_matrix_1b_loc  (num_wann, num_wann, max(1,counts(my_node_id))),stat=ierr )
!    if (ierr/=0) call io_error('Error in allocating m_matrix_1b_loc in wann_main')
    if (precond) then
      allocate (cdodq_precond_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
      if (ierr /= 0) call io_error('Error in allocating cdodq_precond_loc in wann_main')
    end if
    ! initialize local u and m matrices with global ones
    do nkp_loc = 1, counts(my_node_id)
      nkp = nkp_loc + displs(my_node_id)
!       m_matrix_loc (:,:,:, nkp_loc) = &
!           m_matrix (:,:,:, nkp)
      u_matrix_loc(:, :, nkp_loc) = &
        u_matrix(:, :, nkp)
    end do
    call comms_scatterv(m_matrix_loc, num_wann*num_wann*nntot*counts(my_node_id), &
                        m_matrix, num_wann*num_wann*nntot*counts, num_wann*num_wann*nntot*displs)

    allocate (cdq_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cdq_loc in wann_main')
    allocate (cdodq_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cdodq_loc in wann_main')
    allocate (cdqkeep_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cdqkeep_loc in wann_main')
    if (optimisation > 0) then
      allocate (m0_loc(num_wann, num_wann, nntot, max(1, counts(my_node_id))), stat=ierr)
    end if
    if (ierr /= 0) call io_error('Error in allocating m0_loc in wann_main')
    allocate (u0_loc(num_wann, num_wann, max(1, counts(my_node_id))), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating u0_loc in wann_main')

    allocate (cz(num_wann, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cz in wann_main')
    allocate (cmtmp(num_wann, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cmtmp in wann_main')
    allocate (tmp_cdq(num_wann, num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating tmp_cdq in wann_main')
    allocate (evals(num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating evals in wann_main')
    allocate (cwork(4*num_wann), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating cwork in wann_main')
    allocate (rwork(3*num_wann - 2), stat=ierr)
    if (ierr /= 0) call io_error('Error in allocating rwork in wann_main')

    cwschur1 = cmplx_0; cwschur2 = cmplx_0; cwschur3 = cmplx_0; cwschur4 = cmplx_0
    cdq = cmplx_0; cz = cmplx_0; cmtmp = cmplx_0; cdqkeep_loc = cmplx_0; cdq_loc = cmplx_0; ! buff=cmplx_0;

    gcnorm1 = 0.0_dp; gcnorm0 = 0.0_dp

    ! initialise rguide to projection centres (Cartesians in units of Ang)
    if (guiding_centres) then
      do n = 1, num_proj
        call utility_frac_to_cart(proj_site(:, n), rguide(:, n), real_lattice)
      enddo
!       if(spinors) then ! not needed with new changes to spinor proj 2013 JRY
!          do n=1,num_proj
!             call utility_frac_to_cart(proj_site(:,n),rguide(:,n+num_proj),real_lattice)
!          enddo
!       end if
    end if

    if (on_root) then
      write (stdout, *)
      write (stdout, '(1x,a)') '*------------------------------- WANNIERISE ---------------------------------*'
      write (stdout, '(1x,a)') '+--------------------------------------------------------------------+<-- CONV'
      if (lenconfac .eq. 1.0_dp) then
        write (stdout, '(1x,a)') '| Iter  Delta Spread     RMS Gradient      Spread (Ang^2)      Time  |<-- CONV'
      else
        write (stdout, '(1x,a)') '| Iter  Delta Spread     RMS Gradient      Spread (Bohr^2)     Time  |<-- CONV'
      endif
      write (stdout, '(1x,a)') '+--------------------------------------------------------------------+<-- CONV'
      write (stdout, *)
    endif

    irguide = 0
    if (guiding_centres .and. (num_no_guide_iter .le. 0)) then
      call wann_phases(csheet, sheet, rguide, irguide)
      irguide = 1
    endif

    ! constrained centres part
    lambda_loc = 0.0_dp
    if (selective_loc .and. slwf_constrain) then
      lambda_loc = slwf_lambda
    end if

    ! calculate initial centers and spread
    call wann_omega(csheet, sheet, rave, r2ave, rave2, wann_spread)

    ! public variables
    if (.not. selective_loc) then
      omega_total = wann_spread%om_tot
      omega_invariant = wann_spread%om_i
      omega_tilde = wann_spread%om_d + wann_spread%om_od
    else
      omega_total = wann_spread%om_tot
      ! omega_invariant = wann_spread%om_iod
      ! omega_tilde = wann_spread%om_d + wann_spread%om_nu
    end if

    ! public arrays of Wannier centres and spreads
    wannier_centres = rave
    wannier_spreads = r2ave - rave2

    if (lfixstep) lquad = .false.
    ncg = 0
    iter = 0
    old_spread%om_tot = 0.0_dp

    ! print initial state
    if (on_root) then
      write (stdout, '(1x,a78)') repeat('-', 78)
      write (stdout, '(1x,a)') 'Initial State'
      do iw = 1, num_wann
        write (stdout, 1000) iw, (rave(ind, iw)*lenconfac, ind=1, 3), &
          (r2ave(iw) - rave2(iw))*lenconfac**2
      end do
      write (stdout, 1001) (sum(rave(ind, :))*lenconfac, ind=1, 3), (sum(r2ave) - sum(rave2))*lenconfac**2
      write (stdout, *)
      if (selective_loc .and. slwf_constrain) then
        write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
          iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, sqrt(abs(gcnorm1))*lenconfac, &
          wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
        write (stdout, '(7x,a,F15.7,a,F15.7,a,F15.7,a,F15.7,a)') &
          'O_D=', wann_spread%om_d*lenconfac**2, &
          ' O_IOD=', (wann_spread%om_iod + wann_spread%om_nu)*lenconfac**2, &
          ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
        write (stdout, '(1x,a78)') repeat('-', 78)
      elseif (selective_loc .and. .not. slwf_constrain) then
        write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
          iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, sqrt(abs(gcnorm1))*lenconfac, &
          wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
        write (stdout, '(7x,a,F15.7,a,F15.7,a,F15.7,a)') &
          'O_D=', wann_spread%om_d*lenconfac**2, &
          ' O_IOD=', wann_spread%om_iod*lenconfac**2, &
          ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
        write (stdout, '(1x,a78)') repeat('-', 78)
      else
        write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
          iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, sqrt(abs(gcnorm1))*lenconfac, &
          wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
        write (stdout, '(8x,a,F15.7,a,F15.7,a,F15.7,a)') &
          'O_D=', wann_spread%om_d*lenconfac**2, ' O_OD=', wann_spread%om_od*lenconfac**2, &
          ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
        write (stdout, '(1x,a78)') repeat('-', 78)
      end if
    endif

    lconverged = .false.; lfirst = .true.; lrandom = .false.
    conv_count = 0; noise_count = 0

    if (.not. lfixstep .and. optimisation <= 0) then
      page_unit = io_file_unit()
      open (unit=page_unit, status='scratch', form='unformatted')
    endif

    ! main iteration loop
    do iter = 1, num_iter

      lprint = .false.
      if ((mod(iter, num_print_cycles) .eq. 0) .or. (iter .eq. 1) &
          .or. (iter .eq. num_iter)) lprint = .true.

      ldump = .false.
      if ((num_dump_cycles .gt. 0) .and. (mod(iter, num_dump_cycles) .eq. 0)) ldump = .true.

      if (lprint .and. on_root) write (stdout, '(1x,a,i6)') 'Cycle: ', iter

      if (guiding_centres .and. (iter .gt. num_no_guide_iter) &
          .and. (mod(iter, num_guide_cycles) .eq. 0)) then
        call wann_phases(csheet, sheet, rguide, irguide)
        irguide = 1
      endif

      ! calculate gradient of omega

      if (lsitesymmetry .or. precond) then
        call wann_domega(csheet, sheet, rave, cdodq)
      else
        call wann_domega(csheet, sheet, rave)!,cdodq)  fills only cdodq_loc
      endif

      if (lprint .and. iprint > 2 .and. on_root) &
        write (stdout, *) ' LINE --> Iteration                     :', iter

      ! calculate search direction (cdq)
      call internal_search_direction()
      if (lsitesymmetry) call sitesym_symmetrize_gradient(2, cdq) !RS:

      ! save search direction
      cdqkeep_loc(:, :, :) = cdq_loc(:, :, :)

      ! check whether we're doing fixed step lengths
      if (lfixstep) then

        alphamin = fixed_step

        ! or a parabolic line search
      else

        ! take trial step
        cdq_loc(:, :, :) = cdqkeep_loc(:, :, :)*(trial_step/(4.0_dp*wbtot))

        ! store original U and M before rotating
        u0_loc = u_matrix_loc

        if (optimisation <= 0) then
!             write(page_unit)   m_matrix
          write (page_unit) m_matrix_loc
          rewind (page_unit)
        else
          m0_loc = m_matrix_loc
        endif

        ! update U and M
        call internal_new_u_and_m()

        ! calculate spread at trial step
        call wann_omega(csheet, sheet, rave, r2ave, rave2, trial_spread)

        ! Calculate optimal step (alphamin)
        call internal_optimal_step()

      endif

      ! print line search information
      if (lprint .and. iprint > 2 .and. on_root) then
        write (stdout, *) ' LINE --> Spread at initial point       :', wann_spread%om_tot*lenconfac**2
        if (.not. lfixstep) &
          write (stdout, *) ' LINE --> Spread at trial step          :', trial_spread%om_tot*lenconfac**2
        write (stdout, *) ' LINE --> Slope along search direction  :', doda0*lenconfac**2
        write (stdout, *) ' LINE --> ||SD gradient||^2             :', gcnorm1*lenconfac**2
        if (.not. lfixstep) then
          write (stdout, *) ' LINE --> Trial step length             :', trial_step
          if (lquad) then
            write (stdout, *) ' LINE --> Optimal parabolic step length :', alphamin
            write (stdout, *) ' LINE --> Spread at predicted minimum   :', falphamin*lenconfac**2
          endif
        else
          write (stdout, *) ' LINE --> Fixed step length             :', fixed_step
        endif
        write (stdout, *) ' LINE --> CG coefficient                :', gcfac
      endif

      ! if taking a fixed step or if parabolic line search was successful
      if (lfixstep .or. lquad) then

        ! take optimal step
        cdq_loc(:, :, :) = cdqkeep_loc(:, :, :)*(alphamin/(4.0_dp*wbtot))

        ! if doing a line search then restore original U and M before rotating
        if (.not. lfixstep) then
          u_matrix_loc = u0_loc
          if (optimisation <= 0) then
!                read(page_unit)  m_matrix
            read (page_unit) m_matrix_loc
            rewind (page_unit)
          else
            m_matrix_loc = m0_loc
          endif
        endif

        ! update U and M
        call internal_new_u_and_m()

        call wann_spread_copy(wann_spread, old_spread)

        ! calculate the new centers and spread
        call wann_omega(csheet, sheet, rave, r2ave, rave2, wann_spread)

        ! parabolic line search was unsuccessful, use trial step already taken
      else

        call wann_spread_copy(wann_spread, old_spread)
        call wann_spread_copy(trial_spread, wann_spread)

      endif

      ! print the new centers and spreads
      if (lprint .and. on_root) then
        do iw = 1, num_wann
          write (stdout, 1000) iw, (rave(ind, iw)*lenconfac, ind=1, 3), &
            (r2ave(iw) - rave2(iw))*lenconfac**2
        end do
        write (stdout, 1001) (sum(rave(ind, :))*lenconfac, ind=1, 3), &
          (sum(r2ave) - sum(rave2))*lenconfac**2
        write (stdout, *)
        if (selective_loc .and. slwf_constrain) then
          write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
            iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, &
            sqrt(abs(gcnorm1))*lenconfac, &
            wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
          write (stdout, '(7x,a,F15.7,a,F15.7,a,F15.7,a)') &
            'O_IOD=', (wann_spread%om_iod + wann_spread%om_nu)*lenconfac**2, &
            ' O_D=', wann_spread%om_d*lenconfac**2, &
            ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
          write (stdout, '(a,E15.7,a,E15.7,a,E15.7,a)') &
            'Delta: O_IOD=', ((wann_spread%om_iod + wann_spread%om_nu) - &
                              (old_spread%om_iod + wann_spread%om_nu))*lenconfac**2, &
            ' O_D=', (wann_spread%om_d - old_spread%om_d)*lenconfac**2, &
            ' O_TOT=', (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, ' <-- DLTA'
          write (stdout, '(1x,a78)') repeat('-', 78)
        elseif (selective_loc .and. .not. slwf_constrain) then
          write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
            iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, &
            sqrt(abs(gcnorm1))*lenconfac, &
            wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
          write (stdout, '(7x,a,F15.7,a,F15.7,a,F15.7,a)') &
            'O_IOD=', wann_spread%om_iod*lenconfac**2, &
            ' O_D=', wann_spread%om_d*lenconfac**2, &
            ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
          write (stdout, '(a,E15.7,a,E15.7,a,E15.7,a)') &
            'Delta: O_IOD=', (wann_spread%om_iod - old_spread%om_iod)*lenconfac**2, &
            ' O_D=', (wann_spread%om_d - old_spread%om_d)*lenconfac**2, &
            ' O_TOT=', (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, ' <-- DLTA'
          write (stdout, '(1x,a78)') repeat('-', 78)
        else
          write (stdout, '(1x,i6,2x,E12.3,2x,F15.10,2x,F18.10,3x,F8.2,2x,a)') &
            iter, (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, &
            sqrt(abs(gcnorm1))*lenconfac, &
            wann_spread%om_tot*lenconfac**2, io_wallclocktime(), '<-- CONV'
          write (stdout, '(8x,a,F15.7,a,F15.7,a,F15.7,a)') &
            'O_D=', wann_spread%om_d*lenconfac**2, &
            ' O_OD=', wann_spread%om_od*lenconfac**2, &
            ' O_TOT=', wann_spread%om_tot*lenconfac**2, ' <-- SPRD'
          write (stdout, '(1x,a,E15.7,a,E15.7,a,E15.7,a)') &
            'Delta: O_D=', (wann_spread%om_d - old_spread%om_d)*lenconfac**2, &
            ' O_OD=', (wann_spread%om_od - old_spread%om_od)*lenconfac**2, &
            ' O_TOT=', (wann_spread%om_tot - old_spread%om_tot)*lenconfac**2, ' <-- DLTA'
          write (stdout, '(1x,a78)') repeat('-', 78)
        end if
      end if

      ! Public array of Wannier centres and spreads
      wannier_centres = rave
      wannier_spreads = r2ave - rave2

      ! Public variables
      if (.not. selective_loc) then
        omega_total = wann_spread%om_tot
        omega_tilde = wann_spread%om_d + wann_spread%om_od
      else
        omega_total = wann_spread%om_tot
        !omega_tilde = wann_spread%om_d + wann_spread%om_nu
      end if

      if (ldump) then
        ! Before calling param_write_chkpt, I need to gather on the root node
        ! the u_matrix from the u_matrix_loc. No need to broadcast it since
        ! it's printed by the root node only
        call comms_gatherv(u_matrix_loc, num_wann*num_wann*counts(my_node_id), &
                           u_matrix, num_wann*num_wann*counts, num_wann*num_wann*displs)
        ! I also transfer the M matrix
        call comms_gatherv(m_matrix_loc, num_wann*num_wann*nntot*counts(my_node_id), &
                           m_matrix, num_wann*num_wann*nntot*counts, num_wann*num_wann*nntot*displs)
        if (on_root) call param_write_chkpt('postdis')
      endif

      if (conv_window .gt. 1) call internal_test_convergence()

      if (lconverged) then
        write (stdout, '(/13x,a,es10.3,a,i2,a)') &
          '<<<     Delta <', conv_tol, &
          '  over ', conv_window, ' iterations     >>>'
        write (stdout, '(13x,a/)') '<<< Wannierisation convergence criteria satisfied >>>'
        exit
      endif

    enddo
    ! end of the minimization loop

    ! the m matrix is sent by piece to avoid huge arrays
    ! But, I want to reduce the memory usage as much as possible.
!    do nn = 1, nntot
!      m_matrix_1b_loc=m_matrix_loc(:,:,nn,:)
!      call comms_gatherv(m_matrix_1b_loc,num_wann*num_wann*counts(my_node_id),&
!                 m_matrix_1b,num_wann*num_wann*counts,num_wann*num_wann*displs)
!      call comms_bcast(m_matrix_1b(1,1,1),num_wann*num_wann*num_kpts)
!      m_matrix(:,:,nn,:)=m_matrix_1b(:,:,:)
!    end do!nn
    call comms_gatherv(m_matrix_loc, num_wann*num_wann*nntot*counts(my_node_id), &
                       m_matrix, num_wann*num_wann*nntot*counts, num_wann*num_wann*nntot*displs)

    ! send u matrix
    call comms_gatherv(u_matrix_loc, num_wann*num_wann*counts(my_node_id), &
                       u_matrix, num_wann*num_wann*counts, num_wann*num_wann*displs)
    call comms_bcast(u_matrix(1, 1, 1), num_wann*num_wann*num_kpts)

    ! Evaluate the penalty functional
    if (selective_loc .and. slwf_constrain) then
      rnr0n2 = 0.0_dp
      do iw = 1, slwf_num
        rnr0n2(iw) = (wannier_centres(1, iw) - ccentres_cart(iw, 1))**2 &
                     + (wannier_centres(2, iw) - ccentres_cart(iw, 2))**2 &
                     + (wannier_centres(3, iw) - ccentres_cart(iw, 3))**2
      end do
    end if

    if (on_root) then
      write (stdout, '(1x,a)') 'Final State'
      do iw = 1, num_wann
        write (stdout, 1000) iw, (rave(ind, iw)*lenconfac, ind=1, 3), &
          (r2ave(iw) - rave2(iw))*lenconfac**2
      end do
      write (stdout, 1001) (sum(rave(ind, :))*lenconfac, ind=1, 3), &
        (sum(r2ave) - sum(rave2))*lenconfac**2
      write (stdout, *)
      if (selective_loc .and. slwf_constrain) then
        write (stdout, '(3x,a21,a,f15.9)') '     Spreads ('//trim(length_unit)//'^2)', &
          '       Omega IOD_C   = ', (wann_spread%om_iod + wann_spread%om_nu)*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '     ================       Omega D       = ', &
          wann_spread%om_d*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '                            Omega Rest    = ', &
          (sum(r2ave) - sum(rave2) + wann_spread%om_tot)*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '                            Penalty func  = ', &
          sum(rnr0n2(:))
        write (stdout, '(3x,a21,a,f15.9)') 'Final Spread ('//trim(length_unit)//'^2)', &
          '       Omega Total_C = ', wann_spread%om_tot*lenconfac**2
        write (stdout, '(1x,a78)') repeat('-', 78)
      else if (selective_loc .and. .not. slwf_constrain) then
        write (stdout, '(3x,a21,a,f15.9)') '     Spreads ('//trim(length_unit)//'^2)', &
          '       Omega IOD    = ', wann_spread%om_iod*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '     ================       Omega D      = ', &
          wann_spread%om_d*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '                            Omega Rest   = ', &
          (sum(r2ave) - sum(rave2) + wann_spread%om_tot)*lenconfac**2
        write (stdout, '(3x,a21,a,f15.9)') 'Final Spread ('//trim(length_unit)//'^2)', &
          '       Omega Total  = ', wann_spread%om_tot*lenconfac**2
        write (stdout, '(1x,a78)') repeat('-', 78)
      else
        write (stdout, '(3x,a21,a,f15.9)') '     Spreads ('//trim(length_unit)//'^2)', &
          '       Omega I      = ', wann_spread%om_i*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '     ================       Omega D      = ', &
          wann_spread%om_d*lenconfac**2
        write (stdout, '(3x,a,f15.9)') '                            Omega OD     = ', &
          wann_spread%om_od*lenconfac**2
        write (stdout, '(3x,a21,a,f15.9)') 'Final Spread ('//trim(length_unit)//'^2)', &
          '       Omega Total  = ', wann_spread%om_tot*lenconfac**2
        write (stdout, '(1x,a78)') repeat('-', 78)
      end if
    endif

    if (write_xyz .and. on_root) call wann_write_xyz()

    if (write_hr_diag) then
      call hamiltonian_setup()
      call hamiltonian_get_hr()
      if (on_root) then
        write (stdout, *)
        write (stdout, '(1x,a)') 'On-site Hamiltonian matrix elements'
        write (stdout, '(3x,a)') '  n        <0n|H|0n> (eV)'
        write (stdout, '(3x,a)') '-------------------------'
        do i = 1, num_wann
          write (stdout, '(3x,i3,5x,f12.6)') i, real(ham_r(i, i, rpt_origin), kind=dp)
        enddo
        write (stdout, *)
      endif
    endif

    if (guiding_centres) call wann_phases(csheet, sheet, rguide, irguide)

    ! unitarity is checked
!~    call internal_check_unitarity()
    call wann_check_unitarity()

    ! write extra info regarding omega_invariant
!~    if (iprint>2) call internal_svd_omega_i()
!    if (iprint>2) call wann_svd_omega_i()
    if (iprint > 2 .and. on_root) call wann_svd_omega_i()

    ! write matrix elements <m|r^2|n> to file
!~    if (write_r2mn) call internal_write_r2mn()
!    if (write_r2mn) call wann_write_r2mn()
    if (write_r2mn .and. on_root) call wann_write_r2mn()

    ! calculate and write projection of WFs on original bands in outer window
    if (have_disentangled .and. write_proj) call wann_calc_projection()

    ! aam: write data required for vdW utility
    if (write_vdw_data .and. on_root) call wann_write_vdw_data()

    ! deallocate sub vars not passed into other subs
    deallocate (rwork, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rwork in wann_main')
    deallocate (cwork, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cwork in wann_main')
    deallocate (evals, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating evals in wann_main')
    deallocate (tmp_cdq, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating tmp_cdq in wann_main')
    deallocate (cmtmp, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cmtmp in wann_main')
    deallocate (cz, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cz in wann_main')
    deallocate (cdq, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cdq in wann_main')

    ! for MPI
    deallocate (ln_tmp_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating ln_tmp_loc in wann_main')
    deallocate (rnkb_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rnkb_loc in wann_main')
    deallocate (u_matrix_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating u_matrix_loc in wann_main')
    deallocate (m_matrix_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating m_matrix_loc in wann_main')
!    deallocate(m_matrix_1b,stat=ierr)
!    if (ierr/=0) call io_error('Error in deallocating m_matrix_1b in wann_main')
!    deallocate(m_matrix_1b_loc,stat=ierr)
!    if (ierr/=0) call io_error('Error in deallocating m_matrix_1b_loc in wann_main')
    deallocate (cdq_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cdq_loc in wann_main')
    deallocate (cdodq_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cdodq_loc in wann_main')
    deallocate (cdqkeep_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cdqkeep_loc in wann_main')

    deallocate (cwschur3, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cwschur3 in wann_main')
    deallocate (cwschur1, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cwschur1 in wann_main')
    if (precond) then
      if (optimisation >= 3) then
        deallocate (k_to_r, stat=ierr)
        if (ierr /= 0) call io_error('Error in deallocating k_to_r in wann_main')
      end if
      deallocate (cdodq_r, stat=ierr)
      if (ierr /= 0) call io_error('Error in deallocating cdodq_r in wann_main')
      deallocate (cdodq_precond, stat=ierr)
      if (ierr /= 0) call io_error('Error in deallocating cdodq_precond in wann_main')
      deallocate (cdodq_precond_loc, stat=ierr)
      if (ierr /= 0) call io_error('Error in deallocating cdodq_precond_loc in wann_main')
    end if

    ! deallocate sub vars passed into other subs
    deallocate (rguide, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rguide in wann_main')
    deallocate (rave2, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rave2 in wann_main')
    deallocate (rave, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rave in wann_main')
    deallocate (sheet, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating sheet in wann_main')
    deallocate (cdodq, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating cdodq in wann_main')
    deallocate (csheet, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating csheet in wann_main')
    if (selective_loc) then
      deallocate (rnr0n2, stat=ierr)
      if (ierr /= 0) call io_error('Error in deallocating rnr0n2 in wann_main')
    end if
    ! deallocate module data
    deallocate (ln_tmp, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating ln_tmp in wann_main')
    deallocate (rnkb, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating rnkb in wann_main')

    deallocate (u0_loc, stat=ierr)
    if (ierr /= 0) call io_error('Error in deallocating u0_loc in wann_main')
    if (optimisation > 0) then
      deallocate (m0_loc, stat=ierr)
      if (ierr /= 0) call io_error('Error in deallocating m0_loc in wann_main')
    end if

    if (allocated(counts)) deallocate (counts)
    if (allocated(displs)) deallocate (displs)

    deallocate (history, stat=ierr)
    if (ierr /= 0) call io_error('Error deallocating history in wann_main')

    if (timing_level > 0 .and. on_root) call io_stopwatch('wann: main', 2)

    return

1000 format(2x, 'WF centre and spread', &
&       i5, 2x, '(', f10.6, ',', f10.6, ',', f10.6, ' )', f15.8)

1001 format(2x, 'Sum of centres and spreads', &
&       1x, '(', f10.6, ',', f10.6, ',', f10.6, ' )', f15.8)

  contains

    !===============================================!
    subroutine internal_test_convergence()
      !===============================================!
      !                                               !
      !! Determine whether minimisation of non-gauge
      !! invariant spread is converged
      !                                               !
      !===============================================!

      implicit none

      real(kind=dp) :: delta_omega
      integer :: j, ierr
      real(kind=dp), allocatable :: temp_hist(:)

      allocate (temp_hist(conv_window), stat=ierr)
      if (ierr /= 0) call io_error('Error allocating temp_hist in wann_main')

      delta_omega = wann_spread%om_tot - old_spread%om_tot

      if (iter .le. conv_window) then
        history(iter) = delta_omega
      else
        temp_hist = eoshift(history, 1, delta_omega)
        history = temp_hist
      endif

      conv_count = conv_count + 1

      if (conv_count .lt. conv_window) then
        return
      else
!~         write(stdout,*) (history(j),j=1,conv_window)
        do j = 1, conv_window
          if (abs(history(j)) .gt. conv_tol) return
        enddo
      endif

      if ((conv_noise_amp .gt. 0.0_dp) .and. (noise_count .lt. conv_noise_num)) then
        if (lfirst) then
          lfirst = .false.
          save_spread = wann_spread%om_tot
          lrandom = .true.
          conv_count = 0
        else
          if (abs(save_spread - wann_spread%om_tot) .lt. conv_tol) then
            lconverged = .true.
            return
          else
            save_spread = wann_spread%om_tot
            lrandom = .true.
            conv_count = 0
          endif
        endif
      else
        lconverged = .true.
      endif

      if (lrandom) noise_count = noise_count + 1

      deallocate (temp_hist, stat=ierr)
      if (ierr /= 0) call io_error('Error deallocating temp_hist in wann_main')

      return

    end subroutine internal_test_convergence

    !===============================================!
    subroutine internal_random_noise()
      !===============================================!
      !                                               !
      !! Add some random noise to the search direction
      !! to help escape from local minima
      !                                               !
      !===============================================!

      implicit none

      integer :: ikp, iw, jw
      real(kind=dp), allocatable :: noise_real(:, :), noise_imag(:, :)
      complex(kind=dp), allocatable :: cnoise(:, :)

      ! Allocate
      allocate (noise_real(num_wann, num_wann), stat=ierr)
      if (ierr /= 0) call io_error('Error allocating noise_real in wann_main')
      allocate (noise_imag(num_wann, num_wann), stat=ierr)
      if (ierr /= 0) call io_error('Error allocating noise_imag in wann_main')
      allocate (cnoise(num_wann, num_wann), stat=ierr)
      if (ierr /= 0) call io_error('Error allocating cnoise in wann_main')

      ! Initialise
      cnoise = cmplx_0; noise_real = 0.0_dp; noise_imag = 0.0_dp

      ! cdq is a num_wann x num_wann x num_kpts anti-hermitian array
      ! to which we add a random anti-hermitian matrix

      do ikp = 1, counts(my_node_id)
        do iw = 1, num_wann
          call random_seed()
          call random_number(noise_real(:, iw))
          call random_seed()
          call random_number(noise_imag(:, iw))
        enddo
        do jw = 1, num_wann
          do iw = 1, jw
            if (iw .eq. jw) then
              cnoise(iw, jw) = cmplx(0.0_dp, noise_imag(iw, jw), dp)
            else
              cnoise(iw, jw) = cmplx(noise_real(iw, jw), noise_imag(iw, jw), dp)
            endif
            cnoise(jw, iw) = -conjg(cnoise(iw, jw))
          enddo
        enddo
        ! Add noise to search direction
        cdq_loc(:, :, ikp) = cdq_loc(:, :, ikp) + conv_noise_amp*cnoise(:, :)
      enddo

      ! Deallocate
      deallocate (cnoise, stat=ierr)
      if (ierr /= 0) call io_error('Error deallocating cnoise in wann_main')
      deallocate (noise_imag, stat=ierr)
      if (ierr /= 0) call io_error('Error deallocating noise_imag in wann_main')
      deallocate (noise_real, stat=ierr)
      if (ierr /= 0) call io_error('Error deallocating noise_real in wann_main')

      return

    end subroutine internal_random_noise

    !===============================================!
    subroutine internal_search_direction()
      !===============================================!
      !                                               !
      !! Calculate the conjugate gradients search
      !! direction using the Fletcher-Reeves formula:
      !!
      !!     cg_coeff = [g(i).g(i)]/[g(i-1).g(i-1)]
      !                                               !
      !===============================================!

      implicit none

      complex(kind=dp) :: zdotc

      if (timing_level > 1 .and. on_root) call io_stopwatch('wann: main: search_direction', 1)

      ! gcnorm1 = Tr[gradient . gradient] -- NB gradient is anti-Hermitian
      ! gcnorm1 = real(zdotc(num_kpts*num_wann*num_wann,cdodq,1,cdodq,1),dp)

      if (precond) then
        ! compute cdodq_precond

        cdodq_r(:, :, :) = 0 ! intermediary gradient in R space
        cdodq_precond(:, :, :) = 0
        cdodq_precond_loc(:, :, :) = 0
!         cdodq_precond(:,:,:) = complx_0

        ! convert to real space in cdodq_r
        ! Two algorithms: either double loop or GEMM. GEMM is much more efficient but requires more RAM
        ! Ideally, we should implement FFT-based filtering here
        if (optimisation >= 3) then
          call zgemm('N', 'N', num_wann*num_wann, nrpts, num_kpts, cmplx_1, &
               & cdodq, num_wann*num_wann, k_to_r, num_kpts, cmplx_0, cdodq_r, num_wann*num_wann)
          cdodq_r = cdodq_r/real(num_kpts, dp)
        else
          do irpt = 1, nrpts
            do loop_kpt = 1, num_kpts
              rdotk = twopi*dot_product(kpt_latt(:, loop_kpt), real(irvec(:, irpt), dp))
              fac = exp(-cmplx_i*rdotk)/real(num_kpts, dp)
              cdodq_r(:, :, irpt) = cdodq_r(:, :, irpt) + fac*cdodq(:, :, loop_kpt)
            enddo
          enddo
        end if

        ! filter cdodq_r in real space by 1/(1+R^2/alpha)

        ! this alpha coefficient is more or less arbitrary, and could
        ! be tweaked further: the point is to have something that has
        ! the right units, and is not too small (or the filtering is
        ! too severe) or too high (or the filtering does nothing).
        !
        ! the descent direction produced has a different magnitude
        ! than the one without preconditionner, so the values of
        ! trial_step are not consistent
        alpha_precond = 10*wann_spread%om_tot/num_wann
        do irpt = 1, nrpts
          rvec_cart = matmul(real_lattice(:, :), real(irvec(:, irpt), dp))
          cdodq_r(:, :, irpt) = cdodq_r(:, :, irpt)*1/(1 + dot_product(rvec_cart, rvec_cart)/alpha_precond)
        end do

        ! go back to k space
        if (optimisation >= 3) then
          do irpt = 1, nrpts
            cdodq_r(:, :, irpt) = cdodq_r(:, :, irpt)/real(ndegen(irpt), dp)
          end do
          call zgemm('N', 'C', num_wann*num_wann, num_kpts, nrpts, cmplx_1, &
               & cdodq_r, num_wann*num_wann, k_to_r, num_kpts, cmplx_0, cdodq_precond, num_wann*num_wann)
        else
          do irpt = 1, nrpts
            do loop_kpt = 1, num_kpts
              rdotk = twopi*dot_product(kpt_latt(:, loop_kpt), real(irvec(:, irpt), dp))
              fac = exp(cmplx_i*rdotk)/real(ndegen(irpt), dp)
              cdodq_precond(:, :, loop_kpt) = cdodq_precond(:, :, loop_kpt) + fac*cdodq_r(:, :, irpt)
            enddo
          enddo
        end if
        cdodq_precond_loc(:, :, 1:counts(my_node_id)) = &
          cdodq_precond(:, :, 1 + displs(my_node_id):displs(my_node_id) + counts(my_node_id))

      end if

      ! gcnorm1 = Tr[gradient . gradient] -- NB gradient is anti-Hermitian
      if (precond) then
!         gcnorm1 = real(zdotc(num_kpts*num_wann*num_wann,cdodq_precond,1,cdodq,1),dp)
        gcnorm1 = real(zdotc(counts(my_node_id)*num_wann*num_wann, cdodq_precond_loc, 1, cdodq_loc, 1), dp)
      else
        gcnorm1 = real(zdotc(counts(my_node_id)*num_wann*num_wann, cdodq_loc, 1, cdodq_loc, 1), dp)
      end if
      call comms_allreduce(gcnorm1, 1, 'SUM')

      ! calculate cg_coefficient
      if ((iter .eq. 1) .or. (ncg .ge. num_cg_steps)) then
        gcfac = 0.0_dp                 ! Steepest descents
        ncg = 0
      else
        if (gcnorm0 .gt. epsilon(1.0_dp)) then
          gcfac = gcnorm1/gcnorm0     ! Fletcher-Reeves CG coefficient
          ! prevent CG coefficient from getting too large
          if (gcfac .gt. 3.0_dp) then
            if (lprint .and. iprint > 2 .and. on_root) &
              write (stdout, *) ' LINE --> CG coeff too large. Resetting :', gcfac
            gcfac = 0.0_dp
            ncg = 0
          else
            ncg = ncg + 1
          endif
        else
          gcfac = 0.0_dp
          ncg = 0
        endif
      endif

      ! save for next iteration
      gcnorm0 = gcnorm1

      ! calculate search direction

      if (precond) then
        cdq_loc(:, :, :) = cdodq_precond_loc(:, :, :) + cdqkeep_loc(:, :, :)*gcfac !! JRY not MPI
      else
        cdq_loc(:, :, :) = cdodq_loc(:, :, :) + cdqkeep_loc(:, :, :)*gcfac
      end if

      ! add some random noise to search direction, if required
      if (lrandom) then
        if (on_root) write (stdout, '(a,i3,a,i3,a)') &
          ' [ Adding random noise to search direction. Time ', noise_count, ' / ', conv_noise_num, ' ]'
        call internal_random_noise()
      endif
      ! calculate gradient along search direction - Tr[gradient . search direction]
      ! NB gradient is anti-hermitian
      doda0 = -real(zdotc(counts(my_node_id)*num_wann*num_wann, cdodq_loc, 1, cdq_loc, 1), dp)

      call comms_allreduce(doda0, 1, 'SUM')

      doda0 = doda0/(4.0_dp*wbtot)

      ! check search direction is not uphill
      if (doda0 .gt. 0.0_dp) then
        ! if doing a CG step then reset CG
        if (ncg .gt. 0) then
          if (lprint .and. iprint > 2 .and. on_root) &
            write (stdout, *) ' LINE --> Search direction uphill: resetting CG'
          cdq_loc(:, :, :) = cdodq_loc(:, :, :)
          if (lrandom) call internal_random_noise()
          ncg = 0
          gcfac = 0.0_dp
          ! re-calculate gradient along search direction
          doda0 = -real(zdotc(counts(my_node_id)*num_wann*num_wann, cdodq_loc, 1, cdq_loc, 1), dp)

          call comms_allreduce(doda0, 1, 'SUM')

          doda0 = doda0/(4.0_dp*wbtot)
          ! if search direction still uphill then reverse search direction
          if (doda0 .gt. 0.0_dp) then
            if (lprint .and. iprint > 2 .and. on_root) &
              write (stdout, *) ' LINE --> Search direction still uphill: reversing'
            cdq_loc(:, :, :) = -cdq_loc(:, :, :)
            doda0 = -doda0
          endif
          ! if doing a SD step then reverse search direction
        else
          if (lprint .and. iprint > 2 .and. on_root) &
            write (stdout, *) ' LINE --> Search direction uphill: reversing'
          cdq_loc(:, :, :) = -cdq_loc(:, :, :)
          doda0 = -doda0
        endif
      endif

      !~     ! calculate search direction
      !~     cdq(:,:,:) = cdodq(:,:,:) + cdqkeep(:,:,:) * gcfac

      if (timing_level > 1 .and. on_root) call io_stopwatch('wann: main: search_direction', 2)

      lrandom = .false.

      return

    end subroutine internal_search_direction

    !===============================================!
    subroutine internal_optimal_step()
      !===============================================!
      !                                               !
      !! Calculate the optimal step length based on a
      !! parabolic line search
      !                                               !
      !===============================================!

      implicit none

      real(kind=dp) :: fac, shift, eqa, eqb

      if (timing_level > 1 .and. on_root) call io_stopwatch('wann: main: optimal_step', 1)

      fac = trial_spread%om_tot - wann_spread%om_tot
      if (abs(fac) .gt. tiny(1.0_dp)) then
        fac = 1.0_dp/fac
        shift = 1.0_dp
      else
        fac = 1.0e6_dp
        shift = fac*trial_spread%om_tot - fac*wann_spread%om_tot
      endif
      eqb = fac*doda0
      eqa = shift - eqb*trial_step
      if (abs(eqa/(fac*wann_spread%om_tot)) .gt. epsilon(1.0_dp)) then
        lquad = .true.
        alphamin = -0.5_dp*eqb/eqa*(trial_step**2)
        falphamin = wann_spread%om_tot &
                    - 0.25_dp*eqb*eqb/(fac*eqa)*(trial_step**2)
      else
        if (lprint .and. iprint > 2 .and. on_root) write (stdout, *) &
          ' LINE --> Parabolic line search unstable: using trial step'
        lquad = .false.
        alphamin = trial_step
        falphamin = trial_spread%om_tot
      endif

      if (doda0*alphamin .gt. 0.0_dp) then
        if (lprint .and. iprint > 2 .and. on_root) write (stdout, *) &
          ' LINE --> Line search unstable : using trial step'
        lquad = .false.
        alphamin = trial_step
        falphamin = trial_spread%om_tot
      endif

      if (timing_level > 1 .and. on_root) call io_stopwatch('wann: main: optimal_step', 2)

      return

    end subroutine internal_optimal_step

    !===============================================!
    subroutine internal_new_u_and_m()
      !===============================================!
      !                                               !
      !! Update U and M matrices after a trial step
      !                                               !
      !===============================================!
      use w90_sitesym, only: sitesym_symmetrize_rotation, & !RS:
        ir2ik, ik2ir !YN: RS:

      implicit none

      integer :: nkp, nn, nkp2, nsdim, nkp_loc
      logical :: ltmp

      if (timing_level > 1 .and. on_root) call io_stopwatch('wann: main: u_and_m', 1)

      do nkp_loc = 1, counts(my_node_id)
        nkp = nkp_loc + displs(my_node_id)
        if (lsitesymmetry) then                !YN: RS:
          if (ir2ik(ik2ir(nkp)) .ne. nkp) cycle !YN: RS:
        end if                                 !YN: RS:
        ! cdq(nkp) is anti-Hermitian; tmp_cdq = i*cdq  is Hermitian
        tmp_cdq(:, :) = cmplx_i*cdq_loc(:, :, nkp_loc)
        ! Hermitian matrix eigen-solver
        call zheev('V', 'U', num_wann, tmp_cdq, num_wann, evals, cwork, 4*num_wann, rwork, info)
        if (info .ne. 0) then
          if (on_root) write (stdout, *) &
            'wann_main: ZHEEV in internal_new_u_and_m failed, info= ', info
          if (on_root) write (stdout, *) '           trying Schur decomposition instead'
!!$            call io_error('wann_main: problem in ZHEEV in internal_new_u_and_m')
          tmp_cdq(:, :) = cdq_loc(:, :, nkp_loc)
          call zgees('V', 'N', ltmp, num_wann, tmp_cdq, num_wann, nsdim, &
                     cwschur1, cz, num_wann, cwschur2, 10*num_wann, cwschur3, &
                     cwschur4, info)
          if (info .ne. 0) then
            if (on_root) write (stdout, *) 'wann_main: SCHUR failed, info= ', info
            call io_error('wann_main: problem computing schur form 1')
          endif
          do i = 1, num_wann
            tmp_cdq(:, i) = cz(:, i)*exp(cwschur1(i))
          enddo
          ! cmtmp   = tmp_cdq . cz^{dagger}
          call utility_zgemm(cmtmp, tmp_cdq, 'N', cz, 'C', num_wann)
          cdq_loc(:, :, nkp_loc) = cmtmp(:, :)
        else
          do i = 1, num_wann
            cmtmp(:, i) = tmp_cdq(:, i)*exp(-cmplx_i*evals(i))
          enddo
          ! cdq(nkp)   = cmtmp . tmp_cdq^{dagger}
          call utility_zgemm(cdq_loc(:, :, nkp_loc), cmtmp, 'N', tmp_cdq, 'C', num_wann)
        endif
      enddo

      ! each process communicates its result to other processes
      ! it would be enough to copy only next neighbors
      call comms_gatherv(cdq_loc, num_wann*num_wann*counts(my_node_id), &
                         cdq, num_wann*num_wann*counts, num_wann*num_wann*displs)
      call comms_bcast(cdq(1, 1, 1), num_wann*num_wann*num_kpts)

!!$      do nkp = 1, num_kpts
!!$         tmp_cdq(:,:) = cdq(:,:,nkp)
!!$         call zgees ('V', 'N', ltmp, num_wann, tmp_cdq, num_wann, nsdim, &
!!$              cwschur1, cz, num_wann, cwschur2, 10 * num_wann, cwschur3, &
!!$              cwschur4, info)
!!$         if (info.ne.0) then
!!$            write(stdout,*) 'SCHUR: ', info
!!$            call io_error('wann_main: problem computing schur form 1')
!!$         endif
!!$         do i=1,num_wann
!!$            tmp_cdq(:,i) = cz(:,i) * exp(cwschur1(i))
!!$         enddo
!!$         ! cmtmp   = tmp_cdq . cz^{dagger}
!!$         call utility_zgemm(cmtmp,tmp_cdq,'N',cz,'C',num_wann)
!!$         cdq(:,:,nkp)=cmtmp(:,:)
!!$      enddo

      if (lsitesymmetry) then
        call sitesym_symmetrize_rotation(cdq) !RS: calculate cdq(Rk) from k
        cdq_loc(:, :, 1:counts(my_node_id)) = cdq(:, :, 1 + displs(my_node_id):displs(my_node_id) + counts(my_node_id))
      endif

      ! the orbitals are rotated
      do nkp_loc = 1, counts(my_node_id)
        nkp = nkp_loc + displs(my_node_id)
        ! cmtmp = U(k) . cdq(k)
        call utility_zgemm(cmtmp, u_matrix_loc(:, :, nkp_loc), 'N', cdq_loc(:, :, nkp_loc), 'N', num_wann)
        u_matrix_loc(:, :, nkp_loc) = cmtmp(:, :)
      enddo

      ! and the M_ij are updated
      do nkp_loc = 1, counts(my_node_id)
        nkp = nkp_loc + displs(my_node_id)
        do nn = 1, nntot
          nkp2 = nnlist(nkp, nn)
          ! tmp_cdq = cdq^{dagger} . M
          call utility_zgemm(tmp_cdq, cdq(:, :, nkp), 'C', m_matrix_loc(:, :, nn, nkp_loc), 'N', num_wann)
          ! cmtmp = tmp_cdq . cdq
          call utility_zgemm(cmtmp, tmp_cdq, 'N', cdq(:, :, nkp2), 'N', num_wann)
          m_matrix_loc(:, :, nn, nkp_loc) = cmtmp(:, :)
        enddo
      enddo

      if (timing_level > 1) call io_stopwatch('wann: main: u_and_m', 2)

      return

    end subroutine internal_new_u_and_m

!~    !========================================!
!~    subroutine internal_check_unitarity()
!~    !========================================!
!~
!~      implicit none
!~
!~      integer :: nkp,i,j,m
!~      complex(kind=dp) :: ctmp1,ctmp2
!~
!~      if (timing_level>1) call io_stopwatch('wann: main: check_unitarity',1)
!~
!~      do nkp = 1, num_kpts
!~         do i = 1, num_wann
!~            do j = 1, num_wann
!~               ctmp1 = cmplx_0
!~               ctmp2 = cmplx_0
!~               do m = 1, num_wann
!~                  ctmp1 = ctmp1 + u_matrix (i, m, nkp) * conjg (u_matrix (j, m, nkp) )
!~                  ctmp2 = ctmp2 + u_matrix (m, j, nkp) * conjg (u_matrix (m, i, nkp) )
!~               enddo
!~               if ( (i.eq.j) .and. (abs (ctmp1 - cmplx_1 ) .gt. eps5) ) &
!~                    then
!~                  write ( stdout , * ) ' ERROR: unitariety of final U', nkp, i, j, &
!~                       ctmp1
!~                  call io_error('wann_main: unitariety error 1')
!~               endif
!~               if ( (i.eq.j) .and. (abs (ctmp2 - cmplx_1 ) .gt. eps5) ) &
!~                    then
!~                  write ( stdout , * ) ' ERROR: unitariety of final U', nkp, i, j, &
!~                       ctmp2
!~                  call io_error('wann_main: unitariety error 2')
!~               endif
!~               if ( (i.ne.j) .and. (abs (ctmp1) .gt. eps5) ) then
!~                  write ( stdout , * ) ' ERROR: unitariety of final U', nkp, i, j, &
!~                       ctmp1
!~                  call io_error('wann_main: unitariety error 3')
!~               endif
!~               if ( (i.ne.j) .and. (abs (ctmp2) .gt. eps5) ) then
!~                  write ( stdout , * ) ' ERROR: unitariety of final U', nkp, i, j, &
!~                       ctmp2
!~                  call io_error('wann_main: unitariety error 4')
!~               endif
!~            enddo
!~         enddo
!~      enddo
!~
!~      if (timing_level>1) call io_stopwatch('wann: main: check_unitarity',2)
!~
!~      return
!~
!~    end subroutine internal_check_unitarity

!~    !========================================!
!~    subroutine internal_write_r2mn()
!~    !========================================!
!~    !                                        !
!~    ! Write seedname.r2mn file               !
!~    !                                        !
!~    !========================================!
!~      use w90_io, only: seedname,io_file_unit,io_error
!~
!~      implicit none
!~
!~      integer :: r2mnunit,nw1,nw2,nkp,nn
!~      real(kind=dp) :: r2ave_mn,delta
!~
!~      ! note that here I use formulas analogue to Eq. 23, and not to the
!~      ! shift-invariant Eq. 32 .
!~      r2mnunit=io_file_unit()
!~      open(r2mnunit,file=trim(seedname)//'.r2mn',form='formatted',err=158)
!~      do nw1 = 1, num_wann
!~         do nw2 = 1, num_wann
!~            r2ave_mn = 0.0_dp
!~            delta = 0.0_dp
!~            if (nw1.eq.nw2) delta = 1.0_dp
!~            do nkp = 1, num_kpts
!~               do nn = 1, nntot
!~                  r2ave_mn = r2ave_mn + wb(nn) * &
!~                       ! [GP-begin, Apr13, 2012: corrected sign inside "real"]
!~                       ( 2.0_dp * delta - real(m_matrix(nw1,nw2,nn,nkp) + &
!~                       conjg(m_matrix(nw2,nw1,nn,nkp)),kind=dp) )
!~                       ! [GP-end]
!~               enddo
!~            enddo
!~            r2ave_mn = r2ave_mn / real(num_kpts,dp)
!~            write (r2mnunit, '(2i6,f20.12)') nw1, nw2, r2ave_mn
!~         enddo
!~      enddo
!~      close(r2mnunit)
!~
!~      return
!~
!~158   call io_error('Error opening file '//trim(seedname)//'.r2mn in wann_main')
!~
!~    end subroutine internal_write_r2mn

!~    !========================================!
!~    subroutine internal_svd_omega_i()
!~    !========================================!
!~
!~      implicit none
!~
!~      complex(kind=dp), allocatable  :: cv1(:,:),cv2(:,:)
!~      complex(kind=dp), allocatable  :: cw1(:),cw2(:)
!~      complex(kind=dp), allocatable  :: cpad1 (:)
!~      real(kind=dp),    allocatable  :: singvd (:)
!~
!~      integer :: nkp,nn,nb,na,ind
!~      real(kind=dp) :: omt1,omt2,omt3
!~
!~      if (timing_level>1) call io_stopwatch('wann: main: svd_omega_i',1)
!~
!~      allocate( cw1 (10 * num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating cw1 in wann_main')
!~      allocate( cw2 (10 * num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating cw2 in wann_main')
!~      allocate( cv1 (num_wann, num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating cv1 in wann_main')
!~      allocate( cv2 (num_wann, num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating cv2 in wann_main')
!~      allocate( singvd (num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating singvd in wann_main')
!~      allocate( cpad1 (num_wann * num_wann),stat=ierr  )
!~      if (ierr/=0) call io_error('Error in allocating cpad1 in wann_main')
!~
!~      cw1=cmplx_0; cw2=cmplx_0; cv1=cmplx_0; cv2=cmplx_0; cpad1=cmplx_0
!~      singvd=0.0_dp
!~
!~      ! singular value decomposition
!~      omt1 = 0.0_dp ; omt2 = 0.0_dp ; omt3 = 0.0_dp
!~      do nkp = 1, num_kpts
!~         do nn = 1, nntot
!~            ind = 1
!~            do nb = 1, num_wann
!~               do na = 1, num_wann
!~                  cpad1 (ind) = m_matrix (na, nb, nn, nkp)
!~                  ind = ind+1
!~               enddo
!~            enddo
!~            call zgesvd ('A', 'A', num_wann, num_wann, cpad1, num_wann, singvd, cv1, &
!~                 num_wann, cv2, num_wann, cw1, 10 * num_wann, cw2, info)
!~            if (info.ne.0) then
!~               call io_error('ERROR: Singular value decomp. zgesvd failed')
!~            endif
!~
!~            do nb = 1, num_wann
!~               omt1 = omt1 + wb(nn) * (1.0_dp - singvd (nb) **2)
!~               omt2 = omt2 - wb(nn) * (2.0_dp * log (singvd (nb) ) )
!~               omt3 = omt3 + wb(nn) * (acos (singvd (nb) ) **2)
!~            enddo
!~         enddo
!~      enddo
!~      omt1 = omt1 / real(num_kpts,dp)
!~      omt2 = omt2 / real(num_kpts,dp)
!~      omt3 = omt3 / real(num_kpts,dp)
!~      write ( stdout , * ) ' '
!~      write(stdout,'(2x,a,f15.9,1x,a)') 'Omega Invariant:   1-s^2 = ',&
!~           omt1*lenconfac**2,'('//trim(length_unit)//'^2)'
!~      write(stdout,'(2x,a,f15.9,1x,a)') '                 -2log s = ',&
!~           omt2*lenconfac**2,'('//trim(length_unit)//'^2)'
!~      write(stdout,'(2x,a,f15.9,1x,a)') '                  acos^2 = ',&
!~           omt3*lenconfac**2,'('//trim(length_unit)//'^2)'
!~
!~      deallocate(cpad1,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating cpad1 in wann_main')
!~      deallocate(singvd,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating singvd in wann_main')
!~      deallocate(cv2,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating cv2 in wann_main')
!~      deallocate(cv1,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating cv1 in wann_main')
!~      deallocate(cw2,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating cw2 in wann_main')
!~      deallocate(cw1,stat=ierr)
!~      if (ierr/=0) call io_error('Error in deallocating cw1 in wann_main')
!~
!~      if (timing_level>1) call io_stopwatch('wann: main: svd_omega_i',2)
!~
!~      return
!~
!~    end subroutine internal_svd_omega_i

  end subroutine wann_main