Commit cb9773a9 authored by 4ast's avatar 4ast Committed by GitHub

Merge pull request #1850 from pchaigno/array-access-to-probe-reads

Rewrite array accesses
parents b2d18a78 4ba5c09e
......@@ -25,6 +25,7 @@
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/MultiplexConsumer.h>
#include <clang/Rewrite/Core/Rewriter.h>
#include <clang/Lex/Lexer.h>
#include "b_frontend_action.h"
#include "bpf_module.h"
......@@ -487,7 +488,64 @@ bool ProbeVisitor::VisitMemberExpr(MemberExpr *E) {
rewriter_.ReplaceText(expansionRange(SourceRange(member, E->getLocEnd())), post);
return true;
}
bool ProbeVisitor::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
if (memb_visited_.find(E) != memb_visited_.end()) return true;
if (!ProbeChecker(E, ptregs_, track_helpers_).needs_probe())
return true;
// Parent expr has addrof, skip the rewrite.
if (is_addrof_)
return true;
if (!rewriter_.isRewritable(E->getLocStart()))
return true;
Expr *base = E->getBase();
Expr *idx = E->getIdx();
memb_visited_.insert(E);
string pre, lbracket, rbracket;
LangOptions opts;
SourceLocation lbracket_start, lbracket_end;
SourceRange lbracket_range;
pre = "({ typeof(" + E->getType().getAsString() + ") _val; __builtin_memset(&_val, 0, sizeof(_val));";
pre += " bpf_probe_read(&_val, sizeof(_val), (u64)(";
if (isMemberDereference(base)) {
pre += "&";
// If the base of the array subscript is a member dereference, we'll rewrite
// both at the same time.
addrof_stmt_ = base;
is_addrof_ = true;
}
rewriter_.InsertText(expansionLoc(base->getLocStart()), pre);
/* Replace left bracket and any space around it. Since Clang doesn't provide
* a method to retrieve the left bracket, replace everything from the end of
* the base to the start of the index. */
lbracket = ") + (";
lbracket_start = Lexer::getLocForEndOfToken(base->getLocEnd(), 1,
rewriter_.getSourceMgr(),
opts).getLocWithOffset(1);
lbracket_end = idx->getLocStart().getLocWithOffset(-1);
lbracket_range = expansionRange(SourceRange(lbracket_start, lbracket_end));
rewriter_.ReplaceText(lbracket_range, lbracket);
rbracket = ")); _val; })";
rewriter_.ReplaceText(expansionLoc(E->getRBracketLoc()), 1, rbracket);
return true;
}
bool ProbeVisitor::isMemberDereference(Expr *E) {
if (E->IgnoreParenCasts()->getStmtClass() != Stmt::MemberExprClass)
return false;
for (MemberExpr *M = dyn_cast<MemberExpr>(E->IgnoreParenCasts()); M;
M = dyn_cast<MemberExpr>(M->getBase()->IgnoreParenCasts())) {
if (M->isArrow())
return true;
}
return false;
}
bool ProbeVisitor::IsContextMemberExpr(Expr *E) {
if (!E->getType()->isPointerType())
return false;
......
......@@ -102,11 +102,13 @@ class ProbeVisitor : public clang::RecursiveASTVisitor<ProbeVisitor> {
bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitUnaryOperator(clang::UnaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E);
bool VisitArraySubscriptExpr(clang::ArraySubscriptExpr *E);
void set_ptreg(std::tuple<clang::Decl *, int> &pt) { ptregs_.insert(pt); }
void set_ctx(clang::Decl *D) { ctx_ = D; }
std::set<std::tuple<clang::Decl *, int>> get_ptregs() { return ptregs_; }
private:
bool assignsExtPtr(clang::Expr *E, int *nbAddrOf);
bool isMemberDereference(clang::Expr *E);
bool IsContextMemberExpr(clang::Expr *E);
clang::SourceRange expansionRange(clang::SourceRange range);
clang::SourceLocation expansionLoc(clang::SourceLocation loc);
......
......@@ -1126,6 +1126,91 @@ int test(struct pt_regs *ctx) {
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses1(self):
text = """
#include <linux/ptrace.h>
#include <linux/dcache.h>
int test(struct pt_regs *ctx, const struct qstr *name) {
return name->name[1];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses2(self):
text = """
#include <linux/ptrace.h>
#include <linux/dcache.h>
int test(struct pt_regs *ctx, const struct qstr *name) {
return name->name [ 1];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses3(self):
text = """
#include <linux/ptrace.h>
#include <linux/dcache.h>
int test(struct pt_regs *ctx, const struct qstr *name) {
return (name->name)[1];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses4(self):
text = """
#include <linux/ptrace.h>
int test(struct pt_regs *ctx, char *name) {
return name[1];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses5(self):
text = """
#include <linux/ptrace.h>
int test(struct pt_regs *ctx, char **name) {
return (*name)[1];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses6(self):
text = """
#include <linux/ptrace.h>
struct test_t {
int tab[5];
};
int test(struct pt_regs *ctx, struct test_t *t) {
return *(&t->tab[1]);
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses7(self):
text = """
#include <net/inet_sock.h>
int test(struct pt_regs *ctx, struct sock *sk) {
return sk->__sk_common.skc_v6_rcv_saddr.in6_u.u6_addr32[0];
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
def test_probe_read_array_accesses8(self):
text = """
#include <linux/mm_types.h>
int test(struct pt_regs *ctx, struct mm_struct *mm) {
return mm->rss_stat.count[MM_ANONPAGES].counter;
}
"""
b = BPF(text=text)
fn = b.load_func("test", BPF.KPROBE)
if __name__ == "__main__":
main()
......@@ -104,12 +104,11 @@ int trace_tcp_drop(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb)
u16 family = sk->__sk_common.skc_family;
char state = sk->__sk_common.skc_state;
u16 sport = 0, dport = 0;
u8 tcpflags = 0;
struct tcphdr *tcp = skb_to_tcphdr(skb);
struct iphdr *ip = skb_to_iphdr(skb);
u8 tcpflags = ((u_int8_t *)tcp)[13];
sport = tcp->source;
dport = tcp->dest;
bpf_probe_read(&tcpflags, sizeof(tcpflags), &tcp_flag_byte(tcp));
sport = ntohs(sport);
dport = ntohs(dport);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment