// Copyright Epic Games, Inc. All Rights Reserved. #include "llvm/Transforms/Instrumentation/CustomMemoryInstrumentation.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Demangle/Demangle.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/IRReader/IRReader.h" #include "llvm/InitializePasses.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/ModRef.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/EscapeEnumerator.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" namespace llvm { // Atomic base (generic shared implementation). Regex MSVCStdAtomicLoadRegex(".*std::_Atomic_storage.*::load.*"); Regex MSVCStdAtomicImplicitLoadRegex( ".*std::atomic<.*>::operator .*\\(void\\).*"); Regex MSVCStdAtomicStoreRegex(".*std::_Atomic_storage.*::store.*"); Regex MSVCStdAtomicExchangeRegex(".*std::_Atomic_storage.*::exchange.*"); Regex MSVCStdAtomicCompareExchangeRegex( ".*std::_Atomic_storage.*::compare_exchange_.*"); // Atomic integrals. Regex MSVCStdAtomicFetchAddRegex(".*std::_Atomic_integral.*::fetch_add.*"); Regex MSVCStdAtomicFetchSubRegex(".*std::_Atomic_integral.*::fetch_sub.*"); Regex MSVCStdAtomicFetchAndRegex(".*std::_Atomic_integral.*::fetch_and.*"); Regex MSVCStdAtomicFetchOrRegex(".*std::_Atomic_integral.*::fetch_or.*"); Regex MSVCStdAtomicFetchXorRegex(".*std::_Atomic_integral.*::fetch_xor.*"); // Atomic pointers. Regex MSVCStdAtomicPointerFetchAddRegex(".*std::_Atomic_pointer.*::fetch_add.*"); Regex MSVCStdAtomicPointerFetchSubRegex(".*std::_Atomic_pointer.*::fetch_sub.*"); SmallVector> MSVCAtomicCallSites = { {&MSVCStdAtomicLoadRegex, AtomicCallSite::LoadSite(0, 1)}, {&MSVCStdAtomicImplicitLoadRegex, AtomicCallSite::LoadSite(0)}, {&MSVCStdAtomicStoreRegex, AtomicCallSite::StoreSite(0, 1, 1, 2)}, {&MSVCStdAtomicExchangeRegex, AtomicCallSite::ExchangeSite(0, 1, 1, 2)}, {&MSVCStdAtomicCompareExchangeRegex, AtomicCallSite::CompareExchangeSite(0, 2, 1, 2, 3, 4)}, {&MSVCStdAtomicFetchAddRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Add, 0, 1, 1, 2)}, {&MSVCStdAtomicFetchSubRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Sub, 0, 1, 1, 2)}, {&MSVCStdAtomicFetchAndRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::And, 0, 1, 1, 2)}, {&MSVCStdAtomicFetchOrRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Or, 0, 1, 1, 2)}, {&MSVCStdAtomicFetchXorRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Xor, 0, 1, 1, 2)}, // Atomic pointers FetchAdd and FetchSub require pointer arithmetic. {&MSVCStdAtomicPointerFetchAddRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Add, 0, 1, 1, 2, true)}, {&MSVCStdAtomicPointerFetchSubRegex, AtomicCallSite::RMWSite(AtomicRMWInst::BinOp::Sub, 0, 1, 1, 2, true)}, }; uint32_t GetRealNumCallOperands(CallInst *Call) { return Call->getNumOperands() - 1; } bool IsRMWOpHandled(RMWBinOp Op) { switch (Op) { case RMWBinOp::Add: case RMWBinOp::Sub: case RMWBinOp::And: case RMWBinOp::Or: case RMWBinOp::Xor: return true; default: return false; } } struct MemoryAccessFlags { uint8_t bAtomic : 1; }; enum FAtomicMemoryOrder : int8_t { MEMORY_ORDER_RELAXED, MEMORY_ORDER_CONSUME, MEMORY_ORDER_ACQUIRE, MEMORY_ORDER_RELEASE, MEMORY_ORDER_ACQ_REL, MEMORY_ORDER_SEQ_CST }; FAtomicMemoryOrder MemoryOrderFromLLVMOrdering(const AtomicOrdering &Ordering) { switch (Ordering) { case AtomicOrdering::Acquire: return MEMORY_ORDER_ACQUIRE; case AtomicOrdering::Release: return MEMORY_ORDER_RELEASE; case AtomicOrdering::AcquireRelease: return MEMORY_ORDER_ACQ_REL; case AtomicOrdering::Unordered: case AtomicOrdering::Monotonic: return MEMORY_ORDER_RELAXED; case AtomicOrdering::SequentiallyConsistent: return MEMORY_ORDER_SEQ_CST; case AtomicOrdering::NotAtomic: default: assert(false); } llvm_unreachable("Should have a memory order."); } FAtomicMemoryOrder MemoryOrderFromInst(Instruction *Inst) { AtomicOrdering Ordering = AtomicOrdering::NotAtomic; if (StoreInst *Store = dyn_cast(Inst)) { Ordering = Store->getOrdering(); } else if (LoadInst *Load = dyn_cast(Inst)) { Ordering = Load->getOrdering(); } else if (AtomicRMWInst *RMW = dyn_cast(Inst)) { Ordering = RMW->getOrdering(); } return MemoryOrderFromLLVMOrdering(Ordering); } static_assert(sizeof(MemoryAccessFlags) == sizeof(uint8_t)); MemoryAccessFlags GetMemoryAccessFlags(Instruction *Inst) { MemoryAccessFlags Flags; Flags.bAtomic = false; if (StoreInst *Store = dyn_cast(Inst)) { if (Store->isAtomic()) { Flags.bAtomic = (getAtomicSyncScopeID(Inst) != SyncScope::SingleThread); } } else if (LoadInst *Store = dyn_cast(Inst)) { if (Store->isAtomic()) { Flags.bAtomic = (getAtomicSyncScopeID(Inst) != SyncScope::SingleThread); } } return Flags; } std::string RMWOpName(RMWBinOp Op) { std::string Str = AtomicRMWInst::getOperationName(Op).str(); Str[0] = std::toupper(Str[0]); return Str; } Value *CustomMemoryInstrumentationPass::CreateCast(IRBuilder<> &Builder, Value *Val, Type *DesiredType) { if (Val->getType() == DesiredType) { return Val; } if (Val->getType() == Builder.getInt1Ty() && DesiredType == Builder.getInt8Ty()) { return Builder.CreateIntCast(Val, DesiredType, false); } if (Val->getType() == Builder.getInt8Ty() && DesiredType == Builder.getInt1Ty()) { return Builder.CreateIntCast(Val, DesiredType, false); } uint32_t Size = CurrentModule->getDataLayout().getTypeStoreSize(Val->getType()); uint32_t DesiredSize = CurrentModule->getDataLayout().getTypeStoreSize(DesiredType); if (Size == DesiredSize) { return Builder.CreateBitOrPointerCast(Val, DesiredType); } errs() << "Cast not supported\n"; assert(false && "Cast not supported"); return nullptr; } Type *GetSretType(CallInst *Call) { Type *Typ = nullptr; Function *Func = Call->getCalledFunction(); if (Func->hasStructRetAttr()) { Typ = Func->getParamStructRetType(0); if (!Typ) { Typ = Func->getParamStructRetType(1); } } return Typ; } uint64_t GetPointeeSizeFromMSVCAtomicPointerFetchAddCall(Function &MSVCFetchAdd) { // Find the 'mul' instruction that contains the pointee size. for (auto &BasicBlock : MSVCFetchAdd) { for (auto &Instruction : BasicBlock) { if (BinaryOperator *BinOp = dyn_cast(&Instruction)) { if (BinOp->getOpcode() == Instruction::Mul) { Value *Op = BinOp->getOperand(1); if (ConstantInt *ConstOp = dyn_cast(Op)) { return ConstOp->getSExtValue(); } } } } } return 0; } uint64_t GetPointeeSizeFromMSVCAtomicPointerFetchSubCall(Function &MSVCFetchSub) { // MSVC's fetch_sub ends up calling fetch_add. Find the call to fetch_add. for (auto &BasicBlock : MSVCFetchSub) { for (auto &Instruction : BasicBlock) { if (CallInst *Call = dyn_cast(&Instruction)) { std::string FunctionName = demangle(Call->getCalledFunction()->getName().str()); if (MSVCStdAtomicPointerFetchAddRegex.match(FunctionName)) { return GetPointeeSizeFromMSVCAtomicPointerFetchAddCall( *Call->getCalledFunction()); } } } } return 0; } uint64_t CustomMemoryInstrumentationPass::CacheOrGetPointeeSizeForMSVCAtomicPointerRMW( CallInst *MSVCCall, RMWBinOp RMWOp) { Function *MSVCFunction = MSVCCall->getCalledFunction(); auto CachedPointeeSize = AtomicPointeeSizeCache.find(MSVCFunction); if (CachedPointeeSize != AtomicPointeeSizeCache.end()) { return CachedPointeeSize->second; } uint64_t PointeeSize = 0; if (RMWOp == RMWBinOp::Add) { PointeeSize = GetPointeeSizeFromMSVCAtomicPointerFetchAddCall(*MSVCFunction); } else { PointeeSize = GetPointeeSizeFromMSVCAtomicPointerFetchSubCall(*MSVCFunction); } AtomicPointeeSizeCache[MSVCFunction] = PointeeSize; return PointeeSize; } CustomMemoryInstrumentationPass::CustomMemoryInstrumentationPass( bool MSVCStandardLibPrepass) : MSVCStandardLibPrepass(MSVCStandardLibPrepass) { for (const auto &ExcludeRegex : Options.ExcludedFunctionNameRegexes.keys()) { CachedExcludedFunctionRegexes.push_back(Regex(ExcludeRegex)); } } CustomMemoryInstrumentationPass::CustomMemoryInstrumentationPass( const CustomMemoryInstrumentationOptions &Options, bool MSVCStandardLibPrepass) : Options(Options), MSVCStandardLibPrepass(MSVCStandardLibPrepass) {} PreservedAnalyses CustomMemoryInstrumentationPass::run(Module &M, ModuleAnalysisManager &AM) { bool Instrumented = false; if (shouldInstrumentModule(M)) { CurrentModule = &M; cacheInstrumentationFunctions(M); if (MSVCStandardLibPrepass) { Instrumented |= instrumentMSVCStandardLib(M); } else { Instrumented = instrumentModule(M); } if (Instrumented && verifyModule(M, &errs())) { errs() << "Broken module\n" //<< M ; exit(1); } CurrentModule = nullptr; } return Instrumented ? PreservedAnalyses::none() : PreservedAnalyses::all(); } bool CustomMemoryInstrumentationPass::shouldInstrumentModule(Module &M) { std::string Filename = M.getSourceFileName(); SmallString<256> CurrentModuleFilename = StringRef(Filename); sys::fs::make_absolute(CurrentModuleFilename); bool Included = Options.IncludedModulesRegexes.empty(); for (const auto &IncludeRegex : Options.IncludedModulesRegexes.keys()) { Regex Reg(IncludeRegex); if (Reg.match(CurrentModuleFilename)) { Included = true; break; } } if (!Included) { return false; } for (const auto &ExcludeRegex : Options.FurtherExcludedModulesRegexes.keys()) { Regex Reg(ExcludeRegex); if (Reg.match(CurrentModuleFilename)) { Included = false; break; } } return Included; } bool CustomMemoryInstrumentationPass::shouldInstrumentFunction(Function &F) { for (auto &InstrumentFunction : InstrumentFunctions) { if (F.getName() == InstrumentFunction->getCallee()->getName()) { return false; } } std::string DemangledName = demangle(F.getName().str()); for (const auto &Reg : CachedExcludedFunctionRegexes) { if (Reg.match(DemangledName)) { return false; } } if (F.hasFnAttribute(Attribute::Naked)) { return false; } if (F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) { return false; } // Apply the SanitizeThread attribute to any function we instrument // to prevent SimpleCFG to speculate some instructions and cause // race condition that wouldn't exist otherwise. // See llvm::mustSuppressSpeculation if (!F.hasFnAttribute(Attribute::SanitizeThread)) { F.addFnAttr(Attribute::SanitizeThread); } return true; } void CustomMemoryInstrumentationPass::cacheInstrumentationFunctions(Module &M) { IRBuilder Builder(M.getContext()); AttributeList Attrs; Attrs = Attrs.addFnAttribute(M.getContext(), Attribute::NoUnwind); Attrs = Attrs.addFnAttribute(M.getContext(), Attribute::NoInline); Attrs = Attrs.addFnAttribute(M.getContext(), Attribute::DisableSanitizerInstrumentation); // Function entry/exit. FuncEntryInstrumentFunction = M.getOrInsertFunction("__Instrument_FuncEntry", Attrs, Builder.getVoidTy(), Builder.getPtrTy()); FuncExitInstrumentFunction = M.getOrInsertFunction( "__Instrument_FuncExit", Attrs, Builder.getVoidTy()); InstrumentFunctions.push_back(&FuncEntryInstrumentFunction); InstrumentFunctions.push_back(&FuncExitInstrumentFunction); // Virtual Ptr Load and Store StoreVPtrInstrumentFunction = M.getOrInsertFunction("__Instrument_VPtr_Store", Attrs, Builder.getVoidTy(), Builder.getPtrTy(), Builder.getPtrTy()); LoadVPtrInstrumentFunction = M.getOrInsertFunction("__Instrument_VPtr_Load", Attrs, Builder.getVoidTy(), Builder.getPtrTy()); // Non-atomic loads/stores. StoreInstrumentFunction = M.getOrInsertFunction("__Instrument_Store", Attrs, Builder.getVoidTy(), Builder.getInt64Ty(), Builder.getInt32Ty()); LoadInstrumentFunction = M.getOrInsertFunction("__Instrument_Load", Attrs, Builder.getVoidTy(), Builder.getInt64Ty(), Builder.getInt32Ty()); StoreRangeInstrumentFunction = M.getOrInsertFunction( "__Instrument_StoreRange", Attrs, Builder.getVoidTy(), Builder.getInt64Ty(), Builder.getInt32Ty()); LoadRangeInstrumentFunction = M.getOrInsertFunction( "__Instrument_LoadRange", Attrs, Builder.getVoidTy(), Builder.getInt64Ty(), Builder.getInt32Ty()); InstrumentFunctions.push_back(&StoreInstrumentFunction); InstrumentFunctions.push_back(&LoadInstrumentFunction); InstrumentFunctions.push_back(&StoreRangeInstrumentFunction); InstrumentFunctions.push_back(&LoadRangeInstrumentFunction); // Atomic operations. for (size_t i = 1; i <= MAX_ATOMIC_SIZE; i *= 2) { SmallString<64> FuncName("__Instrument_AtomicStore_int" + utostr(i * 8)); AtomicStoreInstrumentFunctions[FunctionIndexFromSize(i)] = M.getOrInsertFunction( FuncName, Attrs, Builder.getVoidTy(), // Return void. Builder.getIntNTy(i * 8)->getPointerTo(), // Atomic pointer. Builder.getIntNTy(i * 8), // Value to store. Builder.getInt8Ty() // Memory order. ); } for (size_t i = 1; i <= MAX_ATOMIC_SIZE; i *= 2) { SmallString<64> FuncName("__Instrument_AtomicLoad_int" + utostr(i * 8)); AtomicLoadInstrumentFunctions[FunctionIndexFromSize(i)] = M.getOrInsertFunction( FuncName, Attrs, Builder.getIntNTy(i * 8), // Return loaded value. Builder.getIntNTy(i * 8)->getPointerTo(), // Atomic pointer. Builder.getInt8Ty() // Memory order. ); } for (size_t i = 1; i <= MAX_ATOMIC_SIZE; i *= 2) { SmallString<64> FuncName("__Instrument_AtomicExchange_int" + utostr(i * 8)); AtomicExchangeInstrumentFunctions[FunctionIndexFromSize(i)] = M.getOrInsertFunction( FuncName, Attrs, Builder.getIntNTy(i * 8), // Return previous value. Builder.getIntNTy(i * 8)->getPointerTo(), // Atomic pointer. Builder.getIntNTy(i * 8), // Value to store. Builder.getInt8Ty() // Memory order. ); } for (size_t i = 1; i <= MAX_ATOMIC_SIZE; i *= 2) { SmallString<64> FuncName("__Instrument_AtomicCompareExchange_int" + utostr(i * 8)); AtomicCompareExchangeInstrumentFunctions[FunctionIndexFromSize(i)] = M.getOrInsertFunction( FuncName, Attrs, Builder.getIntNTy(i * 8), // Return previous value. Builder.getIntNTy(i * 8)->getPointerTo(), // Atomic pointer. Builder.getIntNTy(i * 8)->getPointerTo(), // Expected pointer. Builder.getIntNTy(i * 8), // Value to store. Builder.getInt8Ty(), // Success memory order. Builder.getInt8Ty() // Failure memory order. ); } for (size_t i = 1; i <= MAX_ATOMIC_SIZE; i *= 2) { for (int b = 0; b < RMWBinOp::LAST_BINOP; ++b) { if (!IsRMWOpHandled((RMWBinOp)b)) { continue; } std::string OpName = RMWOpName((RMWBinOp)b); SmallString<64> FuncName("__Instrument_AtomicFetch"); FuncName.append(OpName); FuncName.append("_int"); FuncName.append(utostr(i * 8)); AtomicRMWInstrumentFunctions[b][FunctionIndexFromSize(i)] = M.getOrInsertFunction( FuncName, Attrs, Builder.getIntNTy(i * 8), // Return previous value. Builder.getIntNTy(i * 8)->getPointerTo(), // Atomic pointer. Builder.getIntNTy(i * 8), // Value to add. Builder.getInt8Ty() // Memory order. ); } AtomicRMWInstrumentFunctions[RMWBinOp::Xchg][FunctionIndexFromSize(i)] = AtomicExchangeInstrumentFunctions[FunctionIndexFromSize(i)]; } for (size_t i = 0; i < NUM_ATOMIC_FUNCS; ++i) { InstrumentFunctions.push_back(&AtomicStoreInstrumentFunctions[i]); InstrumentFunctions.push_back(&AtomicLoadInstrumentFunctions[i]); InstrumentFunctions.push_back(&AtomicExchangeInstrumentFunctions[i]); InstrumentFunctions.push_back(&AtomicCompareExchangeInstrumentFunctions[i]); for (size_t b = 0; b < RMWBinOp::LAST_BINOP; ++b) { if (AtomicRMWInstrumentFunctions[b][i].getCallee() != nullptr) { InstrumentFunctions.push_back(&AtomicRMWInstrumentFunctions[b][i]); } } } } bool CustomMemoryInstrumentationPass::instrumentMSVCStandardLib(Module &M) { bool AnyInstrumented = false; SmallVector> Insts; for (auto &Function : M) { if (!shouldInstrumentFunction(Function)) { continue; } // Find calls. for (auto &BasicBlock : Function) { for (auto &Instruction : BasicBlock) { if (CallInst *Call = dyn_cast(&Instruction)) { if (!Call->getCalledFunction()) { continue; } std::string DemangledName = demangle(Call->getCalledFunction()->getName().str()); // errs() << "Call: " << DemangledName << "\n"; for (auto &[FunctionNameRegex, CallSite] : MSVCAtomicCallSites) { if (FunctionNameRegex->match(DemangledName)) { Insts.push_back({&CallSite, Call}); } } } } } } // First, if there is any call to instrument that requires // pointer arithmetic, figure out the pointee sizes before // any instrumentation can interfere with that process. for (auto &[CallSite, Inst] : Insts) { if (CallSite->RequiresPointerArithmetic) { CacheOrGetPointeeSizeForMSVCAtomicPointerRMW(Inst, CallSite->RMWOp); } } // Instrument calls. for (auto &[CallSite, Inst] : Insts) { AnyInstrumented |= instrumentMSVCAtomicCallSite(Inst, *CallSite); } return AnyInstrumented; } bool CustomMemoryInstrumentationPass::instrumentModule(Module &M) { bool AnyInstrumented = false; for (auto &Function : M) { if (!shouldInstrumentFunction(Function)) { continue; } bool SkipNonAtomics = Function.hasFnAttribute("no_sanitize_thread"); bool ContainsCalls = false; bool FunctionInstrumented = false; for (auto &BasicBlock : Function) { SmallVector Stores; SmallVector Loads; SmallVector CompareExchanges; SmallVector RMWs; SmallVector MemTransfers; SmallVector MemSets; for (auto &Instruction : BasicBlock) { if (StoreInst *Store = dyn_cast(&Instruction)) { Stores.push_back(Store); } else if (LoadInst *Load = dyn_cast(&Instruction)) { Loads.push_back(Load); } else if (AtomicCmpXchgInst *Cmp = dyn_cast(&Instruction)) { CompareExchanges.push_back(Cmp); } else if (AtomicRMWInst *RMW = dyn_cast(&Instruction)) { RMWs.push_back(RMW); } else if (MemTransferInst *MemCpy = dyn_cast(&Instruction)) { MemTransfers.push_back(MemCpy); } else if (MemSetInst *MemSet = dyn_cast(&Instruction)) { MemSets.push_back(MemSet); } else if (CallInst *Call = dyn_cast(&Instruction)) { ContainsCalls = true; } } for (auto *Inst : Stores) { FunctionInstrumented |= instrumentStore(Inst, SkipNonAtomics); } for (auto *Inst : Loads) { FunctionInstrumented |= instrumentLoad(Inst, SkipNonAtomics); } for (auto *Inst : CompareExchanges) { FunctionInstrumented |= instrumentCompareExchange(Inst); } for (auto *Inst : RMWs) { FunctionInstrumented |= instrumentRMW(Inst); } if (!SkipNonAtomics) { for (auto *Inst : MemTransfers) { FunctionInstrumented |= instrumentMemTransfer(Inst); } for (auto *Inst : MemSets) { FunctionInstrumented |= instrumentMemSet(Inst); } } // errs() << "Instrumented function " << Function.getName() << "\n"; // if (verifyFunction(Function, &errs())) { // errs() << "Broken function" << Function << "\n "; // return true; //} } if (FunctionInstrumented || ContainsCalls) { FunctionInstrumented |= instrumentFunctionEntry(Function); FunctionInstrumented |= instrumentFunctionExit(Function); } AnyInstrumented |= FunctionInstrumented; } return AnyInstrumented; } bool CustomMemoryInstrumentationPass::instrumentFunctionEntry(Function &F) { InstrumentationIRBuilder Builder(F.getEntryBlock().getFirstNonPHI()); Value *ReturnAddress = Builder.CreateCall( Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress), Builder.getInt32(0)); Builder.CreateCall(FuncEntryInstrumentFunction, ReturnAddress); return true; } bool CustomMemoryInstrumentationPass::instrumentFunctionExit(Function &F) { EscapeEnumerator EE(F, "instrumentation_cleanup", /* HandleExceptions = */ false); while (IRBuilder<> *Builder = EE.Next()) { InstrumentationIRBuilder::ensureDebugInfo(*Builder, F); Builder->CreateCall(FuncExitInstrumentFunction, {}); } return true; } FunctionCallee & CustomMemoryInstrumentationPass::getInstrumentFunctionForMSVCAtomicCallSite( uint32_t Size, AtomicCallSite &CallSite) { switch (CallSite.Type) { case CALL_SITE_LOAD: return AtomicLoadInstrumentFunctions[FunctionIndexFromSize(Size)]; case CALL_SITE_STORE: return AtomicStoreInstrumentFunctions[FunctionIndexFromSize(Size)]; case CALL_SITE_EXCHANGE: return AtomicExchangeInstrumentFunctions[FunctionIndexFromSize(Size)]; case CALL_SITE_COMPARE_EXCHANGE: return AtomicCompareExchangeInstrumentFunctions[FunctionIndexFromSize( Size)]; case CALL_SITE_RMW: return AtomicRMWInstrumentFunctions[CallSite.RMWOp] [FunctionIndexFromSize(Size)]; default: llvm_unreachable("Should be handled"); } } bool CustomMemoryInstrumentationPass::instrumentMSVCAtomicCallSite( CallInst *Inst, AtomicCallSite CallSite) { Type *SretType = CallSite.AdjustCallSiteForSret(Inst); InstrumentationIRBuilder Builder(Inst); uint32_t Size = 0; if (SretType) { Size = CurrentModule->getDataLayout().getTypeStoreSize(SretType); } else if (CallSite.SizeTypeOperand == -1) { if (Inst->getType()->isVoidTy()) { errs() << "Void type: " << *Inst << "\n"; errs() << demangle(Inst->getCalledFunction()->getName().str()) << "\n"; errs() << demangle(Inst->getFunction()->getName().str()) << "\n"; } Size = CurrentModule->getDataLayout().getTypeStoreSize(Inst->getType()); } else { Size = CurrentModule->getDataLayout().getTypeStoreSize( Inst->getArgOperand(CallSite.SizeTypeOperand)->getType()); } if (Size > MAX_ATOMIC_SIZE) { return false; } Value *Ptr = Builder.CreatePointerCast(Inst->getArgOperand(CallSite.PtrOperand), Builder.getIntNTy(Size * 8)->getPointerTo()); Value *Val = nullptr; if (CallSite.StoreValueOperand) { Val = Inst->getArgOperand(*CallSite.StoreValueOperand); // If we're doing pointer arithmetic, we need to know the pointee's size // to multiply the value with. if (CallSite.RequiresPointerArithmetic) { uint64_t PointeeSize = CacheOrGetPointeeSizeForMSVCAtomicPointerRMW(Inst, CallSite.RMWOp); if (PointeeSize == 0) { errs() << "Failed to determine pointee size for atomic pointer RMW: " << *Inst << "\n"; report_fatal_error(make_error( "Failed to determine pointee size for atomic pointer RMW", inconvertibleErrorCode())); } Val = Builder.CreateMul( Val, ConstantInt::get(Builder.getInt64Ty(), PointeeSize)); } } Value *Expected = nullptr; if (CallSite.ExpectedOperand) { Expected = Builder.CreatePointerCast( Inst->getArgOperand(*CallSite.ExpectedOperand), Builder.getIntNTy(Size * 8)->getPointerTo()); } Value *MemoryOrder = nullptr; if (CallSite.AtomicOrderOperand && GetRealNumCallOperands(Inst) > *CallSite.AtomicOrderOperand) { Value *StdMemoryOrder = Inst->getArgOperand(*CallSite.AtomicOrderOperand); MemoryOrder = Builder.CreateIntCast(StdMemoryOrder, Builder.getInt8Ty(), true); } else { MemoryOrder = ConstantInt::get(Builder.getInt8Ty(), FAtomicMemoryOrder::MEMORY_ORDER_SEQ_CST); } Value *FailureMemoryOrder = MemoryOrder; if (CallSite.FailureAtomicOrderOperand && GetRealNumCallOperands(Inst) > *CallSite.FailureAtomicOrderOperand) { Value *StdMemoryOrder = Inst->getArgOperand(*CallSite.FailureAtomicOrderOperand); FailureMemoryOrder = Builder.CreateIntCast(StdMemoryOrder, Builder.getInt8Ty(), true); } Value *Sret = nullptr; if (CallSite.SretOperand) { Sret = Inst->getArgOperand(*CallSite.SretOperand); } FunctionCallee &InstrumentFunction = getInstrumentFunctionForMSVCAtomicCallSite(Size, CallSite); if (!InstrumentFunction.getCallee()) { return false; } if (CallSite.Type == CALL_SITE_COMPARE_EXCHANGE) { assert(!Sret); return instrumentAtomicCompareExchangeMemoryInst( Builder, Inst, Ptr, Expected, Val, MemoryOrder, FailureMemoryOrder, InstrumentFunction, true /* return a single boolean value */); } return instrumentAtomicMemoryInst(Builder, Inst, Ptr, Val, FailureMemoryOrder, InstrumentFunction, Sret); } bool CustomMemoryInstrumentationPass::shouldInstrumentAddr(Value *Addr) { // if the variable is on stack and is never captured, we don't need to instrument it. if (isa(getUnderlyingObject(Addr)) && !PointerMayBeCaptured(Addr, true, true)) { return false; } return true; } bool CustomMemoryInstrumentationPass::instrumentStore(StoreInst *Inst, bool SkipNonAtomics) { InstrumentationIRBuilder Builder(Inst); Value *Addr = Inst->getPointerOperand(); if (!shouldInstrumentAddr(Addr)) { return false; } // Special case for virtual table pointer updates. if (MDNode *Metadata = Inst->getMetadata(LLVMContext::MD_tbaa)) { if (Metadata->isTBAAVtableAccess()) { Value *ValueOperand = Inst->getValueOperand(); if (isa(ValueOperand->getType())) ValueOperand = Builder.CreateExtractElement( ValueOperand, ConstantInt::get(Builder.getInt32Ty(), 0)); if (ValueOperand->getType()->isIntegerTy()) ValueOperand = Builder.CreateIntToPtr(ValueOperand, Builder.getPtrTy()); Builder.CreateCall(StoreVPtrInstrumentFunction, {Addr, ValueOperand}); return true; } } Value *Ptr = Builder.CreateCast(Instruction::CastOps::PtrToInt, Inst->getPointerOperand(), Builder.getInt64Ty()); uint32_t Size = CurrentModule->getDataLayout().getTypeStoreSize( Inst->getValueOperand()->getType()); if (Inst->isAtomic()) { assert(Size <= MAX_ATOMIC_SIZE); Value *MemoryOrder = ConstantInt::get(Builder.getInt8Ty(), MemoryOrderFromInst(Inst)); return instrumentAtomicMemoryInst( Builder, Inst, Inst->getPointerOperand(), Inst->getValueOperand(), MemoryOrder, AtomicStoreInstrumentFunctions[FunctionIndexFromSize(Size)], nullptr); } else if (SkipNonAtomics) { return false; } return instrumentMemoryInst(Builder, Inst->getDebugLoc(), Ptr, Size, StoreInstrumentFunction); } bool CustomMemoryInstrumentationPass::instrumentLoad(LoadInst *Inst, bool SkipNonAtomics) { InstrumentationIRBuilder Builder(Inst); Value* Addr = Inst->getPointerOperand(); if (!shouldInstrumentAddr(Addr)) { return false; } // Special case for virtual table pointer reads. if (MDNode *Metadata = Inst->getMetadata(LLVMContext::MD_tbaa)) { if (Metadata->isTBAAVtableAccess()) { Builder.CreateCall(LoadVPtrInstrumentFunction, Addr); return true; } } Value *Ptr = Builder.CreateCast(Instruction::CastOps::PtrToInt, Addr, Builder.getInt64Ty()); uint32_t Size = CurrentModule->getDataLayout().getTypeStoreSize(Inst->getType()); if (Inst->isAtomic()) { assert(Size <= MAX_ATOMIC_SIZE); Value *MemoryOrder = ConstantInt::get(Builder.getInt8Ty(), MemoryOrderFromInst(Inst)); return instrumentAtomicMemoryInst( Builder, Inst, Addr, nullptr /* value */, MemoryOrder, AtomicLoadInstrumentFunctions[FunctionIndexFromSize(Size)], nullptr); } else if (SkipNonAtomics) { return false; } return instrumentMemoryInst(Builder, Inst->getDebugLoc(), Ptr, Size, LoadInstrumentFunction); } bool CustomMemoryInstrumentationPass::instrumentCompareExchange( AtomicCmpXchgInst *Inst) { InstrumentationIRBuilder Builder(Inst); uint32_t Size = CurrentModule->getDataLayout().getTypeStoreSize( Inst->getNewValOperand()->getType()); assert(Size <= MAX_ATOMIC_SIZE); Value *SuccessMemoryOrder = ConstantInt::get(Builder.getInt8Ty(), MemoryOrderFromLLVMOrdering(Inst->getSuccessOrdering())); Value *FailureMemoryOrder = ConstantInt::get(Builder.getInt8Ty(), MemoryOrderFromLLVMOrdering(Inst->getFailureOrdering())); return instrumentAtomicCompareExchangeMemoryInst( Builder, Inst, Inst->getPointerOperand(), Inst->getCompareOperand(), Inst->getNewValOperand(), SuccessMemoryOrder, FailureMemoryOrder, AtomicCompareExchangeInstrumentFunctions[FunctionIndexFromSize(Size)], false /* return both old val and success bool*/); } bool CustomMemoryInstrumentationPass::instrumentRMW(AtomicRMWInst *Inst) { InstrumentationIRBuilder Builder(Inst); uint32_t Size = CurrentModule->getDataLayout().getTypeStoreSize( Inst->getValOperand()->getType()); assert(Size <= MAX_ATOMIC_SIZE); Value *MemoryOrder = ConstantInt::get(Builder.getInt8Ty(), MemoryOrderFromInst(Inst)); FunctionCallee &InstrumentFunction = AtomicRMWInstrumentFunctions[Inst->getOperation()] [FunctionIndexFromSize(Size)]; if (InstrumentFunction.getCallee()) { return instrumentAtomicMemoryInst( Builder, Inst, Inst->getPointerOperand(), Inst->getValOperand(), MemoryOrder, AtomicRMWInstrumentFunctions[Inst->getOperation()] [FunctionIndexFromSize(Size)], nullptr); } return false; } bool CustomMemoryInstrumentationPass::instrumentMemoryInst( InstrumentationIRBuilder &Builder, const DebugLoc &DebugLoc, Value *Ptr, uint32_t Size, FunctionCallee &InstrumentFunction) { Builder.CreateCall(InstrumentFunction, {Ptr, ConstantInt::get(Builder.getInt32Ty(), Size)}); return true; } bool CustomMemoryInstrumentationPass::instrumentAtomicMemoryInst( InstrumentationIRBuilder &Builder, Instruction *Inst, Value *Ptr, Value *ValIfStore, Value *MemoryOrder, FunctionCallee &InstrumentFunction, Value *Sret) { const DebugLoc &DebugLoc = Inst->getDebugLoc(); CallInst *CallInstruction = nullptr; Value *Ret = nullptr; if (ValIfStore) { Value *Val = CreateCast( Builder, ValIfStore, InstrumentFunction.getFunctionType()->getFunctionParamType(1)); CallInstruction = Builder.CreateCall(InstrumentFunction, {Ptr, Val, MemoryOrder}); CallInstruction->setDebugLoc(DebugLoc); } else { CallInstruction = Builder.CreateCall(InstrumentFunction, {Ptr, MemoryOrder}); CallInstruction->setDebugLoc(DebugLoc); } if (Sret) { Ret = Builder.CreateStore(CallInstruction, Sret); } else { Ret = CreateCast(Builder, CallInstruction, Inst->getType()); } Inst->replaceAllUsesWith(Ret); Ret->takeName(Inst); Inst->eraseFromParent(); return true; } bool CustomMemoryInstrumentationPass::instrumentAtomicCompareExchangeMemoryInst( InstrumentationIRBuilder &Builder, Instruction *Inst, Value *Ptr, Value *Expected, Value *Val, Value *SuccessMemoryOrder, Value *FailureMemoryOrder, FunctionCallee &InstrumentFunction, bool ReturnOnlyBool) { const DebugLoc &DebugLoc = Inst->getDebugLoc(); Value *ExpectedVal = nullptr; Value *ExpectedPtr = nullptr; if (Expected->getType()->isPointerTy()) { ExpectedPtr = Expected; ExpectedVal = Builder.CreateLoad(Val->getType(), Expected); } else { ExpectedVal = Expected; // Insert alloca at the beginning of the function. auto CurrentInsertPoint = Builder.GetInsertPoint(); Builder.SetInsertPoint( &*Inst->getFunction()->getEntryBlock().getFirstInsertionPt()); AllocaInst *ExpectedPtrAlloca = Builder.CreateAlloca(Val->getType()); ExpectedPtrAlloca->setAlignment(llvm::Align(MAX_ATOMIC_SIZE)); ExpectedPtr = ExpectedPtrAlloca; Builder.SetInsertPoint(Inst->getParent(), CurrentInsertPoint); Builder.CreateStore(ExpectedVal, ExpectedPtrAlloca); } Value *Ret = nullptr; Value *StoreVal = CreateCast(Builder, Val, InstrumentFunction.getFunctionType()->getFunctionParamType(2)); Value *PrevVal = Builder.CreateCall( InstrumentFunction, {Ptr, ExpectedPtr, StoreVal, SuccessMemoryOrder, FailureMemoryOrder}); dyn_cast(PrevVal)->setDebugLoc(DebugLoc); // Compare bytes (reinterpret value as integer bytes). Value *Success = Builder.CreateICmpEQ( PrevVal, CreateCast(Builder, ExpectedVal, PrevVal->getType())); // Handle return value. if (ReturnOnlyBool) { Ret = Success; Ret = CreateCast(Builder, Success, Inst->getType()); } else { AtomicCmpXchgInst *CmpXchg = dyn_cast(Inst); assert(CmpXchg); Type *PrevValType = CmpXchg->getNewValOperand()->getType(); PrevVal = CreateCast(Builder, PrevVal, PrevValType); Ret = Builder.CreateInsertValue(PoisonValue::get(Inst->getType()), PrevVal, 0); Ret = Builder.CreateInsertValue(Ret, Success, 1); } Inst->replaceAllUsesWith(Ret); Ret->takeName(Inst); Inst->eraseFromParent(); return true; } bool CustomMemoryInstrumentationPass::instrumentMemTransfer( MemTransferInst *Inst) { InstrumentationIRBuilder Builder(Inst); instrumentMemoryInstRange(Builder, Inst->getDebugLoc(), Inst->getSource(), Inst->getLength(), LoadRangeInstrumentFunction); instrumentMemoryInstRange(Builder, Inst->getDebugLoc(), Inst->getDest(), Inst->getLength(), StoreRangeInstrumentFunction); return true; } bool CustomMemoryInstrumentationPass::instrumentMemSet(MemSetInst *Inst) { InstrumentationIRBuilder Builder(Inst); return instrumentMemoryInstRange(Builder, Inst->getDebugLoc(), Inst->getDest(), Inst->getLength(), StoreRangeInstrumentFunction); } bool CustomMemoryInstrumentationPass::instrumentMemoryInstRange( InstrumentationIRBuilder &Builder, const DebugLoc &DebugLoc, Value *Ptr, Value *Length, FunctionCallee &InstrumentFunction) { Value *Addr = Builder.CreatePtrToInt(Ptr, Builder.getInt64Ty()); Value *Size = Builder.CreateIntCast(Length, Builder.getInt32Ty(), false); CallInst *Call = Builder.CreateCall(InstrumentFunction, {Addr, Size}); Call->setDebugLoc(DebugLoc); return true; } } // namespace llvm