Commit de9f8696 authored by Jann Horn's avatar Jann Horn Committed by Linus Torvalds

x86/insn-eval: Fix use-after-free access to LDT entry

get_desc() computes a pointer into the LDT while holding a lock that
protects the LDT from being freed, but then drops the lock and returns the
(now potentially dangling) pointer to its caller.

Fix it by giving the caller a copy of the LDT entry instead.

Fixes: 670f928b ("x86/insn-eval: Add utility function to get segment descriptor")
Cc: stable@vger.kernel.org
Signed-off-by: default avatarJann Horn <jannh@google.com>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent 1e1d9263
...@@ -557,7 +557,8 @@ static int get_reg_offset_16(struct insn *insn, struct pt_regs *regs, ...@@ -557,7 +557,8 @@ static int get_reg_offset_16(struct insn *insn, struct pt_regs *regs,
} }
/** /**
* get_desc() - Obtain pointer to a segment descriptor * get_desc() - Obtain contents of a segment descriptor
* @out: Segment descriptor contents on success
* @sel: Segment selector * @sel: Segment selector
* *
* Given a segment selector, obtain a pointer to the segment descriptor. * Given a segment selector, obtain a pointer to the segment descriptor.
...@@ -565,18 +566,18 @@ static int get_reg_offset_16(struct insn *insn, struct pt_regs *regs, ...@@ -565,18 +566,18 @@ static int get_reg_offset_16(struct insn *insn, struct pt_regs *regs,
* *
* Returns: * Returns:
* *
* Pointer to segment descriptor on success. * True on success, false on failure.
* *
* NULL on error. * NULL on error.
*/ */
static struct desc_struct *get_desc(unsigned short sel) static bool get_desc(struct desc_struct *out, unsigned short sel)
{ {
struct desc_ptr gdt_desc = {0, 0}; struct desc_ptr gdt_desc = {0, 0};
unsigned long desc_base; unsigned long desc_base;
#ifdef CONFIG_MODIFY_LDT_SYSCALL #ifdef CONFIG_MODIFY_LDT_SYSCALL
if ((sel & SEGMENT_TI_MASK) == SEGMENT_LDT) { if ((sel & SEGMENT_TI_MASK) == SEGMENT_LDT) {
struct desc_struct *desc = NULL; bool success = false;
struct ldt_struct *ldt; struct ldt_struct *ldt;
/* Bits [15:3] contain the index of the desired entry. */ /* Bits [15:3] contain the index of the desired entry. */
...@@ -584,12 +585,14 @@ static struct desc_struct *get_desc(unsigned short sel) ...@@ -584,12 +585,14 @@ static struct desc_struct *get_desc(unsigned short sel)
mutex_lock(&current->active_mm->context.lock); mutex_lock(&current->active_mm->context.lock);
ldt = current->active_mm->context.ldt; ldt = current->active_mm->context.ldt;
if (ldt && sel < ldt->nr_entries) if (ldt && sel < ldt->nr_entries) {
desc = &ldt->entries[sel]; *out = ldt->entries[sel];
success = true;
}
mutex_unlock(&current->active_mm->context.lock); mutex_unlock(&current->active_mm->context.lock);
return desc; return success;
} }
#endif #endif
native_store_gdt(&gdt_desc); native_store_gdt(&gdt_desc);
...@@ -604,9 +607,10 @@ static struct desc_struct *get_desc(unsigned short sel) ...@@ -604,9 +607,10 @@ static struct desc_struct *get_desc(unsigned short sel)
desc_base = sel & ~(SEGMENT_RPL_MASK | SEGMENT_TI_MASK); desc_base = sel & ~(SEGMENT_RPL_MASK | SEGMENT_TI_MASK);
if (desc_base > gdt_desc.size) if (desc_base > gdt_desc.size)
return NULL; return false;
return (struct desc_struct *)(gdt_desc.address + desc_base); *out = *(struct desc_struct *)(gdt_desc.address + desc_base);
return true;
} }
/** /**
...@@ -628,7 +632,7 @@ static struct desc_struct *get_desc(unsigned short sel) ...@@ -628,7 +632,7 @@ static struct desc_struct *get_desc(unsigned short sel)
*/ */
unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx) unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx)
{ {
struct desc_struct *desc; struct desc_struct desc;
short sel; short sel;
sel = get_segment_selector(regs, seg_reg_idx); sel = get_segment_selector(regs, seg_reg_idx);
...@@ -666,11 +670,10 @@ unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx) ...@@ -666,11 +670,10 @@ unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx)
if (!sel) if (!sel)
return -1L; return -1L;
desc = get_desc(sel); if (!get_desc(&desc, sel))
if (!desc)
return -1L; return -1L;
return get_desc_base(desc); return get_desc_base(&desc);
} }
/** /**
...@@ -692,7 +695,7 @@ unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx) ...@@ -692,7 +695,7 @@ unsigned long insn_get_seg_base(struct pt_regs *regs, int seg_reg_idx)
*/ */
static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx) static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx)
{ {
struct desc_struct *desc; struct desc_struct desc;
unsigned long limit; unsigned long limit;
short sel; short sel;
...@@ -706,8 +709,7 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx) ...@@ -706,8 +709,7 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx)
if (!sel) if (!sel)
return 0; return 0;
desc = get_desc(sel); if (!get_desc(&desc, sel))
if (!desc)
return 0; return 0;
/* /*
...@@ -716,8 +718,8 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx) ...@@ -716,8 +718,8 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx)
* not tested when checking the segment limits. In practice, * not tested when checking the segment limits. In practice,
* this means that the segment ends in (limit << 12) + 0xfff. * this means that the segment ends in (limit << 12) + 0xfff.
*/ */
limit = get_desc_limit(desc); limit = get_desc_limit(&desc);
if (desc->g) if (desc.g)
limit = (limit << 12) + 0xfff; limit = (limit << 12) + 0xfff;
return limit; return limit;
...@@ -741,7 +743,7 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx) ...@@ -741,7 +743,7 @@ static unsigned long get_seg_limit(struct pt_regs *regs, int seg_reg_idx)
*/ */
int insn_get_code_seg_params(struct pt_regs *regs) int insn_get_code_seg_params(struct pt_regs *regs)
{ {
struct desc_struct *desc; struct desc_struct desc;
short sel; short sel;
if (v8086_mode(regs)) if (v8086_mode(regs))
...@@ -752,8 +754,7 @@ int insn_get_code_seg_params(struct pt_regs *regs) ...@@ -752,8 +754,7 @@ int insn_get_code_seg_params(struct pt_regs *regs)
if (sel < 0) if (sel < 0)
return sel; return sel;
desc = get_desc(sel); if (!get_desc(&desc, sel))
if (!desc)
return -EINVAL; return -EINVAL;
/* /*
...@@ -761,10 +762,10 @@ int insn_get_code_seg_params(struct pt_regs *regs) ...@@ -761,10 +762,10 @@ int insn_get_code_seg_params(struct pt_regs *regs)
* determines whether a segment contains data or code. If this is a data * determines whether a segment contains data or code. If this is a data
* segment, return error. * segment, return error.
*/ */
if (!(desc->type & BIT(3))) if (!(desc.type & BIT(3)))
return -EINVAL; return -EINVAL;
switch ((desc->l << 1) | desc->d) { switch ((desc.l << 1) | desc.d) {
case 0: /* case 0: /*
* Legacy mode. CS.L=0, CS.D=0. Address and operand size are * Legacy mode. CS.L=0, CS.D=0. Address and operand size are
* both 16-bit. * both 16-bit.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment