Time to accomplish what we set out to do - wrap each load/store opcode with a call to traceMemory.

Let’s start by changing our pass’s run function to pass over all the instructions in our program:

llvm::PreservedAnalyses run(llvm::Module &M,
                        llvm::ModuleAnalysisManager &) {
    Function *main = M.getFunction("main");
    if (main) {
            addGlobalMemoryTraceFP(M);
            addMemoryTraceFPInitialization(M, *main);
            addTraceMemoryFunction(M);
            errs() << "Found main in module " << M.getName() << "\\n";
    } else {
            errs() << "Did not find main in " << M.getName() << "\\n";
    }

    for (llvm::Function &F : M) {
        if (F.getName() == TraceMemoryFunctionName) {
            continue;
        }

        for (llvm::BasicBlock &BB : F) {
            for (llvm::Instruction &Instruction : BB) {
                if (Instruction.getOpcode() == llvm::Instruction::Load ||
                    Instruction.getOpcode() == llvm::Instruction::Store) {
                    addMemoryTraceToInstruction(M, Instruction);
                }
            }
        }
    }

    return llvm::PreservedAnalyses::none();
}

This is pretty straightforward:

  1. We iterate over all functions in our compilation module - importantly, we skip over our traceMemory function, because if we try to trace its memory accesses we will end up with an infinite recursion
  2. We iterate over each instruction in each function, and we use llvm::Instruction::getOpcode to filter out only memory access instructions - which we pass into addMemoryTraceToInstruction

Adding Memory Trace to an Instruction

This is the last piece in the puzzle of our little project. Our addMemoryTraceToInstruction function will need three parts:

  1. We will need to fetch the traceMemory function using llvm::Module::getOrInsertFunction

    <aside> 💡 We previously already called FunctionCallee TraceMemory = M.getOrInsertFunction(TraceMemoryFunctionName, TraceMemoryTy); when creating the function in our main module - so why can’t we just reuse this?

    This would work fine in the compilation module that includes main - but since we want to call the externally-linked function from other compilation modules, we will need to call getOrInsertFunction for each compilation module in turn

    </aside>

  2. We will need to prepare the arguments to be passed into traceMemory - since we are handling both load and store instructions, we will need to handle the nuances in the difference between the two types of instructions

  3. Finally, we will need to emit a call to traceMemory

The implementations of these three stages are hidden away for brevity:

1. Fetching traceMemory

2. Preparing the Arguments for traceMemory

3. Introducing the Call to traceMemory

Putting it all together, we get:

void addMemoryTraceToInstruction(llvm::Module &M, llvm::Instruction &Instruction) {
    auto &CTX = M.getContext();

    std::vector<llvm::Type*> TraceMemoryArgs{
        PointerType::getUnqual(Type::getInt8Ty(CTX)),
        Type::getInt64Ty(CTX),
        Type::getInt32Ty(CTX)
    };

    FunctionType *TraceMemoryTy = FunctionType::get(Type::getVoidTy(CTX),
                                                    TraceMemoryArgs,
                                                    false);

    FunctionCallee TraceMemory = M.getOrInsertFunction(TraceMemoryFunctionName, TraceMemoryTy);

    IRBuilder<> Builder(Instruction.getNextNode());
    llvm::LoadInst *LoadInst = dyn_cast<llvm::LoadInst>(&Instruction);
    llvm::StoreInst *StoreInst = dyn_cast<llvm::StoreInst>(&Instruction);
        
    llvm::Value *MemoryAddress;
    if (LoadInst)
        MemoryAddress = Builder.CreatePointerCast(LoadInst->getPointerOperand(), TraceMemoryArgs[0], "memoryAddress");
    else
        MemoryAddress = Builder.CreatePointerCast(StoreInst->getPointerOperand(), TraceMemoryArgs[0], "memoryAddress");

    llvm::Value *CastTo64;
    llvm::Value *ValueToPrint = LoadInst ? &Instruction : StoreInst->getOperand(0);
    bool ShouldConvertPointer = ValueToPrint->getType()->isPointerTy();

    if (ShouldConvertPointer)
        CastTo64 = Builder.CreatePtrToInt(ValueToPrint, TraceMemoryArgs[1], "castTo64");
    else
        CastTo64 = Builder.CreateIntCast(ValueToPrint, TraceMemoryArgs[1], false, "castTo64");

    Builder.CreateCall(TraceMemory, {MemoryAddress, CastTo64,
                                    Builder.getInt32(Instruction.getOpcode() == llvm::Instruction::Load)});
}

Our pass results in LLVM IR that looks like: