summaryrefslogtreecommitdiff
path: root/source4/lib/stream
diff options
context:
space:
mode:
Diffstat (limited to 'source4/lib/stream')
-rw-r--r--source4/lib/stream/packet.c45
-rw-r--r--source4/lib/stream/packet.h3
2 files changed, 46 insertions, 2 deletions
diff --git a/source4/lib/stream/packet.c b/source4/lib/stream/packet.c
index f367368def..7aed43a943 100644
--- a/source4/lib/stream/packet.c
+++ b/source4/lib/stream/packet.c
@@ -36,6 +36,7 @@ struct packet_context {
DATA_BLOB partial;
uint32_t initial_read_size;
uint32_t num_read;
+ uint32_t initial_read;
struct tls_context *tls;
struct socket_context *sock;
struct event_context *ev;
@@ -44,6 +45,7 @@ struct packet_context {
struct fd_event *fde;
BOOL serialise;
BOOL processing;
+ BOOL recv_disable;
struct send_element {
struct send_element *next, *prev;
@@ -134,6 +136,15 @@ void packet_set_serialise(struct packet_context *pc, struct fd_event *fde)
pc->fde = fde;
}
+/*
+ tell the packet layer how much to read when starting a new packet
+ this ensures it doesn't overread
+*/
+void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read)
+{
+ pc->initial_read = initial_read;
+}
+
/*
tell the caller we have an error
@@ -172,7 +183,9 @@ static void packet_next_event(struct event_context *ev, struct timed_event *te,
struct timeval t, void *private)
{
struct packet_context *pc = talloc_get_type(private, struct packet_context);
- packet_recv(pc);
+ if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
+ packet_recv(pc);
+ }
}
/*
@@ -190,6 +203,11 @@ void packet_recv(struct packet_context *pc)
return;
}
+ if (pc->recv_disable) {
+ EVENT_FD_NOT_READABLE(pc->fde);
+ return;
+ }
+
if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) {
goto next_partial;
}
@@ -198,6 +216,8 @@ void packet_recv(struct packet_context *pc)
/* we've already worked out how long this next packet is, so skip the
socket_pending() call */
npending = pc->packet_size - pc->num_read;
+ } else if (pc->initial_read != 0) {
+ npending = pc->initial_read - pc->num_read;
} else {
if (pc->tls) {
status = tls_socket_pending(pc->tls, &npending);
@@ -306,7 +326,7 @@ next_partial:
if (pc->partial.length == 0) {
return;
}
-
+
/* we got multiple packets in one tcp read */
if (pc->ev == NULL) {
goto next_partial;
@@ -330,6 +350,27 @@ next_partial:
/*
+ temporarily disable receiving
+*/
+void packet_recv_disable(struct packet_context *pc)
+{
+ EVENT_FD_NOT_READABLE(pc->fde);
+ pc->recv_disable = True;
+}
+
+/*
+ re-enable receiving
+*/
+void packet_recv_enable(struct packet_context *pc)
+{
+ EVENT_FD_READABLE(pc->fde);
+ pc->recv_disable = False;
+ if (pc->num_read != 0 && pc->packet_size >= pc->num_read) {
+ event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc);
+ }
+}
+
+/*
trigger a run of the send queue
*/
void packet_queue_run(struct packet_context *pc)
diff --git a/source4/lib/stream/packet.h b/source4/lib/stream/packet.h
index bba8a1940f..a8db89853c 100644
--- a/source4/lib/stream/packet.h
+++ b/source4/lib/stream/packet.h
@@ -39,7 +39,10 @@ void packet_set_tls(struct packet_context *pc, struct tls_context *tls);
void packet_set_socket(struct packet_context *pc, struct socket_context *sock);
void packet_set_event_context(struct packet_context *pc, struct event_context *ev);
void packet_set_serialise(struct packet_context *pc, struct fd_event *fde);
+void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read);
void packet_recv(struct packet_context *pc);
+void packet_recv_disable(struct packet_context *pc);
+void packet_recv_enable(struct packet_context *pc);
NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob);
void packet_queue_run(struct packet_context *pc);