windows-nt/Source/XPSP1/NT/net/tcpip/tpipv6/6to4svc/proxy.c

1407 lines
38 KiB
C
Raw Normal View History

2020-09-26 03:20:57 -05:00
//
// proxy.c - Generic application level proxy for IPv6/IPv4
//
// This program accepts TCP connections on one socket and port, and
// forwards data between in and another socket to a given address
// (default loopback) and port (default same as listening port).
//
// For example, it can make an unmodified IPv4 server look like an IPv6 server.
// Typically, the proxy will run on the same machine as
// the server it is fronting, but that doesn't have to be the case.
//
// Copyright (C) Microsoft Corporation.
// All rights reserved.
//
// History:
// Original code by Brian Zill.
// Made into a service by Dave Thaler.
//
#include "precomp.h"
#pragma hdrstop
//
// Configuration parameters.
//
#define BUFFER_SIZE (4 * 1024)
typedef enum {
Connect,
Accept,
Receive,
Send
} OPERATION;
CONST CHAR *OperationName[]={
"Connect",
"Accept",
"Receive",
"Send"
};
typedef enum {
Inbound = 0, // Receive from client, send to server.
Outbound, // Receive from server, send to client.
NumDirections
} DIRECTION;
typedef enum {
Client = 0,
Server,
NumSides
} SIDE;
//
// Information we keep for each port we're proxying on.
//
#define ADDR_BUFF_LEN (16+sizeof(SOCKADDR_IN6))
typedef struct _PORT_INFO {
LIST_ENTRY Link;
ULONG ReferenceCount;
SOCKET ListenSocket;
SOCKET AcceptSocket;
BYTE AcceptBuffer[ADDR_BUFF_LEN*2];
WSAOVERLAPPED Overlapped;
OPERATION Operation;
SOCKADDR_STORAGE LocalAddress;
ULONG LocalAddressLength;
SOCKADDR_STORAGE RemoteAddress;
ULONG RemoteAddressLength;
//
// A lock protects the connection list for this port.
//
CRITICAL_SECTION Lock;
LIST_ENTRY ConnectionHead;
} PORT_INFO, *PPORT_INFO;
//
// Information we keep for each direction of a bi-directional connection.
//
typedef struct _DIRECTION_INFO {
WSABUF Buffer;
WSAOVERLAPPED Overlapped;
OPERATION Operation;
struct _CONNECTION_INFO *Connection;
DIRECTION Direction;
} DIRECTION_INFO, *PDIRECTION_INFO;
//
// Information we keep for each client connection.
//
typedef struct _CONNECTION_INFO {
LIST_ENTRY Link;
ULONG ReferenceCount;
PPORT_INFO Port;
BOOL HalfOpen; // Has one side or the other stopped sending?
LONG Closing;
SOCKET Socket[NumSides];
DIRECTION_INFO DirectionInfo[NumDirections];
} CONNECTION_INFO, *PCONNECTION_INFO;
//
// Global variables.
//
LIST_ENTRY g_GlobalPortList;
LPFN_CONNECTEX ConnectEx = NULL;
//
// Function prototypes.
//
VOID
ProcessReceiveError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
);
VOID
ProcessSendError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
);
VOID
ProcessAcceptError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
);
VOID
ProcessConnectError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
);
VOID APIENTRY
TpProcessWorkItem(
IN ULONG Status,
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
);
//
// Inline functions.
//
__inline
ReferenceConnection(
IN PCONNECTION_INFO Connection
)
{
InterlockedIncrement(&Connection->ReferenceCount);
}
__inline
DereferenceConnection(
IN OUT PCONNECTION_INFO *ConnectionPtr
)
{
ULONG Value;
Value = InterlockedDecrement(&(*ConnectionPtr)->ReferenceCount);
if (Value == 0) {
FREE(*ConnectionPtr);
*ConnectionPtr = NULL;
}
}
__inline
VOID
ReferencePort(
IN PPORT_INFO Port
)
{
InterlockedIncrement(&Port->ReferenceCount);
}
__inline
VOID
DereferencePort(
IN OUT PPORT_INFO *PortPtr
)
{
ULONG Value;
Value = InterlockedDecrement(&(*PortPtr)->ReferenceCount);
if (Value == 0) {
DeleteCriticalSection(&(*PortPtr)->Lock);
FREE(*PortPtr);
*PortPtr = NULL;
}
}
//
// Allocate and initialize state for a new client connection.
//
PCONNECTION_INFO
NewConnection(
IN SOCKET ClientSocket,
IN ULONG ConnectFamily
)
{
PCONNECTION_INFO Connection;
//
// Allocate space for a CONNECTION_INFO structure and two buffers.
//
Connection = (CONNECTION_INFO *)MALLOC(sizeof(*Connection) + (2 * BUFFER_SIZE));
if (Connection == NULL) {
return NULL;
}
//
// Fill everything in.
//
Connection->HalfOpen = FALSE;
Connection->Closing = FALSE;
Connection->Socket[Client] = ClientSocket;
Connection->DirectionInfo[Inbound].Direction = Inbound;
Connection->DirectionInfo[Inbound].Operation = Receive; // Start out receiving.
Connection->DirectionInfo[Inbound].Buffer.len = BUFFER_SIZE;
Connection->DirectionInfo[Inbound].Buffer.buf = (char *)(Connection + 1);
Connection->DirectionInfo[Inbound].Connection = Connection;
Connection->Socket[Server] = socket(ConnectFamily, SOCK_STREAM, 0);
Connection->DirectionInfo[Outbound].Direction = Outbound;
Connection->DirectionInfo[Outbound].Operation = Receive; // Start out receiving.
Connection->DirectionInfo[Outbound].Buffer.len = BUFFER_SIZE;
Connection->DirectionInfo[Outbound].Buffer.buf = Connection->DirectionInfo[Inbound].Buffer.buf + BUFFER_SIZE;
Connection->DirectionInfo[Outbound].Connection = Connection;
Connection->ReferenceCount = 0;
ReferenceConnection(Connection);
Trace2(FSM, _T("R++ %d %x NewConnection"), Connection->ReferenceCount, Connection);
return Connection;
}
//
// Create state information for a client.
//
CONNECTION_INFO *
CreateConnectionState(
IN SOCKET ClientSocket,
IN SOCKET ServerSocket
)
{
CONNECTION_INFO *Conn;
//
// Allocate space for a CONNECTION_INFO structure and two buffers.
//
Conn = (CONNECTION_INFO *)MALLOC(sizeof(*Conn) + (2 * BUFFER_SIZE));
if (Conn == NULL) {
return NULL;
}
//
// Fill everything in.
//
Conn->HalfOpen = FALSE;
//
// Start out in the receiving state in both directions.
//
Conn->Socket[Client] = ClientSocket;
Conn->DirectionInfo[Inbound].Direction = Inbound;
Conn->DirectionInfo[Inbound].Operation = Receive;
Conn->DirectionInfo[Inbound].Buffer.len = BUFFER_SIZE;
Conn->DirectionInfo[Inbound].Buffer.buf = (char *)(Conn + 1);
Conn->DirectionInfo[Inbound].Connection = Conn;
Conn->Socket[Server] = ServerSocket;
Conn->DirectionInfo[Outbound].Direction = Outbound;
Conn->DirectionInfo[Outbound].Operation = Receive;
Conn->DirectionInfo[Outbound].Buffer.len = BUFFER_SIZE;
Conn->DirectionInfo[Outbound].Buffer.buf =
Conn->DirectionInfo[Inbound].Buffer.buf + BUFFER_SIZE;
Conn->DirectionInfo[Outbound].Connection = Conn;
return Conn;
}
//
// Start an asynchronous accept.
//
// Assumes caller holds a reference on Port.
//
DWORD
StartAccept(
IN PPORT_INFO Port
)
{
ULONG Status, Junk;
ASSERT(Port->ReferenceCount > 0);
//
// Count another reference for the operation.
//
ReferencePort(Port);
Port->AcceptSocket = socket(Port->LocalAddress.ss_family, SOCK_STREAM, 0);
if (Port->AcceptSocket == INVALID_SOCKET) {
Status = WSAGetLastError();
ProcessAcceptError(0, &Port->Overlapped, Status);
return Status;
}
Trace2(SOCKET, _T("Starting an accept with new socket %x ovl %p"),
Port->AcceptSocket, &Port->Overlapped);
Port->Overlapped.hEvent = NULL;
Port->Operation = Accept;
if (!AcceptEx(Port->ListenSocket,
Port->AcceptSocket,
Port->AcceptBuffer, // only used to hold addresses
0,
ADDR_BUFF_LEN,
ADDR_BUFF_LEN,
&Junk,
&Port->Overlapped)) {
Status = WSAGetLastError();
if (Status != ERROR_IO_PENDING) {
ProcessAcceptError(0, &Port->Overlapped, Status);
return Status;
}
}
return NO_ERROR;
}
//
// Start an asynchronous connect.
//
// Assumes caller holds a reference on Connection.
//
DWORD
StartConnect(
IN PCONNECTION_INFO Connection,
IN PPORT_INFO Port
)
{
ULONG Status, Junk;
SOCKADDR_STORAGE LocalAddress;
//
// Count a reference for the operation.
//
ReferenceConnection(Connection);
Trace2(FSM, _T("R++ %d %x StartConnect"), Connection->ReferenceCount, Connection);
Connection->Socket[Server] = socket(Port->RemoteAddress.ss_family,
SOCK_STREAM, 0);
if (Connection->Socket[Server] == INVALID_SOCKET) {
Status = WSAGetLastError();
ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped,
Status);
return Status;
}
Connection->DirectionInfo[Inbound].Overlapped.hEvent = NULL;
Connection->DirectionInfo[Outbound].Overlapped.hEvent = NULL;
ZeroMemory(&LocalAddress, Port->RemoteAddressLength);
LocalAddress.ss_family = Port->RemoteAddress.ss_family;
if (bind(Connection->Socket[Server], (LPSOCKADDR)&LocalAddress,
Port->RemoteAddressLength) == SOCKET_ERROR) {
Status = WSAGetLastError();
ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped,
Status);
return Status;
}
if (!BindIoCompletionCallback((HANDLE)Connection->Socket[Server],
TpProcessWorkItem,
0)) {
Status = GetLastError();
ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped,
Status);
return Status;
}
if (ConnectEx == NULL) {
GUID Guid = WSAID_CONNECTEX;
if (WSAIoctl(Connection->Socket[Server],
SIO_GET_EXTENSION_FUNCTION_POINTER,
&Guid,
sizeof(Guid),
&ConnectEx,
sizeof(ConnectEx),
&Junk,
NULL, NULL) == SOCKET_ERROR) {
ProcessConnectError(0,
&Connection->DirectionInfo[Inbound].Overlapped,
WSAGetLastError());
}
}
Trace2(SOCKET,
_T("Starting a connect with socket %x ovl %p"),
Connection->Socket[Server],
&Connection->DirectionInfo[Inbound].Overlapped);
Connection->DirectionInfo[Inbound].Operation = Connect;
if (!ConnectEx(Connection->Socket[Server],
(LPSOCKADDR)&Port->RemoteAddress,
Port->RemoteAddressLength,
NULL, 0,
&Junk,
&Connection->DirectionInfo[Inbound].Overlapped)) {
Status = WSAGetLastError();
if (Status != ERROR_IO_PENDING) {
ProcessConnectError(0,
&Connection->DirectionInfo[Inbound].Overlapped,
Status);
return Status;
}
}
return NO_ERROR;
}
//
// Start an asynchronous receive.
//
// Assumes caller holds a reference on DirectionInfo.
//
VOID
StartReceive(
IN PDIRECTION_INFO DirectionInfo
)
{
ULONG BytesRcvd, Status;
PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
Trace3(SOCKET, _T("starting ReadFile on socket %x with Dir %p ovl %p"),
Connection->Socket[DirectionInfo->Direction], DirectionInfo,
&DirectionInfo->Overlapped);
//
// Count a reference for the operation.
//
ReferenceConnection(Connection);
Trace2(FSM, _T("R++ %d %x StartReceive"), Connection->ReferenceCount, Connection);
ASSERT(DirectionInfo->Overlapped.hEvent == NULL);
ASSERT(DirectionInfo->Buffer.len > 0);
ASSERT(DirectionInfo->Buffer.buf != NULL);
DirectionInfo->Operation = Receive;
Trace5(SOCKET, _T("ReadFile %x %p %d %p %p"),
Connection->Socket[DirectionInfo->Direction],
&DirectionInfo->Buffer.buf,
DirectionInfo->Buffer.len,
&BytesRcvd,
&DirectionInfo->Overlapped);
//
// Post receive buffer.
//
if (!ReadFile((HANDLE)Connection->Socket[DirectionInfo->Direction],
DirectionInfo->Buffer.buf,
DirectionInfo->Buffer.len,
&BytesRcvd,
&DirectionInfo->Overlapped)) {
Status = GetLastError();
if (Status != ERROR_IO_PENDING) {
ProcessReceiveError(0, &DirectionInfo->Overlapped, Status);
return;
}
}
}
//
// Start an asynchronous send.
//
// Assumes caller holds a reference on DirectionInfo.
//
VOID
StartSend(
IN PDIRECTION_INFO DirectionInfo,
IN ULONG NumBytes
)
{
ULONG BytesSent, Status;
PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
Trace3(SOCKET, _T("starting WriteFile on socket %x with Dir %p ovl %p"),
Connection->Socket[1 - DirectionInfo->Direction], DirectionInfo,
&DirectionInfo->Overlapped);
//
// Count a reference for the operation.
//
ReferenceConnection(Connection);
Trace2(FSM, _T("R++ %d %x StartSend"), Connection->ReferenceCount, Connection);
DirectionInfo->Operation = Send;
//
// Post send buffer.
//
if (!WriteFile((HANDLE)Connection->Socket[1 - DirectionInfo->Direction],
DirectionInfo->Buffer.buf,
NumBytes,
&BytesSent,
&DirectionInfo->Overlapped)) {
Status = GetLastError();
if (Status != ERROR_IO_PENDING) {
Trace1(ERR, _T("WriteFile 1 failed %d"), Status);
ProcessSendError(0, &DirectionInfo->Overlapped, Status);
return;
}
}
}
//
// This gets called when we want to start proxying for a new port.
//
DWORD
StartUpPort(
IN PPORT_INFO Port
)
{
ULONG Status = NO_ERROR;
CHAR LocalBuffer[256];
CHAR RemoteBuffer[256];
ULONG Length;
//
// Add an initial reference.
//
ReferencePort(Port);
InitializeCriticalSection(&Port->Lock);
InitializeListHead(&Port->ConnectionHead);
Port->ListenSocket = socket(Port->LocalAddress.ss_family, SOCK_STREAM, 0);
if (Port->ListenSocket == INVALID_SOCKET) {
Status = WSAGetLastError();
Trace1(ERR, _T("socket() failed with error %u"), Status);
return Status;
}
if (bind(Port->ListenSocket, (LPSOCKADDR)&Port->LocalAddress,
Port->LocalAddressLength) == SOCKET_ERROR) {
Trace1(ERR, _T("bind() failed with error %u"), WSAGetLastError());
goto Fail;
}
if (listen(Port->ListenSocket, 5) == SOCKET_ERROR) {
Trace1(ERR, _T("listen() failed with error %u"), WSAGetLastError());
goto Fail;
}
if (!BindIoCompletionCallback((HANDLE)Port->ListenSocket,
TpProcessWorkItem,
0)) {
Trace1(ERR, _T("BindIoCompletionCallback() failed with error %u"),
GetLastError());
goto Fail;
}
Length = sizeof(LocalBuffer);
LocalBuffer[0] = '\0';
WSAAddressToStringA((LPSOCKADDR)&Port->LocalAddress,
Port->LocalAddressLength, NULL, LocalBuffer, &Length);
Length = sizeof(RemoteBuffer);
RemoteBuffer[0] = '\0';
WSAAddressToStringA((LPSOCKADDR)&Port->RemoteAddress,
Port->RemoteAddressLength, NULL, RemoteBuffer, &Length);
Trace2(FSM, _T("Proxying %hs to %hs"), LocalBuffer, RemoteBuffer);
//
// Start an asynchronous accept
//
return StartAccept(Port);
Fail:
closesocket(Port->ListenSocket);
Port->ListenSocket = INVALID_SOCKET;
return WSAGetLastError();
}
VOID
CloseConnection(
IN OUT PCONNECTION_INFO *ConnectionPtr
)
{
PCONNECTION_INFO Connection = (*ConnectionPtr);
PPORT_INFO Port = Connection->Port;
if (InterlockedExchange(&Connection->Closing, TRUE) != FALSE) {
//
// Nothing to do.
//
return;
}
Trace2(SOCKET, _T("Closing client socket %x and server socket %x"),
Connection->Socket[Client],
Connection->Socket[Server]);
closesocket(Connection->Socket[Client]);
closesocket(Connection->Socket[Server]);
EnterCriticalSection(&Port->Lock);
{
RemoveEntryList(&Connection->Link);
}
LeaveCriticalSection(&Port->Lock);
Trace2(FSM, _T("R-- %d %x CloseConnection"),
Connection->ReferenceCount, Connection);
DereferenceConnection(ConnectionPtr);
}
//
// This gets called when we want to stop proxying for a given port.
//
VOID
ShutDownPort(
IN PPORT_INFO *PortPtr
)
{
PCONNECTION_INFO Connection;
PPORT_INFO Port = *PortPtr;
//
// Close any connections.
//
EnterCriticalSection(&Port->Lock);
while (!IsListEmpty(&Port->ConnectionHead)) {
Connection = CONTAINING_RECORD(Port->ConnectionHead.Flink,
CONNECTION_INFO, Link);
CloseConnection(&Connection);
}
LeaveCriticalSection(&Port->Lock);
closesocket(Port->ListenSocket);
Port->ListenSocket = INVALID_SOCKET;
Trace1(FSM, _T("Shut down port %u"), RtlUshortByteSwap(SS_PORT(&Port->RemoteAddress)));
//
// Release the reference added by StartUpPort.
//
DereferencePort(PortPtr);
}
typedef enum {
V4TOV4,
V4TOV6,
V6TOV4,
V6TOV6
} PPTYPE, *PPPTYPE;
typedef struct {
ULONG ListenFamily;
ULONG ConnectFamily;
PWCHAR KeyString;
} PPTYPEINFO, *PPPTYPEINFO;
#define KEY_V4TOV4 L"v4tov4"
#define KEY_V4TOV6 L"v4tov6"
#define KEY_V6TOV4 L"v6tov4"
#define KEY_V6TOV6 L"v6tov6"
#define KEY_PORTS L"System\\CurrentControlSet\\Services\\PortProxy"
PPTYPEINFO PpTypeInfo[] = {
{ AF_INET, AF_INET, KEY_V4TOV4 },
{ AF_INET, AF_INET6, KEY_V4TOV6 },
{ AF_INET6, AF_INET, KEY_V6TOV4 },
{ AF_INET6, AF_INET6, KEY_V6TOV6 },
};
//
// Given new configuration data, make any changes needed.
//
VOID
ApplyNewPortList(
IN OUT PLIST_ENTRY pNewList
)
{
PPORT_INFO Port, pCurr;
PLIST_ENTRY pleCurr, plePort, pleNext;
//
// Compare against current port list.
//
for (pleCurr = g_GlobalPortList.Flink; pleCurr != &g_GlobalPortList; pleCurr = pleNext) {
pleNext = pleCurr->Flink;
pCurr = CONTAINING_RECORD(pleCurr, PORT_INFO, Link);
for (plePort = pNewList->Flink; plePort != pNewList; plePort = plePort->Flink) {
Port = CONTAINING_RECORD(plePort, PORT_INFO, Link);
if (SS_PORT(&Port->RemoteAddress) == SS_PORT(&pCurr->RemoteAddress)) {
break;
}
}
if (plePort == pNewList) {
//
// Shut down an old proxy port.
//
RemoveEntryList(pleCurr);
ShutDownPort(&pCurr);
}
}
for (plePort = pNewList->Flink; plePort != pNewList; plePort = pleNext) {
pleNext = plePort->Flink;
Port = CONTAINING_RECORD(plePort, PORT_INFO, Link);
for (pleCurr = g_GlobalPortList.Flink; pleCurr != &g_GlobalPortList; pleCurr = pleCurr->Flink) {
pCurr = CONTAINING_RECORD(pleCurr, PORT_INFO, Link);
if (SS_PORT(&Port->LocalAddress) == SS_PORT(&pCurr->LocalAddress)) {
//
// Update remote address.
//
pCurr->RemoteAddress = Port->RemoteAddress;
pCurr->RemoteAddressLength = Port->RemoteAddressLength;
break;
}
}
if (pleCurr == &g_GlobalPortList) {
//
// Start up a new proxy port.
//
RemoveEntryList(plePort);
InsertTailList(&g_GlobalPortList, plePort);
StartUpPort(Port);
}
}
}
//
// Reads from the registry one type of proxying (e.g., v6-to-v4).
//
VOID
AppendType(
IN PLIST_ENTRY Head,
IN HKEY hPorts,
IN PPTYPE Type
)
{
ADDRINFO ListenHints, ConnectHints;
ADDRINFO *LocalAi, *RemoteAi;
ULONG ListenChars, dwType, ConnectBytes, i;
WCHAR ListenBuffer[256], *ListenAddress, *ListenPort;
WCHAR ConnectAddress[256], *ConnectPort;
PPORT_INFO Port;
ULONG Status;
HKEY hType, hProto;
ZeroMemory(&ListenHints, sizeof(ListenHints));
ListenHints.ai_family = PpTypeInfo[Type].ListenFamily;
ListenHints.ai_socktype = SOCK_STREAM;
ListenHints.ai_flags = AI_PASSIVE;
ZeroMemory(&ConnectHints, sizeof(ConnectHints));
ConnectHints.ai_family = PpTypeInfo[Type].ConnectFamily;
ConnectHints.ai_socktype = SOCK_STREAM;
Status = RegOpenKeyExW(hPorts, PpTypeInfo[Type].KeyString, 0,
KEY_QUERY_VALUE, &hType);
if (Status != NO_ERROR) {
return;
}
Status = RegOpenKeyExW(hType, L"tcp", 0, KEY_QUERY_VALUE, &hProto);
if (Status != NO_ERROR) {
RegCloseKey(hType);
return;
}
for (i=0; ; i++) {
ListenChars = sizeof(ListenBuffer)/sizeof(WCHAR);
ConnectBytes = sizeof(ConnectAddress);
Status = RegEnumValueW(hProto, i, ListenBuffer, &ListenChars, NULL,
&dwType, (PVOID)ConnectAddress, &ConnectBytes);
if (Status != NO_ERROR) {
break;
}
if (dwType != REG_SZ) {
continue;
}
ListenPort = wcschr(ListenBuffer, L'/');
if (ListenPort) {
//
// Replace slash with NULL, so we have 2 strings to pass
// to getaddrinfo.
//
if (ListenBuffer[0] == '*') {
ListenAddress = NULL;
} else {
ListenAddress = ListenBuffer;
}
*ListenPort++ = '\0';
} else {
//
// If the address data didn't include a connect address
// use NULL.
//
ListenAddress = NULL;
ListenPort = ListenBuffer;
}
ConnectPort = wcschr(ConnectAddress, '/');
if (ConnectPort) {
//
// Replace slash with NULL, so we have 2 strings to pass
// to getaddrinfo.
//
*ConnectPort++ = '\0';
} else {
//
// If the address data didn't include a remote port number,
// use the same port as the local port number.
//
ConnectPort = ListenPort;
}
Status = GetAddrInfoW(ConnectAddress, ConnectPort, &ConnectHints,
&RemoteAi);
if (Status != NO_ERROR) {
continue;
}
Status = GetAddrInfoW(ListenAddress, ListenPort, &ListenHints,
&LocalAi);
if (Status != NO_ERROR) {
freeaddrinfo(RemoteAi);
continue;
}
Port = MALLOC(sizeof(PORT_INFO));
if (Port) {
ZeroMemory(Port, sizeof(PORT_INFO));
InsertTailList(Head, &Port->Link);
memcpy(&Port->RemoteAddress, RemoteAi->ai_addr, RemoteAi->ai_addrlen);
Port->RemoteAddressLength = (ULONG)RemoteAi->ai_addrlen;
memcpy(&Port->LocalAddress, LocalAi->ai_addr, LocalAi->ai_addrlen);
Port->LocalAddressLength = (ULONG)LocalAi->ai_addrlen;
}
freeaddrinfo(RemoteAi);
freeaddrinfo(LocalAi);
}
RegCloseKey(hProto);
RegCloseKey(hType);
}
//
// Read new configuration data from the registry and see what's changed.
//
VOID
UpdateGlobalPortState(
IN PVOID Unused
)
{
LIST_ENTRY PortHead;
HKEY hPorts;
ULONG Status = NO_ERROR;
PLIST_ENTRY ple;
InitializeListHead(&PortHead);
//
// Read new port list from registry and initialize per-port proxy state.
//
Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, KEY_PORTS, 0, KEY_QUERY_VALUE,
&hPorts);
AppendType(&PortHead, hPorts, V4TOV4);
AppendType(&PortHead, hPorts, V4TOV6);
AppendType(&PortHead, hPorts, V6TOV4);
AppendType(&PortHead, hPorts, V6TOV6);
RegCloseKey(hPorts);
ApplyNewPortList(&PortHead);
//
// Free new port list.
//
while (!IsListEmpty(&PortHead)) {
ple = PortHead.Flink;
RemoveEntryList(ple);
FREE(ple);
}
}
//
// Force UpdateGlobalPortState to be executed in a persistent thread,
// since we need to make sure that the asynchronous IO routines are
// started in a thread that won't go away before the operation completes.
//
BOOL
QueueUpdateGlobalPortState(
IN PVOID Unused
)
{
NTSTATUS nts = QueueUserWorkItem(
(LPTHREAD_START_ROUTINE)UpdateGlobalPortState,
(PVOID)Unused,
WT_EXECUTEINPERSISTENTTHREAD);
return NT_SUCCESS(nts);
}
VOID
InitializePorts(
VOID
)
{
InitializeListHead(&g_GlobalPortList);
}
VOID
UninitializePorts(
VOID
)
{
LIST_ENTRY Empty;
// Check if ports got initialized to begin with.
if (g_GlobalPortList.Flink == NULL)
return;
InitializeListHead(&Empty);
ApplyNewPortList(&Empty);
}
//////////////////////////////////////////////////////////////////////////////
// Event handlers
//////////////////////////////////////////////////////////////////////////////
//
// This is called when an asynchronous accept completes successfully.
//
VOID
ProcessAccept(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
)
{
PPORT_INFO Port = CONTAINING_RECORD(Overlapped, PORT_INFO, Overlapped);
SOCKADDR_IN6 *psinLocal, *psinRemote;
int iLocalLen, iRemoteLen;
PCONNECTION_INFO Connection;
ULONG Status;
//
// Accept incoming connection.
//
GetAcceptExSockaddrs(Port->AcceptBuffer,
0,
ADDR_BUFF_LEN,
ADDR_BUFF_LEN,
(LPSOCKADDR*)&psinLocal,
&iLocalLen,
(LPSOCKADDR*)&psinRemote,
&iRemoteLen );
if (!BindIoCompletionCallback((HANDLE)Port->AcceptSocket,
TpProcessWorkItem,
0)) {
Status = GetLastError();
Trace2(SOCKET,
_T("BindIoCompletionCallback failed on socket %x with error %u"),
Port->AcceptSocket, Status);
ProcessAcceptError(NumBytes, Overlapped, Status);
return;
}
//
// Call SO_UPDATE_ACCEPT_CONTEXT so that the AcceptSocket will be valid
// in other winsock calls like shutdown().
//
if (setsockopt(Port->AcceptSocket,
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
(char *)&Port->ListenSocket,
sizeof(Port->ListenSocket)) == SOCKET_ERROR) {
Status = WSAGetLastError();
Trace2(SOCKET,
_T("SO_UPDATE_ACCEPT_CONTEXT failed on socket %x with error %u"),
Port->AcceptSocket, Status);
ProcessAcceptError(NumBytes, Overlapped, Status);
return;
}
//
// Create connection state.
//
Connection = NewConnection(Port->AcceptSocket,
Port->RemoteAddress.ss_family);
Port->AcceptSocket = INVALID_SOCKET;
//
// Add connection to port's list.
//
EnterCriticalSection(&Port->Lock);
{
Connection->Port = Port;
InsertTailList(&Port->ConnectionHead, &Connection->Link);
}
LeaveCriticalSection(&Port->Lock);
//
// Connect to real server on client's behalf.
//
StartConnect(Connection, Port);
//
// Start next accept.
//
StartAccept(Port);
//
// Release the reference from the original accept.
//
DereferencePort(&Port);
}
//
// This is called when an asynchronous accept completes with an error.
//
VOID
ProcessAcceptError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
)
{
PPORT_INFO Port = CONTAINING_RECORD(Overlapped, PORT_INFO, Overlapped);
if (Status == ERROR_MORE_DATA) {
ProcessAccept(NumBytes, Overlapped);
return;
} else {
//
// This happens at shutdown time when the accept
// socket gets closed.
//
Trace3(ERR, _T("Accept failed with port=%p nb=%d err=%x"),
Port, NumBytes, Status);
}
//
// Release the reference from the accept.
//
DereferencePort(&Port);
}
//
// This is called when an asynchronous connect completes successfully.
//
VOID
ProcessConnect(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
)
{
PDIRECTION_INFO pInbound = CONTAINING_RECORD(Overlapped, DIRECTION_INFO,
Overlapped);
PCONNECTION_INFO Connection = CONTAINING_RECORD(pInbound, CONNECTION_INFO,
DirectionInfo[Inbound]);
ULONG Status;
Trace3(SOCKET, _T("Connect succeeded with %d bytes with ovl %p socket %x"),
NumBytes, Overlapped, Connection->Socket[Server]);
//
// Call SO_UPDATE_CONNECT_CONTEXT so that the socket will be valid
// in other winsock calls like shutdown().
//
if (setsockopt(Connection->Socket[Server],
SOL_SOCKET,
SO_UPDATE_CONNECT_CONTEXT,
NULL,
0) == SOCKET_ERROR) {
Status = WSAGetLastError();
Trace2(SOCKET,
_T("SO_UPDATE_CONNECT_CONTEXT failed on socket %x with error %u"),
Connection->Socket[Server], Status);
ProcessConnectError(NumBytes, Overlapped, Status);
return;
}
StartReceive(&Connection->DirectionInfo[Inbound]);
StartReceive(&Connection->DirectionInfo[Outbound]);
//
// Release the reference from the connect.
//
Trace2(FSM, _T("R-- %d %x ProcessConnect"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// This is called when an asynchronous connect completes with an error.
//
VOID
ProcessConnectError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
)
{
PDIRECTION_INFO pInbound = CONTAINING_RECORD(Overlapped, DIRECTION_INFO,
Overlapped);
PCONNECTION_INFO Connection = CONTAINING_RECORD(pInbound, CONNECTION_INFO,
DirectionInfo[Inbound]);
Trace1(ERR, _T("ProcessConnectError saw error %x"), Status);
CloseConnection(&Connection);
//
// Release the reference from the connect.
//
Trace2(FSM, _T("R-- %d %x ProcessConnectError"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// This is called when an asynchronous send completes successfully.
//
VOID
ProcessSend(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
)
{
PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD(
Overlapped, DIRECTION_INFO, Overlapped);
PCONNECTION_INFO Connection = CONTAINING_RECORD(
DirectionInfo, CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
//
// Post another recv request since we but live to serve.
//
StartReceive(DirectionInfo);
//
// Release the reference from the send.
//
Trace2(FSM, _T("R-- %d %x ProcessSend"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// This is called when an asynchronous send completes with an error.
//
VOID
ProcessSendError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
)
{
PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped);
PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
Trace3(FSM, _T("WriteFile on ovl %p failed with error %u = 0x%x"),
Overlapped, Status, Status);
if (Status == ERROR_NETNAME_DELETED) {
struct linger Linger;
Trace2(FSM, _T("Connection %p %hs was reset"), Connection,
(DirectionInfo->Direction == Inbound)? "inbound" : "outbound");
//
// Prepare to forward the reset, if we can.
//
ZeroMemory(&Linger, sizeof(Linger));
setsockopt(Connection->Socket[DirectionInfo->Direction],
SOL_SOCKET, SO_LINGER, (char*)&Linger,
sizeof(Linger));
} else {
Trace1(ERR, _T("Send failed with error %u"), Status);
}
if (Connection->HalfOpen == FALSE) {
//
// Other side is still around, tell it to quit.
//
Trace1(SOCKET, _T("Starting a shutdown on socket %x"),
Connection->Socket[DirectionInfo->Direction]);
if (shutdown(Connection->Socket[DirectionInfo->Direction], SD_RECEIVE)
== SOCKET_ERROR) {
Status = WSAGetLastError();
Trace2(SOCKET, _T("shutdown failed with error %u = 0x%x"),
Status, Status);
CloseConnection(&Connection);
} else {
Connection->HalfOpen = TRUE;
}
} else {
CloseConnection(&Connection);
}
//
// Release the reference from the send.
//
Trace2(FSM, _T("R-- %d %x ProcessSendError"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// This is called when an asynchronous receive completes successfully.
//
VOID
ProcessReceive(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
)
{
PDIRECTION_INFO DirectionInfo;
PCONNECTION_INFO Connection;
if (NumBytes == 0) {
//
// Other side initiated a close.
//
ProcessReceiveError(0, Overlapped, ERROR_NETNAME_DELETED);
return;
}
DirectionInfo = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped);
Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
DirectionInfo->Buffer.buf[NumBytes] = 0;
Trace3(SOCKET, _T("Dir %d got %d bytes: !%hs!"), DirectionInfo->Direction,
NumBytes, DirectionInfo->Buffer.buf);
//
// Connection is still active, and we received some data.
// Post a send request to forward it onward.
//
StartSend(DirectionInfo, NumBytes);
//
// Release the reference from the receive.
//
Trace2(FSM, _T("R-- %d %x ProcessReceive"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// This is called when an asynchronous receive completes with an error.
//
VOID
ProcessReceiveError(
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped,
IN ULONG Status
)
{
PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD(Overlapped,
DIRECTION_INFO, Overlapped);
PCONNECTION_INFO Connection = CONTAINING_RECORD(
DirectionInfo,
CONNECTION_INFO,
DirectionInfo[DirectionInfo->Direction]);
Trace3(ERR, _T("ReadFile on ovl %p failed with error %u = 0x%x"),
Overlapped, Status, Status);
if (Status == ERROR_NETNAME_DELETED) {
struct linger Linger;
Trace2(FSM, _T("Connection %p %hs was reset"), Connection,
(DirectionInfo->Direction == Inbound)? "inbound" : "outbound");
//
// Prepare to forward the reset, if we can.
//
ZeroMemory(&Linger, sizeof(Linger));
setsockopt(Connection->Socket[1 - DirectionInfo->Direction],
SOL_SOCKET, SO_LINGER, (char*)&Linger,
sizeof(Linger));
} else {
Trace1(ERR, _T("Receive failed with error %u"), Status);
}
if (Connection->HalfOpen == FALSE) {
//
// Other side is still around, tell it to quit.
//
Trace1(SOCKET, _T("Starting a shutdown on socket %x"),
Connection->Socket[1 - DirectionInfo->Direction]);
if (shutdown(Connection->Socket[1 - DirectionInfo->Direction], SD_SEND)
== SOCKET_ERROR) {
Status = WSAGetLastError();
Trace2(SOCKET, _T("shutdown failed with error %u = 0x%x"),
Status, Status);
CloseConnection(&Connection);
} else {
Connection->HalfOpen = TRUE;
}
} else {
CloseConnection(&Connection);
}
//
// Release the reference from the receive.
//
Trace2(FSM, _T("R-- %d %x ProcessReceiveError"), Connection->ReferenceCount, Connection);
DereferenceConnection(&Connection);
}
//
// Main dispatch routine
//
VOID APIENTRY
TpProcessWorkItem(
IN ULONG Status,
IN ULONG NumBytes,
IN LPOVERLAPPED Overlapped
)
{
OPERATION Operation;
Operation = *(OPERATION*)(Overlapped+1);
Trace4(SOCKET,
_T("TpProcessWorkItem got err %x operation=%hs ovl %p bytes=%d"),
Status, OperationName[Operation], Overlapped, NumBytes);
if (Status == NO_ERROR) {
switch(Operation) {
case Accept:
ProcessAccept(NumBytes, Overlapped);
break;
case Connect:
ProcessConnect(NumBytes, Overlapped);
break;
case Receive:
ProcessReceive(NumBytes, Overlapped);
break;
case Send:
ProcessSend(NumBytes, Overlapped);
break;
}
} else if (Overlapped) {
switch(Operation) {
case Accept:
ProcessAcceptError(NumBytes, Overlapped, Status);
break;
case Connect:
ProcessConnectError(NumBytes, Overlapped, Status);
break;
case Receive:
ProcessReceiveError(NumBytes, Overlapped, Status);
break;
case Send:
ProcessSendError(NumBytes, Overlapped, Status);
break;
}
}
}