windows-nt/Source/XPSP1/NT/base/ntdll/lpcsvr.c
2020-09-26 16:20:57 +08:00

744 lines
18 KiB
C

//+---------------------------------------------------------------------------
//
// Microsoft Windows
// Copyright (C) Microsoft Corporation, 1992 - 1997.
//
// File: lpcsvr.c
//
// Contents:
//
// Classes:
//
// Functions:
//
// History: 12-12-97 RichardW Created
//
//----------------------------------------------------------------------------
#include <ntos.h>
#include <nt.h>
#include <ntrtl.h>
#include <nturtl.h>
#include "lpcsvr.h"
#define RtlpLpcLockServer( s ) RtlEnterCriticalSection( &s->Lock );
#define RtlpLpcUnlockServer( s ) RtlLeaveCriticalSection( &s->Lock );
#define RtlpLpcContextFromClient( p ) ( CONTAINING_RECORD( p, LPCSVR_CONTEXT, PrivateContext ) )
//+---------------------------------------------------------------------------
//
// Function: RtlpLpcDerefContext
//
// Synopsis: Deref the context. If this context is being cleaned up after
// the server has been deleted, then the message is freed directly,
// rather than being released to the general queue.
//
// Arguments: [Context] --
// [Message] --
//
// History: 2-06-98 RichardW Created
//
// Notes:
//
//----------------------------------------------------------------------------
VOID
RtlpLpcDerefContext(
PLPCSVR_CONTEXT Context,
PLPCSVR_MESSAGE Message
)
{
PLPCSVR_SERVER Server ;
Server = Context->Server ;
if ( InterlockedDecrement( &Context->RefCount ) < 0 )
{
//
// All gone, time to clean up:
//
RtlpLpcLockServer( Server );
if ( Context->List.Flink )
{
RemoveEntryList( &Context->List );
Server->ContextCount -- ;
}
else
{
if ( Message )
{
RtlFreeHeap( RtlProcessHeap(),
0,
Message );
}
}
RtlpLpcUnlockServer( Server );
if ( Context->CommPort )
{
NtClose( Context->CommPort );
}
RtlFreeHeap( RtlProcessHeap(),
0,
Context );
}
else
{
RtlpLpcLockServer( Server );
Server->MessagePoolSize++ ;
if ( Server->MessagePoolSize < Server->MessagePoolLimit )
{
Message->Header.Next = Server->MessagePool ;
Server->MessagePool = Message ;
}
else
{
Server->MessagePoolSize-- ;
RtlFreeHeap( RtlProcessHeap(),
0,
Message );
}
RtlpLpcUnlockServer( Server );
}
}
//+---------------------------------------------------------------------------
//
// Function: RtlpLpcWorkerThread
//
// Synopsis: General worker thread
//
// Arguments: [Parameter] --
//
// History: 2-06-98 RichardW Created
//
// Notes:
//
//----------------------------------------------------------------------------
VOID
RtlpLpcWorkerThread(
PVOID Parameter
)
{
PLPCSVR_MESSAGE Message ;
PLPCSVR_CONTEXT Context ;
NTSTATUS Status ;
BOOLEAN Accept ;
Message = (PLPCSVR_MESSAGE) Parameter ;
Context = Message->Header.Context ;
switch ( Message->Message.u2.s2.Type & 0xF )
{
case LPC_REQUEST:
case LPC_DATAGRAM:
DbgPrint("Calling Server's Request function\n");
Status = Context->Server->Init.RequestFn(
&Context->PrivateContext,
&Message->Message,
&Message->Message
);
if ( NT_SUCCESS( Status ) )
{
Status = NtReplyPort( Context->CommPort,
&Message->Message );
if ( !NT_SUCCESS( Status ) )
{
//
// See what happened. The client may have gone away already.
//
break;
}
}
break;
case LPC_CONNECTION_REQUEST:
DbgPrint("Calling Server's Connect function\n");
Status = Context->Server->Init.ConnectFn(
&Context->PrivateContext,
&Message->Message,
&Accept
);
//
// If the comm port is still null, then do the accept. Otherwise, the
// server called RtlAcceptConnectPort() explicitly, to set up a view.
//
if ( NT_SUCCESS( Status ) )
{
if ( Context->CommPort == NULL )
{
Status = NtAcceptConnectPort(
&Context->CommPort,
Context,
&Message->Message,
Accept,
NULL,
NULL );
if ( !Accept )
{
//
// Yank the context out of the list, since it is worthless
//
Context->RefCount = 0 ;
}
else
{
Status = NtCompleteConnectPort( Context->CommPort );
}
}
}
else
{
Status = NtAcceptConnectPort(
&Context->CommPort,
NULL,
&Message->Message,
FALSE,
NULL,
NULL );
Context->RefCount = 0 ;
}
break;
case LPC_CLIENT_DIED:
DbgPrint( "Calling Server's Rundown function\n" );
Status = Context->Server->Init.RundownFn(
&Context->PrivateContext,
&Message->Message
);
InterlockedDecrement( &Context->RefCount );
break;
default:
//
// An unexpected message came through. Normal LPC servers
// don't handle the other types of messages. Drop it.
//
break;
}
RtlpLpcDerefContext( Context, Message );
return ;
}
VOID
RtlpLpcServerCallback(
PVOID Parameter,
BOOLEAN TimedOut
)
{
PLPCSVR_SERVER Server ;
NTSTATUS Status ;
PLPCSVR_MESSAGE Message ;
PLPCSVR_CONTEXT Context ;
PLARGE_INTEGER RealTimeout ;
LPCSVR_FILTER_RESULT FilterResult ;
Server = (PLPCSVR_SERVER) Parameter ;
if ( Server->WaitHandle )
{
Server->WaitHandle = NULL ;
}
while ( 1 )
{
DbgPrint("Entering LPC server\n" );
RtlpLpcLockServer( Server );
if ( Server->Flags & LPCSVR_SHUTDOWN_PENDING )
{
break;
}
if ( Server->MessagePool )
{
Message = Server->MessagePool ;
Server->MessagePool = Message->Header.Next ;
}
else
{
Message = RtlAllocateHeap( RtlProcessHeap(),
0,
Server->MessageSize );
}
RtlpLpcUnlockServer( Server );
if ( !Message )
{
LARGE_INTEGER SleepInterval ;
SleepInterval.QuadPart = 125 * 10000 ;
NtDelayExecution( FALSE, &SleepInterval );
continue;
}
if ( Server->Timeout.QuadPart )
{
RealTimeout = &Server->Timeout ;
}
else
{
RealTimeout = NULL ;
}
Status = NtReplyWaitReceivePortEx(
Server->Port,
&Context,
NULL,
&Message->Message,
RealTimeout );
DbgPrint("Server: NtReplyWaitReceivePort completed with %x\n", Status );
if ( NT_SUCCESS( Status ) )
{
//
// If we timed out, nobody was waiting for us:
//
if ( Status == STATUS_TIMEOUT )
{
//
// Set up a general wait that will call back to this function
// when ready.
//
RtlpLpcLockServer( Server );
if ( ( Server->Flags & LPCSVR_SHUTDOWN_PENDING ) == 0 )
{
Status = RtlRegisterWait( &Server->WaitHandle,
Server->Port,
RtlpLpcServerCallback,
Server,
0xFFFFFFFF,
WT_EXECUTEONLYONCE );
}
RtlpLpcUnlockServer( Server );
break;
}
if ( Status == STATUS_SUCCESS )
{
if ( Context )
{
InterlockedIncrement( &Context->RefCount );
}
else
{
//
// New connection. Create a new context record
//
Context = RtlAllocateHeap( RtlProcessHeap(),
0,
sizeof( LPCSVR_CONTEXT ) +
Server->Init.ContextSize );
if ( !Context )
{
HANDLE Bogus ;
Status = NtAcceptConnectPort(
&Bogus,
NULL,
&Message->Message,
FALSE,
NULL,
NULL );
RtlpLpcLockServer( Server );
Message->Header.Next = Server->MessagePool ;
Server->MessagePool = Message ;
RtlpLpcUnlockServer( Server );
continue;
}
Context->Server = Server ;
Context->RefCount = 1 ;
Context->CommPort = NULL ;
RtlpLpcLockServer( Server );
InsertTailList( &Server->ContextList, &Context->List );
Server->ContextCount++ ;
RtlpLpcUnlockServer( Server );
}
Message->Header.Context = Context ;
FilterResult = LpcFilterAsync ;
if ( Server->Init.FilterFn )
{
FilterResult = Server->Init.FilterFn( Context, &Message->Message );
if (FilterResult == LpcFilterDrop )
{
RtlpLpcDerefContext( Context, Message );
continue;
}
}
if ( (Server->Flags & LPCSVR_SYNCHRONOUS) ||
(FilterResult == LpcFilterSync) )
{
RtlpLpcWorkerThread( Message );
}
else
{
RtlQueueWorkItem( RtlpLpcWorkerThread,
Message,
0 );
}
}
}
else
{
//
// Error? Better shut down...
//
break;
}
}
}
NTSTATUS
RtlCreateLpcServer(
POBJECT_ATTRIBUTES PortName,
PLPCSVR_INITIALIZE Init,
PLARGE_INTEGER IdleTimeout,
ULONG MessageSize,
ULONG Options,
PVOID * LpcServer
)
{
PLPCSVR_SERVER Server ;
NTSTATUS Status ;
HANDLE Thread ;
CLIENT_ID Id ;
*LpcServer = NULL ;
Server = RtlAllocateHeap( RtlProcessHeap(),
0,
sizeof( LPCSVR_SERVER ) );
if ( !Server ) {
return STATUS_INSUFFICIENT_RESOURCES;
}
Status = RtlInitializeCriticalSectionAndSpinCount (&Server->Lock,
1000);
if (!NT_SUCCESS (Status)) {
RtlFreeHeap( RtlProcessHeap(), 0, Server );
return Status;
}
InitializeListHead( &Server->ContextList );
Server->ContextCount = 0;
Server->Init = *Init;
if ( !IdleTimeout ) {
Server->Timeout.QuadPart = 0;
} else {
Server->Timeout = *IdleTimeout;
}
Server->MessageSize = MessageSize + sizeof( LPCSVR_MESSAGE ) -
sizeof( PORT_MESSAGE );
Server->MessagePool = 0;
Server->MessagePoolSize = 0;
Server->MessagePoolLimit = 4;
Server->Flags = Options;
//
// Create the LPC port:
//
Status = NtCreateWaitablePort(
&Server->Port,
PortName,
MessageSize,
MessageSize,
MessageSize * 4
);
if ( !NT_SUCCESS( Status ) )
{
RtlDeleteCriticalSection( &Server->Lock );
RtlFreeHeap( RtlProcessHeap(), 0, Server );
return Status;
}
//
// Now, post the handle over to a wait queue
//
Status = RtlRegisterWait(
&Server->WaitHandle,
Server->Port,
RtlpLpcServerCallback,
Server,
0xFFFFFFFF,
WT_EXECUTEONLYONCE
);
if (!NT_SUCCESS (Status)) {
NtClose (Server->Port);
RtlDeleteCriticalSection( &Server->Lock );
RtlFreeHeap( RtlProcessHeap(), 0, Server );
return Status;
}
*LpcServer = Server;
return Status;
}
NTSTATUS
RtlShutdownLpcServer(
PVOID LpcServer
)
{
PLPCSVR_SERVER Server ;
OBJECT_ATTRIBUTES ObjA ;
PLIST_ENTRY Scan ;
PLPCSVR_CONTEXT Context ;
PLPCSVR_MESSAGE Message ;
NTSTATUS Status ;
Server = (PLPCSVR_SERVER) LpcServer ;
RtlpLpcLockServer( Server );
if ( Server->Flags & LPCSVR_SHUTDOWN_PENDING )
{
RtlpLpcUnlockServer( Server );
return STATUS_PENDING ;
}
if ( Server->WaitHandle )
{
RtlDeregisterWait( Server->WaitHandle );
Server->WaitHandle = NULL ;
}
if ( Server->Timeout.QuadPart == 0 )
{
RtlpLpcUnlockServer( Server );
return STATUS_NOT_IMPLEMENTED ;
}
//
// If there are receives still pending, we have to sync
// with those threads. To do so, we will tag the shutdown
// flag, and then wait the timeout amount.
//
if ( Server->ReceiveThreads != 0 )
{
InitializeObjectAttributes( &ObjA,
NULL,
0,
0,
0 );
Status = NtCreateEvent( &Server->ShutdownEvent,
EVENT_ALL_ACCESS,
&ObjA,
NotificationEvent,
FALSE );
if ( !NT_SUCCESS( Status ) )
{
RtlpLpcUnlockServer( Server );
return Status ;
}
Server->Flags |= LPCSVR_SHUTDOWN_PENDING ;
RtlpLpcUnlockServer( Server );
Status = NtWaitForSingleObject(
Server->ShutdownEvent,
FALSE,
&Server->Timeout );
if ( Status == STATUS_TIMEOUT )
{
//
// Hmm, the LPC server thread is hung somewhere,
// press on
//
}
RtlpLpcLockServer( Server );
NtClose( Server->ShutdownEvent );
Server->ShutdownEvent = NULL ;
}
else
{
Server->Flags |= LPCSVR_SHUTDOWN_PENDING ;
}
//
// The server object is locked, and there are no receives
// pending. Or, the receives appear hung. Skim through the
// context list, calling the server code. The disconnect
// message is NULL, indicating that this is a server initiated
// shutdown.
//
while ( ! IsListEmpty( &Server->ContextList ) )
{
Scan = RemoveHeadList( &Server->ContextList );
Context = CONTAINING_RECORD( Scan, LPCSVR_CONTEXT, List );
Status = Server->Init.RundownFn(
Context->PrivateContext,
NULL );
Context->List.Flink = NULL ;
RtlpLpcDerefContext( Context, NULL );
}
//
// All contexts have been deleted: clean up the messages
//
while ( Server->MessagePool )
{
Message = Server->MessagePool ;
Server->MessagePool = Message ;
RtlFreeHeap( RtlProcessHeap(),
0,
Message );
}
//
// Clean up server objects
//
return(STATUS_SUCCESS);
}
NTSTATUS
RtlImpersonateLpcClient(
PVOID Context,
PPORT_MESSAGE Message
)
{
PLPCSVR_CONTEXT LpcContext ;
LpcContext = RtlpLpcContextFromClient( Context );
return NtImpersonateClientOfPort(
LpcContext->CommPort,
Message );
}
NTSTATUS
RtlCallbackLpcClient(
PVOID Context,
PPORT_MESSAGE Request,
PPORT_MESSAGE Callback
)
{
NTSTATUS Status ;
PLPCSVR_CONTEXT LpcContext ;
if ( Request != Callback )
{
Callback->ClientId = Request->ClientId ;
Callback->MessageId = Request->MessageId ;
}
LpcContext = RtlpLpcContextFromClient( Context );
Status = NtRequestWaitReplyPort(
LpcContext->CommPort,
Callback,
Callback
);
return Status ;
}