summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--source4/lib/stream/packet.c45
-rw-r--r--source4/lib/stream/packet.h3
-rw-r--r--source4/librpc/rpc/dcerpc_sock.c32
3 files changed, 66 insertions, 14 deletions
diff --git a/source4/lib/stream/packet.c b/source4/lib/stream/packet.c
index f367368def..7aed43a943 100644
--- a/source4/lib/stream/packet.c
+++ b/source4/lib/stream/packet.c
@@ -36,6 +36,7 @@ struct packet_context {
DATA_BLOB partial;
uint32_t initial_read_size;
uint32_t num_read;
+ uint32_t initial_read;
struct tls_context *tls;
struct socket_context *sock;
struct event_context *ev;
@@ -44,6 +45,7 @@ struct packet_context {
struct fd_event *fde;
BOOL serialise;
BOOL processing;
+ BOOL recv_disable;
struct send_element {
struct send_element *next, *prev;
@@ -134,6 +136,15 @@ void packet_set_serialise(struct packet_context *pc, struct fd_event *fde)
pc->fde = fde;
}
+/*
+ tell the packet layer how much to read when starting a new packet
+ this ensures it doesn't overread
+*/
+void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
+{
+ pc->initial_read = initial_read;
+}
+
/*
tell the caller we have an error
@@ -172,7 +183,9 @@ static void packet_next_event(struct event_context *ev, struct timed_event *te,
struct timeval t, void *private)
{
struct packet_context *pc = talloc_get_type(private, struct packet_context);
- packet_recv(pc);
+ if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
+ packet_recv(pc);
+ }
}
/*
@@ -190,6 +203,11 @@ void packet_recv(struct packet_context *pc)
return;
}
+ if (pc->recv_disable) {
+ EVENT_FD_NOT_READABLE(pc->fde);
+ return;
+ }
+
if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
goto next_partial;
}
@@ -198,6 +216,8 @@ void packet_recv(struct packet_context *pc)
/* we've already worked out how long this next packet is, so skip the
socket_pending() call */
npending = pc->packet_size - pc->num_read;
+ } else if (pc->initial_read != 0) {
+ npending = pc->initial_read - pc->num_read;
} else {
if (pc->tls) {
status = tls_socket_pending(pc->tls, &npending);
@@ -306,7 +326,7 @@ next_partial:
if (pc->partial.length == 0) {
return;
}
-
+
/* we got multiple packets in one tcp read */
if (pc->ev == NULL) {
goto next_partial;
@@ -330,6 +350,27 @@ next_partial:
/*
+ temporarily disable receiving
+*/
+void packet_recv_disable(struct packet_context *pc)
+{
+ EVENT_FD_NOT_READABLE(pc->fde);
+ pc->recv_disable = True;
+}
+
+/*
+ re-enable receiving
+*/
+void packet_recv_enable(struct packet_context *pc)
+{
+ EVENT_FD_READABLE(pc->fde);
+ pc->recv_disable = False;
+ if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
+ event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
+ }
+}
+
+/*
trigger a run of the send queue
*/
void packet_queue_run(struct packet_context *pc)
diff --git a/source4/lib/stream/packet.h b/source4/lib/stream/packet.h
index bba8a1940f..a8db89853c 100644
--- a/source4/lib/stream/packet.h
+++ b/source4/lib/stream/packet.h
@@ -39,7 +39,10 @@ void packet_set_tls(struct packet_context *pc, struct tls_context *tls);
void packet_set_socket(struct packet_context *pc, struct socket_context *sock);
void packet_set_event_context(struct packet_context *pc, struct event_context *ev);
void packet_set_serialise(struct packet_context *pc, struct fd_event *fde);
+void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read);
void packet_recv(struct packet_context *pc);
+void packet_recv_disable(struct packet_context *pc);
+void packet_recv_enable(struct packet_context *pc);
NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob);
void packet_queue_run(struct packet_context *pc);
diff --git a/source4/librpc/rpc/dcerpc_sock.c b/source4/librpc/rpc/dcerpc_sock.c
index eb2e7f8f66..2ecf9f1530 100644
--- a/source4/librpc/rpc/dcerpc_sock.c
+++ b/source4/librpc/rpc/dcerpc_sock.c
@@ -35,6 +35,7 @@ struct sock_private {
char *server_name;
struct packet_context *packet;
+ uint32_t pending_reads;
};
@@ -83,22 +84,17 @@ static NTSTATUS sock_complete_packet(void *private, DATA_BLOB blob, size_t *size
}
/*
- process send requests
-*/
-static void sock_process_send(struct dcerpc_connection *p)
-{
- struct sock_private *sock = p->transport.private;
- packet_queue_run(sock->packet);
-}
-
-
-/*
process recv requests
*/
static NTSTATUS sock_process_recv(void *private, DATA_BLOB blob)
{
struct dcerpc_connection *p = talloc_get_type(private,
struct dcerpc_connection);
+ struct sock_private *sock = p->transport.private;
+ sock->pending_reads--;
+ if (sock->pending_reads == 0) {
+ packet_recv_disable(sock->packet);
+ }
p->transport.recv_data(p, &blob, NT_STATUS_OK);
return NT_STATUS_OK;
}
@@ -114,7 +110,7 @@ static void sock_io_handler(struct event_context *ev, struct fd_event *fde,
struct sock_private *sock = p->transport.private;
if (flags & EVENT_FD_WRITE) {
- sock_process_send(p);
+ packet_queue_run(sock->packet);
return;
}
@@ -132,6 +128,11 @@ static void sock_io_handler(struct event_context *ev, struct fd_event *fde,
*/
static NTSTATUS sock_send_read(struct dcerpc_connection *p)
{
+ struct sock_private *sock = p->transport.private;
+ sock->pending_reads++;
+ if (sock->pending_reads == 1) {
+ packet_recv_enable(sock->packet);
+ }
return NT_STATUS_OK;
}
@@ -159,6 +160,10 @@ static NTSTATUS sock_send_request(struct dcerpc_connection *p, DATA_BLOB *data,
return status;
}
+ if (trigger_read) {
+ sock_send_read(p);
+ }
+
return NT_STATUS_OK;
}
@@ -230,10 +235,11 @@ static NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *c,
c->transport.peer_name = sock_peer_name;
sock->sock = socket_ctx;
+ sock->pending_reads = 0;
sock->server_name = strupper_talloc(sock, server);
sock->fde = event_add_fd(c->event_ctx, sock->sock, socket_get_fd(sock->sock),
- EVENT_FD_READ, sock_io_handler, c);
+ 0, sock_io_handler, c);
c->transport.private = sock;
@@ -249,6 +255,8 @@ static NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *c,
packet_set_error_handler(sock->packet, sock_error_handler);
packet_set_event_context(sock->packet, c->event_ctx);
packet_set_serialise(sock->packet, sock->fde);
+ packet_recv_disable(sock->packet);
+ packet_set_initial_read(sock->packet, 16);
/* ensure we don't get SIGPIPE */
BlockSignals(True,SIGPIPE);