Commit 28ba934d authored by David S. Miller's avatar David S. Miller

Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf

Alexei Starovoitov says:

====================
pull-request: bpf 2019-07-25

The following pull-request contains BPF updates for your *net* tree.

The main changes are:

1) fix segfault in libbpf, from Andrii.

2) fix gso_segs access, from Eric.

3) tls/sockmap fixes, from Jakub and John.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 47d858d0 cb8ffde5
...@@ -513,3 +513,9 @@ Redirects leak clear text ...@@ -513,3 +513,9 @@ Redirects leak clear text
In the RX direction, if segment has already been decrypted by the device In the RX direction, if segment has already been decrypted by the device
and it gets redirected or mirrored - clear text will be transmitted out. and it gets redirected or mirrored - clear text will be transmitted out.
shutdown() doesn't clear TLS state
----------------------------------
shutdown() system call allows for a TLS socket to be reused as a different
connection. Offload doesn't currently handle that.
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <net/sch_generic.h> #include <net/sch_generic.h>
#include <asm/byteorder.h>
#include <uapi/linux/filter.h> #include <uapi/linux/filter.h>
#include <uapi/linux/bpf.h> #include <uapi/linux/bpf.h>
...@@ -747,6 +748,18 @@ bpf_ctx_narrow_access_ok(u32 off, u32 size, u32 size_default) ...@@ -747,6 +748,18 @@ bpf_ctx_narrow_access_ok(u32 off, u32 size, u32 size_default)
return size <= size_default && (size & (size - 1)) == 0; return size <= size_default && (size & (size - 1)) == 0;
} }
static inline u8
bpf_ctx_narrow_load_shift(u32 off, u32 size, u32 size_default)
{
u8 load_off = off & (size_default - 1);
#ifdef __LITTLE_ENDIAN
return load_off * 8;
#else
return (size_default - (load_off + size)) * 8;
#endif
}
#define bpf_ctx_wide_access_ok(off, size, type, field) \ #define bpf_ctx_wide_access_ok(off, size, type, field) \
(size == sizeof(__u64) && \ (size == sizeof(__u64) && \
off >= offsetof(type, field) && \ off >= offsetof(type, field) && \
......
...@@ -354,6 +354,12 @@ static inline void sk_psock_restore_proto(struct sock *sk, ...@@ -354,6 +354,12 @@ static inline void sk_psock_restore_proto(struct sock *sk,
sk->sk_write_space = psock->saved_write_space; sk->sk_write_space = psock->saved_write_space;
if (psock->sk_proto) { if (psock->sk_proto) {
struct inet_connection_sock *icsk = inet_csk(sk);
bool has_ulp = !!icsk->icsk_ulp_data;
if (has_ulp)
tcp_update_ulp(sk, psock->sk_proto);
else
sk->sk_prot = psock->sk_proto; sk->sk_prot = psock->sk_proto;
psock->sk_proto = NULL; psock->sk_proto = NULL;
} }
......
...@@ -2108,6 +2108,8 @@ struct tcp_ulp_ops { ...@@ -2108,6 +2108,8 @@ struct tcp_ulp_ops {
/* initialize ulp */ /* initialize ulp */
int (*init)(struct sock *sk); int (*init)(struct sock *sk);
/* update ulp */
void (*update)(struct sock *sk, struct proto *p);
/* cleanup ulp */ /* cleanup ulp */
void (*release)(struct sock *sk); void (*release)(struct sock *sk);
...@@ -2119,6 +2121,7 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type); ...@@ -2119,6 +2121,7 @@ void tcp_unregister_ulp(struct tcp_ulp_ops *type);
int tcp_set_ulp(struct sock *sk, const char *name); int tcp_set_ulp(struct sock *sk, const char *name);
void tcp_get_available_ulp(char *buf, size_t len); void tcp_get_available_ulp(char *buf, size_t len);
void tcp_cleanup_ulp(struct sock *sk); void tcp_cleanup_ulp(struct sock *sk);
void tcp_update_ulp(struct sock *sk, struct proto *p);
#define MODULE_ALIAS_TCP_ULP(name) \ #define MODULE_ALIAS_TCP_ULP(name) \
__MODULE_INFO(alias, alias_userspace, name); \ __MODULE_INFO(alias, alias_userspace, name); \
......
...@@ -107,9 +107,7 @@ struct tls_device { ...@@ -107,9 +107,7 @@ struct tls_device {
enum { enum {
TLS_BASE, TLS_BASE,
TLS_SW, TLS_SW,
#ifdef CONFIG_TLS_DEVICE
TLS_HW, TLS_HW,
#endif
TLS_HW_RECORD, TLS_HW_RECORD,
TLS_NUM_CONFIG, TLS_NUM_CONFIG,
}; };
...@@ -162,6 +160,7 @@ struct tls_sw_context_tx { ...@@ -162,6 +160,7 @@ struct tls_sw_context_tx {
int async_capable; int async_capable;
#define BIT_TX_SCHEDULED 0 #define BIT_TX_SCHEDULED 0
#define BIT_TX_CLOSING 1
unsigned long tx_bitmask; unsigned long tx_bitmask;
}; };
...@@ -272,6 +271,8 @@ struct tls_context { ...@@ -272,6 +271,8 @@ struct tls_context {
unsigned long flags; unsigned long flags;
/* cache cold stuff */ /* cache cold stuff */
struct proto *sk_proto;
void (*sk_destruct)(struct sock *sk); void (*sk_destruct)(struct sock *sk);
void (*sk_proto_close)(struct sock *sk, long timeout); void (*sk_proto_close)(struct sock *sk, long timeout);
...@@ -289,6 +290,8 @@ struct tls_context { ...@@ -289,6 +290,8 @@ struct tls_context {
struct list_head list; struct list_head list;
refcount_t refcount; refcount_t refcount;
struct work_struct gc;
}; };
enum tls_offload_ctx_dir { enum tls_offload_ctx_dir {
...@@ -355,13 +358,17 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval, ...@@ -355,13 +358,17 @@ int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
unsigned int optlen); unsigned int optlen);
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx); int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
void tls_sw_strparser_done(struct tls_context *tls_ctx);
int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_sw_sendpage(struct sock *sk, struct page *page, int tls_sw_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags); int offset, size_t size, int flags);
void tls_sw_close(struct sock *sk, long timeout); void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_tx(struct sock *sk); void tls_sw_release_resources_tx(struct sock *sk);
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
void tls_sw_free_resources_rx(struct sock *sk); void tls_sw_free_resources_rx(struct sock *sk);
void tls_sw_release_resources_rx(struct sock *sk); void tls_sw_release_resources_rx(struct sock *sk);
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len); int nonblock, int flags, int *addr_len);
bool tls_sw_stream_read(const struct sock *sk); bool tls_sw_stream_read(const struct sock *sk);
......
...@@ -8616,8 +8616,8 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) ...@@ -8616,8 +8616,8 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
} }
if (is_narrower_load && size < target_size) { if (is_narrower_load && size < target_size) {
u8 shift = (off & (size_default - 1)) * 8; u8 shift = bpf_ctx_narrow_load_shift(off, size,
size_default);
if (ctx_field_size <= 4) { if (ctx_field_size <= 4) {
if (shift) if (shift)
insn_buf[cnt++] = BPF_ALU32_IMM(BPF_RSH, insn_buf[cnt++] = BPF_ALU32_IMM(BPF_RSH,
......
...@@ -7455,12 +7455,12 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type, ...@@ -7455,12 +7455,12 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
case offsetof(struct __sk_buff, gso_segs): case offsetof(struct __sk_buff, gso_segs):
/* si->dst_reg = skb_shinfo(SKB); */ /* si->dst_reg = skb_shinfo(SKB); */
#ifdef NET_SKBUFF_DATA_USES_OFFSET #ifdef NET_SKBUFF_DATA_USES_OFFSET
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head),
si->dst_reg, si->src_reg,
offsetof(struct sk_buff, head));
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end),
BPF_REG_AX, si->src_reg, BPF_REG_AX, si->src_reg,
offsetof(struct sk_buff, end)); offsetof(struct sk_buff, end));
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, head),
si->dst_reg, si->src_reg,
offsetof(struct sk_buff, head));
*insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX); *insn++ = BPF_ALU64_REG(BPF_ADD, si->dst_reg, BPF_REG_AX);
#else #else
*insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end), *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, end),
......
...@@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy); ...@@ -585,12 +585,12 @@ EXPORT_SYMBOL_GPL(sk_psock_destroy);
void sk_psock_drop(struct sock *sk, struct sk_psock *psock) void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
{ {
rcu_assign_sk_user_data(sk, NULL);
sk_psock_cork_free(psock); sk_psock_cork_free(psock);
sk_psock_zap_ingress(psock); sk_psock_zap_ingress(psock);
sk_psock_restore_proto(sk, psock);
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
sk_psock_restore_proto(sk, psock);
rcu_assign_sk_user_data(sk, NULL);
if (psock->progs.skb_parser) if (psock->progs.skb_parser)
sk_psock_stop_strp(sk, psock); sk_psock_stop_strp(sk, psock);
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
......
...@@ -247,6 +247,8 @@ static void sock_map_free(struct bpf_map *map) ...@@ -247,6 +247,8 @@ static void sock_map_free(struct bpf_map *map)
raw_spin_unlock_bh(&stab->lock); raw_spin_unlock_bh(&stab->lock);
rcu_read_unlock(); rcu_read_unlock();
synchronize_rcu();
bpf_map_area_free(stab->sks); bpf_map_area_free(stab->sks);
kfree(stab); kfree(stab);
} }
...@@ -276,16 +278,20 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, ...@@ -276,16 +278,20 @@ static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
struct sock **psk) struct sock **psk)
{ {
struct sock *sk; struct sock *sk;
int err = 0;
raw_spin_lock_bh(&stab->lock); raw_spin_lock_bh(&stab->lock);
sk = *psk; sk = *psk;
if (!sk_test || sk_test == sk) if (!sk_test || sk_test == sk)
*psk = NULL; sk = xchg(psk, NULL);
raw_spin_unlock_bh(&stab->lock);
if (unlikely(!sk)) if (likely(sk))
return -EINVAL;
sock_map_unref(sk, psk); sock_map_unref(sk, psk);
return 0; else
err = -EINVAL;
raw_spin_unlock_bh(&stab->lock);
return err;
} }
static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
...@@ -328,6 +334,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -328,6 +334,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
struct sock *sk, u64 flags) struct sock *sk, u64 flags)
{ {
struct bpf_stab *stab = container_of(map, struct bpf_stab, map); struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
struct inet_connection_sock *icsk = inet_csk(sk);
struct sk_psock_link *link; struct sk_psock_link *link;
struct sk_psock *psock; struct sk_psock *psock;
struct sock *osk; struct sock *osk;
...@@ -338,6 +345,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx, ...@@ -338,6 +345,8 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
return -EINVAL; return -EINVAL;
if (unlikely(idx >= map->max_entries)) if (unlikely(idx >= map->max_entries))
return -E2BIG; return -E2BIG;
if (unlikely(icsk->icsk_ulp_data))
return -EINVAL;
link = sk_psock_init_link(); link = sk_psock_init_link();
if (!link) if (!link)
......
...@@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen) ...@@ -96,6 +96,19 @@ void tcp_get_available_ulp(char *buf, size_t maxlen)
rcu_read_unlock(); rcu_read_unlock();
} }
void tcp_update_ulp(struct sock *sk, struct proto *proto)
{
struct inet_connection_sock *icsk = inet_csk(sk);
if (!icsk->icsk_ulp_ops) {
sk->sk_prot = proto;
return;
}
if (icsk->icsk_ulp_ops->update)
icsk->icsk_ulp_ops->update(sk, proto);
}
void tcp_cleanup_ulp(struct sock *sk) void tcp_cleanup_ulp(struct sock *sk)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
......
...@@ -261,24 +261,36 @@ void tls_ctx_free(struct tls_context *ctx) ...@@ -261,24 +261,36 @@ void tls_ctx_free(struct tls_context *ctx)
kfree(ctx); kfree(ctx);
} }
static void tls_sk_proto_close(struct sock *sk, long timeout) static void tls_ctx_free_deferred(struct work_struct *gc)
{ {
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = container_of(gc, struct tls_context, gc);
long timeo = sock_sndtimeo(sk, 0);
void (*sk_proto_close)(struct sock *sk, long timeout);
bool free_ctx = false;
lock_sock(sk);
sk_proto_close = ctx->sk_proto_close;
if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) /* Ensure any remaining work items are completed. The sk will
goto skip_tx_cleanup; * already have lost its tls_ctx reference by the time we get
* here so no xmit operation will actually be performed.
*/
if (ctx->tx_conf == TLS_SW) {
tls_sw_cancel_work_tx(ctx);
tls_sw_free_ctx_tx(ctx);
}
if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { if (ctx->rx_conf == TLS_SW) {
free_ctx = true; tls_sw_strparser_done(ctx);
goto skip_tx_cleanup; tls_sw_free_ctx_rx(ctx);
} }
tls_ctx_free(ctx);
}
static void tls_ctx_free_wq(struct tls_context *ctx)
{
INIT_WORK(&ctx->gc, tls_ctx_free_deferred);
schedule_work(&ctx->gc);
}
static void tls_sk_proto_cleanup(struct sock *sk,
struct tls_context *ctx, long timeo)
{
if (unlikely(sk->sk_write_pending) && if (unlikely(sk->sk_write_pending) &&
!wait_on_pending_writer(sk, &timeo)) !wait_on_pending_writer(sk, &timeo))
tls_handle_open_record(sk, 0); tls_handle_open_record(sk, 0);
...@@ -287,7 +299,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -287,7 +299,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
if (ctx->tx_conf == TLS_SW) { if (ctx->tx_conf == TLS_SW) {
kfree(ctx->tx.rec_seq); kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv); kfree(ctx->tx.iv);
tls_sw_free_resources_tx(sk); tls_sw_release_resources_tx(sk);
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
} else if (ctx->tx_conf == TLS_HW) { } else if (ctx->tx_conf == TLS_HW) {
tls_device_free_resources_tx(sk); tls_device_free_resources_tx(sk);
...@@ -295,26 +307,67 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -295,26 +307,67 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
} }
if (ctx->rx_conf == TLS_SW) if (ctx->rx_conf == TLS_SW)
tls_sw_free_resources_rx(sk); tls_sw_release_resources_rx(sk);
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
if (ctx->rx_conf == TLS_HW) if (ctx->rx_conf == TLS_HW)
tls_device_offload_cleanup_rx(sk); tls_device_offload_cleanup_rx(sk);
if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
#else
{
#endif #endif
tls_ctx_free(ctx); }
ctx = NULL;
static void tls_sk_proto_unhash(struct sock *sk)
{
struct inet_connection_sock *icsk = inet_csk(sk);
long timeo = sock_sndtimeo(sk, 0);
struct tls_context *ctx;
if (unlikely(!icsk->icsk_ulp_data)) {
if (sk->sk_prot->unhash)
sk->sk_prot->unhash(sk);
} }
skip_tx_cleanup: ctx = tls_get_ctx(sk);
tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
if (ctx->sk_proto->unhash)
ctx->sk_proto->unhash(sk);
tls_ctx_free_wq(ctx);
}
static void tls_sk_proto_close(struct sock *sk, long timeout)
{
struct inet_connection_sock *icsk = inet_csk(sk);
struct tls_context *ctx = tls_get_ctx(sk);
long timeo = sock_sndtimeo(sk, 0);
bool free_ctx;
if (ctx->tx_conf == TLS_SW)
tls_sw_cancel_work_tx(ctx);
lock_sock(sk);
free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
tls_sk_proto_cleanup(sk, ctx, timeo);
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
icsk->icsk_ulp_data = NULL;
sk->sk_prot = ctx->sk_proto;
write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk); release_sock(sk);
sk_proto_close(sk, timeout); if (ctx->tx_conf == TLS_SW)
/* free ctx for TLS_HW_RECORD, used by tcp_set_state tls_sw_free_ctx_tx(ctx);
* for sk->sk_prot->unhash [tls_hw_unhash] if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
*/ tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx);
ctx->sk_proto_close(sk, timeout);
if (free_ctx) if (free_ctx)
tls_ctx_free(ctx); tls_ctx_free(ctx);
} }
...@@ -526,6 +579,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -526,6 +579,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{ {
#endif #endif
rc = tls_set_sw_offload(sk, ctx, 1); rc = tls_set_sw_offload(sk, ctx, 1);
if (rc)
goto err_crypto_info;
conf = TLS_SW; conf = TLS_SW;
} }
} else { } else {
...@@ -537,13 +592,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, ...@@ -537,13 +592,13 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
{ {
#endif #endif
rc = tls_set_sw_offload(sk, ctx, 0); rc = tls_set_sw_offload(sk, ctx, 0);
if (rc)
goto err_crypto_info;
conf = TLS_SW; conf = TLS_SW;
} }
tls_sw_strparser_arm(sk, ctx);
} }
if (rc)
goto err_crypto_info;
if (tx) if (tx)
ctx->tx_conf = conf; ctx->tx_conf = conf;
else else
...@@ -607,6 +662,7 @@ static struct tls_context *create_ctx(struct sock *sk) ...@@ -607,6 +662,7 @@ static struct tls_context *create_ctx(struct sock *sk)
ctx->setsockopt = sk->sk_prot->setsockopt; ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt; ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close; ctx->sk_proto_close = sk->sk_prot->close;
ctx->unhash = sk->sk_prot->unhash;
return ctx; return ctx;
} }
...@@ -730,6 +786,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -730,6 +786,7 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt;
prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close;
prot[TLS_BASE][TLS_BASE].unhash = tls_sk_proto_unhash;
prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg;
...@@ -747,16 +804,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -747,16 +804,20 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
#ifdef CONFIG_TLS_DEVICE #ifdef CONFIG_TLS_DEVICE
prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
prot[TLS_HW][TLS_BASE].unhash = base->unhash;
prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage;
prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
prot[TLS_HW][TLS_SW].unhash = base->unhash;
prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg;
prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage;
prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
prot[TLS_BASE][TLS_HW].unhash = base->unhash;
prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
prot[TLS_SW][TLS_HW].unhash = base->unhash;
prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
#endif #endif
...@@ -764,7 +825,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], ...@@ -764,7 +825,6 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash;
prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close;
} }
static int tls_init(struct sock *sk) static int tls_init(struct sock *sk)
...@@ -773,7 +833,7 @@ static int tls_init(struct sock *sk) ...@@ -773,7 +833,7 @@ static int tls_init(struct sock *sk)
int rc = 0; int rc = 0;
if (tls_hw_prot(sk)) if (tls_hw_prot(sk))
goto out; return 0;
/* The TLS ulp is currently supported only for TCP sockets /* The TLS ulp is currently supported only for TCP sockets
* in ESTABLISHED state. * in ESTABLISHED state.
...@@ -784,21 +844,38 @@ static int tls_init(struct sock *sk) ...@@ -784,21 +844,38 @@ static int tls_init(struct sock *sk)
if (sk->sk_state != TCP_ESTABLISHED) if (sk->sk_state != TCP_ESTABLISHED)
return -ENOTSUPP; return -ENOTSUPP;
tls_build_proto(sk);
/* allocate tls context */ /* allocate tls context */
write_lock_bh(&sk->sk_callback_lock);
ctx = create_ctx(sk); ctx = create_ctx(sk);
if (!ctx) { if (!ctx) {
rc = -ENOMEM; rc = -ENOMEM;
goto out; goto out;
} }
tls_build_proto(sk);
ctx->tx_conf = TLS_BASE; ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE;
ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
out: out:
write_unlock_bh(&sk->sk_callback_lock);
return rc; return rc;
} }
static void tls_update(struct sock *sk, struct proto *p)
{
struct tls_context *ctx;
ctx = tls_get_ctx(sk);
if (likely(ctx)) {
ctx->sk_proto_close = p->close;
ctx->sk_proto = p;
} else {
sk->sk_prot = p;
}
}
void tls_register_device(struct tls_device *device) void tls_register_device(struct tls_device *device)
{ {
spin_lock_bh(&device_spinlock); spin_lock_bh(&device_spinlock);
...@@ -819,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { ...@@ -819,6 +896,7 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
.name = "tls", .name = "tls",
.owner = THIS_MODULE, .owner = THIS_MODULE,
.init = tls_init, .init = tls_init,
.update = tls_update,
}; };
static int __init tls_register(void) static int __init tls_register(void)
......
...@@ -2054,7 +2054,16 @@ static void tls_data_ready(struct sock *sk) ...@@ -2054,7 +2054,16 @@ static void tls_data_ready(struct sock *sk)
} }
} }
void tls_sw_free_resources_tx(struct sock *sk) void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
{
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
cancel_delayed_work_sync(&ctx->tx_work.work);
}
void tls_sw_release_resources_tx(struct sock *sk)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
...@@ -2065,11 +2074,6 @@ void tls_sw_free_resources_tx(struct sock *sk) ...@@ -2065,11 +2074,6 @@ void tls_sw_free_resources_tx(struct sock *sk)
if (atomic_read(&ctx->encrypt_pending)) if (atomic_read(&ctx->encrypt_pending))
crypto_wait_req(-EINPROGRESS, &ctx->async_wait); crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
release_sock(sk);
cancel_delayed_work_sync(&ctx->tx_work.work);
lock_sock(sk);
/* Tx whatever records we can transmit and abandon the rest */
tls_tx_records(sk, -1); tls_tx_records(sk, -1);
/* Free up un-sent records in tx_list. First, free /* Free up un-sent records in tx_list. First, free
...@@ -2092,6 +2096,11 @@ void tls_sw_free_resources_tx(struct sock *sk) ...@@ -2092,6 +2096,11 @@ void tls_sw_free_resources_tx(struct sock *sk)
crypto_free_aead(ctx->aead_send); crypto_free_aead(ctx->aead_send);
tls_free_open_rec(sk); tls_free_open_rec(sk);
}
void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
{
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
kfree(ctx); kfree(ctx);
} }
...@@ -2110,25 +2119,40 @@ void tls_sw_release_resources_rx(struct sock *sk) ...@@ -2110,25 +2119,40 @@ void tls_sw_release_resources_rx(struct sock *sk)
skb_queue_purge(&ctx->rx_list); skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv); crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp); strp_stop(&ctx->strp);
/* If tls_sw_strparser_arm() was not called (cleanup paths)
* we still want to strp_stop(), but sk->sk_data_ready was
* never swapped.
*/
if (ctx->saved_data_ready) {
write_lock_bh(&sk->sk_callback_lock); write_lock_bh(&sk->sk_callback_lock);
sk->sk_data_ready = ctx->saved_data_ready; sk->sk_data_ready = ctx->saved_data_ready;
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
release_sock(sk); }
strp_done(&ctx->strp);
lock_sock(sk);
} }
} }
void tls_sw_free_resources_rx(struct sock *sk) void tls_sw_strparser_done(struct tls_context *tls_ctx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
tls_sw_release_resources_rx(sk); strp_done(&ctx->strp);
}
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
kfree(ctx); kfree(ctx);
} }
void tls_sw_free_resources_rx(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
tls_sw_release_resources_rx(sk);
tls_sw_free_ctx_rx(tls_ctx);
}
/* The work handler to transmitt the encrypted records in tx_list */ /* The work handler to transmitt the encrypted records in tx_list */
static void tx_work_handler(struct work_struct *work) static void tx_work_handler(struct work_struct *work)
{ {
...@@ -2137,11 +2161,17 @@ static void tx_work_handler(struct work_struct *work) ...@@ -2137,11 +2161,17 @@ static void tx_work_handler(struct work_struct *work)
struct tx_work, work); struct tx_work, work);
struct sock *sk = tx_work->sk; struct sock *sk = tx_work->sk;
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); struct tls_sw_context_tx *ctx;
if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) if (unlikely(!tls_ctx))
return;
ctx = tls_sw_ctx_tx(tls_ctx);
if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
return; return;
if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
return;
lock_sock(sk); lock_sock(sk);
tls_tx_records(sk, -1); tls_tx_records(sk, -1);
release_sock(sk); release_sock(sk);
...@@ -2160,6 +2190,18 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) ...@@ -2160,6 +2190,18 @@ void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
} }
} }
void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
write_lock_bh(&sk->sk_callback_lock);
rx_ctx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
strp_check_rcv(&rx_ctx->strp);
}
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{ {
struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_context *tls_ctx = tls_get_ctx(sk);
...@@ -2357,13 +2399,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ...@@ -2357,13 +2399,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
cb.parse_msg = tls_read_size; cb.parse_msg = tls_read_size;
strp_init(&sw_ctx_rx->strp, sk, &cb); strp_init(&sw_ctx_rx->strp, sk, &cb);
write_lock_bh(&sk->sk_callback_lock);
sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
strp_check_rcv(&sw_ctx_rx->strp);
} }
goto out; goto out;
......
// SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause) // SPDX-License-Identifier: (LGPL-2.1 OR BSD-2-Clause)
/* Copyright (c) 2018 Facebook */ /* Copyright (c) 2018 Facebook */
#include <endian.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
...@@ -419,9 +420,9 @@ struct btf *btf__new(__u8 *data, __u32 size) ...@@ -419,9 +420,9 @@ struct btf *btf__new(__u8 *data, __u32 size)
static bool btf_check_endianness(const GElf_Ehdr *ehdr) static bool btf_check_endianness(const GElf_Ehdr *ehdr)
{ {
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #if __BYTE_ORDER == __LITTLE_ENDIAN
return ehdr->e_ident[EI_DATA] == ELFDATA2LSB; return ehdr->e_ident[EI_DATA] == ELFDATA2LSB;
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #elif __BYTE_ORDER == __BIG_ENDIAN
return ehdr->e_ident[EI_DATA] == ELFDATA2MSB; return ehdr->e_ident[EI_DATA] == ELFDATA2MSB;
#else #else
# error "Unrecognized __BYTE_ORDER__" # error "Unrecognized __BYTE_ORDER__"
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <inttypes.h> #include <inttypes.h>
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
#include <endian.h>
#include <fcntl.h> #include <fcntl.h>
#include <errno.h> #include <errno.h>
#include <asm/unistd.h> #include <asm/unistd.h>
...@@ -612,10 +613,10 @@ static int bpf_object__elf_init(struct bpf_object *obj) ...@@ -612,10 +613,10 @@ static int bpf_object__elf_init(struct bpf_object *obj)
static int bpf_object__check_endianness(struct bpf_object *obj) static int bpf_object__check_endianness(struct bpf_object *obj)
{ {
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #if __BYTE_ORDER == __LITTLE_ENDIAN
if (obj->efile.ehdr.e_ident[EI_DATA] == ELFDATA2LSB) if (obj->efile.ehdr.e_ident[EI_DATA] == ELFDATA2LSB)
return 0; return 0;
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #elif __BYTE_ORDER == __BIG_ENDIAN
if (obj->efile.ehdr.e_ident[EI_DATA] == ELFDATA2MSB) if (obj->efile.ehdr.e_ident[EI_DATA] == ELFDATA2MSB)
return 0; return 0;
#else #else
...@@ -1377,8 +1378,13 @@ static void bpf_object__sanitize_btf(struct bpf_object *obj) ...@@ -1377,8 +1378,13 @@ static void bpf_object__sanitize_btf(struct bpf_object *obj)
if (!has_datasec && kind == BTF_KIND_VAR) { if (!has_datasec && kind == BTF_KIND_VAR) {
/* replace VAR with INT */ /* replace VAR with INT */
t->info = BTF_INFO_ENC(BTF_KIND_INT, 0, 0); t->info = BTF_INFO_ENC(BTF_KIND_INT, 0, 0);
t->size = sizeof(int); /*
*(int *)(t+1) = BTF_INT_ENC(0, 0, 32); * using size = 1 is the safest choice, 4 will be too
* big and cause kernel BTF validation failure if
* original variable took less than 4 bytes
*/
t->size = 1;
*(int *)(t+1) = BTF_INT_ENC(0, 0, 8);
} else if (!has_datasec && kind == BTF_KIND_DATASEC) { } else if (!has_datasec && kind == BTF_KIND_DATASEC) {
/* replace DATASEC with STRUCT */ /* replace DATASEC with STRUCT */
struct btf_var_secinfo *v = (void *)(t + 1); struct btf_var_secinfo *v = (void *)(t + 1);
...@@ -1500,6 +1506,12 @@ static int bpf_object__sanitize_and_load_btf(struct bpf_object *obj) ...@@ -1500,6 +1506,12 @@ static int bpf_object__sanitize_and_load_btf(struct bpf_object *obj)
BTF_ELF_SEC, err); BTF_ELF_SEC, err);
btf__free(obj->btf); btf__free(obj->btf);
obj->btf = NULL; obj->btf = NULL;
/* btf_ext can't exist without btf, so free it as well */
if (obj->btf_ext) {
btf_ext__free(obj->btf_ext);
obj->btf_ext = NULL;
}
if (bpf_object__is_btf_mandatory(obj)) if (bpf_object__is_btf_mandatory(obj))
return err; return err;
} }
...@@ -4507,13 +4519,13 @@ struct perf_buffer *perf_buffer__new(int map_fd, size_t page_cnt, ...@@ -4507,13 +4519,13 @@ struct perf_buffer *perf_buffer__new(int map_fd, size_t page_cnt,
const struct perf_buffer_opts *opts) const struct perf_buffer_opts *opts)
{ {
struct perf_buffer_params p = {}; struct perf_buffer_params p = {};
struct perf_event_attr attr = { struct perf_event_attr attr = { 0, };
.config = PERF_COUNT_SW_BPF_OUTPUT,
.type = PERF_TYPE_SOFTWARE, attr.config = PERF_COUNT_SW_BPF_OUTPUT,
.sample_type = PERF_SAMPLE_RAW, attr.type = PERF_TYPE_SOFTWARE;
.sample_period = 1, attr.sample_type = PERF_SAMPLE_RAW;
.wakeup_events = 1, attr.sample_period = 1;
}; attr.wakeup_events = 1;
p.attr = &attr; p.attr = &attr;
p.sample_cb = opts ? opts->sample_cb : NULL; p.sample_cb = opts ? opts->sample_cb : NULL;
......
...@@ -317,17 +317,16 @@ static int xsk_load_xdp_prog(struct xsk_socket *xsk) ...@@ -317,17 +317,16 @@ static int xsk_load_xdp_prog(struct xsk_socket *xsk)
static int xsk_get_max_queues(struct xsk_socket *xsk) static int xsk_get_max_queues(struct xsk_socket *xsk)
{ {
struct ethtool_channels channels; struct ethtool_channels channels = { .cmd = ETHTOOL_GCHANNELS };
struct ifreq ifr; struct ifreq ifr = {};
int fd, err, ret; int fd, err, ret;
fd = socket(AF_INET, SOCK_DGRAM, 0); fd = socket(AF_INET, SOCK_DGRAM, 0);
if (fd < 0) if (fd < 0)
return -errno; return -errno;
channels.cmd = ETHTOOL_GCHANNELS;
ifr.ifr_data = (void *)&channels; ifr.ifr_data = (void *)&channels;
strncpy(ifr.ifr_name, xsk->ifname, IFNAMSIZ - 1); memcpy(ifr.ifr_name, xsk->ifname, IFNAMSIZ - 1);
ifr.ifr_name[IFNAMSIZ - 1] = '\0'; ifr.ifr_name[IFNAMSIZ - 1] = '\0';
err = ioctl(fd, SIOCETHTOOL, &ifr); err = ioctl(fd, SIOCETHTOOL, &ifr);
if (err && errno != EOPNOTSUPP) { if (err && errno != EOPNOTSUPP) {
...@@ -335,7 +334,7 @@ static int xsk_get_max_queues(struct xsk_socket *xsk) ...@@ -335,7 +334,7 @@ static int xsk_get_max_queues(struct xsk_socket *xsk)
goto out; goto out;
} }
if (channels.max_combined == 0 || errno == EOPNOTSUPP) if (err || channels.max_combined == 0)
/* If the device says it has no channels, then all traffic /* If the device says it has no channels, then all traffic
* is sent to a single stream, so max queues = 1. * is sent to a single stream, so max queues = 1.
*/ */
...@@ -517,7 +516,7 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname, ...@@ -517,7 +516,7 @@ int xsk_socket__create(struct xsk_socket **xsk_ptr, const char *ifname,
err = -errno; err = -errno;
goto out_socket; goto out_socket;
} }
strncpy(xsk->ifname, ifname, IFNAMSIZ - 1); memcpy(xsk->ifname, ifname, IFNAMSIZ - 1);
xsk->ifname[IFNAMSIZ - 1] = '\0'; xsk->ifname[IFNAMSIZ - 1] = '\0';
err = xsk_set_xdp_socket_config(&xsk->config, usr_config); err = xsk_set_xdp_socket_config(&xsk->config, usr_config);
......
...@@ -41,8 +41,7 @@ int sendmsg_v6_prog(struct bpf_sock_addr *ctx) ...@@ -41,8 +41,7 @@ int sendmsg_v6_prog(struct bpf_sock_addr *ctx)
} }
/* Rewrite destination. */ /* Rewrite destination. */
if ((ctx->user_ip6[0] & 0xFFFF) == bpf_htons(0xFACE) && if (ctx->user_ip6[0] == bpf_htonl(0xFACEB00C)) {
ctx->user_ip6[0] >> 16 == bpf_htons(0xB00C)) {
ctx->user_ip6[0] = bpf_htonl(DST_REWRITE_IP6_0); ctx->user_ip6[0] = bpf_htonl(DST_REWRITE_IP6_0);
ctx->user_ip6[1] = bpf_htonl(DST_REWRITE_IP6_1); ctx->user_ip6[1] = bpf_htonl(DST_REWRITE_IP6_1);
ctx->user_ip6[2] = bpf_htonl(DST_REWRITE_IP6_2); ctx->user_ip6[2] = bpf_htonl(DST_REWRITE_IP6_2);
......
...@@ -974,6 +974,17 @@ ...@@ -974,6 +974,17 @@
.result = ACCEPT, .result = ACCEPT,
.prog_type = BPF_PROG_TYPE_CGROUP_SKB, .prog_type = BPF_PROG_TYPE_CGROUP_SKB,
}, },
{
"read gso_segs from CGROUP_SKB",
.insns = {
BPF_LDX_MEM(BPF_W, BPF_REG_1, BPF_REG_1,
offsetof(struct __sk_buff, gso_segs)),
BPF_MOV64_IMM(BPF_REG_0, 0),
BPF_EXIT_INSN(),
},
.result = ACCEPT,
.prog_type = BPF_PROG_TYPE_CGROUP_SKB,
},
{ {
"write gso_segs from CGROUP_SKB", "write gso_segs from CGROUP_SKB",
.insns = { .insns = {
......
...@@ -25,6 +25,80 @@ ...@@ -25,6 +25,80 @@
#define TLS_PAYLOAD_MAX_LEN 16384 #define TLS_PAYLOAD_MAX_LEN 16384
#define SOL_TLS 282 #define SOL_TLS 282
#ifndef ENOTSUPP
#define ENOTSUPP 524
#endif
FIXTURE(tls_basic)
{
int fd, cfd;
bool notls;
};
FIXTURE_SETUP(tls_basic)
{
struct sockaddr_in addr;
socklen_t len;
int sfd, ret;
self->notls = false;
len = sizeof(addr);
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
self->fd = socket(AF_INET, SOCK_STREAM, 0);
sfd = socket(AF_INET, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = listen(sfd, 10);
ASSERT_EQ(ret, 0);
ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0);
ret = connect(self->fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
self->cfd = accept(sfd, &addr, &len);
ASSERT_GE(self->cfd, 0);
close(sfd);
ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
if (ret != 0) {
ASSERT_EQ(errno, ENOTSUPP);
self->notls = true;
printf("Failure setting TCP_ULP, testing without tls\n");
return;
}
ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
ASSERT_EQ(ret, 0);
}
FIXTURE_TEARDOWN(tls_basic)
{
close(self->fd);
close(self->cfd);
}
/* Send some data through with ULP but no keys */
TEST_F(tls_basic, base_base)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
};
FIXTURE(tls) FIXTURE(tls)
{ {
int fd, cfd; int fd, cfd;
...@@ -165,6 +239,16 @@ TEST_F(tls, msg_more) ...@@ -165,6 +239,16 @@ TEST_F(tls, msg_more)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
TEST_F(tls, msg_more_unsent)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
}
TEST_F(tls, sendmsg_single) TEST_F(tls, sendmsg_single)
{ {
struct msghdr msg; struct msghdr msg;
...@@ -610,6 +694,37 @@ TEST_F(tls, recv_lowat) ...@@ -610,6 +694,37 @@ TEST_F(tls, recv_lowat)
EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0); EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
} }
TEST_F(tls, bidir)
{
struct tls12_crypto_info_aes_gcm_128 tls12;
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
int ret;
memset(&tls12, 0, sizeof(tls12));
tls12.info.version = TLS_1_3_VERSION;
tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
ASSERT_EQ(ret, 0);
ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
ASSERT_EQ(ret, 0);
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
memset(buf, 0, sizeof(buf));
EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
};
TEST_F(tls, pollin) TEST_F(tls, pollin)
{ {
char const *test_str = "test_poll"; char const *test_str = "test_poll";
...@@ -837,6 +952,85 @@ TEST_F(tls, control_msg) ...@@ -837,6 +952,85 @@ TEST_F(tls, control_msg)
EXPECT_EQ(memcmp(buf, test_str, send_len), 0); EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
} }
TEST_F(tls, shutdown)
{
char const *test_str = "test_read";
int send_len = 10;
char buf[10];
ASSERT_EQ(strlen(test_str) + 1, send_len);
EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
shutdown(self->fd, SHUT_RDWR);
shutdown(self->cfd, SHUT_RDWR);
}
TEST_F(tls, shutdown_unsent)
{
char const *test_str = "test_read";
int send_len = 10;
EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
shutdown(self->fd, SHUT_RDWR);
shutdown(self->cfd, SHUT_RDWR);
}
TEST(non_established) {
struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr;
int sfd, ret, fd;
socklen_t len;
len = sizeof(addr);
memset(&tls12, 0, sizeof(tls12));
tls12.info.version = TLS_1_2_VERSION;
tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
addr.sin_family = AF_INET;
addr.sin_addr.s_addr = htonl(INADDR_ANY);
addr.sin_port = 0;
fd = socket(AF_INET, SOCK_STREAM, 0);
sfd = socket(AF_INET, SOCK_STREAM, 0);
ret = bind(sfd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = listen(sfd, 10);
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
/* TLS ULP not supported */
if (errno == ENOENT)
return;
EXPECT_EQ(errno, ENOTSUPP);
ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
EXPECT_EQ(errno, ENOTSUPP);
ret = getsockname(sfd, &addr, &len);
ASSERT_EQ(ret, 0);
ret = connect(fd, &addr, sizeof(addr));
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
ASSERT_EQ(ret, 0);
ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
EXPECT_EQ(ret, -1);
EXPECT_EQ(errno, EEXIST);
close(fd);
close(sfd);
}
TEST(keysizes) { TEST(keysizes) {
struct tls12_crypto_info_aes_gcm_256 tls12; struct tls12_crypto_info_aes_gcm_256 tls12;
struct sockaddr_in addr; struct sockaddr_in addr;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment