#include "mp3.h" /* mp3 header defs */
#include <stdio.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h> /* sockaddr_in structure */
#include <netdb.h> /* /etc/hosts table entries */
#include <sys/time.h> /* time unit */

#define BUFFERSIZE 1300 /* max buffer length */
#define DATABUFFERSIZE 4000

   char inputBuffer[BUFFERSIZE], headerBuffer[BUFFERSIZE], dataBuffer[10*DATABUFFERSIZE];
   mp3header frameHeader, dummyHeader, lastHeader;
   RTPheader packetHeader;
   MSheader specificHeader;
   int bitRate, samplingFrequency, padding, dataOffset, frameSize, timeInterval;
   int currentFrame=0;
   int currentPacket=-1;
   FILE *mp3file;
   int mySocket; /* client socket descriptor */
   int dataBufferOffset = 0;
   int headerBufferOffset=0;
   int lastFrameOffset = 0;

void myExit(int code)
  {
  close(mySocket);
  fclose(mp3file);
  printf("\nExit called, code %d.\n",code);
  exit(code);
  }

/* init the client */
int initClient(int myPort)
   {
   int mySocket;
   struct sockaddr_in mySockAddr;  /* Internet socket name */
   int adrlen; /* sockaddr length */
   struct sockaddr addr;    /* generic socket name */
   struct sockaddr_in *nptr;   /* pointer to get port number */
   if((mySocket = socket(AF_INET, SOCK_DGRAM, 0)) < 0)
      {
      printf("Socket create failure, error code %d\n", errno);
      perror("Error message"); /* print error message */
      myExit(1); /* abnormal termination */
      }
    else
      {
      bzero(&mySockAddr, sizeof(mySockAddr));
      mySockAddr.sin_family = AF_INET; /* net socket */
      mySockAddr.sin_port = myPort;
      mySockAddr.sin_addr.s_addr = htonl(INADDR_ANY);  /* wildcard */
      if(bind(mySocket, &mySockAddr, sizeof(mySockAddr)) < 0)
         {
         printf("Socket bind failure, error code %d\n", errno);
         perror("Error message"); /* print error message */
         close(mySocket);  /* dispose of it, since bind failed */
         myExit(2); /* abnormal termination */
         }
      adrlen = sizeof(addr); /* length of return value */
      if(getsockname(mySocket, &addr, &adrlen) < 0)
         {
         printf("Function getsockname returned error code %d\n", errno);
         perror("Error message"); /* print error message */
         close(mySocket); /* dispose of it, since getname failed */
         myExit(3); /* abnormal termination */
         }
      printf("\n");
      nptr = (struct sockaddr_in *) &addr;  /* pointer to reference port # */
      printf("Client was assigned port number: %d\n", ntohs(nptr -> sin_port));
      }
   return mySocket;
   }

/* parses a frame header */
void parseHeader()  
  {
  int bytesToSend;
  int maxFrames;
  bitRate = bitRates[frameHeader.BRI];
  samplingFrequency = samplingFrequencies[frameHeader.SF];
  padding = frameHeader.PaB;
  dataOffset = frameHeader.offset;
  frameSize = ((1152./samplingFrequency)*(bitRate/8.))+padding;
  }

void doWork(int count)
  {
  int i;
  int bytesToFill; /* # of data bytes to fill with zeros */
  int frameCount;
  int headerPacketOffset; /* where headers are */
  int dataPacketOffset; /* where data is */
  int framesToFill; /* # of frames to fill */
  int lastFrameSize; /* size of last frame read in last packet */
  int dataWriteOffset = 0; /* counter for written data bytes */
  int writtenHeadersLength = 0; /* counter for written header bytes */
  int availableFrameHeaders; /* # of headers in buffer */
  memcpy(&packetHeader, inputBuffer, sizeof(packetHeader)); /* get the packet header */
  if ((packetHeader.SN > currentPacket)||(currentPacket==-1)) /* ignore duplicate or late packets */
    {
    if (packetHeader.P == 1) count -= inputBuffer[count-1]; /* subtract padding bytes from size */
    memcpy(&specificHeader, inputBuffer + sizeof(packetHeader), sizeof(specificHeader)); /* extract specific header */
    frameCount = specificHeader.numFrames; /* how many frames? */
    headerPacketOffset = sizeof(packetHeader) + sizeof(specificHeader); /* where headers are */
    dataPacketOffset = headerPacketOffset + sizeof(frameHeader) * frameCount; /* where data is */
    framesToFill = specificHeader.FFN - currentFrame; /* # of frames to fill */
    if (framesToFill > 0)
      { /* put in a bunch of zeroes */
      memcpy(&dummyHeader, inputBuffer + headerPacketOffset, sizeof(dummyHeader)); /* get the first header read */
      if (currentPacket!=-1)memcpy(&lastHeader, headerBuffer + headerBufferOffset - sizeof(frameHeader), sizeof(lastHeader)); /* get the last header in buffer */
       else
        { /* don't have anything in the buffer */
        memcpy(&lastHeader, &dummyHeader, sizeof(lastHeader));/* so just duplicate what we have */
        lastHeader.offset = 0; /* set offset to zero, this is the first frame */
        memcpy(headerBuffer, &lastHeader, sizeof(lastHeader)); /* copy it in the header buffer */
        headerBufferOffset = sizeof(lastHeader); /* got one header in the buffer now */
        }
      bitRate = bitRates[lastHeader.BRI];
      samplingFrequency = samplingFrequencies[lastHeader.SF];
      padding = lastHeader.PaB;
      lastFrameSize = ((1152./samplingFrequency)*(bitRate/8.))+padding; /* compute the size */
      bytesToFill = framesToFill * (lastFrameSize - sizeof(frameHeader)) + lastFrameOffset - dummyHeader.offset; /* how many bytes missing? */
      for (i = 0; i < bytesToFill; i++) dataBuffer[dataBufferOffset++] = 0; /* put the zeroes */
      dummyHeader.offset = lastFrameOffset; /* set zero length */
      dummyHeader.PaB = lastHeader.PaB; /* set padding same as last frame */
      while (currentFrame < specificHeader.FFN) /* loop to generate headers */
        {
        memcpy(headerBuffer + headerBufferOffset, &dummyHeader, sizeof(dummyHeader)); /* just copy */
        headerBufferOffset += sizeof(dummyHeader); /* adjust offset */
        dummyHeader.offset += (lastFrameSize - sizeof(frameHeader)); /* make all frames zero length, except last */
        currentFrame++; /* got a brand new frame here, empty! */
        }
      }   
    memcpy(headerBuffer + headerBufferOffset, inputBuffer + headerPacketOffset, sizeof(frameHeader) * frameCount); /* copy the headers */
    headerBufferOffset += sizeof(frameHeader) * frameCount; /* how many frames? */
    currentFrame += frameCount;
    memcpy(dataBuffer + dataBufferOffset, inputBuffer + dataPacketOffset, count - dataPacketOffset); /* copy the data */
    dataBufferOffset += (count - dataPacketOffset); /* add the length of the data */
    i=0; /* counter */
    availableFrameHeaders = headerBufferOffset / sizeof(frameHeader);
    memcpy(&frameHeader, headerBuffer, sizeof(frameHeader)); /* set up the loop */
    parseHeader(); /* set up some global vars and compute frameSize */
    while((i < availableFrameHeaders)&&(dataWriteOffset + frameSize - sizeof(frameHeader) < dataBufferOffset)) /* last frame is incomplete */
      {
      memcpy(&frameHeader, headerBuffer + writtenHeadersLength, sizeof(frameHeader)); /* get the header */
      parseHeader(); /* adjust the global vars and compute frameSize */
      fwrite(&frameHeader, sizeof(frameHeader), 1, mp3file); /* write the header */
      fwrite(dataBuffer + dataWriteOffset, frameSize - sizeof(frameHeader), 1, mp3file); /* write the data */
      dataWriteOffset += (frameSize - sizeof(frameHeader)); /* how many bytes belong here */
      writtenHeadersLength += sizeof(frameHeader); /* another header down the drain */
      i++;
      }
    memcpy(headerBuffer, headerBuffer + writtenHeadersLength, headerBufferOffset - writtenHeadersLength); /* copy unwritten headers */ 
    headerBufferOffset = headerBufferOffset - writtenHeadersLength; /* next time start at position x */
    lastFrameSize = dataBufferOffset - dataWriteOffset; /* the last frame has this much data - not enough to write to the file though */
    memcpy(dataBuffer, dataBuffer + dataWriteOffset, lastFrameSize); /* copy it to the beginning */
    dataBufferOffset = lastFrameSize; /* adjust the offset */
    currentPacket = packetHeader.SN; /* set the sequence number */
    lastFrameOffset = specificHeader.lastFrameOffset; /* set it for next iteration */
    }
  }

int main (int argc, char *argv[])
  {
  fd_set mySockets; /* incoming sockets set */
  struct sockaddr_in myFromSockAddr;  /* Internet socket name */
  int count = 0;
  int myAddrSize = sizeof(myFromSockAddr);
  int myPort;
  char fileName[256];
  signal(SIGINT, myExit);
  signal(SIGPIPE, myExit);
  if (argc != 5)
    {
    printf("Usage: mp3_client -dataport <port> -output <output_file>\n");
    myExit(0);
    }
  printf("\nmp3 Client PID %d\n", getpid() );
  myPort = ntohs(atoi(argv[2]));
  mySocket = initClient(myPort); /* create socket */
  strcpy(fileName,argv[4]);
  if (!(mp3file=fopen(fileName,"w")))
    {
    printf("\nFile %s could not be opened for writing!\n",fileName);
    myExit(1);
    }
  while (1) /* infinite loop */
    {
    do
      {
      FD_ZERO(&mySockets);
      FD_SET(mySocket,&mySockets);
      count = select(mySocket+1, &mySockets, (fd_set *) 0, (fd_set *) 0, NULL);
      }
    while (count <= 0); /* wait for some data */
    printf("\nConnected.\n");
    count = recvfrom(mySocket, inputBuffer, BUFFERSIZE, 0, &myFromSockAddr, &myAddrSize);
    if (count != -1) doWork(count);
    while (count>0)
      if (connect(mySocket, &myFromSockAddr, myAddrSize)!=-1)
        {
        count = recv(mySocket, inputBuffer, BUFFERSIZE, 0);
        if (count != -1) doWork(count);
        printf("*",count); /* to signal it's getting something */
        fflush(stdout);
        }
    printf("\nDisconnected.\n");
    }
  }
