// portforward.cpp : trivial port forwarding for windows
// originally by QuantumG (qg@biodome.org)
// logging added by Scott Dale Robison (scott@mediaport.com) on 22 Oct 2007

#include <winsock2.h>
#include <stdio.h>
#include <windows.h>

FILE* log_file = 0;
const int bytes_per_line = 16;
CRITICAL_SECTION critical_section;

void xprintf(const char* format, ...)
{
    EnterCriticalSection(&critical_section);

    va_list args;
    va_start(args, format);

    vprintf(format, args);
    vfprintf(log_file, format, args);

    va_end(args);

    LeaveCriticalSection(&critical_section);
}

void dump(const char* label, char* buf, int n)
{
    EnterCriticalSection(&critical_section);

    for (int i = 0; i < n; i += bytes_per_line)
    {
        int bytes = n - i;

        if (bytes > bytes_per_line)
            bytes = bytes_per_line;

        xprintf("%s", label);

        int j;

        for (j = 0; j < bytes; ++j)
            xprintf(" %02X", (unsigned char)buf[i+j]);

        for (; j <= bytes_per_line; ++j)
            xprintf("   ");

        for (j = 0; j < bytes; ++j)
        {
            char c = ' ';
            if ((buf[i+j] > 32) && (buf[i+j] < 127))
                c = buf[i+j];
            xprintf("%c", c);
        }

        xprintf("\n");
    }

    LeaveCriticalSection(&critical_section);
}

DWORD WINAPI reader(LPVOID lpParameter)
{
    SOCKET *socks = (SOCKET*)lpParameter;

    char buf[65536];
    int n;
    while ((n = recv(socks[0], buf, sizeof(buf), 0)) > 0) {
        dump("c -> s:", buf, n);
        send(socks[1], buf, n, 0);
    }

    closesocket(socks[0]);
    closesocket(socks[1]);

    return 0;
}

DWORD WINAPI writer(LPVOID lpParameter)
{
    SOCKET *socks = (SOCKET*)lpParameter;

    char buf[65536];
    int n;
    while ((n = recv(socks[1], buf, sizeof(buf), 0)) > 0) {
        dump("s -> c:", buf, n);
        send(socks[0], buf, n, 0);
    }

    closesocket(socks[0]);
    closesocket(socks[1]);

    return 0;
}

int main(int argc, char* argv[])
{
    if (argc < 5) {
        printf("usage: portforward [port to listen on] [ip of host to connect to] [port to connect to] [log file].\n");
        return 1;
    }

    InitializeCriticalSection(&critical_section);

    log_file = fopen(argv[4], "wt");
    if (log_file == 0)
    {
        printf("cannot create/truncate file %s", argv[4]);
        return 1;
    }

    WORD wVersionRequested;
    WSADATA wsaData;
 
    wVersionRequested = MAKEWORD( 2, 2 );
    WSAStartup( wVersionRequested, &wsaData );

    SOCKET s = socket(AF_INET, SOCK_STREAM, 0);

    sockaddr_in sin;
    sin.sin_addr.S_un.S_addr = INADDR_ANY;
    sin.sin_family = AF_INET;
    sin.sin_port = htons(atoi(argv[1]));
    if (bind(s, (sockaddr*)&sin, sizeof(sin)) != 0) {
        xprintf("cannot bind to port %i\n", atoi(argv[1]));
        return 1;
    }

    listen(s, 5);
    int ss = sizeof(sin);
    SOCKET n;
    while ((n = accept(s, (sockaddr*)&sin, &ss)) != -1) {
        xprintf("received connection from %i.%i.%i.%i\n", sin.sin_addr.S_un.S_un_b.s_b1, sin.sin_addr.S_un.S_un_b.s_b2, sin.sin_addr.S_un.S_un_b.s_b3, sin.sin_addr.S_un.S_un_b.s_b4);
        SOCKET d = socket(AF_INET, SOCK_STREAM, 0);
        sin.sin_family = AF_INET;
        sin.sin_addr.S_un.S_addr = inet_addr(argv[2]);
        sin.sin_port = htons(atoi(argv[3]));
        if (connect(d, (sockaddr*)&sin, sizeof(sin)) != 0) {
            xprintf("received a connection but can't connect to %s:%i\n", argv[2], atoi(argv[3]));
            closesocket(n);
        } else {
            xprintf("connection to %s:%i established\n", argv[2], atoi(argv[3]));
            SOCKET *socks = new SOCKET[2];
            socks[0] = n;
            socks[1] = d;
            DWORD id;
            CreateThread(NULL, 0, reader, socks, 0, &id);
            CreateThread(NULL, 0, writer, socks, 0, &id);
        }
    }
    closesocket(s);

    return 0;
}


