Commit ad2d0d9f authored by Paul Chaignon's avatar Paul Chaignon Committed by yonghong-song

Trace all external pointers passed through a first map (#1737)

* Trace all external pointers going through a first map

Currently, MapVisitor only detects maps with external pointers as
values if the value was directly passed from a function's argument.
For example, in the following, the rewriter is currently unable to
detect currsock has an external pointer as value because an
intermediate variable is used instead of passing directly sk as the
map's value.

    int test(struct pt_regs *ctx, struct sock *sk) {
        u32 pid = bpf_get_current_pid_tgid();
        struct sock **skp = &sk;
        currsock.update(&pid, skp);
        return 0;
    };

With this commit, MapVisitor is able to trace any external pointer
derived from the function's argument and used as a map value. This
commit breaks the ProbeVisitor traversal in two distinct traversals.
The first rewrites dereferences of external pointers originating
from function's arguments and helpers, while the second rewrites only
dereferences of external pointers passed through maps.
Maps with external pointers as values are identified between the two
ProbeVisitor traversals.

* New tests for external pointers passed through maps

test_ext_ptr_maps_reverse ensures dereferences are correctly replaced
even if the update happens after the lookup (in the order of
MapVisitor traversal).
test_ext_ptr_maps_indirect ensures the rewriter is able to trace
external pointers used as map values even if using an intermediate
variable.
parent 42da08aa
......@@ -90,8 +90,9 @@ using namespace clang;
class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
public:
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs)
: needs_probe_(false), is_transitive_(false), ptregs_(ptregs) {
explicit ProbeChecker(Expr *arg, const set<Decl *> &ptregs, bool track_helpers)
: needs_probe_(false), is_transitive_(false), ptregs_(ptregs),
track_helpers_(track_helpers) {
if (arg) {
TraverseStmt(arg);
if (arg->getType()->isPointerType())
......@@ -100,9 +101,10 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
}
bool VisitCallExpr(CallExpr *E) {
needs_probe_ = false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl())) {
if (!track_helpers_)
return false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl()))
needs_probe_ = V->getName() == "bpf_get_current_task";
}
return false;
}
bool VisitMemberExpr(MemberExpr *M) {
......@@ -123,6 +125,7 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
bool needs_probe_;
bool is_transitive_;
const set<Decl *> &ptregs_;
bool track_helpers_;
};
// Visit a piece of the AST and mark it as needing probe reads
......@@ -152,7 +155,7 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) {
return true;
if (memb_name == "update" || memb_name == "insert") {
if (ProbeChecker(Call->getArg(1), ptregs_).needs_probe()) {
if (ProbeChecker(Call->getArg(1), ptregs_, true).needs_probe()) {
m_.insert(Ref->getDecl());
}
}
......@@ -162,12 +165,12 @@ bool MapVisitor::VisitCallExpr(CallExpr *Call) {
return true;
}
ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m) :
C(C), rewriter_(rewriter), m_(m) {}
ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m, bool track_helpers) :
C(C), rewriter_(rewriter), m_(m), track_helpers_(track_helpers) {}
bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
if (Expr *E = Decl->getInit()) {
if (ProbeChecker(E, ptregs_).is_transitive() || IsContextMemberExpr(E)) {
if (ProbeChecker(E, ptregs_, track_helpers_).is_transitive() || IsContextMemberExpr(E)) {
set_ptreg(Decl);
}
}
......@@ -178,7 +181,7 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
if (F->hasBody()) {
unsigned i = 0;
for (auto arg : Call->arguments()) {
if (ProbeChecker(arg, ptregs_).needs_probe())
if (ProbeChecker(arg, ptregs_, track_helpers_).needs_probe())
ptregs_.insert(F->getParamDecl(i));
++i;
}
......@@ -194,7 +197,7 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp())
return true;
// copy probe attribute from RHS to LHS if present
if (ProbeChecker(E->getRHS(), ptregs_).is_transitive()) {
if (ProbeChecker(E->getRHS(), ptregs_, track_helpers_).is_transitive()) {
ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS());
} else if (E->getRHS()->getStmtClass() == Stmt::CallExprClass) {
......@@ -227,7 +230,7 @@ bool ProbeVisitor::VisitUnaryOperator(UnaryOperator *E) {
return true;
if (memb_visited_.find(E) != memb_visited_.end())
return true;
if (!ProbeChecker(E, ptregs_).needs_probe())
if (!ProbeChecker(E, ptregs_, track_helpers_).needs_probe())
return true;
memb_visited_.insert(E);
Expr *sub = E->getSubExpr();
......@@ -264,7 +267,7 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
if (!ProbeChecker(base, ptregs_).needs_probe())
if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe())
return true;
string rhs = rewriter_.getRewrittenText(expansionRange(SourceRange(rhs_start, E->getLocEnd())));
......@@ -889,44 +892,71 @@ BTypeConsumer::BTypeConsumer(ASTContext &C, BFrontendAction &fe,
: fe_(fe),
map_visitor_(m),
btype_visitor_(C, fe),
probe_visitor_(C, rewriter, m) {}
probe_visitor1_(C, rewriter, m, true),
probe_visitor2_(C, rewriter, m, false) {}
bool BTypeConsumer::HandleTopLevelDecl(DeclGroupRef Group) {
for (auto D : Group) {
void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) {
DeclContext::decl_iterator it;
DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl());
/**
* In a first traversal, ProbeVisitor tracks external pointers identified
* through each function's arguments and replaces their dereferences with
* calls to bpf_probe_read. It also passes all identified pointers to
* external addresses to MapVisitor.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
for (auto arg : F->parameters()) {
if (arg != F->getParamDecl(0) && !arg->getType()->isFundamentalType()) {
map_visitor_.set_ptreg(arg);
if (arg == F->getParamDecl(0)) {
probe_visitor1_.set_ctx(arg);
} else if (!arg->getType()->isFundamentalType()) {
probe_visitor1_.set_ptreg(arg);
}
}
map_visitor_.TraverseDecl(D);
probe_visitor1_.TraverseDecl(D);
for (auto ptreg : probe_visitor1_.get_ptregs()) {
map_visitor_.set_ptreg(ptreg);
}
}
}
}
return true;
}
void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) {
DeclContext::decl_iterator it;
DeclContext *DC = TranslationUnitDecl::castToDeclContext(Context.getTranslationUnitDecl());
/**
* MapVisitor uses external pointers identified by the first ProbeVisitor
* traversal to identify all maps with external pointers as values.
* MapVisitor runs only after ProbeVisitor finished its traversal of the
* whole translation unit to clearly separate the role of each ProbeVisitor's
* traversal: the first tracks external pointers from function arguments,
* whereas the second tracks external pointers from maps. Without this clear
* separation, ProbeVisitor might attempt to replace several times the same
* dereferences.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
map_visitor_.TraverseDecl(D);
}
}
}
/**
* ProbeVisitor's traversal runs after an entire translation unit has been parsed.
* to make sure maps with external pointers have been identified.
* In a second traversal, ProbeVisitor tracks pointers passed through the
* maps identified by MapVisitor and replaces their dereferences with calls
* to bpf_probe_read.
* This last traversal runs after MapVisitor went through an entire
* translation unit, to ensure maps with external pointers have all been
* identified.
*/
for (it = DC->decls_begin(); it != DC->decls_end(); it++) {
Decl *D = *it;
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) {
for (auto arg : F->parameters()) {
if (arg == F->getParamDecl(0)) {
probe_visitor_.set_ctx(arg);
} else if (!arg->getType()->isFundamentalType()) {
probe_visitor_.set_ptreg(arg);
}
}
probe_visitor_.TraverseDecl(D);
probe_visitor2_.TraverseDecl(D);
}
}
......
......@@ -88,7 +88,8 @@ class BTypeVisitor : public clang::RecursiveASTVisitor<BTypeVisitor> {
// Do a depth-first search to rewrite all pointers that need to be probed
class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
public:
explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter, std::set<clang::Decl *> &m);
explicit ProbeVisitor(clang::ASTContext &C, clang::Rewriter &rewriter,
std::set<clang::Decl *> &m, bool track_helpers);
bool VisitVarDecl(clang::VarDecl *Decl);
bool VisitCallExpr(clang::CallExpr *Call);
bool VisitBinaryOperator(clang::BinaryOperator *E);
......@@ -96,6 +97,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
bool VisitMemberExpr(clang::MemberExpr *E);
void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
void set_ctx(clang::Decl *D) { ctx_ = D; }
std::set<clang::Decl *> get_ptregs() { return ptregs_; }
private:
bool IsContextMemberExpr(clang::Expr *E);
clang::SourceRange expansionRange(clang::SourceRange range);
......@@ -109,19 +111,20 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
std::set<clang::Decl *> ptregs_;
std::set<clang::Decl *> &m_;
clang::Decl *ctx_;
bool track_helpers_;
};
// A helper class to the frontend action, walks the decls
class BTypeConsumer : public clang::ASTConsumer {
public:
explicit BTypeConsumer(clang::ASTContext &C, BFrontendAction &fe, clang::Rewriter &rewriter, std::set<clang::Decl *> &map);
bool HandleTopLevelDecl(clang::DeclGroupRef Group) override;
void HandleTranslationUnit(clang::ASTContext &Context) override;
private:
BFrontendAction &fe_;
MapVisitor map_visitor_;
BTypeVisitor btype_visitor_;
ProbeVisitor probe_visitor_;
ProbeVisitor probe_visitor1_;
ProbeVisitor probe_visitor2_;
};
// Create a B program in 2 phases (everything else is normal C frontend):
......
......@@ -507,6 +507,65 @@ int trace_entry(struct pt_regs *ctx, struct sock *sk,
return 0;
};
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
skpp = currsock.lookup(&pid);
if (skpp) {
struct sock *skp = *skpp;
return skp->__sk_common.skc_dport;
}
return 0;
}
"""
b = BPF(text=bpf_text)
b.load_func("trace_entry", BPF.KPROBE)
b.load_func("trace_exit", BPF.KPROBE)
def test_ext_ptr_maps_reverse(self):
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <net/sock.h>
#include <bcc/proto.h>
BPF_HASH(currsock, u32, struct sock *);
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
skpp = currsock.lookup(&pid);
if (skpp) {
struct sock *skp = *skpp;
return skp->__sk_common.skc_dport;
}
return 0;
}
int trace_entry(struct pt_regs *ctx, struct sock *sk) {
u32 pid = bpf_get_current_pid_tgid();
currsock.update(&pid, &sk);
return 0;
};
"""
b = BPF(text=bpf_text)
b.load_func("trace_entry", BPF.KPROBE)
b.load_func("trace_exit", BPF.KPROBE)
def test_ext_ptr_maps_indirect(self):
bpf_text = """
#include <uapi/linux/ptrace.h>
#include <net/sock.h>
#include <bcc/proto.h>
BPF_HASH(currsock, u32, struct sock *);
int trace_entry(struct pt_regs *ctx, struct sock *sk) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skp = &sk;
currsock.update(&pid, skp);
return 0;
};
int trace_exit(struct pt_regs *ctx) {
u32 pid = bpf_get_current_pid_tgid();
struct sock **skpp;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment