From 33a4c0b5a1052026193dcbb800e2bccb1b832730 Mon Sep 17 00:00:00 2001 From: Jeremy Allison Date: Mon, 13 Jun 2005 22:26:08 +0000 Subject: r7554: Refactor very messy code in util_sock.c Remove write_socket_data/read_socket_data as they do nothing that write_socket/read_socket don't do. Add a more useful error message when read_socket/write_socket error out on the main client fd for a process (ie. try and list the IP of the client that errored). Jeremy. (This used to be commit cbd7578e7c226e6a8002542141b914ed4c7a8269) --- source3/lib/util_sock.c | 228 ++++++++++++++++++++---------------------------- 1 file changed, 95 insertions(+), 133 deletions(-) (limited to 'source3/lib') diff --git a/source3/lib/util_sock.c b/source3/lib/util_sock.c index 6107e5abed..e5e16f1c48 100644 --- a/source3/lib/util_sock.c +++ b/source3/lib/util_sock.c @@ -3,6 +3,7 @@ Samba utility functions Copyright (C) Andrew Tridgell 1992-1998 Copyright (C) Tim Potter 2000-2001 + Copyright (C) Jeremy Allison 1992-2005 This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by @@ -21,13 +22,15 @@ #include "includes.h" -/* the last IP received from */ -struct in_addr lastip; - -/* the last port received from */ -int lastport=0; +/* the following 3 client_*() functions are nasty ways of allowing + some generic functions to get info that really should be hidden in + particular modules */ +static int client_fd = -1; -int smb_read_error = 0; +void client_setfd(int fd) +{ + client_fd = fd; +} static char *get_socket_addr(int fd) { @@ -69,6 +72,47 @@ static int get_socket_port(int fd) return ntohs(sockin->sin_port); } +char *client_name(void) +{ + return get_peer_name(client_fd,False); +} + +char *client_addr(void) +{ + return get_peer_addr(client_fd); +} + +char *client_socket_addr(void) +{ + return get_socket_addr(client_fd); +} + +int client_socket_port(void) +{ + return get_socket_port(client_fd); +} + +struct in_addr *client_inaddr(struct sockaddr *sa) +{ + struct sockaddr_in *sockin = (struct sockaddr_in *) (sa); + socklen_t length = sizeof(*sa); + + if (getpeername(client_fd, sa, &length) < 0) { + DEBUG(0,("getpeername failed. Error was %s\n", strerror(errno) )); + return NULL; + } + + return &sockin->sin_addr; +} + +/* the last IP received from */ +struct in_addr lastip; + +/* the last port received from */ +int lastport=0; + +int smb_read_error = 0; + /**************************************************************************** Determine if a file descriptor is in fact a socket. ****************************************************************************/ @@ -356,8 +400,10 @@ ssize_t read_socket_with_timeout(int fd,char *buf,size_t mincnt,size_t maxcnt,un smb_read_error = 0; /* Blocking read */ - if (time_out <= 0) { - if (mincnt == 0) mincnt = maxcnt; + if (time_out == 0) { + if (mincnt == 0) { + mincnt = maxcnt; + } while (nread < mincnt) { readret = sys_read(fd, buf + nread, maxcnt - nread); @@ -369,7 +415,13 @@ ssize_t read_socket_with_timeout(int fd,char *buf,size_t mincnt,size_t maxcnt,un } if (readret == -1) { - DEBUG(0,("read_socket_with_timeout: read error = %s.\n", strerror(errno) )); + if (fd == client_fd) { + /* Try and give an error message saying what client failed. */ + DEBUG(0,("read_socket_with_timeout: client %s read error = %s.\n", + client_addr(), strerror(errno) )); + } else { + DEBUG(0,("read_socket_with_timeout: read error = %s.\n", strerror(errno) )); + } smb_read_error = READ_ERROR; return -1; } @@ -397,7 +449,13 @@ ssize_t read_socket_with_timeout(int fd,char *buf,size_t mincnt,size_t maxcnt,un /* Check if error */ if (selrtn == -1) { /* something is wrong. Maybe the socket is dead? */ - DEBUG(0,("read_socket_with_timeout: timeout read. select error = %s.\n", strerror(errno) )); + if (fd == client_fd) { + /* Try and give an error message saying what client failed. */ + DEBUG(0,("read_socket_with_timeout: timeout read for client %s. select error = %s.\n", + client_addr(), strerror(errno) )); + } else { + DEBUG(0,("read_socket_with_timeout: timeout read. select error = %s.\n", strerror(errno) )); + } smb_read_error = READ_ERROR; return -1; } @@ -420,7 +478,13 @@ ssize_t read_socket_with_timeout(int fd,char *buf,size_t mincnt,size_t maxcnt,un if (readret == -1) { /* the descriptor is probably dead */ - DEBUG(0,("read_socket_with_timeout: timeout read. read error = %s.\n", strerror(errno) )); + if (fd == client_fd) { + /* Try and give an error message saying what client failed. */ + DEBUG(0,("read_socket_with_timeout: timeout read to client %s. read error = %s.\n", + client_addr(), strerror(errno) )); + } else { + DEBUG(0,("read_socket_with_timeout: timeout read. read error = %s.\n", strerror(errno) )); + } smb_read_error = READ_ERROR; return -1; } @@ -453,37 +517,13 @@ ssize_t read_data(int fd,char *buffer,size_t N) } if (ret == -1) { - DEBUG(0,("read_data: read failure for %d. Error = %s\n", (int)(N - total), strerror(errno) )); - smb_read_error = READ_ERROR; - return -1; - } - total += ret; - } - return (ssize_t)total; -} - -/**************************************************************************** - Read data from a socket, reading exactly N bytes. -****************************************************************************/ - -static ssize_t read_socket_data(int fd,char *buffer,size_t N) -{ - ssize_t ret; - size_t total=0; - - smb_read_error = 0; - - while (total < N) { - ret = sys_read(fd,buffer + total,N - total); - - if (ret == 0) { - DEBUG(10,("read_socket_data: recv of %d returned 0. Error = %s\n", (int)(N - total), strerror(errno) )); - smb_read_error = READ_EOF; - return 0; - } - - if (ret == -1) { - DEBUG(0,("read_socket_data: recv failure for %d. Error = %s\n", (int)(N - total), strerror(errno) )); + if (fd == client_fd) { + /* Try and give an error message saying what client failed. */ + DEBUG(0,("read_data: read failure for %d bytes to client %s. Error = %s\n", + (int)(N - total), client_addr(), strerror(errno) )); + } else { + DEBUG(0,("read_data: read failure for %d. Error = %s\n", (int)(N - total), strerror(errno) )); + } smb_read_error = READ_ERROR; return -1; } @@ -505,60 +545,25 @@ ssize_t write_data(int fd, const char *buffer, size_t N) ret = sys_write(fd,buffer + total,N - total); if (ret == -1) { - DEBUG(0,("write_data: write failure. Error = %s\n", strerror(errno) )); + if (fd == client_fd) { + /* Try and give an error message saying what client failed. */ + DEBUG(0,("write_data: write failure in writing to client %s. Error %s\n", + client_addr(), strerror(errno) )); + } else { + DEBUG(0,("write_data: write failure. Error = %s\n", strerror(errno) )); + } return -1; } - if (ret == 0) - return total; - - total += ret; - } - return (ssize_t)total; -} - -/**************************************************************************** - Write data to a socket - use send rather than write. -****************************************************************************/ - -static ssize_t write_socket_data(int fd, const char *buffer, size_t N) -{ - size_t total=0; - ssize_t ret; - - while (total < N) { - ret = sys_send(fd,buffer + total,N - total,0); - if (ret == -1) { - DEBUG(0,("write_socket_data: write failure. Error = %s\n", strerror(errno) )); - return -1; - } - if (ret == 0) + if (ret == 0) { return total; + } total += ret; } return (ssize_t)total; } -/**************************************************************************** - Write to a socket. -****************************************************************************/ - -ssize_t write_socket(int fd, const char *buf, size_t len) -{ - ssize_t ret=0; - - DEBUG(6,("write_socket(%d,%d)\n",fd,(int)len)); - ret = write_socket_data(fd,buf,len); - - DEBUG(6,("write_socket(%d,%d) wrote %d\n",fd,(int)len,(int)ret)); - if(ret <= 0) - DEBUG(0,("write_socket: Error writing %d bytes to socket %d: ERRNO = %s\n", - (int)len, fd, strerror(errno) )); - - return(ret); -} - /**************************************************************************** Send a keepalive packet (rfc1002). ****************************************************************************/ @@ -570,7 +575,7 @@ BOOL send_keepalive(int client) buf[0] = SMBkeepalive; buf[1] = buf[2] = buf[3] = 0; - return(write_socket_data(client,(char *)buf,4) == 4); + return(write_data(client,(char *)buf,4) == 4); } @@ -592,7 +597,7 @@ static ssize_t read_smb_length_return_keepalive(int fd, char *inbuf, unsigned in if (timeout > 0) ok = (read_socket_with_timeout(fd,inbuf,4,4,timeout) == 4); else - ok = (read_socket_data(fd,inbuf,4) == 4); + ok = (read_data(fd,inbuf,4) == 4); if (!ok) return(-1); @@ -693,7 +698,7 @@ BOOL receive_smb_raw(int fd, char *buffer, unsigned int timeout) if (timeout > 0) { ret = read_socket_with_timeout(fd,buffer+4,len,len,timeout); } else { - ret = read_socket_data(fd,buffer+4,len); + ret = read_data(fd,buffer+4,len); } if (ret != len) { @@ -748,7 +753,7 @@ BOOL send_smb(int fd, char *buffer) len = smb_len(buffer) + 4; while (nwritten < len) { - ret = write_socket(fd,buffer+nwritten,len - nwritten); + ret = write_data(fd,buffer+nwritten,len - nwritten); if (ret <= 0) { DEBUG(0,("Error writing %d bytes to client. %d. (%s)\n", (int)len,(int)ret, strerror(errno) )); @@ -1086,49 +1091,6 @@ int open_udp_socket(const char *host, int port) } -/* the following 3 client_*() functions are nasty ways of allowing - some generic functions to get info that really should be hidden in - particular modules */ -static int client_fd = -1; - -void client_setfd(int fd) -{ - client_fd = fd; -} - -char *client_name(void) -{ - return get_peer_name(client_fd,False); -} - -char *client_addr(void) -{ - return get_peer_addr(client_fd); -} - -char *client_socket_addr(void) -{ - return get_socket_addr(client_fd); -} - -int client_socket_port(void) -{ - return get_socket_port(client_fd); -} - -struct in_addr *client_inaddr(struct sockaddr *sa) -{ - struct sockaddr_in *sockin = (struct sockaddr_in *) (sa); - socklen_t length = sizeof(*sa); - - if (getpeername(client_fd, sa, &length) < 0) { - DEBUG(0,("getpeername failed. Error was %s\n", strerror(errno) )); - return NULL; - } - - return &sockin->sin_addr; -} - /******************************************************************* Matchname - determine if host name matches IP address. Used to confirm a hostname lookup to prevent spoof attacks. -- cgit