Commit ebf3e7ee authored by Marius Wachtler's avatar Marius Wachtler

Implement vararg map()

parent f58caf8c
...@@ -450,11 +450,9 @@ Value ASTInterpreter::visit_jump(AST_Jump* node) { ...@@ -450,11 +450,9 @@ Value ASTInterpreter::visit_jump(AST_Jump* node) {
} }
CompiledFunction* partial_func = compilePartialFuncInternal(&exit); CompiledFunction* partial_func = compilePartialFuncInternal(&exit);
Box* arg1 = arg_array.size() >= 1 ? arg_array[0] : 0; auto arg_tuple = getTupleFromArgsArray(&arg_array[0], arg_array.size());
Box* arg2 = arg_array.size() >= 2 ? arg_array[1] : 0; return partial_func->call(std::get<0>(arg_tuple), std::get<1>(arg_tuple), std::get<2>(arg_tuple),
Box* arg3 = arg_array.size() >= 3 ? arg_array[2] : 0; std::get<3>(arg_tuple));
Box** args = arg_array.size() >= 4 ? &arg_array[3] : 0;
return partial_func->call(arg1, arg2, arg3, args);
} }
} }
......
...@@ -559,6 +559,54 @@ Box* map2(Box* f, Box* container) { ...@@ -559,6 +559,54 @@ Box* map2(Box* f, Box* container) {
return rtn; return rtn;
} }
Box* map(Box* f, BoxedTuple* args) {
assert(args->cls == tuple_cls);
auto num_iterable = args->elts.size();
if (num_iterable < 1)
raiseExcHelper(TypeError, "map() requires at least two args");
// performance optimization for the case where we only have one iterable
if (num_iterable == 1)
return map2(f, args->elts[0]);
std::vector<BoxIterator> args_it;
std::vector<BoxIterator> args_end;
for (auto& e : args->elts) {
auto range = e->pyElements();
args_it.emplace_back(range.begin());
args_end.emplace_back(range.end());
}
assert(args_it.size() == num_iterable);
assert(args_end.size() == num_iterable);
Box* rtn = new BoxedList();
std::vector<Box*> current_val(num_iterable);
while (true) {
int num_done = 0;
for (int i = 0; i < num_iterable; ++i) {
if (args_it[i] == args_end[i]) {
++num_done;
current_val[i] = None;
} else {
current_val[i] = *args_it[i];
}
}
if (num_done == num_iterable)
break;
auto v = getTupleFromArgsArray(&current_val[0], num_iterable);
listAppendInternal(rtn, runtimeCall(f, ArgPassSpec(num_iterable), std::get<0>(v), std::get<1>(v),
std::get<2>(v), std::get<3>(v), NULL));
for (int i = 0; i < num_iterable; ++i) {
if (args_it[i] != args_end[i])
++args_it[i];
}
}
return rtn;
}
Box* reduce(Box* f, Box* container, Box* initial) { Box* reduce(Box* f, Box* container, Box* initial) {
Box* current = initial; Box* current = initial;
...@@ -1042,7 +1090,8 @@ void setupBuiltins() { ...@@ -1042,7 +1090,8 @@ void setupBuiltins() {
builtins_module->giveAttr("execfile", builtins_module->giveAttr("execfile",
new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)execfile, UNKNOWN, 1), "execfile")); new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)execfile, UNKNOWN, 1), "execfile"));
builtins_module->giveAttr("map", new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)map2, LIST, 2), "map")); builtins_module->giveAttr(
"map", new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)map, LIST, 1, 0, true, false), "map"));
builtins_module->giveAttr( builtins_module->giveAttr(
"reduce", new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)reduce, UNKNOWN, 3, 1, false, false), "reduce", "reduce", new BoxedBuiltinFunctionOrMethod(boxRTFunction((void*)reduce, UNKNOWN, 3, 1, false, false), "reduce",
{ NULL })); { NULL }));
......
...@@ -148,5 +148,13 @@ static const char* objectNewParameterTypeErrorMsg() { ...@@ -148,5 +148,13 @@ static const char* objectNewParameterTypeErrorMsg() {
} }
bool exceptionMatches(const ExcInfo& e, BoxedClass* cls); bool exceptionMatches(const ExcInfo& e, BoxedClass* cls);
inline std::tuple<Box*, Box*, Box*, Box**> getTupleFromArgsArray(Box** args, int num_args) {
Box* arg1 = num_args >= 1 ? args[0] : nullptr;
Box* arg2 = num_args >= 2 ? args[1] : nullptr;
Box* arg3 = num_args >= 3 ? args[2] : nullptr;
Box** argtuple = num_args >= 4 ? &args[3] : nullptr;
return std::make_tuple(arg1, arg2, arg3, argtuple);
}
} }
#endif #endif
def f(x): def f(*args):
print "f(%s)" % x print "f(",
return -x for a in args:
print a,
print ")"
s = -1
try:
s = sum(args)
except:
pass
return s
print map(f, range(10)) print map(f, range(10))
print map(f, range(9), range(20, 31), range(30, 40), range(40, 50), range(60, 70))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment