Commit 9198d61f authored by Brendan Gregg's avatar Brendan Gregg

add ternary operator

parent b57abde0
...@@ -36,6 +36,10 @@ void Unop::accept(Visitor &v) { ...@@ -36,6 +36,10 @@ void Unop::accept(Visitor &v) {
v.visit(*this); v.visit(*this);
} }
void Ternary::accept(Visitor &v) {
v.visit(*this);
}
void FieldAccess::accept(Visitor &v) { void FieldAccess::accept(Visitor &v) {
v.visit(*this); v.visit(*this);
} }
......
...@@ -160,6 +160,14 @@ public: ...@@ -160,6 +160,14 @@ public:
void accept(Visitor &v) override; void accept(Visitor &v) override;
}; };
class Ternary : public Expression {
public:
Ternary(Expression *cond, Expression *left, Expression *right) : cond(cond), left(left), right(right) { }
Expression *cond, *left, *right;
void accept(Visitor &v) override;
};
class AttachPoint : public Node { class AttachPoint : public Node {
public: public:
explicit AttachPoint(const std::string &provider) explicit AttachPoint(const std::string &provider)
...@@ -233,6 +241,7 @@ public: ...@@ -233,6 +241,7 @@ public:
virtual void visit(Variable &var) = 0; virtual void visit(Variable &var) = 0;
virtual void visit(Binop &binop) = 0; virtual void visit(Binop &binop) = 0;
virtual void visit(Unop &unop) = 0; virtual void visit(Unop &unop) = 0;
virtual void visit(Ternary &ternary) = 0;
virtual void visit(FieldAccess &acc) = 0; virtual void visit(FieldAccess &acc) = 0;
virtual void visit(Cast &cast) = 0; virtual void visit(Cast &cast) = 0;
virtual void visit(ExprStatement &expr) = 0; virtual void visit(ExprStatement &expr) = 0;
......
...@@ -583,6 +583,53 @@ void CodegenLLVM::visit(Unop &unop) ...@@ -583,6 +583,53 @@ void CodegenLLVM::visit(Unop &unop)
} }
} }
void CodegenLLVM::visit(Ternary &ternary)
{
Function *parent = b_.GetInsertBlock()->getParent();
BasicBlock *left_block = BasicBlock::Create(module_->getContext(), "left", parent);
BasicBlock *right_block = BasicBlock::Create(module_->getContext(), "right", parent);
BasicBlock *done = BasicBlock::Create(module_->getContext(), "done", parent);
// ordering of all the following statements is important
Value *result = b_.CreateAllocaBPF(ternary.type, "result");
AllocaInst *buf = b_.CreateAllocaBPF(ternary.type, "buf");
Value *cond;
ternary.cond->accept(*this);
cond = expr_;
b_.CreateCondBr(b_.CreateICmpNE(cond, b_.getInt64(0), "true_cond"),
left_block, right_block);
if (ternary.type.type == Type::integer) {
// fetch selected integer via CreateStore
b_.SetInsertPoint(left_block);
ternary.left->accept(*this);
b_.CreateStore(expr_, result);
b_.CreateBr(done);
b_.SetInsertPoint(right_block);
ternary.right->accept(*this);
b_.CreateStore(expr_, result);
b_.CreateBr(done);
b_.SetInsertPoint(done);
expr_ = b_.CreateLoad(result);
} else {
// copy selected string via CreateMemCpy
b_.SetInsertPoint(left_block);
ternary.left->accept(*this);
b_.CreateMemCpy(buf, expr_, ternary.type.size, 1);
b_.CreateBr(done);
b_.SetInsertPoint(right_block);
ternary.right->accept(*this);
b_.CreateMemCpy(buf, expr_, ternary.type.size, 1);
b_.CreateBr(done);
b_.SetInsertPoint(done);
expr_ = buf;
}
}
void CodegenLLVM::visit(FieldAccess &acc) void CodegenLLVM::visit(FieldAccess &acc)
{ {
// TODO // TODO
......
...@@ -35,6 +35,7 @@ public: ...@@ -35,6 +35,7 @@ public:
void visit(Variable &var) override; void visit(Variable &var) override;
void visit(Binop &binop) override; void visit(Binop &binop) override;
void visit(Unop &unop) override; void visit(Unop &unop) override;
void visit(Ternary &ternary) override;
void visit(FieldAccess &acc) override; void visit(FieldAccess &acc) override;
void visit(Cast &cast) override; void visit(Cast &cast) override;
void visit(ExprStatement &expr) override; void visit(ExprStatement &expr) override;
......
...@@ -86,6 +86,18 @@ void Printer::visit(Unop &unop) ...@@ -86,6 +86,18 @@ void Printer::visit(Unop &unop)
--depth_; --depth_;
} }
void Printer::visit(Ternary &ternary)
{
std::string indent(depth_, ' ');
out_ << indent << "?:" << std::endl;
++depth_;
ternary.cond->accept(*this);
ternary.left->accept(*this);
ternary.right->accept(*this);
--depth_;
}
void Printer::visit(FieldAccess &acc) void Printer::visit(FieldAccess &acc)
{ {
std::string indent(depth_, ' '); std::string indent(depth_, ' ');
......
...@@ -18,6 +18,7 @@ public: ...@@ -18,6 +18,7 @@ public:
void visit(Variable &var) override; void visit(Variable &var) override;
void visit(Binop &binop) override; void visit(Binop &binop) override;
void visit(Unop &unop) override; void visit(Unop &unop) override;
void visit(Ternary &ternary) override;
void visit(FieldAccess &acc) override; void visit(FieldAccess &acc) override;
void visit(Cast &cast) override; void visit(Cast &cast) override;
void visit(ExprStatement &expr) override; void visit(ExprStatement &expr) override;
......
...@@ -390,6 +390,29 @@ void SemanticAnalyser::visit(Unop &unop) ...@@ -390,6 +390,29 @@ void SemanticAnalyser::visit(Unop &unop)
} }
} }
void SemanticAnalyser::visit(Ternary &ternary)
{
ternary.cond->accept(*this);
ternary.left->accept(*this);
ternary.right->accept(*this);
Type &lhs = ternary.left->type.type;
Type &rhs = ternary.right->type.type;
if (is_final_pass()) {
if (lhs != rhs) {
err_ << "Ternary operator must return the same type: ";
err_ << "have '" << lhs << "' ";
err_ << "and '" << rhs << "'" << std::endl;
}
}
if (lhs == Type::string)
ternary.type = SizedType(lhs, STRING_SIZE);
else if (lhs == Type::integer)
ternary.type = SizedType(lhs, 8);
else {
err_ << "Ternary return type unsupported " << lhs << std::endl;
}
}
void SemanticAnalyser::visit(FieldAccess &acc) void SemanticAnalyser::visit(FieldAccess &acc)
{ {
acc.expr->accept(*this); acc.expr->accept(*this);
......
...@@ -26,6 +26,7 @@ public: ...@@ -26,6 +26,7 @@ public:
void visit(Variable &var) override; void visit(Variable &var) override;
void visit(Binop &binop) override; void visit(Binop &binop) override;
void visit(Unop &unop) override; void visit(Unop &unop) override;
void visit(Ternary &ternary) override;
void visit(FieldAccess &acc) override; void visit(FieldAccess &acc) override;
void visit(Cast &cast) override; void visit(Cast &cast) override;
void visit(ExprStatement &expr) override; void visit(ExprStatement &expr) override;
......
...@@ -78,6 +78,7 @@ pid|tid|uid|gid|nsecs|cpu|comm|stack|ustack|arg[0-9]|retval|func|name { ...@@ -78,6 +78,7 @@ pid|tid|uid|gid|nsecs|cpu|comm|stack|ustack|arg[0-9]|retval|func|name {
"#include" { return Parser::make_INCLUDE(loc); } "#include" { return Parser::make_INCLUDE(loc); }
"." { return Parser::make_DOT(loc); } "." { return Parser::make_DOT(loc); }
"->" { return Parser::make_PTR(loc); } "->" { return Parser::make_PTR(loc); }
"?" { return Parser::make_QUES(loc); }
\" { BEGIN(STR); string_buffer.clear(); } \" { BEGIN(STR); string_buffer.clear(); }
<STR>\" { BEGIN(INITIAL); return Parser::make_STRING(string_buffer, loc); } <STR>\" { BEGIN(INITIAL); return Parser::make_STRING(string_buffer, loc); }
......
...@@ -44,6 +44,7 @@ void yyerror(bpftrace::Driver &driver, const char *s); ...@@ -44,6 +44,7 @@ void yyerror(bpftrace::Driver &driver, const char *s);
RBRACKET "]" RBRACKET "]"
LPAREN "(" LPAREN "("
RPAREN ")" RPAREN ")"
QUES "?"
ENDPRED "end predicate" ENDPRED "end predicate"
COMMA "," COMMA ","
ASSIGN "=" ASSIGN "="
...@@ -84,6 +85,7 @@ void yyerror(bpftrace::Driver &driver, const char *s); ...@@ -84,6 +85,7 @@ void yyerror(bpftrace::Driver &driver, const char *s);
%type <ast::ProbeList *> probes %type <ast::ProbeList *> probes
%type <ast::Probe *> probe %type <ast::Probe *> probe
%type <ast::Predicate *> pred %type <ast::Predicate *> pred
%type <ast::Ternary *> ternary
%type <ast::StatementList *> block stmts %type <ast::StatementList *> block stmts
%type <ast::Statement *> stmt %type <ast::Statement *> stmt
%type <ast::Expression *> expr %type <ast::Expression *> expr
...@@ -98,6 +100,7 @@ void yyerror(bpftrace::Driver &driver, const char *s); ...@@ -98,6 +100,7 @@ void yyerror(bpftrace::Driver &driver, const char *s);
%type <std::string> ident %type <std::string> ident
%right ASSIGN %right ASSIGN
%left QUES COLON
%left LOR %left LOR
%left LAND %left LAND
%left BOR %left BOR
...@@ -153,6 +156,9 @@ pred : DIV expr ENDPRED { $$ = new ast::Predicate($2); } ...@@ -153,6 +156,9 @@ pred : DIV expr ENDPRED { $$ = new ast::Predicate($2); }
| { $$ = nullptr; } | { $$ = nullptr; }
; ;
ternary : expr QUES expr COLON expr { $$ = new ast::Ternary($1, $3, $5); }
;
block : "{" stmts "}" { $$ = $2; } block : "{" stmts "}" { $$ = $2; }
| "{" stmts ";" "}" { $$ = $2; } | "{" stmts ";" "}" { $$ = $2; }
; ;
...@@ -169,6 +175,7 @@ stmt : expr { $$ = new ast::ExprStatement($1); } ...@@ -169,6 +175,7 @@ stmt : expr { $$ = new ast::ExprStatement($1); }
expr : INT { $$ = new ast::Integer($1); } expr : INT { $$ = new ast::Integer($1); }
| STRING { $$ = new ast::String($1); } | STRING { $$ = new ast::String($1); }
| BUILTIN { $$ = new ast::Builtin($1); } | BUILTIN { $$ = new ast::Builtin($1); }
| ternary { $$ = $1; }
| map { $$ = $1; } | map { $$ = $1; }
| var { $$ = $1; } | var { $$ = $1; }
| call { $$ = $1; } | call { $$ = $1; }
......
...@@ -2017,6 +2017,104 @@ attributes #1 = { argmemonly nounwind } ...@@ -2017,6 +2017,104 @@ attributes #1 = { argmemonly nounwind }
)EXPECTED"); )EXPECTED");
} }
TEST(codegen, ternary_int)
{
test("kprobe:f { @x = pid < 10000 ? 1 : 2; }",
R"EXPECTED(; Function Attrs: nounwind
declare i64 @llvm.bpf.pseudo(i64, i64) #0
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1
define i64 @"kprobe:f"(i8* nocapture readnone) local_unnamed_addr section "s_kprobe:f" {
entry:
%"@x_val" = alloca i64, align 8
%"@x_key" = alloca i64, align 8
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%1 = icmp ult i64 %get_pid_tgid, 42949672960000
%. = select i1 %1, i64 1, i64 2
%2 = bitcast i64* %"@x_key" to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %2)
store i64 0, i64* %"@x_key", align 8
%3 = bitcast i64* %"@x_val" to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %3)
store i64 %., i64* %"@x_val", align 8
%pseudo = tail call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%update_elem = call i64 inttoptr (i64 2 to i64 (i8*, i8*, i8*, i64)*)(i64 %pseudo, i64* nonnull %"@x_key", i64* nonnull %"@x_val", i64 0)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %2)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %3)
ret i64 0
}
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1
attributes #0 = { nounwind }
attributes #1 = { argmemonly nounwind }
)EXPECTED");
}
TEST(codegen, ternary_str)
{
test("kprobe:f { @x = pid < 10000 ? \"lo\" : \"hi\"; }",
R"EXPECTED(; Function Attrs: nounwind
declare i64 @llvm.bpf.pseudo(i64, i64) #0
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture) #1
define i64 @"kprobe:f"(i8* nocapture readnone) local_unnamed_addr section "s_kprobe:f" {
entry:
%"@x_key" = alloca i64, align 8
%buf = alloca [64 x i8], align 1
%1 = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 0
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %1)
%get_pid_tgid = tail call i64 inttoptr (i64 14 to i64 ()*)()
%2 = icmp ult i64 %get_pid_tgid, 42949672960000
br i1 %2, label %left, label %right
left: ; preds = %entry
store i8 108, i8* %1, align 1
%str.sroa.3.0..sroa_idx = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 1
store i8 111, i8* %str.sroa.3.0..sroa_idx, align 1
%str.sroa.4.0..sroa_idx = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 2
call void @llvm.memset.p0i8.i64(i8* nonnull %str.sroa.4.0..sroa_idx, i8 0, i64 61, i32 1, i1 false)
br label %done
right: ; preds = %entry
store i8 104, i8* %1, align 1
%str1.sroa.3.0..sroa_idx = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 1
store i8 105, i8* %str1.sroa.3.0..sroa_idx, align 1
%str1.sroa.4.0..sroa_idx = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 2
call void @llvm.memset.p0i8.i64(i8* nonnull %str1.sroa.4.0..sroa_idx, i8 0, i64 61, i32 1, i1 false)
br label %done
done: ; preds = %right, %left
%3 = getelementptr inbounds [64 x i8], [64 x i8]* %buf, i64 0, i64 63
store i8 0, i8* %3, align 1
%4 = bitcast i64* %"@x_key" to i8*
call void @llvm.lifetime.start.p0i8(i64 -1, i8* nonnull %4)
store i64 0, i64* %"@x_key", align 8
%pseudo = tail call i64 @llvm.bpf.pseudo(i64 1, i64 1)
%update_elem = call i64 inttoptr (i64 2 to i64 (i8*, i8*, i8*, i64)*)(i64 %pseudo, i64* nonnull %"@x_key", [64 x i8]* nonnull %buf, i64 0)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %4)
call void @llvm.lifetime.end.p0i8(i64 -1, i8* nonnull %1)
ret i64 0
}
; Function Attrs: argmemonly nounwind
declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture) #1
; Function Attrs: argmemonly nounwind
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i32, i1) #1
attributes #0 = { nounwind }
attributes #1 = { argmemonly nounwind }
)EXPECTED");
}
} // namespace codegen } // namespace codegen
} // namespace test } // namespace test
} // namespace bpftrace } // namespace bpftrace
...@@ -220,6 +220,56 @@ TEST(Parser, expressions) ...@@ -220,6 +220,56 @@ TEST(Parser, expressions)
" int: 1\n"); " int: 1\n");
} }
TEST(Parser, ternary_int)
{
test("kprobe:sys_open { @x = pid < 10000 ? 1 : 2 }",
"Program\n"
" kprobe:sys_open\n"
" =\n"
" map: @x\n"
" ?:\n"
" <\n"
" builtin: pid\n"
" int: 10000\n"
" int: 1\n"
" int: 2\n");
}
TEST(Parser, ternary_str)
{
test("kprobe:sys_open { @x = pid < 10000 ? \"lo\" : \"high\" }",
"Program\n"
" kprobe:sys_open\n"
" =\n"
" map: @x\n"
" ?:\n"
" <\n"
" builtin: pid\n"
" int: 10000\n"
" string: lo\n"
" string: high\n");
}
TEST(Parser, ternary_nested)
{
test("kprobe:sys_open { @x = pid < 10000 ? pid < 5000 ? 1 : 2 : 3 }",
"Program\n"
" kprobe:sys_open\n"
" =\n"
" map: @x\n"
" ?:\n"
" <\n"
" builtin: pid\n"
" int: 10000\n"
" ?:\n"
" <\n"
" builtin: pid\n"
" int: 5000\n"
" int: 1\n"
" int: 2\n"
" int: 3\n");
}
TEST(Parser, call) TEST(Parser, call)
{ {
test("kprobe:sys_open { @x = count(); @y = hist(1,2,3); delete(@x); }", test("kprobe:sys_open { @x = count(); @y = hist(1,2,3); delete(@x); }",
......
...@@ -126,6 +126,14 @@ TEST(semantic_analyser, predicate_expressions) ...@@ -126,6 +126,14 @@ TEST(semantic_analyser, predicate_expressions)
test("kprobe:f / @mymap / { @mymap = \"str\" }", 10); test("kprobe:f / @mymap / { @mymap = \"str\" }", 10);
} }
TEST(semantic_analyser, ternary_experssions)
{
test("kprobe:f { @x = pid < 10000 ? 1 : 2 }", 0);
test("kprobe:f { @x = pid < 10000 ? \"lo\" : \"high\" }", 0);
test("kprobe:f { @x = pid < 10000 ? 1 : \"high\" }", 10);
test("kprobe:f { @x = pid < 10000 ? \"lo\" : 2 }", 10);
}
TEST(semantic_analyser, mismatched_call_types) TEST(semantic_analyser, mismatched_call_types)
{ {
test("kprobe:f { @x = 1; @x = count(); }", 1); test("kprobe:f { @x = 1; @x = count(); }", 1);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment