aboutsummaryrefslogtreecommitdiff
path: root/src/tls/tls_drv.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/tls/tls_drv.c')
-rw-r--r--src/tls/tls_drv.c277
1 files changed, 249 insertions, 28 deletions
diff --git a/src/tls/tls_drv.c b/src/tls/tls_drv.c
index 02f4a26a7..811293406 100644
--- a/src/tls/tls_drv.c
+++ b/src/tls/tls_drv.c
@@ -1,5 +1,5 @@
/*
- * ejabberd, Copyright (C) 2002-2008 Process-one
+ * ejabberd, Copyright (C) 2002-2009 ProcessOne
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
@@ -10,7 +10,7 @@
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
- *
+ *
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
@@ -23,24 +23,187 @@
#include <erl_driver.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
-
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <stdint.h>
#define BUF_SIZE 1024
typedef struct {
ErlDrvPort port;
- SSL_CTX *ctx;
BIO *bio_read;
BIO *bio_write;
SSL *ssl;
} tls_data;
+#ifdef _WIN32
+typedef unsigned __int32 uint32_t;
+#endif
+
+#ifndef SSL_OP_NO_TICKET
+#define SSL_OP_NO_TICKET 0
+#endif
+
+/*
+ * str_hash is based on the public domain code from
+ * http://www.burtleburtle.net/bob/hash/doobs.html
+ */
+static uint32_t str_hash(char *s)
+{
+ unsigned char *key = (unsigned char *)s;
+ uint32_t hash = 0;
+ size_t i;
+
+ for (i = 0; key[i] != 0; i++) {
+ hash += key[i];
+ hash += (hash << 10);
+ hash ^= (hash >> 6);
+ }
+ hash += (hash << 3);
+ hash ^= (hash >> 11);
+ hash += (hash << 15);
+ return hash;
+}
+
+/* Linear hashing */
+
+#define MIN_LEVEL 8
+#define MAX_LEVEL 20
+
+struct bucket {
+ uint32_t hash;
+ char *key_file;
+ time_t mtime;
+ SSL_CTX *ssl_ctx;
+ struct bucket *next;
+};
+
+struct hash_table {
+ int split;
+ int level;
+ struct bucket **buckets;
+ int size;
+};
+
+struct hash_table ht;
+
+static void init_hash_table()
+{
+ size_t size = 1 << (MIN_LEVEL + 1);
+ size_t i;
+ ht.buckets = (struct bucket **)driver_alloc(sizeof(struct bucket *) * size);
+ ht.split = 0;
+ ht.level = MIN_LEVEL;
+ for (i = 0; i < size; i++)
+ ht.buckets[i] = NULL;
+
+}
+
+static void hash_table_insert(char *key_file, time_t mtime,
+ SSL_CTX *ssl_ctx)
+{
+ int level, split;
+ uint32_t hash = str_hash(key_file);
+ size_t bucket;
+ int do_split = 0;
+ struct bucket *el;
+ struct bucket *new_bucket_el;
+
+ split = ht.split;
+ level = ht.level;
+
+ bucket = hash & ((1 << level) - 1);
+ if (bucket < split)
+ bucket = hash & ((1 << (level + 1)) - 1);
+
+ el = ht.buckets[bucket];
+ while (el != NULL) {
+ if (el->hash == hash && strcmp(el->key_file, key_file) == 0) {
+ el->mtime = mtime;
+ if (el->ssl_ctx != NULL)
+ SSL_CTX_free(el->ssl_ctx);
+ el->ssl_ctx = ssl_ctx;
+ break;
+ }
+ el = el->next;
+ }
+
+ if (el == NULL) {
+ if (ht.buckets[bucket] != NULL)
+ do_split = !0;
+
+ new_bucket_el = (struct bucket *)driver_alloc(sizeof(struct bucket));
+ new_bucket_el->hash = hash;
+ new_bucket_el->key_file = (char *)driver_alloc(strlen(key_file) + 1);
+ strcpy(new_bucket_el->key_file, key_file);
+ new_bucket_el->mtime = mtime;
+ new_bucket_el->ssl_ctx = ssl_ctx;
+ new_bucket_el->next = ht.buckets[bucket];
+ ht.buckets[bucket] = new_bucket_el;
+ }
+
+ if (do_split) {
+ struct bucket **el_ptr = &ht.buckets[split];
+ size_t new_bucket = split + (1 << level);
+ while (*el_ptr != NULL) {
+ uint32_t hash = (*el_ptr)->hash;
+ if ((hash & ((1 << (level + 1)) - 1)) == new_bucket) {
+ struct bucket *moved_el = *el_ptr;
+ *el_ptr = (*el_ptr)->next;
+ moved_el->next = ht.buckets[new_bucket];
+ ht.buckets[new_bucket] = moved_el;
+ } else
+ el_ptr = &(*el_ptr)->next;
+ }
+ split++;
+ if (split == 1 << level) {
+ size_t size;
+ size_t i;
+ split = 0;
+ level++;
+ size = 1 << (level + 1);
+ ht.split = split;
+ ht.level = level;
+ ht.buckets = (struct bucket **)
+ driver_realloc(ht.buckets, sizeof(struct bucket *) * size);
+ for (i = 1 << level; i < size; i++)
+ ht.buckets[i] = NULL;
+ } else
+ ht.split = split;
+ }
+}
+
+static SSL_CTX *hash_table_lookup(char *key_file, time_t *pmtime)
+{
+ int level, split;
+ uint32_t hash = str_hash(key_file);
+ size_t bucket;
+ struct bucket *el;
+
+ split = ht.split;
+ level = ht.level;
+
+ bucket = hash & ((1 << level) - 1);
+ if (bucket < split)
+ bucket = hash & ((1 << (level + 1)) - 1);
+
+ el = ht.buckets[bucket];
+ while (el != NULL) {
+ if (el->hash == hash && strcmp(el->key_file, key_file) == 0) {
+ *pmtime = el->mtime;
+ return el->ssl_ctx;
+ }
+ el = el->next;
+ }
+
+ return NULL;
+}
+
static ErlDrvData tls_drv_start(ErlDrvPort port, char *buff)
{
tls_data *d = (tls_data *)driver_alloc(sizeof(tls_data));
d->port = port;
- d->ctx = NULL;
d->bio_read = NULL;
d->bio_write = NULL;
d->ssl = NULL;
@@ -57,12 +220,46 @@ static void tls_drv_stop(ErlDrvData handle)
if (d->ssl != NULL)
SSL_free(d->ssl);
- if (d->ctx != NULL)
- SSL_CTX_free(d->ctx);
-
driver_free((char *)handle);
}
+static void tls_drv_finish()
+{
+ int level;
+ struct bucket *el;
+ int i;
+
+ level = ht.level;
+ for (i = 0; i < 1 << (level + 1); i++) {
+ el = ht.buckets[i];
+ while (el != NULL) {
+ if (el->ssl_ctx != NULL)
+ SSL_CTX_free(el->ssl_ctx);
+ driver_free(el->key_file);
+ el = el->next;
+ }
+ }
+
+ driver_free(ht.buckets);
+}
+
+static int is_key_file_modified(char *file, time_t *key_file_mtime)
+{
+ struct stat file_stat;
+
+ if (stat(file, &file_stat))
+ {
+ *key_file_mtime = 0;
+ return 1;
+ } else {
+ if (*key_file_mtime != file_stat.st_mtime)
+ {
+ *key_file_mtime = file_stat.st_mtime;
+ return 1;
+ } else
+ return 0;
+ }
+}
static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
{
@@ -77,6 +274,7 @@ static int verify_callback(int preverify_ok, X509_STORE_CTX *ctx)
#define GET_DECRYPTED_INPUT 6
#define GET_PEER_CERTIFICATE 7
#define GET_VERIFY_RESULT 8
+#define VERIFY_NONE 0x10000
#define die_unless(cond, errstr) \
@@ -117,36 +315,55 @@ static int tls_drv_control(ErlDrvData handle,
int size;
ErlDrvBinary *b;
X509 *cert;
+ unsigned int flags = command;
+
+ command &= 0xffff;
ERR_clear_error();
switch (command)
{
case SET_CERTIFICATE_FILE_ACCEPT:
- case SET_CERTIFICATE_FILE_CONNECT:
- d->ctx = SSL_CTX_new(SSLv23_method());
- die_unless(d->ctx, "SSL_CTX_new failed");
+ case SET_CERTIFICATE_FILE_CONNECT: {
+ time_t mtime = 0;
+ SSL_CTX *ssl_ctx = hash_table_lookup(buf, &mtime);
+ if (is_key_file_modified(buf, &mtime) || ssl_ctx == NULL)
+ {
+ SSL_CTX *ctx;
- res = SSL_CTX_use_certificate_chain_file(d->ctx, buf);
- die_unless(res > 0, "SSL_CTX_use_certificate_file failed");
+ hash_table_insert(buf, mtime, NULL);
- res = SSL_CTX_use_PrivateKey_file(d->ctx, buf, SSL_FILETYPE_PEM);
- die_unless(res > 0, "SSL_CTX_use_PrivateKey_file failed");
+ ctx = SSL_CTX_new(SSLv23_method());
+ die_unless(ctx, "SSL_CTX_new failed");
- res = SSL_CTX_check_private_key(d->ctx);
- die_unless(res > 0, "SSL_CTX_check_private_key failed");
+ res = SSL_CTX_use_certificate_chain_file(ctx, buf);
+ die_unless(res > 0, "SSL_CTX_use_certificate_file failed");
- SSL_CTX_set_default_verify_paths(d->ctx);
+ res = SSL_CTX_use_PrivateKey_file(ctx, buf, SSL_FILETYPE_PEM);
+ die_unless(res > 0, "SSL_CTX_use_PrivateKey_file failed");
- if (command == SET_CERTIFICATE_FILE_ACCEPT)
- {
- SSL_CTX_set_verify(d->ctx,
- SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE,
- verify_callback);
+ res = SSL_CTX_check_private_key(ctx);
+ die_unless(res > 0, "SSL_CTX_check_private_key failed");
+
+ SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
+ SSL_CTX_set_default_verify_paths(ctx);
+
+ if (command == SET_CERTIFICATE_FILE_ACCEPT)
+ {
+ SSL_CTX_set_verify(ctx,
+ SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE,
+ verify_callback);
+ }
+
+ ssl_ctx = ctx;
+ hash_table_insert(buf, mtime, ssl_ctx);
}
-
- d->ssl = SSL_new(d->ctx);
+
+ d->ssl = SSL_new(ssl_ctx);
die_unless(d->ssl, "SSL_new failed");
+ if (flags & VERIFY_NONE)
+ SSL_set_verify(d->ssl, SSL_VERIFY_NONE, verify_callback);
+
d->bio_read = BIO_new(BIO_s_mem());
d->bio_write = BIO_new(BIO_s_mem());
@@ -154,9 +371,12 @@ static int tls_drv_control(ErlDrvData handle,
if (command == SET_CERTIFICATE_FILE_ACCEPT)
SSL_set_accept_state(d->ssl);
- else
+ else {
+ SSL_set_options(d->ssl, SSL_OP_NO_SSLv2|SSL_OP_NO_TICKET);
SSL_set_connect_state(d->ssl);
+ }
break;
+ }
case SET_ENCRYPTED_INPUT:
die_unless(d->ssl, "SSL not initialized");
BIO_write(d->bio_read, buf, len);
@@ -250,7 +470,7 @@ static int tls_drv_control(ErlDrvData handle,
rlen++;
b = driver_alloc_binary(rlen);
b->orig_bytes[0] = 0;
- tmp_buf = &b->orig_bytes[1];
+ tmp_buf = (unsigned char *)&b->orig_bytes[1];
i2d_X509(cert, &tmp_buf);
X509_free(cert);
*rbuf = (char *)b;
@@ -282,7 +502,7 @@ ErlDrvEntry tls_driver_entry = {
NULL, /* F_PTR ready_input, called when input descriptor ready */
NULL, /* F_PTR ready_output, called when output descriptor ready */
"tls_drv", /* char *driver_name, the argument to open_port */
- NULL, /* F_PTR finish, called when unloaded */
+ tls_drv_finish, /* F_PTR finish, called when unloaded */
NULL, /* handle */
tls_drv_control, /* F_PTR control, port_command callback */
NULL, /* F_PTR timeout, reserved */
@@ -293,6 +513,7 @@ DRIVER_INIT(tls_drv) /* must match name in driver_entry */
{
OpenSSL_add_ssl_algorithms();
SSL_load_error_strings();
+ init_hash_table();
return &tls_driver_entry;
}