windows-nt/Source/XPSP1/NT/net/tcpip/services/tftp/tftp_lib.c

4472 lines
117 KiB
C
Raw Permalink Normal View History

2020-09-26 03:20:57 -05:00
/*++
Copyright (c) 1992-1996 Microsoft Corporation
Module Name:
tftpd.c
Abstract:
This implements an RFC 783 tftp daemon. The tftp daemon listens on it's
well-known port waiting for requests. When a valid request comes in, it
spawns a thread to process the request.
Functions Defined:
TftpdErrorPacket - sends an error reply.
TftpdDoRead - read from file and convert.
TftpdHandleRead - incoming read file request.
read file => sendto.
TftpdDoWrite - convert and write to file.
TftpdHandleWrite - incoming write file request, calls TftpdDoWrite().
Author: Sam Patton (sampa) 08-apr-1992
Revision History:
MohsinA, 02-Dec-96.
--*/
#include "tftpd.h"
#if defined(REMOTE_BOOT_SECURITY)
#include <ipsec.h>
#endif // defined(REMOTE_BOOT_SECURITY)
extern TFTP_GLOBALS Globals;
char * ErrorString[NUM_TFTP_ERROR_CODES] =
{
"Error undefined",
"File not found",
"Access violation",
"Disk full or allocation exceeded",
"Illegal TFTP operation",
"Unknown transfer ID",
"File already exists",
"No such user",
"Option negotiation failure"
};
#if defined(REMOTE_BOOT_SECURITY)
//
// These routines manage the security info structures for clients
// thay have logged in. I put all the code that deals with how the
// structures are actually stored in these functions, in case I
// change it.
//
PTFTPD_SECURITY SecurityArray = NULL;
USHORT SecurityArrayLength = 0;
USHORT SecurityValidation;
CRITICAL_SECTION SecurityCriticalSection;
#define INITIAL_SECURITY_ARRAY_SIZE 8
UCHAR
TftpdHexDigitToChar(
PUCHAR HexDigit
)
{
UCHAR Nibble;
UCHAR ReturnValue = 0;
int i;
for (i = 0; i < 2; i++) {
if ((HexDigit[i] >= '0') && (HexDigit[i] <= '9')) {
Nibble = (UCHAR)(HexDigit[i] - '0');
} else if ((HexDigit[i] >= 'a') && (HexDigit[i] <= 'f')) {
Nibble = (UCHAR)(HexDigit[i] - 'a' + 10);
} else if ((HexDigit[i] >= 'A') && (HexDigit[i] <= 'F')) {
Nibble = (UCHAR)(HexDigit[i] - 'A' + 10);
} else {
Nibble = 0;
}
ReturnValue = (UCHAR)((ReturnValue << 4) + Nibble);
}
return ReturnValue;
}
BOOL
TftpdInitSecurityArray(
VOID
)
{
int i;
SecurityArray = (PTFTPD_SECURITY)malloc(sizeof(TFTPD_SECURITY) * INITIAL_SECURITY_ARRAY_SIZE);
if (SecurityArray == NULL) {
DbgPrint("TftpdInitSecurityArray: cannot allocate security array\n");
return FALSE;
}
for (i = 0; i < INITIAL_SECURITY_ARRAY_SIZE; i++) {
SecurityArray[i].Validation = 0; // means this entry is free
SecurityArray[i].LastFileRead[0] = '\0';
}
SecurityArrayLength = INITIAL_SECURITY_ARRAY_SIZE;
srand((unsigned)time(NULL));
SecurityValidation = (USHORT)rand();
RtlInitializeCriticalSection(&SecurityCriticalSection);
return TRUE;
}
VOID
TftpdUninitSecurityArray(
VOID
)
{
free(SecurityArray);
}
BOOL
TftpdAllocateSecurityEntry(
PUSHORT Index,
PTFTPD_SECURITY Security
)
{
USHORT i, j;
RtlEnterCriticalSection (&SecurityCriticalSection);
for (i = 0; i < SecurityArrayLength; i++) {
if (SecurityArray[i].Validation == 0) {
break;
}
}
if (i == SecurityArrayLength) {
PTFTPD_SECURITY NewSecurity;
USHORT NewSecurityLength;
//
// Could not find a spot, double the array.
//
if (SecurityArrayLength < 0x8000) {
NewSecurityLength = SecurityArrayLength * 2;
} else {
NewSecurityLength = 0xffff;
}
NewSecurity = malloc(sizeof(TFTPD_SECURITY) * NewSecurityLength);
if (NewSecurity == NULL) {
RtlLeaveCriticalSection (&SecurityCriticalSection);
return FALSE;
}
memcpy(NewSecurity, SecurityArray, sizeof(TFTPD_SECURITY) * SecurityArrayLength);
i = SecurityArrayLength;
for (j = SecurityArrayLength; j < NewSecurityLength; j++) {
NewSecurity[j].Validation = 0; // means this entry is free
NewSecurity[j].LastFileRead[0] = '\0';
}
SecurityArray = NewSecurity;
SecurityArrayLength = NewSecurityLength;
}
SecurityArray[i].Validation = SecurityValidation;
SecurityArray[i].CredentialsHandleValid = FALSE;
SecurityArray[i].ServerContextHandleValid = FALSE;
SecurityArray[i].GeneratedKey = FALSE;
SecurityValidation = (SecurityValidation % 10000) + 1;
*Security = SecurityArray[i];
RtlLeaveCriticalSection (&SecurityCriticalSection);
*Index = i;
return TRUE;
}
VOID
TftpdFreeSecurityEntry(
USHORT Index
)
{
TFTPD_SECURITY TempSecurity; // save it so we can leave the critical section
RtlEnterCriticalSection (&SecurityCriticalSection);
if (Index < SecurityArrayLength) {
TempSecurity = SecurityArray[Index];
SecurityArray[Index].Validation = 0; // this marks it as free
RtlLeaveCriticalSection (&SecurityCriticalSection);
if (TempSecurity.ServerContextHandleValid) {
DeleteSecurityContext(&TempSecurity.ServerContextHandle);
}
if (TempSecurity.CredentialsHandleValid) {
FreeCredentialsHandle(&TempSecurity.CredentialsHandle);
}
} else {
RtlLeaveCriticalSection (&SecurityCriticalSection);
}
}
VOID
TftpdGetSecurityEntry(
USHORT Index,
PTFTPD_SECURITY Security
)
{
RtlEnterCriticalSection (&SecurityCriticalSection);
if (Index < SecurityArrayLength) {
*Security = SecurityArray[Index];
} else {
memset(Security, 0, sizeof(TFTPD_SECURITY));
}
RtlLeaveCriticalSection (&SecurityCriticalSection);
}
VOID
TftpdStoreSecurityEntry(
USHORT Index,
PTFTPD_SECURITY Security
)
{
RtlEnterCriticalSection (&SecurityCriticalSection);
if (Index < SecurityArrayLength) {
SecurityArray[Index] = *Security;
}
RtlLeaveCriticalSection (&SecurityCriticalSection);
}
VOID
TftpdGenerateKeyForSecurityEntry(
USHORT Index,
PTFTPD_SECURITY Security
)
{
SecBufferDesc SignMessage;
SecBuffer SigBuffers[2];
SECURITY_STATUS SecStatus;
LARGE_INTEGER SystemTime;
RtlEnterCriticalSection (&SecurityCriticalSection);
if (Index < SecurityArrayLength) {
if (!SecurityArray[Index].GeneratedKey) {
//
// Generate and sign a key.
//
NtQuerySystemTime(&SystemTime);
SecurityArray[Index].Key = (ULONG)(SystemTime.QuadPart % SecurityArray[Index].ForeignAddress.sin_addr.s_addr);
*(PULONG)(SecurityArray[Index].SignedKey) = SecurityArray[Index].Key;
SigBuffers[0].pvBuffer = SecurityArray[Index].SignedKey;
SigBuffers[0].cbBuffer = sizeof(SecurityArray[Index].SignedKey);
SigBuffers[0].BufferType = SECBUFFER_DATA;
SigBuffers[1].pvBuffer = SecurityArray[Index].Sign;
SigBuffers[1].cbBuffer = NTLMSSP_MESSAGE_SIGNATURE_SIZE;
SigBuffers[1].BufferType = SECBUFFER_TOKEN;
SignMessage.pBuffers = SigBuffers;
SignMessage.cBuffers = 2;
SignMessage.ulVersion = 0;
SecStatus = SealMessage(
&(SecurityArray[Index].ServerContextHandle),
0,
&SignMessage,
0 );
if (SecStatus == STATUS_SUCCESS) {
SecurityArray[Index].GeneratedKey = TRUE;
}
}
*Security = SecurityArray[Index];
} else {
memset(Security, 0, sizeof(TFTPD_SECURITY));
}
RtlLeaveCriticalSection (&SecurityCriticalSection);
}
SECURITY_STATUS
TftpdVerifyFileSignature(
USHORT Index,
USHORT Validation,
PTFTPD_SECURITY Security,
char * FileName,
char * Sign,
USHORT ClientPort
)
{
unsigned long FileNameLength;
char * CompareFileName;
SecBufferDesc SignMessage;
SecBuffer SigBuffers[2];
SECURITY_STATUS SecStatus;
PTFTPD_SECURITY TmpSecurity; // points to the real location in the array
//
// First figure out where the last 64 characters of the
// requested filename are since that is all we save.
//
FileNameLength = strlen(FileName);
if (FileNameLength < sizeof(Security->LastFileRead)) {
CompareFileName = FileName;
} else {
CompareFileName = FileName + (FileNameLength + 1 - sizeof(Security->LastFileRead));
}
//
// Make sure that the sign for the filename is valid. If this
// is the same as the last filename requested for this security
// entry, and it is coming in on the same port as before,
// then we assume the client is retransmitting the request,
// so therefore has not re-generated the sign, so we just compare
// the sign with the one he sent last time instead of calling
// VerifySignature again (to prevent us getting unbalanced with
// his MakeSignature call).
//
RtlEnterCriticalSection (&SecurityCriticalSection);
if ((Index < SecurityArrayLength) &&
(SecurityArray[Index].Validation == Validation)) {
TmpSecurity = &SecurityArray[Index];
} else {
memset(Security, 0, sizeof(TFTPD_SECURITY));
return (SECURITY_STATUS)STATUS_INVALID_HANDLE;
}
if ((strcmp(CompareFileName, TmpSecurity->LastFileRead) == 0) &&
(ClientPort == TmpSecurity->LastFileReadPort)) {
//
// Compare them, and fake a security error if they don't match.
//
if (memcmp(TmpSecurity->LastFileSign, Sign, NTLMSSP_MESSAGE_SIGNATURE_SIZE) == 0) {
SecStatus = SEC_E_OK;
} else {
SecStatus = SEC_E_MESSAGE_ALTERED;
}
} else {
//
// Save the values in case this request is resent.
//
strcpy(TmpSecurity->LastFileRead, CompareFileName);
memcpy(TmpSecurity->LastFileSign, Sign, NTLMSSP_MESSAGE_SIGNATURE_SIZE);
TmpSecurity->LastFileReadPort = ClientPort;
//
// Now make sure the signature is correct.
//
SigBuffers[1].pvBuffer = Sign;
SigBuffers[1].cbBuffer = NTLMSSP_MESSAGE_SIGNATURE_SIZE;
SigBuffers[1].BufferType = SECBUFFER_TOKEN;
SigBuffers[0].pvBuffer = FileName;
SigBuffers[0].cbBuffer = FileNameLength;
SigBuffers[0].BufferType = SECBUFFER_DATA | SECBUFFER_READONLY;
SignMessage.pBuffers = SigBuffers;
SignMessage.cBuffers = 2;
SignMessage.ulVersion = 0;
SecStatus = VerifySignature(
&TmpSecurity->ServerContextHandle,
&SignMessage,
0,
0 );
}
*Security = *TmpSecurity;
RtlLeaveCriticalSection (&SecurityCriticalSection);
return SecStatus;
}
#endif // defined(REMOTE_BOOT_SECURITY)
// ========================================================================
VOID
TftpdErrorPacket(
struct sockaddr * PeerAddress,
char * RequestPacket,
SOCKET LocalSocket,
unsigned short ErrorCode,
char * ErrorMessage OPTIONAL
)
/*++
Routine Description:
This sends an error packet back to the person who sent the request. The
RequestPacket is used to select an appropriate error code.
Arguments:
PeerAddress - The remote address
RequestPacket - packet making the request
LocalSocket - socket to send error from
Return Value:
None
Error?
--*/
{
char ErrorPacket[MAX_TFTP_DATAGRAM];
int err;
int errorLength;
((unsigned short *) ErrorPacket)[0] = htons(TFTPD_ERROR);
((unsigned short *) ErrorPacket)[1] = htons(ErrorCode);
DbgPrint("TftpdError: Sending error packet Error Code: %d",ErrorCode);
if ( ErrorMessage != NULL ) {
strcpy(&ErrorPacket[4], ErrorMessage);
errorLength = strlen(ErrorMessage);
} else {
if (ErrorCode >= NUM_TFTP_ERROR_CODES) {
DbgPrint("TftpdErrorPacket: Unknown ErrorCode=%d.\n",
ErrorCode );
ErrorCode = 0;
}
strcpy(&ErrorPacket[4], ErrorString[ErrorCode]);
errorLength = strlen(ErrorString[ErrorCode]);
}
err = sendto(
LocalSocket,
ErrorPacket,
5 + errorLength,
0,
PeerAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == err ){
DbgPrint("TftpdErrorPacket: sendto failed=%d\n",
WSAGetLastError() );
}
return;
}
#if defined(REMOTE_BOOT_SECURITY)
int
TftpdProcessOptionsPhase1(
PTFTP_REQUEST Request,
PUCHAR Options,
int Opcode
)
{
int i;
//
// Assume default values.
//
Request->SecurityHandle = 0;
//
// Walk through the remainder of the request packet, looking for options
// that we need to process in phase 1.
//
while ( *Options != 0 ) {
if ( _stricmp(Options, "security") == 0 ) {
Options += sizeof("security");
if ( Opcode == TFTPD_RRQ ) {
Request->SecurityHandle = atoi(Options);
}
Options += strlen(Options) + 1;
} else if ( _stricmp(Options, "sign") == 0 ) {
Options += sizeof("sign");
if ( Opcode == TFTPD_RRQ ) {
for (i = 0; i < NTLMSSP_MESSAGE_SIGNATURE_SIZE; i++) {
Request->Sign[i] = TftpdHexDigitToChar(&Options[i*2]);
}
}
Options += strlen(Options) + 1;
} else {
//
// Unrecognized option. Skip the option ID and the value.
//
Options += strlen(Options) + 1;
if ( *Options != 0 ) {
Options += strlen(Options) + 1;
}
}
}
return 0;
}
#endif // defined(REMOTE_BOOT_SECURITY)
int
TftpdProcessOptionsPhase2(
PTFTP_REQUEST Request,
PUCHAR Options,
int Opcode,
int *OackLength,
char *PacketBuffer,
BOOL *ReceivedTimeoutOption
)
{
PUCHAR oack;
//
// Assume default values.
//
Request->BlockSize = MAX_OACK_PACKET_LENGTH - 4; // Save an alloc mem by default to the current packet size.
Request->Timeout = 10;
*ReceivedTimeoutOption=FALSE;
//
// Build the OACK header.
//
memset( PacketBuffer, 0, MAX_OACK_PACKET_LENGTH );
((unsigned short *)PacketBuffer)[0] = htons(TFTPD_OACK);
oack = &PacketBuffer[2];
//
// Walk through the remainder of the request packet, looking for options
// that we understand.
//
while ( *Options != 0 ) {
if ( _stricmp(Options, "blksize") == 0 ) {
strcpy( oack, Options );
oack += sizeof("blksize");
Options += sizeof("blksize");
Request->BlockSize = atoi(Options);
if ( (Request->BlockSize < 8) ||
(Request->BlockSize > 65464) ) {
DbgPrint("TftpdProcessOptionsPhase2: invalid blksize=%s\n", Options );
TftpdErrorPacket(
(struct sockaddr *)&Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_OPTION_NEGOT_FAILED,
NULL);
return -1;
}
//
// Workaround for problem in .98 version of ROM, which
// doesn't like our OACK response. If the requested blksize is
// 1456, pretend that the option wasn't specified. In the case
// of the ROM's TFTP layer, this is the only option specified,
// so ignoring it will mean that we don't send an OACK, and the
// ROM will deign to talk to us. Note that our TFTP code uses
// a blksize of 1432, so this workaround won't affect us.
//
if ( Request->BlockSize == 1456 ) {
Request->BlockSize = MAX_OACK_PACKET_LENGTH - 4;
oack -= sizeof("blksize");
Options += strlen(Options) + 1;
continue;
}
if ( Request->BlockSize > MAX_TFTP_DATA ) {
Request->BlockSize = MAX_TFTP_DATA;
}
_itoa( Request->BlockSize, oack, 10 );
oack += strlen(oack) + 1;
Options += strlen(Options) + 1;
} else if ( _stricmp(Options, "timeout") == 0 ) {
strcpy( oack, Options );
oack += sizeof("timeout");
Options += sizeof("timeout");
Request->Timeout = atoi(Options);
if ( (Request->Timeout < 1) ||
(Request->Timeout > 255) ) {
DbgPrint("TftpdProcessOptionsPhase2: invalid timeout=%s\n", Options );
TftpdErrorPacket(
(struct sockaddr *)&Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_OPTION_NEGOT_FAILED,
NULL);
return -1;
}
*ReceivedTimeoutOption = TRUE;
strcpy( oack, Options );
oack += strlen(Options) + 1;
Options += strlen(Options) + 1;
} else if ( _stricmp(Options, "tsize") == 0 ) {
strcpy( oack, Options );
oack += sizeof("tsize");
Options += sizeof("tsize");
if ( Opcode == TFTPD_WRQ ) {
strcpy( oack, Options );
oack += strlen(Options) + 1;
Options += strlen(Options) + 1;
} else {
_itoa( Request->FileSize, oack, 10 );
oack += strlen(oack) + 1;
Options += strlen(Options) + 1;
}
#if defined(REMOTE_BOOT_SECURITY)
} else if ( _stricmp(Options, "security") == 0 ) {
//
// We process this just so that we can copy it to the OACK.
//
// Should really copy over Request->Security, in case
// it has since become 0, to show the client we reject the
// security option for some reason.
//
strcpy( oack, Options );
oack += sizeof("security");
Options += sizeof("security");
if ( Opcode == TFTPD_RRQ ) {
strcpy( oack, Options );
oack += strlen(Options) + 1;
}
Options += strlen(Options) + 1;
#endif //defined (REMOTE_BOOT_SECURITY)
} else {
//
// Unrecognized option. Skip the option ID and the value.
//
Options += strlen(Options) + 1;
if ( *Options != 0 ) {
Options += strlen(Options) + 1;
}
}
}
*OackLength =(int)(oack - PacketBuffer);
if ( *OackLength == 2 ) {
*OackLength = 0;
}
return 0;
}
#define IS_SEPARATOR(c) (((c) == '\\') || ((c) == '/'))
BOOL
TftpdCanonicalizeFileName(
IN OUT PUCHAR FileName
)
{
PUCHAR destination;
PUCHAR source;
PUCHAR lastComponent;
//
// The canonicalization is done in place. Initialize the source and
// destination pointers to point to the same place.
//
source = FileName;
destination = FileName;
//
// The lastComponent variable is used as a placeholder when
// backtracking over trailing blanks and dots. It points to the
// first character after the last directory separator or the
// beginning of the pathname.
//
lastComponent = FileName;
//
// Get rid of leading directory separators.
//
while ( (*source != 0) && IS_SEPARATOR(*source) ) {
source++;
}
//
// Walk through the pathname until we reach the zero terminator. At
// the start of this loop, Input points to the first charaecter
// after a directory separator or the first character of the
// pathname.
//
while ( *source != 0 ) {
if ( *source == '.' ) {
//
// If we see a dot, look at the next character.
//
if ( IS_SEPARATOR(*(source+1)) ) {
//
// If the next character is a directory separator,
// advance the source pointer to the directory
// separator.
//
source++;
} else if ( (*(source+1) == '.') && IS_SEPARATOR(*(source+2)) ) {
//
// If the following characters are ".\", we have a "..\".
// Advance the source pointer to the "\".
//
source += 2;
//
// Move the destination pointer to the character before the
// last directory separator in order to prepare for backing
// up. This may move the pointer before the beginning of
// the name pointer.
//
destination -= 2;
//
// If destination points before the beginning of the name
// pointer, fail because the user is attempting to go
// to a higher directory than the TFTPD root. This is
// the equivalent of a leading "..\", but may result from
// a case like "dir\..\..\file".
//
if ( destination <= FileName ) {
return FALSE;
}
//
// Back up the destination pointer to after the last
// directory separator or to the beginning of the pathname.
// Backup to the beginning of the pathname will occur
// in a case like "dir\..\file".
//
while ( destination >= FileName && !IS_SEPARATOR(*destination) ) {
destination--;
}
//
// destination points to \ or character before name; we
// want it to point to character after last \.
//
destination++;
} else {
//
// The characters after the dot are not "\" or ".\", so
// so just copy source to destination until we reach a
// directory separator character. This will occur in
// a case like ".file" (filename starts with a dot).
//
do {
*destination++ = *source++;
} while ( (*source != 0) && !IS_SEPARATOR(*source) );
}
} else { // if ( *source == '.' )
//
// source does not point to a dot, so copy source to
// destination until we get to a directory separator.
//
while ( (*source != 0) && !IS_SEPARATOR(*source) ) {
*destination++ = *source++;
}
}
//
// Truncate trailing blanks. destination should point to the last
// character before the directory separator, so back up over blanks.
//
while ( (destination > lastComponent) && (*(destination-1) == ' ') ) {
destination--;
}
//
// At this point, source points to a directory separator or to
// a zero terminator. If it is a directory separator, put one
// in the destination.
//
if ( IS_SEPARATOR(*source) ) {
//
// If we haven't put the directory separator in the path name,
// put it in.
//
if ( (destination != FileName) && !IS_SEPARATOR(*(destination-1)) ) {
*destination++ = '\\';
}
//
// It is legal to have multiple directory separators, so get
// rid of them here. Example: "dir\\\\\\\\file".
//
do {
source++;
} while ( (source != 0) && IS_SEPARATOR(*source) );
//
// Make lastComponent point to the character after the directory
// separator.
//
lastComponent = destination;
}
}
//
// We're just about done. If there was a trailing .. (example:
// "file\.."), trailing . ("file\."), or multiple trailing
// separators ("file\\\\"), then back up one since separators are
// illegal at the end of a pathname.
//
if ( (destination != FileName) && IS_SEPARATOR(*(destination-1)) ) {
destination--;
}
//
// Terminate the destination string.
//
*destination = L'\0';
return TRUE;
}
BOOL
TftpdPrependStringToFileName(
IN OUT PUCHAR FileName,
IN ULONG FileNameLength,
IN PCHAR Prefix
)
{
BOOL prefixHasSeparator;
BOOL currentFileNameHasSeparator;
ULONG prefixLength;
ULONG separatorLength;
ULONG currentFileNameLength;
prefixLength = strlen( Prefix );
currentFileNameLength = strlen( FileName );
prefixHasSeparator = (BOOL)(Prefix[prefixLength - 1] == '\\');
currentFileNameHasSeparator = (BOOL)(FileName[0] == '\\');
if ( prefixHasSeparator || currentFileNameHasSeparator ) {
separatorLength = 0;
if ( prefixHasSeparator && currentFileNameHasSeparator ) {
prefixLength--;
}
} else {
separatorLength = 1;
}
if ( (prefixLength + separatorLength + currentFileNameLength) > (FileNameLength - 1) ) {
return FALSE;
}
//
// Move the existing string down to make room for the prefix.
//
memmove( FileName + prefixLength + separatorLength, FileName, currentFileNameLength + 1 );
//
// Move the prefix into place.
//
memcpy( FileName, Prefix, prefixLength );
//
// If necessary, insert a backslash between the prefix and the file name.
//
if ( separatorLength != 0 ) {
FileName[prefixLength] = '\\';
}
return TRUE;
}
BOOL
TftpdGetNextReadPacket(
PTFTP_READ_CONTEXT Context,
PTFTP_REQUEST Request
)
/*++
Routine Description:
Arguments:
Return Value:
TRUE: got next packet into Context-Packet
FALSE : error packet in Request->Packet2
--*/
{
#if defined(REMOTE_BOOT_SECURITY)
SECURITY_STATUS SecStatus;
#endif //defined(REMOTE_BOOT_SECURITY)
if ( Context->oackLength != 0 ) {
//
// The first "data packet" sent will really be the OACK.
//
Context->packetLength = Context->oackLength;
Context->oackLength = 0;
Context->BlockNumber = 0;
Context->BytesRead = Request->BlockSize; // to prevent exit condition from being true
} else {
((unsigned short *) Context->Packet)[0] = htons(TFTPD_DATA);
((unsigned short *) Context->Packet)[1] = htons(Context->BlockNumber);
#if defined(REMOTE_BOOT_SECURITY)
if (Request->SecurityHandle) {
if (Context->EncryptBytesSent == 0) {
//
// Read the file before sending the first data packet.
//
Context->BytesRead = _read(
Context->fd,
Context->EncryptFileBuffer + NTLMSSP_MESSAGE_SIGNATURE_SIZE,
Request->FileSize);
if (Context->BytesRead != Request->FileSize) {
DbgPrint("TftpdHandleRead: Could not read EncryptFileBuffer=%d.\n", errno);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// We have it in memory, so encrypt it.
//
Context->SigBuffers[0].pvBuffer = Context->EncryptFileBuffer + NTLMSSP_MESSAGE_SIGNATURE_SIZE;
Context->SigBuffers[0].cbBuffer = Request->FileSize;
Context->SigBuffers[0].BufferType = SECBUFFER_DATA;
Context->SigBuffers[1].pvBuffer = Context->EncryptFileBuffer;
Context->SigBuffers[1].cbBuffer = NTLMSSP_MESSAGE_SIGNATURE_SIZE;
Context->SigBuffers[1].BufferType = SECBUFFER_TOKEN;
Context->SignMessage.pBuffers = Context->SigBuffers;
Context->SignMessage.cBuffers = 2;
Context->SignMessage.ulVersion = 0;
SecStatus = SealMessage(
&Context->Security.ServerContextHandle,
0,
&Context->SignMessage,
0 );
if (SecStatus != STATUS_SUCCESS) {
DbgPrint("TftpdHandleRead: Could not seal message=%d.\n", SecStatus);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Encryption error");
goto cleanup;
}
}
if ((Context->EncryptBytesSent + Request->BlockSize) <= (int)(Request->FileSize + NTLMSSP_MESSAGE_SIGNATURE_SIZE)) {
Context->BytesRead = Request->BlockSize;
} else {
Context->BytesRead = (Request->FileSize + NTLMSSP_MESSAGE_SIGNATURE_SIZE) - Context->EncryptBytesSent;
}
memcpy(
&Context->Packet[4],
Context->EncryptFileBuffer + Context->EncryptBytesSent,
Context->BytesRead);
Context->EncryptBytesSent += Context->BytesRead;
} else
#endif //defined(REMOTE_BOOT_SECURITY)
{
//
// read BlockSize bytes (or whatever's left)
//
Context->BytesRead = _read(
Context->fd,
&Context->Packet[4],
Request->BlockSize);
if( Context->BytesRead == -1 ){
DbgPrint("TftpdHandleRead: read failed=%d\n", errno );
SetLastError( errno );
goto cleanup;
}
if (Context->BytesRead != Request->BlockSize) {
DbgPrint("GetNextReadPacket read %d bytes\n",Context->BytesRead);
}
}
Context->packetLength = 4 + Context->BytesRead;
}
return TRUE;
cleanup:
return FALSE;
}
DWORD
TftpdAddContextToList(PLIST_ENTRY pEntry)
{
EnterCriticalSection(&Globals.Lock);
InsertHeadList(&Globals.WorkList,pEntry);
DbgPrint("Adding 0x%X to global list\n", pEntry);
LeaveCriticalSection(&Globals.Lock);
return TRUE;
}
PVOID
TftpdFindContextInList(SOCKET Sock)
/*++
Routine Description:
Look for context based upon Socket descriptor. If found, return pointer to context with lock held
You must release the lock via a call to TftpdReleaseContextLock().
For now, simple linked list walk. Move to hash table if time permits.
Arguments:
Argument - socket
Return Value:
NULL, failed to find context
--*/
{
PLIST_ENTRY pEntry;
PTFTP_CONTEXT_HEADER Context;
EnterCriticalSection(&Globals.Lock);
for ( pEntry = Globals.WorkList.Flink;
pEntry != &Globals.WorkList;
pEntry = pEntry->Flink) {
Context=CONTAINING_RECORD(pEntry, TFTP_CONTEXT_HEADER, ContextLinkage);
if (Context->Sock == Sock) {
// Found it
EnterCriticalSection(&Context->Lock);
if (!Context->Closing) {
Context->RefCount++;
} else {
LeaveCriticalSection(&Context->Lock);
Context = NULL;
}
LeaveCriticalSection(&Globals.Lock);
return (Context);
}
}
LeaveCriticalSection(&Globals.Lock);
return(NULL);
}
void
TftpdReleaseContextLock(
PTFTP_CONTEXT_HEADER Context
)
/*++
Routine Description:
Used to leave any context critical section entered via TftpdFindContextInList().
Arguments:
Argument - Context
Return Value:
None.
--*/
{
assert(Context->RefCount > 0);
Context->RefCount--;
if (Context->Closing && (Context->RefCount == 0)) {
Context->IdleCount=0;
LeaveCriticalSection(&Context->Lock);
TftpdRemoveContextFromList((PTFTP_CONTEXT_HEADER)Context);
} else {
LeaveCriticalSection(&Context->Lock);
}
}
VOID
TftpdRemoveContextFromList(PTFTP_CONTEXT_HEADER Context)
/*++
Routine Description:
Look for context. If found, remove it, free all resources
For now, simple linked list walk. Move to hash table if time permits.
Arguments:
Argument - socket
Return Value:
--*/
{
PLIST_ENTRY pEntry;
PLIST_ENTRY pNextEntry;
PTFTP_CONTEXT_HEADER LocalContext;
EnterCriticalSection(&Globals.Lock);
pEntry=Globals.WorkList.Flink;
while (pEntry != &Globals.WorkList)
{
LocalContext=CONTAINING_RECORD(pEntry, TFTP_CONTEXT_HEADER, ContextLinkage);
pNextEntry=pEntry->Flink;
if (Context == LocalContext) {
// Found it
assert(Context->Closing);
assert(Context->RefCount == 0);
RemoveEntryList(pEntry);
DbgPrint("Removing 0x%X from global list\n", pEntry);
LeaveCriticalSection(&Globals.Lock);
DbgPrint("Removing connection to port %d\n",htons(Context->ForeignAddress.sin_port));
TftpdFreeContext(Context);
return;
}
pEntry=pNextEntry;
}
LeaveCriticalSection(&Globals.Lock);
}
VOID
TftpdFreeGeneralContextFields(PTFTP_CONTEXT_HEADER Context)
{
if (Context->Sock != INVALID_SOCKET) {
closesocket(Context->Sock);
DbgPrint("TftpdFreeGeneralContextFields: Close Socket %d\n",Context->Sock);
}
if (Context->TimerHandle) {
RtlDeleteTimer(Globals.TimerQueueHandle,Context->TimerHandle,NULL);
}
Context->TimerHandle = 0;
if (Context->Packet != NULL) {
free(Context->Packet);
Context->Packet = NULL;
}
RtlDeregisterWaitEx(Context->WaitEvent,NULL);
CloseHandle(Context->SocketEvent);
DeleteCriticalSection(&Context->Lock);
}
VOID
TftpdFreeReadContext(PTFTP_READ_CONTEXT Context)
{
#if defined(REMOTE_BOOT_SECURITY)
if (Context->EncryptFileBuffer) {
free(Context->EncryptFileBuffer);
}
#endif //defined(REMOTE_BOOT_SECURITY)
if (Context->fd != -1) {
_close(Context->fd);
}
free(Context);
}
VOID
TftpdFreeWriteContext(PTFTP_WRITE_CONTEXT Context)
{
if (Context->fd != -1) {
_close(Context->fd);
}
free(Context);
}
VOID
TftpdFreeLoginContext(PTFTP_LOGIN_CONTEXT Context)
{
}
VOID
TftpdFreeKeyContext(PTFTP_KEY_CONTEXT Context)
{
}
VOID
TftpdFreeContext(PTFTP_CONTEXT_HEADER Context)
{
if (Context == NULL) {
DbgPrint("TftpdFreeContext: Called with Null context");
return;
}
TftpdFreeGeneralContextFields(Context);
switch (Context->ContextType) {
case READ_CONTEXT:
TftpdFreeReadContext((PTFTP_READ_CONTEXT)Context);
break;
case WRITE_CONTEXT:
TftpdFreeWriteContext((PTFTP_WRITE_CONTEXT)Context);
break;
case LOGIN_CONTEXT:
TftpdFreeLoginContext((PTFTP_LOGIN_CONTEXT)Context);
break;
case KEY_CONTEXT:
TftpdFreeKeyContext((PTFTP_KEY_CONTEXT)Context);
break;
}
}
VOID
TftpdReaper(PVOID ReaperContext,
BOOLEAN Flag)
/*++
Routine Description:
Walk WorkList looking for inactive Contexts
Arguments:
Return Value:
--*/
{
PLIST_ENTRY pEntry;
PLIST_ENTRY pNextEntry;
PTFTP_CONTEXT_HEADER Context;
EnterCriticalSection(&Globals.Lock);
pEntry = Globals.WorkList.Flink;
while (pEntry != &Globals.WorkList)
{
Context=CONTAINING_RECORD(pEntry, TFTP_CONTEXT_HEADER, ContextLinkage);
EnterCriticalSection(&Context->Lock);
pNextEntry=pEntry->Flink;
Context->IdleCount++;
if ((Context->IdleCount >= DEAD_CONTEXT_COUNT) &&
!Context->Closing &&
(Context->RefCount == 0)) {
// Context is dead.
Context->Closing = TRUE;
DbgPrint("Reaping connection to port %d",htons(Context->ForeignAddress.sin_port));
LeaveCriticalSection(&Context->Lock);
LeaveCriticalSection(&Globals.Lock);
TftpdRemoveContextFromList((PTFTP_CONTEXT_HEADER)Context);
EnterCriticalSection(&Globals.Lock);
} else {
LeaveCriticalSection(&Context->Lock);
}
pEntry=pNextEntry;
}
LeaveCriticalSection(&Globals.Lock);
TftpdCleanHeap();
}
VOID
TftpdRetransmit(PVOID RetransContext,
BOOLEAN Flag)
{
PTFTP_READ_WRITE_CONTEXT_HEADER Context;
BOOL Status;
NTSTATUS ntStatus;
Context=(PTFTP_READ_WRITE_CONTEXT_HEADER)TftpdFindContextInList((SOCKET)RetransContext);
if (Context == NULL) {
DbgPrint("TftpdRetransmit: Unable to find context\n");
return;
}
if (Context->RetransmissionCount < MAX_TFTPD_RETRIES) {
if (Context->RetransmissionCount > 5) {
SYSTEMTIME _st;
GetLocalTime(&_st);
DbgPrint("%2d-%02d: %02d:%02d:%02d TftpdRetransmit: Socket %d DstPort %d Count %d BlkNum %d\n",
_st.wMonth,_st.wDay,_st.wHour,_st.wMinute,_st.wSecond,
(DWORD)((DWORD_PTR)RetransContext),
ntohs(Context->ForeignAddress.sin_port),
Context->RetransmissionCount,
htons(((unsigned short*)(Context->Packet))[1]));
}
Status = sendto(
Context->Sock,
Context->Packet,
Context->packetLength,
0,
(struct sockaddr *) &Context->ForeignAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleRead: sendto failed=%d\n",
WSAGetLastError() );
}
Context->RetransmissionCount++;
Context->IdleCount = 0; // don't accidently reap this connection during retransmit tries.
if (Context->TimerHandle) {
if (!Context->FixedTimer) {
Context->DueTime *= 2;
if (Context->DueTime > (TFTPD_MAX_TIMEOUT * 1000)) {
Context->DueTime = (TFTPD_MAX_TIMEOUT * 1000);
}
}
ntStatus=RtlUpdateTimer(Globals.TimerQueueHandle,
Context->TimerHandle,
Context->DueTime,
Context->DueTime);
if (ntStatus != STATUS_SUCCESS) {
DbgPrint("TftpdRetransmit: UpdateTimerFailed %d",GetLastError());
}
}
} else {
//Send timeout
TftpdErrorPacket((struct sockaddr *) &Context->ForeignAddress,
NULL,
Context->Sock,
TFTPD_ERROR_UNDEFINED,
"Timeout"
);
Context->Closing = TRUE;
}
TftpdReleaseContextLock((PTFTP_CONTEXT_HEADER)Context);
}
DWORD
TftpdResumeRead(
PTFTP_READ_CONTEXT Context,
PTFTP_REQUEST Request
)
/*++
Routine Description:
Resumes processing of existing read request. Context lock held when function is called.
Arguments:
Argument - buffer containing the read request datagram
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
s--*/
{
BOOL Acked=FALSE;
BOOL Status=FALSE;
int SendStatus=0;
BOOL Retrans=FALSE;
NTSTATUS Stat;
//
// Parse the request
//
DbgPrint("TftpdResumeRead BlockNum %d\n",Context->BlockNumber);
Request->BlockSize=Context->BlockSize;
if (CHECK_ACK(Request->Packet1, TFTPD_ACK, Context->BlockNumber)) {
Acked = TRUE;
Context->RetransmissionCount=0;
} else {
DbgPrint("Ack failed: Expect Blk %d Received Blk %d OpCode %d",
Context->BlockNumber,
ntohs((((unsigned short *)Request->Packet1)[1])),
htons(*((unsigned short *) (Request->Packet1))));
if (CHECK_ACK(Request->Packet1, TFTPD_ACK, Context->BlockNumber-1)) {
Retrans=TRUE;
}
}
if (Acked) {
if (Context->Done) {
Context->Closing = TRUE;
return 0;
}
if (++Context->BlockNumber == 0)
Context->BlockNumber = 1; // 32 MB file roll-over.
Status=TftpdGetNextReadPacket(Context,Request);
if (!Status) {
DbgPrint("GetNextPacketFailed %d",ntohs(Request->ForeignAddress.sin_port));
return 0;
}
Context->RetransmissionCount=0;
Context->IdleCount=0;
if (!Context->FixedTimer) {
// received new packet, reset timer
Context->DueTime=TFTPD_INITIAL_TIMEOUT*1000;
}
if (Context->TimerHandle) {
Stat=RtlUpdateTimer(Globals.TimerQueueHandle,
Context->TimerHandle,
Context->DueTime,
Context->DueTime);
if (!NT_SUCCESS(Stat)) {
DbgPrint("Failed to Update Timer");
}
}
//
// If we've sent the whole file, exit the loop. Note that we
// don't send an error packet if there is a timeout on the last
// data packet, because the receiver might have only sent the
// ACK once, then forgotten about this transfer.
//
if (Context->BytesRead < Request->BlockSize) {
SOCKET Sock;
DbgPrint("We're done with %d\n",ntohs(Request->ForeignAddress.sin_port));
Context->Done=TRUE;
}
}
if (Status) {
// Got a valid packet to send
DbgPrint("TftpdResumeRead: Sending data BlkNumber %d Socket %d PeerPort %d Size %d\n",Context->BlockNumber, Context->Sock, ntohs(Request->ForeignAddress.sin_port), Context->packetLength);
SendStatus = sendto(
Context->Sock,
Context->Packet,
Context->packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == SendStatus ){
DbgPrint("TftpdHandleRead: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
}
return 0;
cleanup:
return 1;
}
DWORD
TftpdResumeWrite(
PTFTP_WRITE_CONTEXT Context,
PTFTP_REQUEST Request
)
/*++
Routine Description:
Resumes processing of existing write request. Context lock held when function is called.
Arguments:
Argument - buffer containing the write request datagram
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
BOOL NewData=FALSE;
char State;
int BytesWritten;
int Status;
NTSTATUS Stat;
DbgPrint("Request Blocksize %d Context Blocksize %d\n",Request->BlockSize,Context->BlockSize);
Request->BlockSize=Context->BlockSize;
DbgPrint("TftpdResumeWrite: PktNum %d DataSize %d\n",Context->BlockNumber+1,Request->DataSize-4);
if (CHECK_ACK(Request->Packet1, TFTPD_DATA, Context->BlockNumber+1)) {
NewData = TRUE;
Context->BlockNumber++;
Context->IdleCount=0;
} else {
if (CHECK_ACK(Request->Packet1, TFTPD_DATA, Context->BlockNumber)) {
// resend ack
((unsigned short *) Context->Packet)[0] = htons(TFTPD_ACK);
((unsigned short *) Context->Packet)[1] = htons(Context->BlockNumber);
Status =
sendto(
Context->Sock,
Context->Packet,
4,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
DbgPrint("TftpdResumeWrite: Resending Ack %d\n",Context->BlockNumber);
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleWrite: sendto failed=%d\n",
WSAGetLastError() );
}
return 0;
}
}
State = '\0';
if (NewData) {
BytesWritten =
TftpdDoWrite(Context->fd, &Request->Packet1[4], Request->DataSize - 4, Context->FileMode, &State);
}
if (!NewData) {
DbgPrint("TftpdHandleWrite: Timed out waiting for ack\n");
/*
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Timeout");
*/
goto cleanup;
} else if (BytesWritten < 0) {
DbgPrint("TftpdHandleWrite: disk full?\n");
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_DISK_FULL,
NULL);
goto cleanup;
} else if (Request->DataSize - 4 <= Request->BlockSize ) {
//
// Ack the last packet
//
((unsigned short *) Context->Packet)[0] = htons(TFTPD_ACK);
((unsigned short *) Context->Packet)[1] = htons(Context->BlockNumber);
Status =
sendto(
Context->Sock,
Context->Packet,
4,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
DbgPrint("TftpdResumeWrite: Sending Ack %d\n",Context->BlockNumber);
if (Context->TimerHandle) {
Context->RetransmissionCount=0;
if (!Context->FixedTimer) {
// received new packet, reset timer
Context->DueTime=TFTPD_INITIAL_TIMEOUT*1000;
}
Stat=RtlUpdateTimer(Globals.TimerQueueHandle,
Context->TimerHandle,
Context->DueTime,
Context->DueTime);
if (!NT_SUCCESS(Stat)) {
DbgPrint("Failed to Update Timer");
}
}
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleWrite: sendto failed=%d\n",
WSAGetLastError() );
}
if (Request->DataSize - 4 < Request->BlockSize ) {
// we're done. flag for speedy cleanup
Context->Closing = TRUE;
}
}
return 0;
cleanup:
return 1;
}
DWORD
TftpdResumeLogin(
PTFTP_LOGIN_CONTEXT Context,
PTFTP_REQUEST Request
)
/*++
Routine Description:
Resumes processing of existing login request. Context lock held when function is called. Lock released upon exiting.
Arguments:
Argument - buffer containing the login request datagram
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
return 0;
}
DWORD
TftpdResumeKey(
PTFTP_KEY_CONTEXT Context,
PTFTP_REQUEST Request
)
/*++
Routine Description:
Resumes processing of existing key request. Context lock held when function is called. Lock released upon exiting.
Arguments:
Argument - buffer containing the key request datagram
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
return 0;
}
VOID
TftpdResumeProcessing(PVOID Argument)
/*++
Routine Description:
Resume work, if possible
Arguments:
Argument - buffer containing the incoming datagram
Return Value:
--*/
{
PTFTP_REQUEST Request=(PTFTP_REQUEST)Argument;
PTFTP_CONTEXT_HEADER Context;
DWORD Status;
Context=(PTFTP_CONTEXT_HEADER)TftpdFindContextInList(Request->TftpdPort);
if (Context == NULL) {
DbgPrint("Invalid request on port %d", Request->TftpdPort);
return;
}
if ((Context->ForeignAddress.sin_family != Request->ForeignAddress.sin_family) ||
(Context->ForeignAddress.sin_addr.s_addr != Request->ForeignAddress.sin_addr.s_addr) ||
(Context->ForeignAddress.sin_port != Request->ForeignAddress.sin_port)) {
TftpdReleaseContextLock(Context);
DbgPrint("Invalid request on port %d", Request->TftpdPort);
return;
}
switch (Context->ContextType) {
case READ_CONTEXT:
Status=TftpdResumeRead((PTFTP_READ_CONTEXT)Context, Request);
break;
case WRITE_CONTEXT:
TftpdResumeWrite((PTFTP_WRITE_CONTEXT)Context, Request);
break;
case LOGIN_CONTEXT:
TftpdResumeLogin((PTFTP_LOGIN_CONTEXT)Context, Request);
break;
case KEY_CONTEXT:
TftpdResumeKey((PTFTP_KEY_CONTEXT)Context, Request);
break;
}
TftpdReleaseContextLock(Context);
return;
}
/*
Make sure incoming name is null terminated
*/
BOOL IsFileNameValid(char* FileName, DWORD MaxLen)
{
DWORD i;
// Make sure Filename has null terminator
for (i=0; i < MaxLen; i++) {
if (FileName[i] == (char)0 ) {
return TRUE;
}
}
return FALSE;
}
// ========================================================================
DWORD
TftpdHandleRead(
PVOID Argument
)
/*++
Routine Description:
This handles an incoming read file request.
Arguments:
Argument - buffer containing the read request datagram
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
BOOL Acked;
int AddressLength;
int BytesAck;
int BytesRead;
char * CharPtr;
struct fd_set exceptfds;
int FileMode;
char * FileName;
char * ReadMode;
char * NewPacket;
struct sockaddr_in ReadAddress;
struct fd_set readfds;
SOCKET ReadPort = INVALID_SOCKET;
PTFTP_REQUEST Request;
int Status, err;
struct timeval timeval;
char * client_ipaddr;
short client_port;
BOOL LockHeld=FALSE;
BOOL AddedContext=FALSE;
int length;
#if defined(REMOTE_BOOT_SECURITY)
SECURITY_STATUS SecStatus;
#endif //REMOTE_BOOT_SECURITY)
PTFTP_READ_CONTEXT Context = NULL;
NTSTATUS ntStatus;
//
// Parse the request
//
DbgPrint("Entered Handle read\n");
Request = (PTFTP_REQUEST) Argument;
FileName = &Request->Packet1[2];
if (!IsFileNameValid(FileName,MAX_TFTP_DATAGRAM-2)) {
goto cleanup;
}
ReadMode = FileName + (length = strlen(FileName)) + 1;
// Make sure ReadMode is NUL terminated.
if (!IsFileNameValid(ReadMode, MAX_TFTP_DATAGRAM - (length + 1))) {
DbgPrint("TftpdHandleRead: invalid ReadMode\n");
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
// Set up context.
Context=(PTFTP_READ_CONTEXT)malloc(sizeof(TFTP_READ_CONTEXT));
if (Context == NULL) {
goto cleanup;
}
memset(Context,0,sizeof(TFTP_READ_CONTEXT));
Context->Packet = (char *)malloc(MAX_OACK_PACKET_LENGTH);
if (Context->Packet == NULL) {
goto cleanup;
}
//
// Profile data.
//
client_ipaddr = inet_ntoa( Request->ForeignAddress.sin_addr );
if (client_ipaddr == NULL)
client_ipaddr = "";
client_port = htons( Request->ForeignAddress.sin_port );
DbgPrint("TftpdHandleRead: FileName=%s, ReadMode=%s, from=%s:%d.\n",
FileName, ReadMode,
client_ipaddr,
client_port );
//
// Convert the mode to all lower case for comparison
//
for (CharPtr = ReadMode; *CharPtr; CharPtr++) {
*CharPtr = (char)tolower(*CharPtr);
}
if (strcmp(ReadMode, "netascii") == 0) {
FileMode = O_TEXT;
DbgPrint("TftpdHandleRead: netascii mode.\n");
} else if (strcmp(ReadMode, "octet") == 0) {
FileMode = O_BINARY;
DbgPrint("TftpdHandleRead: binary mode.\n");
} else {
DbgPrint("TftpdHandleRead: invalid ReadMode=%s?\n", ReadMode );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
#if defined(REMOTE_BOOT_SECURITY)
err = TftpdProcessOptionsPhase1( Request, CharPtr + 1, TFTPD_RRQ );
if ( err != 0 ) {
goto cleanup;
}
if (Request->SecurityHandle) {
//
// This returns TRUE (and the security entry) if the sign
// for this file is valid.
//
SecStatus = TftpdVerifyFileSignature(
(USHORT)(Request->SecurityHandle >> 16), // index
(USHORT)(Request->SecurityHandle & 0xffff), // validation
&Context->Security,
FileName,
Request->Sign,
client_port);
//
// This error code is known to mean an invalid security handle.
//
if ( SecStatus == (SECURITY_STATUS)STATUS_INVALID_HANDLE ) {
DbgPrint("TftpdHandleRead: SecurityHandle %x is invalid.\n",
Request->SecurityHandle);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Invalid security handle");
goto cleanup;
} else if ( SecStatus != SEC_E_OK ) {
DbgPrint("TftpdHandleRead: sign is invalid.\n");
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Invalid sign");
goto cleanup;
}
}
#endif // defined(REMOTE_BOOT_SECURITY)
//
// Canonicalize the file name.
//
DbgPrint("TftpdHandleRead: Canonicalizing name.\n");
strcpy( Request->Packet3, FileName );
if ( !TftpdCanonicalizeFileName(Request->Packet3) ) {
DbgPrint("TftpdHandleRead: invalid FileName=%s\n", FileName );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Malformed file name");
goto cleanup;
}
//
// Check whether this access is permitted.
//
if( !( match( ValidClients, client_ipaddr )
|| match( ValidMasters, client_ipaddr ) )
|| !match( ValidReadFiles, Request->Packet3 )
){
DbgPrint("TftpdHandleRead: cannot open file=%s, errno=%d.\n"
" client %s:%d,\n"
" ValidReadFiles=%s, ValidClients=%s, ValidMasters=%s,\n"
,
Request->Packet3, errno,
client_ipaddr, client_port,
ValidReadFiles, ValidClients, ValidMasters
);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ACCESS_VIOLATION,
NULL);
goto cleanup;
}
//
// Prepend the start directory to the file name.
//
DbgPrint("TftpdHandleRead: Prepending directory name.\n");
if ( !TftpdPrependStringToFileName(
Request->Packet3,
sizeof(Request->Packet3),
StartDirectory) ) {
DbgPrint("TftpdHandleRead: too long FileName=%s\n", FileName );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"File name too long");
goto cleanup;
}
//
// Open the file.
//
DbgPrint("TftpdHandleRead: opening file <%s>\n", Request->Packet3 );
Context->fd = _open(Request->Packet3, O_RDONLY | O_BINARY);
if (Context->fd == -1) {
SetLastError( errno );
DbgPrint("TftpdHandleRead: cannot open file %s, errno=%d.\n", Request->Packet3, errno );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_FILE_NOT_FOUND,
NULL);
goto cleanup;
}
err = _lseek(Context->fd, 0, SEEK_END);
if ( err != -1 ) {
Request->FileSize = err;
err = _lseek(Context->fd, 0, SEEK_SET);
}
if( err == -1 ){
DbgPrint("TftpdHandleRead: lseek failed, errno=%d\n",
errno );
SetLastError( errno );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Open a new socket for this request
//
ReadPort =
socket(
AF_INET,
SOCK_DGRAM,
0);
DbgPrint("TftpdHandleRead: New Socket %d\n",ReadPort);
if (ReadPort == INVALID_SOCKET) {
DbgPrint("TftpdHandleRead: cannot open socket, Error=%d\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Bind to a random address
//
ReadAddress.sin_family = AF_INET;
ReadAddress.sin_port = 0;
ReadAddress.sin_addr.s_addr = Request->MyAddr;
Status =
bind(
ReadPort,
(struct sockaddr *) &ReadAddress,
sizeof(ReadAddress)
);
if (Status) {
DbgPrint("TftpdHandleRead: cannot bind socket, error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
Request->TftpdPort = ReadPort;
// Enter Context into list now that we know the port
InitializeCriticalSection(&Context->Lock);
Context->Sock=ReadPort;
memcpy(&Context->ForeignAddress,&Request->ForeignAddress,sizeof(struct sockaddr_in));
Context->SocketEvent=CreateEvent(NULL,FALSE,FALSE,NULL);
if (Context->SocketEvent == NULL) {
DbgPrint("Failed to create socket event %d",GetLastError());
goto cleanup;
}
Context->WaitEvent=RegisterSocket(ReadPort,Context->SocketEvent,REG_CONTINUE_SOCKET);
if (Context->WaitEvent == NULL) {
DbgPrint("Failed to create socket event %d",GetLastError());
goto cleanup;
}
// Insert Context
TftpdAddContextToList(&Context->ContextLinkage);
AddedContext=TRUE;
Context=(PTFTP_READ_CONTEXT)TftpdFindContextInList(ReadPort);
if (Context == NULL) {
DbgPrint("Failed to Lookup ReadContext");
goto cleanup;
}
LockHeld=TRUE;
err = TftpdProcessOptionsPhase2( Request, CharPtr + 1, TFTPD_RRQ, &Context->oackLength,Context->Packet,
&Context->FixedTimer);
if ( err != 0 ) {
goto cleanup;
}
// Start retransmission timer
if (Context->FixedTimer) {
Context->DueTime=Request->Timeout*1000;
} else {
Context->DueTime=TFTPD_INITIAL_TIMEOUT*1000;
}
DbgPrint("TftpdHandleRead: Timer Interval %d msecs",Context->DueTime);
ntStatus=RtlCreateTimer(Globals.TimerQueueHandle,
&Context->TimerHandle,
TftpdRetransmit,
(PVOID)Context->Sock,
Context->DueTime,
Context->DueTime,
0);
if (!NT_SUCCESS(ntStatus)) {
DbgPrint("Failed to Arm Timer %d",ntStatus);
}
Context->ContextType=READ_CONTEXT;
Context->BlockSize=Request->BlockSize;
if (Context->BlockSize > MAX_OACK_PACKET_LENGTH - 4) {
DbgPrint("TftpdHandleRead: Reallocating packet.\n");
NewPacket = (char *)realloc(Context->Packet, Context->BlockSize + 4);
if (NewPacket == NULL) {
goto cleanup;
}
Context->Packet = NewPacket;
}
#if defined(REMOTE_BOOT_SECURITY)
if (Request->SecurityHandle) {
//
// For secure mode, we read the whole file in at once so
// we can encrypt it. For large files like ntoskrnl.exe,
// will this work? If we get errors here, we could
// just change the oack to say "security 0" and then send
// down the unencrypted file -- maybe we should also do this
// for files beyond a certain size.
//
Context->EncryptFileBuffer = malloc(Request->FileSize + NTLMSSP_MESSAGE_SIGNATURE_SIZE);
if (Context->EncryptFileBuffer == NULL) {
DbgPrint("TftpdHandleRead: Could not allocate EncryptFileBuffer length %d.\n",
Request->FileSize + NTLMSSP_MESSAGE_SIGNATURE_SIZE);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// We don't actually read/seal until later -- this is so we can
// send the OACK out right now to prevent more threads from
// being spawned if he resends the initial request.
//
Context->EncryptBytesSent = 0;
}
#endif //defined(REMOTE_BOOT_SECURITY)
//
// Ready to read and send file in blocks.
//
Context->BlockNumber = 1;
Status=TftpdGetNextReadPacket(Context,Request);
Context->RetransmissionCount=0;
if (Status) {
// Got a valid packet to send
Status = sendto(
ReadPort,
Context->Packet,
Context->packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
if (Context->BytesRead < Request->BlockSize) {
Context->Done=TRUE;
}
} else {
// send error packet
Status = sendto(
ReadPort,
Request->Packet2,
Context->packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
}
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleRead: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
cleanup:
if (Context != NULL) {
if (LockHeld) {
TftpdReleaseContextLock((PTFTP_CONTEXT_HEADER)Context);
}
if (!AddedContext) {
free(Context);
}
}
return 0;
}
// ========================================================================
/*++
Routine Description:
This handles an incoming write file request.
Arguments:
Argument - buffer containing the write request
Return Value:
Exit status
0 == success
>0 == failure
--*/
DWORD
TftpdHandleWrite(
PVOID Argument
)
{
int AddressLength;
int BytesRead;
int BytesWritten;
char * CharPtr;
struct fd_set exceptfds;
char * FileName;
char * NewPacket;
BOOL NewData;
struct sockaddr_in ReadAddress;
struct fd_set readfds;
SOCKET ReadPort = INVALID_SOCKET;
int Retry;
char State;
int Status, err;
struct timeval timeval;
char * WriteMode;
PTFTP_REQUEST Request;
int oackLength;
int packetLength;
char * client_ipaddr;
short client_port;
BOOL LockHeld=FALSE;
BOOL AddedContext=FALSE;
int length;
PTFTP_WRITE_CONTEXT Context = NULL;
NTSTATUS ntStatus;
// Set up context.
Context=(PTFTP_WRITE_CONTEXT)malloc(sizeof(TFTP_WRITE_CONTEXT));
if (Context == NULL) {
goto cleanup;
}
memset(Context,0,sizeof(TFTP_WRITE_CONTEXT));
Context->Packet = (char *)malloc(MAX_OACK_PACKET_LENGTH);
if (Context->Packet == NULL) {
goto cleanup;
}
//
// Parse the request
//
Request = (PTFTP_REQUEST) Argument;
FileName = &Request->Packet1[2];
if (!IsFileNameValid(FileName,MAX_TFTP_DATAGRAM-2)) {
goto cleanup;
}
WriteMode = FileName + (length = strlen(FileName)) + 1;
// Make sure WriteMode is NUL terminated.
if (!IsFileNameValid(WriteMode, MAX_TFTP_DATAGRAM - (length + 1))) {
DbgPrint("TftpdHandleWrite: invalid WriteMode\n");
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
//
// Profile data.
//
client_ipaddr = inet_ntoa( Request->ForeignAddress.sin_addr );
if (client_ipaddr == NULL)
client_ipaddr = "";
client_port = htons( Request->ForeignAddress.sin_port );
DbgPrint("TftpdHandleWrite: FileName=%s, WriteMode=%s, from=%s:%d.\n",
FileName, WriteMode,
client_ipaddr, client_port
);
for (CharPtr = WriteMode; *CharPtr; CharPtr ++) {
*CharPtr = isupper(*CharPtr) ? tolower(*CharPtr) : *CharPtr;
}
if (strcmp(WriteMode, "netascii") == 0) {
Context->FileMode = O_TEXT;
} else if (strcmp(WriteMode, "octet") == 0) {
Context->FileMode = O_BINARY;
} else {
DbgPrint("TftpdHandleWrite: invalid WriteMode=%s\n", WriteMode );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
#if defined(REMOTE_BOOT_SECURITY)
err = TftpdProcessOptionsPhase1( Request, CharPtr + 1, TFTPD_WRQ );
if ( err != 0 ) {
goto cleanup;
}
#endif //defined(REMOTE_BOOT_SECURITY)
//
// Canonicalize the file name.
//
strcpy( Request->Packet3, FileName );
if ( !TftpdCanonicalizeFileName(Request->Packet3) ) {
DbgPrint("TftpdHandleWrite: invalid FileName=%s\n", FileName );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Malformed file name");
goto cleanup;
}
//
// Check whether this access is permitted.
//
if( !match( ValidMasters, client_ipaddr )
|| !match( ValidWriteFiles, FileName )
){
DbgPrint("TftpdHandleWrite: cannot open file=%s, errno=%d.\n"
" client %s:%d,\n"
" ValidWriteFiles=%s, ValidClients=%s, ValidMasters=%s,\n"
,
Request->Packet3, errno,
client_ipaddr, client_port,
ValidWriteFiles, ValidClients, ValidMasters
);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ACCESS_VIOLATION,
NULL);
goto cleanup;
}
//
// Prepend the start directory to the file name.
//
if ( !TftpdPrependStringToFileName(
Request->Packet3,
sizeof(Request->Packet3),
StartDirectory) ) {
DbgPrint("TftpdHandleWrite: too long FileName=%s\n", FileName );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"File name too long");
goto cleanup;
}
//
// Open the file.
//
DbgPrint("TftpdHandleWrite: opening file <%s>\n", Request->Packet3 );
Context->fd = _open(Request->Packet3, _O_WRONLY | _O_CREAT | _O_BINARY | _O_TRUNC,
_S_IREAD | _S_IWRITE);
if (Context->fd == -1) {
DbgPrint("TftpdHandleWrite: cannot open file=%s, errno=%d.\n"
" client %s:%d,\n"
" ValidWriteFiles=%s, ValidClients=%s, ValidMasters=%s,\n"
,
FileName, errno,
client_ipaddr, client_port,
ValidWriteFiles, ValidClients, ValidMasters
);
SetLastError( errno );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ACCESS_VIOLATION,
NULL);
goto cleanup;
}
//
// Open a new socket for this request
//
ReadPort =
socket(
AF_INET,
SOCK_DGRAM,
0);
if( ReadPort == INVALID_SOCKET ){
DbgPrint("TftpdHandleWrite: cannot open socket, Error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Bind to a random address
//
ReadAddress.sin_family = AF_INET;
ReadAddress.sin_port = 0;
ReadAddress.sin_addr.s_addr = Request->MyAddr;
Status = bind(
ReadPort,
(struct sockaddr *) &ReadAddress,
sizeof(ReadAddress));
if (Status) {
DbgPrint("TftpdHandleWrite: cannot bind socket, Error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
Request->TftpdPort = ReadPort;
err = TftpdProcessOptionsPhase2( Request, CharPtr + 1, TFTPD_WRQ, &Context->oackLength,Context->Packet,
&Context->FixedTimer);
if ( err != 0 ) {
goto cleanup;
}
State = '\0';
// Enter Context into list now that we know the port
InitializeCriticalSection(&Context->Lock);
Context->Sock=ReadPort;
memcpy(&Context->ForeignAddress,&Request->ForeignAddress,sizeof(struct sockaddr_in));
Context->SocketEvent=CreateEvent(NULL,FALSE,FALSE,NULL);
if (Context->SocketEvent == NULL) {
DbgPrint("Failed to create socket event %d",GetLastError());
goto cleanup;
}
Context->WaitEvent=RegisterSocket(ReadPort,Context->SocketEvent,REG_CONTINUE_SOCKET);
if (Context->WaitEvent == NULL) {
DbgPrint("Failed to create socket event %d",GetLastError());
goto cleanup;
}
// Insert Context
TftpdAddContextToList(&Context->ContextLinkage);
AddedContext=TRUE;
Context=(PTFTP_WRITE_CONTEXT)TftpdFindContextInList(ReadPort);
if (Context == NULL) {
DbgPrint("Failed to Lookup ReadContext");
goto cleanup;
}
LockHeld=TRUE;
// Start retransmission timer
if (Context->FixedTimer) {
Context->DueTime=Request->Timeout*1000;
} else {
Context->DueTime=TFTPD_INITIAL_TIMEOUT*1000;
}
DbgPrint("TftpdHandleWrite: Timer Interval %d msecs\n",Context->DueTime);
ntStatus=RtlCreateTimer(Globals.TimerQueueHandle,
&Context->TimerHandle,
TftpdRetransmit,
(PVOID)Context->Sock,
Context->DueTime,
Context->DueTime,
0);
if (!NT_SUCCESS(ntStatus)) {
DbgPrint("Failed to Arm Timer %d",ntStatus);
}
Context->ContextType=WRITE_CONTEXT;
Context->BlockSize=Request->BlockSize;
if (Context->BlockSize > MAX_OACK_PACKET_LENGTH - 4) {
NewPacket = (char *)realloc(Context->Packet, Context->BlockSize + 4);
if (NewPacket == NULL) {
goto cleanup;
}
Context->Packet = NewPacket;
}
if ( Context->oackLength != 0 ) {
Context->packetLength = Context->oackLength;
Context->oackLength = 0;
} else {
((unsigned short *) Context->Packet)[0] = htons(TFTPD_ACK);
((unsigned short *) Context->Packet)[1] = htons(Context->BlockNumber);
Context->packetLength = 4;
}
Status =
sendto(
ReadPort,
Context->Packet,
Context->packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in)
);
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleWrite: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
cleanup:
if (Context != NULL) {
if (LockHeld) {
TftpdReleaseContextLock((PTFTP_CONTEXT_HEADER)Context);
}
if (!AddedContext) {
free(Context);
}
}
// _chmod(Request->Packet3, _S_IWRITE);
return 0;
}
// End function TftpdHandleWrite.
// ========================================================================
#if defined(REMOTE_BOOT_SECURITY)
DWORD
TftpdHandleLogin(
PVOID Argument
)
/*++
Routine Description:
This handles an incoming login request.
Arguments:
Argument - buffer containing the read request datagram
freed when done.
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
int packetLength;
int Retry;
int Status, err;
struct timeval timeval;
struct sockaddr_in LoginAddress;
BOOL Acked;
int AddressLength;
char * CharPtr;
PTFTP_REQUEST Request;
char * OperationType;
char * PackageName;
char * SecurityString;
char * Options;
SECURITY_STATUS SecStatus;
PSecPkgInfo PackageInfo = NULL;
USHORT Index = -1;
ULONG MaxToken;
TFTPD_SECURITY Security;
ULONG SecurityHandle;
USHORT LastMessageSequence; // sequence number of the last message sent
SOCKET LoginPort = INVALID_SOCKET;
char * IncomingMessage;
SecBufferDesc IncomingDesc;
SecBuffer IncomingBuffer;
SecBufferDesc OutgoingDesc;
SecBuffer OutgoingBuffer;
BOOL FirstChallenge;
struct fd_set exceptfds;
struct fd_set loginfds;
TimeStamp Lifetime;
int BytesAck;
//
// Parse the request. The initial request should always be a
// "login".
//
Request = (PTFTP_REQUEST) Argument;
OperationType = &Request->Packet1[2];
//
// Convert the operation to all lower case for comparison
//
for (CharPtr = OperationType; *CharPtr; CharPtr ++) {
*CharPtr = (char)tolower(*CharPtr);
}
if (strcmp(OperationType, "login") == 0) {
PackageName = OperationType + strlen(OperationType) + 1;
//
// Profile data.
//
DbgPrint("TftpdHandleLogin: OperationType=%s, Package=%s, from=%s.\n",
OperationType, PackageName,
inet_ntoa( Request->ForeignAddress.sin_addr )
);
//
// Check that the security package is known.
//
SecStatus = QuerySecurityPackageInfoA( PackageName, &PackageInfo );
if (SecStatus != STATUS_SUCCESS) {
DbgPrint("TftpdHandleLogin: invalid PackageName=%s?\n", PackageName );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
"invalid security package");
goto cleanup;
}
MaxToken = PackageInfo->cbMaxToken;
FreeContextBuffer(PackageInfo);
//
// Things look OK so far, so let's find a spot in the array of
// security information to store this client.
//
if (!TftpdAllocateSecurityEntry(&Index, &Security)) {
DbgPrint("TftpdHandleLogin: could not allocate security entry\n" );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"insufficient resources");
goto cleanup;
}
//
// Acquire a credential handle for the server side.
//
SecStatus = AcquireCredentialsHandleA(
NULL, // New principal
PackageName, // Package Name
SECPKG_CRED_INBOUND,
NULL,
NULL,
NULL,
NULL,
&(Security.CredentialsHandle),
&Lifetime );
if ( SecStatus != STATUS_SUCCESS ) {
DbgPrint("TftpdHandleLogin: AcquireCredentialsHandle failed %x\n", SecStatus );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"insufficient resources");
goto cleanup;
}
Security.CredentialsHandleValid = TRUE;
//
// Open a new socket for this request
//
LoginPort =
socket(
AF_INET,
SOCK_DGRAM,
0);
if (LoginPort == INVALID_SOCKET) {
DbgPrint("TftpdHandleLogin: cannot open socket, Error=%d\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Bind to a random address
//
LoginAddress.sin_family = AF_INET;
LoginAddress.sin_port = 0;
LoginAddress.sin_addr.s_addr = INADDR_ANY;
Status =
bind(
LoginPort,
(struct sockaddr *) &LoginAddress,
sizeof(LoginAddress)
);
if (Status) {
DbgPrint("TftpdHandleLogin: cannot bind socket, error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
Request->TftpdPort = LoginPort;
Request->Timeout = 10; // let client set this in login packet?
//
// Ready to do exchanges until login is complete.
//
LastMessageSequence = (USHORT)-1;
FirstChallenge = TRUE;
IncomingMessage = PackageName + strlen(PackageName) + 1;
while (1) {
IncomingDesc.ulVersion = 0;
IncomingDesc.cBuffers = 1;
IncomingDesc.pBuffers = &IncomingBuffer;
IncomingBuffer.cbBuffer = (ntohl)(((unsigned long UNALIGNED *)IncomingMessage)[0]);
IncomingBuffer.BufferType = SECBUFFER_TOKEN | SECBUFFER_READONLY;
IncomingBuffer.pvBuffer = IncomingMessage + 4;
OutgoingDesc.ulVersion = 0;
OutgoingDesc.cBuffers = 1;
OutgoingDesc.pBuffers = &OutgoingBuffer;
OutgoingBuffer.cbBuffer = MaxToken;
OutgoingBuffer.BufferType = SECBUFFER_TOKEN;
OutgoingBuffer.pvBuffer = Request->Packet2 + 8;
//
// Pass the client buffer to the security system -- the first time
// we don't have a valid SecurityContextHandle, so we pass the
// CredentialsHandle instead.
//
SecStatus = AcceptSecurityContext(
FirstChallenge ? &(Security.CredentialsHandle) : NULL,
FirstChallenge ? NULL : &(Security.ServerContextHandle),
&IncomingDesc,
FirstChallenge ?
ISC_REQ_SEQUENCE_DETECT | ASC_REQ_ALLOW_NON_USER_LOGONS :
ASC_REQ_ALLOW_NON_USER_LOGONS,
SECURITY_NATIVE_DREP,
&(Security.ServerContextHandle),
&OutgoingDesc,
&(Security.ContextAttributes),
&Lifetime );
if (FirstChallenge) {
Security.ServerContextHandleValid = TRUE;
}
FirstChallenge = FALSE;
if (SecStatus != SEC_I_CONTINUE_NEEDED) {
//
// The login has been accepted or rejected.
//
((unsigned short *) Request->Packet2)[0] = htons(TFTPD_LOGIN);
((unsigned short *) Request->Packet2)[1] = htons((USHORT)-1);
if (SecStatus == STATUS_SUCCESS) {
sprintf(Request->Packet2+4, "status %u handle %d ",
SecStatus, (Index << 16) + Security.Validation);
} else {
sprintf(Request->Packet2+4, "status %u ", SecStatus);
}
packetLength = 4 + strlen(Request->Packet2+4);
for (CharPtr = Request->Packet2+4; *CharPtr; CharPtr ++) {
if (*CharPtr == ' ') {
*CharPtr = '\0';
}
}
Security.LoginComplete = TRUE;
Security.LoginStatus = SecStatus;
Security.ForeignAddress = Request->ForeignAddress;
TftpdStoreSecurityEntry(Index, &Security);
LastMessageSequence = (USHORT)-1;
} else if (SecStatus == SEC_I_CONTINUE_NEEDED) {
//
// Need to exchange with the client. Note that the response
// message has already been stored at Request->Packet2 + 8.
//
++LastMessageSequence;
((unsigned short *) Request->Packet2)[0] = htons(TFTPD_LOGIN);
((unsigned short *) Request->Packet2)[1] = htons(LastMessageSequence);
((unsigned long UNALIGNED *) Request->Packet2)[1] = htonl(OutgoingBuffer.cbBuffer);
packetLength = 8 + OutgoingBuffer.cbBuffer;
}
Acked = FALSE;
Retry = 0;
while (!Acked && (Retry < MAX_TFTPD_RETRIES) ){
//
// send the data
//
Status = sendto(
LoginPort,
Request->Packet2,
packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleLogin: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
//
// wait for the ack
//
FD_ZERO( &loginfds );
FD_ZERO( &exceptfds );
FD_SET( LoginPort, &loginfds );
FD_SET( LoginPort, &exceptfds );
timeval.tv_sec = Request->Timeout;
timeval.tv_usec = 0;
Status = select(0, &loginfds, NULL, &exceptfds, &timeval);
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleLogin: select failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
if ((Status > 0) && (FD_ISSET(LoginPort, &loginfds))) {
//
// Got response, maybe
//
AddressLength = sizeof(LoginAddress);
BytesAck =
recvfrom(
LoginPort,
Request->Packet1,
sizeof(Request->Packet1),
0,
(struct sockaddr *) &LoginAddress,
&AddressLength);
if( SOCKET_ERROR == BytesAck ){
DbgPrint("TftpdHandleLogin: recvfrom failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
if (CHECK_ACK(Request->Packet1, TFTPD_LOGIN, LastMessageSequence)) {
Acked = TRUE;
}
}
Retry ++;
} // end while.
if (!Acked) {
DbgPrint("TftpdHandleLogin: Timed out waiting for ack\n");
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Timeout");
goto cleanup;
}
if (LastMessageSequence == (USHORT)-1) {
//
// If we got an ack for the last sequence number, then
// break.
//
break;
} else {
//
// Loop back and process this message.
//
IncomingMessage = Request->Packet1 + 4;
}
} // end while 1.
} else if (strcmp(OperationType, "logoff") == 0) {
PackageName = OperationType + strlen(OperationType) + 1;
//
// Don't bother checking the package name.
//
SecurityString = PackageName + strlen(PackageName) + 1;
for (CharPtr = SecurityString; *CharPtr; CharPtr ++) {
*CharPtr = (char)tolower(*CharPtr);
}
if (strcmp(SecurityString, "security") != 0) {
DbgPrint("TftpdHandleLogin: invalid logoff handle %s\n", SecurityString );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
"invalid security handle");
goto cleanup;
}
//
// Open a new socket for this request
//
LoginPort =
socket(
AF_INET,
SOCK_DGRAM,
0);
if (LoginPort == INVALID_SOCKET) {
DbgPrint("TftpdHandleLogin: cannot open socket, Error=%d\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Bind to a random address
//
LoginAddress.sin_family = AF_INET;
LoginAddress.sin_port = 0;
LoginAddress.sin_addr.s_addr = INADDR_ANY;
Status =
bind(
LoginPort,
(struct sockaddr *) &LoginAddress,
sizeof(LoginAddress)
);
if (Status) {
DbgPrint("TftpdHandleLogin: cannot bind socket, error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Start to prepare the response.
//
((unsigned short *) Request->Packet2)[0] = htons(TFTPD_LOGIN);
((unsigned short *) Request->Packet2)[1] = htons((USHORT)-1);
//
// Now get the handle and delete the security entry if it is valid.
//
Options = SecurityString + strlen(SecurityString) + 1;
SecurityHandle = atoi(Options);
TftpdGetSecurityEntry((USHORT)(SecurityHandle >> 16), &Security);
if (Security.Validation == ((SecurityHandle) & 0xffff)) {
TftpdFreeSecurityEntry((USHORT)(SecurityHandle >> 16));
sprintf(Request->Packet2+4, "status %u ", 0);
} else {
sprintf(Request->Packet2+4, "status %u ", STATUS_INVALID_HANDLE);
}
packetLength = 4 + strlen(Request->Packet2+4);
for (CharPtr = Request->Packet2+4; *CharPtr; CharPtr ++) {
if (*CharPtr == ' ') {
*CharPtr = '\0';
}
}
//
// Wait for his ack, but not for too long.
//
Acked = FALSE;
Retry = 0;
while (!Acked && (Retry < 3) ){
//
// send the data
//
Status = sendto(
LoginPort,
Request->Packet2,
packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleLogin: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
//
// wait for the ack
//
FD_ZERO( &loginfds );
FD_ZERO( &exceptfds );
FD_SET( LoginPort, &loginfds );
FD_SET( LoginPort, &exceptfds );
timeval.tv_sec = 2;
timeval.tv_usec = 0;
Status = select(0, &loginfds, NULL, &exceptfds, &timeval);
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleLogin: select failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
if ((Status > 0) && (FD_ISSET(LoginPort, &loginfds))) {
//
// Got response, maybe
//
AddressLength = sizeof(LoginAddress);
BytesAck =
recvfrom(
LoginPort,
Request->Packet1,
sizeof(Request->Packet1),
0,
(struct sockaddr *) &LoginAddress,
&AddressLength);
if( SOCKET_ERROR == BytesAck ){
DbgPrint("TftpdHandleLogin: recvfrom failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
if (CHECK_ACK(Request->Packet1, TFTPD_LOGIN, (USHORT)-1)) {
Acked = TRUE;
}
}
Retry ++;
} // end while.
//
// If the ack timed out, don't worry about it.
//
} else {
DbgPrint("TftpdHandleLogin: invalid OperationType=%s?\n", OperationType );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
cleanup:
if (LoginPort != INVALID_SOCKET) {
closesocket(LoginPort);
}
// free(Request);
return 0;
}
#endif //defined(REMOTE_BOOT_SECURITY)
#if defined(REMOTE_BOOT_SECURITY)
DWORD
TftpdHandleKey(
PVOID Argument
)
/*++
Routine Description:
This handles an incoming key.
Arguments:
Argument - buffer containing the read request datagram
freed when done.
Return Value:
Exit status
0 == success
1 == failure
N >0 failure
--*/
{
int Status, err;
int packetLength;
struct timeval timeval;
struct sockaddr_in LoginAddress;
BOOL Acked;
int AddressLength;
char * CharPtr;
PTFTP_REQUEST Request;
char * OperationType;
char * SpiString;
char * SecurityString;
ULONG SpiValue;
ULONG KeyValue;
ULONG SecurityHandle;
SOCKET LoginPort = INVALID_SOCKET;
HANDLE IpsecHandle = INVALID_HANDLE_VALUE;
BOOL IOStatus;
char PolicyBuffer[sizeof(IPSEC_SET_POLICY) + sizeof(IPSEC_POLICY_INFO)];
PIPSEC_SET_POLICY SetPolicy = (PIPSEC_SET_POLICY)PolicyBuffer;
IPSEC_FILTER OutboundFilter;
IPSEC_FILTER InboundFilter;
IPSEC_GET_SPI GetSpi;
char SaBuffer[sizeof(IPSEC_ADD_UPDATE_SA) + (6 * sizeof(ULONG))];
PIPSEC_ADD_UPDATE_SA AddUpdateSa;
IPSEC_DELETE_POLICY DeletePolicy;
// char EnumPolicyBuffer[(UINT)(FIELD_OFFSET(IPSEC_ENUM_POLICY, pInfo[0]))];
char EnumPolicyBuffer[2 * sizeof(DWORD)];
PIPSEC_ENUM_POLICY EnumPolicy;
DWORD EnumPolicySize;
char MyName[80];
PHOSTENT Host;
DWORD BytesReturned;
DWORD i;
LARGE_INTEGER SystemTime;
TFTPD_SECURITY Security;
//
// Parse the request. The initial request should always be a
// "spi".
//
Request = (PTFTP_REQUEST) Argument;
OperationType = &Request->Packet1[4];
//
// Convert the operation to all lower case for comparison
//
for (CharPtr = OperationType; *CharPtr; CharPtr ++) {
*CharPtr = (char)tolower(*CharPtr);
}
if (strcmp(OperationType, "spi") == 0) {
SpiString = OperationType + sizeof("spi");
SpiValue = atoi(SpiString);
OperationType = SpiString + strlen(SpiString) + 1;
//
// See if the client request encryption of the key.
//
if (strcmp(OperationType, "security") == 0) {
SecurityString = OperationType + sizeof("security");
SecurityHandle = atoi(SecurityString);
//
// High 16 bits of handle is index, low 16 bits is validation.
//
TftpdGenerateKeyForSecurityEntry((USHORT)(SecurityHandle >> 16), &Security);
if ((Security.Validation != ((Request->SecurityHandle) & 0xffff)) ||
(!Security.GeneratedKey)) {
DbgPrint("TftpdHandleRead: SecurityHandle %x is invalid.\n",
Request->SecurityHandle);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Invalid security handle");
goto cleanup;
}
KeyValue = Security.Key;
DbgPrint("TftpdHandleKey: SPI %lx, retrieved secure key %lx\n", SpiValue, KeyValue);
} else {
NtQuerySystemTime(&SystemTime);
KeyValue = (ULONG)(SystemTime.QuadPart % Request->ForeignAddress.sin_addr.s_addr);
SecurityHandle = 0;
DbgPrint("TftpdHandleKey: SPI %lx, generated key %lx\n", SpiValue, KeyValue);
}
//
// Open IPSEC so we can send down IOCTLS.
//
IpsecHandle = CreateFileW(
DD_IPSEC_DOS_NAME, // IPSEC device name
GENERIC_READ | GENERIC_WRITE, // access (read-write) mode
0, // share mode
NULL, // pointer to security attributes
OPEN_EXISTING, // how to create
0, // file attributes
NULL); // handle to file with attributes to copy
if (IpsecHandle == INVALID_HANDLE_VALUE) {
DbgPrint("TftpdHandleKey: Could not open <%ws>\n", DD_IPSEC_DOS_NAME);
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// See how many policies are defined. First send down a buffer
// with room for no policies, to see how many there are.
//
EnumPolicy = (PIPSEC_ENUM_POLICY)EnumPolicyBuffer;
memset(EnumPolicy, 0, sizeof(EnumPolicyBuffer));
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_ENUM_POLICIES, // Control code
EnumPolicy, // Input buffer
sizeof(EnumPolicyBuffer), // Input buffer size
EnumPolicy, // Output buffer
sizeof(EnumPolicyBuffer), // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
if (GetLastError() != ERROR_MORE_DATA) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_ENUM_POLICY #1 failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_ENUM_POLICIES");
goto cleanup;
}
EnumPolicySize = FIELD_OFFSET(IPSEC_ENUM_POLICY, pInfo[0]) +
sizeof(IPSEC_POLICY_INFO) * EnumPolicy->NumEntriesPresent;
EnumPolicy = malloc(EnumPolicySize);
if (EnumPolicy == NULL) {
DbgPrint("TftpdHandleKey: alloc ENUM_POLICIES buffer failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_ENUM_POLICIES");
goto cleanup;
}
//
// Re-submit the IOCTL.
//
memset(EnumPolicy, 0, EnumPolicySize);
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_ENUM_POLICIES, // Control code
EnumPolicy, // Input buffer
EnumPolicySize, // Input buffer size
EnumPolicy, // Output buffer
EnumPolicySize, // Output buffer size
&BytesReturned,
NULL);
//
// We may get MORE_DATA if someone just added a policy, but
// that is OK since it won't be for this remote.
//
if (!IOStatus && (GetLastError() != ERROR_MORE_DATA)) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_ENUM_POLICY #2 failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_ENUM_POLICIES");
free(EnumPolicy);
goto cleanup;
}
//
// Display all the policies.
//
// Delete any policies involving the remote machine.
//
for (i = 0; i < EnumPolicy->NumEntriesPresent; i++) {
if ((EnumPolicy->pInfo[i].AssociatedFilter.SrcAddr ==
Request->ForeignAddress.sin_addr.s_addr) ||
(EnumPolicy->pInfo[i].AssociatedFilter.DestAddr ==
Request->ForeignAddress.sin_addr.s_addr)) {
DeletePolicy.NumEntries = 1;
DeletePolicy.pInfo[0] = EnumPolicy->pInfo[i];
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_DELETE_POLICY, // Control code
&DeletePolicy, // Input buffer
sizeof(IPSEC_DELETE_POLICY), // Input buffer size
NULL, // Output buffer
0, // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_DELETE_POLICY(%lx, %lx) failed %x\n",
EnumPolicy->pInfo[i].AssociatedFilter.SrcAddr,
EnumPolicy->pInfo[i].AssociatedFilter.DestAddr,
GetLastError());
}
}
}
free(EnumPolicy);
} else {
//
// If the call succeeds, we don't need to do anything, since
// there should have been 0 policies returned.
//
}
//
// Get our local IP address.
//
gethostname(MyName, sizeof(MyName));
Host = gethostbyname(MyName);
//
// Set the policy. We need two filters, one for outbound and
// one for inbound.
//
memset(&OutboundFilter, 0, sizeof(IPSEC_FILTER));
memset(&InboundFilter, 0, sizeof(IPSEC_FILTER));
OutboundFilter.SrcAddr = *(DWORD *)Host->h_addr;
OutboundFilter.SrcMask = 0xffffffff;
// OutboundFilter.SrcPort = 0x8B; // netbios session port
OutboundFilter.DestAddr = Request->ForeignAddress.sin_addr.s_addr;
OutboundFilter.DestMask = 0xffffffff;
OutboundFilter.Protocol = 0x6; // TCP
InboundFilter.SrcAddr = Request->ForeignAddress.sin_addr.s_addr;
InboundFilter.SrcMask = 0xffffffff;
InboundFilter.DestAddr = *(DWORD *)Host->h_addr;
InboundFilter.DestMask = 0xffffffff;
// InboundFilter.DestPort = 0x8B; // netbios session port
InboundFilter.Protocol = 0x6; // TCP
memset(SetPolicy, 0, sizeof(PolicyBuffer));
SetPolicy->NumEntries = 2;
SetPolicy->pInfo[0].Index = 1;
SetPolicy->pInfo[0].AssociatedFilter = OutboundFilter;
SetPolicy->pInfo[1].Index = 2;
SetPolicy->pInfo[1].AssociatedFilter = InboundFilter;
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_SET_POLICY, // Control code
SetPolicy, // Input buffer
sizeof(PolicyBuffer), // Input buffer size
NULL, // Output buffer
0, // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_SET_POLICY failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_SET_POLICY");
goto cleanup;
}
//
// Now get an SPI to give to the remote.
//
GetSpi.Context = 0;
GetSpi.InstantiatedFilter = InboundFilter;
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_GET_SPI, // Control code
&GetSpi, // Input buffer
sizeof(IPSEC_GET_SPI), // Input buffer size
&GetSpi, // Output buffer
sizeof(IPSEC_GET_SPI), // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_GET_SPI failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_GET_SPI");
goto cleanup;
}
//
// Set up the security association for the outbound
// connection.
//
AddUpdateSa = (PIPSEC_ADD_UPDATE_SA)SaBuffer;
memset(AddUpdateSa, 0, sizeof(SaBuffer));
AddUpdateSa->SAInfo.Context = GetSpi.Context;
AddUpdateSa->SAInfo.NumSAs = 1;
AddUpdateSa->SAInfo.InstantiatedFilter = OutboundFilter;
AddUpdateSa->SAInfo.SecAssoc[0].Operation = Encrypt;
AddUpdateSa->SAInfo.SecAssoc[0].SPI = SpiValue;
AddUpdateSa->SAInfo.SecAssoc[0].IntegrityAlgo.algoIdentifier = IPSEC_AH_MD5;
AddUpdateSa->SAInfo.SecAssoc[0].IntegrityAlgo.algoKeylen = 4 * sizeof(ULONG);
AddUpdateSa->SAInfo.SecAssoc[0].ConfAlgo.algoIdentifier = IPSEC_ESP_DES;
AddUpdateSa->SAInfo.SecAssoc[0].ConfAlgo.algoKeylen = 2 * sizeof(ULONG);
AddUpdateSa->SAInfo.KeyLen = 6 * sizeof(ULONG);
memcpy(AddUpdateSa->SAInfo.KeyMat, &KeyValue, sizeof(ULONG));
memcpy(AddUpdateSa->SAInfo.KeyMat+sizeof(ULONG), &KeyValue, sizeof(ULONG));
memcpy(AddUpdateSa->SAInfo.KeyMat+(2*sizeof(ULONG)), &KeyValue, sizeof(ULONG));
memcpy(AddUpdateSa->SAInfo.KeyMat+(3*sizeof(ULONG)), &KeyValue, sizeof(ULONG));
memcpy(AddUpdateSa->SAInfo.KeyMat+(4*sizeof(ULONG)), &KeyValue, sizeof(ULONG));
memcpy(AddUpdateSa->SAInfo.KeyMat+(5*sizeof(ULONG)), &KeyValue, sizeof(ULONG));
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_ADD_SA, // Control code
AddUpdateSa, // Input buffer
FIELD_OFFSET(IPSEC_ADD_UPDATE_SA, SAInfo.KeyMat[0]) +
AddUpdateSa->SAInfo.KeyLen, // Input buffer size
NULL, // Output buffer
0, // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_ADD_SA failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_ADD_SA");
goto cleanup;
}
//
// Set up the security association for the inbound connection.
// If our Operation is "None", then IPSEC does this for us.
//
if (AddUpdateSa->SAInfo.SecAssoc[0].Operation != None) {
AddUpdateSa->SAInfo.SecAssoc[0].SPI = GetSpi.SPI;
AddUpdateSa->SAInfo.InstantiatedFilter = InboundFilter;
IOStatus = DeviceIoControl(
IpsecHandle, // Driver handle
IOCTL_IPSEC_UPDATE_SA, // Control code
AddUpdateSa, // Input buffer
FIELD_OFFSET(IPSEC_ADD_UPDATE_SA, SAInfo.KeyMat[0]) +
AddUpdateSa->SAInfo.KeyLen, // Input buffer size
NULL, // Output buffer
0, // Output buffer size
&BytesReturned,
NULL);
if (!IOStatus) {
DbgPrint("TftpdHandleKey: IOCTL_IPSEC_UPDATE_SA failed %x\n", GetLastError());
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"IOCTL_IPSEC_UPDATE_SA");
goto cleanup;
}
}
//
// Open a new socket for this request
//
LoginPort =
socket(
AF_INET,
SOCK_DGRAM,
0);
if (LoginPort == INVALID_SOCKET) {
DbgPrint("TftpdHandleKey: cannot open socket, Error=%d\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Bind to a random address
//
LoginAddress.sin_family = AF_INET;
LoginAddress.sin_port = 0;
LoginAddress.sin_addr.s_addr = INADDR_ANY;
Status =
bind(
LoginPort,
(struct sockaddr *) &LoginAddress,
sizeof(LoginAddress)
);
if (Status) {
DbgPrint("TftpdHandleKey: cannot bind socket, error=%d.\n",
WSAGetLastError() );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_UNDEFINED,
"Insufficient resources");
goto cleanup;
}
//
// Generate the response for the client.
//
((unsigned short *) Request->Packet2)[0] = htons(TFTPD_KEY);
Request->Packet2[2] = Request->Packet1[2]; // copy sequence number
Request->Packet2[3] = Request->Packet1[3];
//
// They key is sent as hex digits since it might be longer
// than four bytes.
//
if (SecurityHandle == 0) {
//
// No security, send key in the clear.
//
sprintf(Request->Packet2+4, "spi %d key %2.2x%2.2x%2.2x%2.2x",
GetSpi.SPI,
((PUCHAR)(&KeyValue))[0],
((PUCHAR)(&KeyValue))[1],
((PUCHAR)(&KeyValue))[2],
((PUCHAR)(&KeyValue))[3]);
packetLength = 4 + strlen(Request->Packet2+4);
} else {
PCHAR SignLoc;
ULONG i;
//
// Security requested, so send the encrypted key and the sign.
//
sprintf(Request->Packet2+4, "spi %d security %d sign ",
GetSpi.SPI,
SecurityHandle);
packetLength = 4 + strlen(Request->Packet2+4);
SignLoc = Request->Packet2 + packetLength;
for (i = 0; i < NTLMSSP_MESSAGE_SIGNATURE_SIZE; i++) {
sprintf(SignLoc, "%2.2x", Security.Sign[i]);
SignLoc += 2;
packetLength += 2;
}
sprintf(Request->Packet2+packetLength, " key %2.2x%2.2x%2.2x%2.2x",
Security.SignedKey[0],
Security.SignedKey[1],
Security.SignedKey[2],
Security.SignedKey[3]);
packetLength += strlen(" key ") + (2 * sizeof(Security.SignedKey));
}
for (CharPtr = Request->Packet2+4; *CharPtr; CharPtr ++) {
if (*CharPtr == ' ') {
*CharPtr = '\0';
}
}
//
// Send the response back to the client.
//
Status = sendto(
LoginPort,
Request->Packet2,
packetLength,
0,
(struct sockaddr *) &Request->ForeignAddress,
sizeof(struct sockaddr_in));
if( SOCKET_ERROR == Status ){
DbgPrint("TftpdHandleKey: sendto failed=%d\n",
WSAGetLastError() );
goto cleanup;
}
} else {
DbgPrint("TftpdHandleKey: invalid OperationType=%s?\n", OperationType );
TftpdErrorPacket(
(struct sockaddr *) &Request->ForeignAddress,
Request->Packet2,
Request->TftpdPort,
TFTPD_ERROR_ILLEGAL_OPERATION,
NULL);
goto cleanup;
}
cleanup:
if (IpsecHandle != INVALID_HANDLE_VALUE) {
CloseHandle(IpsecHandle);
}
if (LoginPort != INVALID_SOCKET) {
closesocket(LoginPort);
}
// free(Request);
return 0;
}
#endif //defined(REMOTE_BOOT_SECURITY)
// ========================================================================
int
TftpdDoRead(
int ReadFd,
char * Buffer,
int BufferSize,
int ReadMode)
/*++
Routine Description:
This does a read with the appropriate conversions for netascii or octet
modes.
Arguments:
ReadFd - file to read from
Buffer - Buffer to read into
BufferSize - size of buffer
ReadMode - O_TEXT or O_BINARY
O_TEXT means the netascii conversions must be done
O_BINARY means octet mode
Return Value:
BytesRead
Error?
--*/
{
int BytesRead;
int BytesWritten;
int BytesUsed;
char NextChar;
char State;
char LocalBuffer[MAX_TFTP_DATA];
int err;
if (ReadMode == O_BINARY) {
BytesRead = _read(ReadFd, Buffer, BufferSize);
if( BytesRead == -1 ){
DbgPrint("TftpdDoRead: read failed, errno=%d\n", errno );
SetLastError( errno );
}
return(BytesRead);
} else {
//
// Do those cr/lf conversions. A \r not followed by a \n must
// be followed by a \0.
//
BytesWritten = 0;
BytesUsed = 0;
State = '\0';
BytesRead = _read(ReadFd, LocalBuffer, sizeof(LocalBuffer));
if( BytesRead == -1 ){
DbgPrint("TftpdDoRead: read failed, errno=%d\n", errno );
SetLastError( errno );
return -1;
}
while ((BytesUsed < BytesRead) && (BytesWritten < BufferSize)) {
NextChar = LocalBuffer[BytesUsed++];
if (State == '\r') {
if (NextChar == '\n') {
Buffer[BytesWritten++] = NextChar;
State = '\0';
} else {
Buffer[BytesWritten++] = '\0';
Buffer[BytesWritten++] = NextChar;
State = '\0';
}
} else {
Buffer[BytesWritten++] = NextChar;
State = '\0';
}
if (NextChar == '\r') {
State = '\r';
}
}
err = _lseek(ReadFd, BytesUsed - BytesRead, SEEK_CUR);
if( err == -1 ){
DbgPrint("TftpdDoRead: lseek failed, errno=%d\n",
errno );
SetLastError( errno );
return -1;
}
return(BytesWritten);
}
}
// End function TftpdDoRead.
// ========================================================================
int
TftpdDoWrite(
int WriteFd,
char * Buffer,
int BufferSize,
int WriteMode,
char * State)
/*++
Routine Description:
This does a write with the appropriate conversions for netascii or octet
modes.
Arguments:
WriteFd - file to write to
Buffer - Buffer to read into
BufferSize - size of buffer
WriteMode - O_TEXT or O_BINARY
O_TEXT means the netascii conversions must be done
O_BINARY means octet mode
State - pointer to the current output state. If the last character in the
buffer is a '\r', that fact must be remembered.
Return Value:
BytesWritten
Error?
--*/
{
int BytesWritten;
int i;
char OutputBuffer[MAX_TFTP_DATA*2];
int OutputPointer;
if (WriteMode == O_BINARY) {
BytesWritten = _write(WriteFd, Buffer, BufferSize);
if( BytesWritten == -1 ){
DbgPrint("TftpdDoWrite: write failed=%d\n", errno );
SetLastError( errno );
}
} else {
//
// Do those cr/lf conversions. If a '\r' followed by a '\0' is
// followed by a '\0', the '\0' is stripped.
//
OutputPointer = 0;
for (i=0; i<BufferSize; i++) {
if ((*State == '\r') && (Buffer[i] == '\0')) {
*State = '\0';
} else {
OutputBuffer[OutputPointer ++] = Buffer[i];
if (Buffer[i] == '\r') {
*State = '\r';
}
}
}
BytesWritten = _write(WriteFd, Buffer, OutputPointer);
if( BytesWritten == -1 ){
DbgPrint("TftpdDoWrite: write failed=%d\n", errno );
SetLastError( errno );
}
}
return(BytesWritten);
}
// End function TftpdDoWrite.
// ========================================================================
// EOF.