//+---------------------------------------------------------------------------- // File: array.cxx // // Synopsis: // //----------------------------------------------------------------------------- // Includes ------------------------------------------------------------------- #include #include //+---------------------------------------------------------------------------- // // Member: AddClass // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- HRESULT CLicenseManager::AddClass( REFCLSID rclsid, int * piLic) { IClassFactory2 * pcf2 = NULL; LICINFO licinfo; BSTR bstrLic; int iLic; HRESULT hr; Assert(piLic); Assert(!FindClass(rclsid, &iLic)); // Get the class factory for the CLSID hr = ::CoGetClassObject(rclsid, CLSCTX_INPROC_SERVER | CLSCTX_INPROC_HANDLER | CLSCTX_LOCAL_SERVER, NULL, IID_IClassFactory2, (void **)&pcf2); if (hr) goto Cleanup; Assert(pcf2); // Determine if the object supports the creation of runtime licenses licinfo.cbLicInfo = sizeof(LICINFO); hr = pcf2->GetLicInfo(&licinfo); if (hr) goto Cleanup; if (!licinfo.fRuntimeKeyAvail || !licinfo.fLicVerified) { hr = CLASS_E_NOTLICENSED; goto Cleanup; } // Obtain the object's runtime license hr = pcf2->RequestLicKey(0, &bstrLic); if (hr) goto Cleanup; Assert(bstrLic); // Add the object and its runtime license to the array of CLSID-License pairs // (The class is added in ascending order based upon the first DWORD of the CLSID) hr = _aryLic.SetSize(_aryLic.Size()+1); if (hr) goto Cleanup; for (iLic = 0; iLic < (_aryLic.Size()-1); iLic++) { if (rclsid.Data1 < _aryLic[iLic].clsid.Data1) break; } if (iLic < (_aryLic.Size()-1)) { ::memmove(&_aryLic[iLic+1], &_aryLic[iLic], sizeof(_aryLic[0])*(_aryLic.Size()-iLic-1)); } _aryLic[iLic].clsid = rclsid; _aryLic[iLic].bstrLic = bstrLic; _aryLic[iLic].pcf2 = pcf2; pcf2 = NULL; *piLic = iLic; _fDirty = TRUE; Cleanup: ::SRelease(pcf2); return hr; } //+---------------------------------------------------------------------------- // // Member: FindClass // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- BOOL CLicenseManager::FindClass( REFCLSID rclsid, int * piLic) { int iLic; Assert(piLic); // BUGBUG: Consider using a more efficient search if the number of classes is large for (iLic=0; iLic < _aryLic.Size(); iLic++) { if (_aryLic[iLic].clsid.Data1 == rclsid.Data1 && _aryLic[iLic].clsid == rclsid) break; } if (iLic < _aryLic.Size()) { *piLic = iLic; } return (iLic < _aryLic.Size()); } //+---------------------------------------------------------------------------- // // Member: OnChangeInRequiredClasses // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::OnChangeInRequiredClasses( IRequireClasses * pRequireClasses) { ULONG cClasses; ULONG iClass; int cLic; int iLic; CLSID clsid; BOOL fClassUsed; BOOL fClassNotLicensed = FALSE; HRESULT hr; if (!pRequireClasses) return E_INVALIDARG; // Determine the number of required classes hr = pRequireClasses->CountRequiredClasses(&cClasses); if (hr) goto Cleanup; // Add new classes to the array of required classes // NOTE: During this pass, all required classes are also marked as "in use" // Because of this, the second loop must also alway run, even when errors occur, // to remove these marks; that is, this loop cannot "goto Cleanup" for (iClass = 0; iClass < cClasses; iClass++) { // Get the CLSID of the required class hr = pRequireClasses->GetRequiredClasses(iClass, &clsid); if (hr) break; // Check if the class is already known; if not, add it // (Ignore "false" errors which occur during adding the class and treat it as unlicensed) fClassUsed = TRUE; // Assume the class will be used if (!FindClass(clsid, &iLic)) { hr = AddClass(clsid, &iLic); if (hr) { if (hr == E_OUTOFMEMORY) break; fClassUsed = FALSE; // Class was not found nor added fClassNotLicensed = TRUE; hr = S_OK; } } // Mark the class as "in use" by setting the high-order bit of the factory address if (fClassUsed) { Assert((ULONG)(_aryLic[iLic].pcf2) < (ULONG_PTR)ADDRESS_TAG_BIT); _aryLic[iLic].pcf2 = (IClassFactory2 *)((ULONG_PTR)(_aryLic[iLic].pcf2) | ADDRESS_TAG_BIT); } } // Remove from the array classes no longer required // NOTE: If hr is not S_OK, then this loop should still execute, but only to clear // the mark bits on the IClassFactory2 interface pointers, no other changes // should occur // Also, early exits from this loop (using "break" for example) must not occur for (cLic = iLic = 0; iLic < _aryLic.Size(); iLic++) { // If the class is "in use", clear the mark bit if ((ULONG_PTR)(_aryLic[iLic].pcf2) & ADDRESS_TAG_BIT) { _aryLic[iLic].pcf2 = (IClassFactory2 *)((ULONG_PTR)(_aryLic[iLic].pcf2) & (ADDRESS_TAG_BIT-1)); // If classes have been removed, shift this class down to the first open slot if (!hr && iLic > cLic) { _aryLic[cLic] = _aryLic[iLic]; ::memset(&(_aryLic[iLic]), 0, sizeof(_aryLic[iLic])); } } // Otherwise, free the class and remove it from the array else if (!hr) { ::SysFreeString(_aryLic[iLic].bstrLic); ::SRelease(_aryLic[iLic].pcf2); ::memset(&(_aryLic[iLic]), 0, sizeof(_aryLic[iLic])); _fDirty = TRUE; } // As long as it points at a valid class, increment the class counter if (_aryLic[cLic].clsid != CLSID_NULL) { cLic++; } } Implies(hr, cLic == _aryLic.Size()); Implies(!hr, (ULONG)cLic <= cClasses); Verify(SUCCEEDED(_aryLic.SetSize(cLic))); Cleanup: // If a real error occurred, return it // Otherwise return CLASS_E_NOTLICENSED if any un-licensed objects were encountered return (hr ? hr : (fClassNotLicensed ? CLASS_E_NOTLICENSED : S_OK)); } //+---------------------------------------------------------------------------- // // Member: CreateInstance // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::CreateInstance( CLSID clsid, IUnknown * pUnkOuter, REFIID riid, DWORD dwClsCtx, void ** ppvObj) { int iLic; HRESULT hr; // If there is a runtime license for the class, create it using IClassFactory2 if (FindClass(clsid, &iLic)) { if (!_aryLic[iLic].pcf2) { // // The following code calls CoGetClassObject for an IClassFactory // then QIs for an IClassFactory2. This is because of an apparent // bug in ole32.dll. On a win95 system if the call to // CoGetClassObject is remoted and you ask for IClassFactory2 the // process hangs. // IClassFactory *pIClassFactory; hr = ::CoGetClassObject(clsid, dwClsCtx, NULL, IID_IClassFactory, (void **)&(pIClassFactory)); if (SUCCEEDED(hr)) { hr = pIClassFactory->QueryInterface(IID_IClassFactory2, (void **)&(_aryLic[iLic].pcf2)); pIClassFactory->Release(); } if (hr) goto Cleanup; } Assert(_aryLic[iLic].pcf2); Assert(_aryLic[iLic].bstrLic != NULL); hr = _aryLic[iLic].pcf2->CreateInstanceLic(pUnkOuter, NULL, riid, _aryLic[iLic].bstrLic, ppvObj); } // Otherwise, use the standard COM mechanisms else { hr = ::CoCreateInstance(clsid, pUnkOuter, dwClsCtx, riid, ppvObj); } Cleanup: return hr; } //+---------------------------------------------------------------------------- // // Member: GetTypeLibOfClsid // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::GetTypeLibOfClsid( CLSID clsid, ITypeLib ** ptlib) { UNREF(clsid); UNREF(ptlib); return E_NOTIMPL; } //+---------------------------------------------------------------------------- // // Member: GetClassObjectOfClsid // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::GetClassObjectOfClsid( REFCLSID rclsid, DWORD dwClsCtx, LPVOID lpReserved, REFIID riid, void ** ppcClassObject) { // Load the class object return ::CoGetClassObject(rclsid, dwClsCtx, lpReserved, riid, ppcClassObject); } //+---------------------------------------------------------------------------- // // Member: CountRequiredClasses // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::CountRequiredClasses( ULONG * pcClasses) { if (!pcClasses) return E_INVALIDARG; // Return the current number of classes *pcClasses = _aryLic.Size(); return S_OK; } //+---------------------------------------------------------------------------- // // Member: GetRequiredClasses // // Synopsis: // // Arguments: // // Returns: // //----------------------------------------------------------------------------- STDMETHODIMP CLicenseManager::GetRequiredClasses( ULONG iClass, CLSID * pclsid) { if (!pclsid || iClass >= (ULONG)_aryLic.Size()) return E_INVALIDARG; // Return the requested CLSID *pclsid = _aryLic[iClass].clsid; return S_OK; }