diff options
author | Jelmer Vernooij <jelmer@samba.org> | 2009-09-14 17:03:30 +0200 |
---|---|---|
committer | Jelmer Vernooij <jelmer@samba.org> | 2009-09-14 17:03:30 +0200 |
commit | 667b825d183f6b438b2329aef32686c20e55b7d3 (patch) | |
tree | b15042402c3d577b8286ce2737504513d48d0f46 | |
parent | d106e728fb0c59900c289055c97f424e4f5d3c75 (diff) | |
download | samba-667b825d183f6b438b2329aef32686c20e55b7d3.tar.gz samba-667b825d183f6b438b2329aef32686c20e55b7d3.tar.bz2 samba-667b825d183f6b438b2329aef32686c20e55b7d3.zip |
pyldb: Don't segfault when invalid type is specified to Dn.get().
(#6722)
-rw-r--r-- | source4/lib/ldb/pyldb.c | 12 | ||||
-rwxr-xr-x | source4/lib/ldb/tests/python/api.py | 4 |
2 files changed, 14 insertions, 2 deletions
diff --git a/source4/lib/ldb/pyldb.c b/source4/lib/ldb/pyldb.c index 3f7fa2f395..b4f03dc538 100644 --- a/source4/lib/ldb/pyldb.c +++ b/source4/lib/ldb/pyldb.c @@ -1758,8 +1758,13 @@ static PyObject *py_ldb_msg_keys(PyLdbMessageObject *self) static PyObject *py_ldb_msg_getitem_helper(PyLdbMessageObject *self, PyObject *py_name) { struct ldb_message_element *el; - char *name = PyString_AsString(py_name); + char *name; struct ldb_message *msg = PyLdbMessage_AsMessage(self); + if (!PyString_Check(py_name)) { + PyErr_SetNone(PyExc_TypeError); + return NULL; + } + name = PyString_AsString(py_name); if (!strcmp(name, "dn")) return PyLdbDn_FromDn(msg->dn); el = ldb_msg_find_element(msg, name); @@ -1786,8 +1791,11 @@ static PyObject *py_ldb_msg_get(PyLdbMessageObject *self, PyObject *args) return NULL; ret = py_ldb_msg_getitem_helper(self, name); - if (ret == NULL) + if (ret == NULL) { + if (PyErr_Occurred()) + return NULL; Py_RETURN_NONE; + } return ret; } diff --git a/source4/lib/ldb/tests/python/api.py b/source4/lib/ldb/tests/python/api.py index 88983ac738..133bd180c1 100755 --- a/source4/lib/ldb/tests/python/api.py +++ b/source4/lib/ldb/tests/python/api.py @@ -480,6 +480,10 @@ class LdbMsgTests(unittest.TestCase): self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO") self.assertEquals("@BASEINFO", self.msg.get("dn").__str__()) + def test_get_invalid(self): + self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO") + self.assertRaises(TypeError, self.msg.get, 42) + def test_get_other(self): self.msg["foo"] = ["bar"] self.assertEquals("bar", self.msg.get("foo")[0]) |