Commit 98b90974 authored by Huapeng Zhou's avatar Huapeng Zhou

support macro in call arguments

parent 0107455c
...@@ -336,6 +336,12 @@ bool BTypeVisitor::TraverseCallExpr(CallExpr *Call) { ...@@ -336,6 +336,12 @@ bool BTypeVisitor::TraverseCallExpr(CallExpr *Call) {
// to: // to:
// bpf_table_foo_elem(bpf_pseudo_fd(table), &key [,&leaf]) // bpf_table_foo_elem(bpf_pseudo_fd(table), &key [,&leaf])
bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
// Get rewritten text given a source range, w/ expansion range applied
auto getRewrittenText = [this] (SourceRange R) {
auto r = rewriter_.getSourceMgr().getExpansionRange(R);
return rewriter_.getRewrittenText(r);
};
// make sure node is a reference to a bpf table, which is assured by the // make sure node is a reference to a bpf table, which is assured by the
// presence of the section("maps/<typename>") GNU __attribute__ // presence of the section("maps/<typename>") GNU __attribute__
if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) { if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) {
...@@ -345,9 +351,8 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -345,9 +351,8 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
if (!A->getName().startswith("maps")) if (!A->getName().startswith("maps"))
return true; return true;
SourceRange argRange(Call->getArg(0)->getLocStart(), string args = getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
Call->getArg(Call->getNumArgs()-1)->getLocEnd()); Call->getArg(Call->getNumArgs() - 1)->getLocEnd()));
string args = rewriter_.getRewrittenText(argRange);
// find the table fd, which was opened at declaration time // find the table fd, which was opened at declaration time
auto table_it = tables_.begin(); auto table_it = tables_.begin();
...@@ -366,10 +371,8 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -366,10 +371,8 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
if (memb_name == "lookup_or_init") { if (memb_name == "lookup_or_init") {
map_update_policy = "BPF_NOEXIST"; map_update_policy = "BPF_NOEXIST";
string name = Ref->getDecl()->getName(); string name = Ref->getDecl()->getName();
string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(), string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
Call->getArg(0)->getLocEnd())); string arg1 = getRewrittenText(Call->getArg(1)->getSourceRange());
string arg1 = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
Call->getArg(1)->getLocEnd()));
string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")"; string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")";
string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")"; string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")";
txt = "({typeof(" + name + ".leaf) *leaf = " + lookup + ", " + arg0 + "); "; txt = "({typeof(" + name + ".leaf) *leaf = " + lookup + ", " + arg0 + "); ";
...@@ -381,8 +384,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -381,8 +384,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
txt += "leaf;})"; txt += "leaf;})";
} else if (memb_name == "increment") { } else if (memb_name == "increment") {
string name = Ref->getDecl()->getName(); string name = Ref->getDecl()->getName();
string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(), string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
Call->getArg(0)->getLocEnd()));
string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")"; string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")";
string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")"; string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")";
txt = "({ typeof(" + name + ".key) _key = " + arg0 + "; "; txt = "({ typeof(" + name + ".key) _key = " + arg0 + "; ";
...@@ -394,21 +396,16 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -394,21 +396,16 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
txt += "if (_leaf) (*_leaf)++; })"; txt += "if (_leaf) (*_leaf)++; })";
} else if (memb_name == "perf_submit") { } else if (memb_name == "perf_submit") {
string name = Ref->getDecl()->getName(); string name = Ref->getDecl()->getName();
string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(), string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
Call->getArg(0)->getLocEnd())); string args_other = getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
string args_other = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(), Call->getArg(2)->getLocEnd()));
Call->getArg(2)->getLocEnd()));
txt = "bpf_perf_event_output(" + arg0 + ", bpf_pseudo_fd(1, " + fd + ")"; txt = "bpf_perf_event_output(" + arg0 + ", bpf_pseudo_fd(1, " + fd + ")";
txt += ", bpf_get_smp_processor_id(), " + args_other + ")"; txt += ", bpf_get_smp_processor_id(), " + args_other + ")";
} else if (memb_name == "perf_submit_skb") { } else if (memb_name == "perf_submit_skb") {
string skb = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(), string skb = getRewrittenText(Call->getArg(0)->getSourceRange());
Call->getArg(0)->getLocEnd())); string skb_len = getRewrittenText(Call->getArg(1)->getSourceRange());
string skb_len = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(), string meta = getRewrittenText(Call->getArg(2)->getSourceRange());
Call->getArg(1)->getLocEnd())); string meta_len = getRewrittenText(Call->getArg(3)->getSourceRange());
string meta = rewriter_.getRewrittenText(SourceRange(Call->getArg(2)->getLocStart(),
Call->getArg(2)->getLocEnd()));
string meta_len = rewriter_.getRewrittenText(SourceRange(Call->getArg(3)->getLocStart(),
Call->getArg(3)->getLocEnd()));
txt = "bpf_perf_event_output(" + txt = "bpf_perf_event_output(" +
skb + ", " + skb + ", " +
"bpf_pseudo_fd(1, " + fd + "), " + "bpf_pseudo_fd(1, " + fd + "), " +
...@@ -417,8 +414,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -417,8 +414,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
meta_len + ");"; meta_len + ");";
} else if (memb_name == "get_stackid") { } else if (memb_name == "get_stackid") {
if (table_it->type == BPF_MAP_TYPE_STACK_TRACE) { if (table_it->type == BPF_MAP_TYPE_STACK_TRACE) {
string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(), string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
Call->getArg(0)->getLocEnd()));
txt = "bpf_get_stackid("; txt = "bpf_get_stackid(";
txt += "bpf_pseudo_fd(1, " + fd + "), " + arg0; txt += "bpf_pseudo_fd(1, " + fd + "), " + arg0;
rewrite_end = Call->getArg(0)->getLocEnd(); rewrite_end = Call->getArg(0)->getLocEnd();
...@@ -474,7 +470,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -474,7 +470,7 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
vector<string> args; vector<string> args;
for (auto arg : Call->arguments()) for (auto arg : Call->arguments())
args.push_back(rewriter_.getRewrittenText(SourceRange(arg->getLocStart(), arg->getLocEnd()))); args.push_back(getRewrittenText(arg->getSourceRange()));
string text; string text;
if (Decl->getName() == "incr_cksum_l3") { if (Decl->getName() == "incr_cksum_l3") {
......
...@@ -352,5 +352,45 @@ int many(struct pt_regs *ctx, int a, int b, int c, int d, int e, int f, int g) { ...@@ -352,5 +352,45 @@ int many(struct pt_regs *ctx, int a, int b, int c, int d, int e, int f, int g) {
with self.assertRaises(Exception): with self.assertRaises(Exception):
b = BPF(text=text) b = BPF(text=text)
def test_call_macro_arg(self):
text = """
BPF_TABLE("prog", u32, u32, jmp, 32);
#define JMP_IDX_PIPE (1U << 1)
enum action {
ACTION_PASS
};
int process(struct xdp_md *ctx) {
jmp.call((void *)ctx, ACTION_PASS);
jmp.call((void *)ctx, JMP_IDX_PIPE);
return XDP_PASS;
}
"""
b = BPF(text=text)
t = b["jmp"]
self.assertEquals(len(t), 32);
def test_update_macro_arg(self):
text = """
BPF_TABLE("array", u32, u32, act, 32);
#define JMP_IDX_PIPE (1U << 1)
enum action {
ACTION_PASS
};
int process(struct xdp_md *ctx) {
act.increment(ACTION_PASS);
act.increment(JMP_IDX_PIPE);
return XDP_PASS;
}
"""
b = BPF(text=text)
t = b["act"]
self.assertEquals(len(t), 32);
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment