diff options
Diffstat (limited to 'scripts/tls_test.c')
-rw-r--r-- | scripts/tls_test.c | 326 |
1 files changed, 326 insertions, 0 deletions
diff --git a/scripts/tls_test.c b/scripts/tls_test.c new file mode 100644 index 000000000..202f63bb2 --- /dev/null +++ b/scripts/tls_test.c @@ -0,0 +1,326 @@ +/* + * Copyright (C) 2010 Martin Willi + * Copyright (C) 2010 revosec AG + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License as published by the + * Free Software Foundation; either version 2 of the License, or (at your + * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY + * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * for more details. + */ + +#include <unistd.h> +#include <stdio.h> +#include <sys/types.h> +#include <sys/socket.h> +#include <getopt.h> + +#include <library.h> +#include <debug.h> +#include <tls_socket.h> +#include <utils/host.h> +#include <credentials/sets/mem_cred.h> + +/** + * Print usage information + */ +static void usage(FILE *out, char *cmd) +{ + fprintf(out, "usage:\n"); + fprintf(out, " %s --connect <address> --port <port> [--cert <file>]+\n", cmd); + fprintf(out, " %s --listen <address> --port <port> --key <key> [--cert <file>]+ --oneshot\n", cmd); +} + +/** + * Stream between stdio and TLS socket + */ +static int stream(int fd, tls_socket_t *tls) +{ + while (TRUE) + { + fd_set set; + chunk_t data; + + FD_ZERO(&set); + FD_SET(fd, &set); + FD_SET(0, &set); + + if (select(fd + 1, &set, NULL, NULL, NULL) == -1) + { + return 1; + } + if (FD_ISSET(fd, &set)) + { + if (!tls->read(tls, &data)) + { + fprintf(stderr, "TLS read error\n"); + return 1; + } + if (data.len) + { + ignore_result(write(0, data.ptr, data.len)); + free(data.ptr); + } + } + if (FD_ISSET(0, &set)) + { + char buf[1024]; + ssize_t len; + + len = read(0, buf, sizeof(buf)); + if (len == 0) + { + return 0; + } + if (len > 0) + { + if (!tls->write(tls, chunk_create(buf, len))) + { + fprintf(stderr, "TLS write error\n"); + return 1; + } + } + } + } +} + +/** + * Client routine + */ +static int client(int fd, host_t *host, identification_t *server) +{ + tls_socket_t *tls; + int res; + + if (connect(fd, host->get_sockaddr(host), + *host->get_sockaddr_len(host)) == -1) + { + fprintf(stderr, "connecting to %#H failed: %m\n", host); + return 1; + } + tls = tls_socket_create(FALSE, server, NULL, fd); + if (!tls) + { + return 1; + } + res = stream(fd, tls); + tls->destroy(tls); + return res; +} + +/** + * Server routine + */ +static int serve(int fd, host_t *host, identification_t *server, bool oneshot) +{ + tls_socket_t *tls; + int cfd; + + if (bind(fd, host->get_sockaddr(host), + *host->get_sockaddr_len(host)) == -1) + { + fprintf(stderr, "binding to %#H failed: %m\n", host); + return 1; + } + if (listen(fd, 1) == -1) + { + fprintf(stderr, "listen to %#H failed: %m\n", host); + return 1; + } + + do + { + cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host)); + if (cfd == -1) + { + fprintf(stderr, "accept failed: %m\n"); + return 1; + } + fprintf(stderr, "%#H connected\n", host); + + tls = tls_socket_create(TRUE, server, NULL, cfd); + if (!tls) + { + return 1; + } + stream(cfd, tls); + fprintf(stderr, "%#H disconnected\n", host); + tls->destroy(tls); + } + while (!oneshot); + + return 0; +} + +/** + * In-Memory credential set + */ +static mem_cred_t *creds; + +/** + * Load certificate from file + */ +static bool load_certificate(char *filename) +{ + certificate_t *cert; + + cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509, + BUILD_FROM_FILE, filename, BUILD_END); + if (!cert) + { + fprintf(stderr, "loading certificate from '%s' failed\n", filename); + return FALSE; + } + creds->add_cert(creds, TRUE, cert); + return TRUE; +} + +/** + * Load private key from file + */ +static bool load_key(char *filename) +{ + private_key_t *key; + + key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA, + BUILD_FROM_FILE, filename, BUILD_END); + if (!key) + { + fprintf(stderr, "loading key from '%s' failed\n", filename); + return FALSE; + } + creds->add_key(creds, key); + return TRUE; +} + +/** + * Cleanup + */ +static void cleanup() +{ + lib->credmgr->remove_set(lib->credmgr, &creds->set); + creds->destroy(creds); + library_deinit(); +} + +/** + * Initialize library + */ +static void init() +{ + library_init(NULL); + lib->plugins->load(lib->plugins, NULL, PLUGINS); + + creds = mem_cred_create(); + lib->credmgr->add_set(lib->credmgr, &creds->set); + + atexit(cleanup); +} + +int main(int argc, char *argv[]) +{ + char *address = NULL; + bool listen = FALSE, oneshot = FALSE; + int port = 0, fd, res; + identification_t *server; + host_t *host; + + init(); + + while (TRUE) + { + struct option long_opts[] = { + {"help", no_argument, NULL, 'h' }, + {"connect", required_argument, NULL, 'c' }, + {"listen", required_argument, NULL, 'l' }, + {"port", required_argument, NULL, 'p' }, + {"cert", required_argument, NULL, 'x' }, + {"key", required_argument, NULL, 'k' }, + {"oneshot", no_argument, NULL, 'o' }, + {0,0,0,0 } + }; + switch (getopt_long(argc, argv, "", long_opts, NULL)) + { + case EOF: + break; + case 'h': + usage(stdout, argv[0]); + return 0; + case 'x': + if (!load_certificate(optarg)) + { + return 1; + } + continue; + case 'k': + if (!load_key(optarg)) + { + return 1; + } + continue; + case 'l': + listen = TRUE; + /* fall */ + case 'c': + if (address) + { + usage(stderr, argv[0]); + return 1; + } + address = optarg; + continue; + case 'p': + port = atoi(optarg); + continue; + case 'o': + oneshot = TRUE; + continue; + default: + usage(stderr, argv[0]); + return 1; + } + break; + } + if (!port || !address) + { + usage(stderr, argv[0]); + return 1; + } + if (oneshot && !listen) + { + usage(stderr, argv[0]); + return 1; + } + + fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) + { + fprintf(stderr, "opening socket failed: %m\n"); + return 1; + } + host = host_create_from_dns(address, 0, port); + if (!host) + { + fprintf(stderr, "resolving hostname %s failed\n", address); + close(fd); + return 1; + } + server = identification_create_from_string(address); + if (listen) + { + res = serve(fd, host, server, oneshot); + } + else + { + res = client(fd, host, server); + } + close(fd); + host->destroy(host); + server->destroy(server); + return res; +} + |