Commit 03ac7b78 authored by Kevin Modzelewski's avatar Kevin Modzelewski

list.sort(cmp)

But not both cmp and key at the same time.
We might want to just switch to CPython's sort implementation for that.
parent d5989608
...@@ -630,15 +630,40 @@ extern "C" int PyList_Reverse(PyObject* v) noexcept { ...@@ -630,15 +630,40 @@ extern "C" int PyList_Reverse(PyObject* v) noexcept {
return 0; return 0;
} }
class PyCmpComparer {
private:
Box* cmp;
public:
PyCmpComparer(Box* cmp) : cmp(cmp) {}
bool operator()(Box* lhs, Box* rhs) {
Box* r = runtimeCallInternal(cmp, NULL, ArgPassSpec(2), lhs, rhs, NULL, NULL, NULL);
if (!isSubclass(r->cls, int_cls))
raiseExcHelper(TypeError, "comparison function must return int, not %.200s", r->cls->tp_name);
return static_cast<BoxedInt*>(r)->n < 0;
}
};
void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) { void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) {
LOCK_REGION(self->lock.asWrite()); LOCK_REGION(self->lock.asWrite());
assert(isSubclass(self->cls, list_cls)); assert(isSubclass(self->cls, list_cls));
RELEASE_ASSERT(cmp == None, "The 'cmp' keyword is currently not supported"); if (cmp == None)
cmp = NULL;
if (key == None) if (key == None)
key = NULL; key = NULL;
RELEASE_ASSERT(!cmp || !key, "Specifying both the 'cmp' and 'key' keywords is currently not supported");
// TODO(kmod): maybe we should just switch to CPython's sort. not sure how the algorithms compare,
// but they specifically try to support cases where __lt__ or the cmp function might end up inspecting
// the current list being sorted.
// I also don't know if std::stable_sort is exception-safe.
if (cmp) {
std::stable_sort<Box**, PyCmpComparer>(self->elts->elts, self->elts->elts + self->size, PyCmpComparer(cmp));
} else {
int num_keys_added = 0; int num_keys_added = 0;
auto remove_keys = [&]() { auto remove_keys = [&]() {
for (int i = 0; i < num_keys_added; i++) { for (int i = 0; i < num_keys_added; i++) {
...@@ -672,10 +697,11 @@ void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) { ...@@ -672,10 +697,11 @@ void listSort(BoxedList* self, Box* cmp, Box* key, Box* reverse) {
std::stable_sort<Box**, PyLt>(self->elts->elts, self->elts->elts + self->size, PyLt()); std::stable_sort<Box**, PyLt>(self->elts->elts, self->elts->elts + self->size, PyLt());
} catch (ExcInfo e) { } catch (ExcInfo e) {
remove_keys(); remove_keys();
throw e; raiseRaw(e);
} }
remove_keys(); remove_keys();
}
if (nonzero(reverse)) { if (nonzero(reverse)) {
listReverse(self); listReverse(self);
......
...@@ -99,8 +99,6 @@ void REWRITE_ABORTED(const char* reason) { ...@@ -99,8 +99,6 @@ void REWRITE_ABORTED(const char* reason) {
#define REWRITE_ABORTED(reason) ((void)(reason)) #define REWRITE_ABORTED(reason) ((void)(reason))
#endif #endif
Box* runtimeCallInternal(Box* obj, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* arg3,
Box** args, const std::vector<const std::string*>* keyword_names);
static Box* (*runtimeCallInternal0)(Box*, CallRewriteArgs*, ArgPassSpec) static Box* (*runtimeCallInternal0)(Box*, CallRewriteArgs*, ArgPassSpec)
= (Box * (*)(Box*, CallRewriteArgs*, ArgPassSpec))runtimeCallInternal; = (Box * (*)(Box*, CallRewriteArgs*, ArgPassSpec))runtimeCallInternal;
static Box* (*runtimeCallInternal1)(Box*, CallRewriteArgs*, ArgPassSpec, Box*) static Box* (*runtimeCallInternal1)(Box*, CallRewriteArgs*, ArgPassSpec, Box*)
......
...@@ -103,6 +103,9 @@ struct BinopRewriteArgs; ...@@ -103,6 +103,9 @@ struct BinopRewriteArgs;
extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args); extern "C" Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteArgs* rewrite_args);
struct CallRewriteArgs; struct CallRewriteArgs;
Box* runtimeCallInternal(Box* obj, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* arg3,
Box** args, const std::vector<const std::string*>* keyword_names);
Box* lenCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* lenCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2,
Box* arg3, Box** args, const std::vector<const std::string*>* keyword_names); Box* arg3, Box** args, const std::vector<const std::string*>* keyword_names);
Box* typeCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2, Box* typeCallInternal(BoxedFunctionBase* f, CallRewriteArgs* rewrite_args, ArgPassSpec argspec, Box* arg1, Box* arg2,
......
...@@ -171,3 +171,26 @@ for i in xrange(3): ...@@ -171,3 +171,26 @@ for i in xrange(3):
l1 = [i] l1 = [i]
l2 = [j, k] l2 = [j, k]
print l1 < l2, l1 <= l2, l1 > l2, l1 >= l2 print l1 < l2, l1 <= l2, l1 > l2, l1 >= l2
def mycmp(k1, k2):
types_seen.add((type(k1), type(k2)))
if k1 == k2:
return 0
if k1 < k2:
return -1
return 1
types_seen = set()
l = ["%d" for i in xrange(20)]
l.sort(cmp=mycmp)
print types_seen
print l
"""
types_seen = set()
l = range(20)
l.sort(cmp=mycmp, key=str)
print types_seen
print l
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment