Commit be7bbea1 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller

net/tls: use the full sk_proto pointer

Since we already have the pointer to the full original sk_proto
stored use that instead of storing all individual callback
pointers as well.
Signed-off-by: default avatarJakub Kicinski <jakub.kicinski@netronome.com>
Reviewed-by: default avatarJohn Hurley <john.hurley@netronome.com>
Reviewed-by: default avatarDirk van der Merwe <dirk.vandermerwe@netronome.com>
Acked-by: default avatarJohn Fastabend <john.fastabend@gmail.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 842841ec
...@@ -474,7 +474,8 @@ static int chtls_getsockopt(struct sock *sk, int level, int optname, ...@@ -474,7 +474,8 @@ static int chtls_getsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS) if (level != SOL_TLS)
return ctx->getsockopt(sk, level, optname, optval, optlen); return ctx->sk_proto->getsockopt(sk, level,
optname, optval, optlen);
return do_chtls_getsockopt(sk, optval, optlen); return do_chtls_getsockopt(sk, optval, optlen);
} }
...@@ -541,7 +542,8 @@ static int chtls_setsockopt(struct sock *sk, int level, int optname, ...@@ -541,7 +542,8 @@ static int chtls_setsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS) if (level != SOL_TLS)
return ctx->setsockopt(sk, level, optname, optval, optlen); return ctx->sk_proto->setsockopt(sk, level,
optname, optval, optlen);
return do_chtls_setsockopt(sk, optname, optval, optlen); return do_chtls_setsockopt(sk, optname, optval, optlen);
} }
......
...@@ -275,16 +275,6 @@ struct tls_context { ...@@ -275,16 +275,6 @@ struct tls_context {
struct proto *sk_proto; struct proto *sk_proto;
void (*sk_destruct)(struct sock *sk); void (*sk_destruct)(struct sock *sk);
void (*sk_proto_close)(struct sock *sk, long timeout);
int (*setsockopt)(struct sock *sk, int level,
int optname, char __user *optval,
unsigned int optlen);
int (*getsockopt)(struct sock *sk, int level,
int optname, char __user *optval,
int __user *optlen);
int (*hash)(struct sock *sk);
void (*unhash)(struct sock *sk);
union tls_crypto_context crypto_send; union tls_crypto_context crypto_send;
union tls_crypto_context crypto_recv; union tls_crypto_context crypto_recv;
......
...@@ -331,7 +331,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -331,7 +331,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_sw_strparser_done(ctx); tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW) if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx); tls_sw_free_ctx_rx(ctx);
ctx->sk_proto_close(sk, timeout); ctx->sk_proto->close(sk, timeout);
if (free_ctx) if (free_ctx)
tls_ctx_free(sk, ctx); tls_ctx_free(sk, ctx);
...@@ -451,7 +451,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname, ...@@ -451,7 +451,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS) if (level != SOL_TLS)
return ctx->getsockopt(sk, level, optname, optval, optlen); return ctx->sk_proto->getsockopt(sk, level,
optname, optval, optlen);
return do_tls_getsockopt(sk, optname, optval, optlen); return do_tls_getsockopt(sk, optname, optval, optlen);
} }
...@@ -609,7 +610,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname, ...@@ -609,7 +610,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk); struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS) if (level != SOL_TLS)
return ctx->setsockopt(sk, level, optname, optval, optlen); return ctx->sk_proto->setsockopt(sk, level, optname, optval,
optlen);
return do_tls_setsockopt(sk, optname, optval, optlen); return do_tls_setsockopt(sk, optname, optval, optlen);
} }
...@@ -624,10 +626,7 @@ static struct tls_context *create_ctx(struct sock *sk) ...@@ -624,10 +626,7 @@ static struct tls_context *create_ctx(struct sock *sk)
return NULL; return NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx); rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->setsockopt = sk->sk_prot->setsockopt; ctx->sk_proto = sk->sk_prot;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
ctx->unhash = sk->sk_prot->unhash;
return ctx; return ctx;
} }
...@@ -683,9 +682,6 @@ static int tls_hw_prot(struct sock *sk) ...@@ -683,9 +682,6 @@ static int tls_hw_prot(struct sock *sk)
spin_unlock_bh(&device_spinlock); spin_unlock_bh(&device_spinlock);
tls_build_proto(sk); tls_build_proto(sk);
ctx->hash = sk->sk_prot->hash;
ctx->unhash = sk->sk_prot->unhash;
ctx->sk_proto_close = sk->sk_prot->close;
ctx->sk_destruct = sk->sk_destruct; ctx->sk_destruct = sk->sk_destruct;
sk->sk_destruct = tls_hw_sk_destruct; sk->sk_destruct = tls_hw_sk_destruct;
ctx->rx_conf = TLS_HW_RECORD; ctx->rx_conf = TLS_HW_RECORD;
...@@ -717,7 +713,7 @@ static void tls_hw_unhash(struct sock *sk) ...@@ -717,7 +713,7 @@ static void tls_hw_unhash(struct sock *sk)
} }
} }
spin_unlock_bh(&device_spinlock); spin_unlock_bh(&device_spinlock);
ctx->unhash(sk); ctx->sk_proto->unhash(sk);
} }
static int tls_hw_hash(struct sock *sk) static int tls_hw_hash(struct sock *sk)
...@@ -726,7 +722,7 @@ static int tls_hw_hash(struct sock *sk) ...@@ -726,7 +722,7 @@ static int tls_hw_hash(struct sock *sk)
struct tls_device *dev; struct tls_device *dev;
int err; int err;
err = ctx->hash(sk); err = ctx->sk_proto->hash(sk);
spin_lock_bh(&device_spinlock); spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) { list_for_each_entry(dev, &device_list, dev_list) {
if (dev->hash) { if (dev->hash) {
...@@ -816,7 +812,6 @@ static int tls_init(struct sock *sk) ...@@ -816,7 +812,6 @@ static int tls_init(struct sock *sk)
ctx->tx_conf = TLS_BASE; ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE; ctx->rx_conf = TLS_BASE;
ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx); update_sk_prot(sk, ctx);
out: out:
write_unlock_bh(&sk->sk_callback_lock); write_unlock_bh(&sk->sk_callback_lock);
...@@ -828,12 +823,10 @@ static void tls_update(struct sock *sk, struct proto *p) ...@@ -828,12 +823,10 @@ static void tls_update(struct sock *sk, struct proto *p)
struct tls_context *ctx; struct tls_context *ctx;
ctx = tls_get_ctx(sk); ctx = tls_get_ctx(sk);
if (likely(ctx)) { if (likely(ctx))
ctx->sk_proto_close = p->close;
ctx->sk_proto = p; ctx->sk_proto = p;
} else { else
sk->sk_prot = p; sk->sk_prot = p;
}
} }
static int tls_get_info(const struct sock *sk, struct sk_buff *skb) static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment