Commit 4e9bdfef authored by YOU's avatar YOU Committed by Dylan Trotter

Implement classmethod

parent f2582500
...@@ -86,6 +86,7 @@ var builtinTypes = map[*Type]*builtinTypeInfo{ ...@@ -86,6 +86,7 @@ var builtinTypes = map[*Type]*builtinTypeInfo{
BoolType: {init: initBoolType, global: true}, BoolType: {init: initBoolType, global: true},
BytesWarningType: {global: true}, BytesWarningType: {global: true},
CodeType: {}, CodeType: {},
ClassMethodType: {init: initClassMethodType, global: true},
DeprecationWarningType: {global: true}, DeprecationWarningType: {global: true},
dictItemIteratorType: {init: initDictItemIteratorType}, dictItemIteratorType: {init: initDictItemIteratorType},
dictKeyIteratorType: {init: initDictKeyIteratorType}, dictKeyIteratorType: {init: initDictKeyIteratorType},
......
...@@ -25,6 +25,9 @@ var ( ...@@ -25,6 +25,9 @@ var (
// StaticMethodType is the object representing the Python // StaticMethodType is the object representing the Python
// 'staticmethod' type. // 'staticmethod' type.
StaticMethodType = newBasisType("staticmethod", reflect.TypeOf(staticMethod{}), toStaticMethodUnsafe, ObjectType) StaticMethodType = newBasisType("staticmethod", reflect.TypeOf(staticMethod{}), toStaticMethodUnsafe, ObjectType)
// ClassMethodType is the object representing the Python
// 'classmethod' type.
ClassMethodType = newBasisType("classmethod", reflect.TypeOf(classMethod{}), toClassMethodUnsafe, ObjectType)
) )
// Args represent positional parameters in a call to a Python function. // Args represent positional parameters in a call to a Python function.
...@@ -182,3 +185,43 @@ func initStaticMethodType(map[string]*Object) { ...@@ -182,3 +185,43 @@ func initStaticMethodType(map[string]*Object) {
StaticMethodType.slots.Get = &getSlot{staticMethodGet} StaticMethodType.slots.Get = &getSlot{staticMethodGet}
StaticMethodType.slots.Init = &initSlot{staticMethodInit} StaticMethodType.slots.Init = &initSlot{staticMethodInit}
} }
// classMethod represents Python 'classmethod' objects.
type classMethod struct {
Object
callable *Object
}
func newClassMethod(callable *Object) *classMethod {
return &classMethod{Object{typ: ClassMethodType}, callable}
}
func toClassMethodUnsafe(o *Object) *classMethod {
return (*classMethod)(o.toPointer())
}
// ToObject upcasts f to an Object.
func (m *classMethod) ToObject() *Object {
return &m.Object
}
func classMethodGet(f *Frame, desc, _ *Object, owner *Type) (*Object, *BaseException) {
m := toClassMethodUnsafe(desc)
if m.callable == nil {
return nil, f.RaiseType(RuntimeErrorType, "uninitialized classmethod object")
}
return NewMethod(toFunctionUnsafe(m.callable), owner.ToObject(), owner).ToObject(), nil
}
func classMethodInit(f *Frame, o *Object, args Args, _ KWArgs) (*Object, *BaseException) {
if raised := checkFunctionArgs(f, "__init__", args, ObjectType); raised != nil {
return nil, raised
}
toClassMethodUnsafe(o).callable = args[0]
return None, nil
}
func initClassMethodType(map[string]*Object) {
ClassMethodType.slots.Get = &getSlot{classMethodGet}
ClassMethodType.slots.Init = &initSlot{classMethodInit}
}
...@@ -124,3 +124,38 @@ func TestStaticMethodInit(t *testing.T) { ...@@ -124,3 +124,38 @@ func TestStaticMethodInit(t *testing.T) {
} }
} }
} }
func TestClassMethodGet(t *testing.T) {
cases := []invokeTestCase{
// {args: wrapArgs(newClassMethod(NewStr("abc").ToObject()), 123, IntType), want: NewStr("abc").ToObject()},
{args: wrapArgs(newClassMethod(nil), 123, IntType), wantExc: mustCreateException(RuntimeErrorType, "uninitialized classmethod object")},
}
for _, cas := range cases {
if err := runInvokeMethodTestCase(ClassMethodType, "__get__", &cas); err != "" {
t.Error(err)
}
}
}
func TestClassMethodInit(t *testing.T) {
fun := wrapFuncForTest(func(f *Frame, args ...*Object) (*Object, *BaseException) {
m, raised := ClassMethodType.Call(f, args, nil)
if raised != nil {
return nil, raised
}
get, raised := GetAttr(f, m, NewStr("__get__"), nil)
if raised != nil {
return nil, raised
}
return get.Call(f, wrapArgs(123, IntType), nil)
})
cases := []invokeTestCase{
// {args: wrapArgs(3.14), want: NewFloat(3.14).ToObject()},
{wantExc: mustCreateException(TypeErrorType, "'__init__' requires 1 arguments")},
}
for _, cas := range cases {
if err := runInvokeTestCase(fun, &cas); err != "" {
t.Error(err)
}
}
}
...@@ -102,12 +102,13 @@ class UserDict(object): ...@@ -102,12 +102,13 @@ class UserDict(object):
return self.data.popitem() return self.data.popitem()
def __contains__(self, key): def __contains__(self, key):
return key in self.data return key in self.data
@classmethod
def fromkeys(cls, iterable, value=None): def fromkeys(cls, iterable, value=None):
d = cls() d = cls()
for key in iterable: for key in iterable:
d[key] = value d[key] = value
return d return d
# TODO: Make this a decorator once they're implemented.
fromkeys = classmethod(fromkeys)
class IterableUserDict(UserDict): class IterableUserDict(UserDict):
def __iter__(self): def __iter__(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment