//*****************************************************************************
//
// Program:     vbridge -- Programming Project 2
//
// Author:      Travis Dillon
//
// Class:       CS 444
// Email:       tdillon@ace.cs.ohiou.edu
//
// Description: This is the bridge program that does the bridging between
//              for the virtual lans.
//
// Date:        January 17, 2005
//
//*****************************************************************************


#include <stdlib.h>
#include <iostream.h>
#include <vrouter.h>
#include <time.h>
#include <vector>

#define DEFAULT_GROUP 6
#define BUFFER_SIZE 512
#define INTERFACE_NAME_LENGTH 100
#define MAX_NUM_INTERFACES 4
#define MAX_TABLE_SIZE 100
#define DEFAULT_TIMEOUT 15

using namespace std;

typedef struct _Table
{
   char hw_address[7];
   char hw_interface[7];
   time_t timestamp;
}Table;

typedef struct _Ethernet
{
   unsigned char destination[7];
   unsigned char source[7];
   unsigned char type[3];
}Ethernet;

void option_h();
void option_l(lanlist* all_lans);
void option_g(int argc, char** argv, unsigned& group, int& i);
void option_t(int argc, char** argv, int& timeout, int& i);
void option_e(int argc, char** argv, int& num_packets, int& i, bool& op_e_on);
void forwardit(char buf[], const int& packet_size, lanlist* all_lans,
               const size_t& j, vector <Table>& binding_table,
               INTERFACE** ifaces, const int& timeout);


int main(int argc, char *argv[])
{
   int i, packet_size(0), timeout(DEFAULT_TIMEOUT), num_packets(1);
   char buf[BUFFER_SIZE];
   unsigned group = DEFAULT_GROUP;
   bool debug_on = false, option_e_on(false);
   lanlist* all_lans = whichlans();
   _lan* lan_ptr = all_lans -> lan;
   INTERFACE* ifaces[MAX_NUM_INTERFACES];
   vector <Table> binding_table;

   if(all_lans == NULL)
   {
      cerr << "whichlans() failed and returned NULL\n";
      exit(EXIT_FAILURE);
   }

   for(i=1; i < argc; ++i)  //parse the command line
   {
      if(argv[i][0] != '-')  //at this point an argument must have a '-'
      {
         cout << "Invalid option - \""  << argv[i] << "\".\n";
         option_h();
      }
      switch(argv[i][1])
      {
         case 'd': debug_on = true; break;
         case 'l': option_l(all_lans); break;
         case 'g': option_g(argc, argv, group, i); break;
         case 't': option_t(argc, argv, timeout, i); break;
         case 'h': option_h(); break;
         case 'e': option_e(argc, argv, num_packets, i, option_e_on); break;
         default:
         {
            cout << "Invalid option - \""  << argv[i] << "\".\n";
            option_h();
            break;
         }
      }   
   }

   for(size_t j(0); j < MAX_NUM_INTERFACES; ++j)
   {
      if(lan_ptr == NULL) ifaces[j] = NULL;
      else
      {
         ifaces[j] = openinterface(lan_ptr -> lname, group);
         lan_ptr = lan_ptr -> next;
      }
   }


   if(debug_on)
   {
      if(option_e_on) cout << "\nnum_packets = " << num_packets;
      cout << "\ngroup = " << group << "\ntimeout = " << timeout << endl;
      for(lan_ptr = all_lans -> lan; lan_ptr != NULL; lan_ptr = lan_ptr -> next)
         cout << "lan_ptr -> lname = " << lan_ptr -> lname << endl;
   }

   //This block of code is modeled after the code Dr. Ostermann sent us in the
   //Vbridge options email.
   while(num_packets > 0)  //let the bridge run
   {
      blockforpacket(ifaces);  //wait until we have a packet somewhere
      for(size_t j(0); ifaces[j] != NULL; ++j)  //look for the packet
      {
         if(checkforpacket(ifaces[j]))  //check all interfaces for packets
         {
            packet_size = readpacket(ifaces[j], buf, BUFFER_SIZE);  //get packet
            if(packet_size == -1)
            {
               cerr << "readpacket failed.\n";
               exit(EXIT_FAILURE);
            }
            forwardit(buf, packet_size, all_lans, j,
                      binding_table, ifaces, timeout);  //process the packet
         }
      }
      if(option_e_on) --num_packets;
   }

   for(size_t j(0); ifaces[j] != NULL; ++j) closeinterface(ifaces[j]);

   exit(EXIT_SUCCESS);
}


//****************************************************************************
//
// Function:   option_h
//
// Purpose:    If the -h flag is given in the command line or someone enters
//             an invalid command argument, this will print out options.
//
// Parameters: none
//
// Calls:      exit
//
//****************************************************************************


void option_h()
{
   cout << "-d          To print debugging information.\n"
        << "-l          List all interfaces and exit.\n"
        << "-g NUM      Use group NUM. If NUM is ommitted then "
        << DEFAULT_GROUP << " will be used.\n"
        << "-t SECONDS  Number of seconds an ETHERNET address can stay.\n"
        << "-e NUM      The brige will only send NUM packets.\n"
        << "-h          Print a quick summary of the command line arguments"
        << " and exit.\n";

   exit(EXIT_SUCCESS);
}


//****************************************************************************
//
// Function:   option_l
//
// Purpose:    If the -l flag is given in the command line this will print
//             all the names and addresses of network devices on a computer.
//
// Parameters: all_lans - linked list that holds all network addresses and
//                        names of them on the computer
//
// Calls:      exit
//
//****************************************************************************


void option_l(lanlist* all_lans)
{
   _lan* lan_ptr = all_lans -> lan;
   for(int j(0); j < all_lans -> cnt; ++j)
   {
      cout << "name = " << lan_ptr -> lname << endl;
      cout << "addr = " << lan_ptr -> hwaddr << endl << endl;
      lan_ptr = lan_ptr -> next;
   }
   exit(EXIT_SUCCESS);
}


//****************************************************************************
//
// Function:   option_g
//
// Purpose:    If the -g NUM flag is given in the command line this will
//             change the group to NUM as long as NUM is in [0, 63]
//
// Parameters: argc - number of command line arguments
//             argv - the commmand line arguments
//             i - position in the command line arguments
//
// Calls:      none
//
//****************************************************************************


void option_g(int argc, char** argv, unsigned& group, int& i)
{
   if(i < argc - 1 && argv[i + 1][0] != char(45))
   {  //true if -g isn't last command and -g has numbers after it
      group = argv[++i][0] - 48;  //ascii for zero is 48
      if(argv[i][1] != '\0')  //true if the group is greater than ten      
      {
         group = (group * 10) + (argv[i][1] - 48);
         if(argv[i][2] != '\0')  //true if the group is greater than 99
         {
            cout << "INVALID group number. Defaulting to "
                 << DEFAULT_GROUP << ".\n";
            group = DEFAULT_GROUP;
         }
      }
   }
   if(group > 63)  //valid range of group numbers is 0 - 63
   {
      cout << "INVALID group number. Defaulting to " << DEFAULT_GROUP << ".\n";
      group = DEFAULT_GROUP;
   }
}


//****************************************************************************
//
// Function:   option_t
//
// Purpose:    If the -t SECONDS flag is given in the command line this will
//             change the timeout from DEFAULT_TIMEOUT to user_timeout
//
// Parameters: argc - number of command line arguments
//             argv - the commmand line arguments
//             timeout - variable to track the timeout
//             i - position in the command line arguments
//
// Calls:      none
//
//****************************************************************************


void option_t(int argc, char** argv, int& timeout, int& i)
{
   if(i < argc - 1 && argv[i + 1][0] != char(45))
   {  //true if -t isn't last command and -t has numbers after it
      ++i;
      for(size_t j(0); argv[i][j] != '\0'; ++j)
      {
         timeout = ((j == 0) ? (argv[i][j] - 48)
                             : (timeout * 10 + argv[i][j] - 48));
      }
   }
}


//****************************************************************************
//
// Function:   option_e
//
// Purpose:    If the -e NUM flag is given in the command line this will
//             change the number of packets the bridge will send from
//             infinite to NUM packets.
//
// Parameters: argc - number of command line arguments
//             argv - the commmand line arguments
//             num_packets - variable to track the number of packets to send
//             i - position in the command line arguments
//             op_e_on - boolean to determine whether this option is on
//
// Calls:      none
//
//****************************************************************************


void option_e(int argc, char** argv, int& num_packets, int& i, bool& op_e_on)
{
   if(i < argc - 1 && argv[i + 1][0] != char(45))
   {  //true if -t isn't last command and -t has numbers after it
      ++i;
      for(size_t j(0); argv[i][j] != '\0'; ++j)
      {
         num_packets = ((j == 0) ? (argv[i][j] - 48)
                                 : (num_packets * 10 + argv[i][j] - 48));
      }
      op_e_on = true;
   }
}


//****************************************************************************
//
// Function:   forwardit
//
// Purpose:    This routine maintains the table, and does the smart routing
//             of the packet through the virtual network.
//
// Parameters: buf - holds the packet
//             packet_size - number of bytes in the packet
//             all_lans - linked list that holds all network addresses and
//                        names of them on the computer
//             j - position in the command line arguments
//             binding_table - vector holding the address-to-interface binding
//                             table
//             ifaces - array of interface handles
//             timeout - amount of seconds an address can be in the table
//
// Calls:      strncmp, begin, end, erase, time, size, push_back, writepacket
//
//****************************************************************************


void forwardit(char buf[], const int& packet_size, lanlist* all_lans,
               const size_t& j, vector <Table>& binding_table,
               INTERFACE** ifaces, const int& timeout)
{
   bool not_found = true;
   _lan* temp_lan_ptr = all_lans -> lan;
   Table temp;
   vector<Table>::iterator temp_iter;

   if(binding_table.size() >= MAX_TABLE_SIZE)  //gets rid of expired entries
   {
      for(temp_iter = binding_table.begin(); temp_iter != binding_table.end();
          ++temp_iter)
      {
         if(temp_iter -> timestamp < time(NULL) - timeout)
            binding_table.erase(temp_iter);
      }
   }

   //this makes temp_lan_ptr point to the current interface
   for(size_t w(0); w < j; ++w) temp_lan_ptr = temp_lan_ptr -> next;

   for(size_t l(0); l < 6; ++l)
   {
      temp.hw_address[l] = buf[l + 6];  //hw address of the source of packet
      temp.hw_interface[l] = temp_lan_ptr -> lname[l];  //interface that the
   }                                                   //packet was received on
   temp.timestamp = time(NULL);  //create timestamp

   for(size_t l(0); l < binding_table.size(); ++l)
   {
      if(strncmp(temp.hw_address, binding_table[l].hw_address, 6) == 0)
      {  //true if the sender is already in the table
         binding_table[l].timestamp = time(NULL);  //update timestamp
         not_found = false;
         break;
      }
   }
   if(not_found) binding_table.push_back(temp);  //add the sender to the table

   not_found = true;
   for(size_t l(0); l < 6; ++l)   //need to check if the destination in buf
   {                             //is in binding_table
      temp.hw_address[l] = buf[l];  //the destination address
   }

   for(temp_iter = binding_table.begin(); temp_iter != binding_table.end();
       ++temp_iter)
   {
      if(strncmp(temp.hw_address, temp_iter -> hw_address, 6) == 0)
      {  //true if the destination is found in the table
         if(temp_iter -> timestamp < time(NULL) - timeout)
         {  //destinations timestamp has expired, have to start over and
            binding_table.erase(temp_iter);            //send to everyone again
            goto here;
         }
         for(size_t m(0); m < 6; ++m)  //get interface to put in table
            temp.hw_interface[m] = temp_iter -> hw_interface[m];
         temp_iter -> timestamp = time(NULL);  //update the timestamp
         not_found = false;
         goto here;
      }
   }
   here:
   if(not_found)  //if destination wasn't in the table, write to all interfaces
   {
      temp_lan_ptr = all_lans -> lan;
      for(size_t n(0); n < MAX_NUM_INTERFACES; ++n)
      {
         if((ifaces[n] != NULL) && (n != j))
            writepacket(ifaces[n], buf, packet_size);
         temp_lan_ptr = temp_lan_ptr -> next;
      }
   }
   else  //if destination was in the table, just write to correct interface
   {
      size_t n(0);
      for(temp_lan_ptr = all_lans -> lan;
          strncmp(temp.hw_interface, temp_lan_ptr -> lname, 6) != 0;
          temp_lan_ptr = temp_lan_ptr -> next) ++n;
      writepacket(ifaces[n], buf, packet_size);
   }
}