Commit 03ac7b78 authored by Kevin Modzelewski's avatar Kevin Modzelewski

list.sort(cmp)

But not both cmp and key at the same time.
We might want to just switch to CPython's sort implementation for that.
parent d5989608
...@@ -630,53 +630,79 @@ extern "C" int PyList_Reverse(PyObject* v) noexcept { ...@@ -630,53 +630,79 @@ extern "C" int PyList_Reverse(PyObject* v) noexcept {
return 0; return 0;
} }
class PyCmpComparer {
private:
Box* cmp;
public:
PyCmpComparer(Box* cmp) : cmp(cmp) {}
bool operator()(Box* lhs, Box* rhs) {
Box* r = runtimeCallInternal(cmp, NULL, ArgPassSpec(2), lhs, rhs, NULL, NULL, NULL);
if (!isSubclass(r->cls, int_cls))
raiseExcHelper(TypeError, "comparison function must return int, not %.200s", r->cls->tp_name);
return static_cast<BoxedInt*>(r)->n < 0;
}
};
void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) { void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) {
LOCK_REGION(self->lock.asWrite()); LOCK_REGION(self->lock.asWrite());
assert(isSubclass(self->cls, list_cls)); assert(isSubclass(self->cls, list_cls));
RELEASE_ASSERT(cmp == None, "The 'cmp' keyword is currently not supported"); if (cmp == None)
cmp = NULL;
if (key == None) if (key == None)
key = NULL; key = NULL;
int num_keys_added = 0; RELEASE_ASSERT(!cmp || !key, "Specifying both the 'cmp' and 'key' keywords is currently not supported");
auto remove_keys = [&]() {
for (int i = 0; i < num_keys_added; i++) {
Box** obj_loc = &self->elts->elts[i];
assert((*obj_loc)->cls == tuple_cls);
*obj_loc = static_cast<BoxedTuple*>(*obj_loc)->elts[2];
}
};
try {
if (key) {
for (int i = 0; i < self->size; i++) {
Box** obj_loc = &self->elts->elts[i];
Box* key_val = runtimeCall(key, ArgPassSpec(1), *obj_loc, NULL, NULL, NULL, NULL); // TODO(kmod): maybe we should just switch to CPython's sort. not sure how the algorithms compare,
// Add the index as part of the new tuple so that the comparison never hits the // but they specifically try to support cases where __lt__ or the cmp function might end up inspecting
// original object. // the current list being sorted.
// TODO we could potentially make this faster by copying the CPython approach of // I also don't know if std::stable_sort is exception-safe.
// creating special sortwrapper objects that compare only based on the key.
Box* new_obj = BoxedTuple::create({ key_val, boxInt(i), *obj_loc });
*obj_loc = new_obj; if (cmp) {
num_keys_added++; std::stable_sort<Box**, PyCmpComparer>(self->elts->elts, self->elts->elts + self->size, PyCmpComparer(cmp));
} else {
int num_keys_added = 0;
auto remove_keys = [&]() {
for (int i = 0; i < num_keys_added; i++) {
Box** obj_loc = &self->elts->elts[i];
assert((*obj_loc)->cls == tuple_cls);
*obj_loc = static_cast<BoxedTuple*>(*obj_loc)->elts[2];
}
};
try {
if (key) {
for (int i = 0; i < self->size; i++) {
Box** obj_loc = &self->elts->elts[i];
Box* key_val = runtimeCall(key, ArgPassSpec(1), *obj_loc, NULL, NULL, NULL, NULL);
// Add the index as part of the new tuple so that the comparison never hits the
// original object.
// TODO we could potentially make this faster by copying the CPython approach of
// creating special sortwrapper objects that compare only based on the key.
Box* new_obj = BoxedTuple::create({ key_val, boxInt(i), *obj_loc });
*obj_loc = new_obj;
num_keys_added++;
}
} }
// We don't need to do a stable sort if there's a keyfunc, since we explicitly added the index
// as part of the sort key.
// But we might want to get rid of that approach? CPython doesn't do that (they create special
// wrapper objects that compare only based on the key).
std::stable_sort<Box**, PyLt>(self->elts->elts, self->elts->elts + self->size, PyLt());
} catch (ExcInfo e) {
remove_keys();
raiseRaw(e);
} }
// We don't need to do a stable sort if there's a keyfunc, since we explicitly added the index
// as part of the sort key.
// But we might want to get rid of that approach? CPython doesn't do that (they create special
// wrapper objects that compare only based on the key).
std::stable_sort<Box**, PyLt>(self->elts->elts, self->elts->elts + self->size, PyLt());
} catch (ExcInfo e) {
remove_keys(); remove_keys();
throw e;
} }
remove_keys();
if (nonzero(reverse)) { if (nonzero(reverse)) {
listReverse(self); listReverse(self);
} }
......
...@@ -99,8 +99,6 @@ void REWRITE_ABORTED(const char* reason) { ...@@ -99,8 +99,6 @@ void REWRITE_ABORTED(const char* reason) {
#define REWRITE_ABORTED(reason) ((void)(reason)) #define REWRITE_ABORTED(reason) ((void)(reason))
#endif #endif
Box* runtimeCallInternal(Box* obj, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* arg3,
Box** args, const std::vector<const std::string*>* keyword_names);
static Box* (*runtimeCallInternal0)(Box*, CallRewriteArgs*, ArgPassSpec) static Box* (*runtimeCallInternal0)(Box*, CallRewriteArgs*, ArgPassSpec)
= (Box * (*)(Box*, CallRewriteArgs*, ArgPassSpec))runtimeCallInternal; = (Box * (*)(Box*, CallRewriteArgs*, ArgPassSpec))runtimeCallInternal;
static Box* (*runtimeCallInternal1)(Box*, CallRewriteArgs*, ArgPassSpec, Box*) static Box* (*runtimeCallInternal1)(Box*, CallRewriteArgs*, ArgPassSpec, Box*)
......
...@@ -103,6 +103,9 @@ struct BinopRewriteArgs; ...@@ -103,6 +103,9 @@ struct BinopRewriteArgs;
extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args); extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args);
struct CallRewriteArgs; struct CallRewriteArgs;
Box* runtimeCallInternal(Box* obj, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* arg3,
Box** args, const std::vector<const std::string*>* keyword_names);
Box* lenCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* lenCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2,
Box* arg3, Box** args, const std::vector<const std::string*>* keyword_names); Box* arg3, Box** args, const std::vector<const std::string*>* keyword_names);
Box* typeCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* typeCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2,
......
...@@ -2232,7 +2232,7 @@ Box* strEncode(BoxedString* self, Box* encoding, Box* error) { ...@@ -2232,7 +2232,7 @@ Box* strEncode(BoxedString* self, Box* encoding, Box* error) {
raiseExcHelper(TypeError, "encode() argument 2 must be string, not '%s'", getTypeName(error_str)); raiseExcHelper(TypeError, "encode() argument 2 must be string, not '%s'", getTypeName(error_str));
Box* result = PyString_AsEncodedObject(self, encoding_str ? encoding_str->data() : PyUnicode_GetDefaultEncoding(), Box* result = PyString_AsEncodedObject(self, encoding_str ? encoding_str->data() : PyUnicode_GetDefaultEncoding(),
error_str ? error_str->data() : NULL); error_str ? error_str->data() : NULL);
checkAndThrowCAPIException(); checkAndThrowCAPIException();
return result; return result;
} }
......
...@@ -171,3 +171,26 @@ for i in xrange(3): ...@@ -171,3 +171,26 @@ for i in xrange(3):
l1 = [i] l1 = [i]
l2 = [j, k] l2 = [j, k]
print l1 < l2, l1 <= l2, l1 > l2, l1 >= l2 print l1 < l2, l1 <= l2, l1 > l2, l1 >= l2
def mycmp(k1, k2):
types_seen.add((type(k1), type(k2)))
if k1 == k2:
return 0
if k1 < k2:
return -1
return 1
types_seen = set()
l = ["%d" for i in xrange(20)]
l.sort(cmp=mycmp)
print types_seen
print l
"""
types_seen = set()
l = range(20)
l.sort(cmp=mycmp, key=str)
print types_seen
print l
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment