Commit 0ed33b7a authored by Steve Kowalik's avatar Steve Kowalik

Shift requirement parsing inside Requirement

parent a46fd832
...@@ -2693,15 +2693,11 @@ class DistInfoDistribution(Distribution): ...@@ -2693,15 +2693,11 @@ class DistInfoDistribution(Distribution):
reqs = [] reqs = []
# Including any condition expressions # Including any condition expressions
for req in self._parsed_pkg_info.get_all('Requires-Dist') or []: for req in self._parsed_pkg_info.get_all('Requires-Dist') or []:
current_req = packaging.requirements.Requirement(req) reqs.extend(parse_requirements(req))
specs = _parse_requirement_specs(current_req)
parsed = Requirement(current_req.name, specs, current_req.extras)
parsed._marker = current_req.marker
reqs.append(parsed)
def reqs_for_extra(extra): def reqs_for_extra(extra):
for req in reqs: for req in reqs:
if not req._marker or req._marker.evaluate({'extra': extra}): if not req.marker or req.marker.evaluate({'extra': extra}):
yield req yield req
common = frozenset(reqs_for_extra(None)) common = frozenset(reqs_for_extra(None))
...@@ -2739,10 +2735,6 @@ class RequirementParseError(ValueError): ...@@ -2739,10 +2735,6 @@ class RequirementParseError(ValueError):
return ' '.join(self.args) return ' '.join(self.args)
def _parse_requirement_specs(req):
return [(spec.operator, spec.version) for spec in req.specifier]
def parse_requirements(strs): def parse_requirements(strs):
"""Yield ``Requirement`` objects for each specification in `strs` """Yield ``Requirement`` objects for each specification in `strs`
...@@ -2759,33 +2751,35 @@ def parse_requirements(strs): ...@@ -2759,33 +2751,35 @@ def parse_requirements(strs):
if line.endswith('\\'): if line.endswith('\\'):
line = line[:-2].strip() line = line[:-2].strip()
line += next(lines) line += next(lines)
req = packaging.requirements.Requirement(line) yield Requirement(line)
specs = _parse_requirement_specs(req)
yield Requirement(req.name, specs, req.extras)
class Requirement: class Requirement:
def __init__(self, project_name, specs, extras): def __init__(self, requirement_string):
"""DO NOT CALL THIS UNDOCUMENTED METHOD; use Requirement.parse()!""" """DO NOT CALL THIS UNDOCUMENTED METHOD; use Requirement.parse()!"""
self.unsafe_name, project_name = project_name, safe_name(project_name) try:
self.req = packaging.requirements.Requirement(requirement_string)
except packaging.requirements.InvalidRequirement as e:
raise RequirementParseError(str(e))
self.unsafe_name = self.req.name
project_name = safe_name(self.req.name)
self.project_name, self.key = project_name, project_name.lower() self.project_name, self.key = project_name, project_name.lower()
self.specifier = packaging.specifiers.SpecifierSet( self.specifier = self.req.specifier
",".join(["".join([x, y]) for x, y in specs]) self.specs = [
) (spec.operator, spec.version) for spec in self.req.specifier]
self.specs = specs self.extras = tuple(map(safe_extra, self.req.extras))
self.extras = tuple(map(safe_extra, extras)) self.marker = self.req.marker
self.url = self.req.url
self.hashCmp = ( self.hashCmp = (
self.key, self.key,
self.specifier, self.specifier,
frozenset(self.extras), frozenset(self.extras),
str(self.marker)
) )
self.__hash = hash(self.hashCmp) self.__hash = hash(self.hashCmp)
def __str__(self): def __str__(self):
extras = ','.join(self.extras) return str(self.req)
if extras:
extras = '[%s]' % extras
return '%s%s%s' % (self.project_name, extras, self.specifier)
def __eq__(self, other): def __eq__(self, other):
return ( return (
......
...@@ -353,22 +353,22 @@ class TestRequirements: ...@@ -353,22 +353,22 @@ class TestRequirements:
r = Requirement.parse("Twisted>=1.2") r = Requirement.parse("Twisted>=1.2")
assert str(r) == "Twisted>=1.2" assert str(r) == "Twisted>=1.2"
assert repr(r) == "Requirement.parse('Twisted>=1.2')" assert repr(r) == "Requirement.parse('Twisted>=1.2')"
assert r == Requirement("Twisted", [('>=','1.2')], ()) assert r == Requirement("Twisted>=1.2")
assert r == Requirement("twisTed", [('>=','1.2')], ()) assert r == Requirement("twisTed>=1.2")
assert r != Requirement("Twisted", [('>=','2.0')], ()) assert r != Requirement("Twisted>=2.0")
assert r != Requirement("Zope", [('>=','1.2')], ()) assert r != Requirement("Zope>=1.2")
assert r != Requirement("Zope", [('>=','3.0')], ()) assert r != Requirement("Zope>=3.0")
assert r != Requirement.parse("Twisted[extras]>=1.2") assert r != Requirement("Twisted[extras]>=1.2")
def testOrdering(self): def testOrdering(self):
r1 = Requirement("Twisted", [('==','1.2c1'),('>=','1.2')], ()) r1 = Requirement("Twisted==1.2c1,>=1.2")
r2 = Requirement("Twisted", [('>=','1.2'),('==','1.2c1')], ()) r2 = Requirement("Twisted>=1.2,==1.2c1")
assert r1 == r2 assert r1 == r2
assert str(r1) == str(r2) assert str(r1) == str(r2)
assert str(r2) == "Twisted==1.2c1,>=1.2" assert str(r2) == "Twisted==1.2c1,>=1.2"
def testBasicContains(self): def testBasicContains(self):
r = Requirement("Twisted", [('>=','1.2')], ()) r = Requirement("Twisted>=1.2")
foo_dist = Distribution.from_filename("FooPkg-1.3_1.egg") foo_dist = Distribution.from_filename("FooPkg-1.3_1.egg")
twist11 = Distribution.from_filename("Twisted-1.1.egg") twist11 = Distribution.from_filename("Twisted-1.1.egg")
twist12 = Distribution.from_filename("Twisted-1.2.egg") twist12 = Distribution.from_filename("Twisted-1.2.egg")
...@@ -394,6 +394,7 @@ class TestRequirements: ...@@ -394,6 +394,7 @@ class TestRequirements:
"twisted", "twisted",
packaging.specifiers.SpecifierSet(">=1.2"), packaging.specifiers.SpecifierSet(">=1.2"),
frozenset(["foo","bar"]), frozenset(["foo","bar"]),
'None'
)) ))
) )
...@@ -485,17 +486,17 @@ class TestParsing: ...@@ -485,17 +486,17 @@ class TestParsing:
assert ( assert (
list(parse_requirements('Twis-Ted>=1.2-1')) list(parse_requirements('Twis-Ted>=1.2-1'))
== ==
[Requirement('Twis-Ted',[('>=','1.2-1')], ())] [Requirement('Twis-Ted>=1.2-1')]
) )
assert ( assert (
list(parse_requirements('Twisted >=1.2, \ # more\n<2.0')) list(parse_requirements('Twisted >=1.2, \ # more\n<2.0'))
== ==
[Requirement('Twisted',[('>=','1.2'),('<','2.0')], ())] [Requirement('Twisted>=1.2,<2.0')]
) )
assert ( assert (
Requirement.parse("FooBar==1.99a3") Requirement.parse("FooBar==1.99a3")
== ==
Requirement("FooBar", [('==','1.99a3')], ()) Requirement("FooBar==1.99a3")
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
Requirement.parse(">=2.3") Requirement.parse(">=2.3")
......
...@@ -710,10 +710,7 @@ class easy_install(Command): ...@@ -710,10 +710,7 @@ class easy_install(Command):
elif requirement is None or dist not in requirement: elif requirement is None or dist not in requirement:
# if we wound up with a different version, resolve what we've got # if we wound up with a different version, resolve what we've got
distreq = dist.as_requirement() distreq = dist.as_requirement()
requirement = requirement or distreq requirement = Requirement(str(distreq.req))
requirement = Requirement(
distreq.project_name, distreq.specs, requirement.extras
)
log.info("Processing dependencies for %s", requirement) log.info("Processing dependencies for %s", requirement)
try: try:
distros = WorkingSet([]).resolve( distros = WorkingSet([]).resolve(
......
...@@ -34,7 +34,9 @@ class TestDistInfo: ...@@ -34,7 +34,9 @@ class TestDistInfo:
for d in pkg_resources.find_distributions(self.tmpdir): for d in pkg_resources.find_distributions(self.tmpdir):
assert d.requires() == requires[:1] assert d.requires() == requires[:1]
assert d.requires(extras=('baz',)) == requires assert d.requires(extras=('baz',)) == [
requires[0],
pkg_resources.Requirement.parse('quux>=1.1;extra=="baz"')]
assert d.extras == ['baz'] assert d.extras == ['baz']
metadata_template = DALS(""" metadata_template = DALS("""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment