Commit 0ac30f82 authored by Walter Dörwald's avatar Walter Dörwald

Enhance the punycode decoder so that it can decode

unicode objects.

Fix the idna codec and the tests.
parent 1f05a3b7
...@@ -7,7 +7,8 @@ from unicodedata import ucd_3_2_0 as unicodedata ...@@ -7,7 +7,8 @@ from unicodedata import ucd_3_2_0 as unicodedata
dots = re.compile("[\u002E\u3002\uFF0E\uFF61]") dots = re.compile("[\u002E\u3002\uFF0E\uFF61]")
# IDNA section 5 # IDNA section 5
ace_prefix = "xn--" ace_prefix = b"xn--"
sace_prefix = "xn--"
# This assumes query strings, so AllowUnassigned is true # This assumes query strings, so AllowUnassigned is true
def nameprep(label): def nameprep(label):
...@@ -87,7 +88,7 @@ def ToASCII(label): ...@@ -87,7 +88,7 @@ def ToASCII(label):
raise UnicodeError("label empty or too long") raise UnicodeError("label empty or too long")
# Step 5: Check ACE prefix # Step 5: Check ACE prefix
if label.startswith(ace_prefix): if label.startswith(sace_prefix):
raise UnicodeError("Label starts with ACE prefix") raise UnicodeError("Label starts with ACE prefix")
# Step 6: Encode with PUNYCODE # Step 6: Encode with PUNYCODE
...@@ -134,7 +135,7 @@ def ToUnicode(label): ...@@ -134,7 +135,7 @@ def ToUnicode(label):
# Step 7: Compare the result of step 6 with the one of step 3 # Step 7: Compare the result of step 6 with the one of step 3
# label2 will already be in lower case. # label2 will already be in lower case.
if label.lower() != label2: if str(label, "ascii").lower() != str(label2, "ascii"):
raise UnicodeError("IDNA does not round-trip", label, label2) raise UnicodeError("IDNA does not round-trip", label, label2)
# Step 8: return the result of step 5 # Step 8: return the result of step 5
...@@ -143,7 +144,7 @@ def ToUnicode(label): ...@@ -143,7 +144,7 @@ def ToUnicode(label):
### Codec APIs ### Codec APIs
class Codec(codecs.Codec): class Codec(codecs.Codec):
def encode(self,input,errors='strict'): def encode(self, input, errors='strict'):
if errors != 'strict': if errors != 'strict':
# IDNA is quite clear that implementations must be strict # IDNA is quite clear that implementations must be strict
...@@ -152,19 +153,21 @@ class Codec(codecs.Codec): ...@@ -152,19 +153,21 @@ class Codec(codecs.Codec):
if not input: if not input:
return b"", 0 return b"", 0
result = [] result = b""
labels = dots.split(input) labels = dots.split(input)
if labels and len(labels[-1])==0: if labels and not labels[-1]:
trailing_dot = b'.' trailing_dot = b'.'
del labels[-1] del labels[-1]
else: else:
trailing_dot = b'' trailing_dot = b''
for label in labels: for label in labels:
result.append(ToASCII(label)) if result:
# Join with U+002E # Join with U+002E
return b".".join(result)+trailing_dot, len(input) result.extend(b'.')
result.extend(ToASCII(label))
return result+trailing_dot, len(input)
def decode(self,input,errors='strict'): def decode(self, input, errors='strict'):
if errors != 'strict': if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors) raise UnicodeError("Unsupported error handling "+errors)
...@@ -199,30 +202,31 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder): ...@@ -199,30 +202,31 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
raise UnicodeError("unsupported error handling "+errors) raise UnicodeError("unsupported error handling "+errors)
if not input: if not input:
return ("", 0) return (b'', 0)
labels = dots.split(input) labels = dots.split(input)
trailing_dot = '' trailing_dot = b''
if labels: if labels:
if not labels[-1]: if not labels[-1]:
trailing_dot = '.' trailing_dot = b'.'
del labels[-1] del labels[-1]
elif not final: elif not final:
# Keep potentially unfinished label until the next call # Keep potentially unfinished label until the next call
del labels[-1] del labels[-1]
if labels: if labels:
trailing_dot = '.' trailing_dot = b'.'
result = [] result = b""
size = 0 size = 0
for label in labels: for label in labels:
result.append(ToASCII(label))
if size: if size:
# Join with U+002E
result.extend(b'.')
size += 1 size += 1
result.extend(ToASCII(label))
size += len(label) size += len(label)
# Join with U+002E result += trailing_dot
result = ".".join(result) + trailing_dot
size += len(trailing_dot) size += len(trailing_dot)
return (result, size) return (result, size)
...@@ -239,8 +243,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder): ...@@ -239,8 +243,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
labels = dots.split(input) labels = dots.split(input)
else: else:
# Must be ASCII string # Must be ASCII string
input = str(input) input = str(input, "ascii")
str(input, "ascii")
labels = input.split(".") labels = input.split(".")
trailing_dot = '' trailing_dot = ''
......
...@@ -181,6 +181,8 @@ def insertion_sort(base, extended, errors): ...@@ -181,6 +181,8 @@ def insertion_sort(base, extended, errors):
return base return base
def punycode_decode(text, errors): def punycode_decode(text, errors):
if isinstance(text, str):
text = text.encode("ascii")
pos = text.rfind(b"-") pos = text.rfind(b"-")
if pos == -1: if pos == -1:
base = "" base = ""
...@@ -194,11 +196,11 @@ def punycode_decode(text, errors): ...@@ -194,11 +196,11 @@ def punycode_decode(text, errors):
class Codec(codecs.Codec): class Codec(codecs.Codec):
def encode(self,input,errors='strict'): def encode(self, input, errors='strict'):
res = punycode_encode(input) res = punycode_encode(input)
return res, len(input) return res, len(input)
def decode(self,input,errors='strict'): def decode(self, input, errors='strict'):
if errors not in ('strict', 'replace', 'ignore'): if errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError, "Unsupported error handling "+errors raise UnicodeError, "Unsupported error handling "+errors
res = punycode_decode(input, errors) res = punycode_decode(input, errors)
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment