#include <netdb.h>
#include <errno.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netinet/in.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>
#include <time.h>
#include <unistd.h>

#ifdef SOCK_SSL
#include <ssl/ssl.h>
#endif

#ifndef __GNUC__
#include <errno.h>
#endif

#include "misc.h"

static const char *DEBUG_SOCK;
#define smesg(a,b...) if (DEBUG_SOCK) err(a,##b)

#ifdef SOCK_SSL

static SSL_CTX *ctx;

static __inline__ void
ssl_free_ctx(void) {

  if (!ctx)
    return;

  SSL_CTX_free(ctx);
  ctx=NULL;

  return;

}

static __inline__ int
set_ssl_ctx(void) {

  if (ctx)
    return 0;

  SSLeay_add_ssl_algorithms();
  SSL_load_error_strings();
  ctx = SSL_CTX_new(SSLv23_method());
  if (atexit(ssl_free_ctx))
    errret("Can't setup ssl_free_ctx on exit\n");

  return 0;

}

#endif


static __inline__ int
set_sock(Sock *s) {

  struct sockaddr_in sin;
  int i=1,sock;

  s->s=0;

  sin.sin_family=AF_INET;
  sin.sin_port=htons(s->port);
  if (!inet_aton(s->ip1,&sin.sin_addr))
    errret("Invalid address %s\n",s->ip1);

  if ((sock=socket(AF_INET,SOCK_STREAM,0))==-1)
    errret("Can't make socket\n");
  
  if (setsockopt(sock,SOL_SOCKET,SO_KEEPALIVE,&i,sizeof(int))) 
    errret("Set socket\n");
  
 if (!s->do_ssl && fcntl(sock,F_SETFL,O_NONBLOCK|fcntl(s->s,F_GETFL))==-1)
   errret("Can't set socket to non-blocking\n");
  
  if (connect(sock,(struct sockaddr *) &sin,sizeof(sin))==-1
      && (s->do_ssl || errno!=EINPROGRESS))
    errret("Can't connect socket!\n");
  
#ifdef SOCK_SSL

  if (s->do_ssl) {
    set_ssl_ctx();
    s->ssl=SSL_new(ctx);
    SSL_set_fd(s->ssl,sock);
    if (SSL_connect(s->ssl)<0 && errno!=EINPROGRESS)
      errret("Can't ssl connect socket\n");
  }

#endif

  s->s=sock;

  return 0;

}

static __inline__ int
close_sock(Sock *s) {

  if (close(s->s))
    errret("Can't close socket\n");

#ifdef SOCK_SSL
  if (s->do_ssl) {
    SSL_free(s->ssl);
    s->ssl=NULL;
  }
#endif

  s->s=0;

  return 0;

}
  

static __inline__ int
write_socks(Sock *s1,Sock *se) {

  Sock *s;
  unsigned i,k;
  struct timeval ttv,tv={15,0};
  ssize_t n;
  fd_set fdsw,fdsw1;
  int err=0,bad_sock=0,errs=sizeof(err);

  FD_ZERO(&fdsw);
  for (s=s1,i=k=0;s<se;s++) {

    FD_SET(s->s,&fdsw);
    k=s->s+1>k ? s->s+1 : k;
    i++;

    s->v=s->vi;

  }
    
  while (i) {

    ttv=tv;
    fdsw1=fdsw;
    
    if (!select(k,NULL,&fdsw1,NULL,&ttv)) {
      
      smesg("\nTimed out on writing sockets, reopening ...");
      for (s=s1;s<se;s++)
	if (FD_ISSET(s->s,&fdsw)) {
	  close_sock(s);
	  set_sock(s);
	  smesg("%d ",s-s1);
	  s->v=s->vi;
	  
	}
      
      smesg("done\n");

    } else 
      for (s=s1;s<se;s++) 
	
	if (FD_ISSET(s->s,&fdsw1)) {

	  if (!s->do_ssl && getsockopt(s->s,SOL_SOCKET,SO_ERROR,&err,&errs))
	    errret("Cannot run getsockopt\n");

	  if (err) {
	    close_sock(s);
	    i--;
	    bad_sock=1;
	    continue;
	  }
	    
#ifdef SOCK_SSL
	  if (s->do_ssl) 
	    n=SSL_write(s->ssl,s->v,s->vz-s->v);
          else
#endif
	    n=write(s->s,s->v,s->vz-s->v);
	  switch(n) {
	  case -1:
	    errret("Socket write error\n");
	    break;
	  case 0:
/*  	    errret("Zero bytes written to socket\n"); */
/*  	    break; */
	  default:
	    s->v+=n;
	    if (s->v==s->vz) {
	      FD_CLR(s->s,&fdsw);
	      i--;
	      smesg("%d ",s-s1);
	    }
	    break;
	  }
	}
  }

  if (bad_sock)
    errret("Sockets not connected properly\n");

  return 0;
        
}


static __inline__ int
read_socks(Sock *s1,Sock *se) {

  Sock *s;
  unsigned i,k;
  struct timeval ttv,tv={180,0};
  ssize_t n;
  fd_set fdsr,fdsr1;

  FD_ZERO(&fdsr);
  for (s=s1,i=k=0;s<se;s++) {

    FD_SET(s->s,&fdsr);
    k=s->s+1>k ? s->s+1 : k;
    i++;

    s->v=s->vi;

  }
    
  while (i) {

    ttv=tv;
    fdsr1=fdsr;
    
    if (!select(k,&fdsr1,NULL,NULL,&ttv)) 
      errret("Timed out on reading sockets\n");
    else 
      for (s=s1;s<se;s++) 
	
	if (FD_ISSET(s->s,&fdsr1)) {
	  
#ifdef SOCK_SSL
	  if (s->do_ssl) 
	    n=SSL_read(s->ssl,s->v,s->ve-s->v);
	  else
#endif
	    n=read(s->s,s->v,s->ve-s->v);

	  switch(n) {
	  case -1:
	    errret("Socket read error\n");
	    break;
	  default:
	    s->v+=n;
	    if (s->v==s->ve)
	      r_mem(s->v,s->ve-s->v1+256);
	    if (!s->done || !s->done(s))
	      break;
	  case 0:
	    s->vz=s->v;
	    *(char *)s->v++=0;
	    FD_CLR(s->s,&fdsr);
	    i--;
	    smesg("%d ",s-s1);
	    if (!n)
	      close_sock(s);
	    break;
	  }
	}
    
  }

  return 0;
        
}

int
socks_init(Sock *s,Sock *se,const char *hn,unsigned short port,
	   int (*done)(Sock *)) {

  struct hostent *h;
  const unsigned char *u;

  DEBUG_SOCK=getenv("DEBUG_SOCK");

  if (!hn || !(h=gethostbyname(hn)))
    errret("Hostname lookup error %p %s\n",hn,hn);

  if (h->h_length<4)
    errret("Hostname length %d < 4\n",h->h_length);

  u=(unsigned char *)*h->h_addr_list;
  for (;s<se;s++) {

    if (s->s && close_sock(s))
      errret("failed closing socket\n");

    s->ip=s->ip1;
    l_str(s->ip,"%u.%u.%u.%u",u[0],u[1],u[2],u[3]);
    s->port=port;
    s->done=done;

    switch(s->port) {
    case 443: case 563: case 636:
#ifdef SOCK_SSL    
      s->do_ssl=1;
#else
      errret("Cannot process secure socket, SSL support not compiled int\n");
#endif
      break;
    default:
      s->do_ssl=0;
      break;
    }
    
    if (set_sock(s))
      errret("Failed setting socket\n");

  }

  return 0;

}
  

int
socks_process(Sock *s1,Sock *se) {

  Sock *s;

  for (s=s1;s<se;s++) {
    
    if (!s->s && set_sock(s))
      errret("Can't run set_sock\n");
    
    if (!s->v1 || s->ve<=s->v1 || !(s->vz=memchr(s->v,0,s->ve-s->v)))
      errret("Can't find null termination in last %d bytes for socket %d\n",
	     s->ve-s->v,s-s1);
    
    s->vi=s->v1;
    s->vz--;
    
  }
  
  smesg("Loading sockets...");
  if (write_socks(s1,se))
    errret("Failed writing to sockets\n");
  smesg("done\n");

  for (s=s1;s<se;s++)
    s->vi=s->vz++;

  smesg("Firing sockets...");
  if (write_socks(s1,se))
    errret("Failed firing sockets\n");
  smesg("done\n");

  for (s=s1;s<se;s++)
    s->vi=s->v1;

  smesg("Reading sockets...");
  if (read_socks(s1,se))
    errret("Failed reading to sockets\n");
  smesg("done\n");

  return 0;

}

