# -*- coding: utf-8 -*-

#  SELF Platform: A distributed web application for collaborative
#  production of learning materials employing open standards.

#  Copyright (C) 2007, 2008 Free Software Foundation Europe e.V.

#  This file is part of the SELF Project, a project administered by the
#  SELF Consortium, for which FSFE acts as copyright holder.

#  The SELF Consortium are:
#    Internet Society Nederland
#    Universitat Oberta de Catalunya
#    Free Software Foundation Europe
#    University of Gothenburg
#    Internet Society Bulgaria
#    Fundacion Via Libre
#    Homi Bhabha Centre for Science Education

#  A complete list of authors can be found in the file AUTHORS.

#  This program is Free Software; you can redistribute it and/or
#  modify it under the terms of the GNU General Public License as
#  published by the Free Software Foundation; either version 2 of the
#  License, or (at your option) any later version.

#  This program is distributed in the hope that it will be useful, but
#  WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
#  General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
#  02110-1301, USA.

#  The licensor of SELF Platform is the Free Software Foundation
#  Europe (FSFE), Fiduciary Program, Sumatrastrasse 25, 8006 Zurich,
#  Switzerland, email:ftf@fsfeurope.org.

__docformat__ = 'plaintext'

import sys
import os
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from storageSpec import *
from pgtable import *
from datatypes import *


class tblBase:
    def __init__( self, cur, debug_mode ):
        self.cur = cur
        self.debug_mode = debug_mode

    def debug_print( self, clsname, val ):
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( clsname, val )

    def set_intersect( self, set1, set2 ):
        return( len( set1.intersection( set2 ) ) )
 
    def get_pkey( self, tblname, pkname ):
        query = "SELECT currval( pg_get_serial_sequence( '%s', '%s' ) );" % ( tblname, pkname )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        # print rs[0][0]
        return rs[0][0]

    def pg_sqlize_value( self, fieldDef, value ):
        """
        Requires the fielddef to be from storageSpec class!!
        """
        self.debug_print( "Fields Def: %s" % fieldDef )
        self.debug_print( "Field val : %s" % value )

        if isinstance( value, list ):
            selValue = "ARRAY%s" % value
        else:
            selValue = "'%s'" % value
        return "%s::%s" % ( selValue, fieldDef[1] )

    def does_value_exist( self, tblname, pkcolname, colname, value ):
        selValue = value

        query = "SELECT %s FROM %s WHERE %s=%s GROUP BY %s;" % ( pkcolname, tblname, colname, selValue, pkcolname )
        self.debug_print( "Query: %s"  % query )
        self.cur.execute( query )
        res = self.cur.fetchall()
        self.debug_print( "Result: %s"  % res )

        if res == []:
            return 0
        else:
            return res[0][0]

    def set_in_ft( self, tblname, valuetype, values, dictFieldDefs ):
        """
        Will insert values of valuetype in tblname
        if valuetype = "" -> normal value
        else if valuetype = "ARRAY" -> array value
        """

        print( "Got values ( %s, %s )" % ( tblname, values ) )
        if isinstance( values, list ):
            query = "INSERT INTO %s VALUES( DEFAULT, ARRAY%s );" % ( tblname, values )
        else:
            query = "INSERT INTO %s VALUES( DEFAULT, '%s' );" % ( tblname, values )

        

#         try:
#             # field table name is "nodename_fieldname"
#             # so, extract fieldname and add _fid to make the fid
#             # fieldname like this -> fieldname_fid

#             dummy = tblname.split( '_' )
#             fieldName = "fid"

#             retVal = self.does_value_exist( tblname, fieldName, dummy[1] + "_value", self.pg_sqlize_value( dictFieldDefs[dummy[1]], values ) )
#             print retVal
#             if retVal == 0:
#                 self.debug_print( "Didn't find entry in field table, value unique, so inserting new value" )
#                 self.debug_print( "Query = %s" % query )
#                 #self.cur.execute( query )
#                 return self.get_pkey( tblname, fieldName )
#             else:
#                 self.debug_print( "Found entry in field table" )
#                 self.debug_print( "Ret val is : %s" % retVal )
#                 return retVal
#         except StandardError, err:
#             print "Error: ", err
#             return 0
        

    def get_from_ft( self, tblname, fidfld, fid ):
        """
        Will retrieve values from tblname where fidfld == fid
        """
        query = "SELECT * FROM %s WHERE %s='%s' ORDER BY %s DESC;" % ( tblname, fidfld, fid, fidfld )

        try:
            self.cur.execute( query )
            rs = self.cur.fetchall()
            return rs
        except StandardError, err:
            print "Error: ", err
            return 0


class tbl_nodetype( tblBase ):
    def __init__( self, cursor ):
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def getntid( self, nodetype ):
        query = "SELECT ntid FROM gbnodetypes WHERE nodename='%s';" % nodetype
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0

    def getNodeTypeFromSSID( self, ssid ):
        query = "SELECT nodename FROM view_nidinidssid WHERE ssid='%s';" % ( ssid )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0
        

    def addNodeType( self, nodetype ):
        # TODO:not yet implemented
        pass

class tbl_datatypes( tblBase ):
    def __init__( self, cursor ):
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def getdtnamefromid( self, dtid ):
        query = "SELECT datatypename FROM gbdatatypes WHERE datatypeid='%s';" % dtid
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0
        

    def getdtid( self, datatype ):
        query = "SELECT datatypeid FROM gbdatatypes WHERE datatypename='%s';" % datatype
        self.cur.execute( query )
        rs = self.cur.fetchall()
        if len( rs ) > 0:
            return rs[0][0]
        else:
            return 0

    def addDataType( self, datatype ):
        # TODO:not yet implemented
        pass


class tbl_nidinid( tblBase ):
    def __init__( self, cursor ):
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor

    def setval( self, nid, nodetype ):
        """
        inserts ( nid, nodetype ) in the gbnidinid table
        """

        vt = tbl_values( self.cur, '', 'varchar' )
        vid = vt.insert( nid )

        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )

        query = "INSERT INTO gbnidinid ( nid, inid, ntid ) VALUES ( '%s', DEFAULT, '%s' )" % ( vid, ntid )

        # Does the combo already exist? If not, then...
        if( self.isvalid_nid( nid, nodetype ) == 0 ):
            try:
                self.cur.execute( query )
                return self.get_pkey( 'gbnidinid', 'inid' )
            except StandardError, err:
                print "Error: ", err
        # It does exist! Sorry, no duplicates allowed!
        else:
            print "Sorry, ( %s, %s ) already exists!" % ( nid, nodetype )
            return 0

    def isvalid_nid( self, gbnid, gbnt ):
        """
        function is to check if the given nid,nodetype is valid or not
        """
        query = "select count( gbnidinid.nid ) from gbnidinid, datatypes_varchar, gbnodetypes where gbnidinid.nid = datatypes_varchar.vid AND gbnidinid.ntid = gbnodetypes.ntid AND datatypes_varchar.value = '%s' AND gbnodetypes.nodename='%s';" % ( gbnid, gbnt )

        print "QUERY : " + query
        self.cur.execute( query )
        rs = self.cur.fetchall()

        row = rs[0]
        valid_nid = row[0]

        return valid_nid

    def getinid( self, nid, nodetype ):
        """
        get the inid, given the ( nid, nodetype )
        """
        query = "select gbnidinid.inid from gbnidinid, datatypes_varchar, gbnodetypes where gbnidinid.nid = datatypes_varchar.vid AND gbnidinid.ntid = gbnodetypes.ntid AND datatypes_varchar.value = '%s' AND gbnodetypes.nodename='%s';" % ( nid, nodetype )
        #query = "SELECT inid FROM gbnidinid WHERE nid='%s' AND nodetype='%s'" % ( nid, nodetype )

        self.cur.execute( query )
        rs = self.cur.fetchall()

        if len( rs ) == 0:
            return 0
        else:
            return( rs[0][0] )



class tbl_inidssid( tblBase ):
    """
    Access class for gbinidssid table
    """
    def __init__( self, cursor ):
        """
        get the inid, given the ( nid, nodetype )
        """
        tblBase.__init__( self, cursor, 0 )
        self.cur = cursor
    
    def setval_using_nid( self, nid, nodetype ):
        """
        get a new ssid given the ( nid, nodetype )
        """
        t = tbl_nidinid( self.cur )
        inid = t.getinid( nid, nodetype )

        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )
        
        if inid != 0:
            query = "INSERT INTO gbinidssid ( inid, ssid, ntid ) VALUES( '%s', DEFAULT, '%s' );" % ( inid, ntid )
            self.cur.execute( query )
            return self.get_pkey( 'gbinidssid', 'ssid' )
        else:
            return inid

    def setval_using_inid( self, inid, nodetype ):
        """
        get the inid, given the ( nid, nodetype )
        """
        nt = tbl_nodetype( self.cur )
        ntid = nt.getntid( nodetype )

        query = "INSERT INTO gbinidssid ( inid, ssid, ntid ) VALUES( '%s', DEFAULT, '%s' );" % ( inid, ntid )
        print query
        self.cur.execute( query )
        return self.get_pkey( 'gbinidssid', 'ssid' )
    
#     def get_lastssid_from_nid( self, nid, nodetype ):
#         """
#         get the inid, given the ( nid, nodetype )
#         """
#         t = tbl_nidinid( self.cur )
#         inid = t.getinid( nid, nodetype )
        
#         if inid != 0:
#             query = "SELECT MAX( ssid ) FROM gbinidssid WHERE inid='%s' AND nodetype='%s';" % ( inid, nodetype )
#             self.cur.execute( query )
#             rs = self.cur.fetchall()

#             print rs

#             if len( rs ) == 0:
#                 return 0
#             else:
#                 return( rs[0][0] )
        
#     def get_lastssid_from_inid( self, inid, nodetype ):
#         """
#         get the inid, given the ( inid, nodetype )
#         """
#         query = "SELECT MAX( ssid ) FROM gbinidssid WHERE inid='%s' AND nodetype='%s';" % ( inid, nodetype )
#         self.cur.execute( query )
#         rs = self.cur.fetchall()

#         print rs

#         if len( rs ) == 0:
#             return 0
#         else:
#             return( rs[0][0] )

#     def get_allssid_from_nid( self, nid, nodetype ):
#         """
#         get all ssids given the ( nid, nodetype ) order descending
#         """
#         t = tbl_nidinid( self.cur )
#         inid = t.getinid( nid, nodetype )
        
#         if inid != 0:
#             query = "SELECT ssid FROM gbinidssid WHERE inid='%s' AND nodetype='%s' ORDER BY ssid DESC;" % ( inid, nodetype )
#             self.cur.execute( query )
#             rs = self.cur.fetchall()
            
#             print rs

#             if len( rs ) == 0:
#                 return 0
#             else:
#                 return( rs )
            
#     def get_allssid_from_inid( self, inid, nodetype ):
#         """
#         get all ssids given the ( nid, nodetype ) order descending
#         """
#         query = "SELECT ssid FROM gbinidssid WHERE inid='%s' AND nodetype='%s' ORDER BY ssid DESC;" % ( inid, nodetype )
#         self.cur.execute( query )
#         rs = self.cur.fetchall()

#         print rs

#         if len( rs ) == 0:
#             return 0
#         else:
#             return( rs )


def test_nid_table( cur ):
    nid = "fan"
    nodetype = "gbobjecttypes"

    t = tbl_nidinid( cur )
    print "Inserting nid = %s in nodetype = %s" % ( nid, nodetype )

    inid = t.setval( nid, nodetype )
    print "Generated inid = %s" % inid

    print "testing if nid = %s is valid -> result: %s" % ( nid, t.isvalid_nid( nid, nodetype ) )
    print "testing if nid = %s is valid -> result: %s" % ( 'asdkjsd', t.isvalid_nid( 'asdkjsd', nodetype ) )

    print "getting inid for nid = %s, nodetype = %s -> result: %s" % ( nid, nodetype, t.getinid( nid, nodetype ) )

def test_inid_table( cur ):
    nid = "fan"
    nodetype = "gbobjecttypes"

    t = tbl_inidssid( cur )
    print "Inserting nid = %s in nodetype = %s" % ( nid, nodetype )

    ssid = t.setval_using_nid( nid, nodetype )
    print "Generated ssid = %s" % ssid

    t2 = tbl_nidinid( cur )
    inid = t2.getinid( nid, nodetype )
    print "Got inid = %s" % inid

    ssid = t.setval_using_inid( inid, nodetype )
    print "Generated ssid = %s" % ssid


def test_field_table( cur ):
    s = storageSpec()
    flddef = s.dictTNamesFDefs['gbmetatypes']['title']

    obj = tbl_nodetype( cur )
    ntid = obj.getntid( 'gbmetatypes' )

    dtobj = tbl_datatypes( cur )
    dtid = dtobj.getdtid( 'varchar[]' )

    ft = tbl_field( cur, flddef )
    fid = ft.insert( { 'ntid':ntid, 'datatypeid':dtid, 'value':[ 'one', 'ONE' ] } )
    print "Fid generated was %s" % fid

def test_value_table( cur ):
    vt = tbl_values( cur, '', 'varchar' )

    num = "'Saving Recovery Information Automatically Every 'n' Minutes'"
    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid )

#     vt = tbl_values( cur, '', 'int8' )

#     num = 23
#     vid = vt.insert( num )
#     print "vid for %s is %s" % ( num, vid )

#     print "inserting %s again" % num
#     vid = vt.insert( num )
#     print "vid for %s is %s" % ( num, vid )
    

def test_retr_ntid( cur ):
    nodes = [ "gbobjects",
              "gbrelations",
              "gbattributes",
              "gbmetatypes",
              "gbobjecttypes",
              "gbrelationtypes",
              "gbattributetypes",
              "gbusertypes",
              "gbusers"
              ]

    obj = tbl_nodetype( cur )
    for n in nodes:
        nodetype = n
        ntid = obj.getntid( nodetype )

        print "ntid for %s is %s" % ( nodetype, ntid )


def test_retr_dtid( cur ):
    datatypenames = [
            "int8",
            "int8[]",
            "bit",
            "bit[]",
            "varbit",
            "varbit[]",
            "boolean",
            "boolean[]",
            "box",
            "box[]",
            "bytea",
            "bytea[]",
            "varchar",
            "varchar[]",
            "char",
            "char[]",
            "cidr",
            "cidr[]",
            "circle",
            "circle[]",
            "date",
            "date[]",
            "float8",
            "float8[]",
            "inet",
            "inet[]",
            "int4",
            "int4[]",
            "interval",
            "interval[]",
            "line",
            "line[]",
            "lseg",
            "lseg[]",
            "macaddr",
            "macaddr[]",
            "money",
            "money[]",
            "numeric",
            "numeric[]",
            "path",
            "path[]",
            "point",
            "point[]",
            "polygon",
            "polygon[]",
            "float4",
            "float4[]",
            "int2",
            "int2[]",
            "text",
            "text[]",
            "time",
            "time[]",
            "timestamptz",
            "timestamptz[]",
            "abstime",
            "abstime[]",
            "aclitem",
            "aclitem[]",
            "bpchar",
            "bpchar[]",
            "cid",
            "cid[]",
            "oid",
            "oid[]",
            "refcursor",
            "refcursor[]",
            "regclass",
            "regclass[]",
            "regoper",
            "regoper[]",
            "regoperator",
            "regoperator[]",
            "regproc",
            "regproc[]",
            "regprocedure",
            "regprocedure[]",
            "regtype",
            "regtype[]",
            "reltime",
            "reltime[]",
            "smgr",
            "tid",
            "tid[]",
            "timetz",
            "timetz[]",
            "tinterval",
            "tinterval[]",
            "unknown",
            "xid",
            "xid[]",
            "int2vector",
            "int2vector[]",
            "name",
            "name[]",
            "oidvector",
            "oidvector[]",
            "serial",
            "serial[]",
            "serial8",
            "serial8[]",
            ]


    obj = tbl_datatypes( cur )
    for n in datatypenames:
        dtid = obj.getdtid( n )

        print "dtid for %s is %s" % ( n, dtid )

# class tbl_values:
#     def __init__( self, cur, tblName, datatype ):
#         tmp = gnowsysDatatypes( cur )
#         self.tblName = tmp.getDataTypeTableName( datatype )

#         self.cur = cur
#         #self.tblName = tblName
#         self.datatype = datatype
#         self.debug_mode = 1

#     def debug_print( self, val ):
#         if self.debug_mode != 0:
#             print "Debug[%s]: %s " % ( self.tblName, val )

#     def process_str_list( self, strList ):
#         print "process_str_list called: %s" % strList
#         if isinstance( strList, list ):
#             i=0
#             tmpList = []
#             for s in strList:
#                 # handle both types of strings, unicode as well as ascii
#                 if isinstance( s, str ) or isinstance( s, unicode ):
#                     tmps = s.replace( "'", "''" )
# #                     print tmps
#                     tmpList.append( tmps )
#                 else:
#                     tmpList.append( s )
#                 i=i+1
#             return self.make_array_into_str( tmpList )
#         else:
#             return strList

#     def make_array_into_str( self, arr ):
#         myStr = "[%s]"
#         insideStr = ", " . join( [ "'%s'" % e for e in arr ] )
#         return myStr % insideStr
        

#     def process_value( self, value ):
#         if isinstance( value, list ):
#             return "ARRAY%s::%s" % ( self.process_str_list( value ), self.datatype )
#         else:
#             # handle both types of strings, unicode as well as ascii
#             if isinstance( value, str ) or isinstance( value, unicode ):
#                 value = value.replace( "'", "''" )

#             return "'%s'::%s" % ( value, self.datatype )

#     def get_pkey( self ):
#         query = "SELECT currval( pg_get_serial_sequence( '%s', 'vid' ) );" % ( self.tblName )
#         self.debug_print( "QUERY: %s" % query )
#         self.cur.execute( query )
#         rs = self.cur.fetchall()
#         return rs[0][0]

#     def exists( self, value ):
#         self.debug_print( "exists called: %s" % value )
#         query = "SELECT vid FROM %s WHERE value=%s" % ( self.tblName, self.process_value( value ) )
#         self.debug_print( "QUERY: %s" % query )
#         self.cur.execute( query )
#         rs = self.cur.fetchall()
        
#         if( len( rs ) > 0 ):
#             return rs[0][0]
#         else:
#             return 0

#     def insert( self, value ):
#         """
#         logic:
#         check if value exists
#         if not, add and return vid
#         else, return vid
#         """
#         vid = self.exists( value )
#         if( vid == 0 ):
#             # value doesn't exist, insert it
#             self.debug_print( "Value %s doesn't exist, inserting" % value )

#             query = "INSERT INTO %s ( vid, value ) VALUES ( DEFAULT, %s )" % ( self.tblName, self.process_value( value ) )
#             self.debug_print( "INSERT QUERY: %s " % query )
#             self.cur.execute( query )
#             vid = self.get_pkey()
#             self.debug_print( "Generated vid is : %s" % vid )
#             # get the just inserted value
#             return vid
#         else:
#             self.debug_print( "Value %s already exists, vid is %s" % ( value, vid ) )
#             return vid

class tbl_values:
    def __init__( self, cur, tblName, datatype ):
        tmp = gnowsysDatatypes( cur )
        self.tblName = tmp.getDataTypeTableName( datatype )

        self.cur = cur
        self.datatype = datatype
        self.debug_mode = 1

    def debug_print( self, val ):
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def process_str_list( self, strList ):
        if isinstance( strList, list ):
            i=0
            tmpList = []
            for s in strList:
                # handle both types of strings, unicode as well as ascii
                if isinstance( s, str ) or isinstance( s, unicode ):
                    tmps = s.replace( "'", "''" )
                    #print tmps
                    tmpList.append( tmps )
                else:
                    tmpList.append( s )
                i=i+1
            return self.make_array_into_str( tmpList )
        else:
            return strList

    def make_array_into_str( self, arr ):
        myStr = "[%s]"
        processedElements = []
        for e in arr:
            if isinstance( e, list ):
                arrStr = "ARRAY%s" % self.make_array_into_str( e )
                processedElements.append( arrStr )
            else:
                processedElements.append( "'%s'" % e )

        insideStr = ", " . join( [ "%s" % e for e in processedElements ] )
        return myStr % insideStr
        

    def process_value( self, value ):
        if isinstance( value, list ):
            return "ARRAY%s::%s" % ( self.process_str_list( value ), self.datatype )
        else:
            # handle both types of strings, unicode as well as ascii
            if isinstance( value, str ) or isinstance( value, unicode ):
                value = value.replace( "'", "''" )

            return "'%s'::%s" % ( value, self.datatype )

    def get_pkey( self ):
        query = "SELECT currval( pg_get_serial_sequence( '%s', 'vid' ) );" % ( self.tblName )
        self.debug_print( "QUERY: %s" % query )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def exists( self, value ):
        query = "SELECT vid FROM %s WHERE value=%s" % ( self.tblName, self.process_value( value ) )
        self.debug_print( "QUERY: %s" % query )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        
        if( len( rs ) > 0 ):
            return rs[0][0]
        else:
            return 0

    def insert( self, value ):
        """
        logic:
        check if value exists
        if not, add and return vid
        else, return vid
        """
        vid = self.exists( value )
        if( vid == 0 ):
            # value doesn't exist, insert it
            self.debug_print( "Value %s doesn't exist, inserting" % value )

            query = "INSERT INTO %s ( vid, value ) VALUES ( DEFAULT, %s )" % ( self.tblName, self.process_value( value ) )
            self.debug_print( "INSERT QUERY: %s " % query )
            self.cur.execute( query )
            vid = self.get_pkey()
            self.debug_print( "Generated vid is : %s" % vid )
            # get the just inserted value
            return vid
        else:
            self.debug_print( "Value %s already exists, vid is %s" % ( value, vid ) )
            return vid


class tbl_field:
    def __init__( self, cur, fldDef ):
        self.tblName = "field_%s" % fldDef[0]
        self.datatype = fldDef[1]
        self.valuetable = tbl_values( cur, '', self.datatype )
        self.cur = cur
        self.debug_mode = 1

    def debug_print( self, val ):
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def get_pkey( self ):
        query = "SELECT currval( pg_get_serial_sequence( '%s', 'fid' ) );" % ( self.tblName )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

    def exists( self, value ):
        vid = self.valuetable.exists( value )
        if vid != 0:
            query = "SELECT fid FROM %s WHERE vid=%s" % ( self.tblName, vid )
            self.debug_print( "QUERY: %s" % query )
            self.cur.execute( query )
            
            rs = self.cur.fetchall()
        
            if( len( rs ) > 0 ):
                return rs[0][0]
            else:
                return 0
        else:
            return 0


    def insert( self, dictValues ):
        """
        dictValues { 'ntid': '', 'datatypeid': '', 'value': '' }
        logic:
        check if value exists
        if not, add and return fid
        else, return existing fid
        """
        value = dictValues[ 'value' ]

        vid = self.valuetable.exists( value )
        ntid = dictValues[ 'ntid' ]
        dtid = dictValues[ 'datatypeid' ]

        if( vid == 0 ):
            # value doesn't exist, insert it
            self.debug_print( "Value %s doesn't exist, inserting" % value )
            vid = self.valuetable.insert( value )

            self.debug_print( "Generated vid is: %s " % vid )
            query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
            self.debug_print( "QUERY: %s" % query )
            self.cur.execute( query )
            
            # get the just inserted value
            fid = self.get_pkey()
            self.debug_print( "Generated fid is: %s" % fid )
            return fid
        else:
            fid = self.exists( value )
            if fid != 0:
                return fid
            else:
                query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
                self.debug_print( "QUERY: %s" % query )
                self.cur.execute( query )

                # get the just inserted value
                fid = self.get_pkey()
                self.debug_print( "Generated fid is: %s" % fid )
                return fid
                
#             self.debug_print( "Value %s already exists, fid is %s" % ( value, vid ) )
#             return vid



# class tbl_field:
#     def __init__( self, cur, fldDef ):
#         self.tblName = "field_%s" % fldDef[0]
#         self.datatype = fldDef[1]
#         self.valuetable = tbl_values( cur, '', self.datatype )
#         self.cur = cur
#         self.debug_mode = 1

#     def debug_print( self, val ):
#         if self.debug_mode != 0:
#             print "Debug[%s]: %s " % ( self.tblName, val )

#     def get_pkey( self ):
#         query = "SELECT currval( pg_get_serial_sequence( '%s', 'fid' ) );" % ( self.tblName )
#         self.cur.execute( query )
#         rs = self.cur.fetchall()
#         return rs[0][0]

#     def exists( self, value ):
#         vid = self.valuetable.exists( value )
#         if vid != 0:
#             query = "SELECT fid FROM %s WHERE vid=%s" % ( self.tblName, vid )
#             self.debug_print( "QUERY: %s" % query )
#             self.cur.execute( query )
            
#             rs = self.cur.fetchall()
        
#             if( len( rs ) > 0 ):
#                 return rs[0][0]
#             else:
#                 return 0
#         else:
#             return 0


#     def insert( self, dictValues ):
#         """
#         dictValues { 'ntid': '', 'datatypeid': '', 'value': '' }
#         logic:
#         check if value exists
#         if not, add and return fid
#         else, return existing fid
#         """
#         value = dictValues[ 'value' ]

#         vid = self.valuetable.exists( value )
#         ntid = dictValues[ 'ntid' ]
#         dtid = dictValues[ 'datatypeid' ]

#         if( vid == 0 ):
#             # value doesn't exist, insert it
#             self.debug_print( "Value %s doesn't exist, inserting" % value )
#             vid = self.valuetable.insert( value )

#             self.debug_print( "Generated vid is: %s " % vid )
#             query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
#             self.debug_print( "QUERY: %s" % query )
#             self.cur.execute( query )
            
#             # get the just inserted value
#             fid = self.get_pkey()
#             self.debug_print( "Generated fid is: %s" % fid )
#             return fid
#         else:
#             fid = self.exists( value )
#             if fid != 0:
#                 return fid
#             else:
#                 query = "INSERT INTO %s ( fid, ntid, vid, datatypeid ) VALUES( DEFAULT, '%s', '%s', '%s' )" % ( self.tblName, ntid, vid, dtid )
#                 self.debug_print( "QUERY: %s" % query )
#                 self.cur.execute( query )

#                 # get the just inserted value
#                 fid = self.get_pkey()
#                 self.debug_print( "Generated fid is: %s" % fid )
#                 return fid
                
# #             self.debug_print( "Value %s already exists, fid is %s" % ( value, vid ) )
# #             return vid
    
        

class genericTable:
    def __init__( self, cur, tblName, tblDef='', dictTblDef='', debugmode=0 ):
        s = storageSpec()

        self.cur = cur
        self.tblName = tblName
        self.debug_mode = debugmode
        self.tblDef = s.dictTableNamesAndDefs[ tblName ]
        self.dictTblDef = dictTblDef    # dict: { fldname: flddef }

        self.lstRegFlds = []
        self.lstFldTbls = []
        self.lstRegFldNames = []
        self.lstFldTblNames = []
        self.lstFields  = []

        nt = tbl_nodetype( cur )
        self.ntid = nt.getntid( tblName )

        for v in tblDef:
            self.lstFields = v[0]
            if v[3] != "":
                self.lstFldTbls.append( v )
                self.lstFldTblNames.append( v[0] )
                self.debug_print( "Field table   : %s" % v )
            else:
                self.lstRegFlds.append( v )
                self.lstRegFldNames.append( v[0] )
                self.debug_print( "Regular field : %s" % v )

    def debug_print( self, val ):
        if self.debug_mode != 0:
            print "Debug[%s]: %s " % ( self.tblName, val )

    def process_string( self, regStr ):
        tmpStr=""
        if isinstance( regStr, str ):
            tmpStr = '$$%s$$' % regStr
#             tmpStr = repr( regStr )
#             if tmpStr[0] == '"':
#                 tmpStr = ( repr( ( tmpStr + '"' ) )[:-2] + "'" )
        
            return tmpStr
        else:
            return regStr

    def process_str_list( self, strList ):
        if isinstance( strList, list ):
            i=0
            tmpList = []
            for s in strList:
                if isinstance( s, str ):
                    tmps = s.replace( "'", "''" )
                    print tmps
                    tmpList.append( tmps )
                else:
                    tmpList.append( s )
                i=i+1
            return self.make_array_into_str( tmpList )
        else:
            return strList

    def make_array_into_str( self, arr ):
        myStr = "[%s]"
        insideStr = ", " . join( [ "'%s'" % e for e in arr ] )
        return myStr % insideStr
        

    def process_value( self, value, datatype ):
        if isinstance( value, list ):
            return "ARRAY%s::%s" % ( self.process_str_list( value ), datatype )
        else:
            if isinstance( value, str ):
                value = value.replace( "'", "''" )

            return "'%s'::%s" % ( value, datatype )


#     def process_value( self, value, datatype ):
#         if isinstance( value, list ):
#             return "ARRAY%s::%s" % ( value, datatype )
#         else:
#             return "'%s'::%s" % ( value, datatype )

    def make_string_safe( self, value ):
        if isinstance( value, str ):
            return value.replace( "'", "''" )

        return value
        
    def insert( self, dictFlds ):
        setRegFlds = set( self.lstRegFldNames )
        setFldTbls = set( self.lstFldTblNames )
        setPassedFlds = set( dictFlds.keys() )

        setPassedRegFlds = setRegFlds.intersection( setPassedFlds )
        setPassedFldTbls = setFldTbls.intersection( setPassedFlds )
    
        self.debug_print( "Passed Regular Fields: %s" % setPassedRegFlds )
        self.debug_print( "Passed Table Fields: %s" % setPassedFldTbls )

        lstPassedRegFlds = list( setPassedRegFlds )
        lstPassedFldTbls = list( setPassedFldTbls )

        lstFids = []
        for f in lstPassedRegFlds:
            lstFids.append( self.process_value( dictFlds[ f ], self.dictTblDef[ f ][1] ) )
            #lstFids.append( dictFlds[ f ] )


        print self.tblDef
        for f in lstPassedFldTbls:
            ft = tbl_field( self.cur, self.dictTblDef[ f ] )
            dtt = tbl_datatypes( self.cur )
            dtid = dtt.getdtid( self.dictTblDef[ f ][1] )
            fid = ft.insert( { 'ntid' : self.ntid, 'datatypeid' : dtid, 'value' : dictFlds[ f ] } )
            lstFids.append( fid )
        strRegFlds = ", " . join( lstPassedRegFlds )
        strFldTbls = ", " . join( lstPassedFldTbls )

        strFlds = strRegFlds
        if strFldTbls != "":
            strFlds = strFlds + ', ' + strFldTbls

        print strFlds
        
        strVals = ", " . join( "%s" % v for v in lstFids )
        print strVals
        
        strQuery = "INSERT INTO %s ( %s ) VALUES ( %s );" % ( self.tblName, strFlds, strVals )
        self.debug_print( strQuery )
        self.cur.execute( strQuery )

    def getAllBySSIDCols( self, viewName, colNames, lstSSID, nodeType ):
        """
        Get all SSIDs but only the columns in the list provided
        """
        lstColNames = colNames
        lstColNames.append( 'ssid' )

        strSelCols = ", " . join( lstColNames )
        strLstSSID = ", " . join( "%s" % s for s in lstSSID )

        selectQuery = 'SELECT %s FROM %s WHERE ssid in ( %s ) ORDER BY ssid DESC;' % ( strSelCols, viewName, strLstSSID )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            resDict[ tmpDict['ssid'] ] = tmpDict

        return resDict

    def getAllIdsFromTableCols( self, idCol, tblName, colNames, lstids ):
        """
        Get all specified ids from specified table and only the columns in the list provided
        """
        lstColNames = colNames
        lstColNames.append( idCol )

        strSelCols = ", " . join( lstColNames )
        strLstID = ", " . join( "%s" % s for s in lstids )

        selectQuery = 'SELECT %s FROM %s WHERE %s in ( %s ) ORDER BY %s DESC;' % ( strSelCols, tblName, idCol, strLstID, idCol )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            resDict[ tmpDict[ idCol ] ] = tmpDict

        return resDict

    def getAllBySSID( self, viewName, lstSSID, nodeType ):
        query = "SELECT column_name FROM information_schema.columns WHERE table_name='%s';" % ( viewName )
        #print query

        self.cur.execute( query )
        cols = self.cur.fetchall()
        #print "Columns: %s" % cols

        lstColNames = []
        for fld in cols:
            lstColNames.append( fld[0] )

        #print "Column names: %s " % lstColNames

        strSelCols = ", " . join( lstColNames )
        strLstSSID = ", " . join( "%s" % s for s in lstSSID )

        selectQuery = 'SELECT %s FROM %s WHERE ssid in ( %s ) ORDER BY ssid DESC;' % ( strSelCols, viewName, strLstSSID )
        #print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()
        #print "result: %s" % res

        resDict = {}
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1

            #print tmpDict
            resDict[ tmpDict['ssid'] ] = tmpDict

        #print 'result dict: %s' % resDict
        return resDict

    def getLatestSSIDFromNid( self, lstNids ):
        strNids = ", " . join( [ "'%s'" % self.make_string_safe( n ) for n in lstNids ] )
        query = "SELECT MAX( ssid ) AS ssid, nid FROM view_nidinidssid WHERE nid IN ( %s ) GROUP BY nid;" % strNids
        self.cur.execute( query )
        res = self.cur.fetchall()
        resDict = {}

        for t in res:
            resDict[ t[1] ] = t[0]

        print "nid to ssid: %s" % resDict
        return resDict

    def getLatestSSIDFromNidNT( self, nid, nodetype ):
        """
        Accepts nid, nodetype
        Returns the latest SSID of the ( nid, nodetype ) pair if found
        else returns 0
        """
        query = "SELECT ssid FROM view_nidinidssid WHERE ssid=( SELECT MAX( ssid ) FROM view_nidinidssid WHERE nid='%s' AND nodename='%s' );"
        rquery = query % ( self.make_string_safe( nid ), self.make_string_safe( nodetype ) )
        self.cur.execute( rquery )
        res = self.cur.fetchall()

        if( len( res ) != 0 ):
            return res[0][0]
        else:
            return 0

    def getFromView( self, viewname, ssid=0, nid='' ):
        lstClause = []
        if ssid != 0:
            lstClause.append( ( 'ssid=%s' % ssid ) )

        if nid != '':
            lstClause.append( ( "nid='%s'" % nid ) )

        strClause = " AND " . join( lstClause )

        print strClause

        query = "SELECT column_name FROM information_schema.columns WHERE table_name='%s';" % ( viewname )
        print query

        self.cur.execute( query )
        cols = self.cur.fetchall()
        #print "Columns: %s" % cols

        lstColNames = []
        for fld in cols:
            lstColNames.append( fld[0] )

        #print "Column names: %s " % lstColNames

        strSelCols = ", " . join( lstColNames )

        if strClause != '':
            strClause = "WHERE " + strClause

        selectQuery = 'SELECT %s FROM %s %s ORDER BY ssid DESC;' % ( strSelCols, viewname, strClause )
        print selectQuery

        self.cur.execute( selectQuery )
        res = self.cur.fetchall()
        print res
        resDict = []
        for r in res:
            i=0
            tmpDict = {}
            for c in r:
                tmpDict[ lstColNames[i] ] = c 
                i=i+1
            resDict.append( tmpDict )
        print resDict
        return resDict


    def get_pkey( self, pkcolname ):
        query = "SELECT currval( pg_get_serial_sequence( '%s', '%s' ) );" % ( self.tblName, pkcolname )
        self.cur.execute( query )
        rs = self.cur.fetchall()
        return rs[0][0]

def test_nested_array_insertion( cur ):
    vt = tbl_values( cur, '', 'int8[]' )

    num = [ [1,2], [2,3], [3,5], [7,8] ]

    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid )

    print "inserting %s again" % num
    vid = vt.insert( num )
    print "vid for %s is %s" % ( num, vid ) 


if __name__ == "__main__":
    dictConn = { 
        'dbname':'self2', 
        'username':'akula', 
        'password':'akula', 
        'host':'localhost'
        }

    conn = psycopg2.connect( "dbname=%(dbname)s user=%(username)s password=%(password)s host=%(host)s" % dictConn )
    cur = conn.cursor()

    
    gt = genericTable( cur, 'gbobjects' )
    print gt.getLatestSSIDFromNidNT( 'armenian', 'gbobjects' )
    sys.exit()

    test_nested_array_insertion( cur )

#     test_value_table( cur )
    sys.exit()

    gt = genericTable( cur, 'gbobjects' )
    print gt.getAllBySSID( 'djview_o', [303, 304, 305], 'gbobjects' )
    sys.exit()

#     # test the ntid retrieval
#     test_retr_ntid( cur )

#     # test the ntid retrieval
#     test_retr_dtid( cur )

    test_nid_table( cur )
    test_inid_table( cur )
    test_value_table( cur )
    test_field_table( cur )

#    sys.exit()

    nid = "fan"
    nt = "gbobjecttypes"

    nidtbl = tbl_nidinid( cur )
    inid = nidtbl.setval( nid, nt )
    if inid == 0:
        inid = nidtbl.getinid( nid, nt )

    inidtbl = tbl_inidssid( cur )
    ssid = inidtbl.setval_using_nid( nid, nt )

    testgbmetatypes = {
            'status':"s't'a't'u's'text13",
            'content':"te'x't11",
            'inid': inid, 
            'subtypes':[15, 17, 12],
            'ssid': ssid, 
            'noofcommits':'12',
            'subtypeof':[13, 16, 16],
            'changetype':[1,1,1,1],
            'title':["tasdsad'asdas'asd'as'd'''''ext11", 'text118'],
            'uri':'text13',
            'relations':[19, 17, 18],
            'noofchangesaftercommit':'10',
            'instances':[18, 19, 11],
            'noofchanges':'17',
            'description':'text15',
            'attributes':[19, 11, 12],
            'relationtypes':[13, 16, 17],
            'history':[16, 10, 13],
            'fieldschanged':['text1114', 'text14'],
            'attributetypes':[15, 20, 14],
            'uid':'10',
    }

    s = storageSpec()
    gt = genericTable( cur, nt, s.dictTableNamesAndDefs[ nt ], s.dictTNamesFDefs[ nt ], 1 )
    gt.insert( testgbmetatypes )
    
    conn.commit()
    conn.close()
