From 25593b5e6fea76ed7c08db586924032c0810c27e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20Ter=C3=A4s?= Date: Sun, 7 Nov 2010 00:47:39 +0200 Subject: squark: reorganize sources to src directory --- src/Makefile | 30 ++ src/addr.c | 74 ++++ src/addr.h | 32 ++ src/authdb.c | 364 +++++++++++++++ src/authdb.h | 62 +++ src/blob.c | 426 ++++++++++++++++++ src/blob.h | 63 +++ src/filterdb.c | 157 +++++++ src/filterdb.h | 59 +++ src/lua-squarkdb.c | 343 ++++++++++++++ src/sqdb-build.lua | 335 ++++++++++++++ src/squark-auth-ip.c | 218 +++++++++ src/squark-auth-snmp.c | 1152 ++++++++++++++++++++++++++++++++++++++++++++++++ src/squark-filter.c | 431 ++++++++++++++++++ 14 files changed, 3746 insertions(+) create mode 100644 src/Makefile create mode 100644 src/addr.c create mode 100644 src/addr.h create mode 100644 src/authdb.c create mode 100644 src/authdb.h create mode 100644 src/blob.c create mode 100644 src/blob.h create mode 100644 src/filterdb.c create mode 100644 src/filterdb.h create mode 100644 src/lua-squarkdb.c create mode 100755 src/sqdb-build.lua create mode 100644 src/squark-auth-ip.c create mode 100644 src/squark-auth-snmp.c create mode 100644 src/squark-filter.c (limited to 'src') diff --git a/src/Makefile b/src/Makefile new file mode 100644 index 0000000..db683a2 --- /dev/null +++ b/src/Makefile @@ -0,0 +1,30 @@ +TARGETS=squark-auth-snmp squark-auth-ip squark-filter squarkdb.so + +NETSNMP_CFLAGS:=$(shell net-snmp-config --cflags) +NETSNMP_LIBS:=$(shell net-snmp-config --libs) +LUA_CFLAGS:=$(shell pkg-config --cflags lua5.1) +LUA_LIBS:=$(shell pkg-config --libs lua5.1) +CMPH_CFLAGS:=$(shell pkg-config --cflags cmph) +CMPH_LIBS:=$(shell pkg-config --libs cmph) + +CC=gcc +CFLAGS=-g -I. $(NETSNMP_CFLAGS) $(LUA_CFLAGS) $(CMPH_CFLAGS) -std=gnu99 -D_GNU_SOURCE -Wall +LIBS+=-lrt + +all: $(TARGETS) + +squark-auth-snmp: squark-auth-snmp.o filterdb.o authdb.o blob.o addr.o + $(CC) -o $@ $^ $(NETSNMP_LIBS) $(LIBS) + +squark-auth-ip: squark-auth-ip.o filterdb.o authdb.o blob.o addr.o + $(CC) -o $@ $^ $(LIBS) + +squark-filter: squark-filter.o filterdb.o authdb.o blob.o addr.o + $(CC) -o $@ $^ $(CMPH_LIBS) $(LIBS) + +squarkdb.so: lua-squarkdb.o filterdb.o blob.o + $(CC) -shared -o $@ $^ $(LUA_LIBS) $(CMPH_LIBS) $(LIBS) + +clean: + rm $(OBJS1) $(TARGETS) + diff --git a/src/addr.c b/src/addr.c new file mode 100644 index 0000000..47013f2 --- /dev/null +++ b/src/addr.c @@ -0,0 +1,74 @@ +#include +#include + +#include "addr.h" + +int addr_len(const sockaddr_any *addr) +{ + switch (addr->any.sa_family) { + case AF_INET: + return sizeof(struct sockaddr_in); + default: + return 0; + } +} + +sockaddr_any *addr_parse(blob_t b, sockaddr_any *addr) +{ + memset(addr, 0, sizeof(*addr)); + addr->ipv4.sin_family = AF_INET; + addr->ipv4.sin_addr.s_addr = blob_inet_addr(b); + if (addr->ipv4.sin_addr.s_addr == -1) + return NULL; + return addr; +} + +unsigned long addr_hash(const sockaddr_any *addr) +{ + switch (addr->any.sa_family) { + case AF_INET: + return htonl(addr->ipv4.sin_addr.s_addr); + default: + return 0; + } +} + +const char *addr_print(const sockaddr_any *addr) +{ + switch (addr->any.sa_family) { + case AF_INET: + return inet_ntoa(addr->ipv4.sin_addr); + default: + return "unknown"; + } +} + +blob_t addr_get_hostaddr_blob(const sockaddr_any *addr) +{ + switch (addr->any.sa_family) { + case AF_INET: + return BLOB_BUF(&addr->ipv4.sin_addr); + default: + return BLOB_NULL; + } +} + +void addr_push_hostaddr(blob_t *b, const sockaddr_any *addr) +{ + char buf[64]; + blob_t f; + unsigned int t; + + switch (addr->any.sa_family) { + case AF_INET: + t = ntohl(addr->ipv4.sin_addr.s_addr); + f.ptr = buf; + f.len = sprintf(buf, "%d.%d.%d.%d", + (t ) & 0xff, (t >> 8) & 0xff, + (t >> 16) & 0xff, (t >> 24) & 0xff); + break; + default: + return; + } + blob_push(b, f); +} diff --git a/src/addr.h b/src/addr.h new file mode 100644 index 0000000..452d14b --- /dev/null +++ b/src/addr.h @@ -0,0 +1,32 @@ +#ifndef ADDR_H +#define ADDR_H + +#include +#include "blob.h" + +typedef union { + struct sockaddr any; + struct sockaddr_in ipv4; +} sockaddr_any; + +int addr_len(const sockaddr_any *addr); +sockaddr_any *addr_parse(blob_t text, sockaddr_any *addr); +unsigned long addr_hash(const sockaddr_any *addr); +const char *addr_print(const sockaddr_any *addr); +blob_t addr_get_hostaddr_blob(const sockaddr_any *addr); +void addr_push_hostaddr(blob_t *b, const sockaddr_any *addr); + +static inline void addr_copy(sockaddr_any *dst, const sockaddr_any *src) +{ + memcpy(dst, src, addr_len(src)); +} + +static inline int addr_cmp(const sockaddr_any *a, const sockaddr_any *b) +{ + if (a->any.sa_family != b->any.sa_family) + return -1; + return memcmp(a, b, addr_len(a)); +} + +#endif + diff --git a/src/authdb.c b/src/authdb.c new file mode 100644 index 0000000..e6e71c4 --- /dev/null +++ b/src/authdb.c @@ -0,0 +1,364 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "authdb.h" +#include "filterdb.h" +#include "addr.h" +#include "blob.h" + +#define ALIGN(s,a) (((s) + a - 1) & ~(a - 1)) + +#define AUTHDB_IP_PER_ME 256 +#define AUTHDB_LOGOFF_PERIOD (15*60) /* 15 mins */ +#define AUTHDB_SHM_SIZE ALIGN(sizeof(struct authdb_entry[AUTHDB_IP_PER_ME]), 4096) + +static struct authdb_map_entry *authdb_me_open(sockaddr_any *addr, int create) +{ + int oflag, fd; + char name[64], buf[256]; + blob_t b = BLOB_BUF(name); + void *base; + struct authdb_map_entry *me; + struct group grp, *res; + + blob_push(&b, BLOB_STR("squark-auth-")); + blob_push_hexdump(&b, addr_get_hostaddr_blob(addr)); + blob_push_byte(&b, 0); + + oflag = O_RDWR; + if (create) + oflag |= O_CREAT; + + fd = shm_open(name, oflag, 0660); + if (fd < 0) + return NULL; + + if (ftruncate(fd, AUTHDB_SHM_SIZE) < 0) { + close(fd); + return NULL; + } + + getgrnam_r("squark", &grp, buf, sizeof(buf), &res); + if (res != NULL) { + fchown(fd, -1, res->gr_gid); + fchmod(fd, 0660); + } + + base = mmap(NULL, AUTHDB_SHM_SIZE, + PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + close(fd); + + if (base == MAP_FAILED) + return NULL; + + me = malloc(sizeof(*me)); + if (me == NULL) { + munmap(base, AUTHDB_SHM_SIZE); + return NULL; + } + + me->next = NULL; + me->baseaddr = *addr; + me->entries = base; + + return me; +} + +static void authdb_me_free(struct authdb_map_entry *me) +{ + munmap(me->entries, AUTHDB_SHM_SIZE); + free(me); +} + +int authdb_open(struct authdb *adb, struct authdb_config *cfg, struct sqdb *db) +{ + memset(adb, 0, sizeof(*adb)); + memset(cfg, 0, sizeof(*cfg)); + cfg->db = db; + return adbc_refresh(cfg, time(NULL)); +} + +void authdb_close(struct authdb *adb) +{ + struct authdb_map_entry *c, *n; + int i; + + for (i = 0; i < ARRAY_SIZE(adb->hash_bucket); i++) { + for (c = adb->hash_bucket[i]; c != NULL; c = n) { + n = c->next; + authdb_me_free(c); + } + } +} + +static unsigned int MurmurHash2(const void * key, int len, unsigned int seed) +{ + // 'm' and 'r' are mixing constants generated offline. + // They're not really 'magic', they just happen to work well. + const unsigned int m = 0x5bd1e995; + const int r = 24; + unsigned int h = seed ^ len; + const unsigned char * data = (const unsigned char *)key; + + while(len >= 4) + { + unsigned int k = *(unsigned int *)data; + + k *= m; + k ^= k >> r; + k *= m; + + h *= m; + h ^= k; + + data += 4; + len -= 4; + } + + switch(len) + { + case 3: h ^= data[2] << 16; + case 2: h ^= data[1] << 8; + case 1: h ^= data[0]; + h *= m; + } + + h ^= h >> 13; + h *= m; + h ^= h >> 15; + + return h; +} + +static uint32_t authdb_entry_checksum(struct authdb_entry *entry) +{ + return MurmurHash2(&entry->p, sizeof(entry->p), 0); +} + +void *authdb_get(struct authdb *adb, sockaddr_any *addr, struct authdb_entry *entry, int create) +{ + struct authdb_map_entry *me; + unsigned int hash, e, i; + sockaddr_any baseaddr; + blob_t b; + + baseaddr = *addr; + b = addr_get_hostaddr_blob(&baseaddr); + if (b.len < 4) + return NULL; + + e = (unsigned char) b.ptr[0]; + b.ptr[0] = 0x00; + + hash = b.ptr[1] + b.ptr[2] + b.ptr[3]; + hash %= ARRAY_SIZE(adb->hash_bucket); + + for (me = adb->hash_bucket[hash]; me != NULL; me = me->next) { + if (addr_cmp(&baseaddr, &me->baseaddr) == 0) + break; + } + if (me == NULL) { + me = authdb_me_open(&baseaddr, create); + if (me == NULL) + return NULL; + me->next = adb->hash_bucket[hash]; + adb->hash_bucket[hash] = me; + } + + for (i = 0; i < 3; i++) { + memcpy(entry, &me->entries[e], sizeof(struct authdb_entry)); + if (entry->checksum == 0 && entry->p.login_time == 0) + return &me->entries[e]; + if (entry->checksum == authdb_entry_checksum(entry)) + return &me->entries[e]; + sched_yield(); + } + + authdb_clear_entry(entry); + + return &me->entries[e]; +} + +int authdb_set(void *token, struct authdb_entry *entry) +{ + struct authdb_entry *mme = token; + uint32_t checksum = entry->checksum; + + entry->checksum = authdb_entry_checksum(entry); + if (mme->checksum != checksum) + return 0; + + mme->checksum = ~0; + memcpy(mme, entry, sizeof(*entry)); + + return 1; +} + +int authdb_check_login(void *token, struct authdb_entry *e, blob_t username, time_t now) +{ + struct authdb_entry *mme = token; + + /* check username */ + if (!blob_is_null(username) && + blob_cmp(username, BLOB_STRLEN(e->p.login_name)) != 0) + return 0; + + /* and dates */ + if (now > e->last_activity_time + AUTHDB_LOGOFF_PERIOD) + return 0; + + /* and that no one clobbered the entry */ + if (mme->checksum != e->checksum) + return 0; + + /* refresh last activity */ + mme->last_activity_time = now; + + return 1; +} + +void authdb_clear_entry(struct authdb_entry *entry) +{ + uint32_t checksum = entry->checksum; + + memset(entry, 0, sizeof(*entry)); + entry->checksum = checksum; +} + +void authdb_commit_login(void *token, struct authdb_entry *e, time_t now, struct authdb_config *cfg) +{ + e->p.block_categories = cfg->block_categories; + e->p.hard_block_categories = cfg->hard_block_categories; + e->p.login_time = now; + e->last_activity_time = now; + e->override_time = 0; + + authdb_set(token, e); +} + +void authdb_commit_logout(void *token) +{ + memset(token, 0, sizeof(struct authdb_entry)); +} + +void authdb_commit_override(void *token, struct authdb_entry *e, time_t now) +{ + struct authdb_entry *mme = token; + + mme->override_time = now; +} + +static blob_t read_word(FILE *in, int *lineno, blob_t b) +{ + int ch, i, comment = 0; + blob_t r; + + ch = fgetc(in); + while (1) { + if (ch == EOF) + return BLOB_NULL; + if (ch == '#') + comment = 1; + if (!comment && !isspace(ch)) + break; + if (ch == '\n') { + (*lineno)++; + comment = 0; + } + ch = fgetc(in); + } + + r.ptr = b.ptr; + r.len = 0; + for (i = 0; i < b.len-1 && !isspace(ch); i++, r.len++) { + r.ptr[i] = ch; + ch = fgetc(in); + if (ch == EOF) + break; + if (ch == '\n') + (*lineno)++; + } + + return r; +} + +static int find_category_id(struct sqdb *db, blob_t cat) +{ + uint32_t size, *ptr; + int i; + + ptr = sqdb_section_get(db, SQDB_SECTION_CATEGORIES, &size); + if (ptr == NULL) + return -1; + + size /= sizeof(uint32_t); + for (i = 0; i < size; i++) + if (blob_cmp(cat, sqdb_get_string_literal(db, ptr[i])) == 0) + return i; + + return -1; +} + +static inline uint64_t to_category(struct sqdb *db, blob_t c) +{ + int category; + + category = find_category_id(db, c); + if (category >= 0) + return 1ULL << category; + + fprintf(stderr, "WARNING: unknown category '%.*s'\n", + c.len, c.ptr); + return 0; +} + +int adbc_refresh(struct authdb_config *cfg, time_t now) +{ + FILE *in; + int lineno = 1; + char word1[64], word2[64]; + blob_t b, p; + struct stat st; + + if (cfg->last_check != 0 && cfg->last_check + 2*60 > now) + return 0; + + if (stat("/etc/squark/filter.conf", &st) != 0) + return -1; + + if (cfg->last_change == st.st_ctime) + return 0; + + /* check timestamp */ + + in = fopen("/etc/squark/filter.conf", "r"); + if (in == NULL) + return -1; + + cfg->block_categories = 0; + cfg->hard_block_categories = 0; + while (1) { + b = read_word(in, &lineno, BLOB_BUF(word1)); + if (blob_is_null(b)) + break; + + p = read_word(in, &lineno, BLOB_BUF(word2)); + if (blob_cmp(b, BLOB_STR("redirect_path")) == 0) { + cfg->redirect_url_base = blob_dup(p); + } else if (blob_cmp(b, BLOB_STR("forbid")) == 0) { + cfg->hard_block_categories |= to_category(cfg->db, p); + } else if (blob_cmp(b, BLOB_STR("warn")) == 0) { + cfg->block_categories |= to_category(cfg->db, p); + } + } + cfg->block_categories |= cfg->hard_block_categories; + + fclose(in); +} diff --git a/src/authdb.h b/src/authdb.h new file mode 100644 index 0000000..7bfa2f4 --- /dev/null +++ b/src/authdb.h @@ -0,0 +1,62 @@ +#ifndef AUTHDB_H +#define AUTHDB_H + +#include +#include +#include "blob.h" +#include "addr.h" + +#define AUTHDB_IP_HASH_SIZE 64 + +struct sqdb; +struct authdb_map_entry; + +struct authdb_config { + struct sqdb *db; + time_t last_check; + time_t last_change; + uint64_t block_categories; + uint64_t hard_block_categories; + blob_t redirect_url_base; +}; + +struct authdb { + struct authdb_map_entry *hash_bucket[AUTHDB_IP_HASH_SIZE]; +}; + +struct authdb_entry { + struct { + char login_name[44]; + char mac_address[6]; + uint16_t switch_port; + sockaddr_any switch_ip; + uint64_t block_categories; + uint64_t hard_block_categories; + uint32_t login_time; + } p; + uint32_t last_activity_time; + uint32_t override_time; + uint32_t checksum; +}; + +struct authdb_map_entry { + struct authdb_map_entry *next; + sockaddr_any baseaddr; + struct authdb_entry * entries; +}; + +int authdb_open(struct authdb *adb, struct authdb_config *adbc, struct sqdb *db); +void authdb_close(struct authdb *adb); + +void *authdb_get(struct authdb *adb, sockaddr_any *addr, struct authdb_entry *entry, int create); + +void authdb_clear_entry(struct authdb_entry *entry); +int authdb_set(void *token, struct authdb_entry *entry); +int authdb_check_login(void *token, struct authdb_entry *e, blob_t username, time_t now); +void authdb_commit_login(void *token, struct authdb_entry *e, time_t now, struct authdb_config *cfg); +void authdb_commit_logout(void *token); +void authdb_commit_override(void *token, struct authdb_entry *entry, time_t now); + +int adbc_refresh(struct authdb_config *cfg, time_t now); + +#endif diff --git a/src/blob.c b/src/blob.c new file mode 100644 index 0000000..1604308 --- /dev/null +++ b/src/blob.c @@ -0,0 +1,426 @@ +#include +#include +#include + +#include "blob.h" + +/* RFC 3986 section 2.3 Unreserved Characters (January 2005) */ +#define CTYPE_UNRESERVED 1 + +static const char *xd = "0123456789abcdefghijklmnopqrstuvwxyz"; + +static const unsigned char chartype[128] = { + ['a'] = CTYPE_UNRESERVED, + ['b'] = CTYPE_UNRESERVED, + ['c'] = CTYPE_UNRESERVED, + ['d'] = CTYPE_UNRESERVED, + ['e'] = CTYPE_UNRESERVED, + ['f'] = CTYPE_UNRESERVED, + ['g'] = CTYPE_UNRESERVED, + ['h'] = CTYPE_UNRESERVED, + ['i'] = CTYPE_UNRESERVED, + ['j'] = CTYPE_UNRESERVED, + ['k'] = CTYPE_UNRESERVED, + ['l'] = CTYPE_UNRESERVED, + ['m'] = CTYPE_UNRESERVED, + ['n'] = CTYPE_UNRESERVED, + ['o'] = CTYPE_UNRESERVED, + ['p'] = CTYPE_UNRESERVED, + ['q'] = CTYPE_UNRESERVED, + ['r'] = CTYPE_UNRESERVED, + ['s'] = CTYPE_UNRESERVED, + ['t'] = CTYPE_UNRESERVED, + ['u'] = CTYPE_UNRESERVED, + ['v'] = CTYPE_UNRESERVED, + ['w'] = CTYPE_UNRESERVED, + ['x'] = CTYPE_UNRESERVED, + ['y'] = CTYPE_UNRESERVED, + ['z'] = CTYPE_UNRESERVED, + ['A'] = CTYPE_UNRESERVED, + ['B'] = CTYPE_UNRESERVED, + ['C'] = CTYPE_UNRESERVED, + ['D'] = CTYPE_UNRESERVED, + ['E'] = CTYPE_UNRESERVED, + ['F'] = CTYPE_UNRESERVED, + ['G'] = CTYPE_UNRESERVED, + ['H'] = CTYPE_UNRESERVED, + ['I'] = CTYPE_UNRESERVED, + ['J'] = CTYPE_UNRESERVED, + ['K'] = CTYPE_UNRESERVED, + ['L'] = CTYPE_UNRESERVED, + ['M'] = CTYPE_UNRESERVED, + ['N'] = CTYPE_UNRESERVED, + ['O'] = CTYPE_UNRESERVED, + ['P'] = CTYPE_UNRESERVED, + ['Q'] = CTYPE_UNRESERVED, + ['R'] = CTYPE_UNRESERVED, + ['S'] = CTYPE_UNRESERVED, + ['T'] = CTYPE_UNRESERVED, + ['U'] = CTYPE_UNRESERVED, + ['V'] = CTYPE_UNRESERVED, + ['W'] = CTYPE_UNRESERVED, + ['X'] = CTYPE_UNRESERVED, + ['Y'] = CTYPE_UNRESERVED, + ['Z'] = CTYPE_UNRESERVED, + + ['0'] = CTYPE_UNRESERVED, + ['1'] = CTYPE_UNRESERVED, + ['2'] = CTYPE_UNRESERVED, + ['3'] = CTYPE_UNRESERVED, + ['4'] = CTYPE_UNRESERVED, + ['5'] = CTYPE_UNRESERVED, + ['6'] = CTYPE_UNRESERVED, + ['7'] = CTYPE_UNRESERVED, + ['8'] = CTYPE_UNRESERVED, + ['9'] = CTYPE_UNRESERVED, + + ['-'] = CTYPE_UNRESERVED, + ['_'] = CTYPE_UNRESERVED, + ['.'] = CTYPE_UNRESERVED, + ['~'] = CTYPE_UNRESERVED, +}; + +static inline int dx(int c) +{ + if (likely(c >= '0' && c <= '9')) + return c - '0'; + if (likely(c >= 'a' && c <= 'f')) + return c - 'a' + 0xa; + if (c >= 'A' && c <= 'F') + return c - 'A' + 0xa; + return -1; +} + +char *blob_cstr_dup(blob_t b) +{ + char *p; + + if (blob_is_null(b)) + return NULL; + + p = malloc(b.len+1); + if (p != NULL) { + memcpy(p, b.ptr, b.len); + p[b.len] = 0; + } + return p; +} + +blob_t blob_dup(blob_t b) +{ + blob_t p; + + if (blob_is_null(b)) + return BLOB_NULL; + + p.ptr = malloc(b.len); + if (p.ptr != NULL) { + memcpy(p.ptr, b.ptr, b.len); + p.len = b.len; + } else { + p.len = 0; + } + return p; +} + +int blob_cmp(blob_t a, blob_t b) +{ + if (a.len != b.len) + return a.len - b.len; + return memcmp(a.ptr, b.ptr, a.len); +} + +unsigned long blob_inet_addr(blob_t b) +{ + unsigned long ip = 0; + int i; + + for (i = 0; i < 3; i++) { + ip += blob_pull_uint(&b, 10); + ip <<= 8; + if (!blob_pull_matching(&b, BLOB_STR("."))) + return 0; + } + ip += blob_pull_uint(&b, 10); + if (b.len != 0) + return 0; + return htonl(ip); +} + + +blob_t blob_pushed(blob_t buffer, blob_t left) +{ + if (buffer.ptr + buffer.len != left.ptr + left.len) + return BLOB_NULL; + return BLOB_PTR_LEN(buffer.ptr, left.ptr - buffer.ptr); +} + +void blob_push(blob_t *b, blob_t d) +{ + if (b->len >= d.len) { + memcpy(b->ptr, d.ptr, d.len); + b->ptr += d.len; + b->len -= d.len; + } else { + *b = BLOB_NULL; + } +} + +void blob_push_lower(blob_t *b, blob_t d) +{ + int i; + + if (b->len < d.len) { + *b = BLOB_NULL; + return; + } + for (i = 0; i < d.len; i++) + b->ptr[i] = tolower(d.ptr[i]); + b->ptr += d.len; + b->len -= d.len; +} + +void blob_push_byte(blob_t *b, unsigned char byte) +{ + if (b->len) { + b->ptr[0] = byte; + b->ptr ++; + b->len --; + } else { + *b = BLOB_NULL; + } +} + +void blob_push_uint(blob_t *to, unsigned int value, int radix) +{ + char buf[64]; + char *ptr = &buf[sizeof(buf)-1]; + + if (value == 0) { + blob_push_byte(to, '0'); + return; + } + + while (value != 0) { + *(ptr--) = xd[value % radix]; + value /= radix; + } + + blob_push(to, BLOB_PTR_PTR(ptr+1, &buf[sizeof(buf)-1])); +} + +void blob_push_ctime(blob_t *to, time_t t) +{ + char buf[128]; + blob_t b; + + ctime_r(&t, buf); + b = BLOB_STRLEN(buf); + b.len--; + blob_push(to, b); +} + +void blob_push_hexdump(blob_t *to, blob_t binary) +{ + char *d; + int i; + + if (blob_is_null(*to)) + return; + + if (to->len < binary.len * 2) { + *to = BLOB_NULL; + return; + } + + for (i = 0, d = to->ptr; i < binary.len; i++) { + *(d++) = xd[(binary.ptr[i] >> 4) & 0xf]; + *(d++) = xd[binary.ptr[i] & 0xf]; + } + to->ptr = d; + to->len -= binary.len * 2; +} + +void blob_push_urldecode(blob_t *to, blob_t url) +{ + blob_t b, orig = *to; + + do { + blob_pull_matching(&url, BLOB_STR("/")); + b = blob_pull_cspn(&url, BLOB_STR("/")); + if (blob_is_null(b) || blob_cmp(b, BLOB_STR(".")) == 0) { + /* skip '.' or two consecutive / */ + } else if (blob_cmp(b, BLOB_STR("..")) == 0) { + /* go up one path component */ + blob_shrink_tail(to, blob_pushed(orig, b), '/'); + } else { + /* copy decoded; FIXME decode percent encoding */ + blob_push_byte(to, '/'); + blob_push(to, b); + } + } while (!blob_is_null(url)); +} + +void blob_push_urlencode(blob_t *to, blob_t url) +{ + unsigned char c; + int i; + + for (i = 0; i < url.len; i++) { + c = url.ptr[i]; + + if (c <= 127 && (chartype[c] & CTYPE_UNRESERVED)) { + blob_push_byte(to, c); + } else { + blob_push_byte(to, '%'); + blob_push_uint(to, c, 16); + } + } +} + +blob_t blob_pull(blob_t *b, int len) +{ + blob_t r; + + if (b->len >= len) { + r = BLOB_PTR_LEN(b->ptr, len); + b->ptr += len; + b->len -= len; + return r; + } + *b = BLOB_NULL; + return BLOB_NULL; +} + +void blob_pull_skip(blob_t *b, int len) +{ + if (b->len >= len) { + b->ptr += len; + b->len -= len; + } else { + *b = BLOB_NULL; + } +} + +int blob_pull_matching(blob_t *b, blob_t e) +{ + if (b->len < e.len) + return 0; + if (memcmp(b->ptr, e.ptr, e.len) != 0) + return 0; + b->ptr += e.len; + b->len -= e.len; + return 1; +} + +unsigned int blob_pull_uint(blob_t *b, int radix) +{ + unsigned int val; + int ch; + + val = 0; + while (b->len && b->ptr[0] != 0) { + ch = dx(b->ptr[0]); + if (ch < 0 || ch >= radix) + break; + val *= radix; + val += ch; + + b->ptr++; + b->len--; + } + + return val; +} + +blob_t blob_pull_spn(blob_t *b, const blob_t reject) +{ + blob_t t = *b; + int i; + + for (i = 0; i < t.len; i++) { + if (memchr(reject.ptr, t.ptr[i], reject.len) == NULL) { + *b = BLOB_PTR_LEN(t.ptr + i, t.len - i); + return BLOB_PTR_LEN(t.ptr, i); + } + } + + *b = BLOB_NULL; + return t; +} + +blob_t blob_pull_cspn(blob_t *b, const blob_t reject) +{ + blob_t t = *b; + int i; + + for (i = 0; i < t.len; i++) { + if (memchr(reject.ptr, t.ptr[i], reject.len) != NULL) { + *b = BLOB_PTR_LEN(t.ptr + i, t.len - i); + return BLOB_PTR_LEN(t.ptr, i); + } + } + + *b = BLOB_NULL; + return t; +} + +blob_t blob_expand_head(blob_t *b, blob_t limits, unsigned char sep) +{ + blob_t t = *b; + blob_t r; + + if (t.ptr < limits.ptr || t.ptr+t.len > limits.ptr+limits.len) + return BLOB_NULL; + while (t.ptr > limits.ptr && t.ptr[-1] == sep) + t.ptr--, t.len++; + + r.ptr = t.ptr; + r.len = 0; + while (t.ptr > limits.ptr && t.ptr[-1] != sep) { + t.ptr--, t.len++; + r.ptr--, r.len++; + } + *b = t; + return r; +} + +blob_t blob_expand_tail(blob_t *b, blob_t limits, unsigned char sep) +{ + blob_t t = *b; + blob_t r; + + if (t.ptr < limits.ptr || t.ptr+t.len > limits.ptr+limits.len) + return BLOB_NULL; + while (t.ptr + t.len < limits.ptr + limits.len && t.ptr[t.len] == sep) + t.len++; + + r.ptr = t.ptr + t.len; + r.len = 0; + while (t.ptr + t.len < limits.ptr + limits.len && t.ptr[t.len] != sep) { + t.len++; + r.len++; + } + *b = t; + return r; +} + +blob_t blob_shrink_tail(blob_t *b, blob_t limits, unsigned char sep) +{ + blob_t t = *b; + blob_t r; + + if (t.ptr <= limits.ptr || t.ptr+t.len > limits.ptr+limits.len) + return BLOB_NULL; + while (t.len && t.ptr[t.len-1] == sep) + t.len--; + + r.ptr = t.ptr; + r.len = 0; + while (t.len && t.ptr[t.len-1] != sep) { + t.len--; + r.ptr--, r.len++; + } + *b = t; + return r; +} diff --git a/src/blob.h b/src/blob.h new file mode 100644 index 0000000..76afed7 --- /dev/null +++ b/src/blob.h @@ -0,0 +1,63 @@ +#ifndef BLOB_H +#define BLOB_H + +#include + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +#if defined __GNUC__ && __GNUC__ == 2 && __GNUC_MINOR__ < 96 +#define __builtin_expect(x, expected_value) (x) +#endif + +#ifndef likely +#define likely(x) __builtin_expect((!!(x)),1) +#endif + +#ifndef unlikely +#define unlikely(x) __builtin_expect((!!(x)),0) +#endif + +typedef struct blob { + char *ptr; + unsigned int len; +} blob_t; + +#define BLOB_PTR_LEN(ptr,len) (blob_t){(void*)(ptr), (len)} +#define BLOB_PTR_PTR(beg,end) BLOB_PTR_LEN((beg),(end)-(beg)+1) +#define BLOB_BUF(buf) (blob_t){(void*)(buf), sizeof(buf)} +#define BLOB_STRLEN(str) (blob_t){(str), strlen(str)} +#define BLOB_STR_INIT(str) {(str), sizeof(str)-1} +#define BLOB_STR(str) (blob_t) BLOB_STR_INIT(str) +#define BLOB_NULL (blob_t){NULL, 0} + +static inline int blob_is_null(blob_t b) +{ + return b.len == 0; +} + +char *blob_cstr_dup(blob_t b); +blob_t blob_dup(blob_t b); +int blob_cmp(blob_t a, blob_t b); +unsigned long blob_inet_addr(blob_t buf); + +blob_t blob_pushed(blob_t buffer, blob_t left); +void blob_push(blob_t *b, blob_t d); +void blob_push_lower(blob_t *b, blob_t d); +void blob_push_byte(blob_t *b, unsigned char byte); +void blob_push_uint(blob_t *to, unsigned int value, int radix); +void blob_push_ctime(blob_t *to, time_t t); +void blob_push_hexdump(blob_t *to, blob_t binary); +void blob_push_urldecode(blob_t *to, blob_t url); +void blob_push_urlencode(blob_t *to, blob_t url); +blob_t blob_pull(blob_t *b, int len); +void blob_pull_skip(blob_t *b, int len); +int blob_pull_matching(blob_t *b, blob_t e); +unsigned int blob_pull_uint(blob_t *b, int radix); +blob_t blob_pull_spn(blob_t *b, const blob_t spn); +blob_t blob_pull_cspn(blob_t *b, const blob_t cspn); + +blob_t blob_expand_head(blob_t *b, blob_t limits, unsigned char sep); +blob_t blob_expand_tail(blob_t *b, blob_t limits, unsigned char sep); +blob_t blob_shrink_tail(blob_t *b, blob_t limits, unsigned char sep); + +#endif diff --git a/src/filterdb.c b/src/filterdb.c new file mode 100644 index 0000000..d3f4c6a --- /dev/null +++ b/src/filterdb.c @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include + +#include "filterdb.h" + +#define PAGE_SIZE 4096 +#define ALIGN(s,a) (((s) + a - 1) & ~(a - 1)) + +const char *sqdb_section_names[SQDB_SECTION_MAX] = { + [SQDB_SECTION_STRINGS] = "strings", + [SQDB_SECTION_CATEGORIES] = "categories", + [SQDB_SECTION_INDEX] = "index", + [SQDB_SECTION_INDEX_MPH] = "index_mph", + [SQDB_SECTION_KEYWORD] = "keyword", + [SQDB_SECTION_KEYWORD_MPH] = "keyword_mph", +}; + +static int sqdb_allocate(struct sqdb *db, size_t s, int wr) +{ + size_t old_size, new_size; + void *base; + int prot = PROT_READ; + + old_size = db->file_length; + new_size = ALIGN(db->file_length + s, PAGE_SIZE); + + if (new_size == ALIGN(db->file_length, PAGE_SIZE)) { + db->file_length += s; + return old_size; + } + + if (wr && ftruncate(db->fd, new_size) < 0) + return -1; + + if (db->mmap_base == NULL) { + if (wr) + prot |= PROT_WRITE; + base = mmap(NULL, new_size, prot, MAP_SHARED, db->fd, 0); + } else { + base = mremap(db->mmap_base, ALIGN(old_size, PAGE_SIZE), + new_size, MREMAP_MAYMOVE); + } + if (base == MAP_FAILED) + return -1; + + db->mmap_base = base; + db->file_length += ALIGN(s, 16); + + return old_size; +} + +int sqdb_open(struct sqdb *db, const char *fn) +{ + struct stat st; + + db->fd = open(fn, O_RDONLY); + if (db->fd < 0) + return -1; + + fstat(db->fd, &st); + + db->file_length = 0; + db->mmap_base = NULL; + sqdb_allocate(db, st.st_size, 0); + + return 0; +} + +int sqdb_create(struct sqdb *db, const char *fn) +{ + struct sqdb_header *hdr; + int rc; + + db->fd = open(fn, O_CREAT | O_TRUNC | O_RDWR, 0666); + if (db->fd < 0) + return -1; + + db->file_length = 0; + db->mmap_base = NULL; + + rc = sqdb_allocate(db, sizeof(struct sqdb_header), 1); + if (rc < 0) { + close(db->fd); + return rc; + } + + hdr = db->mmap_base; + strcpy(hdr->description, "Squark Filtering Database"); + hdr->version = 1; + hdr->magic = 0xdbdbdbdb; + hdr->num_sections = SQDB_SECTION_MAX; + + return 0; +} + +int sqdb_open(struct sqdb *db, const char *fn); + +void sqdb_close(struct sqdb *db) +{ + if (db->mmap_base) + munmap(db->mmap_base, ALIGN(db->file_length, PAGE_SIZE)); + close(db->fd); +} + +void *sqdb_section_create(struct sqdb *db, int id, uint32_t size) +{ + struct sqdb_header *hdr; + size_t pos; + + hdr = db->mmap_base; + if (hdr->section[id].offset || hdr->section[id].length) + return NULL; + + pos = sqdb_allocate(db, size, 1); + if (pos < 0) + return NULL; + + /* sqdb_allocate can remap mmap_base */ + hdr = db->mmap_base; + hdr->section[id].offset = pos; + hdr->section[id].length = size; + + return db->mmap_base + pos; +} + +void *sqdb_section_get(struct sqdb *db, int id, uint32_t *size) +{ + struct sqdb_header *hdr = db->mmap_base; + + if (hdr->section[id].offset == 0) + return NULL; + + if (size) + *size = hdr->section[id].length; + + return db->mmap_base + hdr->section[id].offset; +} + +blob_t sqdb_get_string_literal(struct sqdb *db, uint32_t encoded_ptr) +{ + unsigned char *ptr; + unsigned int len, off; + + ptr = sqdb_section_get(db, SQDB_SECTION_STRINGS, NULL); + if (ptr == NULL) + return BLOB_NULL; + + off = encoded_ptr >> SQDB_LENGTH_BITS; + len = encoded_ptr & ((1 << SQDB_LENGTH_BITS) - 1); + if (len == 0) + len = ptr[off++]; + + return BLOB_PTR_LEN(ptr + off, len); +} diff --git a/src/filterdb.h b/src/filterdb.h new file mode 100644 index 0000000..2d16572 --- /dev/null +++ b/src/filterdb.h @@ -0,0 +1,59 @@ +#ifndef FILTERDB_H +#define FILTERDB_H + +#include +#include +#include "blob.h" + +#define SQDB_LENGTH_BITS 5 + +#define SQDB_SECTION_STRINGS 0 +#define SQDB_SECTION_CATEGORIES 1 +#define SQDB_SECTION_INDEX 2 +#define SQDB_SECTION_INDEX_MPH 3 +#define SQDB_SECTION_KEYWORD 4 +#define SQDB_SECTION_KEYWORD_MPH 5 +#define SQDB_SECTION_MAX 8 + +struct sqdb { + int fd; + void * mmap_base; + size_t file_length; +}; + +struct sqdb_section { + uint32_t offset; + uint32_t length; +}; + +struct sqdb_header { + char description[116]; + uint32_t num_sections; + uint32_t version; + uint32_t magic; + struct sqdb_section section[SQDB_SECTION_MAX]; +}; + +#define SQDB_PARENT_ROOT 0xffffff +#define SQDB_PARENT_IPV4 0xfffffe + +struct sqdb_index_entry { + uint32_t has_subdomains : 1; + uint32_t has_paths : 1; + uint32_t category : 6; + uint32_t parent : 24; + uint32_t component; +}; + + +const char *sqdb_section_names[SQDB_SECTION_MAX]; + +int sqdb_create(struct sqdb *db, const char *fn); +int sqdb_open(struct sqdb *db, const char *fn); +void sqdb_close(struct sqdb *db); + +void *sqdb_section_create(struct sqdb *db, int id, uint32_t size); +void *sqdb_section_get(struct sqdb *db, int id, uint32_t *size); +blob_t sqdb_get_string_literal(struct sqdb *db, uint32_t encoded_ptr); + +#endif diff --git a/src/lua-squarkdb.c b/src/lua-squarkdb.c new file mode 100644 index 0000000..5a30848 --- /dev/null +++ b/src/lua-squarkdb.c @@ -0,0 +1,343 @@ +#include + +#include +#include +#include + +#include + +#include "filterdb.h" + +#define SQUARKDB_META "squarkdb" + +struct sqdb *Lsqdb_checkarg(lua_State *L, int index) +{ + struct sqdb *db; + + luaL_checktype(L, index, LUA_TUSERDATA); + db = (struct sqdb *) luaL_checkudata(L, index, SQUARKDB_META); + if (db == NULL) + luaL_typerror(L, index, SQUARKDB_META); + return db; +} + +static int Lsqdb_new(lua_State *L) +{ + struct sqdb *db; + const char *fn; + + fn = luaL_checklstring(L, 1, NULL); + + db = (struct sqdb *) lua_newuserdata(L, sizeof(struct sqdb)); + luaL_getmetatable(L, SQUARKDB_META); + lua_setmetatable(L, -2); + + if (sqdb_create(db, fn) < 0) + luaL_error(L, "Failed to create SquarkDB file '%s'", fn); + + return 1; +} + +static int Lsqdb_destroy(lua_State *L) +{ + struct sqdb *db; + + db = Lsqdb_checkarg(L, 1); + sqdb_close(db); + + return 1; +} + + +struct ioa_data { + lua_State *main; + lua_State *thread; +}; + +static void ioa_rewind(void *data) +{ + struct ioa_data *ioa = (struct ioa_data *) data; + + /* pop previous thread */ + lua_pop(ioa->main, 1); + + /* create a new lua thread */ + ioa->thread = lua_newthread(ioa->main); + lua_pushvalue(ioa->main, -2); /* copy function to top */ + lua_xmove(ioa->main, ioa->thread, 1); /* move function from L to NL */ +} + +static cmph_uint32 ioa_count_keys(void *data) +{ + struct ioa_data *ioa = (struct ioa_data *) data; + lua_State *NL; + cmph_uint32 cnt = 0; + + NL = lua_newthread(ioa->main); + lua_pushvalue(ioa->main, -2); /* copy function to top */ + lua_xmove(ioa->main, NL, 1); /* move function from L to NL */ + + do { + cnt++; + lua_settop(NL, 1); + } while (lua_resume(NL, 0) == LUA_YIELD); + ioa_rewind(data); + + return cnt - 1; +} + +static int ioa_read(void *data, char **key, cmph_uint32 *len) +{ + struct ioa_data *ioa = (struct ioa_data *) data; + lua_State *L = ioa->thread; + size_t l; + + /* get next key from lua thread */ + lua_settop(L, 1); + if (lua_resume(L, 0) != LUA_YIELD || + !lua_isstring(L, 1)) { + *key = NULL; + *len = 0; + return -1; + } + + *key = (char *) lua_tolstring(L, 1, &l); + *len = l; + + return l; +} + +static void ioa_dispose(void *data, char *key, cmph_uint32 len) +{ + /* LUA takes care of garbage collection */ +} + +static int Lsqdb_hash(lua_State *L) +{ + struct sqdb *db; + void *ptr; + cmph_uint32 hash; + const char *key; + size_t keylen; + + db = Lsqdb_checkarg(L, 1); + key = luaL_checklstring(L, 2, &keylen); + + ptr = sqdb_section_get(db, SQDB_SECTION_INDEX_MPH, NULL); + hash = cmph_search_packed(ptr, key, keylen); + + lua_pushinteger(L, hash); + + return 1; +} + +static int Lsqdb_generate_hash(lua_State *L) +{ + struct sqdb *db; + struct ioa_data ioa; + cmph_config_t *cfg; + cmph_t *cmph; + cmph_io_adapter_t io; + + char *ptr; + cmph_uint32 packed; + + db = Lsqdb_checkarg(L, 1); + luaL_argcheck(L, lua_isfunction(L, 2) && !lua_iscfunction(L, 2), + 2, "Lua function expected"); + + ioa.main = L; + io.data = &ioa; + io.nkeys = ioa_count_keys(io.data); + io.read = ioa_read; + io.dispose = ioa_dispose; + io.rewind = ioa_rewind; + + cfg = cmph_config_new(&io); + if (cfg == NULL) + luaL_error(L, "Failed to create CMPH config"); + + cmph_config_set_algo(cfg, CMPH_CHD); + cmph = cmph_new(cfg); + cmph_config_destroy(cfg); + + if (cmph == NULL) + luaL_error(L, "Failed to create minimal perfect hash"); + + packed = cmph_packed_size(cmph); + ptr = sqdb_section_create(db, SQDB_SECTION_INDEX_MPH, packed); + if (ptr == NULL) { + cmph_destroy(cmph); + luaL_error(L, "Unable to allocation MPH section from SquarkDB"); + } + + cmph_pack(cmph, ptr); + cmph_destroy(cmph); + + lua_pushinteger(L, io.nkeys); + lua_pushinteger(L, packed); + + return 2; +} + +static int Lsqdb_create_index(lua_State *L) +{ + struct sqdb *db; + lua_Integer num_entries; + void *ptr; + + db = Lsqdb_checkarg(L, 1); + num_entries = luaL_checkinteger(L, 2); + + ptr = sqdb_section_create(db, SQDB_SECTION_INDEX, sizeof(struct sqdb_index_entry) * num_entries); + if (ptr == NULL) + luaL_error(L, "Failed to create INDEX section"); + + return 0; +} + +static int Lsqdb_assign_index(lua_State *L) +{ + struct sqdb *db; + size_t size; + lua_Integer idx; + struct sqdb_index_entry *ptr; + + db = Lsqdb_checkarg(L, 1); + idx = luaL_checkinteger(L, 2); + + ptr = sqdb_section_get(db, SQDB_SECTION_INDEX, &size); + if (size <= 0 || idx * sizeof(struct sqdb_index_entry) >= size) + luaL_error(L, "Bad index assignment (idx=%d, section size=%d)", idx, size); + + ptr += idx; + if (ptr->component != 0) + luaL_error(L, "Index entry %d has been already assigned", idx); + + ptr->category = luaL_checkinteger(L, 3); + ptr->has_subdomains = lua_toboolean(L, 4); + ptr->has_paths = lua_toboolean(L, 5); + ptr->component = luaL_checkinteger(L, 6); + ptr->parent = luaL_checkinteger(L, 7); + + return 0; +} + +static int Lsqdb_map_strings(lua_State *L) +{ + struct sqdb *db; + const char *str; + unsigned char *ptr; + size_t len, total, pos; + + db = Lsqdb_checkarg(L, 1); + luaL_checktype(L, 2, LUA_TTABLE); + + /* go through the table and count total amount of data */ + total = 0; + lua_pushnil(L); + while (lua_next(L, 2) != 0) { + str = luaL_checklstring(L, -2, &len); + total += len; + if (len >= (1 << SQDB_LENGTH_BITS)) + total++; + lua_pop(L, 1); + } + + /* create string literal section */ + ptr = sqdb_section_create(db, SQDB_SECTION_STRINGS, total); + if (ptr == NULL) + luaL_error(L, "Failed to create string literal section (%d bytes)", total); + + /* populate string literals and return their indices */ + pos = 0; + lua_pushnil(L); + while (lua_next(L, 2) != 0) { + str = lua_tolstring(L, -2, &len); + lua_pop(L, 1); + + /* table[key] = encoded_string_pointer */ + lua_pushvalue(L, -1); + if (len >= (1 << SQDB_LENGTH_BITS)) { + lua_pushinteger(L, pos << SQDB_LENGTH_BITS); + ptr[pos++] = len; + } else { + lua_pushinteger(L, (pos << SQDB_LENGTH_BITS) + len); + } + memcpy(&ptr[pos], str, len); + pos += len; + + lua_rawset(L, 2); + } + + return 0; +} + +static int Lsqdb_write_section(lua_State *L) +{ + struct sqdb *db; + uint32_t *ptr; + const char *section; + int i, tbllen, si = -1; + + db = Lsqdb_checkarg(L, 1); + section = luaL_checkstring(L, 2); + luaL_checktype(L, 3, LUA_TTABLE); + tbllen = lua_objlen(L, 3); + + for (i = 0; sqdb_section_names[i] && i < SQDB_SECTION_MAX; i++) { + if (strcmp(sqdb_section_names[i], section) == 0) { + si = 0; + break; + } + } + if (si < 0) + luaL_error(L, "Section name '%s' is invalid", section); + + ptr = sqdb_section_create(db, i, sizeof(uint32_t) * tbllen); + if (ptr == NULL) + luaL_error(L, "Failed to create section '%s'", section); + + for (i = 0; i < tbllen; i++) { + lua_rawgeti(L, 3, i + 1); + ptr[i] = lua_tointeger(L, -1); + lua_pop(L, 1); + } + + return 0; +} + +static const luaL_reg sqdb_meta_methods[] = { + { "__gc", Lsqdb_destroy }, + { NULL, NULL } +}; + +static const luaL_reg squarkdb_methods[] = { + { "new", Lsqdb_new }, + { "hash", Lsqdb_hash }, + { "generate_hash", Lsqdb_generate_hash }, + { "create_index", Lsqdb_create_index }, + { "assign_index", Lsqdb_assign_index }, + { "map_strings", Lsqdb_map_strings }, + { "write_section", Lsqdb_write_section }, + { NULL, NULL } +}; + +LUALIB_API int luaopen_squarkdb(lua_State *L) +{ + /* Register squarkdb library */ + luaI_openlib(L, "squarkdb", squarkdb_methods, 0); + + /* And metatable for it */ + luaL_newmetatable(L, SQUARKDB_META); + luaI_openlib(L, NULL, sqdb_meta_methods, 0); + lua_pushliteral(L, "__index"); + lua_pushvalue(L, -3); + lua_rawset(L, -3); + lua_pushliteral(L, "__metatable"); + lua_pushvalue(L, -3); + lua_rawset(L, -3); + lua_pop(L, 1); + + return 1; +} diff --git a/src/sqdb-build.lua b/src/sqdb-build.lua new file mode 100755 index 0000000..cd039e2 --- /dev/null +++ b/src/sqdb-build.lua @@ -0,0 +1,335 @@ +#!/usr/bin/lua + +require("squarkdb") + +local all_strings = {} +local all_domains = {} +local all_ips = {} + +local all_categories = {} +local all_categories_by_id = {} +local num_categories = 0 + +local strfind = string.find +local strsub = string.sub +local tinsert = table.insert + +local function strsplit(delimiter, text) + local list = {} + local pos = 1 + --if strfind("", delimiter, 1) then -- this would result in endless loops + -- error("delimiter matches empty string!") + --end + while 1 do + local first, last = strfind(text, delimiter, pos) + if first then -- found? + tinsert(list, strsub(text, pos, first-1)) + pos = last+1 + else + tinsert(list, strsub(text, pos)) + break + end + end + return list +end + +local function account_string(s) + all_strings[s] = true +end + +local function get_category(category_text) + local cat + + cat = all_categories[category_text] + if cat ~= nil then return cat end + + -- start category ID's from zero + cat = { desc=category_text, id=num_categories } + all_categories[category_text] = cat + num_categories = num_categories + 1 + + -- but index them from one + all_categories_by_id[num_categories] = cat + + account_string(category_text) + + return cat +end + +local function get_domain(domain, locked) + local parts, entry, idx, p, child + + parts = strsplit("[.]", domain) + entry = all_domains + for idx=#parts,1,-1 do + p = parts[idx] + if entry.children == nil then + entry.children = {} + end + child = entry.children[p] + if child == nil then + child = {} + entry.children[p] = child + end + if child.locked and not locked then + return nil + end + entry = child + end + return child +end + +local function get_path(domain_entry, path, locked) + local entry, p, n, component + + entry = domain_entry + for n,component in pairs(strsplit("/", path)) do + if entry.paths == nil then + entry.paths = {} + end + p = entry.paths[component] + if p == nil then + p = {} + entry.paths[component] = p + end + if p.locked and not locked then + return nil + end + entry = p + end + return p +end + +local function is_ip_addr(s) + return s:match("^%d+\.%d+\.%d+\.%d+$") +end + +local function read_urls(filename, category, locked) + local fd, url, domain, path, d + + fd = io.open(filename) + if fd == nil then + print("WARNING: File " .. filename .. " does not exist") + return + end + print("Reading " .. filename) + for url in fd:lines() do + url = url:gsub("#.*", "") + url = url:gsub(" *^", "") + url = url:lower() + url = url:gsub("^(www%d*[.])([^.]*[.])", "%2") + domain, path = url:match("([^/]*)/?(.*)") + domain = domain:gsub(":.*", "") + domain = domain:gsub("[.]$", "") -- trailing dot + if domain == "" then + d = nil + elseif not is_ip_addr(domain) then + d = get_domain(domain, locked) + else + d = all_ips[domain] + if d == nil then + d = {} + all_ips[domain] = d + end + end + if d == nil then + --if url ~= "" then + -- print(url .. " ignored due to locked record") + --end + elseif path ~= "" then + if d.category ~= category and #path < 100 and path:match("([?;&])") == nil then + path = path:gsub("^/", "") + path = path:gsub("/$", "") + p = get_path(d, path, locked) + if p ~= nil then + p.category = category + if locked then + p.locked = true + end + end + end + else + if d.category == nil then + d.category = category + if locked then + d.locked = true + end + end + end + end + fd:close() +end + +local function enum_paths(cb, category, path, data) + local fpath, cpath, cdata, cat + + for cpath, cdata in pairs(data) do + fpath = path .. "/" .. cpath + cat = cdata.category or category + cb(fpath, path, cpath, cat, false, cdata.paths) + if cdata.paths then + enum_paths(cb, cat, fpath, cdata.paths) + end + end +end + +local function enum_tree(cb, category, dns, data) + local cdns, cdata, fdns, cat + + if data.paths ~= nil then + enum_paths(cb, category, dns, data.paths) + end + if data.children ~= nil then + for cdns, cdata in pairs(data.children) do + if dns ~= nil then + fdns = cdns .. "." .. dns + else + fdns = cdns + end + cat = cdata.category or category + cb(fdns, dns, cdns, cat, cdata.children, cdata.paths) + enum_tree(cb, cat, fdns, cdata) + end + end +end + +function iptonumber(str) + local num = 0 + for elem in str:gmatch("%d+") do + num = num * 256 + assert(tonumber(elem)) + end + return num +end + +local function enum_all(cb) + local ipaddr, data, category + + -- enumerate all domains + enum_tree(cb, nil, nil, all_domains) + + -- all IP addresses + for ipaddr, data in pairs(all_ips) do + if data.paths ~= nil then + enum_paths(cb, data.category, ipaddr, data.paths) + end + cb(ipaddr, nil, iptonumber(ipaddr), data.category, nil, data.paths) + end +end + +local function prune_paths(paths, category) + local path, pdata, cat + local num_paths = 0 + + for path, pdata in pairs(paths) do + local sub_paths = 0 + + cat = pdata.category or category + if pdata.paths ~= nil then + sub_paths = prune_paths(pdata.paths, cat) + if sub_paths == 0 then + pdata.paths = nil + end + end + if cat == category and sub_paths == 0 then + paths[path] = nil + else + num_paths = num_paths + 1 + account_string(path) + end + end + return num_paths +end + +local function prune_tree(d, pcategory) + local num_childs = 0 + local num_paths = 0 + local cat + + cat = d.category or pcategory + if d.children ~= nil then + for n, child in pairs(d.children) do + if prune_tree(child, cat, n) then + d.children[n] = nil + else + num_childs = num_childs + 1 + account_string(n) + end + end + if num_childs == 0 then + d.children = nil + end + end + --print(name, d.category, category, d.num_paths, num_childs) + if d.paths ~= nil then + num_paths = prune_paths(d.paths, cat) + if num_paths == 0 then + d.paths = nil + end + end + if d.category == pcategory and num_paths == 0 and num_childs == 0 then + --num_pruned_leafs = num_pruned_leafs + 1 + return true + end + return false +end + +local function load_lists(conffile, part) + local line, fields, cat + + for line in io.lines(conffile) do + line = line:gsub("#(.*)", "") + fields = strsplit("[\t ]", line) + if fields[1] == "STOP" then + break + end + if fields[3] then + read_urls("lists/" .. fields[2] .. "list/" .. fields[3] .. "/" .. part, + get_category(fields[1]), + fields[4] == "LOCK") + end + end +end + +-- start by reading in all classification data +get_category("unknown") +load_lists("lists.conf", "domains") +prune_tree(all_domains, nil) +load_lists("lists.conf", "urls") +prune_tree(all_domains, nil) + +-- generate database +local db = squarkdb.new("squark.db") +num_entries = db:generate_hash(function() enum_all(coroutine.yield) end) + +-- write string literals +db:map_strings(all_strings) + +-- map category names and write the category section out +for id, cdata in ipairs(all_categories_by_id) do + all_categories_by_id[id] = all_strings[cdata.desc] +end +db:write_section("categories", all_categories_by_id) + +-- create master index +db:create_index(num_entries) +enum_all( + function(uri, parent_uri, component, category, childs, paths) + if parent_uri == nil and type(component) == "number" then + -- Embedded IPv4 address + db:assign_index(db:hash(uri), + category and category.id or 0, + childs and true or false, + paths and true or false, + component, + -2) + else + -- Regular entry + db:assign_index(db:hash(uri), + category and category.id or 0, + childs and true or false, + paths and true or false, + all_strings[component] or 0, + parent_uri and db:hash(parent_uri) or -1) + end + end +) diff --git a/src/squark-auth-ip.c b/src/squark-auth-ip.c new file mode 100644 index 0000000..3cdea0b --- /dev/null +++ b/src/squark-auth-ip.c @@ -0,0 +1,218 @@ +/* squark-auth-ip.c - Squid User Authentication and Rating Kit + * An external acl helper for Squid which collects authentication + * information about an IP-address from local shared memory database. + * + * Copyright (C) 2010 Timo Teräs + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +#include +#include +#include + +#include "blob.h" +#include "authdb.h" +#include "filterdb.h" + +#define DO_LOGIN -1 +#define DO_OVERRIDE -2 +#define DO_PRINT -3 +#define DO_LOGOUT -4 + +static int running = 1; +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; +static blob_t space = BLOB_STR_INIT(" "); +static blob_t lf = BLOB_STR_INIT("\n"); +static time_t now; + +static void handle_line(blob_t line) +{ + char reply[128]; + blob_t b, id, ipaddr; + struct authdb_entry entry; + sockaddr_any addr; + void *token; + int auth_ok = 0; + + id = blob_pull_cspn(&line, space); + blob_pull_spn(&line, space); + ipaddr = blob_pull_cspn(&line, space); + + if (addr_parse(ipaddr, &addr)) { + token = authdb_get(&adb, &addr, &entry, 1); + + if (authdb_check_login(token, &entry, BLOB_NULL, now)) + auth_ok = 1; + } + + b = BLOB_BUF(reply); + blob_push(&b, id); + if (auth_ok) { + blob_push(&b, BLOB_STR(" OK user=")); + blob_push(&b, BLOB_STRLEN(entry.p.login_name)); + blob_push(&b, BLOB_PTR_LEN("\n", 1)); + } else { + blob_push(&b, BLOB_STR(" ERR\n")); + } + + b = blob_pushed(BLOB_BUF(reply), b); + write(STDOUT_FILENO, b.ptr, b.len); +} + +static void read_input(void) +{ + static char buffer[256]; + static blob_t left; + + blob_t b, line; + int r; + + if (blob_is_null(left)) + left = BLOB_BUF(buffer); + + r = read(STDIN_FILENO, left.ptr, left.len); + if (r < 0) + return; + if (r == 0) { + running = 0; + return; + } + left.ptr += r; + left.len -= r; + + now = time(NULL); + + b = blob_pushed(BLOB_BUF(buffer), left); + do { + line = blob_pull_cspn(&b, lf); + if (!blob_pull_matching(&b, lf)) + return; + + handle_line(line); + + if (b.len) { + memcpy(buffer, b.ptr, b.len); + b.ptr = buffer; + } + left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); + } while (b.len); +} + +#define DUMPPAR(b, name, fn) \ + do { \ + blob_push(b, BLOB_STR("squark_" name "='")); \ + fn; \ + blob_push(b, BLOB_STR("'; ")); \ + } while (0) + +int main(int argc, char **argv) +{ + int opt; + sockaddr_any ipaddr = { .any.sa_family = AF_UNSPEC }; + blob_t ip = BLOB_NULL, username = BLOB_NULL; + + while ((opt = getopt(argc, argv, "i:u:olpL")) != -1) { + switch (opt) { + case 'i': + ip = BLOB_STRLEN(optarg); + if (!addr_parse(ip, &ipaddr)) { + fprintf(stderr, "'%s' does not look like IP-address\n", + optarg); + return 1; + } + break; + case 'u': + username = BLOB_STRLEN(optarg); + break; + case 'o': + running = DO_OVERRIDE; + break; + case 'l': + running = DO_LOGIN; + break; + case 'p': + running = DO_PRINT; + break; + case 'L': + running = DO_LOGOUT; + break; + } + } + + now = time(NULL); + sqdb_open(&db, "/var/lib/squark/squark.db"); + authdb_open(&adb, &adbc, &db); + + if (running < 0) { + struct authdb_entry entry; + void *token; + + if (ipaddr.any.sa_family == AF_UNSPEC) { + fprintf(stderr, "IP-address not specified\n"); + return 2; + } + + token = authdb_get(&adb, &ipaddr, &entry, 1); + if (token == NULL) { + fprintf(stderr, "Failed to get authdb record\n"); + return 3; + } + + switch (running) { + case DO_LOGIN: + if (blob_is_null(username)) { + fprintf(stderr, "Username not specified\n"); + return 2; + } + authdb_clear_entry(&entry); + memcpy(entry.p.login_name, username.ptr, username.len); + authdb_commit_login(token, &entry, now, &adbc); + break; + case DO_OVERRIDE: + if (authdb_check_login(token, &entry, username, now)) + authdb_commit_override(token, &entry, now); + break; + case DO_PRINT: { + char buf[512]; + blob_t b = BLOB_BUF(buf); + + DUMPPAR(&b, "ip_address", + addr_push_hostaddr(&b, &ipaddr)); + DUMPPAR(&b, "username", + blob_push(&b, BLOB_BUF(entry.p.login_name))); + DUMPPAR(&b, "mac_address", + blob_push_hexdump(&b, BLOB_BUF(entry.p.mac_address))); + DUMPPAR(&b, "login_time", + blob_push_ctime(&b, entry.p.login_time)); + DUMPPAR(&b, "activity_time", + blob_push_ctime(&b, entry.last_activity_time)); + DUMPPAR(&b, "override_time", + blob_push_ctime(&b, entry.override_time)); + DUMPPAR(&b, "block_categories", + blob_push_hexdump(&b, BLOB_BUF(&entry.p.block_categories))); + DUMPPAR(&b, "hard_block_categories", + blob_push_hexdump(&b, BLOB_BUF(&entry.p.hard_block_categories))); + blob_push(&b, BLOB_STR("\n")); + b = blob_pushed(BLOB_BUF(buf), b); + fwrite(b.ptr, b.len, 1, stdout); + break; + } + case DO_LOGOUT: + if (authdb_check_login(token, &entry, username, now)) + authdb_commit_logout(token); + break; + } + } else { + while (running) + read_input(); + } + + authdb_close(&adb); + sqdb_close(&db); +} diff --git a/src/squark-auth-snmp.c b/src/squark-auth-snmp.c new file mode 100644 index 0000000..81b846d --- /dev/null +++ b/src/squark-auth-snmp.c @@ -0,0 +1,1152 @@ +/* squark-auth-snmp.c - Squid User Authentication and Rating Kit + * An external acl helper for Squid which collects authentication + * information about an IP-address from switches via SNMP. + * + * Copyright (C) 2010 Timo Teräs + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +/* TODO: + * - implement Q-BRIDGE-MIB query + * - map vlan names to vlan index + * - print some usage information + * - poll lldpStatsRemTablesLastChangeTime when doing switch update + * to figure out if lldp info is valid or not + */ + +#include +#include +#include +#include + +#include +#include + +#include "blob.h" +#include "addr.h" +#include "authdb.h" +#include "filterdb.h" + +/* Compile time configurables */ +#define SWITCH_HASH_SIZE 128 +#define PORT_HASH_SIZE 128 +#define CACHE_TIME 120 /* seconds */ + +/* Some helpers */ +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) +#define MAC_LEN 6 + +#define oid_const(oid) (oid), ARRAY_SIZE(oid) +#define oid_blob(b) ((oid *) (b).ptr), ((b).len / sizeof(oid)) + +/* Format specifiers for username type */ +#define FORMAT_CLIENT_IP 0x01 /* %I */ +#define FORMAT_CLIENT_MAC 0x02 /* %M */ +#define FORMAT_SWITCH_NAME 0x04 /* %N */ +#define FORMAT_SWITCH_LOCATION 0x08 /* %L */ +#define FORMAT_PORT_INDEX 0x10 /* %i */ +#define FORMAT_PORT_NAME 0x20 /* %n */ +#define FORMAT_PORT_DESCR 0x40 /* %d */ +#define FORMAT_PORT_WEBAUTH 0x80 /* %w */ + +/* Some info about the switch which we need */ +#define SWITCHF_NO_LLDP 0x01 +#define SWITCHF_BRIDGE_MIB_HAS_VLAN 0x02 + +/* IANA-AddressFamilyNumbers */ +#define IANA_AFN_OTHER 0 +#define IANA_AFN_IPV4 1 +#define IANA_AFN_IPV6 2 + +/* OIDs used by the program */ +static const oid SNMPv2_MIB_sysObjectID[] = + { SNMP_OID_MIB2, 1, 2, 0 }; +static const oid SNMPv2_MIB_sysName[] = + { SNMP_OID_MIB2, 1, 5, 0 }; +static const oid SNMPv2_MIB_sysLocation[] = + { SNMP_OID_MIB2, 1, 6, 0 }; +static const oid IF_MIB_ifDescr[] = + { SNMP_OID_MIB2, 2, 2, 1, 2 }; +static const oid IF_MIB_ifName[] = + { SNMP_OID_MIB2, 31, 1, 1, 1, 1 }; +static const oid IF_MIB_ifStackStatus[] = + { SNMP_OID_MIB2, 31, 1, 2, 1, 3 }; +static const oid IP_MIB_ipNetToPhysicalPhysAddress[] = + { SNMP_OID_MIB2, 4, 35, 1, 4 }; +static const oid BRIDGE_MIB_dot1dTpFdbPort[] = + { SNMP_OID_MIB2, 17, 4, 3, 1, 2 }; +static const oid LLDP_lldpLocSysName[] = + { 1, 0, 8802, 1, 1, 2, 1, 3, 3, 0 }; +static const oid LLDP_lldpRemManAddrIfSubtype[] = + { 1, 0, 8802, 1, 1, 2, 1, 4, 2, 1, 3 }; +static const oid HP_hpicfUsrAuthWebAuthSessionName[] = + { SNMP_OID_ENTERPRISES, 11, 2, 14, 11, 5, 1, 19, 5, 1, 1, 2 }; +static const oid HP_hpicfUsrAuthPortReauthenticate[] = + { SNMP_OID_ENTERPRISES, 11, 2, 14, 11, 5, 1, 19, 2, 1, 1, 4 }; +static const oid SEMI_MIB_hpHttpMgVersion[] = + { SNMP_OID_ENTERPRISES, 11, 2, 36, 1, 1, 2, 6, 0 }; + +/* ----------------------------------------------------------------- */ + +struct switch_info; + +static int num_queries = 0; +static int running = TRUE; +static int kick_out = FALSE; + +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; +static const char *snmp_community = NULL; +static const char *username_format = "%w"; +static struct switch_info *all_switches[SWITCH_HASH_SIZE]; +static struct switch_info *l3_root_dev, *l2_root_dev; +static int l3_if_ndx, l2_vlan_ndx; +static time_t current_time; +static int username_format_flags; + +static const blob_t space = BLOB_STR_INIT(" "); +static const blob_t lf = BLOB_STR_INIT("\n"); + +/* ----------------------------------------------------------------- */ + +#define BLOB_OID(objid) BLOB_BUF(objid) +#define BLOB_OID_DYN(objid,len) BLOB_PTR_LEN(objid, (len) * sizeof(oid)) + +static inline void blob_push_oid(blob_t *b, oid objid) +{ + if (b->len >= sizeof(objid)) { + *((oid*) b->ptr) = objid; + b->ptr += sizeof(oid); + b->len -= sizeof(oid); + } else { + *b = BLOB_NULL; + } +} + +static inline oid blob_pull_oid(blob_t *b) +{ + oid objid; + + if (b->len >= sizeof(objid)) { + objid = *((oid*) b->ptr); + b->ptr += sizeof(oid); + b->len -= sizeof(oid); + } else { + *b = BLOB_NULL; + objid = -1; + } + return objid; +} + +static inline void blob_push_oid_dump(blob_t *b, blob_t d) +{ + int i; + + if (b->len >= d.len * sizeof(oid)) { + for (i = 0; i < d.len; i++) { + *((oid*) b->ptr) = (unsigned char) d.ptr[i]; + b->ptr += sizeof(oid); + b->len -= sizeof(oid); + } + } else { + *b = BLOB_NULL; + } +} + +static inline void blob_pull_oid_dump(blob_t *b, blob_t d) +{ + int i; + + if (b->len >= d.len * sizeof(oid)) { + for (i = 0; i < d.len; i++) { + d.ptr[i] = (unsigned char) *((oid*) b->ptr); + b->ptr += sizeof(oid); + b->len -= sizeof(oid); + } + } else { + *b = BLOB_NULL; + } +} + +/* ----------------------------------------------------------------- */ + +void blob_push_iana_afn(blob_t *b, sockaddr_any *addr) +{ + unsigned char *ptr; + int type=0, len=0; + + switch (addr->any.sa_family) { + case AF_INET: + type = IANA_AFN_IPV4; + len = 4; + ptr = (unsigned char*) &addr->ipv4.sin_addr; + break; + } + if (type == 0 || b->len < len) { + *b = BLOB_NULL; + return; + } + blob_push_oid(b, type); + blob_push_oid(b, len); + blob_push_oid_dump(b, BLOB_PTR_LEN(ptr, len)); +} + +sockaddr_any *blob_pull_iana_afn(blob_t *b, sockaddr_any *addr) +{ + unsigned char *ptr = NULL; + int type, len; + + memset(addr, 0, sizeof(*addr)); + type = blob_pull_oid(b); + len = blob_pull_oid(b); + if (type == IANA_AFN_IPV4 && len == 4) { + addr->ipv4.sin_family = AF_INET; + ptr = (unsigned char*) &addr->ipv4.sin_addr; + } + if (ptr == NULL) { + blob_pull_skip(b, len); + return NULL; + } + blob_pull_oid_dump(b, BLOB_PTR_LEN(ptr, len)); + return addr; +} + +/* ----------------------------------------------------------------- */ + +static void safe_free(void *ptr) +{ + void **pptr = ptr; + if (*pptr != NULL) { + free(*pptr); + *pptr = NULL; + } +} + +struct cache_control { + time_t update_time; + struct auth_context * sleepers; +}; + +struct switch_port_info { + struct switch_port_info * next; + int port; + struct cache_control cache_control; + sockaddr_any link_partner; +}; + +struct switch_info { + struct switch_info * next; + sockaddr_any addr; + netsnmp_session * session; + + struct cache_control cache_control; + int flags; + int info_available; + char * system_name; + char * system_location; + char * system_version; + blob_t system_oid; + + struct switch_port_info * all_ports[PORT_HASH_SIZE]; +}; + +struct auth_context { + char * token; + sockaddr_any addr; + unsigned char mac[MAC_LEN]; + int info_available; + struct switch_info * current_switch; + struct switch_port_info *spi; + int local_port; + int lldp_port[8]; + int num_lldp_ports; + char * port_name; + char * port_descr; + char * webauth_name; + + void (*pending_operation)(struct auth_context *); + struct auth_context * next_sleeper; +}; + +static void cache_update_time(void) +{ + current_time = time(NULL); + adbc_refresh(&adbc, current_time); +} + +static int cache_refresh( + struct cache_control *cc, struct auth_context *auth, + void (*callback)(struct auth_context *auth)) +{ + int ret; + + if (cc->update_time == -1 || + cc->update_time + CACHE_TIME >= current_time) { + callback(auth); + return 0; + } + + auth->pending_operation = callback; + + ret = (cc->sleepers == NULL); + auth->next_sleeper = cc->sleepers; + cc->sleepers = auth; + + return ret; +} + +static void cache_update(struct cache_control *cc) +{ + struct auth_context *auth, *next; + + cc->update_time = current_time; + auth = cc->sleepers; + cc->sleepers = NULL; + for (; auth; auth = next) { + next = auth->next_sleeper; + auth->pending_operation(auth); + } +} + +static void cache_update_manual(struct cache_control *cc) +{ + cache_update(cc); + cc->update_time = -1; +} + +static void switch_info_free(struct switch_info *si) +{ + safe_free(&si->system_name); + safe_free(&si->system_location); + safe_free(&si->system_version); + safe_free(&si->system_oid.ptr); + si->info_available = 0; + si->flags = 0; +} + +struct switch_info *get_switch(sockaddr_any *addr) +{ + struct snmp_session config; + struct switch_info *si; + unsigned int bucket = addr_hash(addr) % ARRAY_SIZE(all_switches); + + for (si = all_switches[bucket]; si != NULL; si = si->next) + if (addr_cmp(&si->addr, addr) == 0) + return si; + + si = calloc(1, sizeof(*si)); + if (si == NULL) + return NULL; + + addr_copy(&si->addr, addr); + + snmp_sess_init(&config); + if (snmp_community != NULL) { + config.version = SNMP_VERSION_2c; + config.community = (unsigned char *) snmp_community; + config.community_len = strlen(snmp_community); + } + config.peername = (char *) addr_print(addr); + si->session = snmp_open(&config); + + si->next = all_switches[bucket]; + all_switches[bucket] = si; + + return si; +} + +struct switch_port_info *get_switch_port(struct switch_info *si, int port) +{ + unsigned int bucket = port % ARRAY_SIZE(si->all_ports); + struct switch_port_info *spi; + + if (si == NULL) + return NULL; + + for (spi = si->all_ports[bucket]; spi != NULL; spi = spi->next) + if (spi->port == port) + return spi; + + spi = calloc(1, sizeof(*spi)); + if (spi == NULL) + return NULL; + + spi->port = port; + spi->next = si->all_ports[bucket]; + si->all_ports[bucket] = spi; + + return spi; +} + +void link_switch(const char *a, int ap, const char *b, int bp) +{ + struct switch_info *sia, *sib; + struct switch_port_info *spia, *spib; + sockaddr_any addr; + + sia = get_switch(addr_parse(BLOB_STRLEN(a), &addr)); + spia = get_switch_port(sia, ap); + + sib = get_switch(addr_parse(BLOB_STRLEN(b), &addr)); + spib = get_switch_port(sib, bp); + + addr_copy(&spia->link_partner, &sib->addr); + addr_copy(&spib->link_partner, &sia->addr); + + cache_update_manual(&spia->cache_control); + cache_update_manual(&spib->cache_control); +} + +static void auth_query_switch_info(struct auth_context *auth); +static void auth_query_lldp(struct auth_context *auth, int root_query); + +static void auth_free(struct auth_context *auth) +{ + safe_free(&auth->token); + safe_free(&auth->port_name); + safe_free(&auth->port_descr); + safe_free(&auth->webauth_name); + free(auth); +} + +int resolve_ifName2ifIndex(struct switch_info *si, blob_t ifName) +{ + netsnmp_pdu *pdu, *response = NULL; + netsnmp_variable_list *vars, *lastvar = NULL; + int rc = -1; + + pdu = snmp_pdu_create(SNMP_MSG_GETBULK); + pdu->non_repeaters = 0; + pdu->max_repetitions = 10; + snmp_add_null_var(pdu, oid_const(IF_MIB_ifName)); + + do { + if (snmp_synch_response(si->session, pdu, &response) != 0) + return -1; + if (response->errstat != SNMP_ERR_NOERROR) + goto done; + + for (vars = response->variables; vars; vars = vars->next_variable) { + lastvar = vars; + + if (vars->name_length < ARRAY_SIZE(IF_MIB_ifName) || + memcmp(vars->name, IF_MIB_ifName, sizeof(IF_MIB_ifName)) != 0) + goto done; + + if (vars->type != ASN_OCTET_STR) + continue; + + if (blob_cmp(ifName, BLOB_PTR_LEN(vars->val.string, vars->val_len)) != 0) + continue; + + rc = vars->name[vars->name_length - 1]; + goto done; + } + + pdu = snmp_pdu_create(SNMP_MSG_GETBULK); + pdu->non_repeaters = 0; + pdu->max_repetitions = 10; + snmp_add_null_var(pdu, lastvar->name, lastvar->name_length); + + snmp_free_pdu(response); + response = NULL; + } while (1); + +done: + if (response) + snmp_free_pdu(response); + return rc; +} + + +static int parse_format(const char *fmt) +{ + int flags = 0; + const char *p = fmt; + + while ((p = strchr(p, '%')) != NULL) { + switch (p[1]) { + case 'I': + flags |= FORMAT_CLIENT_IP; + break; + case 'M': + flags |= FORMAT_CLIENT_MAC; + break; + case 'N': + flags |= FORMAT_SWITCH_NAME; + break; + case 'L': + flags |= FORMAT_SWITCH_LOCATION; + break; + case 'i': + flags |= FORMAT_PORT_INDEX; + break; + case 'n': + flags |= FORMAT_PORT_NAME; + break; + case 'd': + flags |= FORMAT_PORT_DESCR; + break; + case 'w': + flags |= FORMAT_PORT_WEBAUTH; + break; + } + p++; + } + return flags; +} + +static void blob_push_formatted_username( + blob_t *b, const char *fmt, struct auth_context *auth) +{ + const char *o = fmt, *p = fmt; + struct switch_info *si = auth->current_switch; + + while ((p = strchr(p, '%')) != NULL) { + blob_push(b, BLOB_PTR_LEN(o, p - o)); + switch (p[1]) { + case 'I': + blob_push(b, BLOB_STRLEN((char*) addr_print(&auth->addr))); + break; + case 'M': + blob_push_hexdump(b, BLOB_BUF(auth->mac)); + break; + case 'N': + blob_push(b, BLOB_STRLEN(si->system_name)); + break; + case 'L': + blob_push(b, BLOB_STRLEN(si->system_location)); + break; + case 'i': + blob_push_uint(b, auth->local_port, 10); + break; + case 'n': + blob_push(b, BLOB_STRLEN(auth->port_name)); + break; + case 'd': + blob_push(b, BLOB_STRLEN(auth->port_descr)); + break; + case 'w': + blob_push(b, BLOB_STRLEN(auth->webauth_name)); + break; + default: + o = p; + p++; + continue; + } + p += 2; + o = p; + } + blob_push(b, BLOB_STRLEN((char*) o)); +} + +static int auth_ok(struct auth_context *auth) +{ + return (auth->info_available & username_format_flags) == username_format_flags; +} + +static void auth_completed(struct auth_context *auth) +{ + char tmp[256]; + void *token; + struct authdb_entry entry; + blob_t b = BLOB_BUF(tmp), un; + + token = authdb_get(&adb, &auth->addr, &entry, 1); + authdb_clear_entry(&entry); + + blob_push(&b, BLOB_STRLEN(auth->token)); + if (auth_ok(auth)) { + if (token != NULL) { + un = BLOB_BUF(entry.p.login_name); + blob_push_formatted_username(&un, username_format, auth); + memcpy(entry.p.mac_address, auth->mac, MAC_LEN); + entry.p.switch_ip = auth->current_switch->addr; + entry.p.switch_port = auth->local_port; + authdb_commit_login(token, &entry, current_time, &adbc); + } + + blob_push(&b, BLOB_STR(" OK user=")); + blob_push_formatted_username(&b, username_format, auth); + blob_push(&b, BLOB_PTR_LEN("\n", 1)); + } else { + if (token != NULL) + authdb_commit_logout(token); + blob_push(&b, BLOB_STR(" ERR\n")); + } + b = blob_pushed(BLOB_BUF(tmp), b); + write(STDOUT_FILENO, b.ptr, b.len); + + auth_free(auth); + num_queries--; +} + +static void auth_talk_snmp(struct auth_context *auth, netsnmp_session *s, netsnmp_pdu *pdu, netsnmp_callback callback) +{ + if (snmp_async_send(s, pdu, callback, auth) == 0) { + snmp_free_pdu(pdu); + auth_completed(auth); + } +} + +static void cache_talk_snmp(struct cache_control *cc, netsnmp_session *s, netsnmp_pdu *pdu, netsnmp_callback callback, struct auth_context *auth) +{ + if (snmp_async_send(s, pdu, callback, auth) == 0) { + snmp_free_pdu(pdu); + cache_update(cc); + } +} + +static blob_t var_parse_type(netsnmp_variable_list **varptr, int asn_tag) +{ + netsnmp_variable_list *var = *varptr; + if (var == NULL) + return BLOB_NULL; + + *varptr = var->next_variable; + if (var->type != asn_tag) + return BLOB_NULL; + + return BLOB_PTR_LEN(var->val.string, var->val_len); +} + +static void auth_force_reauthentication(struct auth_context *auth) +{ + struct switch_info *si = auth->current_switch; + netsnmp_pdu *pdu; + oid query_oids[ARRAY_SIZE(HP_hpicfUsrAuthPortReauthenticate)+1]; + blob_t b = BLOB_BUF(query_oids); + long one = 1; + + pdu = snmp_pdu_create(SNMP_MSG_SET); + blob_push(&b, BLOB_OID(HP_hpicfUsrAuthPortReauthenticate)); + blob_push_oid(&b, auth->local_port); + b = blob_pushed(BLOB_OID(query_oids), b); + + snmp_pdu_add_variable(pdu, oid_blob(b), ASN_INTEGER, + (u_char *) &one, sizeof(one)); + + /* Send asynchornously - ignore response */ + if (snmp_send(si->session, pdu) == 0) + snmp_free_pdu(pdu); +} + +static int auth_handle_portinfo_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ + struct auth_context *auth = data; + netsnmp_variable_list *var; + + if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) + goto done; + + var = resp->variables; + if (username_format_flags & FORMAT_PORT_NAME) + auth->port_name = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + if (auth->port_name) + auth->info_available |= FORMAT_PORT_NAME; + if (username_format_flags & FORMAT_PORT_DESCR) + auth->port_descr = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + if (auth->port_descr) + auth->info_available |= FORMAT_PORT_DESCR; + if (username_format_flags & FORMAT_PORT_WEBAUTH) + auth->webauth_name = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + if (auth->webauth_name) + auth->info_available |= FORMAT_PORT_WEBAUTH; + +done: + if (kick_out && auth_ok(auth)) + auth_force_reauthentication(auth); + + auth_completed(auth); + return 1; +} + +static void auth_query_port_info(struct auth_context *auth) +{ + struct switch_info *si = auth->current_switch; + netsnmp_pdu *pdu; + oid query_oids[MAX_OID_LEN]; + blob_t query; + + if (auth_ok(auth)) { + auth_completed(auth); + return; + } + + pdu = snmp_pdu_create(SNMP_MSG_GET); + if (username_format_flags & FORMAT_PORT_NAME) { + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(IF_MIB_ifName)); + blob_push_oid(&query, auth->local_port); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + } + if (username_format_flags & FORMAT_PORT_DESCR) { + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(IF_MIB_ifDescr)); + blob_push_oid(&query, auth->local_port); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + } + if (username_format_flags & FORMAT_PORT_WEBAUTH) { + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(HP_hpicfUsrAuthWebAuthSessionName)); + blob_push_oid(&query, auth->local_port); + blob_push_oid_dump(&query, BLOB_BUF(auth->mac)); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + } + auth_talk_snmp(auth, si->session, pdu, auth_handle_portinfo_reply); +} + +static int auth_handle_lldp_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ + struct auth_context *auth = data; + struct switch_port_info *spi = auth->spi; + netsnmp_variable_list *var = resp->variables; + blob_t res; + int i; + + if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) + goto fail; + + /* print_variable(var->name, var->name_length, var); */ + + for (i = 0; i < auth->num_lldp_ports; i++) { + if (var == NULL) + goto fail; + if (var->type != ASN_INTEGER) + continue; + /* INDEX: TimeFilter, Port, Idx, Family, Addr */ + res = BLOB_OID_DYN(var->name, var->name_length); + if (blob_pull_matching(&res, BLOB_OID(LLDP_lldpRemManAddrIfSubtype)) && + blob_pull_oid(&res) == 0 && + blob_pull_oid(&res) == auth->lldp_port[i]) { + /* We have mathing LLDP neighbor */ + blob_pull_oid(&res); + blob_pull_iana_afn(&res, &spi->link_partner); + cache_update(&spi->cache_control); + return 1; + } + + var = var->next_variable; + } + auth->num_lldp_ports = 0; + for (; var; var = var->next_variable) { + if (var->type != ASN_INTEGER) + break; + /* print_variable(var->name, var->name_length, var); */ + res = BLOB_OID_DYN(var->name, var->name_length); + if (!blob_pull_matching(&res, BLOB_OID(IF_MIB_ifStackStatus))) + break; + if (blob_pull_oid(&res) != auth->local_port) + break; + auth->lldp_port[auth->num_lldp_ports++] = blob_pull_oid(&res); + if (auth->num_lldp_ports >= ARRAY_SIZE(auth->lldp_port)) + break; + } + if (auth->num_lldp_ports) { + auth_query_lldp(auth, FALSE); + return 1; + } +fail: + cache_update(&spi->cache_control); + return 1; +} + +static void auth_query_lldp(struct auth_context *auth, int root_query) +{ + struct switch_info *si = auth->current_switch; + struct switch_port_info *spi = auth->spi; + netsnmp_pdu *pdu; + oid query_oids[MAX_OID_LEN]; + blob_t query; + int i; + + /* printf("Query LLDP info for %s:%d\n", addr_print(&si->addr), spi->port); */ + + if (si->flags & SWITCHF_NO_LLDP) { + memset(&spi->link_partner, 0, sizeof(spi->link_partner)); + cache_update(&spi->cache_control); + return; + } + + if (root_query) { + auth->num_lldp_ports = 1; + auth->lldp_port[0] = auth->local_port; + } + + pdu = snmp_pdu_create(SNMP_MSG_GETBULK); + pdu->non_repeaters = auth->num_lldp_ports; + pdu->max_repetitions = 8; + + for (i = 0; i < auth->num_lldp_ports; i++) { + /* Query LLDP neighbor. lldpRemManAddrTable is INDEXed with + * [TimeFilter, LocalPort, Index, AddrSubType, Addr] */ + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(LLDP_lldpRemManAddrIfSubtype)); + blob_push_oid(&query, 0); + blob_push_oid(&query, auth->lldp_port[i]); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + } + + if (root_query) { + /* Query interface stacking in case this is aggregated trunk: + * IF-MIB::ifStackStatus.. */ + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(IF_MIB_ifStackStatus)); + blob_push_oid(&query, auth->local_port); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + } + + cache_talk_snmp(&spi->cache_control, si->session, pdu, auth_handle_lldp_reply, auth); +} + +static void auth_check_spi(struct auth_context *auth) +{ + struct switch_port_info *spi = auth->spi; + + if (addr_len(&spi->link_partner) != 0) { + auth->current_switch = get_switch(&spi->link_partner); + auth_query_switch_info(auth); + } else { + auth_query_port_info(auth); + } +} + +static int auth_handle_fib_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ + struct auth_context *auth = data; + struct switch_info *si = auth->current_switch; + struct switch_port_info *spi; + netsnmp_variable_list *var; + + if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) + goto failed; + + var = resp->variables; + /* print_variable(var->name, var->name_length, var); */ + + if (var->type != ASN_INTEGER) + goto failed; + + auth->local_port = *var->val.integer; + auth->info_available |= FORMAT_PORT_INDEX; + auth->spi = spi = get_switch_port(si, auth->local_port); + if (cache_refresh(&spi->cache_control, auth, auth_check_spi)) + auth_query_lldp(auth, TRUE); + return 1; + + /* No further info available */ +failed: + auth_completed(auth); + return 1; +} + +static void auth_query_fib(struct auth_context *auth) +{ + oid query_oids[MAX_OID_LEN]; + blob_t query; + struct switch_info *si = auth->current_switch; + netsnmp_pdu *pdu; + + auth->info_available |= si->info_available; + + /* printf("Probing switch %s\n", addr_print(&si->addr)); */ + + pdu = snmp_pdu_create(SNMP_MSG_GET); + + /* FIXME: Implement Q-BRIDGE-MIB query too. */ + + /* BRIDGE-MIB::dot1dTpFdbPort. = INTEGER: port */ + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(BRIDGE_MIB_dot1dTpFdbPort)); + if (si->flags & SWITCHF_BRIDGE_MIB_HAS_VLAN) + blob_push_oid(&query, l2_vlan_ndx); + blob_push_oid_dump(&query, BLOB_BUF(auth->mac)); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + + auth_talk_snmp(auth, si->session, pdu, auth_handle_fib_reply); +} + +static int auth_handle_switch_info_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ + static const oid HP_ICF_OID_hpEtherSwitch[] = + { SNMP_OID_ENTERPRISES, 11, 2, 3, 7, 11 }; + struct auth_context *auth = data; + struct switch_info *si = auth->current_switch; + netsnmp_variable_list *var; + blob_t b; + + switch_info_free(si); + var = resp->variables; + si->system_name = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + si->system_location = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + si->system_oid = blob_dup(var_parse_type(&var, ASN_OBJECT_ID)); + si->system_version = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); + if (blob_is_null(var_parse_type(&var, ASN_OCTET_STR))) + si->flags |= SWITCHF_NO_LLDP; + if (si->system_name) + si->info_available |= FORMAT_SWITCH_NAME; + if (si->system_location) + si->info_available |= FORMAT_SWITCH_LOCATION; + b = si->system_oid; + if (blob_pull_matching(&b, BLOB_OID(HP_ICF_OID_hpEtherSwitch))) { + /* Hewlett-Packard ProCurve Switches */ + switch (blob_pull_oid(&b)) { + case 104: /* 1810G-24 with system_version && + si->system_version[0] == 'P' && + si->system_version[1] == '.' && + si->system_version[2] <= '1') + si->flags |= SWITCHF_BRIDGE_MIB_HAS_VLAN; + break; + } + } + cache_update(&si->cache_control); + return 1; +} + +static void auth_query_switch_info(struct auth_context *auth) +{ + struct switch_info *si = auth->current_switch; + netsnmp_pdu *pdu; + + auth->info_available &= + ~(FORMAT_SWITCH_NAME | FORMAT_SWITCH_LOCATION | + FORMAT_PORT_INDEX); + + if (!cache_refresh(&si->cache_control, auth, auth_query_fib)) + return; + + pdu = snmp_pdu_create(SNMP_MSG_GET); + snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysName)); + snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysLocation)); + snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysObjectID)); + snmp_add_null_var(pdu, oid_const(SEMI_MIB_hpHttpMgVersion)); + snmp_add_null_var(pdu, oid_const(LLDP_lldpLocSysName)); + cache_talk_snmp(&si->cache_control, si->session, pdu, auth_handle_switch_info_reply, auth); +} + +static int auth_handle_arp_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ + struct auth_context *auth = data; + netsnmp_variable_list *var = resp->variables; + + if (oper == NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE && + var->type == ASN_OCTET_STR && + var->val_len == MAC_LEN) { + memcpy(auth->mac, var->val.string, MAC_LEN); + auth->info_available |= FORMAT_CLIENT_MAC; + if (!auth_ok(auth)) { + auth->current_switch = l2_root_dev; + auth_query_switch_info(auth); + return 1; + } + } + + auth_completed(auth); + return 1; +} + +void start_authentication(blob_t token, blob_t ip) +{ + struct auth_context *auth; + oid query_oids[MAX_OID_LEN]; + blob_t query; + netsnmp_pdu *pdu; + + num_queries++; + + auth = calloc(1, sizeof(*auth)); + auth->token = blob_cstr_dup(token); + if (addr_parse(ip, &auth->addr) == NULL) { + auth_completed(auth); + return; + } + auth->info_available = FORMAT_CLIENT_IP; + + /* IP-MIB::ipNetToPhysicalPhysAddress..ipv4."1.2.3.4" + * = STRING: 01:12:34:56:78:9a */ + pdu = snmp_pdu_create(SNMP_MSG_GET); + + query = BLOB_OID(query_oids); + blob_push(&query, BLOB_OID(IP_MIB_ipNetToPhysicalPhysAddress)); + blob_push_oid(&query, l3_if_ndx); + blob_push_iana_afn(&query, &auth->addr); + query = blob_pushed(BLOB_OID(query_oids), query); + snmp_add_null_var(pdu, oid_blob(query)); + + auth_talk_snmp(auth, l3_root_dev->session, pdu, auth_handle_arp_reply); +} + +static void handle_line(blob_t line) +{ + blob_t id, ipaddr; + + id = blob_pull_cspn(&line, space); + blob_pull_spn(&line, space); + ipaddr = blob_pull_cspn(&line, space); + + start_authentication(id, ipaddr); +} + +static void read_input(void) +{ + static char buffer[256]; + static blob_t left; + + blob_t b, line; + int r; + + if (blob_is_null(left)) + left = BLOB_BUF(buffer); + + r = read(STDIN_FILENO, left.ptr, left.len); + if (r < 0) + return; + if (r == 0) { + running = 0; + return; + } + left.ptr += r; + left.len -= r; + + b = blob_pushed(BLOB_BUF(buffer), left); + do { + line = blob_pull_cspn(&b, lf); + if (!blob_pull_matching(&b, lf)) + return; + + handle_line(line); + + if (b.len) { + memcpy(buffer, b.ptr, b.len); + b.ptr = buffer; + } + left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); + } while (b.len); +} + +void load_topology(const char *file) +{ + char a_ip[64], b_ip[64]; + int a_port, b_port; + FILE *in; + + in = fopen(file, "r"); + if (in == NULL) + return; + + while (!feof(in)) { + if (fscanf(in, "%s %d %s %d\n", + a_ip, &a_port, b_ip, &b_port) == 4) + link_switch(a_ip, a_port, b_ip, b_port); + } + fclose(in); +} + +int main(int argc, char **argv) +{ + const char *l3_root = NULL, *l3_ifname = NULL; + const char *l2_root = NULL, *l2_vlan = NULL; + struct timeval timeout; + sockaddr_any addr; + fd_set fdset; + int opt, fds, block, i; + + setenv("MIBS", "", 1); + init_snmp("squark-auth"); + + while ((opt = getopt(argc, argv, "c:r:i:R:v:f:T:K")) != -1) { + switch (opt) { + case 'c': + snmp_community = optarg; + break; + case 'r': + l3_root = optarg; + break; + case 'i': + l3_ifname = optarg; + break; + case 'R': + l2_root = optarg; + break; + case 'v': + l2_vlan = optarg; + break; + case 'f': + username_format = optarg; + break; + case 'T': + load_topology(optarg); + break; + case 'K': + kick_out = TRUE; + break; + } + } + argc -= optind; + argv += optind; + + if (l3_root == NULL || l3_ifname == NULL || l2_vlan == NULL) { + printf("Mandatory information missing\n"); + return 1; + } + + sqdb_open(&db, "/var/lib/squark/squark.db"); + authdb_open(&adb, &adbc, &db); + + if (l2_root == NULL) + l2_root = l3_root; + + l3_root_dev = get_switch(addr_parse(BLOB_STRLEN(l3_root), &addr)); + l3_if_ndx = resolve_ifName2ifIndex(l3_root_dev, BLOB_STRLEN((char *) l3_ifname)); + l2_root_dev = get_switch(addr_parse(BLOB_STRLEN(l2_root), &addr)); + l2_vlan_ndx = atoi(l2_vlan); + username_format_flags = parse_format(username_format); + + if (kick_out) + username_format_flags |= FORMAT_PORT_WEBAUTH; + + for (i = 0; i < argc; i++) { + blob_t b = BLOB_STRLEN(argv[i]); + start_authentication(b, b); + running = FALSE; + } + + fcntl(STDIN_FILENO, F_SETFL, O_NONBLOCK); + while (num_queries || running) { + fds = 0; + block = 1; + + FD_ZERO(&fdset); + if (running) { + FD_SET(STDIN_FILENO, &fdset); + fds = STDIN_FILENO + 1; + } + snmp_select_info(&fds, &fdset, &timeout, &block); + fds = select(fds, &fdset, NULL, NULL, block ? NULL : &timeout); + cache_update_time(); + if (fds) { + if (FD_ISSET(STDIN_FILENO, &fdset)) + read_input(); + snmp_read(&fdset); + } else + snmp_timeout(); + } + authdb_close(&adb); + sqdb_close(&db); + + return 0; +} diff --git a/src/squark-filter.c b/src/squark-filter.c new file mode 100644 index 0000000..995da40 --- /dev/null +++ b/src/squark-filter.c @@ -0,0 +1,431 @@ +/* squark-filter.c - Squid User Authentication and Rating Kit + * An external redirector for Squid which analyzes the URL according + * to a database and can redirect to a block page. + * + * Copyright (C) 2010 Timo Teräs + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +#include +#include +#include +#include +#include + +#include + +#include "blob.h" +#include "addr.h" +#include "filterdb.h" +#include "authdb.h" + +#define FILTER_OVERRIDE_TIMEOUT (15*60) + +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; + +static int running = 1; +static const blob_t dash = BLOB_STR_INIT("-"); +static const blob_t space = BLOB_STR_INIT(" "); +static const blob_t slash = BLOB_STR_INIT("/"); +static const blob_t lf = BLOB_STR_INIT("\n"); +static struct authdb adb; +static time_t now; + +struct url_info { + blob_t protocol; + blob_t username; + blob_t password; + blob_t host; + blob_t significant_host; + blob_t path; + blob_t query; + blob_t fragment; + int port; + int is_ipv4; + int num_dots; +}; + +struct url_dns_part_data { + blob_t word; + int num_dots; + int numeric; +}; + +void blob_pull_url_dns_part(blob_t *b, struct url_dns_part_data *udp) +{ + blob_t t = *b; + int c, i, dots = 0, numeric = 1; + + for (i = 0; i < t.len; i++) { + c = (unsigned char) t.ptr[i]; + switch (c) { + case '.': + dots++; + break; + case ':': case '@': case '/': case '?': + *b = BLOB_PTR_LEN(t.ptr + i, t.len - i); + udp->word = BLOB_PTR_LEN(t.ptr, i); + udp->num_dots = dots; + udp->numeric = numeric; + return; + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + break; + default: + numeric = 0; + break; + } + } + + *b = BLOB_NULL; + udp->word = t; + udp->num_dots = dots; + udp->numeric = numeric; +} + +/* URI is generalized as: + * [proto://][user[:password]@]domain.name[:port][/[path/to][?p=a&q=b;r=c][#fragment]] + * Character literals used as separators are: + * : / @ ? & ; # + * Also URI escaping says to treat %XX as encoded hex value. + */ + +static int url_parse(blob_t uri, struct url_info *nfo) +{ + struct url_dns_part_data prev, cur; + + memset(&prev, 0, sizeof(prev)); + memset(nfo, 0, sizeof(*nfo)); + + /* parse protocol, username/password and domain name/port */ + do { + blob_pull_url_dns_part(&uri, &cur); + + switch (uri.len ? uri.ptr[0] : '/') { + case ':': + blob_pull_skip(&uri, 1); + if (blob_is_null(nfo->protocol) && + blob_pull_matching(&uri, BLOB_STR("//"))) + nfo->protocol = cur.word; + else + prev = cur; + break; + case '@': + blob_pull_skip(&uri, 1); + if (!blob_is_null(nfo->username) || + !blob_is_null(nfo->password)) + goto error; + if (!blob_is_null(prev.word)) { + nfo->username = prev.word; + nfo->password = cur.word; + } else + nfo->username = cur.word; + memset(&prev, 0, sizeof(prev)); + break; + case '/': + case '?': + if (!blob_is_null(prev.word)) { + nfo->host = prev.word; + nfo->num_dots = prev.num_dots; + nfo->is_ipv4 = prev.numeric && prev.num_dots == 3; + nfo->port = blob_pull_uint(&cur.word, 10); + } else { + nfo->host = cur.word; + nfo->num_dots = cur.num_dots; + nfo->is_ipv4 = cur.numeric && cur.num_dots == 3; + } + if (blob_is_null(nfo->host)) + nfo->host = BLOB_STR("localhost"); + break; + } + } while (blob_is_null(nfo->host) && !blob_is_null(uri)); + + /* rest of the components */ + nfo->path = blob_pull_cspn(&uri, BLOB_STR("?&;#")); + nfo->query = blob_pull_cspn(&uri, BLOB_STR("#")); + nfo->fragment = uri; + + /* fill in defaults if needed */ + if (blob_is_null(nfo->protocol)) { + if (nfo->port == 443) + nfo->protocol = BLOB_STR("https"); + else + nfo->protocol = BLOB_STR("http"); + if (nfo->port == 0) + nfo->port = 80; + } else if (nfo->port == 0) { + if (blob_cmp(nfo->protocol, BLOB_STR("https")) == 0) + nfo->port = 443; + else + nfo->port = 80; + } + if (blob_is_null(nfo->path)) + nfo->path = BLOB_STR("/"); + + /* significant host name */ + nfo->significant_host = nfo->host; + if (nfo->num_dots > 1) { + blob_t b = nfo->significant_host; + if (blob_pull_matching(&b, BLOB_STR("www")) && + (blob_pull_uint(&b, 10), 1) && + blob_pull_matching(&b, BLOB_STR("."))) + nfo->significant_host = b; + } + return 1; + +error: + return 0; +} + +static void url_print(struct url_info *nfo) +{ +#define print_field(nfo, x) if (!blob_is_null(nfo->x)) printf(" %s{%.*s}", #x, nfo->x.len, nfo->x.ptr) + print_field(nfo, protocol); + print_field(nfo, username); + print_field(nfo, password); + print_field(nfo, host); + printf(" port{%d}", nfo->port); + print_field(nfo, path); + print_field(nfo, query); + print_field(nfo, fragment); +#undef print_field + printf("\n"); + fflush(stdout); +} + +static int url_classify(struct url_info *url, struct sqdb *db) +{ + unsigned char buffer[512]; + blob_t key, got, tld, keybuf, keylimits; + void *cmph; + struct sqdb_index_entry *indx; + cmph_uint32 i = SQDB_PARENT_ROOT, previ = SQDB_PARENT_ROOT; + int dots_done = 1; + + cmph = sqdb_section_get(db, SQDB_SECTION_INDEX_MPH, NULL); + indx = sqdb_section_get(db, SQDB_SECTION_INDEX, NULL); + + keybuf = BLOB_BUF(buffer); + blob_push_lower(&keybuf, url->significant_host); + key = keylimits = blob_pushed(BLOB_BUF(buffer), keybuf); + + /* search for most qualified domain match; do first lookup + * with two domain components */ + if (url->is_ipv4) { + i = cmph_search_packed(cmph, key.ptr, key.len); + + if (indx[i].parent != SQDB_PARENT_IPV4 || + indx[i].component != blob_inet_addr(url->host)) { + i = previ; + goto parent_dns_match; + } + } else { + key = BLOB_PTR_LEN(key.ptr + key.len, 0); + tld = blob_expand_head(&key, keylimits, '.'); + + do { + /* add one more domain component */ + got = blob_expand_head(&key, keylimits, '.'); + if (blob_is_null(got)) + break; + + previ = i; + i = cmph_search_packed(cmph, key.ptr, key.len); + if (!blob_is_null(tld)) { + int p = indx[i].parent; + + if (p == SQDB_PARENT_ROOT || + p == SQDB_PARENT_IPV4 || + indx[p].parent != SQDB_PARENT_ROOT || + blob_cmp(tld, sqdb_get_string_literal(db, indx[p].component)) != 0) { + /* top level domain did not match */ + i = -1; + goto parent_dns_match; + } + tld = BLOB_NULL; + previ = p; + } + if (indx[i].parent != previ || + blob_cmp(got, sqdb_get_string_literal(db, indx[i].component)) != 0) { + /* the subdomain did no longer match, use + * parents classification */ + i = previ; + goto parent_dns_match; + } + dots_done++; + } while (indx[i].has_subdomains); + } + + /* No paths to match for */ + if (i == SQDB_PARENT_ROOT || !indx[i].has_paths || key.ptr != keylimits.ptr) + goto parent_dns_match; + + /* and then search for path matches -- construct hashing + * string of url decoded path */ + blob_push_urldecode(&keybuf, url->path); + key = keylimits = blob_pushed(BLOB_BUF(buffer), keybuf); + + while (indx[i].has_paths) { + /* add one more path component */ + got = blob_expand_tail(&key, keylimits, '/'); + if (blob_is_null(got)) + break; + previ = i; + i = cmph_search_packed(cmph, key.ptr, key.len); + tld = sqdb_get_string_literal(db, indx[i].component); + if (blob_cmp(got, sqdb_get_string_literal(db, indx[i].component)) != 0) { + /* the subdomain did no longer match, use + * parents classification */ + i = previ; + goto parent_dns_match; + } + } + +parent_dns_match: + if (i == SQDB_PARENT_ROOT) + return 0; /* no category */ + + return indx[i].category; +} + +static blob_t get_category_name(struct sqdb *db, int id) +{ + uint32_t *c, clen; + + c = sqdb_section_get(db, SQDB_SECTION_CATEGORIES, &clen); + if (c == NULL || id < 0 || id * sizeof(uint32_t) >= clen) + return BLOB_NULL; + + return sqdb_get_string_literal(db, c[id]); +} + +static void send_ok(blob_t tag) +{ + static char buffer[64]; + blob_t b = BLOB_BUF(buffer); + + blob_push(&b, tag); + blob_push(&b, lf); + b = blob_pushed(BLOB_BUF(buffer), b); + + write(STDOUT_FILENO, b.ptr, b.len); +} + +static void send_redirect(blob_t redirect_page, blob_t tag, blob_t url, blob_t categ, blob_t username) +{ + static char buffer[8*1024]; + blob_t b = BLOB_BUF(buffer); + + blob_push(&b, tag); + blob_push(&b, BLOB_STR(" 302:")); + blob_push(&b, adbc.redirect_url_base); + blob_push(&b, redirect_page); + blob_push(&b, BLOB_STR("?REASON=")); + blob_push_urlencode(&b, categ); + blob_push(&b, BLOB_STR("&USER=")); + blob_push_urlencode(&b, username); + blob_push(&b, BLOB_STR("&DENIEDURL=")); + blob_push_urlencode(&b, url); + blob_push(&b, lf); + b = blob_pushed(BLOB_BUF(buffer), b); + + write(STDOUT_FILENO, b.ptr, b.len); +} + +static void read_input(struct sqdb *db) +{ + static char buffer[8 * 1024]; + static blob_t left; + + blob_t b, line, id, ipaddr, url, username; + struct url_info nfo; + int r, category, auth_ok; + sockaddr_any addr; + struct authdb_entry entry; + void *token; + + if (blob_is_null(left)) + left = BLOB_BUF(buffer); + + r = read(STDIN_FILENO, left.ptr, left.len); + if (r < 0) + return; + if (r == 0) { + running = 0; + return; + } + left.ptr += r; + left.len -= r; + + now = time(NULL); + + b = blob_pushed(BLOB_BUF(buffer), left); + do { + line = blob_pull_cspn(&b, lf); + if (!blob_pull_matching(&b, lf)) + return; + + id = blob_pull_cspn(&line, space); + blob_pull_spn(&line, space); + url = blob_pull_cspn(&line, space); + blob_pull_spn(&line, space); + ipaddr = blob_pull_cspn(&line, slash); /* client addr */ + blob_pull_cspn(&line, space); /* fqdn */ + blob_pull_spn(&line, space); + username = blob_pull_cspn(&line, space); + /* http method */ + /* urlgroup */ + /* myaddr=xxx myport=xxx etc */ + + if (!blob_is_null(url) && + addr_parse(ipaddr, &addr)) { + /* valid request, handle it */ + if (url_parse(url, &nfo)) + category = url_classify(&nfo, db); + else + category = 0; + + token = authdb_get(&adb, &addr, &entry, 1); + if (authdb_check_login(token, &entry, username, now)) { + auth_ok = 1; + username = BLOB_STRLEN(entry.p.login_name); + } else { + auth_ok = 0; + } + + if (!auth_ok) { + send_redirect(BLOB_STR("login.cgi"), id, url, BLOB_STR("auth"), username); + } else if (((1ULL << category) & entry.p.block_categories) && + (now < entry.override_time || + now > entry.override_time + FILTER_OVERRIDE_TIMEOUT || + ((1ULL << category) & entry.p.hard_block_categories))) { + send_redirect(BLOB_STR("warning.cgi"), id, url, get_category_name(db, category), username); + } else + send_ok(id); + } + + if (b.len) { + memcpy(buffer, b.ptr, b.len); + b.ptr = buffer; + } + left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); + } while (b.len); +} + +int main(int argc, char **argv) +{ + sqdb_open(&db, "/var/lib/squark/squark.db"); + authdb_open(&adb, &adbc, &db); + + while (running) + read_input(&db); + + sqdb_close(&db); + authdb_close(&adb); +} -- cgit v1.2.3