diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c
index 2642f317af8f655ea03b99b1b6a10471b3adbfa6..515dc7b6daea4396acbf96ae45730c311c2c4a7f 100644
--- a/net/ipv4/inet_diag.c
+++ b/net/ipv4/inet_diag.c
@@ -51,18 +51,6 @@ static struct sock *sdiagnl;
 #define INET_DIAG_PUT(skb, attrtype, attrlen) \
 	RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
-static inline int inet_diag_type2proto(int type)
-{
-	switch (type) {
-	case TCPDIAG_GETSOCK:
-		return IPPROTO_TCP;
-	case DCCPDIAG_GETSOCK:
-		return IPPROTO_DCCP;
-	default:
-		return 0;
-	}
-}
-
 static DEFINE_MUTEX(inet_diag_table_mutex);
 
 static const struct inet_diag_handler *inet_diag_lock_handler(int proto)
@@ -85,8 +73,8 @@ static inline void inet_diag_unlock_handler(
 }
 
 static int inet_csk_diag_fill(struct sock *sk,
-			      struct sk_buff *skb,
-			      int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+			      struct sk_buff *skb, struct inet_diag_req *req,
+			      u32 pid, u32 seq, u16 nlmsg_flags,
 			      const struct nlmsghdr *unlh)
 {
 	const struct inet_sock *inet = inet_sk(sk);
@@ -97,8 +85,9 @@ static int inet_csk_diag_fill(struct sock *sk,
 	struct inet_diag_meminfo  *minfo = NULL;
 	unsigned char	 *b = skb_tail_pointer(skb);
 	const struct inet_diag_handler *handler;
+	int ext = req->idiag_ext;
 
-	handler = inet_diag_table[inet_diag_type2proto(unlh->nlmsg_type)];
+	handler = inet_diag_table[req->sdiag_protocol];
 	BUG_ON(handler == NULL);
 
 	nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r));
@@ -198,8 +187,8 @@ static int inet_csk_diag_fill(struct sock *sk,
 }
 
 static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
-			       struct sk_buff *skb, int ext, u32 pid,
-			       u32 seq, u16 nlmsg_flags,
+			       struct sk_buff *skb, struct inet_diag_req *req,
+			       u32 pid, u32 seq, u16 nlmsg_flags,
 			       const struct nlmsghdr *unlh)
 {
 	long tmo;
@@ -250,14 +239,14 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
 }
 
 static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
-			int ext, u32 pid, u32 seq, u16 nlmsg_flags,
+			struct inet_diag_req *r, u32 pid, u32 seq, u16 nlmsg_flags,
 			const struct nlmsghdr *unlh)
 {
 	if (sk->sk_state == TCP_TIME_WAIT)
 		return inet_twsk_diag_fill((struct inet_timewait_sock *)sk,
-					   skb, ext, pid, seq, nlmsg_flags,
+					   skb, r, pid, seq, nlmsg_flags,
 					   unlh);
-	return inet_csk_diag_fill(sk, skb, ext, pid, seq, nlmsg_flags, unlh);
+	return inet_csk_diag_fill(sk, skb, r, pid, seq, nlmsg_flags, unlh);
 }
 
 static int inet_diag_get_exact(struct sk_buff *in_skb,
@@ -317,7 +306,7 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
 	if (!rep)
 		goto out;
 
-	err = sk_diag_fill(sk, rep, req->idiag_ext,
+	err = sk_diag_fill(sk, rep, req,
 			   NETLINK_CB(in_skb).pid,
 			   nlh->nlmsg_seq, 0, nlh);
 	if (err < 0) {
@@ -530,7 +519,7 @@ static int inet_csk_diag_dump(struct sock *sk,
 			return 0;
 	}
 
-	return inet_csk_diag_fill(sk, skb, r->idiag_ext,
+	return inet_csk_diag_fill(sk, skb, r,
 				  NETLINK_CB(cb->skb).pid,
 				  cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -565,7 +554,7 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
 			return 0;
 	}
 
-	return inet_twsk_diag_fill(tw, skb, r->idiag_ext,
+	return inet_twsk_diag_fill(tw, skb, r,
 				   NETLINK_CB(cb->skb).pid,
 				   cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
 }
@@ -876,6 +865,18 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
 	return __inet_diag_dump(skb, cb, (struct inet_diag_req *)NLMSG_DATA(cb->nlh), bc);
 }
 
+static inline int inet_diag_type2proto(int type)
+{
+	switch (type) {
+	case TCPDIAG_GETSOCK:
+		return IPPROTO_TCP;
+	case DCCPDIAG_GETSOCK:
+		return IPPROTO_DCCP;
+	default:
+		return 0;
+	}
+}
+
 static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)
 {
 	struct inet_diag_req_compat *rc = NLMSG_DATA(cb->nlh);