// Copyright (c) 2014 Dropbox, Inc.
// 
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// 
//    http://www.apache.org/licenses/LICENSE-2.0
// 
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <sstream>
#include <unordered_map>

#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"

#include "core/common.h"
#include "core/stats.h"

#include "core/util.h"

#include "codegen/codegen.h"
#include "codegen/llvm_interpreter.h"
#include "codegen/irgen/hooks.h"
#include "codegen/irgen/util.h"

namespace pyston {

union Val {
    bool b;
    int64_t n;
    double d;
    Box* o;

    Val(bool b) : b(b) {}
    Val(int64_t n) : n(n) {}
    Val(double d) : d(d) {}
    Val(Box* o) : o(o) {}
};

typedef std::unordered_map<llvm::Value*, Val> SymMap;

int width(llvm::Type *t, const llvm::DataLayout &dl) {
    return dl.getTypeSizeInBits(t) / 8;
    //if (t == g.i1) return 1;
    //if (t == g.i64) return 8;
    //if (t->isPointerTy()) return 8;
//
    //t->dump();
    //RELEASE_ASSERT(0, "");
}

int width(llvm::Value *v, const llvm::DataLayout &dl) {
    return width(v->getType(), dl);
}

//#undef VERBOSITY
//#define VERBOSITY(x) 2
#define TIME_INTERPRETS

Val fetch(llvm::Value* v, const llvm::DataLayout &dl, const SymMap &symbols) {
    assert(v);

    int opcode = v->getValueID();

    //std::ostringstream os("");
    //os << "fetch_" << opcode;
    //int statid = Stats::getStatId(os.str());
    //Stats::log(statid);

    if (opcode >= llvm::Value::InstructionVal) {
        assert(symbols.count(v));
        return symbols.find(v)->second;
    }

    switch(opcode) {
        case llvm::Value::ArgumentVal: {
            assert(symbols.count(v));
            return symbols.find(v)->second;
        }
        case llvm::Value::ConstantIntVal: {
            if (v->getType() == g.i1)
                return (int64_t)llvm::cast<llvm::ConstantInt>(v)->getZExtValue();
            if (v->getType() == g.i64 || v->getType() == g.i32)
                return llvm::cast<llvm::ConstantInt>(v)->getSExtValue();
            v->dump();
            RELEASE_ASSERT(0, "");
        }
        case llvm::Value::ConstantFPVal: {
            return llvm::cast<llvm::ConstantFP>(v)->getValueAPF().convertToDouble();
        }
        case llvm::Value::ConstantExprVal: {
            llvm::ConstantExpr *ce = llvm::cast<llvm::ConstantExpr>(v);
            if (ce->isCast()) {
                assert(width(ce->getOperand(0), dl) == 8 && width(ce, dl) == 8);

                Val o = fetch(ce->getOperand(0), dl, symbols);
                return o;
            } else if (ce->getOpcode() == llvm::Instruction::GetElementPtr) {
                int64_t base = (int64_t)fetch(ce->getOperand(0), dl, symbols).o;
                llvm::Type *t = ce->getOperand(0)->getType();

                llvm::User::value_op_iterator begin = ce->value_op_begin();
                ++begin;
                std::vector<llvm::Value*> indices(begin, ce->value_op_end());

                int64_t offset = dl.getIndexedOffset(t, indices);

                /*if (VERBOSITY()) {
                    ce->dump();
                    ce->getOperand(0)->dump();
                    for (int i = 0; i < indices.size() ;i++) {
                        indices[i]->dump();
                    }
                    printf("resulting offset: %ld\n", offset);
                }*/

                return base + offset;
            } else {
                v->dump();
                RELEASE_ASSERT(0, "");
            }
        }
        /*case llvm::Value::FunctionVal: {
            llvm::Function* f = llvm::cast<llvm::Function>(v);
            if (f->getName() == "printf") {
                return (int64_t)printf;
            } else if (f->getName() == "reoptCompiledFunc") {
                return (int64_t)reoptCompiledFunc;
            } else if (f->getName() == "compilePartialFunc") {
                return (int64_t)compilePartialFunc;
            } else if (startswith(f->getName(), "runtimeCall")) {
                return (int64_t)g.func_registry.getFunctionAddress("runtimeCall");
            } else {
                return (int64_t)g.func_registry.getFunctionAddress(f->getName());
            }
        }*/
        case llvm::Value::GlobalVariableVal: {
            llvm::GlobalVariable* gv = llvm::cast<llvm::GlobalVariable>(v);
            if (!gv->isDeclaration() && gv->getLinkage() == llvm::GlobalVariable::InternalLinkage) {
                static std::unordered_map<llvm::GlobalVariable*, void*> made;

                void* &r = made[gv];
                if (r == NULL) {
                    llvm::Type *t = gv->getType()->getElementType();
                    r = (void*)malloc(width(t, dl));
                    if (gv->hasInitializer()) {
                        llvm::Constant* init = gv->getInitializer();
                        assert(init->getType() == t);
                        if (t == g.i64) {
                            llvm::ConstantInt *ci = llvm::cast<llvm::ConstantInt>(init);
                            *(int64_t*)r = ci->getSExtValue();
                        } else {
                            gv->dump();
                            RELEASE_ASSERT(0, "");
                        }
                    }
                }

                //gv->getType()->dump();
                //gv->dump();
                //printf("%p\n", r);
                //RELEASE_ASSERT(0, "");
                return (int64_t)r;
            }

            gv->dump();
            RELEASE_ASSERT(0, "");
        }
        case llvm::Value::UndefValueVal:
            return (int64_t)-1337;
        default:
            v->dump();
            RELEASE_ASSERT(0, "%d", v->getValueID());
    }
}

static void set(SymMap &symbols, const llvm::BasicBlock::iterator &it, Val v) {
    if (VERBOSITY() >= 2) {
        printf("Setting to %lx / %f: ", v.n, v.d);
        fflush(stdout);
        it->dump();
    }

    SymMap::iterator f = symbols.find(it);
    if (f != symbols.end())
        f->second = v;
    else
        symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*it)), v));
//#define SET(v) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*it)), Val(v)))
}

static std::unordered_map<void*, const SymMap*> interpreter_roots;
void gatherInterpreterRootsForFrame(GCVisitor *visitor, void* frame_ptr) {
    auto it = interpreter_roots.find(frame_ptr);
    if (it == interpreter_roots.end()) {
        printf("%p is not an interpreter frame; they are", frame_ptr);
        for (auto it2 : interpreter_roots) {
            printf(" %p", it2.first);
        }
        printf("\n");
        abort();
    }

    //printf("Gathering roots for frame %p\n", frame_ptr);
    const SymMap* symbols = it->second;

    for (auto it2 : *symbols) {
        visitor->visitPotential(it2.second.o);
    }
}

class UnregisterHelper {
    private:
        void* frame_ptr;

    public:
        constexpr UnregisterHelper(void* frame_ptr) : frame_ptr(frame_ptr) {}

        ~UnregisterHelper() {
            assert(interpreter_roots.count(frame_ptr));
            interpreter_roots.erase(frame_ptr);
        }
};

Box* interpretFunction(llvm::Function *f, int nargs, Box* arg1, Box* arg2, Box* arg3, Box* *args) {
    assert(f);

#ifdef TIME_INTERPRETS
    Timer _t("to interpret", 1000000);
    long this_us = 0;
#endif

    static StatCounter interpreted_runs("interpreted_runs");
    interpreted_runs.log();

    llvm::DataLayout dl(f->getParent());

    //f->dump();
    //assert(nargs == f->getArgumentList().size());

    SymMap symbols;

    void* frame_ptr = __builtin_frame_address(0);
    interpreter_roots[frame_ptr] = &symbols;
    UnregisterHelper helper(frame_ptr);

    int i = 0;
    for (llvm::Function::arg_iterator AI = f->arg_begin(), end = f->arg_end(); AI != end; AI++, i++) {
        if (i == 0) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*AI)), Val(arg1)));
        else if (i == 1) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*AI)), Val(arg2)));
        else if (i == 2) symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*AI)), Val(arg3)));
        else {
            assert(i == 3);
            assert(f->getArgumentList().size() == 4);
            assert(f->getArgumentList().back().getType() == g.llvm_value_type_ptr->getPointerTo());
            symbols.insert(std::make_pair(static_cast<llvm::Value*>(&(*AI)), Val((int64_t)args)));
            //printf("loading %%4 with %p\n", (void*)args);
            break;
        }
    }

    llvm::BasicBlock *prevblock = NULL;
    llvm::BasicBlock *curblock = &f->getEntryBlock();


    while (true) {
        for (llvm::BasicBlock::iterator it = curblock->begin(), end = curblock->end(); it != end; ++it) {
            if (VERBOSITY("interpreter") >= 2) {
                printf("executing in %s: ", f->getName().data());
                fflush(stdout);
                it->dump();
                //f->dump();
            }

#define SET(v) set(symbols, it, (v))
            if (llvm::LoadInst *li = llvm::dyn_cast<llvm::LoadInst>(it)) {
                llvm::Value *ptr = li->getOperand(0);
                Val v = fetch(ptr, dl, symbols);
                //printf("loading from %p\n", v.o);

                if (width(li, dl) == 1) {
                    Val r = Val(*(bool*)v.o);
                    SET(r);
                    continue;
                } else if (width(li, dl) == 8) {
                    Val r = Val(*(int64_t*)v.o);
                    SET(r);
                    continue;
                } else {
                    li->dump();
                    RELEASE_ASSERT(0, "");
                }
            } else if (llvm::StoreInst *si = llvm::dyn_cast<llvm::StoreInst>(it)) {
                llvm::Value *val = si->getOperand(0);
                llvm::Value *ptr = si->getOperand(1);
                Val v = fetch(val, dl, symbols);
                Val p = fetch(ptr, dl, symbols);

                //printf("storing %lx at %lx\n", v.n, p.n);

                if (width(val, dl) == 1) {
                    *(bool*)p.o = v.b;
                    continue;
                } else if (width(val, dl) == 8) {
                    *(int64_t*)p.o = v.n;
                    continue;
                } else {
                    si->dump();
                    RELEASE_ASSERT(0, "");
                }
            } else if (llvm::CmpInst *ci = llvm::dyn_cast<llvm::CmpInst>(it)) {
                assert(ci->getType() == g.i1);

                Val a0 = fetch(ci->getOperand(0), dl, symbols);
                Val a1 = fetch(ci->getOperand(1), dl, symbols);
                llvm::CmpInst::Predicate pred = ci->getPredicate();
                switch (pred) {
                    case llvm::CmpInst::ICMP_EQ:
                        SET(a0.n == a1.n);
                        continue;
                    case llvm::CmpInst::ICMP_NE:
                        SET(a0.n != a1.n);
                        continue;
                    case llvm::CmpInst::ICMP_SLT:
                        SET(a0.n < a1.n);
                        continue;
                    case llvm::CmpInst::ICMP_SLE:
                        SET(a0.n <= a1.n);
                        continue;
                    case llvm::CmpInst::ICMP_SGT:
                        SET(a0.n > a1.n);
                        continue;
                    case llvm::CmpInst::ICMP_SGE:
                        SET(a0.n >= a1.n);
                        continue;
                    case llvm::CmpInst::FCMP_OEQ:
                        SET(a0.d == a1.d);
                        continue;
                    case llvm::CmpInst::FCMP_UNE:
                        SET(a0.d != a1.d);
                        continue;
                    case llvm::CmpInst::FCMP_OLT:
                        SET(a0.d < a1.d);
                        continue;
                    case llvm::CmpInst::FCMP_OLE:
                        SET(a0.d <= a1.d);
                        continue;
                    case llvm::CmpInst::FCMP_OGT:
                        SET(a0.d > a1.d);
                        continue;
                    case llvm::CmpInst::FCMP_OGE:
                        SET(a0.d >= a1.d);
                        continue;
                    default:
                        ci->dump();
                        RELEASE_ASSERT(0, "");
                }
                continue;
            } else if (llvm::BinaryOperator *bo = llvm::dyn_cast<llvm::BinaryOperator>(it)) {
                if (bo->getOperand(0)->getType() == g.i64 || bo->getOperand(0)->getType() == g.i1) {
                    //assert(bo->getOperand(0)->getType() == g.i64);
                    //assert(bo->getOperand(1)->getType() == g.i64);

                    Val a0 = fetch(bo->getOperand(0), dl, symbols);
                    Val a1 = fetch(bo->getOperand(1), dl, symbols);
                    llvm::Instruction::BinaryOps opcode = bo->getOpcode();
                    switch (opcode) {
                        case llvm::Instruction::Add:
                            SET(a0.n + a1.n);
                            continue;
                        case llvm::Instruction::And:
                            SET(a0.n & a1.n);
                            continue;
                        case llvm::Instruction::AShr:
                            SET(a0.n >> a1.n);
                            continue;
                        case llvm::Instruction::Mul:
                            SET(a0.n * a1.n);
                            continue;
                        case llvm::Instruction::Or:
                            SET(a0.n | a1.n);
                            continue;
                        case llvm::Instruction::Shl:
                            SET(a0.n << a1.n);
                            continue;
                        case llvm::Instruction::Sub:
                            SET(a0.n - a1.n);
                            continue;
                        case llvm::Instruction::Xor:
                            SET(a0.n ^ a1.n);
                            continue;
                        default:
                            bo->dump();
                            RELEASE_ASSERT(0, "");
                    }
                    continue;
                } else if (bo->getOperand(0)->getType() == g.double_) {
                    //assert(bo->getOperand(0)->getType() == g.i64);
                    //assert(bo->getOperand(1)->getType() == g.i64);

                    double lhs = fetch(bo->getOperand(0), dl, symbols).d;
                    double rhs = fetch(bo->getOperand(1), dl, symbols).d;
                    llvm::Instruction::BinaryOps opcode = bo->getOpcode();
                    switch (opcode) {
                        case llvm::Instruction::FAdd:
                            SET(lhs + rhs);
                            continue;
                        case llvm::Instruction::FMul:
                            SET(lhs * rhs);
                            continue;
                        case llvm::Instruction::FSub:
                            SET(lhs - rhs);
                            continue;
                        default:
                            bo->dump();
                            RELEASE_ASSERT(0, "");
                    }
                    continue;
                } else {
                    bo->dump();
                    RELEASE_ASSERT(0, "");
                }
            } else if (llvm::GetElementPtrInst *gep = llvm::dyn_cast<llvm::GetElementPtrInst>(it)) {
                int64_t base = fetch(gep->getPointerOperand(), dl, symbols).n;

                llvm::User::value_op_iterator begin = gep->value_op_begin();
                ++begin;
                std::vector<llvm::Value*> indices(begin, gep->value_op_end());

                int64_t offset = dl.getIndexedOffset(gep->getPointerOperandType(), indices);
                //gep->dump();
                //printf("offset for inst: %ld (base is %lx)\n", offset, base);
                SET(base + offset);
                continue;
            } else if (llvm::AllocaInst *al = llvm::dyn_cast<llvm::AllocaInst>(it)) {
                int size = fetch(al->getArraySize(), dl, symbols).n * width(al->getAllocatedType(), dl);
                void* ptr = alloca(size);
                //void* ptr = malloc(size);
                //printf("alloca()'d at %p\n", ptr);
                SET((int64_t)ptr);
                continue;
            } else if (llvm::SIToFPInst *si = llvm::dyn_cast<llvm::SIToFPInst>(it)) {
                assert(width(si->getOperand(0), dl) == 8);
                SET((double)fetch(si->getOperand(0), dl, symbols).n);
                continue;
            } else if (llvm::BitCastInst *bc = llvm::dyn_cast<llvm::BitCastInst>(it)) {
                assert(width(bc->getOperand(0), dl) == 8);
                SET(fetch(bc->getOperand(0), dl, symbols));
                continue;
            } else if (llvm::IntToPtrInst *bc = llvm::dyn_cast<llvm::IntToPtrInst>(it)) {
                assert(width(bc->getOperand(0), dl) == 8);
                SET(fetch(bc->getOperand(0), dl, symbols));
                continue;
            } else if (llvm::CallInst *ci = llvm::dyn_cast<llvm::CallInst>(it)) {
                void* f;
                int arg_start;
                if (ci->getCalledFunction() && (ci->getCalledFunction()->getName() == "llvm.experimental.patchpoint.void" || ci->getCalledFunction()->getName() == "llvm.experimental.patchpoint.i64")) {
                    //ci->dump();
                    f = (void*)fetch(ci->getArgOperand(2), dl, symbols).n;
                    arg_start = 4;
                } else {
                    f = (void*)fetch(ci->getCalledValue(), dl, symbols).n;
                    arg_start = 0;
                }

                if (VERBOSITY("interpreter") >= 2) printf("calling %s\n", g.func_addr_registry.getFuncNameAtAddress(f, true).c_str());

                std::vector<Val> args;
                int nargs = ci->getNumArgOperands();
                for (int i = arg_start; i < nargs; i++) {
                    //ci->getArgOperand(i)->dump();
                    args.push_back(fetch(ci->getArgOperand(i), dl, symbols));
                }

                int npassed_args = nargs - arg_start;
                //printf("%d %d %d\n", nargs, arg_start, npassed_args);

#ifdef TIME_INTERPRETS
                this_us += _t.end();
#endif
                // This is dumb but I don't know how else to do it:

                int mask = 1;
                if (ci->getType() == g.double_)
                    mask = 3;
                else
                    mask = 2;

                for (int i = 0; i < npassed_args; i++) {
                    mask <<= 1;
                    if (ci->getOperand(i)->getType() == g.double_)
                        mask |= 1;
                }

                Val r((int64_t)0);
                switch (mask) {
                    case 0b10:
                        r = reinterpret_cast<int64_t (*)()>(f)();
                        break;
                    case 0b11:
                        r = reinterpret_cast<double (*)()>(f)();
                        break;
                    case 0b100:
                        r = reinterpret_cast<int64_t (*)(int64_t)>(f)(args[0].n);
                        break;
                    case 0b101:
                        r = reinterpret_cast<int64_t (*)(double)>(f)(args[0].d);
                        break;
                    case 0b110:
                        r = reinterpret_cast<double (*)(int64_t)>(f)(args[0].n);
                        break;
                    case 0b1000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t)>(f)(args[0].n, args[1].n);
                        break;
                    case 0b1001:
                        r = reinterpret_cast<int64_t (*)(int64_t, double)>(f)(args[0].n, args[1].d);
                        break;
                    case 0b1011:
                        r = reinterpret_cast<int64_t (*)(double, double)>(f)(args[0].d, args[1].d);
                        break;
                    case 0b1111:
                        r = reinterpret_cast<double (*)(double, double)>(f)(args[0].d, args[1].d);
                        break;
                    case 0b10000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n);
                        break;
                    case 0b10001:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, double)>(f)(args[0].n, args[1].n, args[2].d);
                        break;
                    case 0b10011:
                        r = reinterpret_cast<int64_t (*)(int64_t, double, double)>(f)(args[0].n, args[1].d, args[2].d);
                        break;
                    case 0b100000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n, args[3].n);
                        break;
                    case 0b100001:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, double)>(f)(args[0].n, args[1].n, args[2].n, args[3].d);
                        break;
                    case 0b100110:
                        r = reinterpret_cast<int64_t (*)(int64_t, double, double, int64_t)>(f)(args[0].n, args[1].d, args[2].d, args[3].n);
                        break;
                    case 0b101010:
                        r = reinterpret_cast<int64_t (*)(double, int, double, int64_t)>(f)(args[0].d, args[1].n, args[2].d, args[3].n);
                        break;
                    case 0b1000000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n, args[3].n, args[4].n);
                        break;
                    case 0b10000000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n, args[3].n, args[4].n, args[5].n);
                        break;
                    case 0b100000000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n, args[3].n, args[4].n, args[5].n, args[6].n);
                        break;
                    case 0b1000000000:
                        r = reinterpret_cast<int64_t (*)(int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t, int64_t)>(f)(args[0].n, args[1].n, args[2].n, args[3].n, args[4].n, args[5].n, args[6].n, args[7].n);
                        break;
                    default:
                        it->dump();
                        RELEASE_ASSERT(0, "%d", mask);
                        break;
                }
                if (ci->getType() != g.void_)
                    SET(r);


#ifdef TIME_INTERPRETS
                _t.restart("to interpret", 10000000);
#endif
                continue;
            } else if (llvm::SelectInst *si = llvm::dyn_cast<llvm::SelectInst>(it)) {
                Val test = fetch(si->getCondition(), dl, symbols);
                Val vt = fetch(si->getTrueValue(), dl, symbols);
                Val vf = fetch(si->getFalseValue(), dl, symbols);
                if (test.b)
                    SET(vt);
                else
                    SET(vf);
                continue;
            } else if (llvm::PHINode *phi = llvm::dyn_cast<llvm::PHINode>(it)) {
                assert(prevblock);
                SET(fetch(phi->getIncomingValueForBlock(prevblock), dl, symbols));
                continue;
            } else if (llvm::BranchInst *br = llvm::dyn_cast<llvm::BranchInst>(it)) {
                prevblock = curblock;
                if (br->isConditional()) {
                    Val t = fetch(br->getCondition(), dl, symbols);
                    if (t.b) {
                        curblock = br->getSuccessor(0);
                    } else {
                        curblock = br->getSuccessor(1);
                    }
                } else {
                    curblock = br->getSuccessor(0);
                }
                //if (VERBOSITY()) {
                    //printf("jumped to %s\n", curblock->getName().data());
                //}
                break;
            } else if (llvm::ReturnInst *ret = llvm::dyn_cast<llvm::ReturnInst>(it)) {
                llvm::Value* r = ret->getReturnValue();

#ifdef TIME_INTERPRETS
                this_us += _t.end();
                static StatCounter us_interpreting("us_interpreting");
                us_interpreting.log(this_us);
#endif

                if (!r)
                    return NULL;
                Val t = fetch(r, dl, symbols);
                return t.o;
            }


            it->dump();
            RELEASE_ASSERT(0, "");
        }
    }

    RELEASE_ASSERT(0, "");
}

}