diff options
Diffstat (limited to 'source4/lib/stream')
-rw-r--r-- | source4/lib/stream/packet.c | 45 | ||||
-rw-r--r-- | source4/lib/stream/packet.h | 3 |
2 files changed, 46 insertions, 2 deletions
diff --git a/source4/lib/stream/packet.c b/source4/lib/stream/packet.c index f367368def..7aed43a943 100644 --- a/source4/lib/stream/packet.c +++ b/source4/lib/stream/packet.c @@ -36,6 +36,7 @@ struct packet_context { DATA_BLOB partial; uint32_t initial_read_size; uint32_t num_read; + uint32_t initial_read; struct tls_context *tls; struct socket_context *sock; struct event_context *ev; @@ -44,6 +45,7 @@ struct packet_context { struct fd_event *fde; BOOL serialise; BOOL processing; + BOOL recv_disable; struct send_element { struct send_element *next, *prev; @@ -134,6 +136,15 @@ void packet_set_serialise(struct packet_context *pc, struct fd_event *fde) pc->fde = fde; } +/* + tell the packet layer how much to read when starting a new packet + this ensures it doesn't overread +*/ +void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read) +{ + pc->initial_read = initial_read; +} + /* tell the caller we have an error @@ -172,7 +183,9 @@ static void packet_next_event(struct event_context *ev, struct timed_event *te, struct timeval t, void *private) { struct packet_context *pc = talloc_get_type(private, struct packet_context); - packet_recv(pc); + if (pc->num_read != 0 && pc->packet_size >= pc->num_read) { + packet_recv(pc); + } } /* @@ -190,6 +203,11 @@ void packet_recv(struct packet_context *pc) return; } + if (pc->recv_disable) { + EVENT_FD_NOT_READABLE(pc->fde); + return; + } + if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) { goto next_partial; } @@ -198,6 +216,8 @@ void packet_recv(struct packet_context *pc) /* we've already worked out how long this next packet is, so skip the socket_pending() call */ npending = pc->packet_size - pc->num_read; + } else if (pc->initial_read != 0) { + npending = pc->initial_read - pc->num_read; } else { if (pc->tls) { status = tls_socket_pending(pc->tls, &npending); @@ -306,7 +326,7 @@ next_partial: if (pc->partial.length == 0) { return; } - + /* we got multiple packets in one tcp read */ if (pc->ev == NULL) { goto next_partial; @@ -330,6 +350,27 @@ next_partial: /* + temporarily disable receiving +*/ +void packet_recv_disable(struct packet_context *pc) +{ + EVENT_FD_NOT_READABLE(pc->fde); + pc->recv_disable = True; +} + +/* + re-enable receiving +*/ +void packet_recv_enable(struct packet_context *pc) +{ + EVENT_FD_READABLE(pc->fde); + pc->recv_disable = False; + if (pc->num_read != 0 && pc->packet_size >= pc->num_read) { + event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc); + } +} + +/* trigger a run of the send queue */ void packet_queue_run(struct packet_context *pc) diff --git a/source4/lib/stream/packet.h b/source4/lib/stream/packet.h index bba8a1940f..a8db89853c 100644 --- a/source4/lib/stream/packet.h +++ b/source4/lib/stream/packet.h @@ -39,7 +39,10 @@ void packet_set_tls(struct packet_context *pc, struct tls_context *tls); void packet_set_socket(struct packet_context *pc, struct socket_context *sock); void packet_set_event_context(struct packet_context *pc, struct event_context *ev); void packet_set_serialise(struct packet_context *pc, struct fd_event *fde); +void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read); void packet_recv(struct packet_context *pc); +void packet_recv_disable(struct packet_context *pc); +void packet_recv_enable(struct packet_context *pc); NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob); void packet_queue_run(struct packet_context *pc); |