__author__ = 'hvishwanath'

import logging
import os
import json

from network_utils import *
from appfw.utils.utils import Utils

log = logging.getLogger("pdservices")

class PortRegistry(object):
    """
    Registry that :
    - Keeps track of the mapped higher order ports for NAT based networks
    - Knows which app->interface->port is associated to which external port
    - Knows what higher order ports are availabe vs used up
    """
    __singleton = None # the one, true Singleton

    # Keep a map of app->ports
    PORT_REGISTRY = {}
    ALLOCATED_TCP_PORTS = {}
    ALLOCATED_UDP_PORTS = {}

    MODE_AUTO = "auto"
    MODE_1to1 = "1to1"
    
    _to_serialize = ("PORT_REGISTRY",)

    def __new__(cls, *args, **kwargs):
        # Check to see if a __singleton exists already for this class
        # Compare class types instead of just looking for None so
        # that subclasses will create their own __singleton objects
        if cls != type(cls.__singleton):
        #if not cls.__singleton:
            cls.__singleton = super(PortRegistry, cls).__new__(cls, *args, **kwargs)
        return cls.__singleton

    def __init__(self, repodir, tcp_pat_range, udp_pat_range, tcp_blocked_ports=[], udp_blocked_ports=[]):
        self.repodir = repodir
        self.datafile = os.path.join(self.repodir, PORT_REGISTRY_FILE)
        self.tcp_pat_range = tcp_pat_range
        self.udp_pat_range = udp_pat_range
        self.tcp_blocked_ports = []
        self.udp_blocked_ports = []
        #If any port ranges specified in the list will be converted in to port numbers
        for port in tcp_blocked_ports:
            if isinstance(port, int):
                self.tcp_blocked_ports.append(port)
            elif isinstance(port, unicode) or isinstance(port, str):
                start, length = interpret_port_string(port)
                self.tcp_blocked_ports.extend(range(start, start+length))
        for port in udp_blocked_ports:
            if isinstance(port, int):
                self.udp_blocked_ports.append(port)
            elif isinstance(port, unicode) or isinstance(port, str):
                start, length = interpret_port_string(port)
                self.udp_blocked_ports.extend(range(start, start+length))

        api_port = Utils.getSystemConfigValue("api", "port", "")
        if api_port and int(api_port) not in self.tcp_blocked_ports:
            self.tcp_blocked_ports.append(int(api_port))
            
        if Utils.getSystemConfigValue("api", "enable_secondary_server", False, "bool"):
            sec_api_port = Utils.getSystemConfigValue("api", "secondary_server_port", "")
            if sec_api_port and int(sec_api_port) not in self.tcp_blocked_ports:
                self.tcp_blocked_ports.append(int(sec_api_port))
            
        self.tcp_start_port, self.tcp_end_port = self.tcp_pat_range.split("-")
        self.tcp_start_port = int(self.tcp_start_port)
        self.tcp_end_port = int(self.tcp_end_port)
        assert self.tcp_end_port > self.tcp_start_port

        self.udp_start_port, self.udp_end_port = self.udp_pat_range.split("-")
        self.udp_start_port = int(self.udp_start_port)
        self.udp_end_port = int(self.udp_end_port)
        assert self.udp_end_port > self.udp_start_port

        self._load_data()
        
        if not self.PORT_REGISTRY:
            self.PORT_REGISTRY = dict()

            
    def _get_used_ports_set(self, port_type, network_name, network_type):
        """
        Get the set tracking used ports based on port_type and network_name
        If a set for a network_name doesn't exist, create and empty set to track
        Example:
        "ALLOCATED_TCP_PORTS": {"iox-nat0": set([4006, 4007]), "iox-nat1": set([7006, 7007])} 
        "ALLOCATED_UDP_PORTS": {"iox-nat0": set([8000, 8005]), "iox-nat1": set([5000, 5005])}
        """
        
        if port_type == "tcp":
            used_ports_set = self.ALLOCATED_TCP_PORTS.get(network_name)
            if used_ports_set == None:
                used_ports_set = set([])
                self.ALLOCATED_TCP_PORTS[network_name] = used_ports_set
        elif port_type == "udp":
            used_ports_set = self.ALLOCATED_UDP_PORTS.get(network_name)
            if used_ports_set == None:
                used_ports_set = set([])
                self.ALLOCATED_UDP_PORTS[network_name] = used_ports_set
        else:
            raise ValueError("Unexpected port type %s" % port_type)
            
        return used_ports_set

        
    def _set_port_available(self, port, length, port_type, network_name, network_type):
        """
        Remove range of ports range(port, port+length) from used ports set
        This makes this range of ports available to use
        """
        
        if length < 1:
            raise ValueError("Unexpected input length")
        elif length == 1:
            log.debug("Setting TCP ephemeral port %s to : AVAILABLE", port)
        elif length > 1:
            log.debug("Setting TCP ephemeral ports %s-%s to : AVAILABLE", port, port+length-1)
            
        used_ports_set = self._get_used_ports_set(port_type, network_name, network_type)    
        if not used_ports_set:
            return
    
        remove_set = set(range(port, port+length))
        
        used_ports_set -= remove_set
        
    def _is_tcp_port_available(self, app_port, network_name, network_type):
    
        port_start, num_ports = interpret_port_string(app_port)
        cand_set = set(range(port_start, port_start+num_ports))
        used_ports_set = self._get_used_ports_set("tcp", network_name, network_type)
        
        if cand_set & used_ports_set:
            return False

        if (cand_set & set(self.tcp_blocked_ports)) and network_type == "nat":
            return False
            
        return True
    
    def _is_udp_port_available(self, app_port, network_name, network_type):
    
        port_start, num_ports = interpret_port_string(app_port)
        cand_set = set(range(port_start, port_start+num_ports))
        used_ports_set = self._get_used_ports_set("udp", network_name, network_type)
        
        if cand_set & used_ports_set:
            return False

        if (cand_set & set(self.udp_blocked_ports)) :
            return False

        return True
        
        
    def _get_auto_port_range(self, app_port, used_ports_set, port_type):
        """
        app_port can be a single port "7000" or a range of ports "7000-7005"
        Search for unused external port(s) in tcp/udp port range
        Used ports tracked by used_ports_set
        Example:
        app_port is "7000" => return unused external port "40000"
        app_port is "7000-7005" => return unused external port range "40000-40005"
        """
        
        if port_type == "tcp":
            port_interval_start = self.tcp_start_port
            port_interval_end = self.tcp_end_port
        elif port_type == "udp":
            port_interval_start = self.udp_start_port
            port_interval_end = self.udp_end_port
        else:
            raise ValueError("Unexpected port type %s" % port_type)   

            
        port_start, num_ports = interpret_port_string(app_port)
        
        found = False
        for i in range(port_interval_start, port_interval_end-num_ports+2):
        
            cand_set = set(range(i, i+num_ports))
        
            if not cand_set & used_ports_set:
                found = True
                break
                
        if not found:
            raise ValueError("No free %s ports available on the system in the range %s-%s" % (port_type, port_interval_start, port_interval_end))
        
        port_ret = i
        if num_ports > 1:
            port_ret = str(i) + '-' + str(i+num_ports-1)
        
        return port_ret, cand_set
    
    
    def validate_ports(self, req_ports, port_type):
        """
        Validates if the req_ports are not blocked
        """
        port_start, num_ports = interpret_port_string(req_ports)
        cand_set = set(range(port_start, port_start+num_ports))
        if port_type == "tcp" and (cand_set & set(self.tcp_blocked_ports)):
            raise ValueError("Port(s) %s overlaps with the tcp blocked ports" % req_ports)
        if port_type == "udp" and (cand_set & set(self.udp_blocked_ports)):
            raise ValueError("Port(s) %s overlaps with the udp blocked ports" % req_ports)
        open_ports = Utils.get_open_port_list()
        log.debug("open ports: %s" % str(open_ports))
        if  (cand_set & set(open_ports)):
            raise ValueError("Port(s) %s overlaps with the already open ports" % req_ports)
        

    def _get_custom_port_range(self, mapped_port, used_ports_set, port_type):
        """
        mapped_port can be a single port "7000" or a range of ports "7000-7005"
        Search used_ports_set if requested mapped ports are already in use.
        """
        
        port_start, num_ports = interpret_port_string(mapped_port)
                            
        cand_set = set(range(port_start, port_start+num_ports))
        if cand_set & used_ports_set:
            raise ValueError("Port(s) %s overlaps with other ports already in use." % mapped_port)

        if port_type == "tcp" and (cand_set & set(self.tcp_blocked_ports)):
            raise ValueError("Port(s) %s overlaps with the tcp blocked ports" % mapped_port)

        if port_type == "udp" and (cand_set & set(self.udp_blocked_ports)):
            raise ValueError("Port(s) %s overlaps with the udp blocked ports" % mapped_port)
        if num_ports == 1:
            mapped_port = int(mapped_port)
            
        return mapped_port, cand_set 

        
    def _get_pat_port(self, app_port, network_type, req_port_map, port_type, network_name, port_map_bridge=False):
        """
        app_port can be a single port "7000" or a range of ports "7000-7005"
        If network_type is bridge mapping is one to one
        req_port_map is the requested port map from activate payload
        Example Activate payload:
        {
            "resources": {
            "profile": "custom",
            "cpu": "50",
            "memory": "50",
            "disk": "100",
            "network": [{"interface-name": "eth0", "network-name": "iox-nat0", "port_map":{ } }]
            }
        }
        Example port_maps in the payload:
        "port_map": {"mode": "auto"}  <--- all ports auto
        "port_map": {"mode": "1to1"}  <--- all port 1to1
        "port_map": {
        "mode": "auto",
        "tcp": {
        "9000": "15000",
        "10100:10200": "20100:20200"}
        }
        "port_map": {
        "mode": "1to1",
        "udp": {
        "7000": "15000",
        "30100:30105": "40100:40105"}
        }
        """

        # No host based PAT needed for non nat network
        if (network_type != "nat" and network_type != "nat_docker") and not port_map_bridge:
            return app_port
            
        port_start, num_ports = interpret_port_string(app_port)
        mapped_port = None
        
        # If port_map is missing in the payload, default to Auto port assignment
        port_map_mode = self.MODE_AUTO
        port_map_for_type = None
        if isinstance(req_port_map, dict):
            if "mode" in req_port_map:
                port_map_mode = req_port_map["mode"]
            
            port_map_for_type = req_port_map.get(port_type)            
            if isinstance(port_map_for_type, dict):
                for p in port_map_for_type:
                    if str(app_port) in p:
                        mapped_port = port_map_for_type[p]
                        port_start, mapped_num_ports = interpret_port_string(mapped_port)
                        if num_ports != mapped_num_ports:
                            raise ValueError("Invalid port mapping. Number of ports do not match "
                                             "between %s and %s" % (str(app_port), mapped_port))

        used_ports_set = self._get_used_ports_set(port_type, network_name, network_type)

        if port_map_mode == self.MODE_AUTO:
            if mapped_port:
                port_ret, cand_set = self._get_custom_port_range(mapped_port, used_ports_set, port_type)
            else:
                # Auto-assign external ports from tcp/udp port range
                port_ret, cand_set = self._get_auto_port_range(app_port, used_ports_set, port_type)            
        elif port_map_mode == self.MODE_1to1:
            if not mapped_port:
                mapped_port = str(app_port)
            port_ret, cand_set = self._get_custom_port_range(mapped_port, used_ports_set, port_type)
        else:
            raise ValueError("Unknown mode for port mapping: %s" % port_map_mode)
                    
        log.debug("Setting %s ephemeral port(s) %s to : OCCUPIED", port_type, port_ret)
            
        used_ports_set.update(cand_set)
        
        return port_ret
        
        
    def _is_ports_reallocate(self, network_name, ifmap, req_port_map):
        """
        Make a decision whether to use port mapping from PORT_REGISTRY or 
        to reallocate the ports. If ports need to be reallocated return True. 
        """
        nwname = ifmap.get("network_name")
        
        # If reactivating on different network, always reallocate
        if nwname != network_name:
            return True
        
        # Reactivating on same network
        # If no port map is provided, default is to use mapping from PORT_REGISTRY
        if not req_port_map:
            return False
        
        # Reactivating on same network and some port mapping is requested
        
        # If previous mapping was all auto and requested mapping is also all auto
        # Do not reallocate ports
        if ifmap["auto_port_map"] and self._is_auto_port_map(req_port_map):
            return False
        
        
        return True

    def _is_auto_port_map(self, req_port_map):
        """
        If requested port map asks auto mapping for all ports, return True
        """
        
        # No port map means auto
        if not req_port_map:
            return True
            
        port_map_mode = req_port_map.get("mode")
        if port_map_mode and port_map_mode != self.MODE_AUTO:
            return False
        udp_port_map = req_port_map.get("udp")
        tcp_port_map = req_port_map.get("tcp")
        if tcp_port_map or udp_port_map:
            # Mode is auto, but some ports are custom mapped
            return False
                
        return True
        
        
    def get_mapping(self, appid, interface_name, network_type, ports, req_port_map, network_name, port_map_bridge=False):
        """
        Main method that sets up PAT for requested ports.

        :param appid: ID of the requesting app
        :param network_type: nat or bridge.
        :param ports: ports requested by app (in the same structure as in the descriptor)
        :req_port_map: requested port map in activate payload
        :network_name: Name of the network in activate payload
        :return: Returns the same ports dict, but now with mapping info.
        Ex. Input: ports: tcp : [9000, 6000], Output: ports:tcp:[[9000, 40000], [6000, 40001]]
        """
        
        if appid in self.PORT_REGISTRY:
            appmap = self.PORT_REGISTRY.get(appid)
        else:
            appmap = dict()
            
        if interface_name in appmap:
            # Mapping for this particular interface for the given app exists.
            ifmap = appmap.get(interface_name)
            # Make a decision whether to reallocate ports
            is_reallocate = self._is_ports_reallocate(network_name, ifmap, req_port_map)
            if not is_reallocate:
                # Reallocate not needed, use mapping from PORT_REGISTRY
                return ifmap["mappings"]
            else:   
                nwtype = ifmap.get("network_type")
                nwname = ifmap.get("network_name")
                self.remove_mapping(appid, interface_name, nwtype, nwname)
                appmap[interface_name] = dict()
        else:
            # Create an empty dict here and fill it up below
            appmap[interface_name] = dict()

        ifmap = appmap[interface_name]
        ifmap["mappings"] = dict()

        if ports:
            log.debug("Network type: %s", network_type)
            log.debug("Input ports: %s", ports)
            log.debug("Interface name: %s", interface_name)
            
            try:
                # Try mapping all of the ports at once, if there is an error with one port
                # Roll back all of them
                for type in ports.keys():
                    if type!="tcp" and type !="udp":
                        raise ValueError("Invalid port type %s. Cannot proceed!" % type)

                    type_ports = ports.get(type)
                    type_mapping_list = []
                    for p in type_ports:
                        req_ports = p
                        description = None
                        if isinstance(p, dict):
                            req_ports = p.get('port')
                            description = p.get('description')
                        pat_port = self._get_pat_port(req_ports, network_type, req_port_map, type, network_name, port_map_bridge)
                        type_mapping_list.append([req_ports, pat_port, description])

                    ifmap["mappings"][type] = type_mapping_list
            except:
                # An error occured possibly an overlapping port
                # Roll back all ports to the last saved state of the port registry  
                if os.path.isfile(self.datafile):
                    self._load_data()
                else:
                    self.PORT_REGISTRY = {}
                    self.ALLOCATED_TCP_PORTS = {}
                    self.ALLOCATED_UDP_PORTS = {}
                raise
              
            ifmap["network_type"] = network_type
            ifmap["network_name"] = network_name
            
            ifmap["auto_port_map"] = False
            if self._is_auto_port_map(req_port_map):
                ifmap["auto_port_map"] = True

        self.PORT_REGISTRY[appid] = appmap

        # Save this data
        self._save_data()

        log.debug("Mapped ports : %s", self.PORT_REGISTRY[appid][interface_name]["mappings"])
        # Return the mapping
        return self.PORT_REGISTRY[appid][interface_name]["mappings"]

    def set_port_mapping(self, appid, interface_name, network_type=None, port_mapping={}):
        """
        Check for the existance of the app id, if exists add the interface ,mapping given.
        If app is not exist then add an entry for app and add the interface and port mapping to it.
        Once the mapping is done save the data.
        """
        log.debug("Setting the port mapping, app %s, interface %s, network type %s and ports %s" %(appid, interface_name, network_type, port_mapping))
        if appid in self.PORT_REGISTRY:
            appmap = self.PORT_REGISTRY.get(appid)
        else:
            appmap = dict()

        if interface_name in appmap.keys():
            del appmap[interface_name]
        appmap[interface_name] = dict()
        ifmap = appmap[interface_name]
        ifmap["mappings"] = dict()

        if port_mapping:
            log.debug("Network type: %s", network_type)
            log.debug("Input ports: %s", port_mapping)
            log.debug("Interface name: %s", interface_name)
            for type in port_mapping.keys():
                if type!="tcp" and type !="udp":
                    raise ValueError("Invalid port type %s. Cannot proceed!" % type)
                type_ports = port_mapping.get(type)
                if network_type :
                    for port_map in type_ports:
                        external_port = port_map[1]
                        internal_port = port_map[0]
                        if type == "tcp":
                            self._set_tcp_port_occupied(external_port)
                        elif type == "udp":
                            self._set_udp_port_occupied(external_port)
                ifmap["mappings"][type] = type_ports
                ifmap["network_type"] = network_type
        #For an app which only asked for the interfaces, didn't ask for ports
        elif port_mapping == {} and network_type is None:
            log.debug("As port mapping and network type are not provided so setting the empty port mapping")
        else:
            raise ValueError("Port mapping needs to be provided")
        self.PORT_REGISTRY[appid] = appmap
        self._save_data()

    def remove_mapping(self, appid, interface_name, network_type, network_name):
        """
        Remove all port mappings and free up the corresponding ephemeral ports
        :param appid:
        :param network_type:
        :return:
        """
        appmap = self.PORT_REGISTRY.get(appid)
        if appmap is None:
            log.error("No registry entry found for app %s", appid)
            return

        ifacemap = appmap.get(interface_name)
        if ifacemap is None:
            log.error("No registry entries found for app %s, interface %s", appid, interface_name)
            return

        if network_type :
            log.debug("Freeing up ephemeral ports occupied by app %s, interface %s", appid, interface_name)
            mappings = ifacemap.get("mappings")
            for type in mappings:
                type_mapping_list = mappings.get(type)
                for m in type_mapping_list:
                    req_port, ephemeral_port = m[0], m[1]
                    port_start, num_ports = interpret_port_string(ephemeral_port)
                    self._set_port_available(port_start, num_ports, type, network_name, network_type)
        else:
            log.debug("Non network type. Nothing to do..")

        log.debug("Removing registry entry for app %s", appid)
        # Remove entry from mapping registry
        self.PORT_REGISTRY[appid].pop(interface_name)

        self._save_data()

    def clear_app_port_entry(self, appid):
        """
        Remove all port mapping entries for this app.
        :param appid:
        :return:
        """
        appmap = self.PORT_REGISTRY.get(appid)
        if appmap is None:
            log.debug("No registry entry found for app %s", appid)
            return

        ifacelist = appmap.keys()
        for iface in ifacelist:
            network_type = appmap[iface].get("network_type")
            network_name = appmap[iface].get("network_name")
            log.debug("Clearing entries for appid %s, interface %s, network type %s", appid, iface, network_type)
            self.remove_mapping(appid, iface, network_type, network_name)

        # Remove the appid from the registry
        if appid in self.PORT_REGISTRY:
            self.PORT_REGISTRY.pop(appid)

        self._save_data()

    @classmethod
    def getInstance(cls, repodir, tcp_pat_range, udp_pat_range, tcp_blocked_ports, udp_blocked_ports):
        '''
        Returns a singleton instance of the class
        '''
        if not cls.__singleton:
            cls.__singleton = PortRegistry(repodir, tcp_pat_range, udp_pat_range,
                                                tcp_blocked_ports, udp_blocked_ports)
        return cls.__singleton


    def _save_data(self):
        # Simply overwrite the file. Should be okay, since this operation will not be done frequently
        # and the entries will not exceed more than a few.
        d = dict()
        d["PORT_REGISTRY"] = self.PORT_REGISTRY
        d["ALLOCATED_TCP_PORTS"] = {}
        d["ALLOCATED_UDP_PORTS"] = {}
        for nw_name, used_set in self.ALLOCATED_TCP_PORTS.items():
            d["ALLOCATED_TCP_PORTS"][nw_name] = list(used_set)
            
        for nw_name, used_set in self.ALLOCATED_UDP_PORTS.items():
            d["ALLOCATED_UDP_PORTS"][nw_name] = list(used_set)
                
        o = json.dumps(d)
        file(self.datafile, "w").write(o)

    def _load_data(self):
        d = dict()
        try:
            if os.path.isfile(self.datafile):
                d = json.load(file(self.datafile, "r"))
                log.debug("Loaded portregistry data from %s", self.datafile)
        except Exception as ex:
            log.exception("Error loading port registry data from %s:%s" % (self.datafile, str(ex)))

        self.PORT_REGISTRY = d.get("PORT_REGISTRY")

        tpt = d.get("ALLOCATED_TCP_PORTS")
        if isinstance(tpt, dict):
            for nw_name, used_list in tpt.items():
                self.ALLOCATED_TCP_PORTS[nw_name] = set(used_list)
        
        upt = d.get("ALLOCATED_UDP_PORTS")
        if isinstance(upt, dict):
            for nw_name, used_list in upt.items():
                self.ALLOCATED_UDP_PORTS[nw_name] = set(used_list)

    def serialize(self):
        d = dict()
        for k in self._to_serialize:
            if hasattr(self, k):
                f = getattr(self, k)
                d[k] = f
        return d

    def __str__(self):
        return str(self.serialize())

    def __repr__(self):
        return str(self.serialize())

'''
if "__main__" == __name__:

    logging.basicConfig(
         level=logging.DEBUG,
         datefmt='%H:%M:%S'
    )

    # set up logging to console
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    # set a format which is simpler for console use
    # formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s')
    # console.setFormatter(formatter)
    # add the handler to the root logger
    logging.getLogger().addHandler(console)

    # container_info = {
    #     "libvirt": {
    #         "connection_str": "lxc:///"
    #     }
    # }

    m = PortRegistry("/tmp", "40000-41000", "45000-46000")
    x = m.get_mapping("nt", "eth1", "bridge", {"tcp": [9000, 90001, 90002, 9003], "udp": [8888, 2222, 3333]})

    x = m.get_mapping("nt", "eth2", "nat", {"tcp": [9000, 90001, 90002, 9003], "udp": [8888, 2222, 3333]})

    print str(m)

    m.remove_mapping("nt", "eth2","nat")

    print str(m)
    '''
