#include "aicombataction.hpp"

#include <components/esm3/loadench.hpp>
#include <components/esm3/loadmgef.hpp>

#include "../mwbase/environment.hpp"
#include "../mwbase/mechanicsmanager.hpp"
#include "../mwbase/world.hpp"

#include "../mwworld/actionequip.hpp"
#include "../mwworld/cellstore.hpp"
#include "../mwworld/class.hpp"
#include "../mwworld/esmstore.hpp"
#include "../mwworld/inventorystore.hpp"

#include "actorutil.hpp"
#include "combat.hpp"
#include "npcstats.hpp"
#include "spellpriority.hpp"
#include "weaponpriority.hpp"
#include "weapontype.hpp"

namespace MWMechanics
{
    float suggestCombatRange(int rangeTypes)
    {
        static const float fCombatDistance = MWBase::Environment::get()
                                                 .getESMStore()
                                                 ->get<ESM::GameSetting>()
                                                 .find("fCombatDistance")
                                                 ->mValue.getFloat();
        static float fHandToHandReach = MWBase::Environment::get()
                                            .getESMStore()
                                            ->get<ESM::GameSetting>()
                                            .find("fHandToHandReach")
                                            ->mValue.getFloat();

        // This distance is a possible distance of melee attack
        static float distance = fCombatDistance * std::max(2.f, fHandToHandReach);

        if (rangeTypes & RangeTypes::Touch)
        {
            return fCombatDistance;
        }

        return distance * 4;
    }

    void ActionSpell::prepare(const MWWorld::Ptr& actor)
    {
        actor.getClass().getCreatureStats(actor).getSpells().setSelectedSpell(mSpellId);
        actor.getClass().getCreatureStats(actor).setDrawState(DrawState::Spell);
        if (actor.getClass().hasInventoryStore(actor))
        {
            MWWorld::InventoryStore& inv = actor.getClass().getInventoryStore(actor);
            inv.setSelectedEnchantItem(inv.end());
        }

        const ESM::Spell* spell = MWBase::Environment::get().getESMStore()->get<ESM::Spell>().find(mSpellId);
        MWBase::Environment::get().getWorld()->preloadEffects(&spell->mEffects);
    }

    float ActionSpell::getCombatRange(bool& isRanged) const
    {
        const ESM::Spell* spell = MWBase::Environment::get().getESMStore()->get<ESM::Spell>().find(mSpellId);
        int types = getRangeTypes(spell->mEffects);

        isRanged = (types & RangeTypes::Target) | (types & RangeTypes::Self);
        return suggestCombatRange(types);
    }

    void ActionEnchantedItem::prepare(const MWWorld::Ptr& actor)
    {
        actor.getClass().getCreatureStats(actor).getSpells().setSelectedSpell(ESM::RefId());
        actor.getClass().getInventoryStore(actor).setSelectedEnchantItem(mItem);
        actor.getClass().getCreatureStats(actor).setDrawState(DrawState::Spell);
    }

    float ActionEnchantedItem::getCombatRange(bool& isRanged) const
    {
        const ESM::Enchantment* enchantment = MWBase::Environment::get().getESMStore()->get<ESM::Enchantment>().find(
            mItem->getClass().getEnchantment(*mItem));
        int types = getRangeTypes(enchantment->mEffects);

        isRanged = (types & RangeTypes::Target) | (types & RangeTypes::Self);
        return suggestCombatRange(types);
    }

    float ActionPotion::getCombatRange(bool& isRanged) const
    {
        // Distance doesn't matter since this action has no animation
        // If we want to back away slightly to avoid enemy hits, we should set isRanged to "true"
        return 600.f;
    }

    void ActionPotion::prepare(const MWWorld::Ptr& actor)
    {
        actor.getClass().consume(mPotion, actor);
    }

    void ActionWeapon::prepare(const MWWorld::Ptr& actor)
    {
        if (actor.getClass().hasInventoryStore(actor))
        {
            if (mWeapon.isEmpty())
                actor.getClass().getInventoryStore(actor).unequipSlot(MWWorld::InventoryStore::Slot_CarriedRight);
            else
            {
                MWWorld::ActionEquip equip(mWeapon);
                equip.execute(actor);
            }

            if (!mAmmunition.isEmpty())
            {
                MWWorld::ActionEquip equip(mAmmunition);
                equip.execute(actor);
            }
        }
        actor.getClass().getCreatureStats(actor).setDrawState(DrawState::Weapon);
    }

    float ActionWeapon::getCombatRange(bool& isRanged) const
    {
        isRanged = false;

        static const float fCombatDistance = MWBase::Environment::get()
                                                 .getESMStore()
                                                 ->get<ESM::GameSetting>()
                                                 .find("fCombatDistance")
                                                 ->mValue.getFloat();
        static const float fProjectileMaxSpeed = MWBase::Environment::get()
                                                     .getESMStore()
                                                     ->get<ESM::GameSetting>()
                                                     .find("fProjectileMaxSpeed")
                                                     ->mValue.getFloat();

        if (mWeapon.isEmpty())
        {
            static float fHandToHandReach = MWBase::Environment::get()
                                                .getESMStore()
                                                ->get<ESM::GameSetting>()
                                                .find("fHandToHandReach")
                                                ->mValue.getFloat();
            return fHandToHandReach * fCombatDistance;
        }

        const ESM::Weapon* weapon = mWeapon.get<ESM::Weapon>()->mBase;
        if (MWMechanics::getWeaponType(weapon->mData.mType)->mWeaponClass != ESM::WeaponType::Melee)
        {
            isRanged = true;
            return fProjectileMaxSpeed;
        }
        else
            return weapon->mData.mReach * fCombatDistance;
    }

    const ESM::Weapon* ActionWeapon::getWeapon() const
    {
        if (mWeapon.isEmpty())
            return nullptr;
        return mWeapon.get<ESM::Weapon>()->mBase;
    }

    std::unique_ptr<Action> prepareNextAction(const MWWorld::Ptr& actor, const MWWorld::Ptr& enemy)
    {
        Spells& spells = actor.getClass().getCreatureStats(actor).getSpells();

        float bestActionRating = 0.f;
        float antiFleeRating = 0.f;
        // Default to hand-to-hand combat
        std::unique_ptr<Action> bestAction = std::make_unique<ActionWeapon>(MWWorld::Ptr());
        if (actor.getClass().isNpc() && actor.getClass().getNpcStats(actor).isWerewolf())
        {
            bestAction->prepare(actor);
            return bestAction;
        }

        if (actor.getClass().hasInventoryStore(actor))
        {
            MWWorld::InventoryStore& store = actor.getClass().getInventoryStore(actor);

            for (MWWorld::ContainerStoreIterator it = store.begin(); it != store.end(); ++it)
            {
                float rating = ratePotion(*it, actor);
                if (rating > bestActionRating)
                {
                    bestActionRating = rating;
                    bestAction = std::make_unique<ActionPotion>(*it);
                    antiFleeRating = std::numeric_limits<float>::max();
                }
            }

            for (MWWorld::ContainerStoreIterator it = store.begin(); it != store.end(); ++it)
            {
                float rating = rateMagicItem(*it, actor, enemy);
                if (rating > bestActionRating)
                {
                    bestActionRating = rating;
                    bestAction = std::make_unique<ActionEnchantedItem>(it);
                    antiFleeRating = std::numeric_limits<float>::max();
                }
            }

            MWWorld::Ptr bestArrow;
            float bestArrowRating = rateAmmo(actor, enemy, bestArrow, ESM::Weapon::Arrow);

            MWWorld::Ptr bestBolt;
            float bestBoltRating = rateAmmo(actor, enemy, bestBolt, ESM::Weapon::Bolt);

            for (MWWorld::ContainerStoreIterator it = store.begin(); it != store.end(); ++it)
            {
                float rating = rateWeapon(*it, actor, enemy, -1, bestArrowRating, bestBoltRating);
                if (rating > bestActionRating)
                {
                    const ESM::Weapon* weapon = it->get<ESM::Weapon>()->mBase;
                    int ammotype = getWeaponType(weapon->mData.mType)->mAmmoType;

                    MWWorld::Ptr ammo;
                    if (ammotype == ESM::Weapon::Arrow)
                        ammo = bestArrow;
                    else if (ammotype == ESM::Weapon::Bolt)
                        ammo = bestBolt;

                    bestActionRating = rating;
                    bestAction = std::make_unique<ActionWeapon>(*it, ammo);
                    antiFleeRating = vanillaRateWeaponAndAmmo(*it, ammo, actor, enemy);
                }
            }
        }

        for (const ESM::Spell* spell : spells)
        {
            float rating = rateSpell(spell, actor, enemy);
            if (rating > bestActionRating)
            {
                bestActionRating = rating;
                bestAction = std::make_unique<ActionSpell>(spell->mId);
                antiFleeRating = vanillaRateSpell(spell, actor, enemy);
            }
        }

        if (makeFleeDecision(actor, enemy, antiFleeRating))
            bestAction = std::make_unique<ActionFlee>();

        if (bestAction.get())
            bestAction->prepare(actor);

        return bestAction;
    }

    float getBestActionRating(const MWWorld::Ptr& actor, const MWWorld::Ptr& enemy)
    {
        Spells& spells = actor.getClass().getCreatureStats(actor).getSpells();

        float bestActionRating = 0.f;
        // Default to hand-to-hand combat
        if (actor.getClass().isNpc() && actor.getClass().getNpcStats(actor).isWerewolf())
        {
            return bestActionRating;
        }

        if (actor.getClass().hasInventoryStore(actor))
        {
            MWWorld::InventoryStore& store = actor.getClass().getInventoryStore(actor);

            for (MWWorld::ContainerStoreIterator it = store.begin(); it != store.end(); ++it)
            {
                float rating = rateMagicItem(*it, actor, enemy);
                if (rating > bestActionRating)
                {
                    bestActionRating = rating;
                }
            }

            float bestArrowRating = rateAmmo(actor, enemy, ESM::Weapon::Arrow);

            float bestBoltRating = rateAmmo(actor, enemy, ESM::Weapon::Bolt);

            for (MWWorld::ContainerStoreIterator it = store.begin(); it != store.end(); ++it)
            {
                float rating = rateWeapon(*it, actor, enemy, -1, bestArrowRating, bestBoltRating);
                if (rating > bestActionRating)
                {
                    bestActionRating = rating;
                }
            }
        }

        for (const ESM::Spell* spell : spells)
        {
            float rating = rateSpell(spell, actor, enemy);
            if (rating > bestActionRating)
            {
                bestActionRating = rating;
            }
        }

        return bestActionRating;
    }

    float getDistanceMinusHalfExtents(const MWWorld::Ptr& actor1, const MWWorld::Ptr& actor2, bool minusZDist)
    {
        osg::Vec3f actor1Pos = actor1.getRefData().getPosition().asVec3();
        osg::Vec3f actor2Pos = actor2.getRefData().getPosition().asVec3();

        float dist = (actor1Pos - actor2Pos).length();

        if (minusZDist)
            dist -= std::abs(actor1Pos.z() - actor2Pos.z());

        return (dist - MWBase::Environment::get().getWorld()->getHalfExtents(actor1).y()
            - MWBase::Environment::get().getWorld()->getHalfExtents(actor2).y());
    }

    float getMaxAttackDistance(const MWWorld::Ptr& actor)
    {
        const CreatureStats& stats = actor.getClass().getCreatureStats(actor);
        const MWWorld::Store<ESM::GameSetting>& gmst
            = MWBase::Environment::get().getESMStore()->get<ESM::GameSetting>();

        const ESM::RefId& selectedSpellId = stats.getSpells().getSelectedSpell();
        MWWorld::Ptr selectedEnchItem;

        MWWorld::Ptr activeWeapon, activeAmmo;
        if (actor.getClass().hasInventoryStore(actor))
        {
            MWWorld::InventoryStore& invStore = actor.getClass().getInventoryStore(actor);

            MWWorld::ContainerStoreIterator item = invStore.getSlot(MWWorld::InventoryStore::Slot_CarriedRight);
            if (item != invStore.end() && item.getType() == MWWorld::ContainerStore::Type_Weapon)
                activeWeapon = *item;

            item = invStore.getSlot(MWWorld::InventoryStore::Slot_Ammunition);
            if (item != invStore.end() && item.getType() == MWWorld::ContainerStore::Type_Weapon)
                activeAmmo = *item;

            if (invStore.getSelectedEnchantItem() != invStore.end())
                selectedEnchItem = *invStore.getSelectedEnchantItem();
        }

        float dist = 1.0f;
        if (activeWeapon.isEmpty() && !selectedSpellId.empty() && !selectedEnchItem.isEmpty())
        {
            static const float fHandToHandReach = gmst.find("fHandToHandReach")->mValue.getFloat();
            dist = fHandToHandReach;
        }
        else if (stats.getDrawState() == MWMechanics::DrawState::Spell)
        {
            dist = 1.0f;
            if (!selectedSpellId.empty())
            {
                const ESM::Spell* spell
                    = MWBase::Environment::get().getESMStore()->get<ESM::Spell>().find(selectedSpellId);
                for (std::vector<ESM::ENAMstruct>::const_iterator effectIt = spell->mEffects.mList.begin();
                     effectIt != spell->mEffects.mList.end(); ++effectIt)
                {
                    if (effectIt->mRange == ESM::RT_Target)
                    {
                        const ESM::MagicEffect* effect
                            = MWBase::Environment::get().getESMStore()->get<ESM::MagicEffect>().find(
                                effectIt->mEffectID);
                        dist = effect->mData.mSpeed;
                        break;
                    }
                }
            }
            else if (!selectedEnchItem.isEmpty())
            {
                const ESM::RefId& enchId = selectedEnchItem.getClass().getEnchantment(selectedEnchItem);
                if (!enchId.empty())
                {
                    const ESM::Enchantment* ench
                        = MWBase::Environment::get().getESMStore()->get<ESM::Enchantment>().find(enchId);
                    for (std::vector<ESM::ENAMstruct>::const_iterator effectIt = ench->mEffects.mList.begin();
                         effectIt != ench->mEffects.mList.end(); ++effectIt)
                    {
                        if (effectIt->mRange == ESM::RT_Target)
                        {
                            const ESM::MagicEffect* effect
                                = MWBase::Environment::get().getESMStore()->get<ESM::MagicEffect>().find(
                                    effectIt->mEffectID);
                            dist = effect->mData.mSpeed;
                            break;
                        }
                    }
                }
            }

            static const float fTargetSpellMaxSpeed = gmst.find("fTargetSpellMaxSpeed")->mValue.getFloat();
            dist *= std::max(1000.0f, fTargetSpellMaxSpeed);
        }
        else if (!activeWeapon.isEmpty())
        {
            const ESM::Weapon* esmWeap = activeWeapon.get<ESM::Weapon>()->mBase;
            if (MWMechanics::getWeaponType(esmWeap->mData.mType)->mWeaponClass != ESM::WeaponType::Melee)
            {
                static const float fTargetSpellMaxSpeed = gmst.find("fProjectileMaxSpeed")->mValue.getFloat();
                dist = fTargetSpellMaxSpeed;
                if (!activeAmmo.isEmpty())
                {
                    const ESM::Weapon* esmAmmo = activeAmmo.get<ESM::Weapon>()->mBase;
                    dist *= esmAmmo->mData.mSpeed;
                }
            }
            else if (esmWeap->mData.mReach > 1)
            {
                dist = esmWeap->mData.mReach;
            }
        }

        dist = (dist > 0.f) ? dist : 1.0f;

        static const float fCombatDistance = gmst.find("fCombatDistance")->mValue.getFloat();
        static const float fCombatDistanceWerewolfMod = gmst.find("fCombatDistanceWerewolfMod")->mValue.getFloat();

        float combatDistance = fCombatDistance;
        if (actor.getClass().isNpc() && actor.getClass().getNpcStats(actor).isWerewolf())
            combatDistance *= (fCombatDistanceWerewolfMod + 1.0f);

        if (dist < combatDistance)
            dist *= combatDistance;

        return dist;
    }

    bool canFight(const MWWorld::Ptr& actor, const MWWorld::Ptr& enemy)
    {
        ESM::Position actorPos = actor.getRefData().getPosition();
        ESM::Position enemyPos = enemy.getRefData().getPosition();

        if (isTargetMagicallyHidden(enemy)
            && !MWBase::Environment::get().getMechanicsManager()->awarenessCheck(enemy, actor))
        {
            return false;
        }

        if (actor.getClass().isPureWaterCreature(actor))
        {
            if (!MWBase::Environment::get().getWorld()->isWading(enemy))
                return false;
        }

        float atDist = getMaxAttackDistance(actor);
        if (atDist > getDistanceMinusHalfExtents(actor, enemy) && atDist > std::abs(actorPos.pos[2] - enemyPos.pos[2]))
        {
            if (MWBase::Environment::get().getWorld()->getLOS(actor, enemy))
                return true;
        }

        if (actor.getClass().isPureLandCreature(actor)
            && MWBase::Environment::get().getWorld()->isWalkingOnWater(enemy))
        {
            return false;
        }

        if (actor.getClass().isPureFlyingCreature(actor) || actor.getClass().isPureLandCreature(actor))
        {
            if (MWBase::Environment::get().getWorld()->isSwimming(enemy))
                return false;
        }

        if (actor.getClass().isBipedal(actor) || !actor.getClass().canFly(actor))
        {
            if (enemy.getClass()
                    .getCreatureStats(enemy)
                    .getMagicEffects()
                    .getOrDefault(ESM::MagicEffect::Levitate)
                    .getMagnitude()
                > 0)
            {
                float attackDistance = getMaxAttackDistance(actor);
                if ((attackDistance + actorPos.pos[2]) < enemyPos.pos[2])
                {
                    if (enemy.getCell()->isExterior())
                    {
                        if (attackDistance < (enemyPos.pos[2]
                                - MWBase::Environment::get().getWorld()->getTerrainHeightAt(
                                    enemyPos.asVec3(), enemy.getCell()->getCell()->getWorldSpace())))
                            return false;
                    }
                }
            }
        }

        if (!actor.getClass().canWalk(actor) && !actor.getClass().isBipedal(actor))
            return true;

        if (actor.getClass()
                .getCreatureStats(actor)
                .getMagicEffects()
                .getOrDefault(ESM::MagicEffect::Levitate)
                .getMagnitude()
            > 0)
            return true;

        if (MWBase::Environment::get().getWorld()->isSwimming(actor))
            return true;

        if (getDistanceMinusHalfExtents(actor, enemy, true) <= 0.0f)
            return false;

        return true;
    }

    float vanillaRateFlee(const MWWorld::Ptr& actor, const MWWorld::Ptr& enemy)
    {
        const CreatureStats& stats = actor.getClass().getCreatureStats(actor);
        const MWWorld::Store<ESM::GameSetting>& gmst
            = MWBase::Environment::get().getESMStore()->get<ESM::GameSetting>();

        const int flee = stats.getAiSetting(AiSetting::Flee).getModified();
        if (flee >= 100)
            return flee;

        static const float fAIFleeHealthMult = gmst.find("fAIFleeHealthMult")->mValue.getFloat();
        static const float fAIFleeFleeMult = gmst.find("fAIFleeFleeMult")->mValue.getFloat();

        float healthPercentage = stats.getHealth().getRatio(false);
        float rating = (1.0f - healthPercentage) * fAIFleeHealthMult + flee * fAIFleeFleeMult;

        static const int iWereWolfLevelToAttack = gmst.find("iWereWolfLevelToAttack")->mValue.getInteger();

        if (actor.getClass().isNpc() && enemy.getClass().isNpc())
        {
            if (enemy.getClass().getNpcStats(enemy).isWerewolf() && stats.getLevel() < iWereWolfLevelToAttack)
            {
                static const int iWereWolfFleeMod = gmst.find("iWereWolfFleeMod")->mValue.getInteger();
                rating = iWereWolfFleeMod;
            }
        }

        if (rating != 0.0f)
            rating += getFightDistanceBias(actor, enemy);

        return rating;
    }

    bool makeFleeDecision(const MWWorld::Ptr& actor, const MWWorld::Ptr& enemy, float antiFleeRating)
    {
        float fleeRating = vanillaRateFlee(actor, enemy);
        if (fleeRating < 100.0f)
            fleeRating = 0.0f;

        if (fleeRating > antiFleeRating)
            return true;

        // Run away after summoning a creature if we have nothing to use but fists.
        if (antiFleeRating == 0.0f && !actor.getClass().getCreatureStats(actor).getSummonedCreatureMap().empty())
            return true;

        return false;
    }
}
