Building Your First LLVM Optimization Pass


Published:
Tags:llvm C++

LLVM is a powerful toolchain used by many modern programming languages such as Rust and Swift, but even older languages like C/C++ can make effective use of it. It not only translates intermediate code into machine-specific instructions but also optimizes it through a series of analysis and transformation passes.

This blog post aims to demystify LLVM optimization passes and guide you through writing your first one.

The complete code examples discussed in this blog post are available on GitHub.

II. Understanding LLVM Passes

III. Writing Your First LLVM Pass

Dependencies:

Setting up the template

We set up a very simplistic file structure:

const_split/
+- build/
+- CMakeLists.txt
\- const_slit.cpp

The contents of CMakeLists.txt:

cmake_minimum_required(VERSION 3.14 FATAL_ERROR)
project(ConstSplit VERSION 0.1.0 LANGUAGES CXX C)

find_package(LLVM REQUIRED CONFIG)

list(APPEND CMAKE_MODULE_PATH ${LLVM_CMAKE_DIR})
include(AddLLVM)

add_llvm_pass_plugin(ConstSplit const_split.cpp)

target_include_directories(ConstSplit PUBLIC ${LLVM_INCLUDE_DIRS})
target_link_directories(ConstSplit PUBLIC ${LLVM_INCLUDE_DIRS})
target_compile_definitions(ConstSplit PUBLIC ${LLVM_DEFINITIONS})

if (NOT $(LLVM_ENABLE_RTTI))
  target_compile_options(ConstSplit PUBLIC "-fno-rtti")
endif()

And in const_split.cpp we start with a function pass skeleton:

#include <llvm/Passes/OptimizationLevel.h>
#include <llvm/IR/Module.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IR/IRBuilder.h>
#include <llvm/IR/DerivedTypes.h>
#include <llvm/Pass.h>
#include <llvm/Passes/PassBuilder.h>
#include <llvm/Passes/PassPlugin.h>

#include <llvm/Support/Debug.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Support/RandomNumberGenerator.h>

using namespace llvm;

namespace {

struct ConstantSplit : public PassInfoMixin<ConstantSplit> {
  PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
};
}

PreservedAnalyses ConstantSplit::run(Function &F, FunctionAnalysisManager &FAM) {
  return PreservedAnalyses::none();
}

extern "C" LLVM_ATTRIBUTE_WEAK PassPluginLibraryInfo llvmGetPassPluginInfo() {
  return {LLVM_PLUGIN_API_VERSION, "constant-split-oot", "0.1", [](PassBuilder &PB) {
    PB.registerScalarOptimizerLateEPCallback([](FunctionPassManager &FPM, OptimizationLevel Ol) {
      FPM.addPass(ConstantSplit());
    });
    PB.registerPipelineParsingCallback([] (StringRef name, FunctionPassManager &FPM, ArrayRef<llvm::PassBuilder::PipelineElement>) {
      if (name == "constant-split"){
        FPM.addPass(ConstantSplit());
        return true;
      }
      return false;
    });
  }};
}

In the build directory we do:

cmake ..

If you want to use a LLVM installation different from your system installation you could do:

cmke -DLLVM_DIR=/path/to/installation/lib/cmake/llvm ..

You can obviously use INSTALL_PREFIX to set the directory where the plugin will be put after you do make install and use -DCMAKE_EXPORT_COMPILE_COMMANDS=On to enable clangd based linting in many IDEs.

Implementing the Transformation

We first implement constant splitting in a naive, care-free way:

PreservedAnalyses ConstantSplit::run(Function &F, FunctionAnalysisManager &FAM) {
  SmallString<64> rngstr = {"obf.ConstSplit."};
  rngstr += F.getName();

  auto rng = F.getParent()->createRNG(rngstr);
  for (auto &bb : F) {
    for (auto &ins : bb) {
      for (auto &op : ins.operands()) {
        Value *v = op.get();
        if (!isa<ConstantInt>(v))
          continue;

        ConstantInt *opval = cast<ConstantInt>(v);
        APInt rval(opval->getValue().getBitWidth(), (*rng)(), false, true);

        {
          IRBuilder<> Builder(&bb, ins.getIterator());
          Value *nins = Builder.CreateAdd(ConstantInt::get(opval->getType(), rval), ConstantInt::get(opval->getType(), opval->getValue() - rval));
          llvm::outs() << "replacing " << *opval << " in " << ins << " with " << *nins << "\n";
          op = nins;
        }
      }
    }
  }
  return PreservedAnalyses::none();
}

IV. Running the pass

Creating a LLVMIR file

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

__attribute__((pure))
unsigned long long fib(unsigned long long i) {
  unsigned long long a = 0, b = 1, t;
  for (;i > 0; --i) {
    t = b;
    b = a;
    a += t;
  }
  return a;
}

int main(int argc, char **argv) {
  if (argc != 2)
    return 1;
  unsigned long long a = strtoull(argv[1], NULL, 0);
  printf("fib(%llu) = %llu\n", a, fib(a));
  return 0;
}
clang -Xclang -emit-llvm -Xclang -disable-O0-optnone -O0 fib.c -S -o fib.ll

Applying function passes to LLVMIR files

opt -load-pass-plugin /absolute/path/to/plugin/ConstSplit.so -passes "mem2reg,loop-rotate,constant-split" fib.ll -S -o fib_cs.ll

which will output:

replacing i64 0 in   %cmp1 = icmp ugt i64 %i, 0 with i64 0
replacing i64 0 in   %a.03 = phi i64 [ 0, %for.body.lr.ph ], [ %add, %for.inc ] with i64 0
replacing i64 1 in   %b.02 = phi i64 [ 1, %for.body.lr.ph ], [ %a.03, %for.inc ] with i64 1
replacing i64 -1 in   %dec = add i64 %i.addr.04, -1 with i64 -1
replacing i64 0 in   %cmp = icmp ugt i64 %dec, 0 with i64 0
replacing i64 0 in   %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ 0, %entry ] with i64 0
replacing i32 2 in   %cmp = icmp ne i32 %argc, 2 with i32 2
replacing i64 1 in   %arrayidx = getelementptr inbounds ptr, ptr %argv, i64 1 with i64 1
replacing i32 0 in   %call = call i64 @strtoull(ptr noundef %0, ptr noundef null, i32 noundef 0) #4 with i32 0
replacing i32 1 in   %retval.0 = phi i32 [ 1, %if.then ], [ 0, %if.end ] with i32 1
replacing i32 0 in   %retval.0 = phi i32 [ 1, %if.then ], [ 0, %if.end ] with i32 0

V. Making the pass work

Forcing LLVM to emit instructions

--- a/const_split.cpp
+++ b/const_split.cpp
@@ -3,6 +3,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/InstrTypes.h"
 #include <llvm/Passes/OptimizationLevel.h>
 #include <llvm/IR/Module.h>
 #include <llvm/IR/PassManager.h>
@@ -41,8 +42,7 @@ PreservedAnalyses ConstantSplit::run(Function &F, FunctionAnalysisManager &FAM)
        APInt rval(opval->getValue().getBitWidth(), (*rng)(), false, true);
 
        {
-         IRBuilder<> Builder(&bb, ins.getIterator());
-         Value *nins = Builder.CreateAdd(ConstantInt::get(opval->getType(), rval), ConstantInt::get(opval->getType(), opval->getValue() - rval));
+         Value *nins = BinaryOperator::CreateAdd(ConstantInt::get(opval->getType(), rval), ConstantInt::get(opval->getType(), opval->getValue() - rval), "split", ins.getIterator());
          llvm::outs() << "replacing " << *opval << " in " << ins << " with " << *nins << "\n";
          op = nins;
        }

Now we can run our pass again and the output looks like:

replacing i64 0 in   %cmp1 = icmp ugt i64 %i, 0 with   %split5 = add i64 3889012652787861261, -3889012652787861261
replacing i64 0 in   %a.03 = phi i64 [ 0, %for.body.lr.ph ], [ %add, %for.inc ] with   %split6 = add i64 4642127197507871265, -4642127197507871265
replacing i64 1 in   %b.02 = phi i64 [ 1, %for.body.lr.ph ], [ %a.03, %for.inc ] with   %split7 = add i64 -7696377703097262296, 7696377703097262297
replacing i64 -1 in   %dec = add i64 %i.addr.04, -1 with   %split8 = add i64 5757122681941097448, -5757122681941097449
replacing i64 0 in   %cmp = icmp ugt i64 %dec, 0 with   %split9 = add i64 -5547331450338019100, 5547331450338019100
replacing i64 0 in   %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ 0, %entry ] with   %split10 = add i64 649514231163583877, -649514231163583877
replacing i32 2 in   %cmp = icmp ne i32 %argc, 2 with   %split = add i32 1123537705, -1123537703
replacing i64 1 in   %arrayidx = getelementptr inbounds ptr, ptr %argv, i64 1 with   %split1 = add i64 6844434560377074114, -6844434560377074113
replacing i32 0 in   %call = call i64 @strtoull(ptr noundef %0, ptr noundef null, i32 noundef 0) #4 with   %split2 = add i32 -1244807385, 1244807385
replacing i32 1 in   %retval.0 = phi i32 [ 1, %if.then ], [ 0, %if.end ] with   %split3 = add i32 -1990629377, 1990629378
replacing i32 0 in   %retval.0 = phi i32 [ %split3, %if.then ], [ 0, %if.end ] with   %split4 = add i32 35660392, -35660392
PHI nodes not grouped at top of basic block!
  %a.03 = phi i64 [ %split6, %for.body.lr.ph ], [ %add, %for.inc ]
label %for.body
PHI nodes not grouped at top of basic block!
  %b.02 = phi i64 [ %split7, %for.body.lr.ph ], [ %a.03, %for.inc ]
label %for.body
PHI nodes not grouped at top of basic block!
  %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ %split10, %entry ]
label %for.end
PHI nodes not grouped at top of basic block!
  %retval.0 = phi i32 [ %split3, %if.then ], [ %split4, %if.end ]
label %return
LLVM ERROR: Broken module found, compilation aborted!

Oh no, what happened?

PHI Nodes!

Handling Phi-Nodes

Let’s look at an example to illustrate these points:

int main(int argc, char **argv) {
  if (argc > 1)
    return argc - 1;
  return 0;
}

The corresponding LLVM IR looks like this:

; Function Attrs: noinline nounwind uwtable
define dso_local i32 @main(i32 noundef %argc, ptr noundef %argv) #0 {
entry:
  %cmp = icmp sgt i32 %argc, 1
  br i1 %cmp, label %if.then, label %if.end

if.then:                                          ; preds = %entry
  %sub = sub nsw i32 %argc, 1
  br label %return

if.end:                                           ; preds = %entry
  br label %return

return:                                           ; preds = %if.end, %if.then
  %retval.0 = phi i32 [ %sub, %if.then ], [ 0, %if.end ]
  ret i32 %retval.0
}

We account for phi nodes by modifying our pass:

--- a/const_split.cpp
+++ b/const_split.cpp
@@ -4,6 +4,7 @@
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instructions.h"
 #include <llvm/Passes/OptimizationLevel.h>
 #include <llvm/IR/Module.h>
 #include <llvm/IR/PassManager.h>
@@ -38,11 +39,18 @@ PreservedAnalyses ConstantSplit::run(Function &F, FunctionAnalysisManager &FAM)
        if (!isa<ConstantInt>(v))
          continue;
 
+       auto insertion_point = ins.getIterator();
+       if (isa<PHINode>(ins)) {
+         PHINode &phi = cast<PHINode>(ins);
+         auto *bb = phi.getIncomingBlock(op);
+         insertion_point = bb->getTerminator()->getIterator();
+       }
+
        ConstantInt *opval = cast<ConstantInt>(v);
        APInt rval(opval->getValue().getBitWidth(), (*rng)(), false, true);
 
        {
-         Value *nins = BinaryOperator::CreateAdd(ConstantInt::get(opval->getType(), rval), ConstantInt::get(opval->getType(), opval->getValue() - rval), "split", ins.getIterator());
+         Value *nins = BinaryOperator::CreateAdd(ConstantInt::get(opval->getType(), rval), ConstantInt::get(opval->getType(), opval->getValue() - rval), "split", insertion_point);
          llvm::outs() << "replacing " << *opval << " in " << ins << " with " << *nins << "\n";
          op = nins;
        }

Running the pass now yields:

replacing i64 0 in   %cmp1 = icmp ugt i64 %i, 0 with   %split5 = add i64 3889012652787861261, -3889012652787861261
replacing i64 0 in   %a.03 = phi i64 [ 0, %for.body.lr.ph ], [ %add, %for.inc ] with   %split6 = add i64 4642127197507871265, -4642127197507871265
replacing i64 1 in   %b.02 = phi i64 [ 1, %for.body.lr.ph ], [ %a.03, %for.inc ] with   %split7 = add i64 -7696377703097262296, 7696377703097262297
replacing i64 -1 in   %dec = add i64 %i.addr.04, -1 with   %split8 = add i64 5757122681941097448, -5757122681941097449
replacing i64 0 in   %cmp = icmp ugt i64 %dec, 0 with   %split9 = add i64 -5547331450338019100, 5547331450338019100
replacing i64 0 in   %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ 0, %entry ] with   %split10 = add i64 649514231163583877, -649514231163583877
replacing i32 2 in   %cmp = icmp ne i32 %argc, 2 with   %split = add i32 1123537705, -1123537703
replacing i64 1 in   %arrayidx = getelementptr inbounds ptr, ptr %argv, i64 1 with   %split1 = add i64 6844434560377074114, -6844434560377074113
replacing i32 0 in   %call = call i64 @strtoull(ptr noundef %0, ptr noundef null, i32 noundef 0) #4 with   %split2 = add i32 -1244807385, 1244807385
replacing i32 1 in   %retval.0 = phi i32 [ 1, %if.then ], [ 0, %if.end ] with   %split3 = add i32 -1990629377, 1990629378
replacing i32 0 in   %retval.0 = phi i32 [ %split3, %if.then ], [ 0, %if.end ] with   %split4 = add i32 35660392, -35660392

To compare the results we also create an optimized version of fib.ll by just omitting our constant splitting pass. Here is the result without constant splitting:

; Function Attrs: noinline nounwind willreturn memory(read) uwtable
define dso_local i64 @fib(i64 noundef %i) #0 {
entry:
  %cmp1 = icmp ugt i64 %i, 0
  br i1 %cmp1, label %for.body.lr.ph, label %for.end

for.body.lr.ph:                                   ; preds = %entry
  br label %for.body

for.body:                                         ; preds = %for.body.lr.ph, %for.inc
  %i.addr.04 = phi i64 [ %i, %for.body.lr.ph ], [ %dec, %for.inc ]
  %a.03 = phi i64 [ 0, %for.body.lr.ph ], [ %add, %for.inc ]
  %b.02 = phi i64 [ 1, %for.body.lr.ph ], [ %a.03, %for.inc ]
  %add = add i64 %a.03, %b.02
  br label %for.inc

for.inc:                                          ; preds = %for.body
  %dec = add i64 %i.addr.04, -1
  %cmp = icmp ugt i64 %dec, 0
  br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge, !llvm.loop !6

for.cond.for.end_crit_edge:                       ; preds = %for.inc
  %split = phi i64 [ %add, %for.inc ]
  br label %for.end

for.end:                                          ; preds = %for.cond.for.end_crit_edge, %entry
  %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ 0, %entry ]
  ret i64 %a.0.lcssa
}

And here it is with constant splitting:

; Function Attrs: noinline nounwind willreturn memory(read) uwtable
define dso_local i64 @fib(i64 noundef %i) #0 {
entry:
  %split5 = add i64 3889012652787861261, -3889012652787861261
  %cmp1 = icmp ugt i64 %i, %split5
  %split10 = add i64 649514231163583877, -649514231163583877
  br i1 %cmp1, label %for.body.lr.ph, label %for.end

for.body.lr.ph:                                   ; preds = %entry
  %split6 = add i64 4642127197507871265, -4642127197507871265
  %split7 = add i64 -7696377703097262296, 7696377703097262297
  br label %for.body

for.body:                                         ; preds = %for.body.lr.ph, %for.inc
  %i.addr.04 = phi i64 [ %i, %for.body.lr.ph ], [ %dec, %for.inc ]
  %a.03 = phi i64 [ %split6, %for.body.lr.ph ], [ %add, %for.inc ]
  %b.02 = phi i64 [ %split7, %for.body.lr.ph ], [ %a.03, %for.inc ]
  %add = add i64 %a.03, %b.02
  br label %for.inc

for.inc:                                          ; preds = %for.body
  %split8 = add i64 5757122681941097448, -5757122681941097449
  %dec = add i64 %i.addr.04, %split8
  %split9 = add i64 -5547331450338019100, 5547331450338019100
  %cmp = icmp ugt i64 %dec, %split9
  br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge, !llvm.loop !6

for.cond.for.end_crit_edge:                       ; preds = %for.inc
  %split = phi i64 [ %add, %for.inc ]
  br label %for.end

for.end:                                          ; preds = %for.cond.for.end_crit_edge, %entry
  %a.0.lcssa = phi i64 [ %split, %for.cond.for.end_crit_edge ], [ %split10, %entry ]
  ret i64 %a.0.lcssa
}

What else to watch out for?

VI. Conclusion

VII. References