Commit adae1e93 authored by Andres Beltran's avatar Andres Beltran Committed by Wei Liu

Drivers: hv: vmbus: Copy packets sent by Hyper-V out of the ring buffer

Pointers to ring-buffer packets sent by Hyper-V are used within the
guest VM. Hyper-V can send packets with erroneous values or modify
packet fields after they are processed by the guest. To defend
against these scenarios, return a copy of the incoming VMBus packet
after validating its length and offset fields in hv_pkt_iter_first().
In this way, the packet can no longer be modified by the host.
Signed-off-by: default avatarAndres Beltran <lkmlabelt@gmail.com>
Co-developed-by: default avatarAndrea Parri (Microsoft) <parri.andrea@gmail.com>
Signed-off-by: default avatarAndrea Parri (Microsoft) <parri.andrea@gmail.com>
Reviewed-by: default avatarMichael Kelley <mikelley@microsoft.com>
Link: https://lore.kernel.org/r/20210408161439.341988-1-parri.andrea@gmail.comSigned-off-by: default avatarWei Liu <wei.liu@kernel.org>
parent 03b30cc3
...@@ -662,12 +662,15 @@ static int __vmbus_open(struct vmbus_channel *newchannel, ...@@ -662,12 +662,15 @@ static int __vmbus_open(struct vmbus_channel *newchannel,
newchannel->onchannel_callback = onchannelcallback; newchannel->onchannel_callback = onchannelcallback;
newchannel->channel_callback_context = context; newchannel->channel_callback_context = context;
err = hv_ringbuffer_init(&newchannel->outbound, page, send_pages); if (!newchannel->max_pkt_size)
newchannel->max_pkt_size = VMBUS_DEFAULT_MAX_PKT_SIZE;
err = hv_ringbuffer_init(&newchannel->outbound, page, send_pages, 0);
if (err) if (err)
goto error_clean_ring; goto error_clean_ring;
err = hv_ringbuffer_init(&newchannel->inbound, err = hv_ringbuffer_init(&newchannel->inbound, &page[send_pages],
&page[send_pages], recv_pages); recv_pages, newchannel->max_pkt_size);
if (err) if (err)
goto error_clean_ring; goto error_clean_ring;
......
...@@ -349,6 +349,7 @@ int hv_fcopy_init(struct hv_util_service *srv) ...@@ -349,6 +349,7 @@ int hv_fcopy_init(struct hv_util_service *srv)
{ {
recv_buffer = srv->recv_buffer; recv_buffer = srv->recv_buffer;
fcopy_transaction.recv_channel = srv->channel; fcopy_transaction.recv_channel = srv->channel;
fcopy_transaction.recv_channel->max_pkt_size = HV_HYP_PAGE_SIZE * 2;
/* /*
* When this driver loads, the user level daemon that * When this driver loads, the user level daemon that
......
...@@ -757,6 +757,7 @@ hv_kvp_init(struct hv_util_service *srv) ...@@ -757,6 +757,7 @@ hv_kvp_init(struct hv_util_service *srv)
{ {
recv_buffer = srv->recv_buffer; recv_buffer = srv->recv_buffer;
kvp_transaction.recv_channel = srv->channel; kvp_transaction.recv_channel = srv->channel;
kvp_transaction.recv_channel->max_pkt_size = HV_HYP_PAGE_SIZE * 4;
/* /*
* When this driver loads, the user level daemon that * When this driver loads, the user level daemon that
......
...@@ -174,7 +174,7 @@ extern int hv_synic_cleanup(unsigned int cpu); ...@@ -174,7 +174,7 @@ extern int hv_synic_cleanup(unsigned int cpu);
void hv_ringbuffer_pre_init(struct vmbus_channel *channel); void hv_ringbuffer_pre_init(struct vmbus_channel *channel);
int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info, int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
struct page *pages, u32 pagecnt); struct page *pages, u32 pagecnt, u32 max_pkt_size);
void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info); void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info);
......
...@@ -181,7 +181,7 @@ void hv_ringbuffer_pre_init(struct vmbus_channel *channel) ...@@ -181,7 +181,7 @@ void hv_ringbuffer_pre_init(struct vmbus_channel *channel)
/* Initialize the ring buffer. */ /* Initialize the ring buffer. */
int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info, int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
struct page *pages, u32 page_cnt) struct page *pages, u32 page_cnt, u32 max_pkt_size)
{ {
int i; int i;
struct page **pages_wraparound; struct page **pages_wraparound;
...@@ -223,6 +223,14 @@ int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info, ...@@ -223,6 +223,14 @@ int hv_ringbuffer_init(struct hv_ring_buffer_info *ring_info,
sizeof(struct hv_ring_buffer); sizeof(struct hv_ring_buffer);
ring_info->priv_read_index = 0; ring_info->priv_read_index = 0;
/* Initialize buffer that holds copies of incoming packets */
if (max_pkt_size) {
ring_info->pkt_buffer = kzalloc(max_pkt_size, GFP_KERNEL);
if (!ring_info->pkt_buffer)
return -ENOMEM;
ring_info->pkt_buffer_size = max_pkt_size;
}
spin_lock_init(&ring_info->ring_lock); spin_lock_init(&ring_info->ring_lock);
return 0; return 0;
...@@ -235,6 +243,9 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info) ...@@ -235,6 +243,9 @@ void hv_ringbuffer_cleanup(struct hv_ring_buffer_info *ring_info)
vunmap(ring_info->ring_buffer); vunmap(ring_info->ring_buffer);
ring_info->ring_buffer = NULL; ring_info->ring_buffer = NULL;
mutex_unlock(&ring_info->ring_buffer_mutex); mutex_unlock(&ring_info->ring_buffer_mutex);
kfree(ring_info->pkt_buffer);
ring_info->pkt_buffer_size = 0;
} }
/* Write to the ring buffer. */ /* Write to the ring buffer. */
...@@ -375,7 +386,7 @@ int hv_ringbuffer_read(struct vmbus_channel *channel, ...@@ -375,7 +386,7 @@ int hv_ringbuffer_read(struct vmbus_channel *channel,
memcpy(buffer, (const char *)desc + offset, packetlen); memcpy(buffer, (const char *)desc + offset, packetlen);
/* Advance ring index to next packet descriptor */ /* Advance ring index to next packet descriptor */
__hv_pkt_iter_next(channel, desc); __hv_pkt_iter_next(channel, desc, true);
/* Notify host of update */ /* Notify host of update */
hv_pkt_iter_close(channel); hv_pkt_iter_close(channel);
...@@ -401,6 +412,22 @@ static u32 hv_pkt_iter_avail(const struct hv_ring_buffer_info *rbi) ...@@ -401,6 +412,22 @@ static u32 hv_pkt_iter_avail(const struct hv_ring_buffer_info *rbi)
return (rbi->ring_datasize - priv_read_loc) + write_loc; return (rbi->ring_datasize - priv_read_loc) + write_loc;
} }
/*
* Get first vmbus packet without copying it out of the ring buffer
*/
struct vmpacket_descriptor *hv_pkt_iter_first_raw(struct vmbus_channel *channel)
{
struct hv_ring_buffer_info *rbi = &channel->inbound;
hv_debug_delay_test(channel, MESSAGE_DELAY);
if (hv_pkt_iter_avail(rbi) < sizeof(struct vmpacket_descriptor))
return NULL;
return (struct vmpacket_descriptor *)(hv_get_ring_buffer(rbi) + rbi->priv_read_index);
}
EXPORT_SYMBOL_GPL(hv_pkt_iter_first_raw);
/* /*
* Get first vmbus packet from ring buffer after read_index * Get first vmbus packet from ring buffer after read_index
* *
...@@ -409,17 +436,49 @@ static u32 hv_pkt_iter_avail(const struct hv_ring_buffer_info *rbi) ...@@ -409,17 +436,49 @@ static u32 hv_pkt_iter_avail(const struct hv_ring_buffer_info *rbi)
struct vmpacket_descriptor *hv_pkt_iter_first(struct vmbus_channel *channel) struct vmpacket_descriptor *hv_pkt_iter_first(struct vmbus_channel *channel)
{ {
struct hv_ring_buffer_info *rbi = &channel->inbound; struct hv_ring_buffer_info *rbi = &channel->inbound;
struct vmpacket_descriptor *desc; struct vmpacket_descriptor *desc, *desc_copy;
u32 bytes_avail, pkt_len, pkt_offset;
hv_debug_delay_test(channel, MESSAGE_DELAY); desc = hv_pkt_iter_first_raw(channel);
if (hv_pkt_iter_avail(rbi) < sizeof(struct vmpacket_descriptor)) if (!desc)
return NULL; return NULL;
desc = hv_get_ring_buffer(rbi) + rbi->priv_read_index; bytes_avail = min(rbi->pkt_buffer_size, hv_pkt_iter_avail(rbi));
if (desc)
prefetch((char *)desc + (desc->len8 << 3)); /*
* Ensure the compiler does not use references to incoming Hyper-V values (which
* could change at any moment) when reading local variables later in the code
*/
pkt_len = READ_ONCE(desc->len8) << 3;
pkt_offset = READ_ONCE(desc->offset8) << 3;
/*
* If pkt_len is invalid, set it to the smaller of hv_pkt_iter_avail() and
* rbi->pkt_buffer_size
*/
if (pkt_len < sizeof(struct vmpacket_descriptor) || pkt_len > bytes_avail)
pkt_len = bytes_avail;
/*
* If pkt_offset is invalid, arbitrarily set it to
* the size of vmpacket_descriptor
*/
if (pkt_offset < sizeof(struct vmpacket_descriptor) || pkt_offset > pkt_len)
pkt_offset = sizeof(struct vmpacket_descriptor);
/* Copy the Hyper-V packet out of the ring buffer */
desc_copy = (struct vmpacket_descriptor *)rbi->pkt_buffer;
memcpy(desc_copy, desc, pkt_len);
/*
* Hyper-V could still change len8 and offset8 after the earlier read.
* Ensure that desc_copy has legal values for len8 and offset8 that
* are consistent with the copy we just made
*/
desc_copy->len8 = pkt_len >> 3;
desc_copy->offset8 = pkt_offset >> 3;
return desc; return desc_copy;
} }
EXPORT_SYMBOL_GPL(hv_pkt_iter_first); EXPORT_SYMBOL_GPL(hv_pkt_iter_first);
...@@ -431,7 +490,8 @@ EXPORT_SYMBOL_GPL(hv_pkt_iter_first); ...@@ -431,7 +490,8 @@ EXPORT_SYMBOL_GPL(hv_pkt_iter_first);
*/ */
struct vmpacket_descriptor * struct vmpacket_descriptor *
__hv_pkt_iter_next(struct vmbus_channel *channel, __hv_pkt_iter_next(struct vmbus_channel *channel,
const struct vmpacket_descriptor *desc) const struct vmpacket_descriptor *desc,
bool copy)
{ {
struct hv_ring_buffer_info *rbi = &channel->inbound; struct hv_ring_buffer_info *rbi = &channel->inbound;
u32 packetlen = desc->len8 << 3; u32 packetlen = desc->len8 << 3;
...@@ -444,7 +504,7 @@ __hv_pkt_iter_next(struct vmbus_channel *channel, ...@@ -444,7 +504,7 @@ __hv_pkt_iter_next(struct vmbus_channel *channel,
rbi->priv_read_index -= dsize; rbi->priv_read_index -= dsize;
/* more data? */ /* more data? */
return hv_pkt_iter_first(channel); return copy ? hv_pkt_iter_first(channel) : hv_pkt_iter_first_raw(channel);
} }
EXPORT_SYMBOL_GPL(__hv_pkt_iter_next); EXPORT_SYMBOL_GPL(__hv_pkt_iter_next);
......
...@@ -895,9 +895,16 @@ static inline u32 netvsc_rqstor_size(unsigned long ringbytes) ...@@ -895,9 +895,16 @@ static inline u32 netvsc_rqstor_size(unsigned long ringbytes)
ringbytes / NETVSC_MIN_IN_MSG_SIZE; ringbytes / NETVSC_MIN_IN_MSG_SIZE;
} }
/* XFER PAGE packets can specify a maximum of 375 ranges for NDIS >= 6.0
* and a maximum of 64 ranges for NDIS < 6.0 with no RSC; with RSC, this
* limit is raised to 562 (= NVSP_RSC_MAX).
*/
#define NETVSC_MAX_XFER_PAGE_RANGES NVSP_RSC_MAX
#define NETVSC_XFER_HEADER_SIZE(rng_cnt) \ #define NETVSC_XFER_HEADER_SIZE(rng_cnt) \
(offsetof(struct vmtransfer_page_packet_header, ranges) + \ (offsetof(struct vmtransfer_page_packet_header, ranges) + \
(rng_cnt) * sizeof(struct vmtransfer_page_range)) (rng_cnt) * sizeof(struct vmtransfer_page_range))
#define NETVSC_MAX_PKT_SIZE (NETVSC_XFER_HEADER_SIZE(NETVSC_MAX_XFER_PAGE_RANGES) + \
sizeof(struct nvsp_message) + (sizeof(u32) * VRSS_SEND_TAB_SIZE))
struct multi_send_data { struct multi_send_data {
struct sk_buff *skb; /* skb containing the pkt */ struct sk_buff *skb; /* skb containing the pkt */
......
...@@ -1650,6 +1650,8 @@ struct netvsc_device *netvsc_device_add(struct hv_device *device, ...@@ -1650,6 +1650,8 @@ struct netvsc_device *netvsc_device_add(struct hv_device *device,
/* Open the channel */ /* Open the channel */
device->channel->rqstor_size = netvsc_rqstor_size(netvsc_ring_bytes); device->channel->rqstor_size = netvsc_rqstor_size(netvsc_ring_bytes);
device->channel->max_pkt_size = NETVSC_MAX_PKT_SIZE;
ret = vmbus_open(device->channel, netvsc_ring_bytes, ret = vmbus_open(device->channel, netvsc_ring_bytes,
netvsc_ring_bytes, NULL, 0, netvsc_ring_bytes, NULL, 0,
netvsc_channel_cb, net_device->chan_table); netvsc_channel_cb, net_device->chan_table);
......
...@@ -1260,6 +1260,8 @@ static void netvsc_sc_open(struct vmbus_channel *new_sc) ...@@ -1260,6 +1260,8 @@ static void netvsc_sc_open(struct vmbus_channel *new_sc)
nvchan->channel = new_sc; nvchan->channel = new_sc;
new_sc->rqstor_size = netvsc_rqstor_size(netvsc_ring_bytes); new_sc->rqstor_size = netvsc_rqstor_size(netvsc_ring_bytes);
new_sc->max_pkt_size = NETVSC_MAX_PKT_SIZE;
ret = vmbus_open(new_sc, netvsc_ring_bytes, ret = vmbus_open(new_sc, netvsc_ring_bytes,
netvsc_ring_bytes, NULL, 0, netvsc_ring_bytes, NULL, 0,
netvsc_channel_cb, nvchan); netvsc_channel_cb, nvchan);
......
...@@ -406,6 +406,14 @@ static void storvsc_on_channel_callback(void *context); ...@@ -406,6 +406,14 @@ static void storvsc_on_channel_callback(void *context);
#define STORVSC_IDE_MAX_TARGETS 1 #define STORVSC_IDE_MAX_TARGETS 1
#define STORVSC_IDE_MAX_CHANNELS 1 #define STORVSC_IDE_MAX_CHANNELS 1
/*
* Upper bound on the size of a storvsc packet. vmscsi_size_delta is not
* included in the calculation because it is set after STORVSC_MAX_PKT_SIZE
* is used in storvsc_connect_to_vsp
*/
#define STORVSC_MAX_PKT_SIZE (sizeof(struct vmpacket_descriptor) +\
sizeof(struct vstor_packet))
struct storvsc_cmd_request { struct storvsc_cmd_request {
struct scsi_cmnd *cmd; struct scsi_cmnd *cmd;
...@@ -701,6 +709,7 @@ static void handle_sc_creation(struct vmbus_channel *new_sc) ...@@ -701,6 +709,7 @@ static void handle_sc_creation(struct vmbus_channel *new_sc)
return; return;
memset(&props, 0, sizeof(struct vmstorage_channel_properties)); memset(&props, 0, sizeof(struct vmstorage_channel_properties));
new_sc->max_pkt_size = STORVSC_MAX_PKT_SIZE;
/* /*
* The size of vmbus_requestor is an upper bound on the number of requests * The size of vmbus_requestor is an upper bound on the number of requests
...@@ -1294,6 +1303,7 @@ static int storvsc_connect_to_vsp(struct hv_device *device, u32 ring_size, ...@@ -1294,6 +1303,7 @@ static int storvsc_connect_to_vsp(struct hv_device *device, u32 ring_size,
memset(&props, 0, sizeof(struct vmstorage_channel_properties)); memset(&props, 0, sizeof(struct vmstorage_channel_properties));
device->channel->max_pkt_size = STORVSC_MAX_PKT_SIZE;
/* /*
* The size of vmbus_requestor is an upper bound on the number of requests * The size of vmbus_requestor is an upper bound on the number of requests
* that can be in-progress at any one time across all channels. * that can be in-progress at any one time across all channels.
......
...@@ -181,6 +181,10 @@ struct hv_ring_buffer_info { ...@@ -181,6 +181,10 @@ struct hv_ring_buffer_info {
* being freed while the ring buffer is being accessed. * being freed while the ring buffer is being accessed.
*/ */
struct mutex ring_buffer_mutex; struct mutex ring_buffer_mutex;
/* Buffer that holds a copy of an incoming host packet */
void *pkt_buffer;
u32 pkt_buffer_size;
}; };
...@@ -799,6 +803,8 @@ struct vmbus_device { ...@@ -799,6 +803,8 @@ struct vmbus_device {
bool allowed_in_isolated; bool allowed_in_isolated;
}; };
#define VMBUS_DEFAULT_MAX_PKT_SIZE 4096
struct vmbus_channel { struct vmbus_channel {
struct list_head listentry; struct list_head listentry;
...@@ -1021,6 +1027,9 @@ struct vmbus_channel { ...@@ -1021,6 +1027,9 @@ struct vmbus_channel {
/* request/transaction ids for VMBus */ /* request/transaction ids for VMBus */
struct vmbus_requestor requestor; struct vmbus_requestor requestor;
u32 rqstor_size; u32 rqstor_size;
/* The max size of a packet on this channel */
u32 max_pkt_size;
}; };
u64 vmbus_next_request_id(struct vmbus_requestor *rqstor, u64 rqst_addr); u64 vmbus_next_request_id(struct vmbus_requestor *rqstor, u64 rqst_addr);
...@@ -1662,32 +1671,55 @@ static inline u32 hv_pkt_datalen(const struct vmpacket_descriptor *desc) ...@@ -1662,32 +1671,55 @@ static inline u32 hv_pkt_datalen(const struct vmpacket_descriptor *desc)
} }
struct vmpacket_descriptor *
hv_pkt_iter_first_raw(struct vmbus_channel *channel);
struct vmpacket_descriptor * struct vmpacket_descriptor *
hv_pkt_iter_first(struct vmbus_channel *channel); hv_pkt_iter_first(struct vmbus_channel *channel);
struct vmpacket_descriptor * struct vmpacket_descriptor *
__hv_pkt_iter_next(struct vmbus_channel *channel, __hv_pkt_iter_next(struct vmbus_channel *channel,
const struct vmpacket_descriptor *pkt); const struct vmpacket_descriptor *pkt,
bool copy);
void hv_pkt_iter_close(struct vmbus_channel *channel); void hv_pkt_iter_close(struct vmbus_channel *channel);
/*
* Get next packet descriptor from iterator
* If at end of list, return NULL and update host.
*/
static inline struct vmpacket_descriptor * static inline struct vmpacket_descriptor *
hv_pkt_iter_next(struct vmbus_channel *channel, hv_pkt_iter_next_pkt(struct vmbus_channel *channel,
const struct vmpacket_descriptor *pkt) const struct vmpacket_descriptor *pkt,
bool copy)
{ {
struct vmpacket_descriptor *nxt; struct vmpacket_descriptor *nxt;
nxt = __hv_pkt_iter_next(channel, pkt); nxt = __hv_pkt_iter_next(channel, pkt, copy);
if (!nxt) if (!nxt)
hv_pkt_iter_close(channel); hv_pkt_iter_close(channel);
return nxt; return nxt;
} }
/*
* Get next packet descriptor without copying it out of the ring buffer
* If at end of list, return NULL and update host.
*/
static inline struct vmpacket_descriptor *
hv_pkt_iter_next_raw(struct vmbus_channel *channel,
const struct vmpacket_descriptor *pkt)
{
return hv_pkt_iter_next_pkt(channel, pkt, false);
}
/*
* Get next packet descriptor from iterator
* If at end of list, return NULL and update host.
*/
static inline struct vmpacket_descriptor *
hv_pkt_iter_next(struct vmbus_channel *channel,
const struct vmpacket_descriptor *pkt)
{
return hv_pkt_iter_next_pkt(channel, pkt, true);
}
#define foreach_vmbus_pkt(pkt, channel) \ #define foreach_vmbus_pkt(pkt, channel) \
for (pkt = hv_pkt_iter_first(channel); pkt; \ for (pkt = hv_pkt_iter_first(channel); pkt; \
pkt = hv_pkt_iter_next(channel, pkt)) pkt = hv_pkt_iter_next(channel, pkt))
......
...@@ -596,7 +596,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, ...@@ -596,7 +596,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
return -EOPNOTSUPP; return -EOPNOTSUPP;
if (need_refill) { if (need_refill) {
hvs->recv_desc = hv_pkt_iter_first(hvs->chan); hvs->recv_desc = hv_pkt_iter_first_raw(hvs->chan);
ret = hvs_update_recv_data(hvs); ret = hvs_update_recv_data(hvs);
if (ret) if (ret)
return ret; return ret;
...@@ -610,7 +610,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg, ...@@ -610,7 +610,7 @@ static ssize_t hvs_stream_dequeue(struct vsock_sock *vsk, struct msghdr *msg,
hvs->recv_data_len -= to_read; hvs->recv_data_len -= to_read;
if (hvs->recv_data_len == 0) { if (hvs->recv_data_len == 0) {
hvs->recv_desc = hv_pkt_iter_next(hvs->chan, hvs->recv_desc); hvs->recv_desc = hv_pkt_iter_next_raw(hvs->chan, hvs->recv_desc);
if (hvs->recv_desc) { if (hvs->recv_desc) {
ret = hvs_update_recv_data(hvs); ret = hvs_update_recv_data(hvs);
if (ret) if (ret)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment