// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
// Mobius Forensic Toolkit
// Copyright (C) 2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018 Eduardo Aguiar
//
// This program is free software; you can redistribute it and/or modify it
// under the terms of the GNU General Public License as published by the
// Free Software Foundation; either version 2, or (at your option) any later
// version.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
// Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
#include "connection_pool.h"
#include "connection.h"
#include "database.h"
#include <mobius/exception.inc>
#include <chrono>
#include <mutex>
#include <stdexcept>
#include <thread>
#include <unordered_map>

namespace mobius
{
namespace database
{
namespace
{
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief main thread ID
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
static std::thread::id main_thread_id = std::this_thread::get_id ();
} // namespace

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief implementation data structure for connection_pool
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
struct connection_pool::impl
{
  //! \brief database file path
  std::string path;

  //! \brief max connections for this pool
  unsigned int max = 32;

  //! \brief pool mutex
  std::mutex mutex;

  //! \brief pool of database objects
  std::unordered_map <std::thread::id, database> pool;

  //! \brief main thread connection
  mobius::database::connection main_thread_connection;
};

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief default constructor
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
connection_pool::connection_pool ()
  : impl_ (std::make_shared <impl> ())
{
}

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief create connection_pool object
//! \param path database file path
//! \param max maximum number of connections opened
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
connection_pool::connection_pool (
  const std::string& path,
  unsigned int max)
  : impl_ (std::make_shared <impl> ())
{
  impl_->path = path;
  impl_->max = max;
}

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief set database path
//! \param path database file path
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
void
connection_pool::set_path (const std::string& path)
{
  if (!impl_->path.empty ())
    throw std::runtime_error (MOBIUS_EXCEPTION_MSG ("Database path cannot be changed"));

  impl_->path = path;
}

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief acquire a connection
//! \return a connection for this thread
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
connection
connection_pool::acquire ()
{
  auto thread_id = std::this_thread::get_id ();
  std::lock_guard <std::mutex> lock (impl_->mutex);

  // check if connection has already been acquired
  auto iter = impl_->pool.find (thread_id);

  if (iter != impl_->pool.end ())
    throw std::runtime_error (MOBIUS_EXCEPTION_MSG ("Connection has already been acquired"));

  // acquire a new connection
  connection conn (*this);
  impl_->pool[thread_id] = database (impl_->path);

  return conn;
}

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief get database
//! \return database object
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
database
connection_pool::get_database () const
{
  auto thread_id = std::this_thread::get_id ();
  std::lock_guard <std::mutex> lock (impl_->mutex);

  // acquire connection for main thread, if it is the first time call
  if (thread_id == main_thread_id && !impl_->main_thread_connection)
    {
      impl_->main_thread_connection = connection (const_cast <connection_pool&> (*this));
      impl_->pool[thread_id] = database (impl_->path);
    }
  
  // get database object
  auto iter = impl_->pool.find (thread_id);

  if (iter == impl_->pool.end ())
    throw std::runtime_error (MOBIUS_EXCEPTION_MSG ("No acquired connection found"));

  return iter->second;
}

// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
//! \brief release connection
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
void
connection_pool::release ()
{
  auto thread_id = std::this_thread::get_id ();
  std::lock_guard <std::mutex> lock (impl_->mutex);

  impl_->pool.erase (thread_id);
}

} // namespace database
} // namespace mobius
