import datetime
import json
import logging

from common.status import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR, HTTP_503_SERVICE_UNAVAILABLE, HTTP_404_NOT_FOUND
from flask import Blueprint, request, session
from flask_jwt_extended import jwt_required
from median.constant import HistoryType, EtatAdresse, MEDIANWEB_POSTE, CONFIG_WEB_CLE
from median.models import Stock, Adresse, CodeBlocage, Product, Cip, Ucd, Historique, Magasin, Gpao, Service, UnitDose
from median.views import RawConfig
from peewee import JOIN, fn, DoesNotExist
from common.util import get_counter
from median.database import mysql_db

stock_blueprint = Blueprint("stock", __name__)

logger = logging.getLogger("median")

gtin_management = RawConfig(MEDIANWEB_POSTE, CONFIG_WEB_CLE).read("k_web_gtin_management")


def _padAddressField(adr):
    _adr_items = adr.split(".")
    # if len(_adr_items) != 5:
    #     return False
    _o = []
    for a in _adr_items:
        _o.append(a.rjust(3))

    return ".".join(_o)


def _moveContainerForward(adr):
    _els = adr.split(".")
    _pos = int(_els[-1].lstrip())
    if _pos > 1:
        return False

    _back_adrs = []
    _els[-1] = "  2"
    _back_adrs.append(".".join(_els))
    _els[-1] = "  3"
    _back_adrs.append(".".join(_els))

    _a = Adresse.select(Adresse.contenant).where(
        (Adresse.adresse == _back_adrs[0]) | (Adresse.adresse == _back_adrs[1])
    )

    _cont = "" if _a.count() == 0 else _a[0].contenant

    (Adresse.update({Adresse.contenant: _cont, Adresse.etat: "O"}).where(Adresse.adresse == adr).execute())

    (
        Adresse.update({Adresse.contenant: "", Adresse.etat: "L"})
        .where((Adresse.adresse == _back_adrs[0]) | (Adresse.adresse == _back_adrs[1]))
        .execute()
    )

    (
        Stock.update({Stock.adresse: adr})
        .where((Stock.adresse == _back_adrs[0]) | (Stock.adresse == _back_adrs[1]))
        .execute()
    )


def _convertUserFriendlyDateToDBDateTime(ufDate):
    _dl = ufDate.split("-")
    if len(_dl[0]) == 4:
        return ufDate + " 00:00:00"
    return _dl[2] + "-" + _dl[1] + "-" + _dl[0] + " 00:00:00"


@stock_blueprint.route("<string:ref>", methods=["PATCH"])
@jwt_required()
def get_all(ref):
    # WARNING : This talks about refs but it uses PK ! <<<<<<<<<<<
    args = json.loads(request.data)
    v_filter_by_magasin = args.get("filterByMagasin", None)

    try:
        if not v_filter_by_magasin:
            filtered_stocks_query = (
                Stock.select(Stock, Adresse, CodeBlocage.libelle, CodeBlocage.valeur)
                .switch(Stock)
                .join(Adresse, JOIN.LEFT_OUTER, on=Adresse.adresse == Stock.adresse)
                .alias("adr")
                .switch(Stock)
                .join(CodeBlocage, JOIN.LEFT_OUTER, on=CodeBlocage.valeur == Adresse.bloque)
                .join(Product, on=Stock.reference == Product.reference)
                .where(Product.pk == ref)
            )
        else:
            ms = v_filter_by_magasin.split(",")

            filtered_stocks_query = (
                Stock.select(Stock.pk, Adresse, CodeBlocage.libelle, CodeBlocage.valeur)
                .switch(Stock)
                .join(Adresse, JOIN.LEFT_OUTER, on=Adresse.adresse == Stock.adresse)
                .alias("adr")
                .switch(Stock)
                .join(CodeBlocage, JOIN.LEFT_OUTER, on=CodeBlocage.valeur == Adresse.bloque)
                .join(Product, on=Stock.reference == Product.reference)
                .where((Product.pk == ref) & (Stock.magasin << ms))
            )

        filtered_stocks = filtered_stocks_query.order_by(Stock.adresse, Stock.date_peremption)

        logger.debug("Lines : %s." % len(filtered_stocks))

        return {
            "data": [
                {
                    "pk": s.pk,
                    "reference": s.reference,
                    "format": s.adresse.format if type(s.adresse) is not str else "",
                    "lock_code": s.codeblocage.valeur if hasattr(s, "codeblocage") else 0,
                    "lock_custom_msg": s.adresse.bloque_message if type(s.adresse) is not str else "",
                    "lock_msg": s.codeblocage.libelle if hasattr(s, "codeblocage") else "",
                    "bloque": ("OUI" if s.bloque else "NON"),
                    "emplacement": s.adresse.adresse if type(s.adresse) is not str else s.adresse,
                    "quantite": s.quantite,
                    "ucd": s.ucd,
                    "gtin": s.cip,
                    "date_sortie": str(s.date_sortie or "-"),
                    "date_entree": str(s.date_entree or "-"),
                    "lot": s.lot,
                    "date_peremption": str(s.date_peremption).split(" ")[0],
                    "contenant": s.contenant,
                    "fraction": s.fraction,
                    "qte_blister_bac": s.qte_blister_bac,
                    "nb_dose_cut": s.nb_dose_cut,
                    "du": s.DU,
                    "qte_prod_blister": s.qte_prod_blister,
                    "magasin": s.magasin,
                    "serial": s.serial,
                }
                for s in filtered_stocks
            ]
        }, HTTP_200_OK

    except Exception as error:
        logger.error("Get stock Datatables raised an exception: ", error.args)
        return {"message": error.args}, HTTP_500_INTERNAL_SERVER_ERROR


@stock_blueprint.route("serial/<string:serial>", methods=["POST"])
@jwt_required()
def get_stock_by_serial(serial):
    try:
        unit_doses = UnitDose.select().where(UnitDose.serial == serial).order_by(-UnitDose.pk)

        if not unit_doses.exists():
            return {"message": "ui.stock.unknown"}, HTTP_404_NOT_FOUND

        # Get all container codes from the unit doses
        container_codes = [ud.contenant for ud in unit_doses]

        # Fetch all stock lines associated with these containers
        stock = (
            Stock.select(Stock, Adresse, Magasin.eco_type)
            .join(Magasin, JOIN.INNER, on=(Magasin.mag == Stock.magasin).alias("magasinmodel"))
            .join_from(Stock, Adresse, JOIN.LEFT_OUTER, on=(Adresse.adresse == Stock.adresse).alias)
            .where(Stock.contenant.in_(container_codes))
            .order_by(Stock.adresse, Stock.date_peremption)
        )

        if stock.count() == 0:
            return {"message": "ui.stock.unknown"}, HTTP_404_NOT_FOUND

    except Exception as error:
        logger.error(f"Error retrieving stock by serial: {error}")
        return {"message": str(error)}, HTTP_500_INTERNAL_SERVER_ERROR

    return {
        "data": [
            {
                "pk": s.pk,
                "reference": s.reference,
                "format": s.adresse.format if type(s.adresse) is not str else "",
                "lock_code": s.codeblocage.valeur if hasattr(s, "codeblocage") else 0,
                "lock_custom_msg": s.adresse.bloque_message if type(s.adresse) is not str else "",
                "lock_msg": s.codeblocage.libelle if hasattr(s, "codeblocage") else "",
                "bloque": ("OUI" if s.bloque else "NON"),
                "emplacement": s.adresse.adresse if type(s.adresse) is not str else s.adresse,
                "quantite": s.quantite,
                "ucd": s.ucd,
                "cip": s.cip,
                "date_sortie": str(s.date_sortie or "-"),
                "date_entree": str(s.date_entree or "-"),
                "lot": s.lot,
                "date_peremption": str(s.date_peremption).split(" ")[0],
                "contenant": s.contenant,
                "fraction": s.fraction,
                "qte_blister_bac": s.qte_blister_bac,
                "nb_dose_cut": s.nb_dose_cut,
                "du": s.DU,
                "qte_prod_blister": s.qte_prod_blister,
                "magasin": s.magasin,
                "serial": s.serial,
                "eco_type": s.magasinmodel.eco_type,
                "adr_state": s.adresse.etat if hasattr(s.adresse, "etat") else "",
            }
            for s in stock
        ],
        "containers": [
            {"code": c, "quantity": Stock.select(fn.SUM(Stock.quantite)).where(Stock.contenant == c).scalar() or 0}
            for c in container_codes
        ],
    }, HTTP_200_OK


@stock_blueprint.route("<string:ref>", methods=["POST"])
@jwt_required()
def create_stock(ref):
    logger.info("Création d'une ligne de stock...")

    args = request.form
    _dp = _convertUserFriendlyDateToDBDateTime(args["date_peremption"])
    _ucd = args["ucd"].strip()

    try:
        _calc_fifo = Stock.select((fn.IFNULL(fn.MAX(Stock.id_fifo), 0) + 1).alias("new_fifo")).where(
            Stock.reference == ref
        )
        logger.info('Compute FIFO : "%s"' % (_calc_fifo))

        _calc_capa = Cip.select(fn.MIN(Cip.qt_pass).alias("qt_pass")).where(Cip.ucd == "_ucd")
        _cal = _calc_capa[0].qt_pass or 30
        logger.info('COmpute PASSbox quantity (qt_pass) : "%s"' % _cal)

        _existing_adr = Stock.select().where((Stock.adresse == args["adresse"]) & (Stock.reference != ref))
        if len(_existing_adr):
            logger.warning("This location is actually occupied")
            return {"message": "This location is actually occupied"}, HTTP_503_SERVICE_UNAVAILABLE

        logger.debug("Ajout ligne de stock, dans la table f_stock")
        Stock.create(
            adresse=args["adresse"],
            reference=ref,
            quantite=args["quantite"],
            id_fifo=_calc_fifo[0].new_fifo,
            lot=args["lot"],
            date_peremption=_dp,
            date_entree=str(datetime.datetime.now()).split(".")[0],
            contenant=args["containerCode"],
            magasin=args["magasin"],
            ucd=_ucd,
            fraction=args["fraction"],
            capa=_cal,
        )

        logger.debug(
            'Mise à jour adresse "%s", état O & code contenant: "%s"' % (args["adresse"], args["containerCode"])
        )
        (
            Adresse.update({Adresse.etat: "O", Adresse.contenant: args["containerCode"]})
            .where(Adresse.adresse == args["adresse"])
            .execute()
        )

        qte_tot = Stock.select(fn.SUM(Stock.quantite)).where(Stock.reference == ref)

        logger.debug("Ecriture dans la table historique")
        Historique.create(
            chrono=datetime.datetime.now(),
            reference=ref,
            adresse=args["adresse"],
            magasin=args["adresse"][:3],
            quantite_mouvement=args["quantite"],
            quantite_totale=qte_tot,
            quantite_picking=qte_tot,
            service=args["service"],
            type_mouvement="ENT",
            lot=args["lot"][0],
            pmp=0,
            date_peremption=_dp,
            contenant=args["containerCode"],
            poste=MEDIANWEB_POSTE,
            ucd=_ucd,
            fraction=args["fraction"],
            utilisateur=session["username"],
        )

        _calc_id_robot = Magasin.select(Magasin.id_robot, Magasin.id_zone).where(Magasin.mag == args["magasin"])
        _id_robot = _calc_id_robot[0].id_robot or 1

        logger.info("GPAO : Ajout d'un mouvement de type ENTREE")
        Gpao.create(
            chrono=datetime.datetime.now(),
            poste="MEDIANWEB",
            etat="A",
            ref=ref,
            qte=args["quantite"],
            lot=args["lot"],
            type_mvt="E",
            dest=args["service"],
            tperemp=_dp,
            fraction=args["fraction"],
            id_robot=_id_robot,
            id_zone=_calc_id_robot[0].id_zone or 1,
            user=session["username"],
            magasin=args["adresse"][:3],
        )

    except Exception as error:
        logger.error(error.args)
        return {"message": error.args}, HTTP_503_SERVICE_UNAVAILABLE

    logger.info("Create a stock line: success")
    return "Success"


@stock_blueprint.route("<string:ref_pk>", methods=["DELETE"])
@jwt_required()
def delete_stock(ref_pk):
    args = request.args
    _pk = args["pk"]
    _service = args["service"]

    try:
        current_stk = Stock.select().where(Stock.pk == _pk).get()
    except DoesNotExist:
        return {"alertMessage": "ui.stock.unknown"}, HTTP_503_SERVICE_UNAVAILABLE

    try:
        reference = Product.select(Product.reference).where(Product.pk == ref_pk).get()
    except DoesNotExist:
        return {"alertMessage": "ui.reference.unknown"}, HTTP_503_SERVICE_UNAVAILABLE

    logger.info('Delete stock line, pk: "%s"' % (_pk))

    try:
        Service.select().where(Service.code == _service).get()
    except DoesNotExist:
        return {"alertMessage": "ui.ward.unknown"}, HTTP_503_SERVICE_UNAVAILABLE

    _padded_adr = _padAddressField(args["adresse"])
    if not _padded_adr:
        logger.warning("Error when formating address")
        return "Error when formating address", HTTP_503_SERVICE_UNAVAILABLE

    _adrs = Stock.select().where(Stock.adresse == _padded_adr)

    try:
        logger.debug('Update address to become free: "%s"' % _padded_adr)
        Adresse.update({Adresse.etat: "L"}).where(Adresse.adresse == _padded_adr).execute()

        # Check if adress state is type M or warehouse = address
        try:
            current_addr = Adresse.select(Adresse).where(Adresse.adresse == _padded_adr).get()
        except DoesNotExist:
            current_addr = None

        if (
            current_addr and current_addr.etat == EtatAdresse.Multiple.value
        ) or current_stk.adresse == current_stk.magasin:
            logger.info(f"Delete the stock line for {_padded_adr}")
            qte_tot = Stock.select(fn.SUM(Stock.quantite)).where(Stock.reference == current_stk.reference).scalar()
            qte_tot = qte_tot - current_stk.quantite
            logger.debug(f"Total quantity {qte_tot}")
            _calc_id_robot = Magasin.get_or_none(mag=_padded_adr[0:3])
            logger.info("HISTO: Add delete movement")
            Historique.create(
                chrono=datetime.datetime.now(),
                reference=current_stk.reference,
                adresse=_padded_adr,
                magasin=_padded_adr[0:3],
                quantite_mouvement=current_stk.quantite,
                quantite_totale=qte_tot,
                service=_service,
                type_mouvement=HistoryType.Sortie.value,
                lot=current_stk.lot,
                pmp=0,
                date_peremption=current_stk.date_peremption,
                contenant=current_stk.contenant,
                poste=MEDIANWEB_POSTE,
                ucd=current_stk.ucd.strip(),
                fraction=current_stk.fraction,
                serial=current_stk.serial,
                quantite_picking=qte_tot,
                utilisateur=session["username"],
            )
            logger.info("GPAO: add delete movement.")
            Gpao.create(
                poste=MEDIANWEB_POSTE,
                chrono=datetime.datetime.now(),
                etat="A",
                ref=current_stk.reference,
                qte=current_stk.quantite,
                lot=current_stk.lot.strip(),
                type_mvt="S",
                dest=_service,
                tperemp=current_stk.date_peremption,
                fraction=current_stk.fraction,
                serial=current_stk.serial,
                ucd=current_stk.ucd.strip(),
                magasin=_padded_adr[0:3],
                contenant=current_stk.contenant,
                user=session["username"],
                id_zone=_calc_id_robot.id_zone if _calc_id_robot else 1,
                id_robot=_calc_id_robot.id_robot if _calc_id_robot else 1,
            )
            current_stk.delete_instance()
        else:
            logger.info(f"Delete stock lines at address {_padded_adr}...")
            for st in _adrs:
                st.delete_instance()
                if _adrs.count() == 0 and args["is_empty_container"] == "false":
                    logger.info("Move passbox from .3 to .1...")
                    _moveContainerForward(_padded_adr)
                qte_tot = Stock.select(fn.SUM(Stock.quantite)).where(Stock.reference == reference.reference).scalar()

                logger.debug("HISTO: Add delete movement")
                Historique.create(
                    chrono=datetime.datetime.now(),
                    reference=reference.reference,
                    adresse=_padded_adr,
                    magasin=_padded_adr[0:3],
                    quantite_mouvement=st.quantite,
                    quantite_totale=qte_tot,
                    service=_service,
                    type_mouvement=HistoryType.Sortie.value,
                    lot=st.lot,
                    pmp=0,
                    date_peremption=st.date_peremption,
                    contenant=st.contenant,
                    poste=MEDIANWEB_POSTE,
                    ucd=st.ucd.strip(),
                    fraction=st.fraction,
                    serial=st.serial,
                    quantite_picking=qte_tot,
                    utilisateur=session["username"],
                )

                _calc_id_robot = Magasin.get_or_none(mag=_padded_adr[0:3])

                logger.info("GPAO: add delete movement.")
                Gpao.create(
                    poste=MEDIANWEB_POSTE,
                    chrono=datetime.datetime.now(),
                    etat="A",
                    ref=reference.reference,
                    qte=st.quantite,
                    lot=st.lot.strip(),
                    type_mvt="S",
                    dest=_service,
                    tperemp=st.date_peremption,
                    fraction=st.fraction,
                    serial=st.serial,
                    ucd=st.ucd.strip(),
                    magasin=_padded_adr[0:3],
                    contenant=st.contenant,
                    user=session["username"],
                    id_zone=_calc_id_robot.id_zone if _calc_id_robot else 1,
                    id_robot=_calc_id_robot.id_robot if _calc_id_robot else 1,
                )

    except Exception as error:
        logger.error(error.args)
        return {"message": error.args}, HTTP_503_SERVICE_UNAVAILABLE

    logger.info('Delete stock line with success, pk: "%s"' % (_pk))
    return "Success"


@stock_blueprint.route("fix_by_serial", methods=["POST"])
@jwt_required()
def fix_stock_by_serial():
    # WARNING : DO NOT USE UNLESS YOU KNOW WHAT YOU ARE DOING
    # This endpoint is used to fix stock entries by serial number, using fake quantities based on the box quantity.
    # Serials of doses after the first one are not included as they are not known, so they will be missing in UnitDose!

    try:
        data = request.json
        serial = data["serial"]
        ucd = data["gtin"]
        batch = data["lot"]
        perem = data["perem"]
    except ValueError:
        logger.error("Invalid JSON data")
        return {"error": "Invalid JSON data"}, HTTP_500_INTERNAL_SERVER_ERROR

    with mysql_db.atomic() as transaction:
        try:
            original_container: UnitDose = (
                UnitDose.select().where(UnitDose.serial == serial).order_by(-UnitDose.pk).get_or_none()
            )

            if original_container is None:
                logger.debug(f"Serial {serial} not found in UnitDose, creating a new one.")
                new_container_code = get_counter("CONTENANT_PASS")
                new_container_code = str(new_container_code).zfill(9)
                newUD: UnitDose = UnitDose()
                newUD.serial = serial
                newUD.contenant = new_container_code
                newUD.chrono = datetime.datetime.now()
                newUD.save()
                container_code = newUD.contenant
            else:
                container_code = original_container.contenant

            cip: Cip = Cip.select().where(Cip.ucd.contains(str(int(ucd)))).order_by(-Cip.dossier).get()
            ucd: Ucd = Ucd.get(Ucd.ucd == cip.ucd)
            max_fifo = Stock.select(fn.MAX(Stock.id_fifo)).where(Stock.reference == ucd.reference).scalar()

            new_stock: Stock = Stock()

            perem_date = _parse_dates(perem)
            current_year = datetime.datetime.now().year

            perem_year = perem_date.year
            if perem_year < current_year:
                raise ValueError(f"Parsed expiry year {perem_year} is older than current year {current_year}")
            if perem_year > current_year + 20:
                raise ValueError(f"Parsed expiry year {perem_year} is more than 20 years in the future")
            new_stock.date_peremption = perem_date.strftime("%Y-%m-%d 00:00:00")

            new_stock.adresse = "SHV"
            new_stock.reference = ucd.reference
            new_stock.quantite = cip.qt_pass
            new_stock.id_fifo = max_fifo + 1 if max_fifo else 1
            new_stock.lot = batch
            new_stock.date_entree = datetime.datetime.now()
            new_stock.contenant = container_code
            new_stock.magasin = "SHV"
            new_stock.ucd = ucd.ucd
            new_stock.cip = cip.cip
            new_stock.serial = serial
            new_stock.fraction = 100
            new_stock.capa = cip.qt_pass
            new_stock.save()

            # Analyse the history and the unitdose table to find a quantity
            last_entry_histo: Historique = (
                Historique.select()
                .where(
                    (Historique.contenant == container_code)
                    & (Historique.type_mouvement == HistoryType.Entree.value)
                    & (Historique.info == "CREATION BAC TAMPON")
                )
                .order_by(-Historique.pk)
                .get_or_none()
            )

            if last_entry_histo:
                try:
                    tot_qty_mvt = (
                        Historique.select(fn.SUM(Historique.quantite_mouvement).alias("tot_qte_mvt"))
                        .where(
                            (Historique.lot == last_entry_histo.lot)
                            & (Historique.pk_item == last_entry_histo.pk_item)
                            & (Historique.type_mouvement == HistoryType.Entree.value)
                            & (Historique.contenant == container_code)
                        )
                        .group_by(Historique.contenant)
                        .get()
                        .tot_qte_mvt
                    )

                    number_of_serials = UnitDose.select().where((UnitDose.contenant == container_code)).count()

                    if number_of_serials > 0 and tot_qty_mvt > 0:
                        # We prefer using the number of serials. Any movement could have happened after the last entry.
                        # The serial list should be more accurate.
                        new_stock.quantite = number_of_serials
                        new_stock.save()
                except DoesNotExist:
                    logger.warning(
                        f"Failed to calculate total quantity for lot {last_entry_histo.lot}"
                        f"and item {last_entry_histo.pk_item}. Using default cip.qt_pass value."
                    )

            reqTotal = (
                Stock.select(fn.IFNULL(fn.SUM(Stock.quantite), 0).alias("total"))
                .where(Stock.reference == ucd.reference)
                .get()
            )

            # Histo
            histo: Historique = Historique()
            histo.chrono = datetime.datetime.now()
            histo.reference = ucd.reference
            histo.adresse = new_stock.adresse
            histo.quantite_mouvement = new_stock.quantite
            histo.quantite_totale = reqTotal.total
            histo.service = "TRAN"
            histo.type_mouvement = HistoryType.Inventaire.value
            histo.lot = batch
            histo.date_peremption = new_stock.date_peremption
            histo.contenant = new_stock.contenant
            histo.poste = MEDIANWEB_POSTE
            histo.ucd = ucd.ucd
            histo.magasin = new_stock.magasin
            histo.fraction = new_stock.fraction
            histo.serial = serial
            histo.info = "Création par inventaire"
            histo.utilisateur = session["username"]
            histo.save()

            logger.info(f"Stock created with ID: {new_stock.pk} and Serial: {serial}")

            return {"message": "Stock updated successfully"}, HTTP_200_OK

        except Exception as e:
            transaction.rollback()
            logger.error(e)
            return {"error": str(e)}, HTTP_500_INTERNAL_SERVER_ERROR


def _parse_dates(date_str) -> datetime.datetime:
    if len(date_str) != 6:
        raise ValueError(f"Invalid date format: {date_str}")

    try:
        formats = ["%y%m%d", "%y%d%m"]
        month_chars = date_str[2:4]
        if "-" in month_chars or "=" in month_chars:
            month_chars = month_chars.replace("-", "").replace("=", "").zfill(2)
            old_date_str = date_str
            date_str = date_str[:2] + month_chars + date_str[4:]
            logger.info(f"Converted date string from {old_date_str} to {date_str}")

        for fmt in formats:
            try:
                return datetime.datetime.strptime(date_str, fmt)
            except ValueError:
                continue
    except ValueError:
        pass
    raise ValueError(f"Invalid date format: {date_str}")
