#!/usr/bin/env python3

import time
import subprocess
import logging
import os.path
import signal
import sys
import re
import XenAPI
import threading
from enum import Enum, auto
from typing import Tuple, List, Optional, Dict, Any
import traceback

# Configure logging
log_format = '%(asctime)s - %(levelname)s - %(message)s'
log_level = logging.INFO

logging.basicConfig(
    level=log_level,
    format=log_format,
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('/var/log/daemon.log')
    ]
)

logger = logging.getLogger(__name__)

# Constants
class SshState(Enum):
    DOWN = auto()
    ACTIVE = auto()
    UNKNOWN = auto()

INSTALLATION_UUID_REGEX = re.compile("^INSTALLATION_UUID")

def match_host_id(s):
    return INSTALLATION_UUID_REGEX.search(s, 0)

class XapiMonitor:
    XAPI_HEALTH_CHECK = '/opt/xensource/libexec/xapi-health-check'
    
    def __init__(self):
        self.logger = logging.getLogger(__name__)
        self.running = True
        self.session = None
        self.localhost_uuid = self.get_localhost_uuid()
        # Create event for graceful exit
        self.exit_event = threading.Event()
        signal.signal(signal.SIGTERM, self._handle_signal)
        signal.signal(signal.SIGINT, self._handle_signal)
        signal.signal(signal.SIGHUP, self._handle_signal)

    def _handle_signal(self, signum, frame):
        """Handle termination signals"""
        signal_names = {
            signal.SIGTERM: "SIGTERM",
            signal.SIGINT: "SIGINT",
            signal.SIGHUP: "SIGHUP"
        }
        signal_name = signal_names.get(signum, f"Signal {signum}")
        self.logger.info(f"Received {signal_name}, preparing to exit...")
        self.running = False
        # Set event to interrupt any waiting
        self.exit_event.set()

    def _create_session(self) -> Optional[Any]:
        """Create a session with local XAPI"""
        try:
            session = XenAPI.xapi_local()
            session.login_with_password("", "")
            return session
        except Exception as e:
            self.logger.error(f"Create XAPI session failed: {e}")
            return None

    def _logout_session(self) -> None:
        """Logout from XAPI session"""
        try:
            if self.session:
                self.session.logout()
                self.logger.debug("XAPI session logged out")
        except Exception as e:
            self.logger.warning(f"Error during session logout: {e}")
            
    @staticmethod
    def get_localhost_uuid() -> str:
        """Get the UUID of the local host from inventory file"""
        filename = '/etc/xensource-inventory'
        try:
            with open(filename, 'r') as f:
                for line in filter(match_host_id, f.readlines()):
                    return line.split("'")[1]
        except Exception as e:
            error_msg = f"Unable to open inventory file [{filename}]: {e}"
            logging.getLogger(__name__).error(error_msg)
            raise RuntimeError(error_msg)
        
        # If we get here, we didn't find the UUID
        error_msg = f"Could not find INSTALLATION_UUID in {filename}"
        logging.getLogger(__name__).error(error_msg)
        raise RuntimeError(error_msg)

    def _run_command(self, command: List[str], timeout: int = 10) -> Tuple[int, str, str]:
        """Execute command and return results
        
        Args:
            command: Command to execute as list of strings
            timeout: Command execution timeout in seconds (default: 10)
        
        Returns:
            Tuple of (return_code, stdout, stderr)
        """
        self.logger.debug(f"Running command: {' '.join(command)}")
        try:
            process = subprocess.Popen(
                command,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                universal_newlines=True
            )
            try:
                stdout, stderr = process.communicate(timeout=timeout)
                self.logger.debug(f"Command returned: {process.returncode}")
                return process.returncode, stdout, stderr
            except subprocess.TimeoutExpired:
                process.kill()
                process.communicate()
                self.logger.error(f"Command execution timeout after {timeout}s: {' '.join(command)}")
                return -1, "", "Timeout"
        except Exception as e:
            self.logger.error(f"Error executing command: {e}")
            return -1, "", str(e)

    def _check_xapi_health(self) -> bool:
        """Check XAPI health status with extended timeout"""
        self.logger.debug("Performing XAPI health check")
        returncode, stdout, stderr = self._run_command([self.XAPI_HEALTH_CHECK], timeout=120)
        
        if returncode != 0:
            self.logger.warning(f"XAPI health check failed: {stderr}")
            
        return returncode == 0

    def _get_ssh_state(self) -> SshState:
        """Get SSH service status"""
        returncode, stdout, stderr = self._run_command(['systemctl', 'is-active', 'sshd'])
        status = stdout.strip()
        
        if status == 'active':
            return SshState.ACTIVE
        if status in ('inactive', 'failed', 'unknown'):
            return SshState.DOWN
        
        self.logger.warning(f"Unexpected SSH status: {status}, stderr: {stderr}")
        return SshState.UNKNOWN

    def _control_ssh_service(self, enable: bool) -> bool:
        """Control SSH service
        
        Returns:
            bool: True if operation was successful, False otherwise
        """
        action = "starting" if enable else "stopping"
        try:
            firewall_cmd = '/usr/bin/firewall-cmd'
            use_firewalld = os.path.exists(firewall_cmd)
            if enable:
                if use_firewalld:
                    ret0, _, stderr0 = self._run_command([firewall_cmd, '--add-service', 'ssh'])
                else:
                    ret0, stderr0 = 0, "n/a"
                ret1, _, stderr1 = self._run_command(['systemctl', 'enable', 'sshd'])
                ret2, _, stderr2 = self._run_command(['systemctl', 'start', 'sshd'])
                success = (ret0 == 0 and ret1 == 0 and ret2 == 0)
            else:
                ret2, _, stderr2 = self._run_command(['systemctl', 'stop', 'sshd'])
                ret1, _, stderr1 = self._run_command(['systemctl', 'disable', 'sshd'])
                if use_firewalld:
                    ret0, _, stderr0 = self._run_command([firewall_cmd, '--remove-service', 'ssh'])
                else:
                    ret0, stderr0 = 0, "n/a"
                success = (ret0 == 0 and ret1 == 0 and ret2 == 0)
                
            if success:
                self.logger.info(f"SSH service {action} successful")
            else:
                err_msg = f"""SSH service {action} failed: enable/disable firewalld service stderr: {stderr0},
                    enable/disable sshd stderr: {stderr1}, start/stop sshd stderr: {stderr2} """
                self.logger.error(err_msg)

            return success
        except Exception as e:
            self.logger.error(f"SSH service {action} failed with exception: {e}")
            self.logger.debug(traceback.format_exc())
            return False

    def _disable_ssh_via_api(self) -> bool:
        """Disable SSH via XAPI, max retries 3 times"""
        if not self.session:
            self.session = self._create_session()
            if not self.session:
                return False
                
        retry_count = 0
        max_retries = 3
        retry_interval = 5
        
        while retry_count < max_retries and self.running:
            try:
                host = self.session.xenapi.host.get_by_uuid(self.localhost_uuid)
                self.session.xenapi.host.disable_ssh(host)
                self.logger.info("Successfully disabled SSH via XAPI")
                return True
            except Exception as e:
                retry_count += 1
                self.logger.warning(f"Disable SSH via API failed ({retry_count}/{max_retries}): {e}")
                if retry_count < max_retries and self.running:
                    # Use interruptible sleep
                    if self.exit_event.wait(retry_interval):
                        return False
                    self._logout_session()
                    self.session = self._create_session()
        
        if not self.running:
            return False
            
        self.logger.error(f"Disable SSH via API failed, max retries reached ({max_retries})")
        return False

    def run(self):
        """Main monitoring loop"""
        self.logger.info("Starting XAPI and SSH service monitoring...")
        
        self.session = self._create_session()
        if not self.session:
            self.logger.warning("Initial session creation failed, will retry later")
        
        while self.running:
            try:
                # Check XAPI health - always perform the check
                xapi_healthy = self._check_xapi_health()

                # Get current SSH state
                current_ssh_state = self._get_ssh_state()
                self.logger.debug(f"Current SSH state: {current_ssh_state}")

                if xapi_healthy:
                    if current_ssh_state == SshState.ACTIVE:
                        self.logger.info("XAPI healthy: Stopping SSH service")
                        if not self._disable_ssh_via_api():
                            self.logger.warning("Disable SSH via API failed, keeping SSH service running")
                else:
                    if current_ssh_state != SshState.ACTIVE:
                        self.logger.info("XAPI unhealthy: Starting SSH service")
                        self._control_ssh_service(True)

            except Exception as e:
                self.logger.error(f"Runtime error: {e}")
                self.logger.debug(traceback.format_exc())
                
                self._logout_session()
                
                self.session = None
                
                # Use interruptible sleep with a fixed interval when there is an error
                if self.exit_event.wait(5):
                    break
                
                continue

            # Use interruptible sleep for main loop
            if self.exit_event.wait(60):
                break

        self._logout_session()
            
        self.logger.info("Monitoring service stopped")

def main():
    logger.info(f"SSH Control Service starting (PID: {os.getpid()})")
    
    try:
        monitor = XapiMonitor()
        monitor.run()
    except Exception as e:
        logger.critical(f"Fatal error in main process: {e}")
        logger.critical(traceback.format_exc())
        sys.exit(1)
    
    logger.info("SSH Control Service exited normally")
    sys.exit(0)

if __name__ == '__main__':
    main()
