diff options
-rw-r--r-- | source4/lib/socket/socket.c | 8 | ||||
-rw-r--r-- | source4/lib/socket/socket.h | 4 | ||||
-rw-r--r-- | source4/lib/socket/socket_ipv4.c | 44 | ||||
-rw-r--r-- | source4/lib/socket/socket_ipv6.c | 39 | ||||
-rw-r--r-- | source4/lib/socket/socket_unix.c | 39 |
5 files changed, 104 insertions, 30 deletions
diff --git a/source4/lib/socket/socket.c b/source4/lib/socket/socket.c index 9be6faf084..cc43348e79 100644 --- a/source4/lib/socket/socket.c +++ b/source4/lib/socket/socket.c @@ -107,6 +107,14 @@ NTSTATUS socket_connect(struct socket_context *sock, return sock->ops->fn_connect(sock, my_address, my_port, server_address, server_port, flags); } +NTSTATUS socket_connect_complete(struct socket_context *sock, uint32_t flags) +{ + if (!sock->ops->fn_connect_complete) { + return NT_STATUS_NOT_IMPLEMENTED; + } + return sock->ops->fn_connect_complete(sock, flags); +} + NTSTATUS socket_listen(struct socket_context *sock, const char *my_address, int port, int queue_size, uint32_t flags) { if (sock->type != SOCKET_TYPE_STREAM) { diff --git a/source4/lib/socket/socket.h b/source4/lib/socket/socket.h index 7a8d335962..7dd8c0ae17 100644 --- a/source4/lib/socket/socket.h +++ b/source4/lib/socket/socket.h @@ -39,6 +39,10 @@ struct socket_ops { const char *server_address, int server_port, uint32_t flags); + /* complete a non-blocking connect */ + NTSTATUS (*fn_connect_complete)(struct socket_context *sock, + uint32_t flags); + /* server ops */ NTSTATUS (*fn_listen)(struct socket_context *sock, const char *my_address, int port, int queue_size, uint32_t flags); diff --git a/source4/lib/socket/socket_ipv4.c b/source4/lib/socket/socket_ipv4.c index 88570512a4..7cf2b73e4e 100644 --- a/source4/lib/socket/socket_ipv4.c +++ b/source4/lib/socket/socket_ipv4.c @@ -1,7 +1,10 @@ /* Unix SMB/CIFS implementation. + Socket IPv4 functions + Copyright (C) Stefan Metzmacher 2004 + Copyright (C) Andrew Tridgell 2004-2005 This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -36,6 +39,34 @@ static void ipv4_tcp_close(struct socket_context *sock) close(sock->fd); } +static NTSTATUS ipv4_tcp_connect_complete(struct socket_context *sock, uint32_t flags) +{ + int error=0, ret; + socklen_t len = sizeof(error); + + /* check for any errors that may have occurred - this is needed + for non-blocking connect */ + ret = getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, &error, &len); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + if (error != 0) { + return map_nt_error_from_unix(error); + } + + if (!(flags & SOCKET_FLAG_BLOCK)) { + ret = set_blocking(sock->fd, False); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + } + + sock->state = SOCKET_STATE_CLIENT_CONNECTED; + + return NT_STATUS_OK; +} + + static NTSTATUS ipv4_tcp_connect(struct socket_context *sock, const char *my_address, int my_port, const char *srv_address, int srv_port, @@ -79,18 +110,10 @@ static NTSTATUS ipv4_tcp_connect(struct socket_context *sock, return map_nt_error_from_unix(errno); } - if (!(flags & SOCKET_FLAG_BLOCK)) { - ret = set_blocking(sock->fd, False); - if (ret == -1) { - return map_nt_error_from_unix(errno); - } - } - - sock->state = SOCKET_STATE_CLIENT_CONNECTED; - - return NT_STATUS_OK; + return ipv4_tcp_connect_complete(sock, flags); } + static NTSTATUS ipv4_tcp_listen(struct socket_context *sock, const char *my_address, int port, int queue_size, uint32_t flags) @@ -315,6 +338,7 @@ static const struct socket_ops ipv4_tcp_ops = { .fn_init = ipv4_tcp_init, .fn_connect = ipv4_tcp_connect, + .fn_connect_complete = ipv4_tcp_connect_complete, .fn_listen = ipv4_tcp_listen, .fn_accept = ipv4_tcp_accept, .fn_recv = ipv4_tcp_recv, diff --git a/source4/lib/socket/socket_ipv6.c b/source4/lib/socket/socket_ipv6.c index 1685f17572..35b4037ff4 100644 --- a/source4/lib/socket/socket_ipv6.c +++ b/source4/lib/socket/socket_ipv6.c @@ -50,6 +50,33 @@ static void ipv6_tcp_close(struct socket_context *sock) close(sock->fd); } +static NTSTATUS ipv6_tcp_connect_complete(struct socket_context *sock, uint32_t flags) +{ + int error=0, ret; + socklen_t len = sizeof(error); + + /* check for any errors that may have occurred - this is needed + for non-blocking connect */ + ret = getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, &error, &len); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + if (error != 0) { + return map_nt_error_from_unix(error); + } + + if (!(flags & SOCKET_FLAG_BLOCK)) { + ret = set_blocking(sock->fd, False); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + } + + sock->state = SOCKET_STATE_CLIENT_CONNECTED; + + return NT_STATUS_OK; +} + static NTSTATUS ipv6_tcp_connect(struct socket_context *sock, const char *my_address, int my_port, const char *srv_address, int srv_port, @@ -87,16 +114,7 @@ static NTSTATUS ipv6_tcp_connect(struct socket_context *sock, return map_nt_error_from_unix(errno); } - if (!(flags & SOCKET_FLAG_BLOCK)) { - ret = set_blocking(sock->fd, False); - if (ret == -1) { - return map_nt_error_from_unix(errno); - } - } - - sock->state = SOCKET_STATE_CLIENT_CONNECTED; - - return NT_STATUS_OK; + return ipv6_tcp_connect_complete(sock, flags); } static NTSTATUS ipv6_tcp_listen(struct socket_context *sock, @@ -333,6 +351,7 @@ static const struct socket_ops ipv6_tcp_ops = { .fn_init = ipv6_tcp_init, .fn_connect = ipv6_tcp_connect, + .fn_connect_complete = ipv6_tcp_connect_complete, .fn_listen = ipv6_tcp_listen, .fn_accept = ipv6_tcp_accept, .fn_recv = ipv6_tcp_recv, diff --git a/source4/lib/socket/socket_unix.c b/source4/lib/socket/socket_unix.c index e35453e6e0..60a4b9ec48 100644 --- a/source4/lib/socket/socket_unix.c +++ b/source4/lib/socket/socket_unix.c @@ -50,6 +50,33 @@ static void unixdom_close(struct socket_context *sock) close(sock->fd); } +static NTSTATUS unixdom_connect_complete(struct socket_context *sock, uint32_t flags) +{ + int error=0, ret; + socklen_t len = sizeof(error); + + /* check for any errors that may have occurred - this is needed + for non-blocking connect */ + ret = getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, &error, &len); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + if (error != 0) { + return map_nt_error_from_unix(error); + } + + if (!(flags & SOCKET_FLAG_BLOCK)) { + ret = set_blocking(sock->fd, False); + if (ret == -1) { + return map_nt_error_from_unix(errno); + } + } + + sock->state = SOCKET_STATE_CLIENT_CONNECTED; + + return NT_STATUS_OK; +} + static NTSTATUS unixdom_connect(struct socket_context *sock, const char *my_address, int my_port, const char *srv_address, int srv_port, @@ -66,21 +93,12 @@ static NTSTATUS unixdom_connect(struct socket_context *sock, srv_addr.sun_family = AF_UNIX; strncpy(srv_addr.sun_path, srv_address, sizeof(srv_addr.sun_path)); - if (!(flags & SOCKET_FLAG_BLOCK)) { - ret = set_blocking(sock->fd, False); - if (ret == -1) { - return NT_STATUS_INVALID_PARAMETER; - } - } - ret = connect(sock->fd, (const struct sockaddr *)&srv_addr, sizeof(srv_addr)); if (ret == -1) { return unixdom_error(errno); } - sock->state = SOCKET_STATE_CLIENT_CONNECTED; - - return NT_STATUS_OK; + return unixdom_connect_complete(sock, flags); } static NTSTATUS unixdom_listen(struct socket_context *sock, @@ -252,6 +270,7 @@ static const struct socket_ops unixdom_ops = { .fn_init = unixdom_init, .fn_connect = unixdom_connect, + .fn_connect_complete = unixdom_connect_complete, .fn_listen = unixdom_listen, .fn_accept = unixdom_accept, .fn_recv = unixdom_recv, |