Commit 01c843e7 authored by yonghong-song's avatar yonghong-song Committed by GitHub

Merge pull request #1724 from pchaigno/detect-ext-ptr-from-ctx

Detect external pointers from context argument
parents cf442785 a8b4cee4
...@@ -156,8 +156,9 @@ ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m) : ...@@ -156,8 +156,9 @@ ProbeVisitor::ProbeVisitor(ASTContext &C, Rewriter &rewriter, set<Decl *> &m) :
bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) { bool ProbeVisitor::VisitVarDecl(VarDecl *Decl) {
if (Expr *E = Decl->getInit()) { if (Expr *E = Decl->getInit()) {
if (ProbeChecker(E, ptregs_).is_transitive()) if (ProbeChecker(E, ptregs_).is_transitive() || IsContextMemberExpr(E)) {
set_ptreg(Decl); set_ptreg(Decl);
}
} }
return true; return true;
} }
...@@ -185,7 +186,7 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -185,7 +186,7 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (ProbeChecker(E->getRHS(), ptregs_).is_transitive()) { if (ProbeChecker(E->getRHS(), ptregs_).is_transitive()) {
ProbeSetter setter(&ptregs_); ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS()); setter.TraverseStmt(E->getLHS());
} else if (E->isAssignmentOp() && E->getRHS()->getStmtClass() == Stmt::CallExprClass) { } else if (E->getRHS()->getStmtClass() == Stmt::CallExprClass) {
CallExpr *Call = dyn_cast<CallExpr>(E->getRHS()); CallExpr *Call = dyn_cast<CallExpr>(E->getRHS());
if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) { if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) {
StringRef memb_name = Memb->getMemberDecl()->getName(); StringRef memb_name = Memb->getMemberDecl()->getName();
...@@ -204,6 +205,9 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -204,6 +205,9 @@ bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
} }
} }
} }
} else if (IsContextMemberExpr(E->getRHS())) {
ProbeSetter setter(&ptregs_);
setter.TraverseStmt(E->getLHS());
} }
return true; return true;
} }
...@@ -262,6 +266,40 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) { ...@@ -262,6 +266,40 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
return true; return true;
} }
bool ProbeVisitor::IsContextMemberExpr(Expr *E) {
if (!E->getType()->isPointerType())
return false;
MemberExpr *Memb = dyn_cast<MemberExpr>(E->IgnoreParenCasts());
Expr *base;
SourceLocation rhs_start, member;
bool found = false;
MemberExpr *M;
for (M = Memb; M; M = dyn_cast<MemberExpr>(M->getBase())) {
memb_visited_.insert(M);
rhs_start = M->getLocEnd();
base = M->getBase();
member = M->getMemberLoc();
if (M->isArrow()) {
found = true;
break;
}
}
if (!found) {
return false;
}
if (member.isInvalid()) {
return false;
}
if (DeclRefExpr *base_expr = dyn_cast<DeclRefExpr>(base->IgnoreImplicit())) {
if (base_expr->getDecl() == ctx_) {
return true;
}
}
return false;
}
SourceRange SourceRange
ProbeVisitor::expansionRange(SourceRange range) { ProbeVisitor::expansionRange(SourceRange range) {
return rewriter_.getSourceMgr().getExpansionRange(range); return rewriter_.getSourceMgr().getExpansionRange(range);
...@@ -862,8 +900,11 @@ void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) { ...@@ -862,8 +900,11 @@ void BTypeConsumer::HandleTranslationUnit(ASTContext &Context) {
if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) { if (FunctionDecl *F = dyn_cast<FunctionDecl>(D)) {
if (fe_.is_rewritable_ext_func(F)) { if (fe_.is_rewritable_ext_func(F)) {
for (auto arg : F->parameters()) { for (auto arg : F->parameters()) {
if (arg != F->getParamDecl(0) && !arg->getType()->isFundamentalType()) if (arg == F->getParamDecl(0)) {
probe_visitor_.set_ctx(arg);
} else if (!arg->getType()->isFundamentalType()) {
probe_visitor_.set_ptreg(arg); probe_visitor_.set_ptreg(arg);
}
} }
probe_visitor_.TraverseDecl(D); probe_visitor_.TraverseDecl(D);
} }
......
...@@ -95,7 +95,9 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> { ...@@ -95,7 +95,9 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
bool VisitUnaryOperator(clang::UnaryOperator *E); bool VisitUnaryOperator(clang::UnaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E); bool VisitMemberExpr(clang::MemberExpr *E);
void set_ptreg(clang::Decl *D) { ptregs_.insert(D); } void set_ptreg(clang::Decl *D) { ptregs_.insert(D); }
void set_ctx(clang::Decl *D) { ctx_ = D; }
private: private:
bool IsContextMemberExpr(clang::Expr *E);
clang::SourceRange expansionRange(clang::SourceRange range); clang::SourceRange expansionRange(clang::SourceRange range);
template <unsigned N> template <unsigned N>
clang::DiagnosticBuilder error(clang::SourceLocation loc, const char (&fmt)[N]); clang::DiagnosticBuilder error(clang::SourceLocation loc, const char (&fmt)[N]);
...@@ -106,6 +108,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> { ...@@ -106,6 +108,7 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
std::set<clang::Expr *> memb_visited_; std::set<clang::Expr *> memb_visited_;
std::set<clang::Decl *> ptregs_; std::set<clang::Decl *> ptregs_;
std::set<clang::Decl *> &m_; std::set<clang::Decl *> &m_;
clang::Decl *ctx_;
}; };
// A helper class to the frontend action, walks the decls // A helper class to the frontend action, walks the decls
......
...@@ -701,6 +701,30 @@ BPF_HASH(table1, struct key_t, struct value_t); ...@@ -701,6 +701,30 @@ BPF_HASH(table1, struct key_t, struct value_t);
self.assertEqual(ct.sizeof(table.Key), 96) self.assertEqual(ct.sizeof(table.Key), 96)
self.assertEqual(ct.sizeof(table.Leaf), 16) self.assertEqual(ct.sizeof(table.Leaf), 16)
@skipUnless(kernel_version_ge(4,7), "requires kernel >= 4.7")
def test_probe_read_tracepoint_context(self):
text = """
#include <linux/netdevice.h>
TRACEPOINT_PROBE(skb, kfree_skb) {
struct sk_buff *skb = (struct sk_buff *)args->skbaddr;
return skb->protocol;
}
"""
b = BPF(text=text)
def test_probe_read_kprobe_ctx(self):
text = """
#include <linux/sched.h>
#include <net/inet_sock.h>
int test(struct pt_regs *ctx) {
struct sock *sk;
sk = (struct sock *)ctx->di;
return sk->sk_dport;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment