// -*- C++ -*-
//
// Copyright (C) 1998, 1999, 2000, 2002  Los Alamos National Laboratory,
// Copyright (C) 1998, 1999, 2000, 2002  CodeSourcery, LLC
//
// This file is part of FreePOOMA.
//
// FreePOOMA is free software; you can redistribute it and/or modify it
// under the terms of the Expat license.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Expat
// license for more details.
//
// You should have received a copy of the Expat license along with
// FreePOOMA; see the file LICENSE.
//
//-----------------------------------------------------------------------------
// Class declaration of a Conjugate-Gradient Poisson solver.
// Solves -Laplace(x)=f,
//     f(z) = 1, 1/3<z1<2/3, 1/3<z2<2/3
//     f(z) = 0 otherwise
//     x = 0 on z1=0,z1=1,z2=0,z2=1
//-----------------------------------------------------------------------------

#ifndef POOMA_BENCHMARKS_SOLVERS_KRYLOV_CGAINCPPTRAN_H
#define POOMA_BENCHMARKS_SOLVERS_KRYLOV_CGAINCPPTRAN_H

// include files

#include "Pooma/Arrays.h"
#include "Utilities/Benchmark.h"

//-----------------------------------------------------------------------------
// CgAInCppTran is a CppTran implementation of the Conjugate gradient solver
//-----------------------------------------------------------------------------

class CgAInCppTran : public Implementation {
public:

  typedef Array<2, double, Brick> Array2D;

  //---------------------------------------------------------------------------
  // We are a CppTran implementation.

  const char* type() const { return CppTranType(); }

  //---------------------------------------------------------------------------
  // We need to initialize the problem for a specific size.

  void initialize(int n) {
    Interval<1> N(1, n);
    Interval<2> newDomain(N, N);

    // reset size of Arrays
    f_m.initialize(newDomain);
    x_m.initialize(newDomain);
    d_m.initialize(newDomain);
    q_m.initialize(newDomain);
    r_m.initialize(newDomain);

    // save problem size    
    n_m = n;
  }

  //---------------------------------------------------------------------------
  // Runs the benchmark.

  void run() {
    int i, j;
    
    setInitialConditions();
    
    int h2i = (n_m-1)*(n_m-1);  // dx^-2
    
    double resid,r2,r2new;
    double alpha, beta;

    double normb = dot(f_m,f_m);

    // r = f-Lx
    for (j = 2; j <= n_m - 1; j++) {
      for (i = 2; i <= n_m - 1; i++) {
	r_m(i,j) = f_m(i,j) - h2i *
                   (4*x_m(i,j)-x_m(i-1,j)-x_m(i+1,j)-x_m(i,j-1)-x_m(i,j+1));
	d_m(i,j) = r_m(i,j);
	q_m(i,j) = 0.0;
      }
    }

    if (normb == 0.0) 
      normb = 1;
  
    r2 = dot(r_m,r_m);
    resid = r2 / normb;

    double tol=1.0E-10;
    int max_iter=10000;

    int ii;

    for (ii = 1; (ii <= max_iter)&&(resid>tol); ii++) {
      for (j = 2; j <= n_m - 1; j++) {
	for (i = 2; i <= n_m - 1; i++) {
	  q_m(i,j) = h2i *
                     (4*d_m(i,j)-d_m(i-1,j)-d_m(i+1,j)-d_m(i,j-1)-d_m(i,j+1));
	}
      }

      alpha = r2 / dot(d_m, q_m);
      for (j = 1; j <= n_m; j++) {
	for (i = 1; i <= n_m; i++) {
	  x_m(i,j) += alpha * d_m(i,j);
	  r_m(i,j) -= alpha * q_m(i,j);
	}
      }

      r2new = dot(r_m, r_m);
      beta = r2new/r2;
      for (j = 1; j <= n_m; j++) {
	for (i = 1; i <= n_m; i++) {
	  d_m(i,j) = r_m(i,j) + beta * d_m(i,j);
	}
      }

      r2 = r2new;
      resid = r2 / normb;

    }

    // save result for checking
    check_m = x_m(n_m/2, n_m/2);
    // save number of iterations
    iters_m = ii;
    // tally up the flops
    flops_m = 11 * ((double)n_m - 2) * ((double)n_m - 2) +
              iters_m * (10 * ((double)n_m - 2) * ((double)n_m - 2) 
	      + 6 * (double)n_m * (double)n_m);
  }    

  //---------------------------------------------------------------------------
  // Just run the setup

  void runSetup() {
    setInitialConditions();
  }

  //---------------------------------------------------------------------------
  // Prints out the check value for this case.

  double resultCheck() const { return check_m; }
  
  //---------------------------------------------------------------------------
  // Returns the number of flops.

  double opCount() const { return flops_m; }

private:

  double dot(Array2D a, Array2D b) {
    int i, j;
    double result;

    result = 0.0;
    for (j = 2; j <= n_m-1; j++) {
      for (i = 2; i <= n_m-1; i++) {
        result += a(i,j) * b(i,j);
      }
    }
    return result;
  }

  //---------------------------------------------------------------------------
  // Initialize our arrays.

  void setInitialConditions() {
    int i, j;

    for (j = 1; j <= n_m; j++)
    {
      for (i = 1; i <= n_m; i++)
      {
        if ((i>=(2+n_m)/3) && (i<=(2*n_m+1)/3) &&
	    (j>=(2+n_m)/3) && (j<=(2*n_m+1)/3))
	{
          f_m(i,j) = 1.0;
        }
        else {
          f_m(i,j) = 0.0;
        }
	x_m(i,j) = 0.0;
	d_m(i,j) = 0.0;
	q_m(i,j) = 0.0;
	r_m(i,j) = 0.0;
      }
    }
  }
  
  //---------------------------------------------------------------------------
  // Arrays
  
  Array2D f_m, x_m, d_m, q_m, r_m;
  
  //---------------------------------------------------------------------------
  // Problem size.
  
  int n_m;
  
  //---------------------------------------------------------------------------
  // Check value.
  
  double check_m;

  //---------------------------------------------------------------------------
  // Iteration count.
  
  int iters_m;

  //---------------------------------------------------------------------------
  // Flop count.
  
  double flops_m;
};

#endif // POOMA_BENCHMARKS_SOLVERS_KRYLOV_CGAINCPPTRAN_H

// ACL:rcsinfo
// ----------------------------------------------------------------------
// $RCSfile: CGAInCppTran.h,v $   $Author: richard $
// $Revision: 1.17 $   $Date: 2004/11/01 18:15:18 $
// ----------------------------------------------------------------------
// ACL:rcsinfo
