Commit fe779f31 authored by Paul Chaignon's avatar Paul Chaignon Committed by yonghong-song

Trace external pointers through function returns (#1821)

* Trace external pointers through function returns

Surprisingly, the rewriter wasn't able to trace external pointers
returned by inlined functions until now.  This commit fixes it by
adding functions that return an external pointer to ProbeVisitor's
set of external pointers, along with the levels of indirection.

This change requires reversing a few traversals to visit called
functions before they are called.  Then, we check the presence of an
external pointer on return statements and retrieve that information
at the call expression.

* Tests dereferences of ext ptrs returned by inlined func

* tcpdrop: remove unnecessary bpf_probe_read calls

e783567a makes these calls unnecessary.
parent f86f7e84
...@@ -107,6 +107,23 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> { ...@@ -107,6 +107,23 @@ class ProbeChecker : public RecursiveASTVisitor<ProbeChecker> {
: ProbeChecker(arg, ptregs, is_transitive, false) {} : ProbeChecker(arg, ptregs, is_transitive, false) {}
bool VisitCallExpr(CallExpr *E) { bool VisitCallExpr(CallExpr *E) {
needs_probe_ = false; needs_probe_ = false;
if (is_assign_) {
// We're looking for a function that returns an external pointer,
// regardless of the number of dereferences.
for(auto p : ptregs_) {
if (std::get<0>(p) == E->getDirectCallee()) {
needs_probe_ = true;
nb_derefs_ += std::get<1>(p);
return false;
}
}
} else {
tuple<Decl *, int> pt = make_tuple(E->getDirectCallee(), nb_derefs_);
if (ptregs_.find(pt) != ptregs_.end())
needs_probe_ = true;
}
if (!track_helpers_) if (!track_helpers_)
return false; return false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl())) if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl()))
...@@ -220,6 +237,12 @@ bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) { ...@@ -220,6 +237,12 @@ bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) {
return true; return true;
} }
/* If the expression contains a call to another function, we need to visit
* that function first to know if a rewrite is necessary (i.e., if the
* function returns an external pointer). */
if (!TraverseStmt(E))
return false;
ProbeChecker checker = ProbeChecker(E, ptregs_, track_helpers_, ProbeChecker checker = ProbeChecker(E, ptregs_, track_helpers_,
true); true);
if (checker.is_transitive()) { if (checker.is_transitive()) {
...@@ -231,8 +254,8 @@ bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) { ...@@ -231,8 +254,8 @@ bool ProbeVisitor::assignsExtPtr(Expr *E, int *nbAddrOf) {
return true; return true;
} }
if (E->getStmtClass() == Stmt::CallExprClass) { if (E->IgnoreParenCasts()->getStmtClass() == Stmt::CallExprClass) {
CallExpr *Call = dyn_cast<CallExpr>(E); CallExpr *Call = dyn_cast<CallExpr>(E->IgnoreParenCasts());
if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) { if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) {
StringRef memb_name = Memb->getMemberDecl()->getName(); StringRef memb_name = Memb->getMemberDecl()->getName();
if (DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(Memb->getBase())) { if (DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(Memb->getBase())) {
...@@ -301,8 +324,45 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) { ...@@ -301,8 +324,45 @@ bool ProbeVisitor::VisitCallExpr(CallExpr *Call) {
} }
if (fn_visited_.find(F) == fn_visited_.end()) { if (fn_visited_.find(F) == fn_visited_.end()) {
fn_visited_.insert(F); fn_visited_.insert(F);
/* Maintains a stack of the number of dereferences for the external
* pointers returned by each function in the call stack or -1 if the
* function didn't return an external pointer. */
ptregs_returned_.push_back(-1);
TraverseDecl(F); TraverseDecl(F);
int nb_derefs = ptregs_returned_.back();
ptregs_returned_.pop_back();
if (nb_derefs != -1) {
tuple<Decl *, int> pt = make_tuple(F, nb_derefs);
ptregs_.insert(pt);
}
}
}
} }
return true;
}
bool ProbeVisitor::VisitReturnStmt(ReturnStmt *R) {
/* If this function wasn't called by another, there's no need to check the
* return statement for external pointers. */
if (ptregs_returned_.size() == 0)
return true;
/* Reverse order of traversals. This is needed if, in the return statement,
* we're calling a function that's returning an external pointer: we need to
* know what the function is returning to decide what this function is
* returning. */
if (!TraverseStmt(R->getRetValue()))
return false;
ProbeChecker checker = ProbeChecker(R->getRetValue(), ptregs_,
track_helpers_, true);
if (checker.needs_probe()) {
int curr_nb_derefs = ptregs_returned_.back();
/* If the function returns external pointers with different levels of
* indirection, we handle the case with the highest level of indirection
* and leave it to the user to manually handle other cases. */
if (checker.get_nb_derefs() > curr_nb_derefs) {
ptregs_returned_.pop_back();
ptregs_returned_.push_back(checker.get_nb_derefs());
} }
} }
return true; return true;
...@@ -359,6 +419,15 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) { ...@@ -359,6 +419,15 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
return false; return false;
} }
/* If the base of the dereference is a call to another function, we need to
* visit that function first to know if a rewrite is necessary (i.e., if the
* function returns an external pointer). */
if (base->IgnoreParenCasts()->getStmtClass() == Stmt::CallExprClass) {
CallExpr *Call = dyn_cast<CallExpr>(base->IgnoreParenCasts());
if (!TraverseStmt(Call))
return false;
}
// Checks to see if the expression references something that needs to be run // Checks to see if the expression references something that needs to be run
// through bpf_probe_read. // through bpf_probe_read.
if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe()) if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe())
......
...@@ -98,6 +98,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> { ...@@ -98,6 +98,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
bool VisitVarDecl(clang::VarDecl *Decl); bool VisitVarDecl(clang::VarDecl *Decl);
bool TraverseStmt(clang::Stmt *S); bool TraverseStmt(clang::Stmt *S);
bool VisitCallExpr(clang::CallExpr *Call); bool VisitCallExpr(clang::CallExpr *Call);
bool VisitReturnStmt(clang::ReturnStmt *R);
bool VisitBinaryOperator(clang::BinaryOperator *E); bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitUnaryOperator(clang::UnaryOperator *E); bool VisitUnaryOperator(clang::UnaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E); bool VisitMemberExpr(clang::MemberExpr *E);
...@@ -120,6 +121,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> { ...@@ -120,6 +121,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
std::set<clang::Decl *> &m_; std::set<clang::Decl *> &m_;
clang::Decl *ctx_; clang::Decl *ctx_;
bool track_helpers_; bool track_helpers_;
std::list<int> ptregs_returned_;
}; };
// A helper class to the frontend action, walks the decls // A helper class to the frontend action, walks the decls
......
...@@ -940,6 +940,74 @@ int test(struct __sk_buff *ctx) { ...@@ -940,6 +940,74 @@ int test(struct __sk_buff *ctx) {
b = BPF(text=text) b = BPF(text=text)
fn = b.load_func("test", BPF.SCHED_CLS) fn = b.load_func("test", BPF.SCHED_CLS)
def test_probe_read_return(self):
text = """
#define KBUILD_MODNAME "foo"
#include <uapi/linux/ptrace.h>
#include <linux/tcp.h>
static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
return skb->head + skb->transport_header;
}
int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
struct tcphdr *th = (struct tcphdr *)my_skb_transport_header(skb);
return th->seq;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_multiple_return(self):
text = """
#define KBUILD_MODNAME "foo"
#include <uapi/linux/ptrace.h>
#include <linux/tcp.h>
static inline u64 error_function() {
return 0;
}
static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
if (skb)
return skb->head + skb->transport_header;
return (unsigned char *)error_function();
}
int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
struct tcphdr *th = (struct tcphdr *)my_skb_transport_header(skb);
return th->seq;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_return_expr(self):
text = """
#define KBUILD_MODNAME "foo"
#include <uapi/linux/ptrace.h>
#include <linux/tcp.h>
static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
return skb->head + skb->transport_header;
}
int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
u32 *seq = (u32 *)my_skb_transport_header(skb) + offsetof(struct tcphdr, seq);
return *seq;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_return_call(self):
text = """
#define KBUILD_MODNAME "foo"
#include <uapi/linux/ptrace.h>
#include <linux/tcp.h>
static inline struct tcphdr *my_skb_transport_header(struct sk_buff *skb) {
return (struct tcphdr *)skb->head + skb->transport_header;
}
int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
return my_skb_transport_header(skb)->seq;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -107,16 +107,16 @@ int trace_tcp_drop(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) ...@@ -107,16 +107,16 @@ int trace_tcp_drop(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb)
u8 tcpflags = 0; u8 tcpflags = 0;
struct tcphdr *tcp = skb_to_tcphdr(skb); struct tcphdr *tcp = skb_to_tcphdr(skb);
struct iphdr *ip = skb_to_iphdr(skb); struct iphdr *ip = skb_to_iphdr(skb);
bpf_probe_read(&sport, sizeof(sport), &tcp->source); sport = tcp->source;
bpf_probe_read(&dport, sizeof(dport), &tcp->dest); dport = tcp->dest;
bpf_probe_read(&tcpflags, sizeof(tcpflags), &tcp_flag_byte(tcp)); bpf_probe_read(&tcpflags, sizeof(tcpflags), &tcp_flag_byte(tcp));
sport = ntohs(sport); sport = ntohs(sport);
dport = ntohs(dport); dport = ntohs(dport);
if (family == AF_INET) { if (family == AF_INET) {
struct ipv4_data_t data4 = {.pid = pid, .ip = 4}; struct ipv4_data_t data4 = {.pid = pid, .ip = 4};
bpf_probe_read(&data4.saddr, sizeof(u32), &ip->saddr); data4.saddr = ip->saddr;
bpf_probe_read(&data4.daddr, sizeof(u32), &ip->daddr); data4.daddr = ip->daddr;
data4.dport = dport; data4.dport = dport;
data4.sport = sport; data4.sport = sport;
data4.state = state; data4.state = state;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment