Commit d8a51700 authored by Tom Niget's avatar Tom Niget

Update stuff

parent 22a67b3a
...@@ -221,7 +221,7 @@ class ExpressionVisitor(NodeVisitor): ...@@ -221,7 +221,7 @@ class ExpressionVisitor(NodeVisitor):
use_dot = None use_dot = None
if type(node.value.type) == TypeType: if type(node.value.type) == TypeType:
use_dot = "dots" use_dot = "dots"
elif isinstance(node.type, FunctionType) and not isinstance(node.value.type, Promise): elif isinstance(node.type, FunctionType) and node.type.is_method and not isinstance(node.value.type, Promise):
if node.value.type.resolve().is_reference: if node.value.type.resolve().is_reference:
use_dot = "dotp" use_dot = "dotp"
else: else:
......
...@@ -284,4 +284,22 @@ class OutsideLoopError(CompileError): ...@@ -284,4 +284,22 @@ class OutsideLoopError(CompileError):
return f"{highlight('break')} and {highlight('continue')} can only be used inside a loop" return f"{highlight('break')} and {highlight('continue')} can only be used inside a loop"
def detail(self, last_node: ast.AST = None) -> str: def detail(self, last_node: ast.AST = None) -> str:
return "" return ""
\ No newline at end of file
@dataclass
class MissingReturnError(CompileError):
node: ast.FunctionDef
def __str__(self) -> str:
return f"Missing return: not all code paths in {highlight(self.node)} return"
def detail(self, last_node: ast.AST = None) -> str:
return f"""
This indicates that a function is missing a {highlight('return')} statement in one or more of its code paths.
For example:
{highlight('def f(x: int):')}
{highlight(' if x > 0:')}
{highlight(' return 1')}
{highlight(' # if x <= 0, the function returns nothing')}
"""
\ No newline at end of file
...@@ -67,7 +67,7 @@ class ScoperExprVisitor(ScoperVisitor): ...@@ -67,7 +67,7 @@ class ScoperExprVisitor(ScoperVisitor):
assert ftype.kind == PromiseKind.TASK assert ftype.kind == PromiseKind.TASK
ftype.kind = PromiseKind.GENERATOR ftype.kind = PromiseKind.GENERATOR
ftype.return_type.unify(ytype) ftype.return_type.unify(ytype)
self.scope.function.has_return = True self.scope.function.has_yield = True
return TY_NONE return TY_NONE
......
...@@ -54,7 +54,7 @@ class Scope: ...@@ -54,7 +54,7 @@ class Scope:
vars: Dict[str, VarDecl] = field(default_factory=dict) vars: Dict[str, VarDecl] = field(default_factory=dict)
children: List["Scope"] = field(default_factory=list) children: List["Scope"] = field(default_factory=list)
obj_type: Optional[BaseType] = None obj_type: Optional[BaseType] = None
has_return: bool = False diverges: bool = False
class_: Optional["Scope"] = None class_: Optional["Scope"] = None
is_loop: Optional[ast.For | ast.While] = None is_loop: Optional[ast.For | ast.While] = None
......
...@@ -105,6 +105,7 @@ class StdlibVisitor(NodeVisitorSeq): ...@@ -105,6 +105,7 @@ class StdlibVisitor(NodeVisitorSeq):
ty.variadic = True ty.variadic = True
ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults) ty.optional_at = 1 + len(node.args.args) - len(node.args.defaults)
if self.cur_class: if self.cur_class:
ty.is_method = True
assert isinstance(self.cur_class, TypeType) assert isinstance(self.cur_class, TypeType)
if isinstance(self.cur_class.type_object, ABCMeta): if isinstance(self.cur_class.type_object, ABCMeta):
self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars) self.cur_class.type_object.gen_methods[node.name] = lambda t: ty.gen_sub(t, self.typevars)
......
...@@ -289,12 +289,12 @@ class TypeOperator(BaseType, ABC): ...@@ -289,12 +289,12 @@ class TypeOperator(BaseType, ABC):
vardict = dict(zip(typevars.keys(), this.args)) vardict = dict(zip(typevars.keys(), this.args))
else: else:
vardict = typevars vardict = typevars
for k in dataclasses.fields(self): for k, v in self.__dict__.items():
setattr(res, k.name, getattr(self, k.name)) setattr(res, k, v)
res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args] res.args = [arg.resolve().gen_sub(this, vardict, cache) for arg in self.args]
res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()} res.methods = {k: v.gen_sub(this, vardict, cache) for k, v in self.methods.items()}
res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents] res.parents = [p.gen_sub(this, vardict, cache) for p in self.parents]
res.is_protocol = self.is_protocol #res.is_protocol = self.is_protocol
return res return res
def to_list(self) -> List["BaseType"]: def to_list(self) -> List["BaseType"]:
...@@ -308,6 +308,7 @@ class ModuleType(TypeOperator): ...@@ -308,6 +308,7 @@ class ModuleType(TypeOperator):
class FunctionType(TypeOperator): class FunctionType(TypeOperator):
is_python_func: bool = False is_python_func: bool = False
python_func_used: bool = False python_func_used: bool = False
is_method: bool = False
def __iter__(self): def __iter__(self):
x = 5 x = 5
...@@ -331,16 +332,19 @@ class FunctionType(TypeOperator): ...@@ -331,16 +332,19 @@ class FunctionType(TypeOperator):
def __str__(self): def __str__(self):
ret, *args = map(str, self.args) ret, *args = map(str, self.args)
if self.optional_at is not None:
args = args[:self.optional_at] + [f"{x}=..." for x in args[self.optional_at:]]
if self.variadic: if self.variadic:
args.append(f"*args") args.append("*args")
if args: if args:
args = f"({', '.join(args)})" args = f"{', '.join(args)}"
else: else:
args = "()" args = ""
return f"{args} -> {ret}" return f"({args}) -> {ret}"
def remove_self(self): def remove_self(self):
res = FunctionType(self.parameters[1:], self.return_type) res = FunctionType(self.parameters[1:], self.return_type)
res.is_method = self.is_method
res.variadic = self.variadic res.variadic = self.variadic
res.optional_at = self.optional_at - 1 if self.optional_at is not None else None res.optional_at = self.optional_at - 1 if self.optional_at is not None else None
return res return res
...@@ -460,8 +464,12 @@ class Promise(TypeOperator, ABC): ...@@ -460,8 +464,12 @@ class Promise(TypeOperator, ABC):
@kind.setter @kind.setter
def kind(self, value: PromiseKind): def kind(self, value: PromiseKind):
if value == PromiseKind.GENERATOR: if value == PromiseKind.GENERATOR:
self.methods["__iter__"] = FunctionType([], self) f_iter = FunctionType([], self)
self.methods["__next__"] = FunctionType([], self.return_type) f_iter.is_method = True
self.methods["__iter__"] = f_iter
f_next = FunctionType([], self.return_type)
f_next.is_method = True
self.methods["__next__"] = f_next
self.args[1].val = value self.args[1].val = value
def __str__(self): def __str__(self):
...@@ -506,4 +514,10 @@ class UserType(TypeOperator): ...@@ -506,4 +514,10 @@ class UserType(TypeOperator):
class UnionType(TypeOperator): class UnionType(TypeOperator):
def __init__(self, *args: List[BaseType]): def __init__(self, *args: List[BaseType]):
super().__init__(args, "Union") super().__init__(args, "Union")
self.parents.extend(args) self.parents.extend(set(args))
def is_optional(self):
if len(self.args) == 2 and TY_NONE in self.args:
return (set(self.args) - {TY_NONE}).pop()
return False
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment