Commit 2d8cb0bf authored by Jakub Kicinski's avatar Jakub Kicinski

Merge branch 'tools-ynl-fix-enum-as-flags-in-the-generic-cli'

Jakub Kicinski says:

====================
tools: ynl: fix enum-as-flags in the generic CLI

The CLI needs to use proper classes when looking at Enum definitions
rather than interpreting the YAML spec ad-hoc, because we have more
than on format of the definition supported.
====================

Link: https://lore.kernel.org/r/20230308003923.445268-1-kuba@kernel.orgSigned-off-by: default avatarJakub Kicinski <kuba@kernel.org>
parents 649c15c7 c311aaa7
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
from .nlspec import SpecAttr, SpecAttrSet, SpecFamily, SpecOperation from .nlspec import SpecAttr, SpecAttrSet, SpecEnumEntry, SpecEnumSet, \
SpecFamily, SpecOperation
from .ynl import YnlFamily from .ynl import YnlFamily
__all__ = ["SpecAttr", "SpecAttrSet", "SpecFamily", "SpecOperation", __all__ = ["SpecAttr", "SpecAttrSet", "SpecEnumEntry", "SpecEnumSet",
"YnlFamily"] "SpecFamily", "SpecOperation", "YnlFamily"]
...@@ -57,6 +57,94 @@ class SpecElement: ...@@ -57,6 +57,94 @@ class SpecElement:
pass pass
class SpecEnumEntry(SpecElement):
""" Entry within an enum declared in the Netlink spec.
Attributes:
doc documentation string
enum_set back reference to the enum
value numerical value of this enum (use accessors in most situations!)
Methods:
raw_value raw value, i.e. the id in the enum, unlike user value which is a mask for flags
user_value user value, same as raw value for enums, for flags it's the mask
"""
def __init__(self, enum_set, yaml, prev, value_start):
if isinstance(yaml, str):
yaml = {'name': yaml}
super().__init__(enum_set.family, yaml)
self.doc = yaml.get('doc', '')
self.enum_set = enum_set
if 'value' in yaml:
self.value = yaml['value']
elif prev:
self.value = prev.value + 1
else:
self.value = value_start
def has_doc(self):
return bool(self.doc)
def raw_value(self):
return self.value
def user_value(self):
if self.enum_set['type'] == 'flags':
return 1 << self.value
else:
return self.value
class SpecEnumSet(SpecElement):
""" Enum type
Represents an enumeration (list of numerical constants)
as declared in the "definitions" section of the spec.
Attributes:
type enum or flags
entries entries by name
entries_by_val entries by value
Methods:
get_mask for flags compute the mask of all defined values
"""
def __init__(self, family, yaml):
super().__init__(family, yaml)
self.type = yaml['type']
prev_entry = None
value_start = self.yaml.get('value-start', 0)
self.entries = dict()
self.entries_by_val = dict()
for entry in self.yaml['entries']:
e = self.new_entry(entry, prev_entry, value_start)
self.entries[e.name] = e
self.entries_by_val[e.raw_value()] = e
prev_entry = e
def new_entry(self, entry, prev_entry, value_start):
return SpecEnumEntry(self, entry, prev_entry, value_start)
def has_doc(self):
if 'doc' in self.yaml:
return True
for entry in self.entries.values():
if entry.has_doc():
return True
return False
def get_mask(self):
mask = 0
idx = self.yaml.get('value-start', 0)
for _ in self.entries.values():
mask |= 1 << idx
idx += 1
return mask
class SpecAttr(SpecElement): class SpecAttr(SpecElement):
""" Single Netlink atttribute type """ Single Netlink atttribute type
...@@ -193,6 +281,7 @@ class SpecFamily(SpecElement): ...@@ -193,6 +281,7 @@ class SpecFamily(SpecElement):
msgs dict of all messages (index by name) msgs dict of all messages (index by name)
msgs_by_value dict of all messages (indexed by name) msgs_by_value dict of all messages (indexed by name)
ops dict of all valid requests / responses ops dict of all valid requests / responses
consts dict of all constants/enums
""" """
def __init__(self, spec_path, schema_path=None): def __init__(self, spec_path, schema_path=None):
with open(spec_path, "r") as stream: with open(spec_path, "r") as stream:
...@@ -222,6 +311,7 @@ class SpecFamily(SpecElement): ...@@ -222,6 +311,7 @@ class SpecFamily(SpecElement):
self.req_by_value = collections.OrderedDict() self.req_by_value = collections.OrderedDict()
self.rsp_by_value = collections.OrderedDict() self.rsp_by_value = collections.OrderedDict()
self.ops = collections.OrderedDict() self.ops = collections.OrderedDict()
self.consts = collections.OrderedDict()
last_exception = None last_exception = None
while len(self._resolution_list) > 0: while len(self._resolution_list) > 0:
...@@ -242,6 +332,9 @@ class SpecFamily(SpecElement): ...@@ -242,6 +332,9 @@ class SpecFamily(SpecElement):
if len(resolved) == 0: if len(resolved) == 0:
raise last_exception raise last_exception
def new_enum(self, elem):
return SpecEnumSet(self, elem)
def new_attr_set(self, elem): def new_attr_set(self, elem):
return SpecAttrSet(self, elem) return SpecAttrSet(self, elem)
...@@ -296,6 +389,12 @@ class SpecFamily(SpecElement): ...@@ -296,6 +389,12 @@ class SpecFamily(SpecElement):
def resolve(self): def resolve(self):
self.resolve_up(super()) self.resolve_up(super())
for elem in self.yaml['definitions']:
if elem['type'] == 'enum' or elem['type'] == 'flags':
self.consts[elem['name']] = self.new_enum(elem)
else:
self.consts[elem['name']] = elem
for elem in self.yaml['attribute-sets']: for elem in self.yaml['attribute-sets']:
attr_set = self.new_attr_set(elem) attr_set = self.new_attr_set(elem)
self.attr_sets[elem['name']] = attr_set self.attr_sets[elem['name']] = attr_set
......
...@@ -303,11 +303,6 @@ class YnlFamily(SpecFamily): ...@@ -303,11 +303,6 @@ class YnlFamily(SpecFamily):
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1) self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
self._types = dict()
for elem in self.yaml.get('definitions', []):
self._types[elem['name']] = elem
self.async_msg_ids = set() self.async_msg_ids = set()
self.async_msg_queue = [] self.async_msg_queue = []
...@@ -353,13 +348,13 @@ class YnlFamily(SpecFamily): ...@@ -353,13 +348,13 @@ class YnlFamily(SpecFamily):
def _decode_enum(self, rsp, attr_spec): def _decode_enum(self, rsp, attr_spec):
raw = rsp[attr_spec['name']] raw = rsp[attr_spec['name']]
enum = self._types[attr_spec['enum']] enum = self.consts[attr_spec['enum']]
i = attr_spec.get('value-start', 0) i = attr_spec.get('value-start', 0)
if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']: if 'enum-as-flags' in attr_spec and attr_spec['enum-as-flags']:
value = set() value = set()
while raw: while raw:
if raw & 1: if raw & 1:
value.add(enum['entries'][i]) value.add(enum.entries_by_val[i].name)
raw >>= 1 raw >>= 1
i += 1 i += 1
else: else:
......
...@@ -6,7 +6,7 @@ import collections ...@@ -6,7 +6,7 @@ import collections
import os import os
import yaml import yaml
from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation from lib import SpecFamily, SpecAttrSet, SpecAttr, SpecOperation, SpecEnumSet, SpecEnumEntry
def c_upper(name): def c_upper(name):
...@@ -567,97 +567,37 @@ class Struct: ...@@ -567,97 +567,37 @@ class Struct:
self.inherited = [c_lower(x) for x in sorted(self._inherited)] self.inherited = [c_lower(x) for x in sorted(self._inherited)]
class EnumEntry: class EnumEntry(SpecEnumEntry):
def __init__(self, enum_set, yaml, prev, value_start): def __init__(self, enum_set, yaml, prev, value_start):
if isinstance(yaml, str): super().__init__(enum_set, yaml, prev, value_start)
self.name = yaml
yaml = {}
self.doc = ''
else:
self.name = yaml['name']
self.doc = yaml.get('doc', '')
self.yaml = yaml
self.enum_set = enum_set
self.c_name = c_upper(enum_set.value_pfx + self.name)
if 'value' in yaml:
self.value = yaml['value']
if prev: if prev:
self.value_change = (self.value != prev.value + 1) self.value_change = (self.value != prev.value + 1)
elif prev:
self.value_change = False
self.value = prev.value + 1
else: else:
self.value = value_start
self.value_change = (self.value != 0) self.value_change = (self.value != 0)
self.value_change = self.value_change or self.enum_set['type'] == 'flags' self.value_change = self.value_change or self.enum_set['type'] == 'flags'
def __getitem__(self, key): # Added by resolve:
return self.yaml[key] self.c_name = None
delattr(self, "c_name")
def __contains__(self, key):
return key in self.yaml
def has_doc(self):
return bool(self.doc)
# raw value, i.e. the id in the enum, unlike user value which is a mask for flags def resolve(self):
def raw_value(self): self.resolve_up(super())
return self.value
# user value, same as raw value for enums, for flags it's the mask self.c_name = c_upper(self.enum_set.value_pfx + self.name)
def user_value(self):
if self.enum_set['type'] == 'flags':
return 1 << self.value
else:
return self.value
class EnumSet: class EnumSet(SpecEnumSet):
def __init__(self, family, yaml): def __init__(self, family, yaml):
self.yaml = yaml
self.family = family
self.render_name = c_lower(family.name + '-' + yaml['name']) self.render_name = c_lower(family.name + '-' + yaml['name'])
self.enum_name = 'enum ' + self.render_name self.enum_name = 'enum ' + self.render_name
self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-") self.value_pfx = yaml.get('name-prefix', f"{family.name}-{yaml['name']}-")
self.type = yaml['type'] super().__init__(family, yaml)
prev_entry = None
value_start = self.yaml.get('value-start', 0)
self.entries = {}
self.entry_list = []
for entry in self.yaml['entries']:
e = EnumEntry(self, entry, prev_entry, value_start)
self.entries[e.name] = e
self.entry_list.append(e)
prev_entry = e
def __getitem__(self, key):
return self.yaml[key]
def __contains__(self, key):
return key in self.yaml
def has_doc(self):
if 'doc' in self.yaml:
return True
for entry in self.entry_list:
if entry.has_doc():
return True
return False
def get_mask(self): def new_entry(self, entry, prev_entry, value_start):
mask = 0 return EnumEntry(self, entry, prev_entry, value_start)
idx = self.yaml.get('value-start', 0)
for _ in self.entry_list:
mask |= 1 << idx
idx += 1
return mask
class AttrSet(SpecAttrSet): class AttrSet(SpecAttrSet):
...@@ -792,8 +732,6 @@ class Family(SpecFamily): ...@@ -792,8 +732,6 @@ class Family(SpecFamily):
self.mcgrps = self.yaml.get('mcast-groups', {'list': []}) self.mcgrps = self.yaml.get('mcast-groups', {'list': []})
self.consts = dict()
self.hooks = dict() self.hooks = dict()
for when in ['pre', 'post']: for when in ['pre', 'post']:
self.hooks[when] = dict() self.hooks[when] = dict()
...@@ -820,6 +758,9 @@ class Family(SpecFamily): ...@@ -820,6 +758,9 @@ class Family(SpecFamily):
if self.kernel_policy == 'global': if self.kernel_policy == 'global':
self._load_global_policy() self._load_global_policy()
def new_enum(self, elem):
return EnumSet(self, elem)
def new_attr_set(self, elem): def new_attr_set(self, elem):
return AttrSet(self, elem) return AttrSet(self, elem)
...@@ -837,12 +778,6 @@ class Family(SpecFamily): ...@@ -837,12 +778,6 @@ class Family(SpecFamily):
} }
def _dictify(self): def _dictify(self):
for elem in self.yaml['definitions']:
if elem['type'] == 'enum' or elem['type'] == 'flags':
self.consts[elem['name']] = EnumSet(self, elem)
else:
self.consts[elem['name']] = elem
ntf = [] ntf = []
for msg in self.msgs.values(): for msg in self.msgs.values():
if 'notify' in msg: if 'notify' in msg:
...@@ -1980,7 +1915,7 @@ def render_uapi(family, cw): ...@@ -1980,7 +1915,7 @@ def render_uapi(family, cw):
if 'doc' in enum: if 'doc' in enum:
doc = ' - ' + enum['doc'] doc = ' - ' + enum['doc']
cw.write_doc_line(enum.enum_name + doc) cw.write_doc_line(enum.enum_name + doc)
for entry in enum.entry_list: for entry in enum.entries.values():
if entry.has_doc(): if entry.has_doc():
doc = '@' + entry.c_name + ': ' + entry['doc'] doc = '@' + entry.c_name + ': ' + entry['doc']
cw.write_doc_line(doc) cw.write_doc_line(doc)
...@@ -1988,7 +1923,7 @@ def render_uapi(family, cw): ...@@ -1988,7 +1923,7 @@ def render_uapi(family, cw):
uapi_enum_start(family, cw, const, 'name') uapi_enum_start(family, cw, const, 'name')
name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-") name_pfx = const.get('name-prefix', f"{family.name}-{const['name']}-")
for entry in enum.entry_list: for entry in enum.entries.values():
suffix = ',' suffix = ','
if entry.value_change: if entry.value_change:
suffix = f" = {entry.user_value()}" + suffix suffix = f" = {entry.user_value()}" + suffix
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment