Commit 9c38f4d1 authored by Matheus Marchini's avatar Matheus Marchini

fix string comparison

This patch fixes strcmp by reimplementing it as a IRBuilder method and
unrolling the loop based on the literal string size instead of
STRING_SIZE.

Fixes: https://github.com/iovisor/bpftrace/issues/14
parent 0c89b18e
......@@ -617,33 +617,41 @@ void CodegenLLVM::visit(Binop &binop)
return;
}
Value *lhs, *rhs;
binop.left->accept(*this);
lhs = expr_;
binop.right->accept(*this);
rhs = expr_;
Type &type = binop.left->type.type;
if (type == Type::string)
{
Function *strcmp_func = module_->getFunction("strcmp");
Value *val;
std::string string_literal("");
if (binop.right->is_literal) {
binop.left->accept(*this);
val = expr_;
string_literal = reinterpret_cast<String*>(binop.right)->str;
} else {
binop.right->accept(*this);
val = expr_;
string_literal = reinterpret_cast<String*>(binop.left)->str;
}
switch (binop.op) {
case bpftrace::Parser::token::EQ:
expr_ = b_.CreateCall(strcmp_func, {lhs, rhs}, "strcmp");
expr_ = b_.CreateStrcmp(val, string_literal);
break;
case bpftrace::Parser::token::NE:
expr_ = b_.CreateNot(b_.CreateCall(strcmp_func, {lhs, rhs}, "strcmp"));
expr_ = b_.CreateStrcmp(val, string_literal, true);
break;
default:
abort();
}
if (!binop.left->is_variable)
b_.CreateLifetimeEnd(lhs);
if (!binop.right->is_variable)
b_.CreateLifetimeEnd(rhs);
b_.CreateLifetimeEnd(val);
}
else
{
Value *lhs, *rhs;
binop.left->accept(*this);
lhs = expr_;
binop.right->accept(*this);
rhs = expr_;
switch (binop.op) {
case bpftrace::Parser::token::EQ: expr_ = b_.CreateICmpEQ (lhs, rhs); break;
case bpftrace::Parser::token::NE: expr_ = b_.CreateICmpNE (lhs, rhs); break;
......@@ -1276,52 +1284,10 @@ void CodegenLLVM::createLinearFunction()
b_.CreateRet(b_.CreateLoad(result_alloc));
}
void CodegenLLVM::createStrcmpFunction()
{
// Returns 1 if strings match, 0 otherwise
// i1 strcmp(const char *s1, const char *s2)
// {
// for (int i=0; i<STRING_SIZE; i++)
// {
// if (s1[i] != s2[i]) return 0;
// }
// return 1;
// }
FunctionType *strcmp_func_type = FunctionType::get(b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, false);
Function *strcmp_func = Function::Create(strcmp_func_type, Function::InternalLinkage, "strcmp", module_.get());
strcmp_func->addFnAttr(Attribute::AlwaysInline);
strcmp_func->setSection("helpers");
BasicBlock *entry = BasicBlock::Create(module_->getContext(), "strcmp.entry", strcmp_func);
BasicBlock *not_equal_block = BasicBlock::Create(module_->getContext(), "strcmp.not_equal", strcmp_func);
b_.SetInsertPoint(entry);
Value *s1 = strcmp_func->arg_begin();
Value *s2 = strcmp_func->arg_begin()+1;
for (int i=0; i<STRING_SIZE; i++)
{
Value *s1_char = b_.CreateGEP(s1, {b_.getInt64(i)});
Value *s2_char = b_.CreateGEP(s2, {b_.getInt64(i)});
BasicBlock *continue_block = BasicBlock::Create(module_->getContext(), "strcmp.continue", strcmp_func);
Value *cmp = b_.CreateICmpNE(b_.CreateLoad(s1_char), b_.CreateLoad(s2_char));
b_.CreateCondBr(cmp, not_equal_block, continue_block);
b_.SetInsertPoint(continue_block);
}
b_.CreateRet(b_.getInt1(1));
b_.SetInsertPoint(not_equal_block);
b_.CreateRet(b_.getInt1(0));
}
std::unique_ptr<BpfOrc> CodegenLLVM::compile(DebugLevel debug, std::ostream &out)
{
createLog2Function();
createLinearFunction();
createStrcmpFunction();
root_->accept(*this);
LLVMInitializeBPFTargetInfo();
......
......@@ -54,7 +54,6 @@ public:
void createLog2Function();
void createLinearFunction();
void createStrcmpFunction();
std::unique_ptr<BpfOrc> compile(DebugLevel debug=DebugLevel::kNone, std::ostream &out=std::cerr);
private:
......
......@@ -323,6 +323,39 @@ Value *IRBuilderBPF::CreateUSDTReadArgument(Value *ctx, AttachPoint *attach_poin
return result;
}
Value *IRBuilderBPF::CreateStrcmp(Value* val, std::string str, bool inverse) {
Function *parent = GetInsertBlock()->getParent();
BasicBlock *str_ne = BasicBlock::Create(module_.getContext(), "strcmp.false", parent);
AllocaInst *store = CreateAllocaBPF(getInt8Ty(), "strcmp.result");
CreateStore(getInt1(inverse), store);
const char *c_str = str.c_str();
for (int i = 0; i < strlen(c_str) + 1; i++)
{
BasicBlock *char_eq = BasicBlock::Create(module_.getContext(), "strcmp.loop", parent);
AllocaInst *val_char = CreateAllocaBPF(getInt8Ty(), "strcmp.char");
Value *ptr = CreateAdd(
val,
getInt64(i));
CreateProbeRead(val_char, 8, ptr);
Value *l = CreateLoad(getInt8Ty(), val_char);
CreateLifetimeEnd(store);
Value *r = getInt8(c_str[i]);
Value *cmp = CreateICmpNE(l, r, "strcmp.cmp");
CreateCondBr(cmp, str_ne, char_eq);
SetInsertPoint(char_eq);
}
CreateStore(getInt1(!inverse), store);
CreateBr(str_ne);
SetInsertPoint(str_ne);
Value *result = CreateLoad(store);
CreateLifetimeEnd(store);
return result;
}
CallInst *IRBuilderBPF::CreateGetNs()
{
// u64 ktime_get_ns()
......
......@@ -34,6 +34,7 @@ public:
CallInst *CreateProbeReadStr(AllocaInst *dst, size_t size, Value *src);
CallInst *CreateProbeReadStr(Value *dst, size_t size, Value *src);
Value *CreateUSDTReadArgument(Value *ctx, AttachPoint *attach_point, int arg_name, Builtin &builtin);
Value *CreateStrcmp(Value* val, std::string str, bool inverse=false);
CallInst *CreateGetNs();
CallInst *CreateGetPidTgid();
CallInst *CreateGetUidGid();
......
......@@ -382,6 +382,10 @@ void SemanticAnalyser::visit(Binop &binop)
err_ << "comparing '" << lhs << "' ";
err_ << "with '" << rhs << "'" << std::endl;
}
else if (lhs == Type::string && !(binop.left->is_literal || binop.right->is_literal)) {
err_ << "Comparison between two variables of ";
err_ << "type string is not allowed" << std::endl;
}
else if (lhs != Type::integer &&
binop.op != Parser::token::EQ &&
binop.op != Parser::token::NE) {
......
......@@ -2152,6 +2152,260 @@ attributes #1 = { argmemonly nounwind }
)EXPECTED");
}
TEST(codegen, string_equal_comparison)
{
test("kretprobe:vfs_read /comm == \"sshd\"/ { @[comm] = count(); }",
R"EXPECTED(; Function Attrs: nounwind
declare i64 @llvm.bpf.pseudo(i64, i64) #0
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1
define i64 @"kretprobe:vfs_read"(i8* nocapture readnone) local_unnamed_addr section "s_kretprobe:vfs_read_1" {
entry:
%"@_val" = alloca i64, align 8
%comm17 = alloca [16 x i8], align 1
%"@_key" = alloca [16 x i8], align 1
%strcmp.char14 = alloca i8, align 1
%strcmp.char10 = alloca i8, align 1
%strcmp.char6 = alloca i8, align 1
%strcmp.char2 = alloca i8, align 1
%strcmp.char = alloca i8, align 1
%comm = alloca [16 x i8], align 1
%1 = getelementptr inbounds [16 x i8], [16 x i8]* %comm, i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %1)
call void @llvm.memset.p0i8.i64(i8* nonnull %1, i8 0, i64 16, i32 1, i1 false)
%get_comm = call i64 inttoptr (i64 16 to i64 (i8*, i64)*)([16 x i8]* nonnull %comm, i64 16)
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char)
%probe_read = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char, i64 8, [16 x i8]* nonnull %comm)
%2 = load i8, i8* %strcmp.char, align 1
%strcmp.cmp = icmp eq i8 %2, 115
br i1 %strcmp.cmp, label %strcmp.loop, label %pred_false.critedge
pred_false.critedge: ; preds = %entry
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_false
pred_false.critedge20: ; preds = %strcmp.loop
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_false
pred_false.critedge21: ; preds = %strcmp.loop1
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_false
pred_false.critedge22: ; preds = %strcmp.loop5
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_false
pred_false: ; preds = %strcmp.loop9, %pred_false.critedge22, %pred_false.critedge21, %pred_false.critedge20, %pred_false.critedge
ret i64 0
pred_true: ; preds = %strcmp.loop9
%3 = getelementptr inbounds [16 x i8], [16 x i8]* %"@_key", i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %3)
%4 = getelementptr inbounds [16 x i8], [16 x i8]* %comm17, i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %4)
call void @llvm.memset.p0i8.i64(i8* nonnull %4, i8 0, i64 16, i32 1, i1 false)
%get_comm18 = call i64 inttoptr (i64 16 to i64 (i8*, i64)*)([16 x i8]* nonnull %comm17, i64 16)
call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull %3, i8* nonnull %4, i64 16, i32 1, i1 false)
%pseudo = call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%lookup_elem = call i8* inttoptr (i64 1 to i8* (i8*, i8*)*)(i64 %pseudo, [16 x i8]* nonnull %"@_key")
%map_lookup_cond = icmp eq i8* %lookup_elem, null
br i1 %map_lookup_cond, label %lookup_merge, label %lookup_success
strcmp.loop: ; preds = %entry
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char2)
%5 = add [16 x i8]* %comm, i64 1
%probe_read3 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char2, i64 8, [16 x i8]* %5)
%6 = load i8, i8* %strcmp.char2, align 1
%strcmp.cmp4 = icmp eq i8 %6, 115
br i1 %strcmp.cmp4, label %strcmp.loop1, label %pred_false.critedge20
strcmp.loop1: ; preds = %strcmp.loop
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char6)
%7 = add [16 x i8]* %comm, i64 2
%probe_read7 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char6, i64 8, [16 x i8]* %7)
%8 = load i8, i8* %strcmp.char6, align 1
%strcmp.cmp8 = icmp eq i8 %8, 104
br i1 %strcmp.cmp8, label %strcmp.loop5, label %pred_false.critedge21
strcmp.loop5: ; preds = %strcmp.loop1
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char10)
%9 = add [16 x i8]* %comm, i64 3
%probe_read11 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char10, i64 8, [16 x i8]* %9)
%10 = load i8, i8* %strcmp.char10, align 1
%strcmp.cmp12 = icmp eq i8 %10, 100
br i1 %strcmp.cmp12, label %strcmp.loop9, label %pred_false.critedge22
strcmp.loop9: ; preds = %strcmp.loop5
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char14)
%11 = add [16 x i8]* %comm, i64 4
%probe_read15 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char14, i64 8, [16 x i8]* %11)
%12 = load i8, i8* %strcmp.char14, align 1
%strcmp.cmp16 = icmp eq i8 %12, 0
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br i1 %strcmp.cmp16, label %pred_true, label %pred_false
lookup_success: ; preds = %pred_true
%13 = load i64, i8* %lookup_elem, align 8
%phitmp = add i64 %13, 1
br label %lookup_merge
lookup_merge: ; preds = %pred_true, %lookup_success
%lookup_elem_val.0 = phi i64 [ %phitmp, %lookup_success ], [ 1, %pred_true ]
%14 = bitcast i64* %"@_val" to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %14)
store i64 %lookup_elem_val.0, i64* %"@_val", align 8
%pseudo19 = call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%update_elem = call i64 inttoptr (i64 2 to i64 (i8*, i8*, i8*, i64)*)(i64 %pseudo19, [16 x i8]* nonnull %"@_key", i64* nonnull %"@_val", i64 0)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %3)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %14)
ret i64 0
}
; Function Attrs: argmemonly nounwind
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i32, i1) #1
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1
; Function Attrs: argmemonly nounwind
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i32, i1) #1
attributes #0 = { nounwind }
attributes #1 = { argmemonly nounwind }
)EXPECTED");
}
TEST(codegen, string_not_equal_comparison)
{
test("kretprobe:vfs_read /comm != \"sshd\"/ { @[comm] = count(); }",
R"EXPECTED(; Function Attrs: nounwind
declare i64 @llvm.bpf.pseudo(i64, i64) #0
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1
define i64 @"kretprobe:vfs_read"(i8* nocapture readnone) local_unnamed_addr section "s_kretprobe:vfs_read_1" {
entry:
%"@_val" = alloca i64, align 8
%comm17 = alloca [16 x i8], align 1
%"@_key" = alloca [16 x i8], align 1
%strcmp.char14 = alloca i8, align 1
%strcmp.char10 = alloca i8, align 1
%strcmp.char6 = alloca i8, align 1
%strcmp.char2 = alloca i8, align 1
%strcmp.char = alloca i8, align 1
%comm = alloca [16 x i8], align 1
%1 = getelementptr inbounds [16 x i8], [16 x i8]* %comm, i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %1)
call void @llvm.memset.p0i8.i64(i8* nonnull %1, i8 0, i64 16, i32 1, i1 false)
%get_comm = call i64 inttoptr (i64 16 to i64 (i8*, i64)*)([16 x i8]* nonnull %comm, i64 16)
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char)
%probe_read = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char, i64 8, [16 x i8]* nonnull %comm)
%2 = load i8, i8* %strcmp.char, align 1
%strcmp.cmp = icmp eq i8 %2, 115
br i1 %strcmp.cmp, label %strcmp.loop, label %pred_true.critedge
pred_false: ; preds = %strcmp.loop9
ret i64 0
pred_true.critedge: ; preds = %entry
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_true
pred_true.critedge20: ; preds = %strcmp.loop
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_true
pred_true.critedge21: ; preds = %strcmp.loop1
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_true
pred_true.critedge22: ; preds = %strcmp.loop5
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br label %pred_true
pred_true: ; preds = %strcmp.loop9, %pred_true.critedge22, %pred_true.critedge21, %pred_true.critedge20, %pred_true.critedge
%3 = getelementptr inbounds [16 x i8], [16 x i8]* %"@_key", i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %3)
%4 = getelementptr inbounds [16 x i8], [16 x i8]* %comm17, i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %4)
call void @llvm.memset.p0i8.i64(i8* nonnull %4, i8 0, i64 16, i32 1, i1 false)
%get_comm18 = call i64 inttoptr (i64 16 to i64 (i8*, i64)*)([16 x i8]* nonnull %comm17, i64 16)
call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull %3, i8* nonnull %4, i64 16, i32 1, i1 false)
%pseudo = call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%lookup_elem = call i8* inttoptr (i64 1 to i8* (i8*, i8*)*)(i64 %pseudo, [16 x i8]* nonnull %"@_key")
%map_lookup_cond = icmp eq i8* %lookup_elem, null
br i1 %map_lookup_cond, label %lookup_merge, label %lookup_success
strcmp.loop: ; preds = %entry
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char2)
%5 = add [16 x i8]* %comm, i64 1
%probe_read3 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char2, i64 8, [16 x i8]* %5)
%6 = load i8, i8* %strcmp.char2, align 1
%strcmp.cmp4 = icmp eq i8 %6, 115
br i1 %strcmp.cmp4, label %strcmp.loop1, label %pred_true.critedge20
strcmp.loop1: ; preds = %strcmp.loop
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char6)
%7 = add [16 x i8]* %comm, i64 2
%probe_read7 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char6, i64 8, [16 x i8]* %7)
%8 = load i8, i8* %strcmp.char6, align 1
%strcmp.cmp8 = icmp eq i8 %8, 104
br i1 %strcmp.cmp8, label %strcmp.loop5, label %pred_true.critedge21
strcmp.loop5: ; preds = %strcmp.loop1
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char10)
%9 = add [16 x i8]* %comm, i64 3
%probe_read11 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char10, i64 8, [16 x i8]* %9)
%10 = load i8, i8* %strcmp.char10, align 1
%strcmp.cmp12 = icmp eq i8 %10, 100
br i1 %strcmp.cmp12, label %strcmp.loop9, label %pred_true.critedge22
strcmp.loop9: ; preds = %strcmp.loop5
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %strcmp.char14)
%11 = add [16 x i8]* %comm, i64 4
%probe_read15 = call i64 inttoptr (i64 4 to i64 (i8*, i64, i8*)*)(i8* nonnull %strcmp.char14, i64 8, [16 x i8]* %11)
%12 = load i8, i8* %strcmp.char14, align 1
%strcmp.cmp16 = icmp eq i8 %12, 0
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
br i1 %strcmp.cmp16, label %pred_false, label %pred_true
lookup_success: ; preds = %pred_true
%13 = load i64, i8* %lookup_elem, align 8
%phitmp = add i64 %13, 1
br label %lookup_merge
lookup_merge: ; preds = %pred_true, %lookup_success
%lookup_elem_val.0 = phi i64 [ %phitmp, %lookup_success ], [ 1, %pred_true ]
%14 = bitcast i64* %"@_val" to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %14)
store i64 %lookup_elem_val.0, i64* %"@_val", align 8
%pseudo19 = call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%update_elem = call i64 inttoptr (i64 2 to i64 (i8*, i8*, i8*, i64)*)(i64 %pseudo19, [16 x i8]* nonnull %"@_key", i64* nonnull %"@_val", i64 0)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %3)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %14)
ret i64 0
}
; Function Attrs: argmemonly nounwind
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i32, i1) #1
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1
; Function Attrs: argmemonly nounwind
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i32, i1) #1
attributes #0 = { nounwind }
attributes #1 = { argmemonly nounwind }
)EXPECTED");
}
TEST(codegen, pred_binop)
{
test("kprobe:f / pid == 1234 / { @x = 1 }",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment