"""
GPAO Service for AidePlus
Handles GPAO (Production Management) operations
"""

import logging
from peewee import JOIN, fn
from flask import session
from datetime import datetime

from median.models import Gpao, Patient, Product, ListeValide, ItemValide, Service
from median.constant import TypeMouvementGpao, TypeEtatGpao, MEDIANWEB_POSTE, MYSQL_ZERO_DATE, CONFIG_WEB_CLE
from common.status import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR, HTTP_400_BAD_REQUEST
from common.models import WebLogActions
from median.database import mysql_db
from median.views import RawConfig

from .stock_service import StockService

logger = logging.getLogger("median.gpao")

gpao_management = RawConfig(MEDIANWEB_POSTE, CONFIG_WEB_CLE).read("k_web_gpao_management")


class GPAOService:
    """Service for handling GPAO operations"""

    def __init__(self, stock_service):
        self.stock_service: StockService = stock_service

    def fetch_gpao_data(self, ward_filter=None, id_chargement=None, chrono=None, gpao_type=None):
        try:
            if gpao_type not in [TypeMouvementGpao.COMPLEMENT.value, TypeMouvementGpao.OUTPUT.value]:
                logger.error(f"Invalid gpaoType: {gpao_type}")
                return {"error": "Invalid gpaoType"}, HTTP_400_BAD_REQUEST

            query = (
                ItemValide.select(
                    Gpao,
                    Patient,
                    Product,
                    ListeValide,
                    ItemValide,
                    Service,
                    fn.SUM(Gpao.qte).alias("calc_quantite_serv"),
                    fn.MAX(Gpao.solde).alias("solde"),
                    fn.GROUP_CONCAT(Gpao.pk.distinct()).alias("gpao_pks"),
                    fn.GROUP_CONCAT(Gpao.item_wms.distinct()).alias("item_wms_ids"),
                )
                .join(ListeValide, JOIN.INNER, on=(ItemValide.liste_pk == ListeValide.pk))
                .join_from(
                    ItemValide,
                    Gpao,
                    JOIN.INNER,
                    on=(
                        (ItemValide.item_wms == Gpao.item_wms)
                        # TODO: this _2 works but it is not solid. find something better
                        # The source is that some products are not available at the start of the picking
                        # and a gpao line with D will be created. But then while picking, a second list can
                        # appear, with this _2. And then we have a misallignement between the gpao and the actual
                        # item_valide line
                        & ((ListeValide.liste == Gpao.liste) | (ListeValide.liste == Gpao.liste + "_2"))
                        & (Gpao.item == ItemValide.item)
                        & (Gpao.ref == ItemValide.reference)
                        & (Gpao.fraction == ItemValide.fraction)
                    ),
                )
                .join_from(ItemValide, Product, JOIN.INNER, on=(Product.reference == ItemValide.reference))
                .join_from(ListeValide, Patient, JOIN.LEFT_OUTER, on=(Patient.ipp == ListeValide.num_ipp))
                .join_from(ListeValide, Service, JOIN.INNER, on=(Service.code == ListeValide.service))
                .where(
                    (
                        Gpao.pk.is_null()
                        | (
                            (Gpao.type_mvt == gpao_type)
                            & (Gpao.etat << [TypeEtatGpao.DRAFT.value, TypeEtatGpao.DRAFT2.value])
                        )
                    )
                )
            )

            if gpao_type != TypeMouvementGpao.COMPLEMENT.value:
                query = query.where(
                    (Gpao.id_pilulier == ItemValide.id_pilulier) & (ListeValide.id_chargement == id_chargement)
                )
            else:
                query = query.where(
                    (ListeValide.id_chargement.is_null()) | (ListeValide.id_chargement == id_chargement)
                )

            query = query.group_by(
                ItemValide.reference, ItemValide.fraction, ListeValide.num_ipp, ListeValide.num_sej
            ).order_by(+Gpao.ref, +Gpao.ipp)

            gpao_data = []
            for item in query:
                stock_data = self.stock_service.get_stock_info(item.reference, item.fraction)

                qte_dem_query = (
                    ItemValide.select(fn.SUM(ItemValide.quantite_dem))
                    .join(ListeValide, JOIN.INNER, on=(ItemValide.liste_pk == ListeValide.pk))
                    .where(
                        (ItemValide.reference == item.reference) &
                        (ListeValide.id_chargement == id_chargement) &
                        (ItemValide.id_pilulier == item.id_pilulier)
                    )
                )
                qte_dem_result = qte_dem_query.scalar()
                qte_serv = qte_dem_result if qte_dem_result is not None else 0

                gpao_data.append(
                    {
                        "id": len(gpao_data),
                        "gpao_ids": str(item.gpao_pks).split(","),  # An array of all the GPAO IDs used in the group
                        "item_wms_ids": str(item.item_wms_ids).split(
                            ","
                        ),  # An array of all the item WMS used in the group
                        "reference": item.reference,
                        "designation": item.product.designation,
                        "product_pk": item.product.pk,
                        "reap_mode": item.product.reap_mode,
                        "quantity": item.calc_quantite_serv,
                        "quantity_dem": qte_serv,
                        "fraction": item.fraction,
                        "ward": item.listevalide.service.code if item.listevalide.service else "",
                        "stock_data": stock_data,
                        "state": item.gpao.etat,
                        "type": item.gpao.type_mvt,
                        "solde": item.solde,
                        "patient": {
                            "ipp": item.listevalide.num_ipp,
                            "first_name": item.listevalide.patient.prenom if item.listevalide.patient else "",
                            "last_name": item.listevalide.patient.nom if item.listevalide.patient else "",
                            "maiden_name": item.listevalide.patient.nom_jeune_fille if item.listevalide.patient else "",
                            "pillbox": item.id_pilulier if item.id_pilulier else "",
                            "birthdate": (
                                item.listevalide.patient.date_naissance.isoformat()
                                if item.listevalide.patient and item.listevalide.patient.date_naissance
                                else None
                            ),
                            "stay": item.listevalide.num_sej,
                        },
                    }
                )

            return {"gpao": gpao_data}, HTTP_200_OK

        except Exception as e:
            logger.error(f"Error fetching GPAO data: {str(e)}")
            return {"error": f"Failed to fetch GPAO data: {str(e)}"}, HTTP_500_INTERNAL_SERVER_ERROR

    def update_gpao_status(self, products, mvt_type):
        """Update GPAO status for multiple items"""
        with mysql_db.atomic() as transaction:
            try:
                updated_count = 0
                for product in products:
                    completed_quantity = 0

                    if product.get("gpao_ids"):
                        # This comes from the output, it can have several gpao lines

                        gpao_ids = product.get("gpao_ids", [])
                        for pk in gpao_ids:
                            gpao_item: Gpao = Gpao.get(Gpao.pk == pk)
                            gpao_item.etat = TypeEtatGpao.DRAFT2.value
                            gpao_item.solde = 1
                            gpao_item.save()
                            completed_quantity += gpao_item.qte
                            updated_count += 1

                        last_gpao: Gpao = Gpao.get(Gpao.pk == gpao_ids[0])

                        item_wms_ids = product.get("item_wms_ids", [])
                        for item_wms in item_wms_ids:
                            gpao_data = dict(last_gpao.__data__)
                            gpao_data.pop("pk", None)  # Remove primary key

                            if gpao_management and gpao_management.value:
                                # If management is enabled, we set the state to DONE
                                # Some clients (ie: Brugge) handles the sending of gpao items without using x_solde = 1
                                gpao_state = TypeEtatGpao.DONE.value
                            else:
                                gpao_state = TypeEtatGpao.DRAFT2.value

                            gpao_data.update(
                                {
                                    "etat": gpao_state,
                                    "item_wms": item_wms,
                                    "lot": "",
                                    "solde": 1,
                                    "qte": 0,
                                    "qte_dem": 0,
                                    "serial": 0,  # TODO: Mirth won't send anything if there is no serial.
                                    # Review and maybe use gpao_management to handle this
                                    "user": session["username"],
                                    "poste": MEDIANWEB_POSTE,
                                    "chrono": datetime.now(),
                                    "tperemp": MYSQL_ZERO_DATE,
                                    "tenvoi": last_gpao.tenvoi if last_gpao.tenvoi else MYSQL_ZERO_DATE,
                                    "tentree": last_gpao.tentree if last_gpao.tentree else MYSQL_ZERO_DATE,
                                }
                            )
                            new_gpao = Gpao(**gpao_data)
                            new_gpao.save(force_insert=True)

                    else:
                        # This is an WMS product request, it is only one item
                        gpao_item: Gpao = Gpao.get(Gpao.pk == product["id"])
                        gpao_item.etat = TypeEtatGpao.DRAFT2.value
                        gpao_item.save()
                        completed_quantity += gpao_item.qte
                        updated_count += 1

                log_completion(
                    session["username"],
                    "update",
                    f"Updated GPAO items to status Draft2 (B): {completed_quantity}, "
                    f"for ward {product['ward']}, type {mvt_type}",
                )

                logger.info(f"Updated {updated_count} GPAO items to status Draft2 (B)")
                return {"message": f"Updated {updated_count} GPAO items", "updated_count": updated_count}, HTTP_200_OK

            except Exception as e:
                transaction.rollback()
                logger.error(f"Error updating GPAO status: {str(e)}")
                return {"error": f"Failed to update GPAO status: {str(e)}"}, HTTP_500_INTERNAL_SERVER_ERROR


def log_completion(username: str, action: str, message: str):
    """
    Add new log for completion (aideplus)

    :param username: User made the action to log
    :param action:
    :param message: message to log
    """
    logger.info("Completion[%s](%s)): %s" % (action, username, message))
    wlog = WebLogActions()
    wlog.chrono = datetime.now()
    wlog.username = username
    wlog.equipement_type = ""
    wlog.action = action
    wlog.message = message
    wlog.save()
