Commit e5f633f2 authored by Gaspar's avatar Gaspar Committed by Alastair Robertson

if-else statements (#114)

fixes: https://github.com/iovisor/bpftrace/issues/8
parent 7aa1bb17
......@@ -399,6 +399,32 @@ open path: retrans_time_ms
Some kernel headers needed to be included to understand the `path` and `dentry` structs.
## 5. `? :`: ternary operators
Example:
```
# bpftrace -e 'tracepoint:syscalls:sys_exit_read { @error[args->ret < 0 ? - args->ret : 0] = count(); }'
Attaching 1 probe...
^C
@error[11]: 24
@error[0]: 78
```
## 6. `if () {...} else {...}`: if-else statements
Example:
```
# bpftrace -e 'tracepoint:syscalls:sys_enter_read { @reads = count(); if (args->count > 1024) { @large = count(); } }'
Attaching 1 probe...
^C
@large: 72
@reads: 80
```
# Probes
- `kprobe` - kernel function start
......
......@@ -68,6 +68,10 @@ void AttachPoint::accept(Visitor &v) {
v.visit(*this);
}
void If::accept(Visitor &v) {
v.visit(*this);
}
void Probe::accept(Visitor &v) {
v.visit(*this);
}
......
......@@ -155,6 +155,18 @@ public:
void accept(Visitor &v) override;
};
class If : public Statement {
public:
If(Expression *cond, StatementList *stmts) : cond(cond), stmts(stmts) { }
If(Expression *cond, StatementList *stmts, StatementList *else_stmts)
: cond(cond), stmts(stmts), else_stmts(else_stmts) { }
Expression *cond;
StatementList *stmts = nullptr;
StatementList *else_stmts = nullptr;
void accept(Visitor &v) override;
};
class Predicate : public Node {
public:
explicit Predicate(Expression *expr) : expr(expr) { }
......@@ -249,6 +261,7 @@ public:
virtual void visit(ExprStatement &expr) = 0;
virtual void visit(AssignMapStatement &assignment) = 0;
virtual void visit(AssignVarStatement &assignment) = 0;
virtual void visit(If &if_block) = 0;
virtual void visit(Predicate &pred) = 0;
virtual void visit(AttachPoint &ap) = 0;
virtual void visit(Probe &probe) = 0;
......
......@@ -604,7 +604,14 @@ void CodegenLLVM::visit(Map &map)
void CodegenLLVM::visit(Variable &var)
{
expr_ = variables_[var.ident];
if (!var.type.IsArray())
{
expr_ = b_.CreateLoad(variables_[var.ident]);
}
else
{
expr_ = variables_[var.ident];
}
}
void CodegenLLVM::visit(Binop &binop)
......@@ -900,7 +907,59 @@ void CodegenLLVM::visit(AssignVarStatement &assignment)
Variable &var = *assignment.var;
assignment.expr->accept(*this);
variables_[var.ident] = expr_;
if (variables_.find(var.ident) == variables_.end())
{
AllocaInst *val = b_.CreateAllocaBPFInit(var.type, var.ident);
variables_[var.ident] = val;
}
if (!var.type.IsArray())
{
b_.CreateStore(expr_, variables_[var.ident]);
}
else
{
b_.CreateMemCpy(variables_[var.ident], expr_, var.type.size, 1);
}
}
void CodegenLLVM::visit(If &if_block)
{
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *if_true = BasicBlock::Create(module_->getContext(), "if_stmt", parent);
BasicBlock *if_false = BasicBlock::Create(module_->getContext(), "else_stmt", parent);
if_block.cond->accept(*this);
Value *cond = expr_;
b_.CreateCondBr(b_.CreateICmpNE(cond, b_.getInt64(0), "true_cond"), if_true, if_false);
b_.SetInsertPoint(if_true);
for (Statement *stmt : *if_block.stmts)
{
stmt->accept(*this);
}
if (if_block.else_stmts)
{
BasicBlock *done = BasicBlock::Create(module_->getContext(), "done", parent);
b_.CreateBr(done);
b_.SetInsertPoint(if_false);
for (Statement *stmt : *if_block.else_stmts)
{
stmt->accept(*this);
}
b_.CreateBr(done);
b_.SetInsertPoint(done);
}
else
{
b_.CreateBr(if_false);
b_.SetInsertPoint(if_false);
}
}
void CodegenLLVM::visit(Predicate &pred)
......@@ -965,6 +1024,7 @@ void CodegenLLVM::visit(Probe &probe)
if (probe.pred) {
probe.pred->accept(*this);
}
for (Statement *stmt : *probe.stmts) {
stmt->accept(*this);
}
......
......@@ -41,6 +41,7 @@ public:
void visit(ExprStatement &expr) override;
void visit(AssignMapStatement &assignment) override;
void visit(AssignVarStatement &assignment) override;
void visit(If &if_block) override;
void visit(Predicate &pred) override;
void visit(AttachPoint &ap) override;
void visit(Probe &probe) override;
......@@ -70,7 +71,7 @@ private:
std::string probefull_;
std::map<std::string, int> next_probe_index_;
std::map<std::string, Value *> variables_;
std::map<std::string, AllocaInst *> variables_;
int printf_id_ = 0;
int time_id_ = 0;
int system_id_ = 0;
......
......@@ -58,6 +58,35 @@ AllocaInst *IRBuilderBPF::CreateAllocaBPF(const SizedType &stype, const std::str
return CreateAllocaBPF(ty, nullptr, name);
}
AllocaInst *IRBuilderBPF::CreateAllocaBPFInit(const SizedType &stype, const std::string &name)
{
Function *parent = GetInsertBlock()->getParent();
BasicBlock &entry_block = parent->getEntryBlock();
auto ip = saveIP();
if (entry_block.empty())
SetInsertPoint(&entry_block);
else
SetInsertPoint(&entry_block.front());
llvm::Type *ty = GetType(stype);
AllocaInst *alloca = CreateAllocaBPF(ty, nullptr, name);
if (!stype.IsArray())
{
CreateStore(getInt64(0), alloca);
}
else
{
CreateMemSet(alloca, getInt64(0), stype.size, 1);
}
restoreIP(ip);
CreateLifetimeStart(alloca);
return alloca;
}
AllocaInst *IRBuilderBPF::CreateAllocaBPF(const SizedType &stype, llvm::Value *arraysize, const std::string &name)
{
llvm::Type *ty = GetType(stype);
......
......@@ -21,6 +21,7 @@ public:
AllocaInst *CreateAllocaBPF(llvm::Type *ty, const std::string &name="");
AllocaInst *CreateAllocaBPF(const SizedType &stype, const std::string &name="");
AllocaInst *CreateAllocaBPFInit(const SizedType &stype, const std::string &name);
AllocaInst *CreateAllocaBPF(llvm::Type *ty, llvm::Value *arraysize, const std::string &name="");
AllocaInst *CreateAllocaBPF(const SizedType &stype, llvm::Value *arraysize, const std::string &name="");
AllocaInst *CreateAllocaBPF(int bytes, const std::string &name="");
......
......@@ -150,6 +150,31 @@ void Printer::visit(AssignVarStatement &assignment)
--depth_;
}
void Printer::visit(If &if_block)
{
std::string indent(depth_, ' ');
out_ << indent << "if" << std::endl;
++depth_;
if_block.cond->accept(*this);
++depth_;
out_ << indent << " then" << std::endl;
for (Statement *stmt : *if_block.stmts) {
stmt->accept(*this);
}
if (if_block.else_stmts) {
out_ << indent << " else" << std::endl;
for (Statement *stmt : *if_block.else_stmts) {
stmt->accept(*this);
}
}
depth_ -= 2;
}
void Printer::visit(Predicate &pred)
{
std::string indent(depth_, ' ');
......
......@@ -24,6 +24,7 @@ public:
void visit(ExprStatement &expr) override;
void visit(AssignMapStatement &assignment) override;
void visit(AssignVarStatement &assignment) override;
void visit(If &if_block) override;
void visit(Predicate &pred) override;
void visit(AttachPoint &ap) override;
void visit(Probe &probe) override;
......
......@@ -480,6 +480,21 @@ void SemanticAnalyser::visit(Ternary &ternary)
}
}
void SemanticAnalyser::visit(If &if_block)
{
if_block.cond->accept(*this);
for (Statement *stmt : *if_block.stmts) {
stmt->accept(*this);
}
if (if_block.else_stmts) {
for (Statement *stmt : *if_block.else_stmts) {
stmt->accept(*this);
}
}
}
void SemanticAnalyser::visit(FieldAccess &acc)
{
acc.expr->accept(*this);
......@@ -596,6 +611,7 @@ void SemanticAnalyser::visit(AssignVarStatement &assignment)
std::string var_ident = assignment.var->ident;
auto search = variable_val_.find(var_ident);
assignment.var->type = assignment.expr->type;
if (search != variable_val_.end()) {
if (search->second.type == Type::none) {
if (is_final_pass()) {
......@@ -605,7 +621,7 @@ void SemanticAnalyser::visit(AssignVarStatement &assignment)
search->second = assignment.expr->type;
}
}
else if (search->second.type != assignment.expr->type.type) {
else if (search->second.type != assignment.expr->type.type || search->second.size != assignment.expr->type.size) {
err_ << "Type mismatch for " << var_ident << ": ";
err_ << "trying to assign value of type '" << assignment.expr->type;
err_ << "'\n\twhen variable already contains a value of type '";
......
......@@ -32,6 +32,7 @@ public:
void visit(ExprStatement &expr) override;
void visit(AssignMapStatement &assignment) override;
void visit(AssignVarStatement &assignment) override;
void visit(If &if_block) override;
void visit(Predicate &pred) override;
void visit(AttachPoint &ap) override;
void visit(Probe &probe) override;
......
......@@ -88,6 +88,8 @@ pid|tid|uid|gid|nsecs|cpu|comm|stack|ustack|arg[0-9]|retval|func|name|curtask|ra
"." { return Parser::make_DOT(loc); }
"->" { return Parser::make_PTR(loc); }
"#".* { return Parser::make_CPREPROC(yytext, loc); }
"if" { return Parser::make_IF(yytext, loc); }
"else" { return Parser::make_ELSE(yytext, loc); }
"?" { return Parser::make_QUES(loc); }
\" { BEGIN(STR); buffer.clear(); }
......
......@@ -80,6 +80,8 @@ void yyerror(bpftrace::Driver &driver, const char *s);
%token <std::string> STRING "string"
%token <std::string> MAP "map"
%token <std::string> VAR "variable"
%nonassoc <std::string> IF "if"
%nonassoc <std::string> ELSE "else"
%token <long> INT "integer"
%type <std::string> c_definitions
......@@ -168,6 +170,8 @@ stmts : stmts ";" stmt { $$ = $1; $1->push_back($3); }
stmt : expr { $$ = new ast::ExprStatement($1); }
| map "=" expr { $$ = new ast::AssignMapStatement($1, $3); }
| var "=" expr { $$ = new ast::AssignVarStatement($1, $3); }
| IF "(" expr ")" block { $$ = new ast::If($3, $5); }
| IF "(" expr ")" block ELSE block { $$ = new ast::If($3, $5, $7); }
;
expr : INT { $$ = new ast::Integer($1); }
......
This diff is collapsed.
......@@ -296,6 +296,59 @@ TEST(Parser, ternary_int)
" int: 2\n");
}
TEST(Parser, if_block)
{
test("kprobe:sys_open { if (pid > 10000) { printf(\"%d is high\\n\", pid); } }",
"Program\n"
" kprobe:sys_open\n"
" if\n"
" >\n"
" builtin: pid\n"
" int: 10000\n"
" then\n"
" call: printf\n"
" string: %d is high\\n\n"
" builtin: pid\n");
}
TEST(Parser, if_block_variable)
{
test("kprobe:sys_open { if (pid > 10000) { $s = 10; } }",
"Program\n"
" kprobe:sys_open\n"
" if\n"
" >\n"
" builtin: pid\n"
" int: 10000\n"
" then\n"
" =\n"
" variable: $s\n"
" int: 10\n");
}
TEST(Parser, if_else)
{
test("kprobe:sys_open { if (pid > 10000) { $s = \"a\"; } else { $s= \"b\"; }; printf(\"%d is high\\n\", pid, $s); }",
"Program\n"
" kprobe:sys_open\n"
" if\n"
" >\n"
" builtin: pid\n"
" int: 10000\n"
" then\n"
" =\n"
" variable: $s\n"
" string: a\n"
" else\n"
" =\n"
" variable: $s\n"
" string: b\n"
" call: printf\n"
" string: %d is high\\n\n"
" builtin: pid\n"
" variable: $s\n");
}
TEST(Parser, ternary_str)
{
test("kprobe:sys_open { @x = pid < 10000 ? \"lo\" : \"high\" }",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment