#!/usr/bin/env python
#
# Unix SMB/CIFS implementation.
# A script to compare differences of objects and attributes between
# two LDAP servers both running at the same time. It generally compares
# one of the three pratitions DOMAIN, CONFIGURATION or SCHEMA. Users
# that have to be provided sheould be able to read objects in any of the
# above partitions.

# Copyright (C) Zahari Zahariev <zahari.zahariev@postpath.com> 2009
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#

import os
import re
import sys
from optparse import OptionParser

sys.path.insert(0, "bin/python")

import samba
import samba.getopt as options
from samba import Ldb
from samba.ndr import ndr_pack, ndr_unpack
from samba.dcerpc import security
from ldb import SCOPE_SUBTREE, SCOPE_ONELEVEL, SCOPE_BASE, ERR_NO_SUCH_OBJECT, LdbError

global summary
summary = {}

class LDAPBase(object):

    def __init__(self, host, creds, lp):
        if not "://" in host:
            self.host = "ldap://" + host + ":389"
        self.ldb = Ldb(self.host, credentials=creds, lp=lp,
                             options=["modules:paged_searches"])
        self.host = host
        self.base_dn = self.find_basedn()
        self.domain_netbios = self.find_netbios()
        self.server_names = self.find_servers()
        self.domain_name = re.sub("[Dd][Cc]=", "", self.base_dn).replace(",", ".")
        self.domain_sid_bin = self.get_object_sid(self.base_dn)
        #
        #print "@", self.host
        #print "@", self.base_dn
        #print "@", self.domain_netbios
        #print "@", self.server_names
        #print "@", self.domain_name
        #print "@", self.domain_sid_bin

    def find_servers(self):
        """
        """
        res = self.ldb.search(base="OU=Domain Controllers,%s" % self.base_dn, \
                scope=SCOPE_SUBTREE, expression="(objectClass=computer)", attrs=["cn"])
        assert len(res) > 0
        srv = []
        for x in res:
            srv.append(x["cn"][0])
        return srv

    def find_netbios(self):
        res = self.ldb.search(base="CN=Partitions,CN=Configuration,%s" % self.base_dn, \
                scope=SCOPE_SUBTREE, attrs=["nETBIOSName"])
        assert len(res) > 0
        for x in res:
            if "nETBIOSName" in x.keys():
                return x["nETBIOSName"][0]

    def find_basedn(self):
        res = self.ldb.search(base="", expression="(objectClass=*)", scope=SCOPE_BASE,
                attrs=["defaultNamingContext"])
        assert len(res) == 1
        return res[0]["defaultNamingContext"][0]

    def object_exists(self, object_dn):
        res = None
        try:
            res = self.ldb.search(base=object_dn, scope=SCOPE_BASE, expression="(objectClass=*)")
        except LdbError, (ERR_NO_SUCH_OBJECT, _):
            return False
        return len(res) == 1

    def get_object_sid(self, object_dn):
        try:
            res = self.ldb.search(base=object_dn, expression="(objectClass=*)", scope=SCOPE_BASE, attrs=["objectSid"])
        except LdbError, (ERR_NO_SUCH_OBJECT, _):
            raise Exception("DN sintax is wrong or object does't exist: " + object_dn)
        assert len(res) == 1
        return res[0]["objectSid"][0]

    def delete_force(self, object_dn):
        try:
            self.ldb.delete(object_dn)
        except Ldb.LdbError, e:
            assert "No such object" in str(e)

    def get_attributes(self, object_dn):
        """ Returns dict with all default visible attributes
        """
        res = self.ldb.search(base=object_dn, scope=SCOPE_BASE, attrs=["*"])
        assert len(res) == 1
        res = dict(res[0])
        # 'Dn' element is not iterable and we have it as 'distinguishedName'
        del res["dn"]
        for key in res.keys():
            res[key] = list(res[key])
        return res

    def get_descriptor(self, object_dn):
        res = self.ldb.search(base=object_dn, scope=SCOPE_BASE, attrs=["nTSecurityDescriptor"])
        return res[0]["nTSecurityDescriptor"][0]


class LDAPObject(object):
    def __init__(self, connection, dn, summary, cmd_opts):
        self.con = connection
        self.two_domains = cmd_opts.two
        self.quiet = cmd_opts.quiet
        self.verbose = cmd_opts.verbose
        self.summary = summary
        self.dn = dn.replace("${DOMAIN_DN}", self.con.base_dn)
        self.dn = self.dn.replace("CN=${DOMAIN_NETBIOS}", "CN=%s" % self.con.domain_netbios)
        for x in self.con.server_names:
            self.dn = self.dn.replace("CN=${SERVERNAME}", "CN=%s" % x)
        self.attributes = self.con.get_attributes(self.dn)
        # Attributes that are considered always to be different e.g based on timestamp etc.
        #
        # One domain - two domain controllers
        self.ignore_attributes =  [
                # Default Naming Context
                "lastLogon", "lastLogoff", "badPwdCount", "logonCount", "badPasswordTime", "modifiedCount",
                "operatingSystemVersion","oEMInformation",
                # Configuration Naming Context
                "repsFrom", "dSCorePropagationData", "msExchServer1HighestUSN",
                "replUpToDateVector", "repsTo", "whenChanged", "uSNChanged", "uSNCreated",
                # Schema Naming Context
                "prefixMap",]
        self.dn_attributes = []
        self.domain_attributes = []
        self.servername_attributes = []
        self.netbios_attributes = []
        self.other_attributes = []
        # Two domains - two domain controllers

        if self.two_domains:
            self.ignore_attributes +=  [
                "objectCategory", "objectGUID", "objectSid", "whenCreated", "pwdLastSet", "uSNCreated", "creationTime",
                "modifiedCount", "priorSetTime", "rIDManagerReference", "gPLink", "ipsecNFAReference",
                "fRSPrimaryMember", "fSMORoleOwner", "masteredBy", "ipsecOwnersReference", "wellKnownObjects",
                "badPwdCount", "ipsecISAKMPReference", "ipsecFilterReference", "msDs-masteredBy", "lastSetTime",
                "ipsecNegotiationPolicyReference", "subRefs", "gPCFileSysPath", "accountExpires", "invocationId",
                # After Exchange preps
                "targetAddress", "msExchMailboxGuid", "siteFolderGUID"]
            #
            # Attributes that contain the unique DN tail part e.g. 'DC=samba,DC=org'
            self.dn_attributes = [
                "distinguishedName", "defaultObjectCategory", "member", "memberOf", "siteList", "nCName",
                "homeMDB", "homeMTA", "interSiteTopologyGenerator", "serverReference",
                "msDS-HasInstantiatedNCs", "hasMasterNCs", "msDS-hasMasterNCs", "msDS-HasDomainNCs", "dMDLocation",
                "msDS-IsDomainFor", "rIDSetReferences", "serverReferenceBL",
                # After Exchange preps
                "msExchHomeRoutingGroup", "msExchResponsibleMTAServer", "siteFolderServer", "msExchRoutingMasterDN",
                "msExchRoutingGroupMembersBL", "homeMDBBL", "msExchHomePublicMDB", "msExchOwningServer", "templateRoots",
                "addressBookRoots", "msExchPolicyRoots", "globalAddressList", "msExchOwningPFTree",
                "msExchResponsibleMTAServerBL", "msExchOwningPFTreeBL",]
            self.dn_attributes = [x.upper() for x in self.dn_attributes]
            #
            # Attributes that contain the Domain name e.g. 'samba.org'
            self.domain_attributes = [
                "proxyAddresses", "mail", "userPrincipalName", "msExchSmtpFullyQualifiedDomainName",
                "dnsHostName", "networkAddress", "dnsRoot", "servicePrincipalName",]
            self.domain_attributes = [x.upper() for x in self.domain_attributes]
            #
            # May contain DOMAIN_NETBIOS and SERVERNAME
            self.servername_attributes = [ "distinguishedName", "name", "CN", "sAMAccountName", "dNSHostName",
                "servicePrincipalName", "rIDSetReferences", "serverReference", "serverReferenceBL",
                "msDS-IsDomainFor", "interSiteTopologyGenerator",]
            self.servername_attributes = [x.upper() for x in self.servername_attributes]
            #
            self.netbios_attributes = [ "servicePrincipalName", "CN", "distinguishedName", "nETBIOSName", "name",]
            self.netbios_attributes = [x.upper() for x in self.netbios_attributes]
            #
            self.other_attributes = [ "name", "DC",]
            self.other_attributes = [x.upper() for x in self.other_attributes]
        #
        self.ignore_attributes = [x.upper() for x in self.ignore_attributes]

    def log(self, msg):
        """
        Log on the screen if there is no --quiet oprion set
        """
        if not self.quiet:
            print msg

    def fix_dn(self, s):
        res = "%s" % s
        if res.upper().endswith(self.con.base_dn.upper()):
            res = res[:len(res)-len(self.con.base_dn)] + "${DOMAIN_DN}"
        return res

    def fix_domain_name(self, s):
        res = "%s" % s
        res = res.replace(self.con.domain_name.lower(), self.con.domain_name.upper())
        res = res.replace(self.con.domain_name.upper(), "${DOMAIN_NAME}")
        return res

    def fix_domain_netbios(self, s):
        res = "%s" % s
        res = res.replace(self.con.domain_netbios.lower(), self.con.domain_netbios.upper())
        res = res.replace(self.con.domain_netbios.upper(), "${DOMAIN_NETBIOS}")
        return res

    def fix_server_name(self, s):
        res = "%s" % s
        for x in self.con.server_names:
            res = res.upper().replace(x, "${SERVERNAME}")
        return res

    def __eq__(self, other):
        res = ""
        self.unique_attrs = []
        self.df_value_attrs = []
        other.unique_attrs = []
        if self.attributes.keys() != other.attributes.keys():
            #
            title = 4*" " + "Attributes found only in %s:" % self.con.host
            for x in self.attributes.keys():
                if not x in other.attributes.keys() and \
                not x.upper() in [q.upper() for q in other.ignore_attributes]:
                    if title:
                        res += title + "\n"
                        title = None
                    res += 8*" " + x + "\n"
                    self.unique_attrs.append(x)
            #
            title = 4*" " + "Attributes found only in %s:" % other.con.host
            for x in other.attributes.keys():
                if not x in self.attributes.keys() and \
                not x.upper() in [q.upper() for q in self.ignore_attributes]:
                    if title:
                        res += title + "\n"
                        title = None
                    res += 8*" " + x + "\n"
                    other.unique_attrs.append(x)
        #
        missing_attrs = [x.upper() for x in self.unique_attrs]
        missing_attrs += [x.upper() for x in other.unique_attrs]
        title = 4*" " + "Difference in attribute values:"
        for x in self.attributes.keys():
            if x.upper() in self.ignore_attributes or x.upper() in missing_attrs:
                continue
            if isinstance(self.attributes[x], list) and isinstance(other.attributes[x], list):
                self.attributes[x] = sorted(self.attributes[x])
                other.attributes[x] = sorted(other.attributes[x])
            if self.attributes[x] != other.attributes[x]:
                p = None
                q = None
                m = None
                n = None
                # First check if the difference can be fixed but shunting the first part
                # of the DomainHostName e.g. 'mysamba4.test.local' => 'mysamba4'
                if x.upper() in self.other_attributes:
                    p = [self.con.domain_name.split(".")[0] == j for j in self.attributes[x]]
                    q = [other.con.domain_name.split(".")[0] == j for j in other.attributes[x]]
                    if p == q:
                        continue
                # Attribute values that are list that contain DN based values that may differ
                elif x.upper() in self.dn_attributes:
                    m = p
                    n = q
                    if not p and not q:
                        m = self.attributes[x]
                        n = other.attributes[x]
                    p = [self.fix_dn(j) for j in m]
                    q = [other.fix_dn(j) for j in n]
                    if p == q:
                        continue
                # Attributes that contain the Domain name in them
                if x.upper() in self.domain_attributes:
                    m = p
                    n = q
                    if not p and not q:
                        m = self.attributes[x]
                        n = other.attributes[x]
                    p = [self.fix_domain_name(j) for j in m]
                    q = [other.fix_domain_name(j) for j in n]
                    if p == q:
                        continue
                #
                if x.upper() in self.servername_attributes:
                    # Attributes with SERVERNAME
                    m = p
                    n = q
                    if not p and not q:
                        m = self.attributes[x]
                        n = other.attributes[x]
                    p = [self.fix_server_name(j) for j in m]
                    q = [other.fix_server_name(j) for j in n]
                    if p == q:
                        continue
                #
                if x.upper() in self.netbios_attributes:
                    # Attributes with NETBIOS Domain name
                    m = p
                    n = q
                    if not p and not q:
                        m = self.attributes[x]
                        n = other.attributes[x]
                    p = [self.fix_domain_netbios(j) for j in m]
                    q = [other.fix_domain_netbios(j) for j in n]
                    if p == q:
                        continue
                #
                if title:
                    res += title + "\n"
                    title = None
                if p and q:
                    res += 8*" " + x + " => \n%s\n%s" % (p, q) + "\n"
                else:
                    res += 8*" " + x + " => \n%s\n%s" % (self.attributes[x], other.attributes[x]) + "\n"
                self.df_value_attrs.append(x)
        #
        if self.unique_attrs + other.unique_attrs != []:
            assert self.unique_attrs != other.unique_attrs
        self.summary["unique_attrs"] += self.unique_attrs
        self.summary["df_value_attrs"] += self.df_value_attrs
        other.summary["unique_attrs"] += other.unique_attrs
        other.summary["df_value_attrs"] += self.df_value_attrs # they are the same
        #
        self.screen_output = res[:-1]
        other.screen_output = res[:-1]
        #
        return res == ""


class LDAPBundel(object):
    def __init__(self, connection, context, cmd_opts, dn_list=None):
        self.con = connection
        self.cmd_opts = cmd_opts
        self.two_domains = cmd_opts.two
        self.quiet = cmd_opts.quiet
        self.verbose = cmd_opts.verbose
        self.summary = {}
        self.summary["unique_attrs"] = []
        self.summary["df_value_attrs"] = []
        self.summary["known_ignored_dn"] = []
        self.summary["abnormal_ignored_dn"] = []
        if dn_list:
            self.dn_list = dn_list
        elif context.upper() in ["DOMAIN", "CONFIGURATION", "SCHEMA"]:
            self.context = context.upper()
            self.dn_list = self.get_dn_list(context)
        else:
            raise Exception("Unknown initialization data for LDAPBundel().")
        counter = 0
        while counter < len(self.dn_list):
            # Use alias reference
            tmp = self.dn_list[counter]
            tmp = tmp[:len(tmp)-len(self.con.base_dn)] + "${DOMAIN_DN}"
            tmp = tmp.replace("CN=%s" % self.con.domain_netbios, "CN=${DOMAIN_NETBIOS}")
            for x in self.con.server_names:
                tmp = tmp.replace("CN=%s" % x, "CN=${SERVERNAME}")
            self.dn_list[counter] = tmp
            counter += 1
        self.dn_list = list(set(self.dn_list))
        self.dn_list = sorted(self.dn_list)
        self.size = len(self.dn_list)

    def log(self, msg):
        """
        Log on the screen if there is no --quiet oprion set
        """
        if not self.quiet:
            print msg

    def update_size(self):
        self.size = len(self.dn_list)
        self.dn_list = sorted(self.dn_list)

    def __eq__(self, other):
        res = True
        if self.size != other.size:
            self.log( "\n* Lists have different size: %s != %s" % (self.size, other.size) )
            res = False
        #
        title= "\n* DNs found only in %s:" % self.con.host
        for x in self.dn_list:
            if not x.upper() in [q.upper() for q in other.dn_list]:
                if title:
                    self.log( title )
                    title = None
                self.log( 4*" " + x )
                self.dn_list[self.dn_list.index(x)] = ""
        self.dn_list = [x for x in self.dn_list if x]
        #
        title= "\n* DNs found only in %s:" % other.con.host
        for x in other.dn_list:
            if not x.upper() in [q.upper() for q in self.dn_list]:
                if title:
                    self.log( title )
                    title = None
                self.log( 4*" " + x )
                other.dn_list[other.dn_list.index(x)] = ""
        other.dn_list = [x for x in other.dn_list if x]
        #
        self.update_size()
        other.update_size()
        assert self.size == other.size
        assert sorted([x.upper() for x in self.dn_list]) == sorted([x.upper() for x in other.dn_list])
        self.log( "\n* Objets to be compared: %s" % self.size )

        index = 0
        while index < self.size:
            skip = False
            try:
                object1 = LDAPObject(connection=self.con,
                        dn=self.dn_list[index],
                        summary=self.summary,
                        cmd_opts = self.cmd_opts)
            except LdbError, (ERR_NO_SUCH_OBJECT, _):
                self.log( "\n!!! Object not found: %s" % self.dn_list[index] )
                skip = True
            try:
                object2 = LDAPObject(connection=other.con,
                        dn=other.dn_list[index],
                        summary=other.summary,
                        cmd_opts = self.cmd_opts)
            except LdbError, (ERR_NO_SUCH_OBJECT, _):
                self.log( "\n!!! Object not found: %s" % other.dn_list[index] )
                skip = True
            if skip:
                index += 1
                continue
            if object1 == object2:
                if self.verbose:
                    self.log( "\nComparing:" )
                    self.log( "'%s' [%s]" % (object1.dn, object1.con.host) )
                    self.log( "'%s' [%s]" % (object2.dn, object2.con.host) )
                    self.log( 4*" " + "OK" )
            else:
                self.log( "\nComparing:" )
                self.log( "'%s' [%s]" % (object1.dn, object1.con.host) )
                self.log( "'%s' [%s]" % (object2.dn, object2.con.host) )
                self.log( object1.screen_output )
                self.log( 4*" " + "FAILED" )
                res = False
            self.summary = object1.summary
            other.summary = object2.summary
            index += 1
        #
        return res

    def get_dn_list(self, context):
        """ Query LDAP server about the DNs of certain naming self.con.ext Domain (or Default), Configuration, Schema.
            Parse all DNs and filter those that are 'strange' or abnormal.
        """
        if context.upper() == "DOMAIN":
            search_base = "%s" % self.con.base_dn
        elif context.upper() == "CONFIGURATION":
            search_base = "CN=Configuration,%s" % self.con.base_dn
        elif context.upper() == "SCHEMA":
            search_base = "CN=Schema,CN=Configuration,%s" % self.con.base_dn

        dn_list = []
        res = self.con.ldb.search(base=search_base, scope=SCOPE_SUBTREE, attrs=["dn"])
        for x in res:
           dn_list.append(x["dn"].get_linearized())

        #
        global summary
        #
        title = "\n* Ignored (DNS related) DNs in %s:" % self.con.host
        for x in dn_list:
            xx = "".join(re.findall("[Cc][Nn]=.*?,", x)) \
                    + "".join(re.findall("[Oo][Uu]=.*?,", x)) \
                    + "".join(re.findall("[Dd][Cc]=.*?,", x)) + re.search("([Dd][Cc]=[\w]+$)", x).group()
            if x != xx:
                if title:
                    self.log( title )
                    title = None
                self.log( 4*" " + x )
                dn_list[dn_list.index(x)] = ""
        #
        dn_list = [x for x in dn_list if x]
        return dn_list

    def print_summary(self):
        self.summary["unique_attrs"] = list(set(self.summary["unique_attrs"]))
        self.summary["df_value_attrs"] = list(set(self.summary["df_value_attrs"]))
        #
        if self.summary["unique_attrs"]:
            self.log( "\nAttributes found only in %s:" % self.con.host )
            self.log( "".join([str("\n" + 4*" " + x) for x in self.summary["unique_attrs"]]) )
        #
        if self.summary["df_value_attrs"]:
            self.log( "\nAttributes with different values:" )
            self.log( "".join([str("\n" + 4*" " + x) for x in self.summary["df_value_attrs"]]) )
            self.summary["df_value_attrs"] = []

###

if __name__ == "__main__":
    parser = OptionParser("ldapcmp [options] domain|configuration|schema")
    sambaopts = options.SambaOptions(parser)
    credopts = options.CredentialsOptionsDouble(parser)
    parser.add_option_group(credopts)

    lp = sambaopts.get_loadparm()
    creds = credopts.get_credentials(lp)
    creds2 = credopts.get_credentials2(lp)

    parser.add_option("", "--host", dest="host",
                              help="IP of the first LDAP server",)
    parser.add_option("", "--host2", dest="host2",
                              help="IP of the second LDAP server",)
    parser.add_option("-w", "--two", dest="two", action="store_true", default=False,
                              help="Hosts are in two different domains",)
    parser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
                              help="Do not print anything but relay on just exit code",)
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true", default=False,
                              help="Print all DN pairs that have been compared",)
    (options, args) = parser.parse_args()

    if not (len(args) == 1 and args[0].upper() in ["DOMAIN", "CONFIGURATION", "SCHEMA"]):
        parser.error("Incorrect arguments")

    if options.verbose and options.quiet:
        parser.error("You cannot set --verbose and --quiet together")

    con1 = LDAPBase(options.host, creds, lp)
    assert len(con1.base_dn) > 0

    con2 = LDAPBase(options.host2, creds2, lp)
    assert len(con2.base_dn) > 0

    b1 = LDAPBundel(con1, context=args[0], cmd_opts=options)
    b2 = LDAPBundel(con2, context=args[0], cmd_opts=options)

    if b1 == b2:
        if not options.quiet:
            print "\n* Final result: SUCCESS"
        status = 0
    else:
        if not options.quiet:
            print "\n* Final result: FAILURE"
            print "\nSUMMARY"
            print "---------"
        status = -1

    assert len(b1.summary["df_value_attrs"]) == len(b2.summary["df_value_attrs"])
    b2.summary["df_value_attrs"] = []

    if not options.quiet:
        b1.print_summary()
        b2.print_summary()

    sys.exit(status)