#!/usr/bin/env python
#
# Unix SMB/CIFS implementation.
# Copyright (C) Amitay Isaacs <amitay@gmail.com> 2012
#
# Upgrade DNS provision from BIND9_FLATFILE to BIND9_DLZ or SAMBA_INTERNAL
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import optparse
import logging
import grp
from base64 import b64encode
import shlex

sys.path.insert(0, "bin/python")

import ldb
import samba
from samba import param
from samba.auth import system_session
from samba.ndr import (
    ndr_pack,
    ndr_unpack )
import samba.getopt as options
from samba.upgradehelpers import (
    get_paths,
    get_ldbs )
from samba.dsdb import DS_DOMAIN_FUNCTION_2003
from samba.provision import (
    find_provision_key_parameters,
    interface_ips_v4,
    interface_ips_v6 )
from samba.provision.common import (
    setup_path,
    setup_add_ldif )
from samba.provision.sambadns import (
    ARecord,
    AAAARecord,
    CNameRecord,
    NSRecord,
    SOARecord,
    SRVRecord,
    TXTRecord,
    get_dnsadmins_sid,
    add_dns_accounts,
    create_dns_partitions,
    fill_dns_data_partitions,
    create_dns_dir,
    secretsdb_setup_dns,
    create_samdb_copy,
    create_named_conf,
    create_named_txt )
from samba.dcerpc import security

samba.ensure_external_module("dns", "dnspython")
import dns.zone, dns.rdatatype

__docformat__ = 'restructuredText'


def find_bind_gid():
    """Find system group id for bind9
    """
    for name in ["bind", "named"]:
        try:
            return grp.getgrnam(name)[2]
        except KeyError:
            pass
    return None


def fix_names(pnames):
    """Convert elements to strings from MessageElement
    """
    names = pnames
    names.rootdn = pnames.rootdn[0]
    names.domaindn = pnames.domaindn[0]
    names.configdn = pnames.configdn[0]
    names.schemadn = pnames.schemadn[0]
    names.wheel_gid = pnames.wheel_gid[0]
    names.serverdn = str(pnames.serverdn)
    return names


def convert_dns_rdata(rdata, serial=1):
    """Convert resource records in dnsRecord format
    """
    if rdata.rdtype == dns.rdatatype.A:
        rec = ARecord(rdata.address, serial=serial)
    elif rdata.rdtype == dns.rdatatype.AAAA:
        rec = AAAARecord(rdata.address, serial=serial)
    elif rdata.rdtype == dns.rdatatype.CNAME:
        rec = CNameRecord(rdata.target.to_text(), serial=serial)
    elif rdata.rdtype == dns.rdatatype.NS:
        rec = NSRecord(rdata.target.to_text(), serial=serial)
    elif rdata.rdtype == dns.rdatatype.SRV:
        rec = SRVRecord(rdata.target.to_text(), int(rdata.port),
                        priority=int(rdata.priority), weight=int(rdata.weight),
                        serial=serial)
    elif rdata.rdtype == dns.rdatatype.TXT:
        slist = shlex.split(rdata.to_text())
        rec = TXTRecord(slist, serial=serial)
    elif rdata.rdtype == dns.rdatatype.SOA:
        rec = SOARecord(rdata.mname.to_text(), rdata.rname.to_text(),
                        serial=int(rdata.serial),
                        refresh=int(rdata.refresh), retry=int(rdata.retry),
                        expire=int(rdata.expire), minimum=int(rdata.minimum))
    else:
        rec = None
    return rec


def import_zone_data(samdb, logger, zone, serial, domaindn, forestdn,
                     dnsdomain, dnsforest):
    """Insert zone data in DNS partitions
    """
    labels = dnsdomain.split('.')
    labels.append('')
    domain_root = dns.name.Name(labels)
    domain_prefix = "DC=%s,CN=MicrosoftDNS,DC=DomainDnsZones,%s" % (dnsdomain,
                                                                    domaindn)

    tmp = "_msdcs.%s" % dnsforest
    labels = tmp.split('.')
    labels.append('')
    forest_root = dns.name.Name(labels)
    dnsmsdcs = "_msdcs.%s" % dnsforest
    forest_prefix = "DC=%s,CN=MicrosoftDNS,DC=ForestDnsZones,%s" % (dnsmsdcs,
                                                                    forestdn)

    # Extract @ record
    at_record = zone.get_node(domain_root)
    zone.delete_node(domain_root)

    # SOA record
    rdset = at_record.get_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)
    soa_rec = ndr_pack(convert_dns_rdata(rdset[0]))
    at_record.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.SOA)

    # NS record
    rdset = at_record.get_rdataset(dns.rdataclass.IN, dns.rdatatype.NS)
    ns_rec = ndr_pack(convert_dns_rdata(rdset[0]))
    at_record.delete_rdataset(dns.rdataclass.IN, dns.rdatatype.NS)

    # A/AAAA records
    ip_recs = []
    for rdset in at_record:
        for r in rdset:
            rec = convert_dns_rdata(r)
            ip_recs.append(ndr_pack(rec))

    # Add @ record for domain
    dns_rec = [soa_rec, ns_rec] + ip_recs
    msg = ldb.Message(ldb.Dn(samdb, 'DC=@,%s' % domain_prefix))
    msg["objectClass"] = ["top", "dnsNode"]
    msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
                                          "dnsRecord")
    try:
        samdb.add(msg)
    except Exception:
        logger.error("Failed to add @ record for domain")
        raise
    logger.debug("Added @ record for domain")

    # Add @ record for forest
    dns_rec = [soa_rec, ns_rec]
    msg = ldb.Message(ldb.Dn(samdb, 'DC=@,%s' % forest_prefix))
    msg["objectClass"] = ["top", "dnsNode"]
    msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
                                          "dnsRecord")
    try:
        samdb.add(msg)
    except Exception:
        logger.error("Failed to add @ record for forest")
        raise
    logger.debug("Added @ record for forest")

    # Add remaining records in domain and forest
    for node in zone.nodes:
        name = node.relativize(forest_root).to_text()
        if name == node.to_text():
            name = node.relativize(domain_root).to_text()
            dn = "DC=%s,%s" % (name, domain_prefix)
            fqdn = "%s.%s" % (name, dnsdomain)
        else:
            dn = "DC=%s,%s" % (name, forest_prefix)
            fqdn = "%s.%s" % (name, dnsmsdcs)

        dns_rec = []
        for rdataset in zone.nodes[node]:
            for rdata in rdataset:
                rec = convert_dns_rdata(rdata, serial)
                if not rec:
                    logger.warn("Unsupported record type (%s) for %s, ignoring" %
                                dns.rdatatype.to_text(rdata.rdatatype), name)
                else:
                    dns_rec.append(ndr_pack(rec))

        msg = ldb.Message(ldb.Dn(samdb, dn))
        msg["objectClass"] = ["top", "dnsNode"]
        msg["dnsRecord"] = ldb.MessageElement(dns_rec, ldb.FLAG_MOD_ADD,
                                              "dnsRecord")
        try:
            samdb.add(msg)
        except Exception:
            logger.error("Failed to add DNS record %s" % (fqdn))
            raise
        logger.debug("Added DNS record %s" % (fqdn))


# dnsprovision creates application partitions for AD based DNS mainly if the existing
# provision was created using earlier snapshots of samba4 which did not have support
# for DNS partitions

if __name__ == '__main__':

    # Setup command line parser
    parser = optparse.OptionParser("upgradedns [options]")
    sambaopts = options.SambaOptions(parser)
    credopts = options.CredentialsOptions(parser)

    parser.add_option_group(options.VersionOptions(parser))
    parser.add_option_group(sambaopts)
    parser.add_option_group(credopts)

    parser.add_option("--dns-backend", type="choice", metavar="<BIND9_DLZ|SAMBA_INTERNAL>",
                      choices=["SAMBA_INTERNAL", "BIND9_DLZ"], default="SAMBA_INTERNAL",
                      help="The DNS server backend, default SAMBA_INTERNAL")
    parser.add_option("--migrate", type="choice", metavar="<yes|no>",
                      choices=["yes","no"], default="yes",
                      help="Migrate existing zone data, default yes")
    parser.add_option("--verbose", help="Be verbose", action="store_true")

    opts = parser.parse_args()[0]

    if opts.dns_backend is None:
        opts.dns_backend = 'SAMBA_INTERNAL'

    if opts.migrate:
        autofill = False
    else:
        autofill = True

    # Set up logger
    logger = logging.getLogger("upgradedns")
    logger.addHandler(logging.StreamHandler(sys.stdout))
    logger.setLevel(logging.INFO)
    if opts.verbose:
        logger.setLevel(logging.DEBUG)

    lp = sambaopts.get_loadparm()
    lp.load(lp.configfile)
    creds = credopts.get_credentials(lp)

    logger.info("Reading domain information")
    paths = get_paths(param, smbconf=lp.configfile)
    paths.bind_gid = find_bind_gid()
    ldbs = get_ldbs(paths, creds, system_session(), lp)
    pnames = find_provision_key_parameters(ldbs.sam, ldbs.secrets, ldbs.idmap,
                                           paths, lp.configfile, lp)
    names = fix_names(pnames)

    if names.domainlevel < DS_DOMAIN_FUNCTION_2003:
        logger.error("Cannot create AD based DNS for OS level < 2003")
        sys.exit(1)

    logger.info("Looking up IPv4 addresses")
    hostip = interface_ips_v4(lp)
    try:
        hostip.remove('127.0.0.1')
    except ValueError:
        pass
    if not hostip:
        logger.error("No IPv4 addresses found")
        sys.exit(1)
    else:
        hostip = hostip[0]
        logger.debug("IPv4 addresses: %s" % hostip)

    logger.info("Looking up IPv6 addresses")
    hostip6 = interface_ips_v6(lp, linklocal=False)
    if not hostip6:
        hostip6 = None
    else:
        hostip6 = hostip6[0]
        logger.debug("IPv6 addresses: %s" % hostip6)

    domaindn = names.domaindn
    forestdn = names.rootdn

    dnsdomain = names.dnsdomain.lower()
    dnsforest = dnsdomain

    site = names.sitename
    hostname = names.hostname
    dnsname = '%s.%s' % (hostname, dnsdomain)

    domainsid = names.domainsid
    domainguid = names.domainguid
    ntdsguid = names.ntdsguid

    # Check for DNS accounts and create them if required
    try:
        msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
                              expression='(sAMAccountName=DnsAdmins)',
                              attrs=['objectSid'])
        dnsadmins_sid = ndr_unpack(security.dom_sid, msg[0]['objectSid'][0])
    except Exception, e:
        logger.info("Adding DNS accounts")
        add_dns_accounts(ldbs.sam, domaindn)
        dnsadmins_sid = get_dnsadmins_sid(ldbs.sam, domaindn)
    else:
        logger.info("DNS accounts already exist")

    # Import dns records from zone file
    if os.path.exists(paths.dns):
        logger.info("Reading records from zone file %s" % paths.dns)
        try:
            zone = dns.zone.from_file(paths.dns, relativize=False)
            rrset = zone.get_rdataset("%s." % dnsdomain, dns.rdatatype.SOA)
            serial = int(rrset[0].serial)
        except Exception, e:
            logger.warn("Error parsing DNS data from '%s' (%s)" % (paths.dns, str(e)))
            logger.warn("DNS records will be automatically created")
            autofill = True
    else:
        logger.info("No zone file %s" % paths.dns)
        logger.warn("DNS records will be automatically created")
        autofill = True

    # Create DNS partitions if missing and fill DNS information
    try:
        expression = '(|(dnsRoot=DomainDnsZones.%s)(dnsRoot=ForestDnsZones.%s))' % \
                     (dnsdomain, dnsforest)
        msg = ldbs.sam.search(base=names.configdn, scope=ldb.SCOPE_DEFAULT,
                              expression=expression, attrs=['nCName'])
        ncname = msg[0]['nCName'][0]
    except Exception, e:
        logger.info("Creating DNS partitions")
        create_dns_partitions(ldbs.sam, domainsid, names, domaindn, forestdn,
                          dnsadmins_sid)

        logger.info("Populating DNS partitions")
        fill_dns_data_partitions(ldbs.sam, domainsid, site, domaindn, forestdn,
                             dnsdomain, dnsforest, hostname, hostip, hostip6,
                             domainguid, ntdsguid, dnsadmins_sid,
                             autofill=autofill)

        if not autofill:
            logger.info("Importing records from zone file")
            import_zone_data(ldbs.sam, logger, zone, serial, domaindn, forestdn,
                             dnsdomain, dnsforest)
    else:
        logger.info("DNS partitions already exist")

    # Mark that we are hosting DNS partitions
    try:
        dns_nclist = [ 'DC=DomainDnsZones,%s' % domaindn,
                       'DC=ForestDnsZones,%s' % forestdn ]

        msgs = ldbs.sam.search(base=names.serverdn, scope=ldb.SCOPE_DEFAULT,
                               expression='(objectclass=nTDSDSa)',
                               attrs=['hasPartialReplicaNCs',
                                      'msDS-hasMasterNCs'])
        msg = msgs[0]

        master_nclist = []
        ncs = msg.get("msDS-hasMasterNCs")
        if ncs:
            for nc in ncs:
                master_nclist.append(nc)

        partial_nclist = []
        ncs = msg.get("hasPartialReplicaNCs")
        if ncs:
            for nc in ncs:
                partial_nclist.append(nc)

        modified_master = False
        modified_partial = False
        for nc in dns_nclist:
            if nc not in master_nclist:
                master_nclist.append(nc)
                modified_master = True
            if nc in partial_nclist:
                partial_nclist.remove(nc)
                modified_partial = True

        if modified_master or modified_partial:
            logger.debug("Updating msDS-hasMasterNCs and hasPartialReplicaNCs attributes")
            m = ldb.Message()
            m.dn = msg.dn
            if modified_master:
                m["msDS-hasMasterNCs"] = ldb.MessageElement(master_nclist,
                                                            ldb.FLAG_MOD_REPLACE,
                                                            "msDS-hasMasterNCs")
            if modified_partial:
                if partial_nclist:
                    m["hasPartialReplicaNCs"] = ldb.MessageElement(partial_nclist,
                                                                   ldb.FLAG_MOD_REPLACE,
                                                                   "hasPartialReplicaNCs")
                else:
                    m["hasPartialReplicaNCs"] = ldb.MessageElement(ncs,
                                                                   ldb.FLAG_MOD_DELETE,
                                                                   "hasPartialReplicaNCs")
            ldbs.sam.modify(m)
    except Exception:
        raise

    # Special stuff for DLZ backend
    if opts.dns_backend == "BIND9_DLZ":
        # Check if dns-HOSTNAME account exists and create it if required
        try:
            dn = 'samAccountName=dns-%s,CN=Principals' % hostname
            msg = ldbs.secrets.search(expression='(dn=%s)' % dn, attrs=['secret'])
            dnssecret = msg[0]['secret'][0]
        except Exception:
            logger.info("Adding dns-%s account" % hostname)

            try:
                msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
                                      expression='(sAMAccountName=dns-%s)' % (hostname),
                                      attrs=['clearTextPassword'])
                dn = msg[0].dn
                ldbs.sam.delete(dn)
            except Exception:
                pass

            dnspass = samba.generate_random_password(128, 255)
            setup_add_ldif(ldbs.sam, setup_path("provision_dns_add_samba.ldif"), {
                    "DNSDOMAIN": dnsdomain,
                    "DOMAINDN": domaindn,
                    "DNSPASS_B64": b64encode(dnspass.encode('utf-16-le')),
                    "HOSTNAME" : hostname,
                    "DNSNAME" : dnsname }
                           )

            secretsdb_setup_dns(ldbs.secrets, names,
                                paths.private_dir, realm=names.realm,
                                dnsdomain=names.dnsdomain,
                                dns_keytab_path=paths.dns_keytab, dnspass=dnspass)
        else:
            logger.info("dns-%s account already exists" % hostname)

        # This forces a re-creation of dns directory and all the files within
        # It's an overkill, but it's easier to re-create a samdb copy, rather
        # than trying to fix a broken copy.
        create_dns_dir(logger, paths)

        # Setup a copy of SAM for BIND9
        create_samdb_copy(ldbs.sam, logger, paths, names, domainsid,
                          domainguid)

        create_named_conf(paths, names.realm, dnsdomain, opts.dns_backend)

        create_named_txt(paths.namedtxt, names.realm, dnsdomain, dnsname,
                         paths.private_dir, paths.dns_keytab)
        logger.info("See %s for an example configuration include file for BIND", paths.namedconf)
        logger.info("and %s for further documentation required for secure DNS "
                    "updates", paths.namedtxt)
    elif opts.dns_backend == "SAMBA_INTERNAL":
        # Check if dns-HOSTNAME account exists and delete it if required
        try:
            dn_str = 'samAccountName=dns-%s,CN=Principals' % hostname
            msg = ldbs.secrets.search(expression='(dn=%s)' % dn_str, attrs=['secret'])
            dn = msg[0].dn
        except Exception:
            dn = None

        if dn is not None:
            try:
                ldbs.secrets.delete(dn)
            except Exception:
                logger.info("Failed to delete %s from secrets.ldb" % dn)

        try:
            msg = ldbs.sam.search(base=domaindn, scope=ldb.SCOPE_DEFAULT,
                                  expression='(sAMAccountName=dns-%s)' % (hostname),
                                  attrs=['clearTextPassword'])
            dn = msg[0].dn
        except Exception:
            dn = None

        if dn is not None:
            try:
                ldbs.sam.delete(dn)
            except Exception:
                logger.info("Failed to delete %s from sam.ldb" % dn)

    logger.info("Finished upgrading DNS")