#!/usr/bin/env python3

# Common functions for managing statically-attached (ie onboot, without xapi) VDIs


import json
import os
import os.path
import subprocess
import sys
import time
import urllib.parse
import xmlrpc.client
from typing import TYPE_CHECKING, cast

import XenAPI

import inventory

if TYPE_CHECKING:
    from typing import Any, Dict

main_dir = "/etc/xensource/static-vdis"

xapi_storage_script = "/usr/libexec/xapi-storage-script"
smapiv3_config = "device-config"

MULTIPATH_FLAG = "/var/run/nonpersistent/multipath_enabled"

def call_volume_plugin(name, command, args):
    cmd_args = [(xapi_storage_script + "/volume/org.xen.xapi.storage." + name + "/"
             + command), "static-vdis"]
    if args:
        cmd_args.extend(args)
    # on Python >= 3.3 a timout can be set, not on 2.7
    # when porting please add a timeout
    output = subprocess.check_output(cmd_args, universal_newlines=True)
    return json.loads(output)

def call_datapath_plugin(name, command, args):
    args = [ xapi_storage_script + "/datapath/" + name + "/" + command, "static-vdis" ] + args
    output = subprocess.check_output(args, universal_newlines=True)
    return json.loads(output)
        
def read_whole_file(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return ''.join(f.readlines()).strip()
        
def write_whole_file(filename, contents):
    with open(filename, "w", encoding='utf-8') as f:
        f.write(contents)

def load(name):
    """Return a dictionary describing a single static VDI"""
    d = { "id": name }
    for key in [ "vdi-uuid", "config", smapiv3_config, "volume-plugin", "volume-uri", "volume-key", "reason" ]:
        path = "%s/%s/%s" % (main_dir, name, key)
        if os.path.exists(path):
             d[key] = read_whole_file(path)
    try:
        disk = "%s/%s/disk" % (main_dir, name)
        os.stat(disk) # throws an error if missing
        d["disk"] = os.readlink(disk)
    except:
        pass
    dnb = "False"
    try:
        os.stat("%s/%s/delete-next-boot" % (main_dir, name))
        dnb = "True"
    except:
        pass
    d["delete-next-boot"] = dnb
    return d

def wait_for_corosync_quorum():
    cmd_args = ['xcli', 'diagnostics', '--json', 'static-vdis']
    quorate = False
    while not quorate:
        output = subprocess.check_output(cmd_args, universal_newlines=True)
        output_map = json.loads(output)
        quorate = output_map.get('is_quorate')
        if not quorate:
            time.sleep(1)

def check_clusterstack(ty):
    config = call_volume_plugin(ty, "Plugin.Query", None)

    if 'corosync' in config.get('required_cluster_stack', ''):
        wait_for_corosync_quorum()

def sr_attach(ty, device_config):
    # type: (str, Dict[str, object]) -> str
    check_clusterstack(ty)

    args = [arg for (k,v) in device_config.items()
            for arg in ["--configuration", k, v]]
    return call_volume_plugin(ty, "SR.attach", args)

def list_vdis():
    files = []
    try:
        files = os.listdir(main_dir)
    except OSError:  # All possible errors are subclasses of OSError
        pass
    return list(map(load, files))

def fresh_name():
    """Return a unique name for a new static VDI configuration directory"""
    try:
        files = os.listdir(main_dir)
        for i in range(0, len(files) + 1):  # guarantees to find a unique number
            i = str(i)
            if i not in files:
                return i
    except OSError:  # All possible errors are subclasses of OSError
        # Directory doesn't exist
        os.mkdir(main_dir)
    return "0"  # Always return a string, fixes pyright error by not returning None
        

def to_string_list(d):
    keys = [ "vdi-uuid", "reason", "currently-attached", "delete-next-boot" "path" ]
    m = 0
    for key in keys:
        if len(key) > m:
            m = m + len(key)
    def left(key, value):
        return key + (" " * (m - len(key))) + ": " + value
    def right(key, value):
        return (" " * (m - len(key))) + key + ": " + value
    l = [ left("vdi-uuid", d["vdi-uuid"]), right("reason", d["reason"]) ]
    l.append(right("delete-next-boot", d["delete-next-boot"]))
    if "disk" in d:
        l.append(right("currently-attached", "True"))
        l.append(right("path", d['disk']))
    else:
        l.append(right("currently-attached", "False"))
    return l

def add(session, vdi_uuid, reason):
    for existing in list_vdis():
        if existing['vdi-uuid'] == vdi_uuid:
            if existing['delete-next-boot'] == "True":
                # Undo the 'delete-next-boot' flag to reinstitute
                path = main_dir + "/" + existing['id']
                os.unlink(path + "/delete-next-boot")
                os.unlink(path + "/reason")
                write_whole_file(path + "/reason", reason)
                # Assume config is still valid
                return
            raise Exception("Static configuration for VDI already exists")
    
    vdi = session.xenapi.VDI.get_by_uuid(vdi_uuid)
    host = session.xenapi.host.get_by_uuid(inventory.get_localhost_uuid ())
    sr = session.xenapi.VDI.get_SR(vdi)
    ty = session.xenapi.SR.get_type(sr)
    sr_uuid = session.xenapi.SR.get_uuid(sr)

    # This host's device-config will have info needed to attach SMAPIv3
    # SRs.
    device_config = None
    for p in session.xenapi.host.get_PBDs(host):
       p_rec = session.xenapi.PBD.get_record(p)
       if p_rec['SR'] == sr:
         device_config = p_rec['device_config']

    sm = None
    all_sm = session.xenapi.SM.get_all_records()
    sm_ref = ""  # pragma: no cover
    for sm_ref in all_sm:
        if all_sm[sm_ref]['type'] == ty:
            sm = all_sm[sm_ref]
            break

    if sm is None:
        raise Exception("Unable to discover SM plugin")

    data = {
        "vdi-uuid": vdi_uuid,
        "reason": reason
    }

    # If the SM supports offline attach then we use it
    if "VDI_ATTACH_OFFLINE" in sm["features"]:
        data["volume-plugin"] = ty
        data[smapiv3_config] = json.dumps(device_config)
        assert device_config  # must not be None
        sr = sr_attach(ty, device_config)
        location = session.xenapi.VDI.get_location(vdi)
        stat = call_volume_plugin(ty, "Volume.stat", [ sr, location ])
        data["volume-uri"] = stat["uri"][0]
        data["volume-key"] = stat["key"]
        data["multipath"] = json.dumps(os.path.exists(MULTIPATH_FLAG))
    else:
        # SMAPIv1
        try:
            data["driver"] = session.xenapi.SM.get_driver_filename(sm_ref)
            data["config"] = session.xenapi.VDI.generate_config(host, vdi)
        except XenAPI.Failure as e:
            raise Exception("Failed generating static config file: %s" % (str(e)))

    # Make a fresh directory in main_dir to store the configuration. Note
    # there is no locking so please run this script serially.
    fresh = fresh_name()
    path = main_dir + "/" + fresh
    os.mkdir(path)
    for key in data:
        write_whole_file(path + "/" + key, data[key])

def delete(vdi_uuid):
    found = False
    for existing in list_vdis():
        if existing['vdi-uuid'] == vdi_uuid:
            found = True
            path = main_dir + "/" + existing['id']
            f = open(path + "/delete-next-boot", "w")
            f.close()
            # If not currently attached then zap the whole tree
            if not("disk" in existing):
                os.system("/bin/rm -rf %s" % path)
    if not found:
        raise Exception("Disk configuration not found")

# Copied by util.py
def doexec(args, inputtext=None):
    """Execute a subprocess, then return its return code, stdout and stderr"""
    proc = subprocess.Popen(args,stdin=subprocess.PIPE,stdout=subprocess.PIPE,stderr=subprocess.PIPE,close_fds=True, universal_newlines=True)
    (stdout,stderr) = proc.communicate(inputtext)
    rc = proc.returncode
    return (rc,stdout,stderr)


def connect_smapiv1_nbd(params_nbd):
    path, export = parse_nbd_uri(params_nbd)
    return connect_nbd(path, export)


def call_backend_attach(driver, config):
    args = [arg.encode('utf-8') for arg in [driver, config]]
    xml = doexec(args)
    if xml[0] != 0:
        raise Exception("SM_BACKEND_FAILURE(%d, %s, %s)" % xml)
    xml_rpc = xmlrpc.client.loads(xml[1])  # type: Any # pragma: no cover

    if 'params_nbd' in xml_rpc[0][0]:
        # Prefer NBD if available
        return connect_smapiv1_nbd(xml_rpc[0][0]['params_nbd'])

    try:
        path = xml_rpc[0][0]['params']
    except:
        path = xml_rpc[0][0]
    return path

def call_backend_detach(driver, config):
    params = xmlrpc.client.loads(config)[0][0]  # type: Any
    params['command'] = 'vdi_detach_from_config'
    config = xmlrpc.client.dumps(tuple([params]), params['command'])
    xml = doexec([ driver, config ])
    if xml[0] != 0:
        raise Exception("SM_BACKEND_FAILURE(%d, %s, %s)" % xml)
    xml_rpc = xmlrpc.client.loads(xml[1])
    try:
        res = cast(dict, xml_rpc[0][0])['params']  # pragma: no cover
    except Exception:
        res = xml_rpc[0][0]
    return res

def connect_nbd(path, exportname):
    return subprocess.check_output(
        ['/opt/xensource/libexec/nbd_client_manager.py', 'connect',
         '--path', path,
         '--exportname', exportname], 
        universal_newlines=True).strip()

def disconnect_nbd_device(nbd_device):
    subprocess.check_call(
        ['/opt/xensource/libexec/nbd_client_manager.py', 'disconnect',
         '--device', nbd_device])


def parse_nbd_uri(uri):
    parts = uri.split(':')
    exportname_prefix = 'exportname='
    if len(parts) != 4 or \
            parts[0:2] != ['nbd', 'unix'] or \
            not parts[3].startswith(exportname_prefix):
        raise Exception('Got invalid NBD URI for SM backend: ' + uri)
    socket_path = parts[2]
    exportname = parts[3][len(exportname_prefix):]
    return (socket_path, exportname)


def attach(vdi_uuid):
    found = False
    for existing in list_vdis():
        if existing['vdi-uuid'] == vdi_uuid:
            found = True
            if not('path' in existing):
                d = main_dir + "/" + existing['id'] 
                # Delete any old symlink
                try:
                    os.unlink(d + "/disk")
                except:
                    pass
                path = None  # Raise TypeError if path is not set at the end
                if not (os.path.exists(d + "/" + smapiv3_config)):
                    # SMAPIv1
                    config = read_whole_file(d + "/config")
                    driver = read_whole_file(d + "/driver")
                    path = call_backend_attach(driver, config)
                else:
                    volume_plugin = read_whole_file(d + "/volume-plugin")
                    configuration = json.loads(read_whole_file(d + "/" + smapiv3_config))
                    vol_key = read_whole_file(d + "/volume-key")
                    vol_uri = read_whole_file(d + "/volume-uri")
                    multipath = json.loads(read_whole_file(d + "/multipath"))
                    scheme = urllib.parse.urlparse(vol_uri).scheme
                    # Set the multipath flag if required
                    if multipath:
                        with open(MULTIPATH_FLAG, 'a'):
                            os.utime(MULTIPATH_FLAG, None)
                    sr = sr_attach(volume_plugin, configuration)
                    attach = call_datapath_plugin(scheme, "Datapath.attach", [ vol_uri, "0" ])
                    (name, implementation) = \
                            next((name, impl)
                                 for (name, impl)
                                 in attach['implementations']
                                 if name in ['BlockDevice', 'File', 'Nbd'])
                    call_datapath_plugin(scheme, "Datapath.activate", [ vol_uri, "0" ])
                    if name in ['BlockDevice', 'File']:
                        path = implementation['path']
                    elif name == 'Nbd':
                        uri = implementation['uri']
                        (path, exportname) = parse_nbd_uri(uri)
                        path = connect_nbd(path=path, exportname=exportname)

                if path is None:
                    raise TypeError("static-vdis: attach(): path was not set")
                os.symlink(path, d + "/disk")
                return d + "/disk"
    if not found:
        raise Exception("Disk configuration not found")
    return None

def detach(vdi_uuid):
    found = False
    for existing in list_vdis():
        if existing['vdi-uuid'] == vdi_uuid:
            if not ('disk' in existing):
                return
            found = True
            d = main_dir + "/" + existing['id']
            disk = existing['disk']
            if disk.startswith('/dev/nbd'):
                disconnect_nbd_device(disk)
            if not (os.path.exists(d + "/" + smapiv3_config)):
                # SMAPIv1
                config = read_whole_file(d + "/config")
                driver = read_whole_file(d + "/driver")
                call_backend_detach(driver, config)
            else:
                volume_plugin = read_whole_file(d + "/volume-plugin")
                vol_key = read_whole_file(d + "/volume-key")
                vol_uri = read_whole_file(d + "/volume-uri")
                scheme = urllib.parse.urlparse(vol_uri).scheme
                call_datapath_plugin(scheme, "Datapath.deactivate", [ vol_uri, "0" ])
                call_datapath_plugin(scheme, "Datapath.detach", [ vol_uri, "0" ])
            os.unlink(d + "/disk")
            return
    if not found:
        raise Exception("Disk configuration not found")

    
def usage():
    print("Usage:")
    print(" %s list                 -- print a list of VDIs which will be attached on host boot" % sys.argv[0])
    print(" %s add <uuid> <reason>  -- make the VDI <uuid> available on host boot" % sys.argv[0])
    print(" %s del <uuid>           -- cease making the VDI <uuid> available on host boot" % sys.argv[0])
    print(" %s attach <uuid>        -- attach the VDI immediately" % sys.argv[0])    
    print(" %s detach <uuid>        -- detach the VDI immediately" % sys.argv[0])
    sys.exit(1)


def main():
    if len(sys.argv) < 2:
        usage()
        
    if sys.argv[1] == "list" and len(sys.argv) == 2:
        for i in list_vdis():
            for line in to_string_list(i):
                print(line)
            print()
    elif sys.argv[1] == "add" and len(sys.argv) == 4:
        session = XenAPI.xapi_local()
        session.xenapi.login_with_password("root", "", "1.0", "xen-api-scripts-static-vdis")        
        try:
            add(session, sys.argv[2], sys.argv[3])
        finally:
            session.xenapi.logout()
    elif sys.argv[1] == "del" and len(sys.argv) == 3:
        delete(sys.argv[2])
    elif sys.argv[1] == "attach" and len(sys.argv) == 3:
        disk_path = attach(sys.argv[2])
        print(disk_path)
    elif sys.argv[1] == "detach" and len(sys.argv) == 3:
        detach(sys.argv[2])
    else:
        usage()


if  __name__ == "__main__":  # pragma: no cover
    main()
