Commit 9c38f4d1 authored by Matheus Marchini's avatar Matheus Marchini

fix string comparison

This patch fixes strcmp by reimplementing it as a IRBuilder method and
unrolling the loop based on the literal string size instead of
STRING_SIZE.

Fixes: https://github.com/iovisor/bpftrace/issues/14
parent 0c89b18e
......@@ -617,33 +617,41 @@ void CodegenLLVM::visit(Binop &binop)
return;
}
Value *lhs, *rhs;
binop.left->accept(*this);
lhs = expr_;
binop.right->accept(*this);
rhs = expr_;
Type &type = binop.left->type.type;
if (type == Type::string)
{
Function *strcmp_func = module_->getFunction("strcmp");
Value *val;
std::string string_literal("");
if (binop.right->is_literal) {
binop.left->accept(*this);
val = expr_;
string_literal = reinterpret_cast<String*>(binop.right)->str;
} else {
binop.right->accept(*this);
val = expr_;
string_literal = reinterpret_cast<String*>(binop.left)->str;
}
switch (binop.op) {
case bpftrace::Parser::token::EQ:
expr_ = b_.CreateCall(strcmp_func, {lhs, rhs}, "strcmp");
expr_ = b_.CreateStrcmp(val, string_literal);
break;
case bpftrace::Parser::token::NE:
expr_ = b_.CreateNot(b_.CreateCall(strcmp_func, {lhs, rhs}, "strcmp"));
expr_ = b_.CreateStrcmp(val, string_literal, true);
break;
default:
abort();
}
if (!binop.left->is_variable)
b_.CreateLifetimeEnd(lhs);
if (!binop.right->is_variable)
b_.CreateLifetimeEnd(rhs);
b_.CreateLifetimeEnd(val);
}
else
{
Value *lhs, *rhs;
binop.left->accept(*this);
lhs = expr_;
binop.right->accept(*this);
rhs = expr_;
switch (binop.op) {
case bpftrace::Parser::token::EQ: expr_ = b_.CreateICmpEQ (lhs, rhs); break;
case bpftrace::Parser::token::NE: expr_ = b_.CreateICmpNE (lhs, rhs); break;
......@@ -1276,52 +1284,10 @@ void CodegenLLVM::createLinearFunction()
b_.CreateRet(b_.CreateLoad(result_alloc));
}
void CodegenLLVM::createStrcmpFunction()
{
// Returns 1 if strings match, 0 otherwise
// i1 strcmp(const char *s1, const char *s2)
// {
// for (int i=0; i<STRING_SIZE; i++)
// {
// if (s1[i] != s2[i]) return 0;
// }
// return 1;
// }
FunctionType *strcmp_func_type = FunctionType::get(b_.getInt1Ty(), {b_.getInt8PtrTy(), b_.getInt8PtrTy()}, false);
Function *strcmp_func = Function::Create(strcmp_func_type, Function::InternalLinkage, "strcmp", module_.get());
strcmp_func->addFnAttr(Attribute::AlwaysInline);
strcmp_func->setSection("helpers");
BasicBlock *entry = BasicBlock::Create(module_->getContext(), "strcmp.entry", strcmp_func);
BasicBlock *not_equal_block = BasicBlock::Create(module_->getContext(), "strcmp.not_equal", strcmp_func);
b_.SetInsertPoint(entry);
Value *s1 = strcmp_func->arg_begin();
Value *s2 = strcmp_func->arg_begin()+1;
for (int i=0; i<STRING_SIZE; i++)
{
Value *s1_char = b_.CreateGEP(s1, {b_.getInt64(i)});
Value *s2_char = b_.CreateGEP(s2, {b_.getInt64(i)});
BasicBlock *continue_block = BasicBlock::Create(module_->getContext(), "strcmp.continue", strcmp_func);
Value *cmp = b_.CreateICmpNE(b_.CreateLoad(s1_char), b_.CreateLoad(s2_char));
b_.CreateCondBr(cmp, not_equal_block, continue_block);
b_.SetInsertPoint(continue_block);
}
b_.CreateRet(b_.getInt1(1));
b_.SetInsertPoint(not_equal_block);
b_.CreateRet(b_.getInt1(0));
}
std::unique_ptr<BpfOrc> CodegenLLVM::compile(DebugLevel debug, std::ostream &out)
{
createLog2Function();
createLinearFunction();
createStrcmpFunction();
root_->accept(*this);
LLVMInitializeBPFTargetInfo();
......
......@@ -54,7 +54,6 @@ public:
void createLog2Function();
void createLinearFunction();
void createStrcmpFunction();
std::unique_ptr<BpfOrc> compile(DebugLevel debug=DebugLevel::kNone, std::ostream &out=std::cerr);
private:
......
......@@ -323,6 +323,39 @@ Value *IRBuilderBPF::CreateUSDTReadArgument(Value *ctx, AttachPoint *attach_poin
return result;
}
Value *IRBuilderBPF::CreateStrcmp(Value* val, std::string str, bool inverse) {
Function *parent = GetInsertBlock()->getParent();
BasicBlock *str_ne = BasicBlock::Create(module_.getContext(), "strcmp.false", parent);
AllocaInst *store = CreateAllocaBPF(getInt8Ty(), "strcmp.result");
CreateStore(getInt1(inverse), store);
const char *c_str = str.c_str();
for (int i = 0; i < strlen(c_str) + 1; i++)
{
BasicBlock *char_eq = BasicBlock::Create(module_.getContext(), "strcmp.loop", parent);
AllocaInst *val_char = CreateAllocaBPF(getInt8Ty(), "strcmp.char");
Value *ptr = CreateAdd(
val,
getInt64(i));
CreateProbeRead(val_char, 8, ptr);
Value *l = CreateLoad(getInt8Ty(), val_char);
CreateLifetimeEnd(store);
Value *r = getInt8(c_str[i]);
Value *cmp = CreateICmpNE(l, r, "strcmp.cmp");
CreateCondBr(cmp, str_ne, char_eq);
SetInsertPoint(char_eq);
}
CreateStore(getInt1(!inverse), store);
CreateBr(str_ne);
SetInsertPoint(str_ne);
Value *result = CreateLoad(store);
CreateLifetimeEnd(store);
return result;
}
CallInst *IRBuilderBPF::CreateGetNs()
{
// u64 ktime_get_ns()
......
......@@ -34,6 +34,7 @@ public:
CallInst *CreateProbeReadStr(AllocaInst *dst, size_t size, Value *src);
CallInst *CreateProbeReadStr(Value *dst, size_t size, Value *src);
Value *CreateUSDTReadArgument(Value *ctx, AttachPoint *attach_point, int arg_name, Builtin &builtin);
Value *CreateStrcmp(Value* val, std::string str, bool inverse=false);
CallInst *CreateGetNs();
CallInst *CreateGetPidTgid();
CallInst *CreateGetUidGid();
......
......@@ -382,6 +382,10 @@ void SemanticAnalyser::visit(Binop &binop)
err_ << "comparing '" << lhs << "' ";
err_ << "with '" << rhs << "'" << std::endl;
}
else if (lhs == Type::string && !(binop.left->is_literal || binop.right->is_literal)) {
err_ << "Comparison between two variables of ";
err_ << "type string is not allowed" << std::endl;
}
else if (lhs != Type::integer &&
binop.op != Parser::token::EQ &&
binop.op != Parser::token::NE) {
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment