diff options
-rw-r--r-- | source4/lib/tls/tls.c | 131 | ||||
-rw-r--r-- | source4/lib/tls/tls.h | 7 |
2 files changed, 126 insertions, 12 deletions
diff --git a/source4/lib/tls/tls.c b/source4/lib/tls/tls.c index 86a2ca0f0b..49f7b758c0 100644 --- a/source4/lib/tls/tls.c +++ b/source4/lib/tls/tls.c @@ -39,7 +39,6 @@ struct tls_params { /* hold per connection tls data */ struct tls_context { - struct tls_params *params; struct socket_context *socket; struct fd_event *fde; gnutls_session session; @@ -50,8 +49,19 @@ struct tls_context { BOOL tls_detect; const char *plain_chars; BOOL output_pending; + gnutls_certificate_credentials xcred; + BOOL interrupted; }; +#define TLSCHECK(call) do { \ + ret = call; \ + if (ret < 0) { \ + DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \ + goto failed; \ + } \ +} while (0) + + /* callback for reading from a socket @@ -80,7 +90,6 @@ static ssize_t tls_pull(gnutls_transport_ptr ptr, void *buf, size_t size) } if (!NT_STATUS_IS_OK(status)) { EVENT_FD_READABLE(tls->fde); - EVENT_FD_NOT_WRITEABLE(tls->fde); errno = EAGAIN; return -1; } @@ -153,6 +162,9 @@ static NTSTATUS tls_handshake(struct tls_context *tls) ret = gnutls_handshake(tls->session); if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + if (gnutls_record_get_direction(tls->session) == 1) { + EVENT_FD_WRITEABLE(tls->fde); + } return STATUS_MORE_ENTRIES; } if (ret < 0) { @@ -164,6 +176,28 @@ static NTSTATUS tls_handshake(struct tls_context *tls) } /* + possibly continue an interrupted operation +*/ +static NTSTATUS tls_interrupted(struct tls_context *tls) +{ + int ret; + + if (!tls->interrupted) { + return NT_STATUS_OK; + } + if (gnutls_record_get_direction(tls->session) == 1) { + ret = gnutls_record_send(tls->session, NULL, 0); + } else { + ret = gnutls_record_recv(tls->session, NULL, 0); + } + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + return STATUS_MORE_ENTRIES; + } + tls->interrupted = False; + return NT_STATUS_OK; +} + +/* see how many bytes are pending on the connection */ NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending) @@ -173,7 +207,12 @@ NTSTATUS tls_socket_pending(struct tls_context *tls, size_t *npending) } *npending = gnutls_record_check_pending(tls->session); if (*npending == 0) { - return socket_pending(tls->socket, npending); + NTSTATUS status = socket_pending(tls->socket, npending); + if (*npending == 0) { + /* seems to be a gnutls bug */ + (*npending) = 100; + } + return status; } return NT_STATUS_OK; } @@ -208,8 +247,15 @@ NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, status = tls_handshake(tls); NT_STATUS_NOT_OK_RETURN(status); + status = tls_interrupted(tls); + NT_STATUS_NOT_OK_RETURN(status); + ret = gnutls_record_recv(tls->session, buf, wantlen); if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + if (gnutls_record_get_direction(tls->session) == 1) { + EVENT_FD_WRITEABLE(tls->fde); + } + tls->interrupted = True; return STATUS_MORE_ENTRIES; } if (ret < 0) { @@ -235,8 +281,15 @@ NTSTATUS tls_socket_send(struct tls_context *tls, const DATA_BLOB *blob, size_t status = tls_handshake(tls); NT_STATUS_NOT_OK_RETURN(status); + status = tls_interrupted(tls); + NT_STATUS_NOT_OK_RETURN(status); + ret = gnutls_record_send(tls->session, blob->data, blob->length); if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + if (gnutls_record_get_direction(tls->session) == 1) { + EVENT_FD_WRITEABLE(tls->fde); + } + tls->interrupted = True; return STATUS_MORE_ENTRIES; } if (ret < 0) { @@ -317,6 +370,7 @@ struct tls_params *tls_initialise(TALLOC_CTX *mem_ctx) gnutls_certificate_set_dh_params(params->x509_cred, params->dh_params); params->tls_enabled = True; + return params; init_failed: @@ -349,14 +403,6 @@ struct tls_context *tls_init_server(struct tls_params *params, return tls; } -#define TLSCHECK(call) do { \ - ret = call; \ - if (ret < 0) { \ - DEBUG(0,("TLS %s - %s\n", #call, gnutls_strerror(ret))); \ - goto failed; \ - } \ -} while (0) - TLSCHECK(gnutls_init(&tls->session, GNUTLS_SERVER)); talloc_set_destructor(tls, tls_destructor); @@ -379,10 +425,10 @@ struct tls_context *tls_init_server(struct tls_params *params, } tls->output_pending = False; - tls->params = params; tls->done_handshake = False; tls->have_first_byte = False; tls->tls_enabled = True; + tls->interrupted = False; return tls; @@ -393,6 +439,60 @@ failed: return tls; } + +/* + setup for a new client connection +*/ +struct tls_context *tls_init_client(struct socket_context *socket, + struct fd_event *fde, + BOOL tls_enable) +{ + struct tls_context *tls; + int ret; + const int cert_type_priority[] = { GNUTLS_CRT_X509, GNUTLS_CRT_OPENPGP, 0 }; + tls = talloc(socket, struct tls_context); + if (tls == NULL) return NULL; + + tls->socket = socket; + tls->fde = fde; + tls->tls_enabled = tls_enable; + + if (!tls->tls_enabled) { + return tls; + } + + gnutls_global_init(); + + gnutls_certificate_allocate_credentials(&tls->xcred); + gnutls_certificate_set_x509_trust_file(tls->xcred, lp_tls_cafile(), + GNUTLS_X509_FMT_PEM); + TLSCHECK(gnutls_init(&tls->session, GNUTLS_CLIENT)); + TLSCHECK(gnutls_set_default_priority(tls->session)); + gnutls_certificate_type_set_priority(tls->session, cert_type_priority); + TLSCHECK(gnutls_credentials_set(tls->session, GNUTLS_CRD_CERTIFICATE, tls->xcred)); + + talloc_set_destructor(tls, tls_destructor); + + gnutls_transport_set_ptr(tls->session, (gnutls_transport_ptr)tls); + gnutls_transport_set_pull_function(tls->session, (gnutls_pull_func)tls_pull); + gnutls_transport_set_push_function(tls->session, (gnutls_push_func)tls_push); + gnutls_transport_set_lowat(tls->session, 0); + tls->tls_detect = False; + + tls->output_pending = False; + tls->done_handshake = False; + tls->have_first_byte = False; + tls->tls_enabled = True; + tls->interrupted = False; + + return tls; + +failed: + DEBUG(0,("TLS init connection failed - %s\n", gnutls_strerror(ret))); + tls->tls_enabled = False; + return tls; +} + BOOL tls_enabled(struct tls_context *tls) { return tls->tls_enabled; @@ -423,6 +523,13 @@ struct tls_context *tls_init_server(struct tls_params *params, return (struct tls_context *)sock; } +struct tls_context *tls_init_client(struct socket_context *sock, + struct fd_event *fde, + BOOL tls_enable) +{ + return (struct tls_context *)sock; +} + NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, size_t *nread) diff --git a/source4/lib/tls/tls.h b/source4/lib/tls/tls.h index 3046e35a1c..a046b91637 100644 --- a/source4/lib/tls/tls.h +++ b/source4/lib/tls/tls.h @@ -41,6 +41,13 @@ struct tls_context *tls_init_server(struct tls_params *parms, BOOL tls_enable); /* + call tls_init_client() on each new client connection +*/ +struct tls_context *tls_init_client(struct socket_context *sock, + struct fd_event *fde, + BOOL tls_enable); + +/* call these to send and receive data. They behave like socket_send() and socket_recv() */ NTSTATUS tls_socket_recv(struct tls_context *tls, void *buf, size_t wantlen, |