diff options
-rw-r--r-- | source4/lib/stream/packet.c | 45 | ||||
-rw-r--r-- | source4/lib/stream/packet.h | 3 | ||||
-rw-r--r-- | source4/librpc/rpc/dcerpc_sock.c | 32 |
3 files changed, 66 insertions, 14 deletions
diff --git a/source4/lib/stream/packet.c b/source4/lib/stream/packet.c index f367368def..7aed43a943 100644 --- a/source4/lib/stream/packet.c +++ b/source4/lib/stream/packet.c @@ -36,6 +36,7 @@ struct packet_context { DATA_BLOB partial; uint32_t initial_read_size; uint32_t num_read; + uint32_t initial_read; struct tls_context *tls; struct socket_context *sock; struct event_context *ev; @@ -44,6 +45,7 @@ struct packet_context { struct fd_event *fde; BOOL serialise; BOOL processing; + BOOL recv_disable; struct send_element { struct send_element *next, *prev; @@ -134,6 +136,15 @@ void packet_set_serialise(struct packet_context *pc, struct fd_event *fde) pc->fde = fde; } +/* + tell the packet layer how much to read when starting a new packet + this ensures it doesn't overread +*/ +void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read) +{ + pc->initial_read = initial_read; +} + /* tell the caller we have an error @@ -172,7 +183,9 @@ static void packet_next_event(struct event_context *ev, struct timed_event *te, struct timeval t, void *private) { struct packet_context *pc = talloc_get_type(private, struct packet_context); - packet_recv(pc); + if (pc->num_read != 0 && pc->packet_size >= pc->num_read) { + packet_recv(pc); + } } /* @@ -190,6 +203,11 @@ void packet_recv(struct packet_context *pc) return; } + if (pc->recv_disable) { + EVENT_FD_NOT_READABLE(pc->fde); + return; + } + if (pc->packet_size != 0 && pc->num_read >= pc->packet_size) { goto next_partial; } @@ -198,6 +216,8 @@ void packet_recv(struct packet_context *pc) /* we've already worked out how long this next packet is, so skip the socket_pending() call */ npending = pc->packet_size - pc->num_read; + } else if (pc->initial_read != 0) { + npending = pc->initial_read - pc->num_read; } else { if (pc->tls) { status = tls_socket_pending(pc->tls, &npending); @@ -306,7 +326,7 @@ next_partial: if (pc->partial.length == 0) { return; } - + /* we got multiple packets in one tcp read */ if (pc->ev == NULL) { goto next_partial; @@ -330,6 +350,27 @@ next_partial: /* + temporarily disable receiving +*/ +void packet_recv_disable(struct packet_context *pc) +{ + EVENT_FD_NOT_READABLE(pc->fde); + pc->recv_disable = True; +} + +/* + re-enable receiving +*/ +void packet_recv_enable(struct packet_context *pc) +{ + EVENT_FD_READABLE(pc->fde); + pc->recv_disable = False; + if (pc->num_read != 0 && pc->packet_size >= pc->num_read) { + event_add_timed(pc->ev, pc, timeval_zero(), packet_next_event, pc); + } +} + +/* trigger a run of the send queue */ void packet_queue_run(struct packet_context *pc) diff --git a/source4/lib/stream/packet.h b/source4/lib/stream/packet.h index bba8a1940f..a8db89853c 100644 --- a/source4/lib/stream/packet.h +++ b/source4/lib/stream/packet.h @@ -39,7 +39,10 @@ void packet_set_tls(struct packet_context *pc, struct tls_context *tls); void packet_set_socket(struct packet_context *pc, struct socket_context *sock); void packet_set_event_context(struct packet_context *pc, struct event_context *ev); void packet_set_serialise(struct packet_context *pc, struct fd_event *fde); +void packet_set_initial_read(struct packet_context *pc, uint32_t initial_read); void packet_recv(struct packet_context *pc); +void packet_recv_disable(struct packet_context *pc); +void packet_recv_enable(struct packet_context *pc); NTSTATUS packet_send(struct packet_context *pc, DATA_BLOB blob); void packet_queue_run(struct packet_context *pc); diff --git a/source4/librpc/rpc/dcerpc_sock.c b/source4/librpc/rpc/dcerpc_sock.c index eb2e7f8f66..2ecf9f1530 100644 --- a/source4/librpc/rpc/dcerpc_sock.c +++ b/source4/librpc/rpc/dcerpc_sock.c @@ -35,6 +35,7 @@ struct sock_private { char *server_name; struct packet_context *packet; + uint32_t pending_reads; }; @@ -83,22 +84,17 @@ static NTSTATUS sock_complete_packet(void *private, DATA_BLOB blob, size_t *size } /* - process send requests -*/ -static void sock_process_send(struct dcerpc_connection *p) -{ - struct sock_private *sock = p->transport.private; - packet_queue_run(sock->packet); -} - - -/* process recv requests */ static NTSTATUS sock_process_recv(void *private, DATA_BLOB blob) { struct dcerpc_connection *p = talloc_get_type(private, struct dcerpc_connection); + struct sock_private *sock = p->transport.private; + sock->pending_reads--; + if (sock->pending_reads == 0) { + packet_recv_disable(sock->packet); + } p->transport.recv_data(p, &blob, NT_STATUS_OK); return NT_STATUS_OK; } @@ -114,7 +110,7 @@ static void sock_io_handler(struct event_context *ev, struct fd_event *fde, struct sock_private *sock = p->transport.private; if (flags & EVENT_FD_WRITE) { - sock_process_send(p); + packet_queue_run(sock->packet); return; } @@ -132,6 +128,11 @@ static void sock_io_handler(struct event_context *ev, struct fd_event *fde, */ static NTSTATUS sock_send_read(struct dcerpc_connection *p) { + struct sock_private *sock = p->transport.private; + sock->pending_reads++; + if (sock->pending_reads == 1) { + packet_recv_enable(sock->packet); + } return NT_STATUS_OK; } @@ -159,6 +160,10 @@ static NTSTATUS sock_send_request(struct dcerpc_connection *p, DATA_BLOB *data, return status; } + if (trigger_read) { + sock_send_read(p); + } + return NT_STATUS_OK; } @@ -230,10 +235,11 @@ static NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *c, c->transport.peer_name = sock_peer_name; sock->sock = socket_ctx; + sock->pending_reads = 0; sock->server_name = strupper_talloc(sock, server); sock->fde = event_add_fd(c->event_ctx, sock->sock, socket_get_fd(sock->sock), - EVENT_FD_READ, sock_io_handler, c); + 0, sock_io_handler, c); c->transport.private = sock; @@ -249,6 +255,8 @@ static NTSTATUS dcerpc_pipe_open_socket(struct dcerpc_connection *c, packet_set_error_handler(sock->packet, sock_error_handler); packet_set_event_context(sock->packet, c->event_ctx); packet_set_serialise(sock->packet, sock->fde); + packet_recv_disable(sock->packet); + packet_set_initial_read(sock->packet, 16); /* ensure we don't get SIGPIPE */ BlockSignals(True,SIGPIPE); |