#!/usr/bin/env python

"""
This script installs UEFI certificates for XCP-ng hosts.
"""
import argparse
import atexit
import base64
import hashlib
import logging
import os
import shutil
import struct
import subprocess
import sys
import tarfile
import tempfile
import urllib2

from datetime import datetime
from subprocess import Popen
from io import BytesIO

import XenAPI

__author__ = "Bobby Eshleman"
__copyright__ = "Copyright 2021, Vates SAS"
__license__ = "GPLv2"
__version__ = "1.0.0"
__maintainer__ = "Bobby Eshleman"
__email__ = "bobbyeshleman@gmail.com"
__status__ = "Production"


TEMPDIRS = []
DEFAULT_USER_AGENT = "Mozilla/5.0 - Open source hypervisor"


def cd_tempdir():
    """Change directories to a temporary directory.

    Usage:

    ```
        prevdir = cd_tempdir()

        # Do stuff inside temporary directory
        ...

        # Return to previous directory
        os.chdir(prevdir)
    ```

    All temporary directories are automatically cleaned up upon program exit.

    Return the name of the current directory.
    """
    prevdir = os.path.abspath(os.curdir)
    tempdir = tempfile.mkdtemp()
    os.chdir(tempdir)

    # cleanup on program exit
    atexit.register(lambda: shutil.rmtree(tempdir))

    return prevdir


class Actions:
    CLEAR = "clear"
    INSTALL = "install"
    REPORT = "report"
    EXTRACT = "extract"


class Urls:
    dbx = "https://github.com/microsoft/secureboot_objects/raw/refs/heads/main/PostSignedObjects/DBX/amd64/DBXUpdate.bin"


def hashfile(path):
    with open(path, "r") as f:
        return hashlib.md5(f.read()).hexdigest()


def download(url, fname=None, tempdir=False, user_agent=DEFAULT_USER_AGENT):
    """Download a file.

    url:   the url to the remote file.
    fname: the name to rename the file to upon download.
    tempdir: If True, place file in a temporary directory.
             Otherwise, place in current directory.

    Returns absolute path to downloaded file.
    """
    if fname is None:
        fname = os.path.basename(url)

    if tempdir:
        d = cd_tempdir()

    dest = None
    try:
        print("Downloading %s..." % url)

        req = urllib2.Request(url)
        # For an unknown reason, microsoft.com reliably rejects the urllib2 User
        # Agent with error 403 (but oddly doesn't block the python-requests User
        # Agent). To avoid issues, just use the well-known Mozilla User Agent.
        req.add_header("User-Agent", user_agent)

        # These two headers are simply the defaults used by the requests library,
        # which is known to work.  There is no deeper rationale for these exact
        # headers.
        req.add_header("Accept", "*/*")
        req.add_header("Connection", "keep-alive")

        resp = urllib2.urlopen(req)
        data = resp.read()

        with open(fname, "wb") as f:
            f.write(data)

        # Get abspath in temp dir before returning to original directory
        # (only matters if tempdir == True, but also correct if False)
        dest = os.path.abspath(fname)
    except (urllib2.URLError, urllib2.HTTPError) as e:
        print(
            (
                "error: unable to retrieve certificate from URL: %s. "
                "Error message: %s.\n\nIf the download was blocked with a 403 "
                "HTTP error, you may retry with a different user agent:\n"
                "secureboot-certs install --user-agent=\"Mozilla/5.0 "
                "My custom user agent\"\n\n"
                "If this still doesn't work, you can download and install the "
                "certificates manually:\n"
                "https://xcp-ng.org/docs/guides.html#install-the-default-uefi-certificates-manually"
            )
            % (url, e)
        )
        sys.exit(1)
    finally:
        if tempdir:
            os.chdir(d)
    return dest


def convert_der_to_pem(der):
    der = os.path.abspath(der)
    d = cd_tempdir()

    # Attempt to convert file foo.der -> foo.crt
    pem = der.replace(".der", "") + ".crt"
    pem = os.path.abspath(os.path.basename(pem))

    try:
        subprocess.check_call(
            [
                "openssl",
                "x509",
                "-in",
                der,
                "-inform",
                "DER",
                "-outform",
                "PEM",
                "-out",
                pem,
            ],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
    finally:
        os.chdir(d)
    return pem


def create_auth(signing_key, signing_cert, var, *certs):
    auth = var + ".auth"

    if any([signing_key, signing_cert]) and not all([signing_key, signing_cert]):
        raise RuntimeError(
            (
                "signing_key and signing_cert must either both "
                "be None or both be defined"
            )
        )
    if signing_key and signing_cert:
        subprocess.check_call(
            [
                "/opt/xensource/libexec/create-auth",
                "-k",
                signing_key,
                "-c",
                signing_cert,
                var,
                auth,
            ]
            + list(certs),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
    else:
        subprocess.check_call(
            [
                "/opt/xensource/libexec/create-auth",
                var,
                auth,
            ]
            + list(certs),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
    return os.path.abspath(auth)


def session_init():
    session = XenAPI.xapi_local()
    ret = session.xenapi.login_with_password(
        "", "", "0.1", "secureboot-certificates.py"
    )
    return session


def clear(session):
    for pool in Pool.get_all(session):
        pool.set_custom_certs(b"")
        print("Cleared certificates from XAPI DB for pool %s." % pool.uuid)


def create_tarball(paths):
    tarball = BytesIO()
    with tarfile.open(mode="w", fileobj=tarball) as tar:
        for name, path in paths.items():
            if not is_auth(path):
                raise RuntimeError("error: %s is not an auth file" % path)
            tar.add(path, arcname="%s.auth" % name)
    return tarball


def getdefault(name, user_agent=DEFAULT_USER_AGENT):
    if name == "PK":
        return "/usr/share/varstored/PK.auth"
    elif name == "KEK":
        return "/usr/share/varstored/KEK.auth"
    elif name == "db":
        return "/usr/share/varstored/db.auth"
    elif name == "dbx":
        return "/usr/share/varstored/dbx.auth"
    else:
        return None


def getpath(args, name):
    val = getattr(args, name)
    user_agent = getattr(args, "user_agent") if getattr(args, "user_agent") is not None else DEFAULT_USER_AGENT
    if os.path.exists(val):
        if os.stat(val).st_size <= 0:
            logging.debug("file %s is empty, skipping..." % val)
            return None
        logging.debug("using file %s for %s" % (val, name))
        return os.path.abspath(val)
    elif val == "default":
        logging.debug("%s for %s" % (val, name))
        return getdefault(name, user_agent=user_agent)
    elif name == "dbx" and val == "latest":
        return download(Urls.dbx, "dbx.auth", tempdir=True, user_agent=user_agent)
    elif name == "dbx" and val == "none":
        logging.debug("No path for dbx, set dbx to 'none'")
        return None
    else:
        print("error: file %s does not exist, and is not option 'default'" % val)
        sys.exit(1)


def validate_args(args):
    valid_values = {
        "PK": ["default"],
        "KEK": ["default"],
        "db": ["default"],
        "dbx": ["default", "latest", "none"],
    }

    for name in ["PK", "KEK", "db", "dbx"]:
        value = getattr(args, name)
        if value not in valid_values[name] and not os.path.exists(value):
            print("error: file %s does not exist." % value)
            sys.exit(1)

    if os.path.exists(args.PK) and not is_auth(args.PK) and not getattr(args, "pk_priv", False):
        print(
            "error: setting a custom PK requires supplying its private half "
            "to --pk-priv."
        )
        sys.exit(1)


def extract(session, args):
    pool = Pool.get_all(session)[0]
    paths = pool.save_certs_to_disk()
    cert = None
    for path in paths:
        if args.cert in path:
            cert = path
            break

    if not cert:
        print("error: cert %s does not exist in XAPI pool DB." % args.cert)
        sys.exit(1)

    shutil.copy(cert, args.filename)


def install(session, args):
    validate_args(args)

    paths = dict()
    for name in ["PK", "KEK", "db", "dbx"]:
        p = getpath(args, name)
        if not p:
            continue
        if name == "PK" and getattr(args, "pk_priv", False):
            priv = os.path.abspath(args.pk_priv)
        else:
            priv = None
        paths[name] = convert_to_auth(name, p, priv)

    tarball = create_tarball(paths)
    data = base64.b64encode(tarball.getvalue())
    tarball.close()

    pool = Pool.get_all(session)[0]
    if not pool:
        print("Could not retrieve pool from XAPI")
        sys.exit(1)
    pool.set_custom_certs(data)
    print("Successfully installed custom certificates to the XAPI DB for pool.")


def convert_to_auth(var, path, priv=None):
    """Return an auth file created from an X509 cert or auth file.

    If path points to an auth file already, its path will be returned without
    modification.

    If it is a DER X509, it will be converted into a new PEM file prior to
    building the auth file (create-auth requires PEM certs).  The original DER
    file will be unaffected.

    If it is already a PEM, no conversion will be required.

    Arguments:
        var - the name of the UEFI variable
        path - a path to an auth, X509 DER, or X509 PEM file.
        priv - the private half of the cert.  Only used for self-signing PK.
    """
    if is_auth(path):
        logging.debug("Using auth directly: %s" % path)
        return path
    elif is_der(path):
        pem = convert_der_to_pem(path)
        logging.debug("Creating auth %s from DER %s" % (var, path))
    elif is_pem(path):
        pem = path
        logging.debug("Creating auth %s from PEM %s" % (var, path))
    else:
        print("file %s is not a valid auth file or x509 certificate" % path)
        sys.exit(1)

    prevdir = cd_tempdir()
    if priv:
        # priv is only used for self-signing the PK as required
        # by the spec and varstored.
        auth = create_auth(priv, pem, var, pem)
    else:
        auth = create_auth(None, None, var, pem)
    os.chdir(prevdir)
    return auth


def is_auth(path):
    """Return True if path is an EFI auth file, otherwise returns False."""
    with open(path, "rb") as f:
        data = f.read()

    if len(data) < 15:
        return False

    # Validate the timestamp
    year = struct.unpack("<H", data[:2])[0]
    month = struct.unpack("<B", data[2])[0]
    day = struct.unpack("<B", data[3])[0]
    hour = struct.unpack("<B", data[4])[0]
    minute = struct.unpack("<B", data[5])[0]
    seconds = struct.unpack("<B", data[6])[0]

    try:
        _ = datetime(year, month, day, hour, minute, seconds)
    except ValueError:
        return False

    pad1 = struct.unpack("<B", data[7])[0]
    nanosecond = struct.unpack("<I", data[8:12])[0]
    tz = struct.unpack("<H", data[12:14])[0]
    daylight = struct.unpack("<B", data[14])[0]
    pad2 = struct.unpack("<B", data[15])[0]

    if pad1 != 0 or nanosecond != 0 or tz != 0 or daylight != 0 or pad2 != 0:
        return False

    # Skip dwLength.  Someday it should be used to verify the data length.
    revision = struct.unpack("<H", data[0x14:0x16])[0]
    if revision != 0x200:
        return False

    # wCertificateType
    certificate_type = struct.unpack("<H", data[0x16:0x18])[0]
    if certificate_type != 0x0EF1:
        return False

    # ... at this point there *is* further verification we can do (verify
    # lengths, etc...) but that level of verification is probably unnecessary
    # for the use case here, which is to simply stop the user from accidentally
    # passing in a wrong file.  For that reason, if we get to this point, we
    # consider the file minimally valid and return True

    return True


def is_der(path):
    """Return True if path is a DER-encoded X509 certificate, otherwise return False."""
    return is_cert_type(path, "DER")


def is_pem(path):
    """Return True if path is a PEM format X509 certificate, otherwise return False."""
    return is_cert_type(path, "PEM")


def is_cert_type(path, t):
    """
    Return True if path is a cert of type t, otherwise return False.

    Arguments:
        t: must be "DER" or "PEM"

    """
    if t not in ("DER", "PEM"):
        raise RuntimeError("arg %s is not DER or PEM" % t)

    with open(path, "rb") as f:
        data = f.read()

    p = Popen(
        ["openssl", "x509", "-inform", t, "-noout"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )

    p.stdin.write(data)
    while p.returncode is None:
        p.poll()

    return p.returncode == 0


def find(strings, substr):
    for item in strings:
        if substr in item:
            return item
    return None


class Pool(object):
    """
    This class represents a XAPI Pool.
    """

    def __init__(self, opaque_ref, session):
        self.session = session
        attrname = type(self).__name__.lower()
        self.xapi_class = getattr(session.xenapi, attrname)
        self.opaque_ref = opaque_ref
        self.__default_certs = None
        self.__custom_certs = None

    @property
    def uuid(self):
        return self.xapi_class.get_uuid(self.opaque_ref)

    @classmethod
    def get_all(cls, session):
        attrname = cls.__name__.lower()
        xapi_class = getattr(session.xenapi, attrname)
        refs = xapi_class.get_all()
        logging.debug("XAPI Request: session.xenapi.%s.get_all()" % attrname)
        return [cls(ref, session) for ref in refs]

    def get_default_certs(self):
        if self.__default_certs is None:
            self.__default_certs = self.xapi_class.get_uefi_certificates(self.opaque_ref)
        return self.__default_certs

    def get_custom_certs(self):
        if self.__custom_certs is None:
            self.__custom_certs = self.xapi_class.get_custom_uefi_certificates(self.opaque_ref)
        return self.__custom_certs

    def get_active_certs(self):
        active_certs = self.get_custom_certs()
        if not active_certs:
            active_certs = self.get_default_certs()
        return active_certs

    def set_custom_certs(self, data):
        self.xapi_class.set_custom_uefi_certificates(self.opaque_ref, data)

    def save_certs_to_disk(self):
        pool = Pool.get_all(self.session)[0]
        decoded = base64.b64decode(pool.get_active_certs())

        d = cd_tempdir()
        _, fname = tempfile.mkstemp()
        atexit.register(lambda: os.remove(fname))

        if not decoded:
            return []

        with open(fname, "w") as f:
            f.write(decoded)

        try:
            subprocess.check_call(
                ["tar", "xvf", fname], stdout=subprocess.PIPE, stderr=subprocess.PIPE
            )
        except subprocess.CalledProcessError:
            print("Certs for %s is not valid tarball" % self)
            return []

        paths = []
        for dirpath, _, filenames in os.walk(os.curdir):
            for f in filenames:
                path = os.path.join(dirpath, f)
                if path.endswith(".auth"):
                    paths.append(os.path.abspath(path))

        os.chdir(d)

        ret = []
        for name in ["PK.auth", "KEK.auth", "db.auth", "dbx.auth"]:
            p = find(paths, name)
            if p:
                ret.append(p)

        return ret


def print_cert(path, uuid, verbose=False):
    print("\tCertificate: %s" % os.path.basename(path))
    if verbose:
        print("\tPool: %s" % uuid)
    print("\tAuth file md5: %s" % hashfile(path))

    if verbose:
        print("\tData:")
        output = subprocess.check_output(["hexdump", "-Cv", path])
        for line in output.splitlines():
            print("\t\t%s" % line)
        print("")


def report(session, verbose=False):
    try:
        print("\n{} -- Report".format(os.path.basename(sys.argv[0])))
        pool = Pool.get_all(session)[0]
        paths = pool.save_certs_to_disk()
        print("Certificate Info for pool: %s):" % pool.uuid)
        s = "\tCertificates (%s): " % len(paths)
        s += ", ".join(os.path.basename(p) for p in paths)
        s += "\n"
        print(s)
        for path in paths:
            print_cert(path, pool.uuid, verbose=verbose)
    except IOError:
        # This technique taken from: https://docs.python.org/3/library/signal.html#note-on-sigpipe
        # Redirect further stdout flushing (like the broken pipe err message) to /dev/null
        devnull = os.open(os.devnull, os.O_WRONLY)
        os.dup2(devnull, sys.stdout.fileno())
        sys.exit(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description=("Configure guest UEFI certificates for an XCP-ng system.")
    )

    parser.add_argument(
        "--version",
        "-V",
        action="store_true",
        help="Print the version number",
    )

    parser.add_argument(
        "--debug",
        "-d",
        action="store_true",
        help="Debug output.",
    )

    action_parsers = parser.add_subparsers()
    install_parser = action_parsers.add_parser(
        Actions.INSTALL,
        help="Install UEFI certificates to the pool.",
        description=(
            "Install UEFI certificates to the pool.\n\n"
            "If no arguments are passed to '{} {}', then the default PK, KEK, "
            "db and dbx will be installed.".format(
                os.path.basename(sys.argv[0]), Actions.INSTALL
            )
        ),
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""

Important Note: dbx variable contents are downloaded automatically from
the Internet, and therefore require network access.

Certificate / auth file URLs:
dbx: {}
""".format(
            Urls.dbx
        ),
    )
    install_parser.set_defaults(action=Actions.INSTALL)

    install_parser.add_argument(
        "--user-agent",
        help="Sets a custom user agent to download default certificates from the Internet.",
        default=DEFAULT_USER_AGENT,
        nargs='?'
    )
    install_parser.add_argument(
        "PK",
        metavar="PK",
        help=(
            "'default' for the default XCP-ng PK or a path to a custom auth file. "
            "If a custom file it must be an EFI .auth file, "
            "a DER-encoded X509 certificate, or a PEM X509 certificate."
        ),
        default='default',
        nargs='?'
    )
    install_parser.add_argument(
        "--pk-priv",
        help=(
            "The private half of the PK certificate.  "
            "Required for custom PK certificates."
        ),
    )
    install_parser.add_argument(
        "KEK",
        metavar="KEK",
        help=(
            "'default' for the default Microsoft certs or a path to a custom auth file. "
            "If a custom file it must be an EFI .auth file, "
            "a DER-encoded X509 certificate, or a PEM X509 certificate."
        ),
        default='default',
        nargs='?'
    )
    install_parser.add_argument(
        "db",
        metavar="db",
        help=(
            "'default' for the default Microsoft certs or a path to a custom auth file. "
            "If a custom file it must be an EFI .auth file, "
            "a DER-encoded X509 certificate, or a PEM X509 certificate."
        ),
        default='default',
        nargs='?'
    )

    install_parser.add_argument(
        "dbx",
        metavar="dbx",
        help="""
'default' for the default dbx, 'latest' for the most recent UEFI dbx, a path to
a custom auth file, or 'none' for no dbx.

If a custom file, it must be an EFI .auth file, a DER-encoded X509 certificate,
or a PEM X509 certificate.

The 'default' and 'latest' dbx revoke any software that hasn't implemented the
most recent security fixes, which may include some OS distributions (even if
they're totally updated, depending how recently the vulnerability was
discovered). Because it varies per distribution, check if your guest
distributions are updated to pass the most recent UEFI revocation before
installing the latest dbx.

Note: Choosing 'none' allows attackers to simply load vulnerable binaries that
were previously signed but later revoked, and therefore bypass Secure Boot
protection.

Guests may extend, replace, or modify the dbx for the VM in which they run if
the default KEK is used, even if dbx is set to 'none'.

For older dbx files, see: https://uefi.org/revocationlistfile/archive. They may
be passed to {} as custom auth files.
""".format(
            os.path.basename(sys.argv[0])
        ),
        default='default',
        nargs='?'
    )

    clear_parser = action_parsers.add_parser(
        Actions.CLEAR,
        help=(
            "Remove all user-installed UEFI certificates from the pool. "
            "The pool will use the default certificates found in "
            "/usr/share/varstored, if they are present. "
        ),
    )
    clear_parser.set_defaults(action=Actions.CLEAR)

    report_parser = action_parsers.add_parser(
        Actions.REPORT,
        help=(
            "View a report containing information about the active UEFI "
            "certificates for the pool."
        ),
    )
    report_parser.add_argument(
        "--verbose",
        "-v",
        dest="report_verbose",
        action="store_true",
        help="Verbose report output.",
    )
    report_parser.set_defaults(action=Actions.REPORT)

    extract_parser = action_parsers.add_parser(
        Actions.EXTRACT,
        help=(
            "Extract a certificate (.auth file) from XAPI and save it to disk."
        ),
    )
    extract_parser.set_defaults(action=Actions.EXTRACT)
    extract_parser.add_argument(
        "cert",
        choices=["PK", "KEK", "db", "dbx"],
        help="The certificate (.auth file) to be extracted from XAPI.",
    )
    extract_parser.add_argument(
        "filename",
        help="The output file name.",
    )

    if "--version" in sys.argv or "-V" in sys.argv:
        print(__version__)
        sys.exit(0)
    else:
        args = parser.parse_args()

        logging.basicConfig(level=logging.DEBUG if args.debug else logging.WARNING)

        session = session_init()
        if args.action == Actions.CLEAR:
            clear(session)
        elif args.action == Actions.INSTALL:
            install(session, args)
        elif args.action == Actions.REPORT:
            report(session, args.report_verbose)
        elif args.action == Actions.EXTRACT:
            extract(session, args)
        else:
            sys.exit(1)
