/* * Copyright (c) 2014 Eran Pe'er. * * This program is made available under the terms of the MIT License. * * Created on Mar 10, 2014 */ #pragma once #include "mockutils/VTUtils.hpp" namespace fakeit { typedef unsigned long DWORD; struct TypeDescriptor { TypeDescriptor() : ptrToVTable(0), spare(0) { // ptrToVTable should contain the pointer to the virtual table of the type type_info!!! int **tiVFTPtr = (int **) (&typeid(void)); int *i = (int *) tiVFTPtr[0]; int type_info_vft_ptr = (int)i; ptrToVTable = type_info_vft_ptr; } DWORD ptrToVTable; DWORD spare; char name[8]; }; struct PMD { /************************************************************************/ /* member displacement. /* For a simple inheritance structure the member displacement is always 0. /* since since the first member is placed at 0. /* In the case of multiple inheritance, this value may have a positive value. /************************************************************************/ int mdisp; int pdisp; // vtable displacement int vdisp; //displacement inside vtable PMD() : mdisp(0), pdisp(-1), vdisp(0) { } }; struct RTTIBaseClassDescriptor { RTTIBaseClassDescriptor() : pTypeDescriptor(nullptr), numContainedBases(0), attributes(0) { } const std::type_info *pTypeDescriptor; //type descriptor of the class DWORD numContainedBases; //number of nested classes following in the Base Class Array struct PMD where; //pointer-to-member displacement info DWORD attributes; //flags, usually 0 }; template struct RTTIClassHierarchyDescriptor { RTTIClassHierarchyDescriptor() : signature(0), attributes(0), numBaseClasses(0), pBaseClassArray(nullptr) { pBaseClassArray = new RTTIBaseClassDescriptor *[1 + sizeof...(baseclasses)]; addBaseClass < C, baseclasses...>(); } ~RTTIClassHierarchyDescriptor() { for (int i = 0; i < 1 + sizeof...(baseclasses); i++) { RTTIBaseClassDescriptor *desc = pBaseClassArray[i]; delete desc; } delete[] pBaseClassArray; } DWORD signature; //always zero? DWORD attributes; //bit 0 set = multiple inheritance, bit 1 set = virtual inheritance DWORD numBaseClasses; //number of classes in pBaseClassArray RTTIBaseClassDescriptor **pBaseClassArray; template void addBaseClass() { static_assert(std::is_base_of::value, "C must be a derived class of BaseType"); RTTIBaseClassDescriptor *desc = new RTTIBaseClassDescriptor(); desc->pTypeDescriptor = &typeid(BaseType); pBaseClassArray[numBaseClasses] = desc; for (unsigned int i = 0; i < numBaseClasses; i++) { pBaseClassArray[i]->numContainedBases++; } numBaseClasses++; } template void addBaseClass() { static_assert(std::is_base_of::value, "invalid inheritance list"); addBaseClass(); addBaseClass(); } }; template struct RTTICompleteObjectLocator { RTTICompleteObjectLocator(const std::type_info &info) : signature(0), offset(0), cdOffset(0), pTypeDescriptor(&info), pClassDescriptor(new RTTIClassHierarchyDescriptor()) { } ~RTTICompleteObjectLocator() { delete pClassDescriptor; } DWORD signature; //always zero ? DWORD offset; //offset of this vtable in the complete class DWORD cdOffset; //constructor displacement offset const std::type_info *pTypeDescriptor; //TypeDescriptor of the complete class struct RTTIClassHierarchyDescriptor *pClassDescriptor; //describes inheritance hierarchy }; struct VirtualTableBase { static VirtualTableBase &getVTable(void *instance) { fakeit::VirtualTableBase *vt = (fakeit::VirtualTableBase *) (instance); return *vt; } VirtualTableBase(void **firstMethod) : _firstMethod(firstMethod) { } void *getCookie(int index) { return _firstMethod[-2 - index]; } void setCookie(int index, void *value) { _firstMethod[-2 - index] = value; } void *getMethod(unsigned int index) const { return _firstMethod[index]; } void setMethod(unsigned int index, void *method) { _firstMethod[index] = method; } protected: void **_firstMethod; }; template struct VirtualTable : public VirtualTableBase { class Handle { friend struct VirtualTable; void **firstMethod; Handle(void **firstMethod) : firstMethod(firstMethod) { } public: VirtualTable &restore() { VirtualTable *vt = (VirtualTable *) this; return *vt; } }; static VirtualTable &getVTable(C &instance) { fakeit::VirtualTable *vt = (fakeit::VirtualTable *) (&instance); return *vt; } void copyFrom(VirtualTable &from) { unsigned int size = VTUtils::getVTSize(); for (unsigned int i = 0; i < size; i++) { _firstMethod[i] = from.getMethod(i); } } VirtualTable() : VirtualTable(buildVTArray()) { } ~VirtualTable() { } void dispose() { _firstMethod--; // skip objectLocator RTTICompleteObjectLocator *locator = (RTTICompleteObjectLocator *) _firstMethod[0]; delete locator; _firstMethod -= numOfCookies; // skip cookies delete[] _firstMethod; } // the dtor VC++ must of the format: int dtor(int) unsigned int dtor(int) { C *c = (C *) this; C &cRef = *c; auto vt = VirtualTable::getVTable(cRef); void *dtorPtr = vt.getCookie(numOfCookies - 1); // read the last cookie void(*method)(C *) = reinterpret_cast(dtorPtr); method(c); return 0; } void setDtor(void *method) { // the dtor VC++ must of the format: int dtor(int). // the method passed by the user is: void dtor(). // store the user method in a cookie and put the // correct format method in the virtual table. // the method stored in the vt will call the method in the cookie when invoked. void *dtorPtr = union_cast(&VirtualTable::dtor); unsigned int index = VTUtils::getDestructorOffset(); _firstMethod[index] = dtorPtr; setCookie(numOfCookies - 1, method); // use the last cookie } unsigned int getSize() { return VTUtils::getVTSize(); } void initAll(void *value) { auto size = getSize(); for (unsigned int i = 0; i < size; i++) { setMethod(i, value); } } Handle createHandle() { Handle h(_firstMethod); return h; } private: class SimpleType { }; static_assert(sizeof(unsigned int (SimpleType::*)()) == sizeof(unsigned int (C::*)()), "Can't mock a type with multiple inheritance"); static const unsigned int numOfCookies = 3; static void **buildVTArray() { int vtSize = VTUtils::getVTSize(); auto array = new void *[vtSize + numOfCookies + 1]{}; RTTICompleteObjectLocator *objectLocator = new RTTICompleteObjectLocator( typeid(C)); array += numOfCookies; // skip cookies array[0] = objectLocator; // initialize RTTICompleteObjectLocator pointer array++; // skip object locator return array; } VirtualTable(void **firstMethod) : VirtualTableBase(firstMethod) { } }; }