"""
MQTT Handler Module - Handles MQTT functionality for the application
"""

import json
import logging
import re
from typing import Optional, Dict, Any

from flask import Flask
from flask_mqtt import Mqtt

from core.websocket_handler import websocket_handler

logger = logging.getLogger("median.mqtt_handler")


class MQTTHandler:
    """
    MQTT Handler class to centralize MQTT functionality
    """

    def __init__(self):
        """Initialize MQTT handler"""
        self.mqtt = None
        self.subscribed_topics = []
        self.message_callbacks = {}

    def init_app(self, app: Flask) -> Mqtt:
        logger.info("Initializing MQTT handler")

        self.mqtt = Mqtt(app)
        self._register_callbacks()

        # Subscribe to default topics after a short delay to ensure connection is established
        @self.mqtt.on_connect()
        def subscribe_on_connect(client, userdata, flags, rc):
            self._subscribe_to_default_topics()

            # Set up device message handling (including heartbeats)
            self.handle_device_messages()

        return self.mqtt

    def get_mqtt(self) -> Optional[Mqtt]:
        return self.mqtt

    def _register_callbacks(self):
        """Register all MQTT callbacks"""
        if self.mqtt is None:
            logger.error("Cannot register MQTT callbacks: MQTT not initialized")
            return

        @self.mqtt.on_connect()
        def handle_connect(client, userdata, flags, rc):
            """Handle MQTT connection event"""
            logger.info(f"MQTT Connected with result code: {rc}")

            # Re-subscribe to topics on reconnect
            for topic in self.subscribed_topics:
                self.mqtt.subscribe(topic)

        @self.mqtt.on_message()
        def handle_message(client, userdata, message):
            """Handle incoming MQTT messages"""
            topic = message.topic
            payload = message.payload.decode()

            # logger.debug(f"MQTT received message on topic {topic}: {payload}")

            # Call registered callbacks for this topic
            for callback_topic, callback_func in self.message_callbacks.items():
                # Use simple matching or topic wildcards
                if callback_topic == topic or (callback_topic.endswith("#") and topic.startswith(callback_topic[:-1])):
                    try:
                        callback_func(topic, payload)
                    except Exception as e:
                        logger.error(f"Error processing MQTT message in callback: {str(e)}")

            # Handle device heartbeat messages
            if re.match(r"^devices/.+/heartbeat$", topic):
                pass
                # logger.info(f"Device heartbeat received on topic {topic}")
                # Here you can add code to handle the heartbeat, e.g., update last seen time in the database

    def subscribe(self, topic: str):
        if self.mqtt is None:
            logger.error(f"Cannot subscribe to topic {topic}: MQTT not initialized")
            return False

        logger.info(f"Subscribing to MQTT topic: {topic}")
        self.mqtt.subscribe(topic)
        if topic not in self.subscribed_topics:
            self.subscribed_topics.append(topic)
        return True

    def publish(self, topic: str, payload: Dict[str, Any] = None, qos: int = 0, retain: bool = False) -> bool:
        """
        Publish a message to an MQTT topic

        Args:
            topic: MQTT topic to publish to
            payload: Dictionary to be JSON serialized and published
            qos: Quality of Service level
            retain: Whether to retain the message

        Returns:
            True if successful, False otherwise
        """
        if self.mqtt is None:
            logger.error(f"Cannot publish to topic {topic}: MQTT not initialized")
            return False

        try:
            message = json.dumps(payload) if payload else ""
            logger.debug(f"Publishing MQTT message to {topic}: {message}")
            self.mqtt.publish(topic, message, qos, retain)
            return True
        except Exception as e:
            logger.error(f"Error publishing MQTT message: {str(e)}")
            return False

    def register_callback(self, topic: str, callback):
        """
        Register a callback function for a specific topic

        Args:
            topic: MQTT topic to receive messages from
            callback: Function to call when a message is received (topic, payload)
        """
        self.message_callbacks[topic] = callback
        # Make sure we're subscribed to the topic
        self.subscribe(topic)

    def _subscribe_to_default_topics(self):
        """Subscribe to default topics defined in the application"""
        try:
            # Import topics from aideplus_blueprint
            from ressources.aideplus.aideplus_blueprint import SERVER_EVENTS_TOPIC, DEVICE_WILDCARD_TOPIC

            # Subscribe to these topics
            self.subscribe(SERVER_EVENTS_TOPIC)
            self.subscribe(DEVICE_WILDCARD_TOPIC)
            logger.info(f"Subscribed to default topics: {SERVER_EVENTS_TOPIC}, {DEVICE_WILDCARD_TOPIC}")
        except ImportError:
            logger.warning("Could not import default topics from aideplus_blueprint")
        except Exception as e:
            logger.error(f"Error subscribing to default topics: {str(e)}")

    def handle_device_messages(self):
        device_topic_pattern = r"^devices/(.+)/(.+)$"

        def on_device_message(topic, payload):
            try:
                data = json.loads(payload) if payload else {}

                # Extract device name and action from topic
                match = re.match(device_topic_pattern, topic)
                if match:
                    device_name = match.group(1)
                    action = match.group(2)

                    message_data = {"device": device_name, "action": action, "topic": topic, "payload": data}

                    # Forward the message to WebSockets
                    websocket_handler.broadcast_to_aideplus("mqtt", message_data)

                    # logger.debug(f"Forwarded device message to WebSockets: {device_name}/{action}")
            except Exception as e:
                logger.error(f"Error handling device message: {str(e)}")

        # Register the callback for all device topics
        self.register_callback("devices/#", on_device_message)
        logger.info("Registered handler for device messages")


# Create a singleton instance
mqtt_handler = MQTTHandler()


def init_mqtt(app: Flask) -> Mqtt:
    return mqtt_handler.init_app(app)
