windows-nt/Source/XPSP1/NT/shell/iecontrols/licmgr/mgr/array.cxx
2020-09-26 16:20:57 +08:00

415 lines
11 KiB
C++

//+----------------------------------------------------------------------------
// File: array.cxx
//
// Synopsis:
//
//-----------------------------------------------------------------------------
// Includes -------------------------------------------------------------------
#include <mgr.hxx>
#include <factory.hxx>
//+----------------------------------------------------------------------------
//
// Member: AddClass
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
HRESULT
CLicenseManager::AddClass(
REFCLSID rclsid,
int * piLic)
{
IClassFactory2 * pcf2 = NULL;
LICINFO licinfo;
BSTR bstrLic;
int iLic;
HRESULT hr;
Assert(piLic);
Assert(!FindClass(rclsid, &iLic));
// Get the class factory for the CLSID
hr = ::CoGetClassObject(rclsid,
CLSCTX_INPROC_SERVER | CLSCTX_INPROC_HANDLER | CLSCTX_LOCAL_SERVER,
NULL,
IID_IClassFactory2, (void **)&pcf2);
if (hr)
goto Cleanup;
Assert(pcf2);
// Determine if the object supports the creation of runtime licenses
licinfo.cbLicInfo = sizeof(LICINFO);
hr = pcf2->GetLicInfo(&licinfo);
if (hr)
goto Cleanup;
if (!licinfo.fRuntimeKeyAvail ||
!licinfo.fLicVerified)
{
hr = CLASS_E_NOTLICENSED;
goto Cleanup;
}
// Obtain the object's runtime license
hr = pcf2->RequestLicKey(0, &bstrLic);
if (hr)
goto Cleanup;
Assert(bstrLic);
// Add the object and its runtime license to the array of CLSID-License pairs
// (The class is added in ascending order based upon the first DWORD of the CLSID)
hr = _aryLic.SetSize(_aryLic.Size()+1);
if (hr)
goto Cleanup;
for (iLic = 0; iLic < (_aryLic.Size()-1); iLic++)
{
if (rclsid.Data1 < _aryLic[iLic].clsid.Data1)
break;
}
if (iLic < (_aryLic.Size()-1))
{
::memmove(&_aryLic[iLic+1], &_aryLic[iLic], sizeof(_aryLic[0])*(_aryLic.Size()-iLic-1));
}
_aryLic[iLic].clsid = rclsid;
_aryLic[iLic].bstrLic = bstrLic;
_aryLic[iLic].pcf2 = pcf2;
pcf2 = NULL;
*piLic = iLic;
_fDirty = TRUE;
Cleanup:
::SRelease(pcf2);
return hr;
}
//+----------------------------------------------------------------------------
//
// Member: FindClass
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
BOOL
CLicenseManager::FindClass(
REFCLSID rclsid,
int * piLic)
{
int iLic;
Assert(piLic);
// BUGBUG: Consider using a more efficient search if the number of classes is large
for (iLic=0; iLic < _aryLic.Size(); iLic++)
{
if (_aryLic[iLic].clsid.Data1 == rclsid.Data1 &&
_aryLic[iLic].clsid == rclsid)
break;
}
if (iLic < _aryLic.Size())
{
*piLic = iLic;
}
return (iLic < _aryLic.Size());
}
//+----------------------------------------------------------------------------
//
// Member: OnChangeInRequiredClasses
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::OnChangeInRequiredClasses(
IRequireClasses * pRequireClasses)
{
ULONG cClasses;
ULONG iClass;
int cLic;
int iLic;
CLSID clsid;
BOOL fClassUsed;
BOOL fClassNotLicensed = FALSE;
HRESULT hr;
if (!pRequireClasses)
return E_INVALIDARG;
// Determine the number of required classes
hr = pRequireClasses->CountRequiredClasses(&cClasses);
if (hr)
goto Cleanup;
// Add new classes to the array of required classes
// NOTE: During this pass, all required classes are also marked as "in use"
// Because of this, the second loop must also alway run, even when errors occur,
// to remove these marks; that is, this loop cannot "goto Cleanup"
for (iClass = 0; iClass < cClasses; iClass++)
{
// Get the CLSID of the required class
hr = pRequireClasses->GetRequiredClasses(iClass, &clsid);
if (hr)
break;
// Check if the class is already known; if not, add it
// (Ignore "false" errors which occur during adding the class and treat it as unlicensed)
fClassUsed = TRUE; // Assume the class will be used
if (!FindClass(clsid, &iLic))
{
hr = AddClass(clsid, &iLic);
if (hr)
{
if (hr == E_OUTOFMEMORY)
break;
fClassUsed = FALSE; // Class was not found nor added
fClassNotLicensed = TRUE;
hr = S_OK;
}
}
// Mark the class as "in use" by setting the high-order bit of the factory address
if (fClassUsed)
{
Assert((ULONG)(_aryLic[iLic].pcf2) < (ULONG_PTR)ADDRESS_TAG_BIT);
_aryLic[iLic].pcf2 = (IClassFactory2 *)((ULONG_PTR)(_aryLic[iLic].pcf2) | ADDRESS_TAG_BIT);
}
}
// Remove from the array classes no longer required
// NOTE: If hr is not S_OK, then this loop should still execute, but only to clear
// the mark bits on the IClassFactory2 interface pointers, no other changes
// should occur
// Also, early exits from this loop (using "break" for example) must not occur
for (cLic = iLic = 0; iLic < _aryLic.Size(); iLic++)
{
// If the class is "in use", clear the mark bit
if ((ULONG_PTR)(_aryLic[iLic].pcf2) & ADDRESS_TAG_BIT)
{
_aryLic[iLic].pcf2 = (IClassFactory2 *)((ULONG_PTR)(_aryLic[iLic].pcf2) & (ADDRESS_TAG_BIT-1));
// If classes have been removed, shift this class down to the first open slot
if (!hr && iLic > cLic)
{
_aryLic[cLic] = _aryLic[iLic];
::memset(&(_aryLic[iLic]), 0, sizeof(_aryLic[iLic]));
}
}
// Otherwise, free the class and remove it from the array
else if (!hr)
{
::SysFreeString(_aryLic[iLic].bstrLic);
::SRelease(_aryLic[iLic].pcf2);
::memset(&(_aryLic[iLic]), 0, sizeof(_aryLic[iLic]));
_fDirty = TRUE;
}
// As long as it points at a valid class, increment the class counter
if (_aryLic[cLic].clsid != CLSID_NULL)
{
cLic++;
}
}
Implies(hr, cLic == _aryLic.Size());
Implies(!hr, (ULONG)cLic <= cClasses);
Verify(SUCCEEDED(_aryLic.SetSize(cLic)));
Cleanup:
// If a real error occurred, return it
// Otherwise return CLASS_E_NOTLICENSED if any un-licensed objects were encountered
return (hr
? hr
: (fClassNotLicensed
? CLASS_E_NOTLICENSED
: S_OK));
}
//+----------------------------------------------------------------------------
//
// Member: CreateInstance
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::CreateInstance(
CLSID clsid,
IUnknown * pUnkOuter,
REFIID riid,
DWORD dwClsCtx,
void ** ppvObj)
{
int iLic;
HRESULT hr;
// If there is a runtime license for the class, create it using IClassFactory2
if (FindClass(clsid, &iLic))
{
if (!_aryLic[iLic].pcf2)
{
//
// The following code calls CoGetClassObject for an IClassFactory
// then QIs for an IClassFactory2. This is because of an apparent
// bug in ole32.dll. On a win95 system if the call to
// CoGetClassObject is remoted and you ask for IClassFactory2 the
// process hangs.
//
IClassFactory *pIClassFactory;
hr = ::CoGetClassObject(clsid, dwClsCtx, NULL,
IID_IClassFactory, (void **)&(pIClassFactory));
if (SUCCEEDED(hr)) {
hr = pIClassFactory->QueryInterface(IID_IClassFactory2,
(void **)&(_aryLic[iLic].pcf2));
pIClassFactory->Release();
}
if (hr)
goto Cleanup;
}
Assert(_aryLic[iLic].pcf2);
Assert(_aryLic[iLic].bstrLic != NULL);
hr = _aryLic[iLic].pcf2->CreateInstanceLic(pUnkOuter, NULL,
riid, _aryLic[iLic].bstrLic, ppvObj);
}
// Otherwise, use the standard COM mechanisms
else
{
hr = ::CoCreateInstance(clsid, pUnkOuter, dwClsCtx, riid, ppvObj);
}
Cleanup:
return hr;
}
//+----------------------------------------------------------------------------
//
// Member: GetTypeLibOfClsid
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::GetTypeLibOfClsid(
CLSID clsid,
ITypeLib ** ptlib)
{
UNREF(clsid);
UNREF(ptlib);
return E_NOTIMPL;
}
//+----------------------------------------------------------------------------
//
// Member: GetClassObjectOfClsid
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::GetClassObjectOfClsid(
REFCLSID rclsid,
DWORD dwClsCtx,
LPVOID lpReserved,
REFIID riid,
void ** ppcClassObject)
{
// Load the class object
return ::CoGetClassObject(rclsid, dwClsCtx, lpReserved, riid, ppcClassObject);
}
//+----------------------------------------------------------------------------
//
// Member: CountRequiredClasses
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::CountRequiredClasses(
ULONG * pcClasses)
{
if (!pcClasses)
return E_INVALIDARG;
// Return the current number of classes
*pcClasses = _aryLic.Size();
return S_OK;
}
//+----------------------------------------------------------------------------
//
// Member: GetRequiredClasses
//
// Synopsis:
//
// Arguments:
//
// Returns:
//
//-----------------------------------------------------------------------------
STDMETHODIMP
CLicenseManager::GetRequiredClasses(
ULONG iClass,
CLSID * pclsid)
{
if (!pclsid || iClass >= (ULONG)_aryLic.Size())
return E_INVALIDARG;
// Return the requested CLSID
*pclsid = _aryLic[iClass].clsid;
return S_OK;
}