/* 
   Unix SMB/CIFS implementation.
   
   Copyright (C) Andrew Tridgell              2003
   
   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.
   
   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.
   
   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/

/*
  this file implements functions for manipulating the 'struct request_context' structure in smbd
*/

#include "includes.h"

/* we over allocate the data buffer to prevent too many realloc calls */
#define REQ_OVER_ALLOCATION 256

/* destroy a request structure */
void req_destroy(struct request_context *req)
{
	/* the request might be marked protected. This is done by the
	 * SMBecho code for example */
	if (req->control_flags & REQ_CONTROL_PROTECTED) {
		return;
	}

	/* ahh, its so nice to destroy a complex structure in such a
	 * simple way! */
	talloc_destroy(req->mem_ctx);
}

/****************************************************************************
construct a basic request packet, mostly used to construct async packets
such as change notify and oplock break requests
****************************************************************************/
struct request_context *init_smb_request(struct server_context *smb)
{
	struct request_context *req;
	TALLOC_CTX *mem_ctx;

	/* each request gets its own talloc context. The request
	   structure itself is also allocated inside this context, so
	   we need to allocate it before we construct the request
	*/
	mem_ctx = talloc_init("request_context[%d]", smb->socket.pkt_count);
	if (!mem_ctx) {
		return NULL;
	}

	smb->socket.pkt_count++;

	req = talloc(mem_ctx, sizeof(*req));
	ZERO_STRUCTP(req);

	/* setup the request context */
	req->smb = smb;
	req->mem_ctx = mem_ctx;
	
	return req;
}


/*
  setup a chained reply in req->out with the given word count and initial data buffer size. 
*/
static void req_setup_chain_reply(struct request_context *req, unsigned wct, unsigned buflen)
{
	uint32 chain_base_size = req->out.size;

	/* we need room for the wct value, the words, the buffer length and the buffer */
	req->out.size += 1 + VWV(wct) + 2 + buflen;

	/* over allocate by a small amount */
	req->out.allocated = req->out.size + REQ_OVER_ALLOCATION; 

	req->out.buffer = talloc_realloc(req->mem_ctx, req->out.buffer, req->out.allocated);
	if (!req->out.buffer) {
		exit_server(req->smb, "allocation failed");
	}

	req->out.hdr = req->out.buffer + NBT_HDR_SIZE;
	req->out.vwv = req->out.buffer + chain_base_size + 1;
	req->out.wct = wct;
	req->out.data = req->out.vwv + VWV(wct) + 2;
	req->out.data_size = buflen;
	req->out.ptr = req->out.data;

	SCVAL(req->out.buffer, chain_base_size, wct);
	SSVAL(req->out.vwv, VWV(wct), buflen);
}


/*
  setup a reply in req->out with the given word count and initial data buffer size. 
  the caller will then fill in the command words and data before calling req_send_reply() to 
  send the reply on its way
*/
void req_setup_reply(struct request_context *req, unsigned wct, unsigned buflen)
{
	if (req->chain_count != 0) {
		req_setup_chain_reply(req, wct, buflen);
		return;
	}

	req->out.size = NBT_HDR_SIZE + MIN_SMB_SIZE + wct*2 + buflen;

	/* over allocate by a small amount */
	req->out.allocated = req->out.size + REQ_OVER_ALLOCATION; 

	req->out.buffer = talloc(req->mem_ctx, req->out.allocated);
	if (!req->out.buffer) {
		exit_server(req->smb, "allocation failed");
	}

	req->out.hdr = req->out.buffer + NBT_HDR_SIZE;
	req->out.vwv = req->out.hdr + HDR_VWV;
	req->out.wct = wct;
	req->out.data = req->out.vwv + VWV(wct) + 2;
	req->out.data_size = buflen;
	req->out.ptr = req->out.data;

	SIVAL(req->out.hdr, HDR_RCLS, 0);

	SCVAL(req->out.hdr, HDR_WCT, wct);
	SSVAL(req->out.vwv, VWV(wct), buflen);


	memcpy(req->out.hdr, "\377SMB", 4);
	SCVAL(req->out.hdr,HDR_FLG, FLAG_REPLY | FLAG_CASELESS_PATHNAMES); 
	SSVAL(req->out.hdr,HDR_FLG2, 
	      (req->flags2 & FLAGS2_UNICODE_STRINGS) |
	      FLAGS2_LONG_PATH_COMPONENTS | FLAGS2_32_BIT_ERROR_CODES | FLAGS2_EXTENDED_SECURITY);

	SSVAL(req->out.hdr,HDR_PIDHIGH,0);
	memset(req->out.hdr + HDR_SS_FIELD, 0, 10);

	if (req->in.hdr) {
		/* copy the cmd, tid, pid, uid and mid from the request */
		SCVAL(req->out.hdr,HDR_COM,CVAL(req->in.hdr,HDR_COM));	
		SSVAL(req->out.hdr,HDR_TID,SVAL(req->in.hdr,HDR_TID));
		SSVAL(req->out.hdr,HDR_PID,SVAL(req->in.hdr,HDR_PID));
		SSVAL(req->out.hdr,HDR_UID,SVAL(req->in.hdr,HDR_UID));
		SSVAL(req->out.hdr,HDR_MID,SVAL(req->in.hdr,HDR_MID));
	} else {
		SSVAL(req->out.hdr,HDR_TID,0);
		SSVAL(req->out.hdr,HDR_PID,0);
		SSVAL(req->out.hdr,HDR_UID,0);
		SSVAL(req->out.hdr,HDR_MID,0);
	}
}

/*
  work out the maximum data size we will allow for this reply, given
  the negotiated max_xmit. The basic reply packet must be setup before
  this call

  note that this is deliberately a signed integer reply
*/
int req_max_data(struct request_context *req)
{
	int ret;
	ret = req->smb->negotiate.max_send;
	ret -= PTR_DIFF(req->out.data, req->out.hdr);
	if (ret < 0) ret = 0;
	return ret;
}


/*
  grow the allocation of the data buffer portion of a reply
  packet. Note that as this can reallocate the packet buffer this
  invalidates any local pointers into the packet.

  To cope with this req->out.ptr is supplied. This will be updated to
  point at the same offset into the packet as before this call
*/
static void req_grow_allocation(struct request_context *req, unsigned new_size)
{
	int delta;
	char *buf2;

	delta = new_size - req->out.data_size;
	if (delta + req->out.size <= req->out.allocated) {
		/* it fits in the preallocation */
		return;
	}

	/* we need to realloc */
	req->out.allocated = req->out.size + delta + REQ_OVER_ALLOCATION;
	buf2 = talloc_realloc(req->mem_ctx, req->out.buffer, req->out.allocated);
	if (buf2 == NULL) {
		smb_panic("out of memory in req_grow_allocation");
	}

	if (buf2 == req->out.buffer) {
		/* the malloc library gave us the same pointer */
		return;
	}
	
	/* update the pointers into the packet */
	req->out.data = buf2 + PTR_DIFF(req->out.data, req->out.buffer);
	req->out.ptr  = buf2 + PTR_DIFF(req->out.ptr,  req->out.buffer);
	req->out.vwv  = buf2 + PTR_DIFF(req->out.vwv,  req->out.buffer);
	req->out.hdr  = buf2 + PTR_DIFF(req->out.hdr,  req->out.buffer);

	req->out.buffer = buf2;
}


/*
  grow the data buffer portion of a reply packet. Note that as this
  can reallocate the packet buffer this invalidates any local pointers
  into the packet. 

  To cope with this req->out.ptr is supplied. This will be updated to
  point at the same offset into the packet as before this call
*/
void req_grow_data(struct request_context *req, unsigned new_size)
{
	int delta;

	if (!(req->control_flags & REQ_CONTROL_LARGE) && new_size > req_max_data(req)) {
		smb_panic("reply buffer too large!");
	}

	req_grow_allocation(req, new_size);

	delta = new_size - req->out.data_size;

	req->out.size += delta;
	req->out.data_size += delta;

	/* set the BCC to the new data size */
	SSVAL(req->out.vwv, VWV(req->out.wct), new_size);
}

/*
  send a reply and destroy the request buffer

  note that this only looks at req->out.buffer and req->out.size, allowing manually 
  constructed packets to be sent
*/
void req_send_reply(struct request_context *req)
{
	if (req->out.size > NBT_HDR_SIZE) {
		_smb_setlen(req->out.buffer, req->out.size - NBT_HDR_SIZE);
	}

	if (write_data(req->smb->socket.fd, req->out.buffer, req->out.size) != req->out.size) {
		smb_panic("failed to send reply\n");
	}

	req_destroy(req);
}



/* 
   construct and send an error packet with a forced DOS error code
   this is needed to match win2000 behaviour for some parts of the protocol
*/
void req_reply_dos_error(struct request_context *req, uint8 eclass, uint16 ecode)
{
	/* if the basic packet hasn't been setup yet then do it now */
	if (req->out.buffer == NULL) {
		req_setup_reply(req, 0, 0);
	}

	SCVAL(req->out.hdr, HDR_RCLS, eclass);
	SSVAL(req->out.hdr, HDR_ERR, ecode);

	SSVAL(req->out.hdr, HDR_FLG2, SVAL(req->out.hdr, HDR_FLG2) & ~FLAGS2_32_BIT_ERROR_CODES);
	
	req_send_reply(req);
}

/* 
   construct and send an error packet, then destroy the request 
   auto-converts to DOS error format when appropriate
*/
void req_reply_error(struct request_context *req, NTSTATUS status)
{
	req_setup_reply(req, 0, 0);

	/* error returns never have any data */
	req_grow_data(req, 0);

	if (!lp_nt_status_support() || !(req->smb->negotiate.client_caps & CAP_STATUS32)) {
		/* convert to DOS error codes */
		uint8 eclass;
		uint32 ecode;
		ntstatus_to_dos(status, &eclass, &ecode);
		req_reply_dos_error(req, eclass, ecode);
		return;
	}

	SIVAL(req->out.hdr, HDR_RCLS, NT_STATUS_V(status));
	SSVAL(req->out.hdr, HDR_FLG2, SVAL(req->out.hdr, HDR_FLG2) | FLAGS2_32_BIT_ERROR_CODES);
	
	req_send_reply(req);
}


/*
  push a string into the data portion of the request packet, growing it if necessary
  this gets quite tricky - please be very careful to cover all cases when modifying this

  if dest is NULL, then put the string at the end of the data portion of the packet

  if dest_len is -1 then no limit applies
*/
size_t req_push_str(struct request_context *req, char *dest, const char *str, int dest_len, unsigned flags)
{
	size_t len;
	unsigned grow_size;
	char *buf0;
	const int max_bytes_per_char = 3;

	if (!(flags & (STR_ASCII|STR_UNICODE))) {
		flags |= (req->flags2 & FLAGS2_UNICODE_STRINGS) ? STR_UNICODE : STR_ASCII;
	}

	if (dest == NULL) {
		dest = req->out.data + req->out.data_size;
	}

	if (dest_len != -1) {
		len = dest_len;
	} else {
		len = (strlen(str)+2) * max_bytes_per_char;
	}

	grow_size = len + PTR_DIFF(dest, req->out.data);
	buf0 = req->out.buffer;

	req_grow_allocation(req, grow_size);

	if (buf0 != req->out.buffer) {
		dest = req->out.buffer + PTR_DIFF(dest, buf0);
	}

	len = push_string(req->out.hdr, dest, str, len, flags);

	grow_size = len + PTR_DIFF(dest, req->out.data);

	if (grow_size > req->out.data_size) {
		req_grow_data(req, grow_size);
	}

	return len;
}


/*
  pull a UCS2 string from a request packet, returning a talloced unix string

  the string length is limited by the 3 things:
   - the data size in the request (end of packet)
   - the passed 'byte_len' if it is not -1
   - the end of string (null termination)

  Note that 'byte_len' is the number of bytes in the packet

  on failure zero is returned and *dest is set to NULL, otherwise the number
  of bytes consumed in the packet is returned
*/
static size_t req_pull_ucs2(struct request_context *req, const char **dest, const char *src, int byte_len, unsigned flags)
{
	int src_len, src_len2, alignment=0;
	ssize_t ret;

	if (!(flags & STR_NOALIGN) && ucs2_align(req->in.buffer, src, flags)) {
		src++;
		alignment=1;
		if (byte_len != -1) {
			byte_len--;
		}
	}

	if (flags & STR_NO_RANGE_CHECK) {
		src_len = byte_len;
	} else {
		src_len = req->in.data_size - PTR_DIFF(src, req->in.data);
		if (src_len < 0) {
			*dest = NULL;
			return 0;
		}

		if (byte_len != -1 && src_len > byte_len) {
			src_len = byte_len;
		}
	}

	src_len2 = strnlen_w((const smb_ucs2_t *)src, src_len/2) * 2;

	if (src_len2 <= src_len - 2) {
		/* include the termination if we didn't reach the end of the packet */
		src_len2 += 2;
	}

	ret = convert_string_talloc(req->mem_ctx, CH_UCS2, CH_UNIX, src, src_len2, (const void **)dest);

	if (ret == -1) {
		*dest = NULL;
		return 0;
	}

	return src_len2 + alignment;
}

/*
  pull a ascii string from a request packet, returning a talloced string

  the string length is limited by the 3 things:
   - the data size in the request (end of packet)
   - the passed 'byte_len' if it is not -1
   - the end of string (null termination)

  Note that 'byte_len' is the number of bytes in the packet

  on failure zero is returned and *dest is set to NULL, otherwise the number
  of bytes consumed in the packet is returned
*/
static size_t req_pull_ascii(struct request_context *req, const char **dest, const char *src, int byte_len, unsigned flags)
{
	int src_len, src_len2;
	ssize_t ret;

	if (flags & STR_NO_RANGE_CHECK) {
		src_len = byte_len;
	} else {
		src_len = req->in.data_size - PTR_DIFF(src, req->in.data);
		if (src_len < 0) {
			*dest = NULL;
			return 0;
		}
		if (byte_len != -1 && src_len > byte_len) {
			src_len = byte_len;
		}
	}

	src_len2 = strnlen(src, src_len);
	if (src_len2 <= src_len - 1) {
		/* include the termination if we didn't reach the end of the packet */
		src_len2++;
	}

	ret = convert_string_talloc(req->mem_ctx, CH_DOS, CH_UNIX, src, src_len2, (const void **)dest);

	if (ret == -1) {
		*dest = NULL;
		return 0;
	}

	return src_len2;
}

/*
  pull a string from a request packet, returning a talloced string

  the string length is limited by the 3 things:
   - the data size in the request (end of packet)
   - the passed 'byte_len' if it is not -1
   - the end of string (null termination)

  Note that 'byte_len' is the number of bytes in the packet

  on failure zero is returned and *dest is set to NULL, otherwise the number
  of bytes consumed in the packet is returned
*/
size_t req_pull_string(struct request_context *req, const char **dest, const char *src, int byte_len, unsigned flags)
{
	if (!(flags & STR_ASCII) && 
	    (((flags & STR_UNICODE) || (req->flags2 & FLAGS2_UNICODE_STRINGS)))) {
		return req_pull_ucs2(req, dest, src, byte_len, flags);
	}

	return req_pull_ascii(req, dest, src, byte_len, flags);
}


/*
  pull a ASCII4 string buffer from a request packet, returning a talloced string
  
  an ASCII4 buffer is a null terminated string that has a prefix
  of the character 0x4. It tends to be used in older parts of the protocol.

  on failure *dest is set to the zero length string. This seems to
  match win2000 behaviour
*/
size_t req_pull_ascii4(struct request_context *req, const char **dest, const char *src, unsigned flags)
{
	ssize_t ret;

	if (PTR_DIFF(src, req->in.data) + 1 > req->in.data_size) {
		/* win2000 treats this as the NULL string! */
		(*dest) = talloc_strdup(req->mem_ctx, "");
		return 0;
	}

	/* this consumes the 0x4 byte. We don't check whether the byte
	   is actually 0x4 or not. This matches win2000 server
	   behaviour */
	src++;

	ret = req_pull_string(req, dest, src, -1, flags);
	if (ret == -1) {
		(*dest) = talloc_strdup(req->mem_ctx, "");
		return 1;
	}
	
	return ret + 1;
}

/*
  pull a DATA_BLOB from a request packet, returning a talloced blob

  return False if any part is outside the data portion of the packet
*/
BOOL req_pull_blob(struct request_context *req, const char *src, int len, DATA_BLOB *blob)
{
	if (len != 0 && req_data_oob(req, src, len)) {
		return False;
	}

	(*blob) = data_blob_talloc(req->mem_ctx, src, len);

	return True;
}

/* check that a lump of data in a request is within the bounds of the data section of
   the packet */
BOOL req_data_oob(struct request_context *req, const char *ptr, uint32 count)
{
	if (count == 0) {
		return False;
	}
	
	/* be careful with wraparound! */
	if (ptr < req->in.data ||
	    ptr >= req->in.data + req->in.data_size ||
	    count > req->in.data_size ||
	    ptr + count > req->in.data + req->in.data_size) {
		return True;
	}
	return False;
}


/* 
   pull an open file handle from a packet, taking account of the chained_fnum
*/
uint16 req_fnum(struct request_context *req, const char *base, unsigned offset)
{
	if (req->chained_fnum != -1) {
		return req->chained_fnum;
	}
	return SVAL(base, offset);
}