/*
   Unix SMB/CIFS implementation.
   Implementation of a reliable server_exists()
   Copyright (C) Volker Lendecke 2010

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 3 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "includes.h"
#include "serverid.h"

struct serverid_key {
	pid_t pid;
#ifdef CLUSTER_SUPPORT
	uint32_t vnn;
#endif
};

struct serverid_data {
	uint64_t unique_id;
	uint32_t msg_flags;
};

static struct db_context *serverid_db(void)
{
	static struct db_context *db;

	if (db != NULL) {
		return db;
	}
	db = db_open(talloc_autofree_context(), lock_path("serverid.tdb"),
		     0, TDB_DEFAULT|TDB_CLEAR_IF_FIRST, O_RDWR|O_CREAT, 0644);
	return db;
}

static void serverid_fill_key(const struct server_id *id,
			      struct serverid_key *key)
{
	ZERO_STRUCTP(key);
	key->pid = id->pid;
#ifdef CLUSTER_SUPPORT
	key->vnn = id->vnn;
#endif
}

bool serverid_register(const struct server_id *id, uint32_t msg_flags)
{
	struct db_context *db;
	struct serverid_key key;
	struct serverid_data data;
	struct db_record *rec;
	TDB_DATA tdbkey, tdbdata;
	NTSTATUS status;
	bool ret = false;

	db = serverid_db();
	if (db == NULL) {
		return false;
	}

	serverid_fill_key(id, &key);
	tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));

	rec = db->fetch_locked(db, talloc_tos(), tdbkey);
	if (rec == NULL) {
		DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
		return false;
	}

	ZERO_STRUCT(data);
	data.unique_id = id->unique_id;
	data.msg_flags = msg_flags;

	tdbdata = make_tdb_data((uint8_t *)&data, sizeof(data));
	status = rec->store(rec, tdbdata, 0);
	if (!NT_STATUS_IS_OK(status)) {
		DEBUG(1, ("Storing serverid.tdb record failed: %s\n",
			  nt_errstr(status)));
		goto done;
	}
	ret = true;
done:
	TALLOC_FREE(rec);
	return ret;
}

bool serverid_register_self(uint32_t msg_flags)
{
	struct server_id pid;

	pid = procid_self();
	return serverid_register(&pid, msg_flags);
}

bool serverid_deregister(const struct server_id *id)
{
	struct db_context *db;
	struct serverid_key key;
	struct db_record *rec;
	TDB_DATA tdbkey;
	NTSTATUS status;
	bool ret = false;

	db = serverid_db();
	if (db == NULL) {
		return false;
	}

	serverid_fill_key(id, &key);
	tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));

	rec = db->fetch_locked(db, talloc_tos(), tdbkey);
	if (rec == NULL) {
		DEBUG(1, ("Could not fetch_lock serverid.tdb record\n"));
		return false;
	}

	status = rec->delete_rec(rec);
	if (!NT_STATUS_IS_OK(status)) {
		DEBUG(1, ("Deleting serverid.tdb record failed: %s\n",
			  nt_errstr(status)));
		goto done;
	}
	ret = true;
done:
	TALLOC_FREE(rec);
	return ret;
}

bool serverid_deregister_self(void)
{
	struct server_id pid;

	pid = procid_self();
	return serverid_deregister(&pid);
}

struct serverid_exists_state {
	const struct server_id *id;
	bool exists;
};

static int server_exists_parse(TDB_DATA key, TDB_DATA data, void *priv)
{
	struct serverid_exists_state *state =
		(struct serverid_exists_state *)priv;
	uint64_t unique_id;

	if (data.dsize != sizeof(struct serverid_data)) {
		return -1;
	}

	/* memcpy, data.dptr might not be aligned */
	memcpy(&unique_id, data.dptr, sizeof(unique_id));

	state->exists = (state->id->unique_id == unique_id);
	return 0;
}

bool serverid_exists(const struct server_id *id)
{
	struct db_context *db;
	struct serverid_exists_state state;
	struct serverid_key key;
	TDB_DATA tdbkey;

	db = serverid_db();
	if (db == NULL) {
		return false;
	}

	serverid_fill_key(id, &key);
	tdbkey = make_tdb_data((uint8_t *)&key, sizeof(key));

	state.id = id;
	state.exists = false;

	if (db->parse_record(db, tdbkey, server_exists_parse, &state) == -1) {
		return false;
	}
	return state.exists;
}

static bool serverid_rec_parse(const struct db_record *rec,
			       struct server_id *id, uint32_t *msg_flags)
{
	struct serverid_key key;
	struct serverid_data data;

	if (rec->key.dsize != sizeof(key)) {
		DEBUG(1, ("Found invalid key length %d in serverid.tdb\n",
			  (int)rec->key.dsize));
		return false;
	}
	if (rec->value.dsize != sizeof(data)) {
		DEBUG(1, ("Found invalid value length %d in serverid.tdb\n",
			  (int)rec->value.dsize));
		return false;
	}

	memcpy(&key, rec->key.dptr, sizeof(key));
	memcpy(&data, rec->value.dptr, sizeof(data));

	id->pid = key.pid;
#ifdef CLUSTER_SUPPORT
	id->vnn = key.vnn;
#endif
	id->unique_id = data.unique_id;
	*msg_flags = data.msg_flags;
	return true;
}

struct serverid_traverse_read_state {
	int (*fn)(const struct server_id *id, uint32_t msg_flags,
		  void *private_data);
	void *private_data;
};

static int serverid_traverse_read_fn(struct db_record *rec, void *private_data)
{
	struct serverid_traverse_read_state *state =
		(struct serverid_traverse_read_state *)private_data;
	struct server_id id;
	uint32_t msg_flags;

	if (!serverid_rec_parse(rec, &id, &msg_flags)) {
		return 0;
	}
	return state->fn(&id, msg_flags,state->private_data);
}

bool serverid_traverse_read(int (*fn)(const struct server_id *id,
				      uint32_t msg_flags, void *private_data),
			    void *private_data)
{
	struct db_context *db;
	struct serverid_traverse_read_state state;

	db = serverid_db();
	if (db == NULL) {
		return false;
	}
	state.fn = fn;
	state.private_data = private_data;
	return db->traverse_read(db, serverid_traverse_read_fn, &state);
}

struct serverid_traverse_state {
	int (*fn)(struct db_record *rec, const struct server_id *id,
		  uint32_t msg_flags, void *private_data);
	void *private_data;
};

static int serverid_traverse_fn(struct db_record *rec, void *private_data)
{
	struct serverid_traverse_state *state =
		(struct serverid_traverse_state *)private_data;
	struct server_id id;
	uint32_t msg_flags;

	if (!serverid_rec_parse(rec, &id, &msg_flags)) {
		return 0;
	}
	return state->fn(rec, &id, msg_flags, state->private_data);
}

bool serverid_traverse(int (*fn)(struct db_record *rec,
				 const struct server_id *id,
				 uint32_t msg_flags, void *private_data),
			    void *private_data)
{
	struct db_context *db;
	struct serverid_traverse_state state;

	db = serverid_db();
	if (db == NULL) {
		return false;
	}
	state.fn = fn;
	state.private_data = private_data;
	return db->traverse(db, serverid_traverse_fn, &state);
}