#!/usr/bin/python3
#
# Copyright (c) 2011 Citrix Systems, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only. with the special
# exception on linking described in file LICENSE.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#


import fcntl
import os
import os.path
import sys
import time
import common
from common import ON_ERROR_FAIL, send_to_syslog, doexec, run

path = "/sbin:/usr/sbin:/bin:/usr/bin"

command_name = sys.argv[0]
short_command_name = os.path.basename(command_name)

ebtables = "ebtables"
ebtables_lock_path = "/var/lock/ebtables.lock"

vsctl = "ovs-vsctl"
ofctl = "ovs-ofctl"

ip = "/sbin/ip"
brctl = "/usr/sbin/brctl"
xs_write = "/usr/bin/xenstore-write"
xs_read = "/usr/bin/xenstore-read"
xs_rm = "/usr/bin/xenstore-rm"

def get_host_network_mode(vif_name, domuuid, devid):
    i = common.Interface(vif_name, domuuid, devid)
    v = i.get_vif()
    return v.get_mode()

def get_vif(vif_name, domuuid, devid):
    i = common.Interface(vif_name, domuuid, devid)
    return i.get_vif()

def ip_link_set(device, direction):
    doexec([ip, "link", "set", device, direction])

def xs_hotplug_path(vif):
    domid = vif.vif_name[3:].split('.')[0]
    return '/xapi/%s/hotplug/%s/vif/%s' % (vif.vm_uuid, domid, vif.devid)

def xs_private_path(vif):
    return '/xapi/%s/private/vif/%s' % (vif.vm_uuid, vif.devid)

def setup_pvs_proxy_rules(vif, command):
    ty = vif.vif_name[:3]
    private = xs_private_path(vif)
    hotplug = xs_hotplug_path(vif)
    # PVS proxy rules are only setup if the pvs-interface xenstore node exists
    rc, stdout, stderr = doexec([xs_read, '%s/pvs-interface' % private])
    if rc == 0:
        rc, stdout, stderr = doexec([xs_read, '%s/setup-pvs-proxy-rules' % private])
        script = stdout.readline().strip()
        doexec([script, command, ty, vif.vif_name, private, hotplug])

###############################################################################
# Creation of ebtables rules in the case of bridge.
###############################################################################

def add_bridge_port(bridge_name, vif_name):
    doexec([brctl, "setfd", bridge_name, '0'])
    doexec([brctl, "addif", bridge_name, vif_name])

def remove_bridge_port(vif_name):
    try:
        path = os.readlink('/sys/class/net/%s/brport/bridge' % vif_name)
        bridge_name = os.path.basename(path)
        doexec([brctl, "delif", bridge_name, vif_name])
    except:
        # the interface was not on the bridge: ignore
        pass

def get_chain_name(vif_name):
    return ("FORWARD_%s" % vif_name)

def do_chain_action(executable, action, chain_name, args=None):
    if args is None:
        args = []
    return doexec([executable, action, chain_name] + args)

def chain_exists(executable, chain_name):
    rc, _, _ = do_chain_action(executable, "-L", chain_name)
    return (rc == 0)

def clear_bridge_rules(vif_name):
    vif_chain = get_chain_name(vif_name)
    if chain_exists(ebtables, vif_chain):
        # Stop forwarding to this VIF's chain.
        do_chain_action(ebtables, "-D", "FORWARD", ["-i", vif_name, "-j", vif_chain])
        do_chain_action(ebtables, "-D", "FORWARD", ["-o", vif_name, "-j", vif_chain])
        # Flush and delete the VIF's chain.
        do_chain_action(ebtables, "-F", vif_chain)
        do_chain_action(ebtables, "-X", vif_chain)

def create_bridge_rules(vif_name, config):
    vif_chain = get_chain_name(vif_name)
    # Forward all traffic on this VIF to a new chain, with default policy DROP.
    do_chain_action(ebtables, "-N", vif_chain)
    do_chain_action(ebtables, "-A", "FORWARD", ["-i", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-A", "FORWARD", ["-o", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-P", vif_chain, ["DROP"])
    # We now need to create rules to allow valid traffic.
    mac = config["mac"]
    # We only fully support IPv4 multitenancy with bridge.
    # We allow all IPv6 traffic if any IPv6 addresses are associated with the VIF.
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Accept all traffic going to the VM.
    do_chain_action(ebtables, "-A", vif_chain, ["-o", vif_name, "-j", "ACCEPT"])
    # Drop everything not coming from the correct MAC.
    do_chain_action(ebtables, "-A", vif_chain, ["-s", "!", mac, "-i", vif_name, "-j", "DROP"])
    # Accept DHCP.
    do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-protocol", "UDP", "--ip-dport", "67", "-j", "ACCEPT"])
    for ipv4 in ipv4_allowed:
        # Accept ARP travelling from known IP addresses, also filtering ARP replies by MAC.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Request", "--arp-ip-src", ipv4, "-j", "ACCEPT"])
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Reply", "--arp-ip-src", ipv4, "--arp-mac-src", mac, "-j", "ACCEPT"])
        # Accept IP travelling from known IP addresses.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-src", ipv4, "-j", "ACCEPT"])
    if ipv6_allowed != []:
        # Accept all IPv6 traffic.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv6", "-i", vif_name, "-j", "ACCEPT"])

def acquire_lock(path):
    lock_file = open(path, 'w')
    while True:
        send_to_syslog("attempting to acquire lock %s" % path)
        try:
            fcntl.lockf(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
            send_to_syslog("acquired lock %s" % path)
            return lock_file
        except IOError as e:
            send_to_syslog("failed to acquire lock %s because '%s' - waiting 1 second" % (path, e))
            time.sleep(1)

def handle_bridge(vif_name, domuuid, devid, action):
    if (action == "clear") or (action == "filter"):
        # ebtables fails if called concurrently, so acquire a lock before starting to call it.
        ebtables_lock_file = acquire_lock(ebtables_lock_path)
        # Start from a clean slate
        ip_link_set(vif_name, "down")
        clear_bridge_rules(vif_name)
        remove_bridge_port(vif_name)
        if action == "filter":
            # Handle the requested locking mode
            vif = get_vif(vif_name, domuuid, devid)
            config = vif.get_locking_mode()
            bridge_name = vif.get_bridge()
            locking_mode = config["locking_mode"]
            # The VIF must be on the bridge, unless it is in 'disabled' mode
            if locking_mode != "disabled":
                add_bridge_port(bridge_name, vif_name)
            # Only apply rules if the VIF is 'locked'
            if locking_mode == "locked":
                create_bridge_rules(vif_name, config)
        # Finally, bring up the netback device
        ip_link_set(vif_name, "up")

###############################################################################
# Creation of openflow rules in the case of openvswitch.
###############################################################################

def get_vswitch_port(vif_name):
    (rc, stdout, stderr) = doexec([vsctl, "get", "interface", vif_name, "ofport"])
    return stdout.readline().strip()

def xs_ofport_path(vif):
    ty = vif.vif_name[:3]
    return '%s/%s-ofport' % (xs_hotplug_path(vif), ty)

def set_xs_ofport_path(vif):
    path = xs_ofport_path(vif)
    port = get_vswitch_port(vif.vif_name)
    doexec([xs_write, path, port])

def clear_xs_ofport_path(vif):
    path = xs_ofport_path(vif)
    doexec([xs_rm, path])

def make_vswitch_external_ids(vif):
    args = []
    for (k, v) in vif.get_external_ids().items():
        args += ["--", "set", "interface", vif.vif_name, 'external-ids:"%s"="%s"' % (k, v)]
    return args

def add_vswitch_port(bridge_name, vif):
    args = [vsctl, "--timeout=30", "add-port", bridge_name, vif.vif_name]
    args += make_vswitch_external_ids(vif)
    run(ON_ERROR_FAIL, args)
    set_xs_ofport_path(vif)
    setup_pvs_proxy_rules(vif, "add")

def remove_vswitch_port(vif):
    setup_pvs_proxy_rules(vif, "remove")
    clear_xs_ofport_path(vif)
    doexec([vsctl, "--timeout=30", "--", "--if-exists", "del-port", vif.vif_name])

def get_current_bridge_name_vswitch(vif_name):
    '''return the bridge that the VIF currently belongs to'''
    (rc, stdout, stderr) = doexec([vsctl, "iface-to-br",  vif_name ])
    return stdout.readline().strip()

def get_parent_bridge(bridge_name):
    '''get bridge parent, in case we were given a fake bridge device'''
    '''will return same name if it is already a real bridge'''
    (rc, stdout, stderr) = doexec([vsctl, "br-to-parent", bridge_name ])
    return stdout.readline().strip()

def clear_vswitch_rules(vif_name):
    '''clear OVS rules, if any'''
    port = get_vswitch_port(vif_name)
    if port != '':
        bridge_name = get_current_bridge_name_vswitch(vif_name)
        bridge_name = get_parent_bridge(bridge_name)
        doexec([ofctl, "del-flows", bridge_name, "in_port=%s,cookie=1/-1" % port])

def add_flow(bridge_name, args):
    doexec([ofctl, "add-flow", bridge_name, "cookie=1," + args])

def create_vswitch_rules(bridge_name, config):
    port = get_vswitch_port(vif_name)
    bridge_name = get_parent_bridge(bridge_name)
    mac = config["mac"]
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Allow DHCP traffic (outgoing UDP on port 67).
    add_flow(bridge_name, "in_port=%s,priority=8000,dl_type=0x0800,nw_proto=0x11,"
                          "tp_dst=67,dl_src=%s,idle_timeout=0,action=normal" % (port, mac))
    # Filter ARP requests.
    add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                          "nw_src=0.0.0.0,idle_timeout=0,action=normal" % (port, mac, mac))
    for ipv4 in ipv4_allowed:
        # Filter ARP responses.
        add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                              "nw_src=%s,idle_timeout=0,action=normal" % (port, mac, mac, ipv4))
        # Allow traffic from specified ipv4 addresses.
        add_flow(bridge_name, "in_port=%s,priority=6000,dl_type=0x0800,nw_src=%s,"
                              "dl_src=%s,idle_timeout=0,action=normal" % (port, ipv4, mac))
    for ipv6 in ipv6_allowed:
        # Neighbour solicitation.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=135,nd_sll=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, mac))
        # Neighbour advertisement.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=136,nd_target=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, ipv6))
        # Allow traffic from specified ipv6 addresses.
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,icmp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,tcp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,udp6,action=normal" % (port, mac, ipv6))
    # Drop all other neighbour discovery.
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=135,action=drop" % port)
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=136,action=drop" % port)
    # Drop other specific ICMPv6 types.
    # Router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=134,action=drop" % port)
    # Redirect gateway.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=137,action=drop" % port)
    # Mobile prefix solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=146,action=drop" % port)
    # Mobile prefix advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=147,action=drop" % port)
    # Multicast router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=151,action=drop" % port)
    # Multicast router solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=152,action=drop" % port)
    # Multicast router termination.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=153,action=drop" % port)
    # Drop everything else.
    add_flow(bridge_name, "in_port=%s,priority=4000,idle_timeout=0,action=drop" % port)

def handle_vswitch(vif_name, domuuid, devid, action):
    if (action == "clear") or (action == "filter"):
        # Start from a clean slate
        vif = get_vif(vif_name, domuuid, devid)
        ip_link_set(vif_name, "down")
        clear_vswitch_rules(vif_name)
        remove_vswitch_port(vif)
        if action == "filter":
            # Handle the requested locking mode
            config = vif.get_locking_mode()
            bridge_name = vif.get_bridge()
            locking_mode = config["locking_mode"]
            # The VIF must be on the bridge, unless it is in 'disabled' mode
            if locking_mode != "disabled":
                add_vswitch_port(bridge_name, vif)
            # Only apply rules if the VIF is 'locked'
            if locking_mode == "locked":
                create_vswitch_rules(bridge_name, config)
        # Finally, bring up the netback device
        ip_link_set(vif_name, "up")

###############################################################################
# Executable entry point.
###############################################################################

def main(vif_name, domuuid, devid, network_mode, action):
    if network_mode == "bridge":
        handle_bridge(vif_name, domuuid, devid, action)
    elif network_mode == "openvswitch":
        handle_vswitch(vif_name, domuuid, devid, action)

def usage():
    print("Usage:")
    print("%s xenopsd_backend interface uuid devid action" % short_command_name)
    print("")
    print("interface:")
    print("    The name of the network interface connecting to the guest.")
    print("uuid:")
    print("    The UUID of the guest.")
    print("devid:")
    print("    The number of the network to which the interface is connected.")
    print("action: [clear|filter]")
    print("    Specifies whether to create filtering rules for the interface, or to clear all rules associated with the interface.")

if __name__ == "__main__":
    if len(sys.argv) != 6:
        usage()
        sys.exit(1)
    else:
        xenopsd_backend, vif_name, domuuid, devid, action = sys.argv[1:6]
        common.set_xenopsd_backend(xenopsd_backend)
        network_mode = get_host_network_mode (vif_name, domuuid, devid)

        send_to_syslog("Called with interface=%s, uuid=%s, devid=%s, network_mode=%s, action=%s" %
            (vif_name, domuuid, devid, network_mode, action))
        main(vif_name, domuuid, devid, network_mode, action)
