Commit 3fe276e2 authored by Dan Xu's avatar Dan Xu Committed by GitHub

Merge pull request #301 from Birch-san/fix-get-arg-values-use-after-free

Fix use-after-free in BPFtrace::get_arg_values
parents a349c28b 01491c63
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "bpforc.h" #include "bpforc.h"
#include "bpftrace.h" #include "bpftrace.h"
#include "attached_probe.h" #include "attached_probe.h"
#include "printf.h"
#include "triggers.h" #include "triggers.h"
#include "resolve_cgroupid.h" #include "resolve_cgroupid.h"
...@@ -263,7 +264,7 @@ void perf_event_printer(void *cb_cookie, void *data, int size) ...@@ -263,7 +264,7 @@ void perf_event_printer(void *cb_cookie, void *data, int size)
auto id = printf_id - asyncactionint(AsyncAction::syscall); auto id = printf_id - asyncactionint(AsyncAction::syscall);
auto fmt = std::get<0>(bpftrace->system_args_[id]).c_str(); auto fmt = std::get<0>(bpftrace->system_args_[id]).c_str();
auto args = std::get<1>(bpftrace->system_args_[id]); auto args = std::get<1>(bpftrace->system_args_[id]);
std::vector<uint64_t> arg_values = bpftrace->get_arg_values(args, arg_data); auto arg_values = bpftrace->get_arg_values(args, arg_data);
char buffer [255]; char buffer [255];
...@@ -273,30 +274,30 @@ void perf_event_printer(void *cb_cookie, void *data, int size) ...@@ -273,30 +274,30 @@ void perf_event_printer(void *cb_cookie, void *data, int size)
system(fmt); system(fmt);
break; break;
case 1: case 1:
snprintf(buffer, 255, fmt, arg_values.at(0)); snprintf(buffer, 255, fmt, arg_values.at(0)->value());
system(buffer); system(buffer);
break; break;
case 2: case 2:
snprintf(buffer, 255, fmt, arg_values.at(0), arg_values.at(1)); snprintf(buffer, 255, fmt, arg_values.at(0)->value(), arg_values.at(1)->value());
system(buffer); system(buffer);
break; break;
case 3: case 3:
snprintf(buffer, 255, fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2)); snprintf(buffer, 255, fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value());
system(buffer); system(buffer);
break; break;
case 4: case 4:
snprintf(buffer, 255, fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), snprintf(buffer, 255, fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3)); arg_values.at(3)->value());
system(buffer); system(buffer);
break; break;
case 5: case 5:
snprintf(buffer, 255, fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), snprintf(buffer, 255, fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3), arg_values.at(4)); arg_values.at(3)->value(), arg_values.at(4)->value());
system(buffer); system(buffer);
break; break;
case 6: case 6:
snprintf(buffer, 255, fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), snprintf(buffer, 255, fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3), arg_values.at(4), arg_values.at(5)); arg_values.at(3)->value(), arg_values.at(4)->value(), arg_values.at(5)->value());
system(buffer); system(buffer);
break; break;
default: default:
...@@ -309,7 +310,7 @@ void perf_event_printer(void *cb_cookie, void *data, int size) ...@@ -309,7 +310,7 @@ void perf_event_printer(void *cb_cookie, void *data, int size)
// printf // printf
auto fmt = std::get<0>(bpftrace->printf_args_[printf_id]).c_str(); auto fmt = std::get<0>(bpftrace->printf_args_[printf_id]).c_str();
auto args = std::get<1>(bpftrace->printf_args_[printf_id]); auto args = std::get<1>(bpftrace->printf_args_[printf_id]);
std::vector<uint64_t> arg_values = bpftrace->get_arg_values(args, arg_data); auto arg_values = bpftrace->get_arg_values(args, arg_data);
switch (args.size()) switch (args.size())
{ {
...@@ -317,38 +318,35 @@ void perf_event_printer(void *cb_cookie, void *data, int size) ...@@ -317,38 +318,35 @@ void perf_event_printer(void *cb_cookie, void *data, int size)
printf(fmt); printf(fmt);
break; break;
case 1: case 1:
printf(fmt, arg_values.at(0)); printf(fmt, arg_values.at(0)->value());
break; break;
case 2: case 2:
printf(fmt, arg_values.at(0), arg_values.at(1)); printf(fmt, arg_values.at(0)->value(), arg_values.at(1)->value());
break; break;
case 3: case 3:
printf(fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2)); printf(fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value());
break; break;
case 4: case 4:
printf(fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), printf(fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3)); arg_values.at(3)->value());
break; break;
case 5: case 5:
printf(fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), printf(fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3), arg_values.at(4)); arg_values.at(3)->value(), arg_values.at(4)->value());
break; break;
case 6: case 6:
printf(fmt, arg_values.at(0), arg_values.at(1), arg_values.at(2), printf(fmt, arg_values.at(0)->value(), arg_values.at(1)->value(), arg_values.at(2)->value(),
arg_values.at(3), arg_values.at(4), arg_values.at(5)); arg_values.at(3)->value(), arg_values.at(4)->value(), arg_values.at(5)->value());
break; break;
default: default:
abort(); abort();
} }
} }
std::vector<uint64_t> BPFtrace::get_arg_values(std::vector<Field> args, uint8_t* arg_data) std::vector<std::unique_ptr<IPrintable>> BPFtrace::get_arg_values(const std::vector<Field> &args, uint8_t* arg_data)
{ {
std::vector<uint64_t> arg_values; std::vector<std::unique_ptr<IPrintable>> arg_values;
std::vector<std::unique_ptr<char>> resolved_symbols;
std::vector<std::unique_ptr<char>> resolved_usernames;
char *name;
for (auto arg : args) for (auto arg : args)
{ {
switch (arg.type.type) switch (arg.type.type)
...@@ -357,54 +355,80 @@ std::vector<uint64_t> BPFtrace::get_arg_values(std::vector<Field> args, uint8_t* ...@@ -357,54 +355,80 @@ std::vector<uint64_t> BPFtrace::get_arg_values(std::vector<Field> args, uint8_t*
switch (arg.type.size) switch (arg.type.size)
{ {
case 8: case 8:
arg_values.push_back(*(uint64_t*)(arg_data+arg.offset)); arg_values.push_back(
std::make_unique<PrintableInt>(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset)));
break; break;
case 4: case 4:
arg_values.push_back(*(uint32_t*)(arg_data+arg.offset)); arg_values.push_back(
std::make_unique<PrintableInt>(
*reinterpret_cast<uint32_t*>(arg_data+arg.offset)));
break; break;
case 2: case 2:
arg_values.push_back(*(uint16_t*)(arg_data+arg.offset)); arg_values.push_back(
std::make_unique<PrintableInt>(
*reinterpret_cast<uint16_t*>(arg_data+arg.offset)));
break; break;
case 1: case 1:
arg_values.push_back(*(uint8_t*)(arg_data+arg.offset)); arg_values.push_back(
std::make_unique<PrintableInt>(
*reinterpret_cast<uint8_t*>(arg_data+arg.offset)));
break; break;
default: default:
abort(); abort();
} }
break; break;
case Type::string: case Type::string:
arg_values.push_back((uint64_t)(arg_data+arg.offset)); arg_values.push_back(
std::make_unique<PrintableCString>(
reinterpret_cast<char *>(arg_data+arg.offset)));
break; break;
case Type::sym: case Type::sym:
resolved_symbols.emplace_back(strdup( arg_values.push_back(
resolve_sym(*(uint64_t*)(arg_data+arg.offset)).c_str())); std::make_unique<PrintableString>(
arg_values.push_back((uint64_t)resolved_symbols.back().get()); resolve_sym(*reinterpret_cast<uint64_t*>(arg_data+arg.offset))));
break; break;
case Type::usym: case Type::usym:
resolved_symbols.emplace_back(strdup( arg_values.push_back(
resolve_usym(*(uint64_t*)(arg_data+arg.offset), *(uint64_t*)(arg_data+arg.offset + 8)).c_str())); std::make_unique<PrintableString>(
arg_values.push_back((uint64_t)resolved_symbols.back().get()); resolve_usym(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset),
*reinterpret_cast<uint64_t*>(arg_data+arg.offset + 8))));
break; break;
case Type::inet: case Type::inet:
name = strdup(resolve_inet(*(uint64_t*)(arg_data+arg.offset), *(uint64_t*)(arg_data+arg.offset+8)).c_str()); arg_values.push_back(
arg_values.push_back((uint64_t)name); std::make_unique<PrintableString>(
resolve_inet(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset),
*reinterpret_cast<uint64_t*>(arg_data+arg.offset+8))));
break; break;
case Type::username: case Type::username:
resolved_usernames.emplace_back(strdup( arg_values.push_back(
resolve_uid(*(uint64_t*)(arg_data+arg.offset)).c_str())); std::make_unique<PrintableString>(
arg_values.push_back((uint64_t)resolved_usernames.back().get()); resolve_uid(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset))));
break; break;
case Type::probe: case Type::probe:
name = strdup(resolve_probe(*(uint64_t*)(arg_data+arg.offset)).c_str()); arg_values.push_back(
arg_values.push_back((uint64_t)name); std::make_unique<PrintableString>(
resolve_probe(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset))));
break; break;
case Type::stack: case Type::stack:
name = strdup(get_stack(*(uint64_t*)(arg_data+arg.offset), false, 8).c_str()); arg_values.push_back(
arg_values.push_back((uint64_t)name); std::make_unique<PrintableString>(
get_stack(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset),
false,
8)));
break; break;
case Type::ustack: case Type::ustack:
name = strdup(get_stack(*(uint64_t*)(arg_data+arg.offset), true, 8).c_str()); arg_values.push_back(
arg_values.push_back((uint64_t)name); std::make_unique<PrintableString>(
get_stack(
*reinterpret_cast<uint64_t*>(arg_data+arg.offset),
true,
8)));
break; break;
default: default:
abort(); abort();
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "ast.h" #include "ast.h"
#include "attached_probe.h" #include "attached_probe.h"
#include "imap.h" #include "imap.h"
#include "printf.h"
#include "struct.h" #include "struct.h"
#include "types.h" #include "types.h"
...@@ -68,7 +69,7 @@ public: ...@@ -68,7 +69,7 @@ public:
std::string extract_func_symbols_from_path(const std::string &path); std::string extract_func_symbols_from_path(const std::string &path);
std::string resolve_probe(uint64_t probe_id); std::string resolve_probe(uint64_t probe_id);
uint64_t resolve_cgroupid(const std::string &path); uint64_t resolve_cgroupid(const std::string &path);
std::vector<uint64_t> get_arg_values(std::vector<Field> args, uint8_t* arg_data); std::vector<std::unique_ptr<IPrintable>> get_arg_values(const std::vector<Field> &args, uint8_t* arg_data);
void add_param(const std::string &param); void add_param(const std::string &param);
bool is_numeric(std::string str); bool is_numeric(std::string str);
std::string get_param(int index); std::string get_param(int index);
......
...@@ -64,4 +64,19 @@ std::string verify_format_string(const std::string &fmt, std::vector<Field> args ...@@ -64,4 +64,19 @@ std::string verify_format_string(const std::string &fmt, std::vector<Field> args
return ""; return "";
} }
uint64_t PrintableString::value()
{
return (uint64_t)value_.c_str();
}
uint64_t PrintableCString::value()
{
return (uint64_t)value_;
}
uint64_t PrintableInt::value()
{
return value_;
}
} // namespace bpftrace } // namespace bpftrace
#pragma once
#include <sstream> #include <sstream>
#include "ast.h" #include "ast.h"
...@@ -9,4 +11,39 @@ struct Field; ...@@ -9,4 +11,39 @@ struct Field;
std::string verify_format_string(const std::string &fmt, std::vector<Field> args); std::string verify_format_string(const std::string &fmt, std::vector<Field> args);
class IPrintable
{
public:
virtual ~IPrintable() { };
virtual uint64_t value() = 0;
};
class PrintableString : public virtual IPrintable
{
public:
PrintableString(std::string value) : value_(std::move(value)) { }
uint64_t value();
private:
std::string value_;
};
class PrintableCString : public virtual IPrintable
{
public:
PrintableCString(char* value) : value_(value) { }
uint64_t value();
private:
char* value_;
};
class PrintableInt : public virtual IPrintable
{
public:
PrintableInt(uint64_t value) : value_(value) { }
uint64_t value();
private:
uint64_t value_;
};
} // namespace bpftrace } // namespace bpftrace
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment