summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source3/smbd/connection.c88
-rw-r--r--source3/tdb/tdb.c14
2 files changed, 91 insertions, 11 deletions
diff --git a/source3/smbd/connection.c b/source3/smbd/connection.c
index 43e89abfbf..6c401264d9 100644
--- a/source3/smbd/connection.c
+++ b/source3/smbd/connection.c
@@ -37,8 +37,9 @@ TDB_CONTEXT *conn_tdb_ctx(void)
}
/****************************************************************************
-delete a connection record
+ Delete a connection record.
****************************************************************************/
+
BOOL yield_connection(connection_struct *conn,char *name,int max_connections)
{
struct connections_key key;
@@ -62,21 +63,92 @@ BOOL yield_connection(connection_struct *conn,char *name,int max_connections)
return(True);
}
+struct count_stat {
+ pid_t mypid;
+ int curr_connections;
+ char *name;
+};
+
+/****************************************************************************
+ Count the entries belonging to a service in the connection db.
+****************************************************************************/
+
+static int count_fn( TDB_CONTEXT *the_tdb, TDB_DATA kbuf, TDB_DATA dbuf, void *udp)
+{
+ struct connections_data crec;
+ struct count_stat *cs = (struct count_stat *)udp;
+
+ memcpy(&crec, dbuf.dptr, sizeof(crec));
+
+ if (crec.cnum == -1)
+ return 0;
+
+ /* if the pid was not found delete the entry from connections.tdb */
+ if (!process_exists(crec.pid) && (errno == ESRCH)) {
+ DEBUG(2,("pid %u doesn't exist - deleting connections %d [%s]\n",
+ (unsigned int)crec.pid, crec.cnum, crec.name));
+ tdb_delete(the_tdb, kbuf);
+ return 0;
+ }
+
+ if (strequal(crec.name, cs->name))
+ cs->curr_connections++;
+
+ return 0;
+}
/****************************************************************************
-claim an entry in the connections database
+ Claim an entry in the connections database.
****************************************************************************/
+
BOOL claim_connection(connection_struct *conn,char *name,int max_connections,BOOL Clear)
{
struct connections_key key;
struct connections_data crec;
TDB_DATA kbuf, dbuf;
+ BOOL db_locked = False;
+ BOOL ret = True;
if (!tdb) {
tdb = tdb_open(lock_path("connections.tdb"), 0, TDB_CLEAR_IF_FIRST,
O_RDWR | O_CREAT, 0644);
}
- if (!tdb) return False;
+ if (!tdb)
+ return False;
+
+ /*
+ * Enforce the max connections parameter.
+ */
+
+ if (max_connections > 0) {
+ struct count_stat cs;
+
+ cs.mypid = sys_getpid();
+ cs.curr_connections = 0;
+ cs.name = lp_servicename(SNUM(conn));
+
+ /*
+ * Go through and count the connections with the db locked. This is
+ * slow but essentially what 2.0.x did. JRA.
+ */
+
+ if (tdb_lockall(tdb))
+ return False;
+
+ db_locked = True;
+
+ if (tdb_traverse(tdb, count_fn, &cs)) {
+ ret = False;
+ goto out;
+ }
+
+ if (cs.curr_connections >= max_connections) {
+ DEBUG(1,("claim_connection: Max connections (%d) exceeded for %s\n",
+ max_connections, name ));
+ ret = False;
+ goto out;
+ }
+ }
DEBUG(5,("claiming %s %d\n",name,max_connections));
@@ -108,8 +180,14 @@ BOOL claim_connection(connection_struct *conn,char *name,int max_connections,BOO
dbuf.dptr = (char *)&crec;
dbuf.dsize = sizeof(crec);
- if (tdb_store(tdb, kbuf, dbuf, TDB_REPLACE) != 0) return False;
+ if (tdb_store(tdb, kbuf, dbuf, TDB_REPLACE) != 0)
+ ret = False;
+
+ out:
+
+ if (db_locked)
+ tdb_unlockall(tdb);
- return True;
+ return ret;
}
diff --git a/source3/tdb/tdb.c b/source3/tdb/tdb.c
index 25f458ac22..5fcea52d5e 100644
--- a/source3/tdb/tdb.c
+++ b/source3/tdb/tdb.c
@@ -872,10 +872,10 @@ int tdb_traverse(TDB_CONTEXT *tdb, tdb_traverse_func fn, void *state)
struct tdb_traverse_lock tl = { NULL, 0, 0 };
int ret, count = 0;
- /* This was in the initializaton, above, but the IRIX compiler
- * did not like it. crh
- */
- tl.next = tdb->travlocks.next;
+ /* This was in the initializaton, above, but the IRIX compiler
+ * did not like it. crh
+ */
+ tl.next = tdb->travlocks.next;
/* fcntl locks don't stack: beware traverse inside traverse */
tdb->travlocks.next = &tl;
@@ -908,8 +908,10 @@ int tdb_traverse(TDB_CONTEXT *tdb, tdb_traverse_func fn, void *state)
free(key.dptr);
}
tdb->travlocks.next = tl.next;
- if (ret < 0) return -1;
- else return count;
+ if (ret < 0)
+ return -1;
+ else
+ return count;
}
/* find the first entry in the database and return its key */