// // proxy.c - Generic application level proxy for IPv6/IPv4 // // This program accepts TCP connections on one socket and port, and // forwards data between in and another socket to a given address // (default loopback) and port (default same as listening port). // // For example, it can make an unmodified IPv4 server look like an IPv6 server. // Typically, the proxy will run on the same machine as // the server it is fronting, but that doesn't have to be the case. // // Copyright (C) Microsoft Corporation. // All rights reserved. // // History: // Original code by Brian Zill. // Made into a service by Dave Thaler. // #include "precomp.h" #pragma hdrstop // // Configuration parameters. // #define BUFFER_SIZE (4 * 1024) typedef enum { Connect, Accept, Receive, Send } OPERATION; CONST CHAR *OperationName[]={ "Connect", "Accept", "Receive", "Send" }; typedef enum { Inbound = 0, // Receive from client, send to server. Outbound, // Receive from server, send to client. NumDirections } DIRECTION; typedef enum { Client = 0, Server, NumSides } SIDE; // // Information we keep for each port we're proxying on. // #define ADDR_BUFF_LEN (16+sizeof(SOCKADDR_IN6)) typedef struct _PORT_INFO { LIST_ENTRY Link; ULONG ReferenceCount; SOCKET ListenSocket; SOCKET AcceptSocket; BYTE AcceptBuffer[ADDR_BUFF_LEN*2]; WSAOVERLAPPED Overlapped; OPERATION Operation; SOCKADDR_STORAGE LocalAddress; ULONG LocalAddressLength; SOCKADDR_STORAGE RemoteAddress; ULONG RemoteAddressLength; // // A lock protects the connection list for this port. // CRITICAL_SECTION Lock; LIST_ENTRY ConnectionHead; } PORT_INFO, *PPORT_INFO; // // Information we keep for each direction of a bi-directional connection. // typedef struct _DIRECTION_INFO { WSABUF Buffer; WSAOVERLAPPED Overlapped; OPERATION Operation; struct _CONNECTION_INFO *Connection; DIRECTION Direction; } DIRECTION_INFO, *PDIRECTION_INFO; // // Information we keep for each client connection. // typedef struct _CONNECTION_INFO { LIST_ENTRY Link; ULONG ReferenceCount; PPORT_INFO Port; BOOL HalfOpen; // Has one side or the other stopped sending? LONG Closing; SOCKET Socket[NumSides]; DIRECTION_INFO DirectionInfo[NumDirections]; } CONNECTION_INFO, *PCONNECTION_INFO; // // Global variables. // LIST_ENTRY g_GlobalPortList; LPFN_CONNECTEX ConnectEx = NULL; // // Function prototypes. // VOID ProcessReceiveError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ); VOID ProcessSendError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ); VOID ProcessAcceptError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ); VOID ProcessConnectError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ); VOID APIENTRY TpProcessWorkItem( IN ULONG Status, IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ); // // Inline functions. // __inline ReferenceConnection( IN PCONNECTION_INFO Connection ) { InterlockedIncrement(&Connection->ReferenceCount); } __inline DereferenceConnection( IN OUT PCONNECTION_INFO *ConnectionPtr ) { ULONG Value; Value = InterlockedDecrement(&(*ConnectionPtr)->ReferenceCount); if (Value == 0) { FREE(*ConnectionPtr); *ConnectionPtr = NULL; } } __inline VOID ReferencePort( IN PPORT_INFO Port ) { InterlockedIncrement(&Port->ReferenceCount); } __inline VOID DereferencePort( IN OUT PPORT_INFO *PortPtr ) { ULONG Value; Value = InterlockedDecrement(&(*PortPtr)->ReferenceCount); if (Value == 0) { DeleteCriticalSection(&(*PortPtr)->Lock); FREE(*PortPtr); *PortPtr = NULL; } } // // Allocate and initialize state for a new client connection. // PCONNECTION_INFO NewConnection( IN SOCKET ClientSocket, IN ULONG ConnectFamily ) { PCONNECTION_INFO Connection; // // Allocate space for a CONNECTION_INFO structure and two buffers. // Connection = (CONNECTION_INFO *)MALLOC(sizeof(*Connection) + (2 * BUFFER_SIZE)); if (Connection == NULL) { return NULL; } // // Fill everything in. // Connection->HalfOpen = FALSE; Connection->Closing = FALSE; Connection->Socket[Client] = ClientSocket; Connection->DirectionInfo[Inbound].Direction = Inbound; Connection->DirectionInfo[Inbound].Operation = Receive; // Start out receiving. Connection->DirectionInfo[Inbound].Buffer.len = BUFFER_SIZE; Connection->DirectionInfo[Inbound].Buffer.buf = (char *)(Connection + 1); Connection->DirectionInfo[Inbound].Connection = Connection; Connection->Socket[Server] = socket(ConnectFamily, SOCK_STREAM, 0); Connection->DirectionInfo[Outbound].Direction = Outbound; Connection->DirectionInfo[Outbound].Operation = Receive; // Start out receiving. Connection->DirectionInfo[Outbound].Buffer.len = BUFFER_SIZE; Connection->DirectionInfo[Outbound].Buffer.buf = Connection->DirectionInfo[Inbound].Buffer.buf + BUFFER_SIZE; Connection->DirectionInfo[Outbound].Connection = Connection; Connection->ReferenceCount = 0; ReferenceConnection(Connection); Trace2(FSM, _T("R++ %d %x NewConnection"), Connection->ReferenceCount, Connection); return Connection; } // // Create state information for a client. // CONNECTION_INFO * CreateConnectionState( IN SOCKET ClientSocket, IN SOCKET ServerSocket ) { CONNECTION_INFO *Conn; // // Allocate space for a CONNECTION_INFO structure and two buffers. // Conn = (CONNECTION_INFO *)MALLOC(sizeof(*Conn) + (2 * BUFFER_SIZE)); if (Conn == NULL) { return NULL; } // // Fill everything in. // Conn->HalfOpen = FALSE; // // Start out in the receiving state in both directions. // Conn->Socket[Client] = ClientSocket; Conn->DirectionInfo[Inbound].Direction = Inbound; Conn->DirectionInfo[Inbound].Operation = Receive; Conn->DirectionInfo[Inbound].Buffer.len = BUFFER_SIZE; Conn->DirectionInfo[Inbound].Buffer.buf = (char *)(Conn + 1); Conn->DirectionInfo[Inbound].Connection = Conn; Conn->Socket[Server] = ServerSocket; Conn->DirectionInfo[Outbound].Direction = Outbound; Conn->DirectionInfo[Outbound].Operation = Receive; Conn->DirectionInfo[Outbound].Buffer.len = BUFFER_SIZE; Conn->DirectionInfo[Outbound].Buffer.buf = Conn->DirectionInfo[Inbound].Buffer.buf + BUFFER_SIZE; Conn->DirectionInfo[Outbound].Connection = Conn; return Conn; } // // Start an asynchronous accept. // // Assumes caller holds a reference on Port. // DWORD StartAccept( IN PPORT_INFO Port ) { ULONG Status, Junk; ASSERT(Port->ReferenceCount > 0); // // Count another reference for the operation. // ReferencePort(Port); Port->AcceptSocket = socket(Port->LocalAddress.ss_family, SOCK_STREAM, 0); if (Port->AcceptSocket == INVALID_SOCKET) { Status = WSAGetLastError(); ProcessAcceptError(0, &Port->Overlapped, Status); return Status; } Trace2(SOCKET, _T("Starting an accept with new socket %x ovl %p"), Port->AcceptSocket, &Port->Overlapped); Port->Overlapped.hEvent = NULL; Port->Operation = Accept; if (!AcceptEx(Port->ListenSocket, Port->AcceptSocket, Port->AcceptBuffer, // only used to hold addresses 0, ADDR_BUFF_LEN, ADDR_BUFF_LEN, &Junk, &Port->Overlapped)) { Status = WSAGetLastError(); if (Status != ERROR_IO_PENDING) { ProcessAcceptError(0, &Port->Overlapped, Status); return Status; } } return NO_ERROR; } // // Start an asynchronous connect. // // Assumes caller holds a reference on Connection. // DWORD StartConnect( IN PCONNECTION_INFO Connection, IN PPORT_INFO Port ) { ULONG Status, Junk; SOCKADDR_STORAGE LocalAddress; // // Count a reference for the operation. // ReferenceConnection(Connection); Trace2(FSM, _T("R++ %d %x StartConnect"), Connection->ReferenceCount, Connection); Connection->Socket[Server] = socket(Port->RemoteAddress.ss_family, SOCK_STREAM, 0); if (Connection->Socket[Server] == INVALID_SOCKET) { Status = WSAGetLastError(); ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped, Status); return Status; } Connection->DirectionInfo[Inbound].Overlapped.hEvent = NULL; Connection->DirectionInfo[Outbound].Overlapped.hEvent = NULL; ZeroMemory(&LocalAddress, Port->RemoteAddressLength); LocalAddress.ss_family = Port->RemoteAddress.ss_family; if (bind(Connection->Socket[Server], (LPSOCKADDR)&LocalAddress, Port->RemoteAddressLength) == SOCKET_ERROR) { Status = WSAGetLastError(); ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped, Status); return Status; } if (!BindIoCompletionCallback((HANDLE)Connection->Socket[Server], TpProcessWorkItem, 0)) { Status = GetLastError(); ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped, Status); return Status; } if (ConnectEx == NULL) { GUID Guid = WSAID_CONNECTEX; if (WSAIoctl(Connection->Socket[Server], SIO_GET_EXTENSION_FUNCTION_POINTER, &Guid, sizeof(Guid), &ConnectEx, sizeof(ConnectEx), &Junk, NULL, NULL) == SOCKET_ERROR) { ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped, WSAGetLastError()); } } Trace2(SOCKET, _T("Starting a connect with socket %x ovl %p"), Connection->Socket[Server], &Connection->DirectionInfo[Inbound].Overlapped); Connection->DirectionInfo[Inbound].Operation = Connect; if (!ConnectEx(Connection->Socket[Server], (LPSOCKADDR)&Port->RemoteAddress, Port->RemoteAddressLength, NULL, 0, &Junk, &Connection->DirectionInfo[Inbound].Overlapped)) { Status = WSAGetLastError(); if (Status != ERROR_IO_PENDING) { ProcessConnectError(0, &Connection->DirectionInfo[Inbound].Overlapped, Status); return Status; } } return NO_ERROR; } // // Start an asynchronous receive. // // Assumes caller holds a reference on DirectionInfo. // VOID StartReceive( IN PDIRECTION_INFO DirectionInfo ) { ULONG BytesRcvd, Status; PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); Trace3(SOCKET, _T("starting ReadFile on socket %x with Dir %p ovl %p"), Connection->Socket[DirectionInfo->Direction], DirectionInfo, &DirectionInfo->Overlapped); // // Count a reference for the operation. // ReferenceConnection(Connection); Trace2(FSM, _T("R++ %d %x StartReceive"), Connection->ReferenceCount, Connection); ASSERT(DirectionInfo->Overlapped.hEvent == NULL); ASSERT(DirectionInfo->Buffer.len > 0); ASSERT(DirectionInfo->Buffer.buf != NULL); DirectionInfo->Operation = Receive; Trace5(SOCKET, _T("ReadFile %x %p %d %p %p"), Connection->Socket[DirectionInfo->Direction], &DirectionInfo->Buffer.buf, DirectionInfo->Buffer.len, &BytesRcvd, &DirectionInfo->Overlapped); // // Post receive buffer. // if (!ReadFile((HANDLE)Connection->Socket[DirectionInfo->Direction], DirectionInfo->Buffer.buf, DirectionInfo->Buffer.len, &BytesRcvd, &DirectionInfo->Overlapped)) { Status = GetLastError(); if (Status != ERROR_IO_PENDING) { ProcessReceiveError(0, &DirectionInfo->Overlapped, Status); return; } } } // // Start an asynchronous send. // // Assumes caller holds a reference on DirectionInfo. // VOID StartSend( IN PDIRECTION_INFO DirectionInfo, IN ULONG NumBytes ) { ULONG BytesSent, Status; PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); Trace3(SOCKET, _T("starting WriteFile on socket %x with Dir %p ovl %p"), Connection->Socket[1 - DirectionInfo->Direction], DirectionInfo, &DirectionInfo->Overlapped); // // Count a reference for the operation. // ReferenceConnection(Connection); Trace2(FSM, _T("R++ %d %x StartSend"), Connection->ReferenceCount, Connection); DirectionInfo->Operation = Send; // // Post send buffer. // if (!WriteFile((HANDLE)Connection->Socket[1 - DirectionInfo->Direction], DirectionInfo->Buffer.buf, NumBytes, &BytesSent, &DirectionInfo->Overlapped)) { Status = GetLastError(); if (Status != ERROR_IO_PENDING) { Trace1(ERR, _T("WriteFile 1 failed %d"), Status); ProcessSendError(0, &DirectionInfo->Overlapped, Status); return; } } } // // This gets called when we want to start proxying for a new port. // DWORD StartUpPort( IN PPORT_INFO Port ) { ULONG Status = NO_ERROR; CHAR LocalBuffer[256]; CHAR RemoteBuffer[256]; ULONG Length; // // Add an initial reference. // ReferencePort(Port); InitializeCriticalSection(&Port->Lock); InitializeListHead(&Port->ConnectionHead); Port->ListenSocket = socket(Port->LocalAddress.ss_family, SOCK_STREAM, 0); if (Port->ListenSocket == INVALID_SOCKET) { Status = WSAGetLastError(); Trace1(ERR, _T("socket() failed with error %u"), Status); return Status; } if (bind(Port->ListenSocket, (LPSOCKADDR)&Port->LocalAddress, Port->LocalAddressLength) == SOCKET_ERROR) { Trace1(ERR, _T("bind() failed with error %u"), WSAGetLastError()); goto Fail; } if (listen(Port->ListenSocket, 5) == SOCKET_ERROR) { Trace1(ERR, _T("listen() failed with error %u"), WSAGetLastError()); goto Fail; } if (!BindIoCompletionCallback((HANDLE)Port->ListenSocket, TpProcessWorkItem, 0)) { Trace1(ERR, _T("BindIoCompletionCallback() failed with error %u"), GetLastError()); goto Fail; } Length = sizeof(LocalBuffer); LocalBuffer[0] = '\0'; WSAAddressToStringA((LPSOCKADDR)&Port->LocalAddress, Port->LocalAddressLength, NULL, LocalBuffer, &Length); Length = sizeof(RemoteBuffer); RemoteBuffer[0] = '\0'; WSAAddressToStringA((LPSOCKADDR)&Port->RemoteAddress, Port->RemoteAddressLength, NULL, RemoteBuffer, &Length); Trace2(FSM, _T("Proxying %hs to %hs"), LocalBuffer, RemoteBuffer); // // Start an asynchronous accept // return StartAccept(Port); Fail: closesocket(Port->ListenSocket); Port->ListenSocket = INVALID_SOCKET; return WSAGetLastError(); } VOID CloseConnection( IN OUT PCONNECTION_INFO *ConnectionPtr ) { PCONNECTION_INFO Connection = (*ConnectionPtr); PPORT_INFO Port = Connection->Port; if (InterlockedExchange(&Connection->Closing, TRUE) != FALSE) { // // Nothing to do. // return; } Trace2(SOCKET, _T("Closing client socket %x and server socket %x"), Connection->Socket[Client], Connection->Socket[Server]); closesocket(Connection->Socket[Client]); closesocket(Connection->Socket[Server]); EnterCriticalSection(&Port->Lock); { RemoveEntryList(&Connection->Link); } LeaveCriticalSection(&Port->Lock); Trace2(FSM, _T("R-- %d %x CloseConnection"), Connection->ReferenceCount, Connection); DereferenceConnection(ConnectionPtr); } // // This gets called when we want to stop proxying for a given port. // VOID ShutDownPort( IN PPORT_INFO *PortPtr ) { PCONNECTION_INFO Connection; PPORT_INFO Port = *PortPtr; // // Close any connections. // EnterCriticalSection(&Port->Lock); while (!IsListEmpty(&Port->ConnectionHead)) { Connection = CONTAINING_RECORD(Port->ConnectionHead.Flink, CONNECTION_INFO, Link); CloseConnection(&Connection); } LeaveCriticalSection(&Port->Lock); closesocket(Port->ListenSocket); Port->ListenSocket = INVALID_SOCKET; Trace1(FSM, _T("Shut down port %u"), RtlUshortByteSwap(SS_PORT(&Port->RemoteAddress))); // // Release the reference added by StartUpPort. // DereferencePort(PortPtr); } typedef enum { V4TOV4, V4TOV6, V6TOV4, V6TOV6 } PPTYPE, *PPPTYPE; typedef struct { ULONG ListenFamily; ULONG ConnectFamily; PWCHAR KeyString; } PPTYPEINFO, *PPPTYPEINFO; #define KEY_V4TOV4 L"v4tov4" #define KEY_V4TOV6 L"v4tov6" #define KEY_V6TOV4 L"v6tov4" #define KEY_V6TOV6 L"v6tov6" #define KEY_PORTS L"System\\CurrentControlSet\\Services\\PortProxy" PPTYPEINFO PpTypeInfo[] = { { AF_INET, AF_INET, KEY_V4TOV4 }, { AF_INET, AF_INET6, KEY_V4TOV6 }, { AF_INET6, AF_INET, KEY_V6TOV4 }, { AF_INET6, AF_INET6, KEY_V6TOV6 }, }; // // Given new configuration data, make any changes needed. // VOID ApplyNewPortList( IN OUT PLIST_ENTRY pNewList ) { PPORT_INFO Port, pCurr; PLIST_ENTRY pleCurr, plePort, pleNext; // // Compare against current port list. // for (pleCurr = g_GlobalPortList.Flink; pleCurr != &g_GlobalPortList; pleCurr = pleNext) { pleNext = pleCurr->Flink; pCurr = CONTAINING_RECORD(pleCurr, PORT_INFO, Link); for (plePort = pNewList->Flink; plePort != pNewList; plePort = plePort->Flink) { Port = CONTAINING_RECORD(plePort, PORT_INFO, Link); if (SS_PORT(&Port->RemoteAddress) == SS_PORT(&pCurr->RemoteAddress)) { break; } } if (plePort == pNewList) { // // Shut down an old proxy port. // RemoveEntryList(pleCurr); ShutDownPort(&pCurr); } } for (plePort = pNewList->Flink; plePort != pNewList; plePort = pleNext) { pleNext = plePort->Flink; Port = CONTAINING_RECORD(plePort, PORT_INFO, Link); for (pleCurr = g_GlobalPortList.Flink; pleCurr != &g_GlobalPortList; pleCurr = pleCurr->Flink) { pCurr = CONTAINING_RECORD(pleCurr, PORT_INFO, Link); if (SS_PORT(&Port->LocalAddress) == SS_PORT(&pCurr->LocalAddress)) { // // Update remote address. // pCurr->RemoteAddress = Port->RemoteAddress; pCurr->RemoteAddressLength = Port->RemoteAddressLength; break; } } if (pleCurr == &g_GlobalPortList) { // // Start up a new proxy port. // RemoveEntryList(plePort); InsertTailList(&g_GlobalPortList, plePort); StartUpPort(Port); } } } // // Reads from the registry one type of proxying (e.g., v6-to-v4). // VOID AppendType( IN PLIST_ENTRY Head, IN HKEY hPorts, IN PPTYPE Type ) { ADDRINFO ListenHints, ConnectHints; ADDRINFO *LocalAi, *RemoteAi; ULONG ListenChars, dwType, ConnectBytes, i; WCHAR ListenBuffer[256], *ListenAddress, *ListenPort; WCHAR ConnectAddress[256], *ConnectPort; PPORT_INFO Port; ULONG Status; HKEY hType, hProto; ZeroMemory(&ListenHints, sizeof(ListenHints)); ListenHints.ai_family = PpTypeInfo[Type].ListenFamily; ListenHints.ai_socktype = SOCK_STREAM; ListenHints.ai_flags = AI_PASSIVE; ZeroMemory(&ConnectHints, sizeof(ConnectHints)); ConnectHints.ai_family = PpTypeInfo[Type].ConnectFamily; ConnectHints.ai_socktype = SOCK_STREAM; Status = RegOpenKeyExW(hPorts, PpTypeInfo[Type].KeyString, 0, KEY_QUERY_VALUE, &hType); if (Status != NO_ERROR) { return; } Status = RegOpenKeyExW(hType, L"tcp", 0, KEY_QUERY_VALUE, &hProto); if (Status != NO_ERROR) { RegCloseKey(hType); return; } for (i=0; ; i++) { ListenChars = sizeof(ListenBuffer)/sizeof(WCHAR); ConnectBytes = sizeof(ConnectAddress); Status = RegEnumValueW(hProto, i, ListenBuffer, &ListenChars, NULL, &dwType, (PVOID)ConnectAddress, &ConnectBytes); if (Status != NO_ERROR) { break; } if (dwType != REG_SZ) { continue; } ListenPort = wcschr(ListenBuffer, L'/'); if (ListenPort) { // // Replace slash with NULL, so we have 2 strings to pass // to getaddrinfo. // if (ListenBuffer[0] == '*') { ListenAddress = NULL; } else { ListenAddress = ListenBuffer; } *ListenPort++ = '\0'; } else { // // If the address data didn't include a connect address // use NULL. // ListenAddress = NULL; ListenPort = ListenBuffer; } ConnectPort = wcschr(ConnectAddress, '/'); if (ConnectPort) { // // Replace slash with NULL, so we have 2 strings to pass // to getaddrinfo. // *ConnectPort++ = '\0'; } else { // // If the address data didn't include a remote port number, // use the same port as the local port number. // ConnectPort = ListenPort; } Status = GetAddrInfoW(ConnectAddress, ConnectPort, &ConnectHints, &RemoteAi); if (Status != NO_ERROR) { continue; } Status = GetAddrInfoW(ListenAddress, ListenPort, &ListenHints, &LocalAi); if (Status != NO_ERROR) { freeaddrinfo(RemoteAi); continue; } Port = MALLOC(sizeof(PORT_INFO)); if (Port) { ZeroMemory(Port, sizeof(PORT_INFO)); InsertTailList(Head, &Port->Link); memcpy(&Port->RemoteAddress, RemoteAi->ai_addr, RemoteAi->ai_addrlen); Port->RemoteAddressLength = (ULONG)RemoteAi->ai_addrlen; memcpy(&Port->LocalAddress, LocalAi->ai_addr, LocalAi->ai_addrlen); Port->LocalAddressLength = (ULONG)LocalAi->ai_addrlen; } freeaddrinfo(RemoteAi); freeaddrinfo(LocalAi); } RegCloseKey(hProto); RegCloseKey(hType); } // // Read new configuration data from the registry and see what's changed. // VOID UpdateGlobalPortState( IN PVOID Unused ) { LIST_ENTRY PortHead; HKEY hPorts; ULONG Status = NO_ERROR; PLIST_ENTRY ple; InitializeListHead(&PortHead); // // Read new port list from registry and initialize per-port proxy state. // Status = RegOpenKeyExW(HKEY_LOCAL_MACHINE, KEY_PORTS, 0, KEY_QUERY_VALUE, &hPorts); AppendType(&PortHead, hPorts, V4TOV4); AppendType(&PortHead, hPorts, V4TOV6); AppendType(&PortHead, hPorts, V6TOV4); AppendType(&PortHead, hPorts, V6TOV6); RegCloseKey(hPorts); ApplyNewPortList(&PortHead); // // Free new port list. // while (!IsListEmpty(&PortHead)) { ple = PortHead.Flink; RemoveEntryList(ple); FREE(ple); } } // // Force UpdateGlobalPortState to be executed in a persistent thread, // since we need to make sure that the asynchronous IO routines are // started in a thread that won't go away before the operation completes. // BOOL QueueUpdateGlobalPortState( IN PVOID Unused ) { NTSTATUS nts = QueueUserWorkItem( (LPTHREAD_START_ROUTINE)UpdateGlobalPortState, (PVOID)Unused, WT_EXECUTEINPERSISTENTTHREAD); return NT_SUCCESS(nts); } VOID InitializePorts( VOID ) { InitializeListHead(&g_GlobalPortList); } VOID UninitializePorts( VOID ) { LIST_ENTRY Empty; // Check if ports got initialized to begin with. if (g_GlobalPortList.Flink == NULL) return; InitializeListHead(&Empty); ApplyNewPortList(&Empty); } ////////////////////////////////////////////////////////////////////////////// // Event handlers ////////////////////////////////////////////////////////////////////////////// // // This is called when an asynchronous accept completes successfully. // VOID ProcessAccept( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ) { PPORT_INFO Port = CONTAINING_RECORD(Overlapped, PORT_INFO, Overlapped); SOCKADDR_IN6 *psinLocal, *psinRemote; int iLocalLen, iRemoteLen; PCONNECTION_INFO Connection; ULONG Status; // // Accept incoming connection. // GetAcceptExSockaddrs(Port->AcceptBuffer, 0, ADDR_BUFF_LEN, ADDR_BUFF_LEN, (LPSOCKADDR*)&psinLocal, &iLocalLen, (LPSOCKADDR*)&psinRemote, &iRemoteLen ); if (!BindIoCompletionCallback((HANDLE)Port->AcceptSocket, TpProcessWorkItem, 0)) { Status = GetLastError(); Trace2(SOCKET, _T("BindIoCompletionCallback failed on socket %x with error %u"), Port->AcceptSocket, Status); ProcessAcceptError(NumBytes, Overlapped, Status); return; } // // Call SO_UPDATE_ACCEPT_CONTEXT so that the AcceptSocket will be valid // in other winsock calls like shutdown(). // if (setsockopt(Port->AcceptSocket, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, (char *)&Port->ListenSocket, sizeof(Port->ListenSocket)) == SOCKET_ERROR) { Status = WSAGetLastError(); Trace2(SOCKET, _T("SO_UPDATE_ACCEPT_CONTEXT failed on socket %x with error %u"), Port->AcceptSocket, Status); ProcessAcceptError(NumBytes, Overlapped, Status); return; } // // Create connection state. // Connection = NewConnection(Port->AcceptSocket, Port->RemoteAddress.ss_family); Port->AcceptSocket = INVALID_SOCKET; // // Add connection to port's list. // EnterCriticalSection(&Port->Lock); { Connection->Port = Port; InsertTailList(&Port->ConnectionHead, &Connection->Link); } LeaveCriticalSection(&Port->Lock); // // Connect to real server on client's behalf. // StartConnect(Connection, Port); // // Start next accept. // StartAccept(Port); // // Release the reference from the original accept. // DereferencePort(&Port); } // // This is called when an asynchronous accept completes with an error. // VOID ProcessAcceptError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ) { PPORT_INFO Port = CONTAINING_RECORD(Overlapped, PORT_INFO, Overlapped); if (Status == ERROR_MORE_DATA) { ProcessAccept(NumBytes, Overlapped); return; } else { // // This happens at shutdown time when the accept // socket gets closed. // Trace3(ERR, _T("Accept failed with port=%p nb=%d err=%x"), Port, NumBytes, Status); } // // Release the reference from the accept. // DereferencePort(&Port); } // // This is called when an asynchronous connect completes successfully. // VOID ProcessConnect( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ) { PDIRECTION_INFO pInbound = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped); PCONNECTION_INFO Connection = CONTAINING_RECORD(pInbound, CONNECTION_INFO, DirectionInfo[Inbound]); ULONG Status; Trace3(SOCKET, _T("Connect succeeded with %d bytes with ovl %p socket %x"), NumBytes, Overlapped, Connection->Socket[Server]); // // Call SO_UPDATE_CONNECT_CONTEXT so that the socket will be valid // in other winsock calls like shutdown(). // if (setsockopt(Connection->Socket[Server], SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0) == SOCKET_ERROR) { Status = WSAGetLastError(); Trace2(SOCKET, _T("SO_UPDATE_CONNECT_CONTEXT failed on socket %x with error %u"), Connection->Socket[Server], Status); ProcessConnectError(NumBytes, Overlapped, Status); return; } StartReceive(&Connection->DirectionInfo[Inbound]); StartReceive(&Connection->DirectionInfo[Outbound]); // // Release the reference from the connect. // Trace2(FSM, _T("R-- %d %x ProcessConnect"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // This is called when an asynchronous connect completes with an error. // VOID ProcessConnectError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ) { PDIRECTION_INFO pInbound = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped); PCONNECTION_INFO Connection = CONTAINING_RECORD(pInbound, CONNECTION_INFO, DirectionInfo[Inbound]); Trace1(ERR, _T("ProcessConnectError saw error %x"), Status); CloseConnection(&Connection); // // Release the reference from the connect. // Trace2(FSM, _T("R-- %d %x ProcessConnectError"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // This is called when an asynchronous send completes successfully. // VOID ProcessSend( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ) { PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD( Overlapped, DIRECTION_INFO, Overlapped); PCONNECTION_INFO Connection = CONTAINING_RECORD( DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); // // Post another recv request since we but live to serve. // StartReceive(DirectionInfo); // // Release the reference from the send. // Trace2(FSM, _T("R-- %d %x ProcessSend"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // This is called when an asynchronous send completes with an error. // VOID ProcessSendError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ) { PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped); PCONNECTION_INFO Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); Trace3(FSM, _T("WriteFile on ovl %p failed with error %u = 0x%x"), Overlapped, Status, Status); if (Status == ERROR_NETNAME_DELETED) { struct linger Linger; Trace2(FSM, _T("Connection %p %hs was reset"), Connection, (DirectionInfo->Direction == Inbound)? "inbound" : "outbound"); // // Prepare to forward the reset, if we can. // ZeroMemory(&Linger, sizeof(Linger)); setsockopt(Connection->Socket[DirectionInfo->Direction], SOL_SOCKET, SO_LINGER, (char*)&Linger, sizeof(Linger)); } else { Trace1(ERR, _T("Send failed with error %u"), Status); } if (Connection->HalfOpen == FALSE) { // // Other side is still around, tell it to quit. // Trace1(SOCKET, _T("Starting a shutdown on socket %x"), Connection->Socket[DirectionInfo->Direction]); if (shutdown(Connection->Socket[DirectionInfo->Direction], SD_RECEIVE) == SOCKET_ERROR) { Status = WSAGetLastError(); Trace2(SOCKET, _T("shutdown failed with error %u = 0x%x"), Status, Status); CloseConnection(&Connection); } else { Connection->HalfOpen = TRUE; } } else { CloseConnection(&Connection); } // // Release the reference from the send. // Trace2(FSM, _T("R-- %d %x ProcessSendError"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // This is called when an asynchronous receive completes successfully. // VOID ProcessReceive( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ) { PDIRECTION_INFO DirectionInfo; PCONNECTION_INFO Connection; if (NumBytes == 0) { // // Other side initiated a close. // ProcessReceiveError(0, Overlapped, ERROR_NETNAME_DELETED); return; } DirectionInfo = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped); Connection = CONTAINING_RECORD(DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); DirectionInfo->Buffer.buf[NumBytes] = 0; Trace3(SOCKET, _T("Dir %d got %d bytes: !%hs!"), DirectionInfo->Direction, NumBytes, DirectionInfo->Buffer.buf); // // Connection is still active, and we received some data. // Post a send request to forward it onward. // StartSend(DirectionInfo, NumBytes); // // Release the reference from the receive. // Trace2(FSM, _T("R-- %d %x ProcessReceive"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // This is called when an asynchronous receive completes with an error. // VOID ProcessReceiveError( IN ULONG NumBytes, IN LPOVERLAPPED Overlapped, IN ULONG Status ) { PDIRECTION_INFO DirectionInfo = CONTAINING_RECORD(Overlapped, DIRECTION_INFO, Overlapped); PCONNECTION_INFO Connection = CONTAINING_RECORD( DirectionInfo, CONNECTION_INFO, DirectionInfo[DirectionInfo->Direction]); Trace3(ERR, _T("ReadFile on ovl %p failed with error %u = 0x%x"), Overlapped, Status, Status); if (Status == ERROR_NETNAME_DELETED) { struct linger Linger; Trace2(FSM, _T("Connection %p %hs was reset"), Connection, (DirectionInfo->Direction == Inbound)? "inbound" : "outbound"); // // Prepare to forward the reset, if we can. // ZeroMemory(&Linger, sizeof(Linger)); setsockopt(Connection->Socket[1 - DirectionInfo->Direction], SOL_SOCKET, SO_LINGER, (char*)&Linger, sizeof(Linger)); } else { Trace1(ERR, _T("Receive failed with error %u"), Status); } if (Connection->HalfOpen == FALSE) { // // Other side is still around, tell it to quit. // Trace1(SOCKET, _T("Starting a shutdown on socket %x"), Connection->Socket[1 - DirectionInfo->Direction]); if (shutdown(Connection->Socket[1 - DirectionInfo->Direction], SD_SEND) == SOCKET_ERROR) { Status = WSAGetLastError(); Trace2(SOCKET, _T("shutdown failed with error %u = 0x%x"), Status, Status); CloseConnection(&Connection); } else { Connection->HalfOpen = TRUE; } } else { CloseConnection(&Connection); } // // Release the reference from the receive. // Trace2(FSM, _T("R-- %d %x ProcessReceiveError"), Connection->ReferenceCount, Connection); DereferenceConnection(&Connection); } // // Main dispatch routine // VOID APIENTRY TpProcessWorkItem( IN ULONG Status, IN ULONG NumBytes, IN LPOVERLAPPED Overlapped ) { OPERATION Operation; Operation = *(OPERATION*)(Overlapped+1); Trace4(SOCKET, _T("TpProcessWorkItem got err %x operation=%hs ovl %p bytes=%d"), Status, OperationName[Operation], Overlapped, NumBytes); if (Status == NO_ERROR) { switch(Operation) { case Accept: ProcessAccept(NumBytes, Overlapped); break; case Connect: ProcessConnect(NumBytes, Overlapped); break; case Receive: ProcessReceive(NumBytes, Overlapped); break; case Send: ProcessSend(NumBytes, Overlapped); break; } } else if (Overlapped) { switch(Operation) { case Accept: ProcessAcceptError(NumBytes, Overlapped, Status); break; case Connect: ProcessConnectError(NumBytes, Overlapped, Status); break; case Receive: ProcessReceiveError(NumBytes, Overlapped, Status); break; case Send: ProcessSendError(NumBytes, Overlapped, Status); break; } } }