windows-nt/Source/XPSP1/NT/base/cluster/resdll/ndquorum/srv.c

720 lines
15 KiB
C
Raw Normal View History

2020-09-26 03:20:57 -05:00
/*++
Copyright (c) 2000 Microsoft Corporation
Module Name:
srv.c
Abstract:
Implements initialization and socket interface for smb server
Author:
Ahmed Mohamed (ahmedm) 1-Feb-2000
Revision History:
--*/
#include "srv.h"
#include <process.h> // for _beginthreadex
#include <mswsock.h>
#define PROTOCOL_TYPE SOCK_SEQPACKET
#define PLUS_CLUSTER 1
#define THREADAPI unsigned int WINAPI
void
SrvCloseEndpoint(EndPoint_t *endpoint);
void
PacketReset(SrvCtx_t *ctx)
{
int i, npackets, nbufs;
Packet_t *p;
char *buf;
npackets = MAX_PACKETS;
nbufs = npackets * 2;
ctx->freelist = NULL;
p = (Packet_t *) ctx->packet_pool;
buf = (char *) ctx->buffer_pool;
for (i = 0; i < npackets; i++) {
p->buffer = (LPVOID) buf;
p->ov.hEvent = NULL;
buf += SRV_PACKET_SIZE;
p->outbuf = (LPVOID) buf;
buf += SRV_PACKET_SIZE;
p->next = ctx->freelist;
ctx->freelist = p;
p++;
}
}
BOOL
PacketInit(SrvCtx_t *ctx)
{
int npackets, nbufs;
// Allocate 2 buffers for each packet
npackets = MAX_PACKETS;
nbufs = npackets * 2;
ctx->packet_pool = xmalloc(sizeof(Packet_t) * npackets);
if (ctx->packet_pool == NULL) {
SrvLogError(("Unable to allocate packet pool!\n"));
return FALSE;
}
ctx->buffer_pool = xmalloc(SRV_PACKET_SIZE * nbufs);
if (ctx->buffer_pool == NULL) {
xfree(ctx->packet_pool);
SrvLogError(("Unable to allocate buffer pool!\n"));
return FALSE;
}
PacketReset(ctx);
return TRUE;
}
Packet_t *
PacketAlloc(EndPoint_t *ep)
{
// allocate a packet from free list, if no packet is available then
// we set the wanted flag and wait on event
SrvCtx_t *ctx;
Packet_t *p;
ASSERT(ep);
ctx = ep->SrvCtx;
retry:
EnterCriticalSection(&ctx->cs);
if (ctx->running == FALSE) {
LeaveCriticalSection(&ctx->cs);
return NULL;
}
if (p = ctx->freelist) {
ctx->freelist = p->next;
} else {
ctx->waiters++;
LeaveCriticalSection(&ctx->cs);
if (WaitForSingleObject(ctx->event, INFINITE) != WAIT_OBJECT_0) {
return NULL;
}
goto retry;
}
// Insert into per endpoint packet list
p->endpoint = ep;
p->next = ep->PacketList;
ep->PacketList = p;
LeaveCriticalSection(&ctx->cs);
return p;
}
void
PacketRelease(SrvCtx_t *ctx, Packet_t *p)
{
p->next = ctx->freelist;
ctx->freelist = p;
if (ctx->waiters > 0) {
ctx->waiters--;
SetEvent(ctx->event);
}
}
void
PacketFree(Packet_t *p)
{
EndPoint_t *ep;
SrvCtx_t *ctx;
Packet_t **last;
ep = p->endpoint;
ASSERT(ep);
ctx = ep->SrvCtx;
ASSERT(ctx);
// insert packet into head of freelist. if wanted flag is set, we signal event
EnterCriticalSection(&ctx->cs);
// Remove packet from ep list
last = &ep->PacketList;
while (*last != NULL) {
if ((*last) == p) {
*last = p->next;
break;
}
last = &(*last)->next;
}
PacketRelease(ctx, p);
if (ep->PacketList == NULL) {
// Free this endpoint
SrvCloseEndpoint(ep);
}
LeaveCriticalSection(&ctx->cs);
}
int
ProcessPacket(EndPoint_t *ep, Packet_t *p)
{
BOOL disp;
if (IsSmb(p->buffer, p->len)) {
p->in.smb = (PNT_SMB_HEADER)p->buffer;
p->in.size = p->len;
p->in.offset = sizeof(NT_SMB_HEADER);
p->in.command = p->in.smb->Command;
p->out.smb = (PNT_SMB_HEADER)p->outbuf;
p->out.size = SRV_PACKET_SIZE;
p->out.valid = sizeof(NT_SMB_HEADER);
InitSmbHeader(p);
DumpSmb(p->buffer, p->len, TRUE);
SrvLog(("dispatching Tid:%d Uid:%d Mid:%d Flags:%x Cmd:%d...\n",
p->in.smb->Tid, p->in.smb->Uid, p->in.smb->Mid,
p->in.smb->Flags2, p->in.command));
p->tag = 0;
disp = SrvDispatch(p);
if (disp == ERROR_IO_PENDING) {
return ERROR_IO_PENDING;
}
// If we handled it ok...
if (disp) {
char *buffer;
int len;
int rc;
buffer = (char *)p->out.smb;
len = (int) p->out.valid;
DumpSmb(buffer, len, FALSE);
SrvLog(("sending...len %d\n", len));
rc = send(ep->Sock, buffer, len, 0);
if (rc == SOCKET_ERROR || rc != len) {
SrvLog(("Send clnt failed %d\n", WSAGetLastError()));
closesocket(ep->Sock);
}
} else {
SrvLog(("dispatch failed!\n"));
// did not understand...hangup on virtual circuit...
SrvLog(("hangup! -- disp failed on sock %s\n", ep->ClientId));
closesocket(ep->Sock);
}
}
return ERROR_SUCCESS;
}
THREADAPI
CompletionThread(LPVOID arg)
{
Packet_t* p;
DWORD len;
ULONG_PTR id;
LPOVERLAPPED lpo;
SrvCtx_t *ctx = (SrvCtx_t *) arg;
HANDLE port = ctx->comport;
EndPoint_t *endpoint;
HANDLE ev;
ev = CreateEvent(NULL, FALSE, FALSE, NULL);
// Each thread needs its own event, msg to use
while(ctx->running) {
BOOL b;
b = GetQueuedCompletionStatus (
port,
&len,
&id,
&lpo,
INFINITE
);
p = (Packet_t *) lpo;
if (p == NULL) {
SrvLog(("SrvThread exiting, %x...\n", id));
CloseHandle(ev);
return 0;
}
if (!b && !lpo) {
SrvLog(("Getqueued failed %d\n",GetLastError()));
CloseHandle(ev);
PacketFree(p);
return 0;
}
// todo: when socket is closed, I need to free this endpoint.
// I need to tag the endpoint with how many packets got scheduled
// on it, when the refcnt reachs zero, I free it.
endpoint = (EndPoint_t *) id;
ASSERT(p->endpoint == endpoint);
p->ev = ev;
p->len = len;
if (ProcessPacket(endpoint, p) != ERROR_IO_PENDING) {
// schedule next read
b = ReadFile ((HANDLE)endpoint->Sock,
p->buffer,
SRV_PACKET_SIZE,
&len,
&p->ov);
if (!b && GetLastError () != ERROR_IO_PENDING) {
SrvLog(("SrvThread read ep 0x%x failed %d\n", endpoint, GetLastError()));
// Return packet to queue
PacketFree(p);
}
}
}
CloseHandle(ev);
SrvLog(("SrvThread exiting, not running...\n"));
return 0;
}
void
SrvFinalize(Packet_t *p)
{
char *buffer;
DWORD len, rc;
EndPoint_t *endpoint = p->endpoint;
ASSERT(p->tag == ERROR_IO_PENDING);
p->tag = 0;
buffer = (char *)p->out.smb;
len = (DWORD) p->out.valid;
DumpSmb(buffer, len, FALSE);
SrvLog(("sending...len %d\n", len));
rc = send(endpoint->Sock, buffer, len, 0);
if (rc == SOCKET_ERROR || rc != len) {
SrvLog(("Finalize Send clnt failed <%d>\n", WSAGetLastError()));
}
rc = ReadFile ((HANDLE)endpoint->Sock,
p->buffer,
SRV_PACKET_SIZE,
&len,
&p->ov);
if (!rc && GetLastError () != ERROR_IO_PENDING) {
// Return packet to queue
PacketFree(p);
}
}
void
SrvCloseEndpoint(EndPoint_t *endpoint)
{
EndPoint_t **p;
Packet_t *packet;
// lock must be held
while (packet = endpoint->PacketList) {
endpoint->PacketList = packet->next;
// return to free list now
PacketRelease(endpoint->SrvCtx, packet);
}
// remove from ctx list
p = &endpoint->SrvCtx->EndPointList;
while (*p != NULL) {
if (*p == endpoint) {
*p = endpoint->Next;
break;
}
p = &(*p)->Next;
}
closesocket(endpoint->Sock);
// We need to inform filesystem that this
// tree is gone.
FsLogoffUser(endpoint->SrvCtx->FsCtx, endpoint->LogonId);
free(endpoint);
}
DWORD
ListenSocket(SrvCtx_t *ctx, int nic)
{
DWORD err = ERROR_SUCCESS;
SOCKET listen_socket = INVALID_SOCKET;
struct sockaddr_nb local;
unsigned char *srvname = ctx->nb_local_name;
SET_NETBIOS_SOCKADDR(&local, NETBIOS_UNIQUE_NAME, srvname, ' ');
listen_socket = socket(AF_NETBIOS, PROTOCOL_TYPE, -nic);
if (listen_socket == INVALID_SOCKET){
err = WSAGetLastError();
SrvLogError(("socket() '%s' nic %d failed with error %d\n",
srvname, nic, err));
return err;
}
//
// bind socket
//
if (bind(listen_socket,(struct sockaddr*)&local,sizeof(local)) == SOCKET_ERROR) {
err = WSAGetLastError();
SrvLogError(("srv nic %d bind() failed with error %d\n",nic, err));
closesocket(listen_socket);
return err;
}
// issue listen
if (listen(listen_socket,5) == SOCKET_ERROR) {
err = WSAGetLastError();
SrvLogError(("listen() failed with error %d\n", err));
closesocket(listen_socket);
return err;
}
// all is well.
ctx->listen_socket = listen_socket;
return ERROR_SUCCESS;
}
THREADAPI
ListenThread(LPVOID arg)
{
SOCKET listen_socket, msgsock;
struct sockaddr_nb from;
int fromlen;
HANDLE comport;
SrvCtx_t *ctx = (SrvCtx_t *) arg;
EndPoint_t *endpoint;
char localname[64];
gethostname(localname, sizeof(localname));
listen_socket = ctx->listen_socket;
comport = ctx->comport;
while(ctx->running) {
int i;
fromlen =sizeof(from);
msgsock = accept(listen_socket,(struct sockaddr*)&from, &fromlen);
if (msgsock == INVALID_SOCKET) {
if (ctx->running)
SrvLogError(("accept() error %d\n",WSAGetLastError()));
break;
}
from.snb_name[NETBIOS_NAME_LENGTH-1] = '\0';
{
char *s = strchr(from.snb_name, ' ');
if (s != NULL) *s = '\0';
}
SrvLog(("Received call from '%s'\n", from.snb_name));
// Fence off all nodes except cluster nodes. We ask
// our resource to check for us. For now we fence off all nodes but the this node
if (_stricmp(localname, from.snb_name)) {
// sorry, we just close the connection now
closesocket(msgsock);
continue;
}
// allocate a new endpoint
endpoint = (EndPoint_t *) malloc(sizeof(*endpoint));
if (endpoint == NULL) {
SrvLogError(("Failed allocate failed %d\n", GetLastError()));
closesocket(msgsock);
continue;
}
memset(endpoint, 0, sizeof(*endpoint));
// add endpoint now
EnterCriticalSection(&ctx->cs);
endpoint->Next = ctx->EndPointList;
ctx->EndPointList = endpoint;
LeaveCriticalSection(&ctx->cs);
endpoint->Sock = msgsock;
endpoint->SrvCtx = ctx;
memcpy(endpoint->ClientId, from.snb_name, sizeof(endpoint->ClientId));
comport = CreateIoCompletionPort((HANDLE)msgsock, comport,
(ULONG_PTR)endpoint, 8);
if (!comport) {
SrvLogError(("CompletionPort bind Failed %d\n", GetLastError()));
SrvCloseEndpoint(endpoint);
comport = ctx->comport;
continue;
}
for (i = 0; i < SRV_NUM_WORKERS; i++) {
Packet_t *p;
BOOL b;
DWORD nbytes;
p = PacketAlloc(endpoint);
if (p == NULL) {
SrvLog(("Listen thread got null packet, exiting posted...\n"));
break;
}
b = ReadFile (
(HANDLE) msgsock,
p->buffer,
SRV_PACKET_SIZE,
&nbytes,
&p->ov);
if (!b && GetLastError () != ERROR_IO_PENDING) {
SrvLog(("Srv ReadFile Failed %d\n",
GetLastError()));
// Return packet to queue
PacketFree(p);
break;
}
}
}
return (0);
}
DWORD
SrvInit(PVOID resHdl, PVOID fsHdl, PVOID *Hdl)
{
SrvCtx_t *ctx;
DWORD err;
ctx = (SrvCtx_t *) malloc(sizeof(*ctx));
if (ctx == NULL) {
return ERROR_NOT_ENOUGH_MEMORY;
}
memset(ctx, 0, sizeof(*ctx));
ctx->FsCtx = fsHdl;
ctx->resHdl = resHdl;
// init lsa now
err = LsaInit(&ctx->LsaHandle, &ctx->LsaPack);
if (err != ERROR_SUCCESS) {
SrvLogError(("LsaInit failed with error %x\n", err));
free(ctx);
return err;
}
// init winsock now
if (WSAStartup(0x202,&ctx->wsaData) == SOCKET_ERROR) {
err = WSAGetLastError();
SrvLogError(("WSAStartup failed with error %d\n", err));
free(ctx);
return err;
}
InitializeCriticalSection(&ctx->cs);
ctx->running = FALSE;
ctx->event = CreateEvent(NULL, FALSE, FALSE, NULL);
ctx->waiters = 0;
if (PacketInit(ctx) != TRUE) {
WSACleanup();
return ERROR_NO_SYSTEM_RESOURCES;
}
SrvUtilInit(ctx);
*Hdl = (PVOID) ctx;
return ERROR_SUCCESS;
}
DWORD
SrvOnline(PVOID Hdl, LPWSTR name, DWORD nic)
{
SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
DWORD err;
int i;
int nFixedThreads = 1;
char localname[128];
SYSTEM_INFO sysinfo;
if (ctx == NULL) {
return ERROR_INVALID_PARAMETER;
}
if (nic > 0)
nic--;
//
// Start up threads in suspended mode
//
if (ctx->running == TRUE)
return ERROR_SUCCESS;
// save name to use
if (name != NULL) {
// we need to translate name to ascii
i = wcstombs(localname, name, NETBIOS_NAME_LENGTH-1);
localname[i] = '\0';
strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH);
} else {
// use local name and append our -crs extension
gethostname(localname, sizeof(localname));
strcat(localname, SRV_NAME_EXTENSION);
strncpy(ctx->nb_local_name, localname, NETBIOS_NAME_LENGTH);
}
for (i = 0; i < NETBIOS_NAME_LENGTH; i++) {
ctx->nb_local_name[i] = (char) toupper(ctx->nb_local_name[i]);
}
// create completion port
GetSystemInfo(&sysinfo);
ctx->comport = CreateIoCompletionPort(INVALID_HANDLE_VALUE, 0, 0,
sysinfo.dwNumberOfProcessors*8);
if (ctx->comport == INVALID_HANDLE_VALUE) {
err = GetLastError();
SrvLogError(("Unable to create completion port %d\n", err));
WSACleanup();
return err;
}
// create listen socket
ctx->nic = nic;
err = ListenSocket(ctx, nic);
if ( err != ERROR_SUCCESS) {
WSACleanup();
return err;
}
// start up 1 listener/receiver, a few workers, a few senders....
ctx->nThreads = nFixedThreads + SRV_NUM_SENDERS;
ctx->hThreads = (HANDLE *) malloc(sizeof(HANDLE) * ctx->nThreads);
if (ctx->hThreads == NULL) {
WSACleanup();
return ERROR_NOT_ENOUGH_MEMORY;
}
for (i = 0; i < nFixedThreads; i++) {
ctx->hThreads[i] = (HANDLE)
_beginthreadex(NULL, 0, &ListenThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL);
}
for ( ; i < ctx->nThreads; i++) {
ctx->hThreads[i] = (HANDLE)
_beginthreadex(NULL, 0, &CompletionThread, (LPVOID)ctx, CREATE_SUSPENDED, NULL);
}
ctx->running = TRUE;
for (i = 0; i < ctx->nThreads; i++)
ResumeThread(ctx->hThreads[i]);
return ERROR_SUCCESS;
}
DWORD
SrvOffline(PVOID Hdl)
{
int i;
SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
if (ctx == NULL) {
return ERROR_INVALID_PARAMETER;
}
// we shutdown all threads in the completion port
// we close all currently open sockets
// we free all memory
if (ctx->running) {
EndPoint_t *ep;
ctx->running = FALSE;
closesocket(ctx->listen_socket);
EnterCriticalSection(&ctx->cs);
for (ep = ctx->EndPointList; ep; ep = ep->Next)
closesocket(ep->Sock);
LeaveCriticalSection(&ctx->cs);
SrvLog(("waiting for threads to die off...\n"));
// send a kill packet to all threads on the completion port
for (i = 0; i < ctx->nThreads; i++) {
if (!PostQueuedCompletionStatus(ctx->comport, 0, 0, NULL)) {
SrvLog(("Port queued port failed %d\n", GetLastError()));
break;
}
}
if (i == ctx->nThreads) {
// wait for them to die of natural causes before we kill them...
WaitForMultipleObjects(ctx->nThreads, ctx->hThreads, TRUE, INFINITE);
}
// close handles
for (i = 0; i < ctx->nThreads; i++) {
CloseHandle(ctx->hThreads[i]);
}
CloseHandle(ctx->comport);
free((char *)ctx->hThreads);
// free endpoints
EnterCriticalSection(&ctx->cs);
while (ep = ctx->EndPointList)
SrvCloseEndpoint(ep);
LeaveCriticalSection(&ctx->cs);
}
return ERROR_SUCCESS;
}
void
SrvExit(PVOID Hdl)
{
SrvCtx_t *ctx = (SrvCtx_t *) Hdl;
if (ctx != NULL) {
SrvUtilExit(ctx);
// must do this last!
if (ctx->packet_pool)
xfree(ctx->packet_pool);
if (ctx->buffer_pool)
xfree(ctx->buffer_pool);
free(ctx);
}
}