Commit 47b74fe0 authored by Paul Chaignon's avatar Paul Chaignon

Fix bpf_dins_pkt rewrite in BinaryOperator

Binary operator expressions where the left hand-side expression is a
reference to the packet are replaced by a call to the bpf_dins_pkt
helper. When replacing text, the Clang Rewriter tries to maintain a
list of offsets between the original and the new position of tokens.

Replacing the whole binary operator expression with the call to
bpf_dins_pkt confuses the Rewriter and it is unable to track the new
position of the right hand-side expression. Rewriting the binary
operator expression in two times without rewriting the right
hand-side expression itself solves the issue.
parent 0a2a46e2
...@@ -521,7 +521,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -521,7 +521,6 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp()) if (!E->isAssignmentOp())
return true; return true;
Expr *LHS = E->getLHS()->IgnoreImplicit(); Expr *LHS = E->getLHS()->IgnoreImplicit();
Expr *RHS = E->getRHS()->IgnoreImplicit();
if (MemberExpr *Memb = dyn_cast<MemberExpr>(LHS)) { if (MemberExpr *Memb = dyn_cast<MemberExpr>(LHS)) {
if (DeclRefExpr *Base = dyn_cast<DeclRefExpr>(Memb->getBase()->IgnoreImplicit())) { if (DeclRefExpr *Base = dyn_cast<DeclRefExpr>(Memb->getBase()->IgnoreImplicit())) {
if (DeprecatedAttr *A = Base->getDecl()->getAttr<DeprecatedAttr>()) { if (DeprecatedAttr *A = Base->getDecl()->getAttr<DeprecatedAttr>()) {
...@@ -534,10 +533,10 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) { ...@@ -534,10 +533,10 @@ bool BTypeVisitor::VisitBinaryOperator(BinaryOperator *E) {
uint64_t ofs = C.getFieldOffset(F); uint64_t ofs = C.getFieldOffset(F);
uint64_t sz = F->isBitField() ? F->getBitWidthValue(C) : C.getTypeSize(F->getType()); uint64_t sz = F->isBitField() ? F->getBitWidthValue(C) : C.getTypeSize(F->getType());
string base = rewriter_.getRewrittenText(expansionRange(Base->getSourceRange())); string base = rewriter_.getRewrittenText(expansionRange(Base->getSourceRange()));
string rhs = rewriter_.getRewrittenText(expansionRange(RHS->getSourceRange()));
string text = "bpf_dins_pkt(" + fn_args_[0]->getName().str() + ", (u64)" + base + "+" + to_string(ofs >> 3) string text = "bpf_dins_pkt(" + fn_args_[0]->getName().str() + ", (u64)" + base + "+" + to_string(ofs >> 3)
+ ", " + to_string(ofs & 0x7) + ", " + to_string(sz) + ", " + rhs + ")"; + ", " + to_string(ofs & 0x7) + ", " + to_string(sz) + ",";
rewriter_.ReplaceText(expansionRange(E->getSourceRange()), text); rewriter_.ReplaceText(expansionRange(SourceRange(E->getLocStart(), E->getOperatorLoc())), text);
rewriter_.InsertTextAfterToken(E->getLocEnd(), ")");
} }
} }
} }
......
...@@ -392,5 +392,24 @@ int process(struct xdp_md *ctx) { ...@@ -392,5 +392,24 @@ int process(struct xdp_md *ctx) {
t = b["act"] t = b["act"]
self.assertEquals(len(t), 32); self.assertEquals(len(t), 32);
def test_bpf_dins_pkt_rewrite(self):
text = """
#include <bcc/proto.h>
int dns_test(struct __sk_buff *skb) {
u8 *cursor = 0;
struct ethernet_t *ethernet = cursor_advance(cursor, sizeof(*ethernet));
if(ethernet->type == ETH_P_IP) {
struct ip_t *ip = cursor_advance(cursor, sizeof(*ip));
ip->src = ip->dst;
return 0;
}
return -1;
}
"""
b = BPF(text=text)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment