From f13895851fde13cefaa484cd9f90a9f0ad41d78e Mon Sep 17 00:00:00 2001 From: Jelmer Vernooij Date: Wed, 11 Feb 2009 17:54:58 +0100 Subject: Cancel transactions when exceptions are raised. --- source4/scripting/python/samba/samdb.py | 175 ++++++++++++++------------ source4/scripting/python/samba/tests/samdb.py | 1 + 2 files changed, 93 insertions(+), 83 deletions(-) (limited to 'source4') diff --git a/source4/scripting/python/samba/samdb.py b/source4/scripting/python/samba/samdb.py index 92b0bd7b89..947c46079f 100644 --- a/source4/scripting/python/samba/samdb.py +++ b/source4/scripting/python/samba/samdb.py @@ -104,41 +104,43 @@ userAccountControl: %u """ # connect to the sam self.transaction_start() - - domain_dn = self.domain_dn() - assert(domain_dn is not None) - user_dn = "CN=%s,CN=Users,%s" % (username, domain_dn) - - # - # the new user record. note the reliance on the samdb module to fill - # in a sid, guid etc - # - # now the real work - self.add({"dn": user_dn, - "sAMAccountName": username, - "userPassword": password, - "objectClass": "user"}) - - res = self.search(user_dn, scope=ldb.SCOPE_BASE, - expression="objectclass=*", - attrs=["objectSid"]) - assert(len(res) == 1) - user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0]) - - try: - idmap = IDmapDB(lp=self.lp) - - user = pwd.getpwnam(unixname) - # setup ID mapping for this UID + domain_dn = self.domain_dn() + assert(domain_dn is not None) + user_dn = "CN=%s,CN=Users,%s" % (username, domain_dn) + + # + # the new user record. note the reliance on the samdb module to + # fill in a sid, guid etc + # + # now the real work + self.add({"dn": user_dn, + "sAMAccountName": username, + "userPassword": password, + "objectClass": "user"}) + + res = self.search(user_dn, scope=ldb.SCOPE_BASE, + expression="objectclass=*", + attrs=["objectSid"]) + assert len(res) == 1 + user_sid = self.schema_format_value("objectSid", res[0]["objectSid"][0]) - idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2]) - - except KeyError: - pass - - # modify the userAccountControl to remove the disabled bit - self.enable_account(user_dn) + try: + idmap = IDmapDB(lp=self.lp) + + user = pwd.getpwnam(unixname) + # setup ID mapping for this UID + + idmap.setup_name_mapping(user_sid, idmap.TYPE_UID, user[2]) + + except KeyError: + pass + + # modify the userAccountControl to remove the disabled bit + self.enable_account(user_dn) + except: + self.transaction_cancel() + raise self.transaction_commit() def setpassword(self, filter, password): @@ -149,32 +151,35 @@ userAccountControl: %u """ # connect to the sam self.transaction_start() - - # find the DNs for the domain - res = self.search("", scope=ldb.SCOPE_BASE, - expression="(defaultNamingContext=*)", - attrs=["defaultNamingContext"]) - assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None) - domain_dn = res[0]["defaultNamingContext"][0] - assert(domain_dn is not None) - - res = self.search(domain_dn, scope=ldb.SCOPE_SUBTREE, - expression=filter, - attrs=[]) - assert(len(res) == 1) - user_dn = res[0].dn - - setpw = """ -dn: %s -changetype: modify -replace: userPassword -userPassword: %s -""" % (user_dn, password) - - self.modify_ldif(setpw) - - # modify the userAccountControl to remove the disabled bit - self.enable_account(user_dn) + try: + # find the DNs for the domain + res = self.search("", scope=ldb.SCOPE_BASE, + expression="(defaultNamingContext=*)", + attrs=["defaultNamingContext"]) + assert(len(res) == 1 and res[0]["defaultNamingContext"] is not None) + domain_dn = res[0]["defaultNamingContext"][0] + assert(domain_dn is not None) + + res = self.search(domain_dn, scope=ldb.SCOPE_SUBTREE, + expression=filter, + attrs=[]) + assert(len(res) == 1) + user_dn = res[0].dn + + setpw = """ + dn: %s + changetype: modify + replace: userPassword + userPassword: %s + """ % (user_dn, password) + + self.modify_ldif(setpw) + + # modify the userAccountControl to remove the disabled bit + self.enable_account(user_dn) + except: + self.transaction_cancel() + raise self.transaction_commit() def set_domain_sid(self, sid): @@ -200,28 +205,32 @@ userPassword: %s :param expiry_seconds: expiry time from now in seconds :param noexpiry: if set, then don't expire password """ - self.transaction_start(); - res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE, - expression=("(samAccountName=%s)" % user), - attrs=["userAccountControl", "accountExpires"]) - assert len(res) == 1 - userAccountControl = int(res[0]["userAccountControl"][0]) - accountExpires = int(res[0]["accountExpires"][0]) - if noexpiry: - userAccountControl = userAccountControl | 0x10000 - accountExpires = 0 - else: - userAccountControl = userAccountControl & ~0x10000 - accountExpires = glue.unix2nttime(expiry_seconds + int(time.time())) - - mod = """ -dn: %s -changetype: modify -replace: userAccountControl -userAccountControl: %u -replace: accountExpires -accountExpires: %u -""" % (res[0].dn, userAccountControl, accountExpires) - # now change the database - self.modify_ldif(mod) + self.transaction_start() + try: + res = self.search(base=self.domain_dn(), scope=ldb.SCOPE_SUBTREE, + expression=("(samAccountName=%s)" % user), + attrs=["userAccountControl", "accountExpires"]) + assert len(res) == 1 + userAccountControl = int(res[0]["userAccountControl"][0]) + accountExpires = int(res[0]["accountExpires"][0]) + if noexpiry: + userAccountControl = userAccountControl | 0x10000 + accountExpires = 0 + else: + userAccountControl = userAccountControl & ~0x10000 + accountExpires = glue.unix2nttime(expiry_seconds + int(time.time())) + + mod = """ + dn: %s + changetype: modify + replace: userAccountControl + userAccountControl: %u + replace: accountExpires + accountExpires: %u + """ % (res[0].dn, userAccountControl, accountExpires) + # now change the database + self.modify_ldif(mod) + except: + self.transaction_cancel() + raise self.transaction_commit(); diff --git a/source4/scripting/python/samba/tests/samdb.py b/source4/scripting/python/samba/tests/samdb.py index cce6ea84d3..161f9f4f65 100644 --- a/source4/scripting/python/samba/tests/samdb.py +++ b/source4/scripting/python/samba/tests/samdb.py @@ -28,6 +28,7 @@ import uuid from samba import param class SamDBTestCase(TestCaseInTempDir): + def setUp(self): super(SamDBTestCase, self).setUp() invocationid = str(uuid.uuid4()) -- cgit