summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJelmer Vernooij <jelmer@samba.org>2009-09-14 17:03:30 +0200
committerJelmer Vernooij <jelmer@samba.org>2009-09-14 17:03:30 +0200
commit667b825d183f6b438b2329aef32686c20e55b7d3 (patch)
treeb15042402c3d577b8286ce2737504513d48d0f46
parentd106e728fb0c59900c289055c97f424e4f5d3c75 (diff)
downloadsamba-667b825d183f6b438b2329aef32686c20e55b7d3.tar.gz
samba-667b825d183f6b438b2329aef32686c20e55b7d3.tar.bz2
samba-667b825d183f6b438b2329aef32686c20e55b7d3.zip
pyldb: Don't segfault when invalid type is specified to Dn.get().
(#6722)
-rw-r--r--source4/lib/ldb/pyldb.c12
-rwxr-xr-xsource4/lib/ldb/tests/python/api.py4
2 files changed, 14 insertions, 2 deletions
diff --git a/source4/lib/ldb/pyldb.c b/source4/lib/ldb/pyldb.c
index 3f7fa2f395..b4f03dc538 100644
--- a/source4/lib/ldb/pyldb.c
+++ b/source4/lib/ldb/pyldb.c
@@ -1758,8 +1758,13 @@ static PyObject *py_ldb_msg_keys(PyLdbMessageObject *self)
static PyObject *py_ldb_msg_getitem_helper(PyLdbMessageObject *self, PyObject *py_name)
{
struct ldb_message_element *el;
- char *name = PyString_AsString(py_name);
+ char *name;
struct ldb_message *msg = PyLdbMessage_AsMessage(self);
+ if (!PyString_Check(py_name)) {
+ PyErr_SetNone(PyExc_TypeError);
+ return NULL;
+ }
+ name = PyString_AsString(py_name);
if (!strcmp(name, "dn"))
return PyLdbDn_FromDn(msg->dn);
el = ldb_msg_find_element(msg, name);
@@ -1786,8 +1791,11 @@ static PyObject *py_ldb_msg_get(PyLdbMessageObject *self, PyObject *args)
return NULL;
ret = py_ldb_msg_getitem_helper(self, name);
- if (ret == NULL)
+ if (ret == NULL) {
+ if (PyErr_Occurred())
+ return NULL;
Py_RETURN_NONE;
+ }
return ret;
}
diff --git a/source4/lib/ldb/tests/python/api.py b/source4/lib/ldb/tests/python/api.py
index 88983ac738..133bd180c1 100755
--- a/source4/lib/ldb/tests/python/api.py
+++ b/source4/lib/ldb/tests/python/api.py
@@ -480,6 +480,10 @@ class LdbMsgTests(unittest.TestCase):
self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO")
self.assertEquals("@BASEINFO", self.msg.get("dn").__str__())
+ def test_get_invalid(self):
+ self.msg.dn = ldb.Dn(ldb.Ldb("foo.tdb"), "@BASEINFO")
+ self.assertRaises(TypeError, self.msg.get, 42)
+
def test_get_other(self):
self.msg["foo"] = ["bar"]
self.assertEquals("bar", self.msg.get("foo")[0])