From 2f27d0c762c8f5be416ed38e00150a8ba58e63ad Mon Sep 17 00:00:00 2001 From: Jelmer Vernooij Date: Wed, 17 Jun 2009 18:25:21 +0200 Subject: pyldb: Support getting the parent of special DNs without segfaulting. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Found by: Андрей Григорьев --- source4/lib/ldb/pyldb.c | 10 +++++++++- source4/lib/ldb/tests/python/api.py | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) (limited to 'source4/lib/ldb') diff --git a/source4/lib/ldb/pyldb.c b/source4/lib/ldb/pyldb.c index 52d8530439..ab2a1215b8 100644 --- a/source4/lib/ldb/pyldb.c +++ b/source4/lib/ldb/pyldb.c @@ -206,7 +206,15 @@ static int py_ldb_dn_compare(PyLdbDnObject *dn1, PyLdbDnObject *dn2) static PyObject *py_ldb_dn_get_parent(PyLdbDnObject *self) { struct ldb_dn *dn = PyLdbDn_AsDn((PyObject *)self); - return PyLdbDn_FromDn(ldb_dn_get_parent(NULL, dn)); + struct ldb_dn *parent; + + parent = ldb_dn_get_parent(NULL, dn); + + if (parent == NULL) { + Py_RETURN_NONE; + } else { + return PyLdbDn_FromDn(parent); + } } #define dn_ldb_ctx(dn) ((struct ldb_context *)dn) diff --git a/source4/lib/ldb/tests/python/api.py b/source4/lib/ldb/tests/python/api.py index 07500e2372..177e2e9864 100755 --- a/source4/lib/ldb/tests/python/api.py +++ b/source4/lib/ldb/tests/python/api.py @@ -14,6 +14,7 @@ def filename(): return os.tempnam() class NoContextTests(unittest.TestCase): + def test_valid_attr_name(self): self.assertTrue(ldb.valid_attr_name("foo")) self.assertFalse(ldb.valid_attr_name("24foo")) @@ -28,6 +29,7 @@ class NoContextTests(unittest.TestCase): class SimpleLdb(unittest.TestCase): + def test_connect(self): ldb.Ldb(filename()) @@ -273,6 +275,7 @@ class SimpleLdb(unittest.TestCase): class DnTests(unittest.TestCase): + def setUp(self): self.ldb = ldb.Ldb(filename()) @@ -301,6 +304,10 @@ class DnTests(unittest.TestCase): x = ldb.Dn(self.ldb, "dc=foo,bar=bloe") self.assertEquals("bar=bloe", x.parent().__str__()) + def test_parent_nonexistant(self): + x = ldb.Dn(self.ldb, "@BLA") + self.assertEquals(None, x.parent()) + def test_compare(self): x = ldb.Dn(self.ldb, "dc=foo,bar=bloe") y = ldb.Dn(self.ldb, "dc=foo,bar=bloe") @@ -373,6 +380,7 @@ class DnTests(unittest.TestCase): class LdbMsgTests(unittest.TestCase): + def setUp(self): self.msg = ldb.Message() @@ -439,6 +447,7 @@ class LdbMsgTests(unittest.TestCase): class MessageElementTests(unittest.TestCase): + def test_cmp_element(self): x = ldb.MessageElement(["foo"]) y = ldb.MessageElement(["foo"]) @@ -479,6 +488,7 @@ class MessageElementTests(unittest.TestCase): class ModuleTests(unittest.TestCase): + def test_register_module(self): class ExampleModule: name = "example" @@ -505,6 +515,7 @@ class ModuleTests(unittest.TestCase): l = ldb.Ldb("usemodule.ldb") self.assertEquals(["init"], ops) + if __name__ == '__main__': import unittest unittest.TestProgram() -- cgit