diff options
Diffstat (limited to 'source4/lib/tls/tls_tstream.c')
-rw-r--r-- | source4/lib/tls/tls_tstream.c | 1249 |
1 files changed, 1249 insertions, 0 deletions
diff --git a/source4/lib/tls/tls_tstream.c b/source4/lib/tls/tls_tstream.c new file mode 100644 index 0000000000..96e6f6b998 --- /dev/null +++ b/source4/lib/tls/tls_tstream.c @@ -0,0 +1,1249 @@ +/* + Unix SMB/CIFS implementation. + + Copyright (C) Stefan Metzmacher 2010 + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +#include "includes.h" +#include "system/network.h" +#include "../util/tevent_unix.h" +#include "../lib/tsocket/tsocket.h" +#include "../lib/tsocket/tsocket_internal.h" +#include "lib/tls/tls.h" + +#if ENABLE_GNUTLS +#include "gnutls/gnutls.h" + +#define DH_BITS 1024 + +#if defined(HAVE_GNUTLS_DATUM) && !defined(HAVE_GNUTLS_DATUM_T) +typedef gnutls_datum gnutls_datum_t; +#endif + +#endif /* ENABLE_GNUTLS */ + +static const struct tstream_context_ops tstream_tls_ops; + +struct tstream_tls { + struct tstream_context *plain_stream; + int error; + +#if ENABLE_GNUTLS + gnutls_session tls_session; +#endif /* ENABLE_GNUTLS */ + + struct tevent_context *current_ev; + + struct tevent_immediate *im; + + struct { + uint8_t buffer[1024]; + struct iovec iov; + struct tevent_req *subreq; + } push, pull; + + struct { + struct tevent_req *req; + } handshake; + + struct { + off_t ofs; + size_t left; + uint8_t buffer[1024]; + struct tevent_req *req; + } write; + + struct { + off_t ofs; + size_t left; + uint8_t buffer[1024]; + struct tevent_req *req; + } read; + + struct { + struct tevent_req *req; + } disconnect; +}; + +static void tstream_tls_retry_handshake(struct tstream_context *stream); +static void tstream_tls_retry_read(struct tstream_context *stream); +static void tstream_tls_retry_write(struct tstream_context *stream); +static void tstream_tls_retry_disconnect(struct tstream_context *stream); +static void tstream_tls_retry_trigger(struct tevent_context *ctx, + struct tevent_immediate *im, + void *private_data); + +static void tstream_tls_retry(struct tstream_context *stream, bool deferred) +{ + + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + + if (tlss->disconnect.req) { + tstream_tls_retry_disconnect(stream); + return; + } + + if (tlss->handshake.req) { + tstream_tls_retry_handshake(stream); + return; + } + + if (tlss->write.req && tlss->read.req && !deferred) { + tevent_schedule_immediate(tlss->im, tlss->current_ev, + tstream_tls_retry_trigger, + stream); + } + + if (tlss->write.req) { + tstream_tls_retry_write(stream); + return; + } + + if (tlss->read.req) { + tstream_tls_retry_read(stream); + return; + } +} + +static void tstream_tls_retry_trigger(struct tevent_context *ctx, + struct tevent_immediate *im, + void *private_data) +{ + struct tstream_context *stream = + talloc_get_type_abort(private_data, + struct tstream_context); + + tstream_tls_retry(stream, true); +} + +#if ENABLE_GNUTLS +static void tstream_tls_push_done(struct tevent_req *subreq); + +static ssize_t tstream_tls_push_function(gnutls_transport_ptr ptr, + const void *buf, size_t size) +{ + struct tstream_context *stream = + talloc_get_type_abort(ptr, + struct tstream_context); + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *subreq; + + if (tlss->error != 0) { + errno = tlss->error; + return -1; + } + + if (tlss->push.subreq) { + errno = EAGAIN; + return -1; + } + + tlss->push.iov.iov_base = tlss->push.buffer; + tlss->push.iov.iov_len = MIN(size, sizeof(tlss->push.buffer)); + + memcpy(tlss->push.buffer, buf, tlss->push.iov.iov_len); + + subreq = tstream_writev_send(tlss, + tlss->current_ev, + tlss->plain_stream, + &tlss->push.iov, 1); + if (subreq == NULL) { + errno = ENOMEM; + return -1; + } + tevent_req_set_callback(subreq, tstream_tls_push_done, stream); + + tlss->push.subreq = subreq; + + return tlss->push.iov.iov_len; +} + +static void tstream_tls_push_done(struct tevent_req *subreq) +{ + struct tstream_context *stream = + tevent_req_callback_data(subreq, + struct tstream_context); + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + int ret; + int sys_errno; + + tlss->push.subreq = NULL; + ZERO_STRUCT(tlss->push.iov); + + ret = tstream_writev_recv(subreq, &sys_errno); + TALLOC_FREE(subreq); + if (ret == -1) { + tlss->error = sys_errno; + tstream_tls_retry(stream, false); + return; + } + + tstream_tls_retry(stream, false); +} + +static void tstream_tls_pull_done(struct tevent_req *subreq); + +static ssize_t tstream_tls_pull_function(gnutls_transport_ptr ptr, + void *buf, size_t size) +{ + struct tstream_context *stream = + talloc_get_type_abort(ptr, + struct tstream_context); + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *subreq; + + if (tlss->error != 0) { + errno = tlss->error; + return -1; + } + + if (tlss->pull.subreq) { + errno = EAGAIN; + return -1; + } + + if (tlss->pull.iov.iov_base) { + size_t n; + + n = MIN(tlss->pull.iov.iov_len, size); + memcpy(buf, tlss->pull.iov.iov_base, n); + + tlss->pull.iov.iov_len -= n; + if (tlss->pull.iov.iov_len == 0) { + tlss->pull.iov.iov_base = NULL; + } + + return n; + } + + if (size == 0) { + return 0; + } + + tlss->pull.iov.iov_base = tlss->pull.buffer; + tlss->pull.iov.iov_len = MIN(size, sizeof(tlss->pull.buffer)); + + subreq = tstream_readv_send(tlss, + tlss->current_ev, + tlss->plain_stream, + &tlss->pull.iov, 1); + if (subreq == NULL) { + errno = ENOMEM; + return -1; + } + tevent_req_set_callback(subreq, tstream_tls_pull_done, stream); + + tlss->pull.subreq = subreq; + errno = EAGAIN; + return -1; +} + +static void tstream_tls_pull_done(struct tevent_req *subreq) +{ + struct tstream_context *stream = + tevent_req_callback_data(subreq, + struct tstream_context); + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + int ret; + int sys_errno; + + tlss->pull.subreq = NULL; + + ret = tstream_readv_recv(subreq, &sys_errno); + TALLOC_FREE(subreq); + if (ret == -1) { + tlss->error = sys_errno; + tstream_tls_retry(stream, false); + return; + } + + tstream_tls_retry(stream, false); +} +#endif /* ENABLE_GNUTLS */ + +static int tstream_tls_destructor(struct tstream_tls *tlss) +{ +#if ENABLE_GNUTLS + if (tlss->tls_session) { + gnutls_deinit(tlss->tls_session); + tlss->tls_session = NULL; + } +#endif /* ENABLE_GNUTLS */ + return 0; +} + +static ssize_t tstream_tls_pending_bytes(struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + size_t ret; + + if (tlss->error != 0) { + errno = tlss->error; + return -1; + } + +#if ENABLE_GNUTLS + ret = gnutls_record_check_pending(tlss->tls_session); + ret += tlss->read.left; +#else /* ENABLE_GNUTLS */ + errno = ENOSYS; + ret = -1; +#endif /* ENABLE_GNUTLS */ + return ret; +} + +struct tstream_tls_readv_state { + struct tstream_context *stream; + + struct iovec *vector; + int count; + + int ret; +}; + +static void tstream_tls_readv_crypt_next(struct tevent_req *req); + +static struct tevent_req *tstream_tls_readv_send(TALLOC_CTX *mem_ctx, + struct tevent_context *ev, + struct tstream_context *stream, + struct iovec *vector, + size_t count) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req; + struct tstream_tls_readv_state *state; + + tlss->read.req = NULL; + tlss->current_ev = ev; + + req = tevent_req_create(mem_ctx, &state, + struct tstream_tls_readv_state); + if (req == NULL) { + return NULL; + } + + state->stream = stream; + state->ret = 0; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return tevent_req_post(req, ev); + } + + /* + * we make a copy of the vector so we can change the structure + */ + state->vector = talloc_array(state, struct iovec, count); + if (tevent_req_nomem(state->vector, req)) { + return tevent_req_post(req, ev); + } + memcpy(state->vector, vector, sizeof(struct iovec) * count); + state->count = count; + + tstream_tls_readv_crypt_next(req); + if (!tevent_req_is_in_progress(req)) { + return tevent_req_post(req, ev); + } + + return req; +} + +static void tstream_tls_readv_crypt_next(struct tevent_req *req) +{ + struct tstream_tls_readv_state *state = + tevent_req_data(req, + struct tstream_tls_readv_state); + struct tstream_tls *tlss = + tstream_context_data(state->stream, + struct tstream_tls); + + /* + * copy the pending buffer first + */ + while (tlss->read.left > 0 && state->count > 0) { + uint8_t *base = (uint8_t *)state->vector[0].iov_base; + size_t len = MIN(tlss->read.left, state->vector[0].iov_len); + + memcpy(base, tlss->read.buffer + tlss->read.ofs, len); + + base += len; + state->vector[0].iov_base = base; + state->vector[0].iov_len -= len; + + tlss->read.ofs += len; + tlss->read.left -= len; + + if (state->vector[0].iov_len == 0) { + state->vector += 1; + state->count -= 1; + } + + state->ret += len; + } + + if (state->count == 0) { + tevent_req_done(req); + return; + } + + tlss->read.req = req; + tstream_tls_retry_read(state->stream); +} + +static void tstream_tls_retry_read(struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req = tlss->read.req; +#if ENABLE_GNUTLS + int ret; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return; + } + + tlss->read.left = 0; + tlss->read.ofs = 0; + + ret = gnutls_record_recv(tlss->tls_session, + tlss->read.buffer, + sizeof(tlss->read.buffer)); + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + return; + } + + tlss->read.req = NULL; + + if (gnutls_error_is_fatal(ret) != 0) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + if (ret == 0) { + tlss->error = EPIPE; + tevent_req_error(req, tlss->error); + return; + } + + tlss->read.left = ret; + tstream_tls_readv_crypt_next(req); +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); +#endif /* ENABLE_GNUTLS */ +} + +static int tstream_tls_readv_recv(struct tevent_req *req, + int *perrno) +{ + struct tstream_tls_readv_state *state = + tevent_req_data(req, + struct tstream_tls_readv_state); + struct tstream_tls *tlss = + tstream_context_data(state->stream, + struct tstream_tls); + int ret; + + tlss->read.req = NULL; + + ret = tsocket_simple_int_recv(req, perrno); + if (ret == 0) { + ret = state->ret; + } + + tevent_req_received(req); + return ret; +} + +struct tstream_tls_writev_state { + struct tstream_context *stream; + + struct iovec *vector; + int count; + + int ret; +}; + +static void tstream_tls_writev_crypt_next(struct tevent_req *req); + +static struct tevent_req *tstream_tls_writev_send(TALLOC_CTX *mem_ctx, + struct tevent_context *ev, + struct tstream_context *stream, + const struct iovec *vector, + size_t count) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req; + struct tstream_tls_writev_state *state; + + tlss->write.req = NULL; + tlss->current_ev = ev; + + req = tevent_req_create(mem_ctx, &state, + struct tstream_tls_writev_state); + if (req == NULL) { + return NULL; + } + + state->stream = stream; + state->ret = 0; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return tevent_req_post(req, ev); + } + + /* + * we make a copy of the vector so we can change the structure + */ + state->vector = talloc_array(state, struct iovec, count); + if (tevent_req_nomem(state->vector, req)) { + return tevent_req_post(req, ev); + } + memcpy(state->vector, vector, sizeof(struct iovec) * count); + state->count = count; + + tstream_tls_writev_crypt_next(req); + if (!tevent_req_is_in_progress(req)) { + return tevent_req_post(req, ev); + } + + return req; +} + +static void tstream_tls_writev_crypt_next(struct tevent_req *req) +{ + struct tstream_tls_writev_state *state = + tevent_req_data(req, + struct tstream_tls_writev_state); + struct tstream_tls *tlss = + tstream_context_data(state->stream, + struct tstream_tls); + + tlss->write.left = sizeof(tlss->write.buffer); + tlss->write.ofs = 0; + + /* + * first fill our buffer + */ + while (tlss->write.left > 0 && state->count > 0) { + uint8_t *base = (uint8_t *)state->vector[0].iov_base; + size_t len = MIN(tlss->write.left, state->vector[0].iov_len); + + memcpy(tlss->write.buffer + tlss->write.ofs, base, len); + + base += len; + state->vector[0].iov_base = base; + state->vector[0].iov_len -= len; + + tlss->write.ofs += len; + tlss->write.left -= len; + + if (state->vector[0].iov_len == 0) { + state->vector += 1; + state->count -= 1; + } + + state->ret += len; + } + + if (tlss->write.ofs == 0) { + tevent_req_done(req); + return; + } + + tlss->write.left = tlss->write.ofs; + tlss->write.ofs = 0; + + tlss->write.req = req; + tstream_tls_retry_write(state->stream); +} + +static void tstream_tls_retry_write(struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req = tlss->write.req; +#if ENABLE_GNUTLS + int ret; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return; + } + + ret = gnutls_record_send(tlss->tls_session, + tlss->write.buffer + tlss->write.ofs, + tlss->write.left); + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + return; + } + + tlss->write.req = NULL; + + if (gnutls_error_is_fatal(ret) != 0) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + if (ret == 0) { + tlss->error = EPIPE; + tevent_req_error(req, tlss->error); + return; + } + + tlss->write.ofs += ret; + tlss->write.left -= ret; + + if (tlss->write.left > 0) { + tlss->write.req = req; + tstream_tls_retry_write(stream); + return; + } + + tstream_tls_writev_crypt_next(req); +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); +#endif /* ENABLE_GNUTLS */ +} + +static int tstream_tls_writev_recv(struct tevent_req *req, + int *perrno) +{ + struct tstream_tls_writev_state *state = + tevent_req_data(req, + struct tstream_tls_writev_state); + struct tstream_tls *tlss = + tstream_context_data(state->stream, + struct tstream_tls); + int ret; + + tlss->write.req = NULL; + + ret = tsocket_simple_int_recv(req, perrno); + if (ret == 0) { + ret = state->ret; + } + + tevent_req_received(req); + return ret; +} + +struct tstream_tls_disconnect_state { + uint8_t _dummy; +}; + +static struct tevent_req *tstream_tls_disconnect_send(TALLOC_CTX *mem_ctx, + struct tevent_context *ev, + struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req; + struct tstream_tls_disconnect_state *state; + + tlss->disconnect.req = NULL; + tlss->current_ev = ev; + + req = tevent_req_create(mem_ctx, &state, + struct tstream_tls_disconnect_state); + if (req == NULL) { + return NULL; + } + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return tevent_req_post(req, ev); + } + + tlss->disconnect.req = req; + tstream_tls_retry_disconnect(stream); + if (!tevent_req_is_in_progress(req)) { + return tevent_req_post(req, ev); + } + + return req; +} + +static void tstream_tls_retry_disconnect(struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req = tlss->disconnect.req; +#if ENABLE_GNUTLS + int ret; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return; + } + + ret = gnutls_bye(tlss->tls_session, GNUTLS_SHUT_WR); + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + return; + } + + tlss->disconnect.req = NULL; + + if (gnutls_error_is_fatal(ret) != 0) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + tevent_req_done(req); +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); +#endif /* ENABLE_GNUTLS */ +} + +static int tstream_tls_disconnect_recv(struct tevent_req *req, + int *perrno) +{ + int ret; + + ret = tsocket_simple_int_recv(req, perrno); + + tevent_req_received(req); + return ret; +} + +static const struct tstream_context_ops tstream_tls_ops = { + .name = "tls", + + .pending_bytes = tstream_tls_pending_bytes, + + .readv_send = tstream_tls_readv_send, + .readv_recv = tstream_tls_readv_recv, + + .writev_send = tstream_tls_writev_send, + .writev_recv = tstream_tls_writev_recv, + + .disconnect_send = tstream_tls_disconnect_send, + .disconnect_recv = tstream_tls_disconnect_recv, +}; + +struct tstream_tls_params { +#if ENABLE_GNUTLS + gnutls_certificate_credentials x509_cred; + gnutls_dh_params dh_params; +#endif /* ENABLE_GNUTLS */ + bool tls_enabled; +}; + +static int tstream_tls_params_destructor(struct tstream_tls_params *tlsp) +{ +#if ENABLE_GNUTLS + if (tlsp->x509_cred) { + gnutls_certificate_free_credentials(tlsp->x509_cred); + tlsp->x509_cred = NULL; + } + if (tlsp->dh_params) { + gnutls_dh_params_deinit(tlsp->dh_params); + tlsp->dh_params = NULL; + } +#endif /* ENABLE_GNUTLS */ + return 0; +} + +bool tstream_tls_params_enabled(struct tstream_tls_params *tlsp) +{ + return tlsp->tls_enabled; +} + +NTSTATUS tstream_tls_params_client(TALLOC_CTX *mem_ctx, + const char *ca_file, + const char *crl_file, + struct tstream_tls_params **_tlsp) +{ +#if ENABLE_GNUTLS + struct tstream_tls_params *tlsp; + int ret; + + ret = gnutls_global_init(); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + return NT_STATUS_NOT_SUPPORTED; + } + + tlsp = talloc_zero(mem_ctx, struct tstream_tls_params); + NT_STATUS_HAVE_NO_MEMORY(tlsp); + + talloc_set_destructor(tlsp, tstream_tls_params_destructor); + + ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_NO_MEMORY; + } + + if (ca_file && *ca_file) { + ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred, + ca_file, + GNUTLS_X509_FMT_PEM); + if (ret < 0) { + DEBUG(0,("TLS failed to initialise cafile %s - %s\n", + ca_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + } + + if (crl_file && *crl_file) { + ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred, + crl_file, + GNUTLS_X509_FMT_PEM); + if (ret < 0) { + DEBUG(0,("TLS failed to initialise crlfile %s - %s\n", + crl_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + } + + tlsp->tls_enabled = true; + + *_tlsp = tlsp; + return NT_STATUS_OK; +#else /* ENABLE_GNUTLS */ + return NT_STATUS_NOT_IMPLEMENTED; +#endif /* ENABLE_GNUTLS */ +} + +struct tstream_tls_connect_state { + struct tstream_context *tls_stream; +}; + +struct tevent_req *_tstream_tls_connect_send(TALLOC_CTX *mem_ctx, + struct tevent_context *ev, + struct tstream_context *plain_stream, + struct tstream_tls_params *tls_params, + const char *location) +{ + struct tevent_req *req; + struct tstream_tls_connect_state *state; +#if ENABLE_GNUTLS + struct tstream_tls *tlss; + int ret; + static const int cert_type_priority[] = { + GNUTLS_CRT_X509, + GNUTLS_CRT_OPENPGP, + 0 + }; +#endif /* ENABLE_GNUTLS */ + + req = tevent_req_create(mem_ctx, &state, + struct tstream_tls_connect_state); + if (req == NULL) { + return NULL; + } + +#if ENABLE_GNUTLS + state->tls_stream = tstream_context_create(state, + &tstream_tls_ops, + &tlss, + struct tstream_tls, + location); + if (tevent_req_nomem(state->tls_stream, req)) { + return tevent_req_post(req, ev); + } + ZERO_STRUCTP(tlss); + talloc_set_destructor(tlss, tstream_tls_destructor); + + tlss->plain_stream = plain_stream; + + tlss->current_ev = ev; + tlss->im = tevent_create_immediate(tlss); + if (tevent_req_nomem(tlss->im, req)) { + return tevent_req_post(req, ev); + } + + ret = gnutls_init(&tlss->tls_session, GNUTLS_CLIENT); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + ret = gnutls_set_default_priority(tlss->tls_session); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + gnutls_certificate_type_set_priority(tlss->tls_session, cert_type_priority); + + ret = gnutls_credentials_set(tlss->tls_session, + GNUTLS_CRD_CERTIFICATE, + tls_params->x509_cred); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream); + gnutls_transport_set_pull_function(tlss->tls_session, + (gnutls_pull_func)tstream_tls_pull_function); + gnutls_transport_set_push_function(tlss->tls_session, + (gnutls_push_func)tstream_tls_push_function); + gnutls_transport_set_lowat(tlss->tls_session, 0); + + tlss->handshake.req = req; + tstream_tls_retry_handshake(state->tls_stream); + if (!tevent_req_is_in_progress(req)) { + return tevent_req_post(req, ev); + } + + return req; +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); + return tevent_req_post(req, ev); +#endif /* ENABLE_GNUTLS */ +} + +int tstream_tls_connect_recv(struct tevent_req *req, + int *perrno, + TALLOC_CTX *mem_ctx, + struct tstream_context **tls_stream) +{ + struct tstream_tls_connect_state *state = + tevent_req_data(req, + struct tstream_tls_connect_state); + + if (tevent_req_is_unix_error(req, perrno)) { + tevent_req_received(req); + return -1; + } + + *tls_stream = talloc_move(mem_ctx, &state->tls_stream); + tevent_req_received(req); + return 0; +} + +extern void tls_cert_generate(TALLOC_CTX *, const char *, const char *, const char *, const char *); + +/* + initialise global tls state +*/ +NTSTATUS tstream_tls_params_server(TALLOC_CTX *mem_ctx, + const char *dns_host_name, + bool disable, + const char *key_file, + const char *cert_file, + const char *ca_file, + const char *crl_file, + const char *dhp_file, + struct tstream_tls_params **_tlsp) +{ + struct tstream_tls_params *tlsp; +#if ENABLE_GNUTLS + int ret; + + ret = gnutls_global_init(); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + return NT_STATUS_NOT_SUPPORTED; + } + + tlsp = talloc_zero(mem_ctx, struct tstream_tls_params); + NT_STATUS_HAVE_NO_MEMORY(tlsp); + + talloc_set_destructor(tlsp, tstream_tls_params_destructor); + + if (!file_exist(ca_file)) { + tls_cert_generate(tlsp, dns_host_name, + key_file, cert_file, ca_file); + } + + ret = gnutls_certificate_allocate_credentials(&tlsp->x509_cred); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_NO_MEMORY; + } + + if (ca_file && *ca_file) { + ret = gnutls_certificate_set_x509_trust_file(tlsp->x509_cred, + ca_file, + GNUTLS_X509_FMT_PEM); + if (ret < 0) { + DEBUG(0,("TLS failed to initialise cafile %s - %s\n", + ca_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + } + + if (crl_file && *crl_file) { + ret = gnutls_certificate_set_x509_crl_file(tlsp->x509_cred, + crl_file, + GNUTLS_X509_FMT_PEM); + if (ret < 0) { + DEBUG(0,("TLS failed to initialise crlfile %s - %s\n", + crl_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + } + + ret = gnutls_certificate_set_x509_key_file(tlsp->x509_cred, + cert_file, key_file, + GNUTLS_X509_FMT_PEM); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS failed to initialise certfile %s and keyfile %s - %s\n", + cert_file, key_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + + ret = gnutls_dh_params_init(&tlsp->dh_params); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_NO_MEMORY; + } + + if (dhp_file && *dhp_file) { + gnutls_datum_t dhparms; + size_t size; + + dhparms.data = (uint8_t *)file_load(dhp_file, &size, 0, tlsp); + + if (!dhparms.data) { + DEBUG(0,("TLS failed to read DH Parms from %s - %d:%s\n", + dhp_file, errno, strerror(errno))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + dhparms.size = size; + + ret = gnutls_dh_params_import_pkcs3(tlsp->dh_params, + &dhparms, + GNUTLS_X509_FMT_PEM); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS failed to import pkcs3 %s - %s\n", + dhp_file, gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_CANT_ACCESS_DOMAIN_INFO; + } + } else { + ret = gnutls_dh_params_generate2(tlsp->dh_params, DH_BITS); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS failed to generate dh_params - %s\n", + gnutls_strerror(ret))); + talloc_free(tlsp); + return NT_STATUS_INTERNAL_ERROR; + } + } + + gnutls_certificate_set_dh_params(tlsp->x509_cred, tlsp->dh_params); + + tlsp->tls_enabled = true; + +#else /* ENABLE_GNUTLS */ + tlsp = talloc_zero(mem_ctx, struct tstream_tls_params); + NT_STATUS_HAVE_NO_MEMORY(tlsp); + talloc_set_destructor(tlsp, tstream_tls_params_destructor); + tlsp->tls_enabled = false; +#endif /* ENABLE_GNUTLS */ + + *_tlsp = tlsp; + return NT_STATUS_OK; +} + +struct tstream_tls_accept_state { + struct tstream_context *tls_stream; +}; + +struct tevent_req *_tstream_tls_accept_send(TALLOC_CTX *mem_ctx, + struct tevent_context *ev, + struct tstream_context *plain_stream, + struct tstream_tls_params *tlsp, + const char *location) +{ + struct tevent_req *req; + struct tstream_tls_accept_state *state; + struct tstream_tls *tlss; +#if ENABLE_GNUTLS + int ret; +#endif /* ENABLE_GNUTLS */ + + req = tevent_req_create(mem_ctx, &state, + struct tstream_tls_accept_state); + if (req == NULL) { + return NULL; + } + + state->tls_stream = tstream_context_create(state, + &tstream_tls_ops, + &tlss, + struct tstream_tls, + location); + if (tevent_req_nomem(state->tls_stream, req)) { + return tevent_req_post(req, ev); + } + ZERO_STRUCTP(tlss); + talloc_set_destructor(tlss, tstream_tls_destructor); + +#if ENABLE_GNUTLS + tlss->plain_stream = plain_stream; + + tlss->current_ev = ev; + tlss->im = tevent_create_immediate(tlss); + if (tevent_req_nomem(tlss->im, req)) { + return tevent_req_post(req, ev); + } + + ret = gnutls_init(&tlss->tls_session, GNUTLS_SERVER); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + ret = gnutls_set_default_priority(tlss->tls_session); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + ret = gnutls_credentials_set(tlss->tls_session, GNUTLS_CRD_CERTIFICATE, + tlsp->x509_cred); + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(0,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tevent_req_error(req, EINVAL); + return tevent_req_post(req, ev); + } + + gnutls_certificate_server_set_request(tlss->tls_session, + GNUTLS_CERT_REQUEST); + gnutls_dh_set_prime_bits(tlss->tls_session, DH_BITS); + + gnutls_transport_set_ptr(tlss->tls_session, (gnutls_transport_ptr)state->tls_stream); + gnutls_transport_set_pull_function(tlss->tls_session, + (gnutls_pull_func)tstream_tls_pull_function); + gnutls_transport_set_push_function(tlss->tls_session, + (gnutls_push_func)tstream_tls_push_function); + gnutls_transport_set_lowat(tlss->tls_session, 0); + + tlss->handshake.req = req; + tstream_tls_retry_handshake(state->tls_stream); + if (!tevent_req_is_in_progress(req)) { + return tevent_req_post(req, ev); + } + + return req; +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); + return tevent_req_post(req, ev); +#endif /* ENABLE_GNUTLS */ +} + +static void tstream_tls_retry_handshake(struct tstream_context *stream) +{ + struct tstream_tls *tlss = + tstream_context_data(stream, + struct tstream_tls); + struct tevent_req *req = tlss->handshake.req; +#if ENABLE_GNUTLS + int ret; + + if (tlss->error != 0) { + tevent_req_error(req, tlss->error); + return; + } + + ret = gnutls_handshake(tlss->tls_session); + if (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN) { + return; + } + + tlss->handshake.req = NULL; + + if (gnutls_error_is_fatal(ret) != 0) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + if (ret != GNUTLS_E_SUCCESS) { + DEBUG(1,("TLS %s - %s\n", __location__, gnutls_strerror(ret))); + tlss->error = EIO; + tevent_req_error(req, tlss->error); + return; + } + + tevent_req_done(req); +#else /* ENABLE_GNUTLS */ + tevent_req_error(req, ENOSYS); +#endif /* ENABLE_GNUTLS */ +} + +int tstream_tls_accept_recv(struct tevent_req *req, + int *perrno, + TALLOC_CTX *mem_ctx, + struct tstream_context **tls_stream) +{ + struct tstream_tls_accept_state *state = + tevent_req_data(req, + struct tstream_tls_accept_state); + + if (tevent_req_is_unix_error(req, perrno)) { + tevent_req_received(req); + return -1; + } + + *tls_stream = talloc_move(mem_ctx, &state->tls_stream); + tevent_req_received(req); + return 0; +} |