Commit f22a1e09 authored by 4ast's avatar 4ast

Merge pull request #227 from iovisor/bblanco_dev

Add support for static helper functions
parents 9ada11d1 8ed57a23
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <clang/AST/ASTContext.h> #include <clang/AST/ASTContext.h>
#include <clang/AST/RecordLayout.h> #include <clang/AST/RecordLayout.h>
#include <clang/Frontend/CompilerInstance.h> #include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/MultiplexConsumer.h>
#include <clang/Rewrite/Core/Rewriter.h> #include <clang/Rewrite/Core/Rewriter.h>
#include "b_frontend_action.h" #include "b_frontend_action.h"
...@@ -36,6 +37,7 @@ const char *calling_conv_regs_x86[] = { ...@@ -36,6 +37,7 @@ const char *calling_conv_regs_x86[] = {
const char **calling_conv_regs = calling_conv_regs_x86; const char **calling_conv_regs = calling_conv_regs_x86;
using std::map; using std::map;
using std::set;
using std::string; using std::string;
using std::to_string; using std::to_string;
using std::unique_ptr; using std::unique_ptr;
...@@ -90,27 +92,107 @@ bool BMapDeclVisitor::VisitBuiltinType(const BuiltinType *T) { ...@@ -90,27 +92,107 @@ bool BMapDeclVisitor::VisitBuiltinType(const BuiltinType *T) {
return true; return true;
} }
class BProbeChecker : public clang::RecursiveASTVisitor<BProbeChecker> { class ProbeChecker : public clang::RecursiveASTVisitor<ProbeChecker> {
public: public:
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs)
: needs_probe_(false), ptregs_(ptregs) {
if (arg)
TraverseStmt(arg);
}
bool VisitDeclRefExpr(clang::DeclRefExpr *E) { bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
if (E->getDecl()->hasAttr<UnavailableAttr>()) if (ptregs_.find(E->getDecl()) != ptregs_.end())
return false; needs_probe_ = true;
return true; return true;
} }
bool needs_probe() const { return needs_probe_; }
private:
bool needs_probe_;
const set<Decl *> &ptregs_;
}; };
// Visit a piece of the AST and mark it as needing probe reads // Visit a piece of the AST and mark it as needing probe reads
class BProbeSetter : public clang::RecursiveASTVisitor<BProbeSetter> { class ProbeSetter : public clang::RecursiveASTVisitor<ProbeSetter> {
public: public:
explicit BProbeSetter(ASTContext &C) : C(C) {} explicit ProbeSetter(set<Decl *> *ptregs) : ptregs_(ptregs) {}
bool VisitDeclRefExpr(clang::DeclRefExpr *E) { bool VisitDeclRefExpr(clang::DeclRefExpr *E) {
E->getDecl()->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs")); ptregs_->insert(E->getDecl());
return true; return true;
} }
private: private:
ASTContext &C; set<Decl *> *ptregs_;
}; };
ProbeVisitor::ProbeVisitor(Rewriter &rewriter) : rewriter_(rewriter) {}
bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
if (Expr *E = Decl->getInit()) {
if (ProbeChecker(E, ptregs_).needs_probe())
set_ptreg(Decl);
}
return true;
}
bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
if (FunctionDecl *F = dyn_cast<FunctionDecl>(Call->getCalleeDecl())) {
if (F->hasBody()) {
unsigned i = 0;
for (auto arg : Call->arguments()) {
if (ProbeChecker(arg, ptregs_).needs_probe())
ptregs_.insert(F->getParamDecl(i));
++i;
}
if (fn_visited_.find(F) == fn_visited_.end()) {
fn_visited_.insert(F);
TraverseDecl(F);
}
}
}
return true;
}
bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp())
return true;
// copy probe attribute from RHS to LHS if present
if (ProbeChecker(E->getRHS(), ptregs_).needs_probe()) {
ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS());
}
return true;
}
bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
if (memb_visited_.find(E) != memb_visited_.end()) return true;
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
if (!ProbeChecker(E, ptregs_).needs_probe())
return true;
Expr *base;
SourceLocation rhs_start, op;
bool found = false;
for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
memb_visited_.insert(M);
rhs_start = M->getLocEnd();
base = M->getBase();
op = M->getOperatorLoc();
if (M->isArrow()) {
found = true;
break;
}
}
if (!found)
return true;
string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
string base_type = base->getType()->getPointeeType().getAsString();
string pre, post;
pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
post = " + offsetof(" + base_type + ", " + rhs + ")";
post += "); _val; })";
rewriter_.InsertText(E->getLocStart(), pre);
rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
return true;
}
BTypeVisitor::BTypeVisitor(ASTContext &C, Rewriter &rewriter, vector<TableDesc> &tables) BTypeVisitor::BTypeVisitor(ASTContext &C, Rewriter &rewriter, vector<TableDesc> &tables)
: C(C), rewriter_(rewriter), out_(llvm::errs()), tables_(tables) { : C(C), rewriter_(rewriter), out_(llvm::errs()), tables_(tables) {
} }
...@@ -141,6 +223,11 @@ bool BTypeVisitor::VisitFunctionDecl(FunctionDecl *D) { ...@@ -141,6 +223,11 @@ bool BTypeVisitor::VisitFunctionDecl(FunctionDecl *D) {
// for each trace argument, convert the variable from ptregs to something on stack // for each trace argument, convert the variable from ptregs to something on stack
if (CompoundStmt *S = dyn_cast<CompoundStmt>(D->getBody())) if (CompoundStmt *S = dyn_cast<CompoundStmt>(D->getBody()))
rewriter_.ReplaceText(S->getLBracLoc(), 1, preamble); rewriter_.ReplaceText(S->getLBracLoc(), 1, preamble);
} else if (D->hasBody() &&
rewriter_.getSourceMgr().getFileID(D->getLocStart())
== rewriter_.getSourceMgr().getMainFileID()) {
// rewritable functions that are static should be always treated as helper
rewriter_.InsertText(D->getLocStart(), "__attribute__((always_inline))\n");
} }
return true; return true;
} }
...@@ -282,37 +369,6 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -282,37 +369,6 @@ bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
return true; return true;
} }
bool BTypeVisitor::VisitMemberExpr(MemberExpr *E) {
if (visited_.find(E) != visited_.end()) return true;
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
BProbeChecker checker;
if (checker.TraverseStmt(E))
return true;
Expr *base;
SourceLocation rhs_start, op;
for (MemberExpr *M = E; M; M = dyn_cast<MemberExpr>(M->getBase())) {
visited_.insert(M);
rhs_start = M->getLocEnd();
base = M->getBase();
op = M->getOperatorLoc();
if (M->isArrow())
break;
}
string rhs = rewriter_.getRewrittenText(SourceRange(rhs_start, E->getLocEnd()));
string base_type = base->getType()->getPointeeType().getAsString();
string pre, post;
pre = "({ typeof(" + E->getType().getAsString() + ") _val; memset(&_val, 0, sizeof(_val));";
pre += " bpf_probe_read(&_val, sizeof(_val), (u64)";
post = " + offsetof(" + base_type + ", " + rhs + ")";
post += "); _val; })";
rewriter_.InsertText(E->getLocStart(), pre);
rewriter_.ReplaceText(SourceRange(op, E->getLocEnd()), post);
return true;
}
bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp()) if (!E->isAssignmentOp())
return true; return true;
...@@ -340,12 +396,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -340,12 +396,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
} }
} }
} }
// copy probe attribute from RHS to LHS if present
BProbeChecker checker;
if (!checker.TraverseStmt(E->getRHS())) {
BProbeSetter setter(C);
setter.TraverseStmt(E->getLHS());
}
return true; return true;
} }
bool BTypeVisitor::VisitImplicitCastExpr(ImplicitCastExpr *E) { bool BTypeVisitor::VisitImplicitCastExpr(ImplicitCastExpr *E) {
...@@ -453,11 +503,6 @@ bool BTypeVisitor::VisitVarDecl(VarDecl *Decl) { ...@@ -453,11 +503,6 @@ bool BTypeVisitor::VisitVarDecl(VarDecl *Decl) {
} }
} }
} }
if (Expr *E = Decl->getInit()) {
BProbeChecker checker;
if (!checker.TraverseStmt(E))
Decl->addAttr(UnavailableAttr::CreateImplicit(C, "ptregs"));
}
return true; return true;
} }
...@@ -465,9 +510,27 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, Rewriter &rewriter, vector<TableDesc ...@@ -465,9 +510,27 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, Rewriter &rewriter, vector<TableDesc
: visitor_(C, rewriter, tables) { : visitor_(C, rewriter, tables) {
} }
bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef D) { bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) {
for (auto it : D) for (auto D : Group)
visitor_.TraverseDecl(it); visitor_.TraverseDecl(D);
return true;
}
ProbeConsumer::ProbeConsumer(clang::ASTContext &C, Rewriter &rewriter)
: visitor_(rewriter) {}
bool ProbeConsumer::HandleTopLevelDecl(clang::DeclGroupRef Group) {
for (auto D : Group) {
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (F->isExternallyVisible() && F->hasBody()) {
for (auto arg : F->parameters()) {
if (arg != F->getParamDecl(0))
visitor_.set_ptreg(arg);
}
visitor_.TraverseDecl(D);
}
}
}
return true; return true;
} }
...@@ -476,7 +539,6 @@ BFrontendAction::BFrontendAction(llvm::raw_ostream &os, unsigned flags) ...@@ -476,7 +539,6 @@ BFrontendAction::BFrontendAction(llvm::raw_ostream &os, unsigned flags)
} }
void BFrontendAction::EndSourceFileAction() { void BFrontendAction::EndSourceFileAction() {
// uncomment to see rewritten source
if (flags_ & 0x4) if (flags_ & 0x4)
rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(llvm::errs()); rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(llvm::errs());
rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(os_); rewriter_->getEditBuffer(rewriter_->getSourceMgr().getMainFileID()).write(os_);
...@@ -485,7 +547,10 @@ void BFrontendAction::EndSourceFileAction() { ...@@ -485,7 +547,10 @@ void BFrontendAction::EndSourceFileAction() {
unique_ptr<ASTConsumer> BFrontendAction::CreateASTConsumer(CompilerInstance &Compiler, llvm::StringRef InFile) { unique_ptr<ASTConsumer> BFrontendAction::CreateASTConsumer(CompilerInstance &Compiler, llvm::StringRef InFile) {
rewriter_->setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts()); rewriter_->setSourceMgr(Compiler.getSourceManager(), Compiler.getLangOpts());
return unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_)); vector<unique_ptr<ASTConsumer>> consumers;
consumers.push_back(unique_ptr<ASTConsumer>(new ProbeConsumer(Compiler.getASTContext(), *rewriter_)));
consumers.push_back(unique_ptr<ASTConsumer>(new BTypeConsumer(Compiler.getASTContext(), *rewriter_, *tables_)));
return unique_ptr<ASTConsumer>(new MultiplexConsumer(move(consumers)));
} }
} }
...@@ -66,7 +66,6 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> { ...@@ -66,7 +66,6 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
bool VisitFunctionDecl(clang::FunctionDecl *D); bool VisitFunctionDecl(clang::FunctionDecl *D);
bool VisitCallExpr(clang::CallExpr *Call); bool VisitCallExpr(clang::CallExpr *Call);
bool VisitVarDecl(clang::VarDecl *Decl); bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitMemberExpr(clang::MemberExpr *E);
bool VisitBinaryOperator(clang::BinaryOperator *E); bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitImplicitCastExpr(clang::ImplicitCastExpr *E); bool VisitImplicitCastExpr(clang::ImplicitCastExpr *E);
...@@ -79,16 +78,41 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> { ...@@ -79,16 +78,41 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
std::set<clang::Expr *> visited_; std::set<clang::Expr *> visited_;
}; };
// Do a depth-first search to rewrite all pointers that need to be probed
class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
public:
explicit ProbeVisitor(clang::Rewriter &rewriter);
bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitCallExpr(clang::CallExpr *Call);
bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E);
void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
private:
clang::Rewriter &rewriter_;
std::set<clang::Decl *> fn_visited_;
std::set<clang::Expr *> memb_visited_;
std::set<clang::Decl *> ptregs_;
};
// A helper class to the frontend action, walks the decls // A helper class to the frontend action, walks the decls
class BTypeConsumer : public clang::ASTConsumer { class BTypeConsumer : public clang::ASTConsumer {
public: public:
explicit BTypeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter, explicit BTypeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter,
std::vector<TableDesc> &tables); std::vector<TableDesc> &tables);
bool HandleTopLevelDecl(clang::DeclGroupRef D) override; bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
private: private:
BTypeVisitor visitor_; BTypeVisitor visitor_;
}; };
// A helper class to the frontend action, walks the decls
class ProbeConsumer : public clang::ASTConsumer {
public:
ProbeConsumer(clang::ASTContext &C, clang::Rewriter &rewriter);
bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
private:
ProbeVisitor visitor_;
};
// Create a B program in 2 phases (everything else is normal C frontend): // Create a B program in 2 phases (everything else is normal C frontend):
// 1. Catch the map declarations and open the fd's // 1. Catch the map declarations and open the fd's
// 2. Capture the IR // 2. Capture the IR
......
...@@ -104,7 +104,6 @@ int pem(struct __sk_buff *skb) { ...@@ -104,7 +104,6 @@ int pem(struct __sk_buff *skb) {
return 1; return 1;
} }
static int br_common(struct __sk_buff *skb, int which_br) __attribute__((always_inline));
static int br_common(struct __sk_buff *skb, int which_br) { static int br_common(struct __sk_buff *skb, int which_br) {
u8 *cursor = 0; u8 *cursor = 0;
u16 proto; u16 proto;
......
...@@ -172,5 +172,40 @@ int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) { ...@@ -172,5 +172,40 @@ int kprobe__blk_update_request(struct pt_regs *ctx, struct request *req) {
return 0; return 0;
}""") }""")
def test_probe_read_helper(self):
b = BPF(text="""
#include <linux/fs.h>
static void print_file_name(struct file *file) {
if (!file) return;
const char *name = file->f_path.dentry->d_name.name;
bpf_trace_printk("%s\\n", name);
}
int trace_entry(struct pt_regs *ctx, struct file *file) {
print_file_name(file);
return 0;
}
""")
fn = b.load_func("trace_entry", BPF.KPROBE)
def test_probe_struct_assign(self):
b = BPF(text = """
#include <uapi/linux/ptrace.h>
struct args_t {
const char *filename;
int flags;
int mode;
};
int kprobe__sys_open(struct pt_regs *ctx, const char *filename,
int flags, int mode) {
struct args_t args = {};
args.filename = filename;
args.flags = flags;
args.mode = mode;
bpf_trace_printk("%s\\n", args.filename);
return 0;
};
""")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment