diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/Makefile | 30 | ||||
| -rw-r--r-- | src/addr.c | 74 | ||||
| -rw-r--r-- | src/addr.h | 32 | ||||
| -rw-r--r-- | src/authdb.c | 364 | ||||
| -rw-r--r-- | src/authdb.h | 62 | ||||
| -rw-r--r-- | src/blob.c | 426 | ||||
| -rw-r--r-- | src/blob.h | 63 | ||||
| -rw-r--r-- | src/filterdb.c | 157 | ||||
| -rw-r--r-- | src/filterdb.h | 59 | ||||
| -rw-r--r-- | src/lua-squarkdb.c | 343 | ||||
| -rwxr-xr-x | src/sqdb-build.lua | 335 | ||||
| -rw-r--r-- | src/squark-auth-ip.c | 218 | ||||
| -rw-r--r-- | src/squark-auth-snmp.c | 1152 | ||||
| -rw-r--r-- | src/squark-filter.c | 431 | 
14 files changed, 3746 insertions, 0 deletions
| diff --git a/src/Makefile b/src/Makefile new file mode 100644 index 0000000..db683a2 --- /dev/null +++ b/src/Makefile @@ -0,0 +1,30 @@ +TARGETS=squark-auth-snmp squark-auth-ip squark-filter squarkdb.so + +NETSNMP_CFLAGS:=$(shell net-snmp-config --cflags) +NETSNMP_LIBS:=$(shell net-snmp-config --libs) +LUA_CFLAGS:=$(shell pkg-config --cflags lua5.1) +LUA_LIBS:=$(shell pkg-config --libs lua5.1) +CMPH_CFLAGS:=$(shell pkg-config --cflags cmph) +CMPH_LIBS:=$(shell pkg-config --libs cmph) + +CC=gcc +CFLAGS=-g -I. $(NETSNMP_CFLAGS) $(LUA_CFLAGS) $(CMPH_CFLAGS) -std=gnu99 -D_GNU_SOURCE -Wall +LIBS+=-lrt + +all: $(TARGETS) + +squark-auth-snmp: squark-auth-snmp.o filterdb.o authdb.o blob.o addr.o +	$(CC) -o $@ $^ $(NETSNMP_LIBS) $(LIBS) + +squark-auth-ip: squark-auth-ip.o filterdb.o authdb.o blob.o addr.o +	$(CC) -o $@ $^ $(LIBS) + +squark-filter: squark-filter.o filterdb.o authdb.o blob.o addr.o +	$(CC) -o $@ $^ $(CMPH_LIBS) $(LIBS) + +squarkdb.so: lua-squarkdb.o filterdb.o blob.o +	$(CC) -shared -o $@ $^ $(LUA_LIBS) $(CMPH_LIBS) $(LIBS) + +clean: +	rm $(OBJS1) $(TARGETS) + diff --git a/src/addr.c b/src/addr.c new file mode 100644 index 0000000..47013f2 --- /dev/null +++ b/src/addr.c @@ -0,0 +1,74 @@ +#include <stdio.h> +#include <string.h> + +#include "addr.h" + +int addr_len(const sockaddr_any *addr) +{ +	switch (addr->any.sa_family) { +	case AF_INET: +		return sizeof(struct sockaddr_in); +	default: +		return 0; +	} +} + +sockaddr_any *addr_parse(blob_t b, sockaddr_any *addr) +{ +	memset(addr, 0, sizeof(*addr)); +	addr->ipv4.sin_family = AF_INET; +	addr->ipv4.sin_addr.s_addr = blob_inet_addr(b); +	if (addr->ipv4.sin_addr.s_addr == -1) +		return NULL; +	return addr; +} + +unsigned long addr_hash(const sockaddr_any *addr) +{ +	switch (addr->any.sa_family) { +	case AF_INET: +		return htonl(addr->ipv4.sin_addr.s_addr); +	default: +		return 0; +	} +} + +const char *addr_print(const sockaddr_any *addr) +{ +	switch (addr->any.sa_family) { +	case AF_INET: +		return inet_ntoa(addr->ipv4.sin_addr); +	default: +		return "unknown"; +	} +} + +blob_t addr_get_hostaddr_blob(const sockaddr_any *addr) +{ +	switch (addr->any.sa_family) { +	case AF_INET: +		return BLOB_BUF(&addr->ipv4.sin_addr); +	default: +		return BLOB_NULL; +	} +} + +void addr_push_hostaddr(blob_t *b, const sockaddr_any *addr) +{ +	char buf[64]; +	blob_t f; +	unsigned int t; + +	switch (addr->any.sa_family) { +	case AF_INET: +		t = ntohl(addr->ipv4.sin_addr.s_addr); +		f.ptr = buf; +		f.len = sprintf(buf, "%d.%d.%d.%d", +				(t      ) & 0xff, (t >>  8) & 0xff, +				(t >> 16) & 0xff, (t >> 24) & 0xff); +		break; +	default: +		return; +	} +	blob_push(b, f); +} diff --git a/src/addr.h b/src/addr.h new file mode 100644 index 0000000..452d14b --- /dev/null +++ b/src/addr.h @@ -0,0 +1,32 @@ +#ifndef ADDR_H +#define ADDR_H + +#include <arpa/inet.h> +#include "blob.h" + +typedef union { +	struct sockaddr		any; +	struct sockaddr_in	ipv4; +} sockaddr_any; + +int addr_len(const sockaddr_any *addr); +sockaddr_any *addr_parse(blob_t text, sockaddr_any *addr); +unsigned long addr_hash(const sockaddr_any *addr); +const char *addr_print(const sockaddr_any *addr); +blob_t addr_get_hostaddr_blob(const sockaddr_any *addr); +void addr_push_hostaddr(blob_t *b, const sockaddr_any *addr); + +static inline void addr_copy(sockaddr_any *dst, const sockaddr_any *src) +{ +	memcpy(dst, src, addr_len(src)); +} + +static inline int addr_cmp(const sockaddr_any *a, const sockaddr_any *b) +{ +	if (a->any.sa_family != b->any.sa_family) +		return -1; +	return memcmp(a, b, addr_len(a)); +} + +#endif + diff --git a/src/authdb.c b/src/authdb.c new file mode 100644 index 0000000..e6e71c4 --- /dev/null +++ b/src/authdb.c @@ -0,0 +1,364 @@ +#include <sys/mman.h> +#include <sys/stat.h> +#include <unistd.h> +#include <malloc.h> +#include <sched.h> +#include <fcntl.h> +#include <ctype.h> +#include <time.h> +#include <grp.h> + +#include "authdb.h" +#include "filterdb.h" +#include "addr.h" +#include "blob.h" + +#define ALIGN(s,a)		(((s) + a - 1) & ~(a - 1)) + +#define AUTHDB_IP_PER_ME		256 +#define AUTHDB_LOGOFF_PERIOD		(15*60)		/* 15 mins */ +#define AUTHDB_SHM_SIZE			ALIGN(sizeof(struct authdb_entry[AUTHDB_IP_PER_ME]), 4096) + +static struct authdb_map_entry *authdb_me_open(sockaddr_any *addr, int create) +{ +	int oflag, fd; +	char name[64], buf[256]; +	blob_t b = BLOB_BUF(name); +	void *base; +	struct authdb_map_entry *me; +	struct group grp, *res; + +	blob_push(&b, BLOB_STR("squark-auth-")); +	blob_push_hexdump(&b, addr_get_hostaddr_blob(addr)); +	blob_push_byte(&b, 0); + +	oflag = O_RDWR; +	if (create) +		oflag |= O_CREAT; + +	fd = shm_open(name, oflag, 0660); +	if (fd < 0) +		return NULL; + +	if (ftruncate(fd, AUTHDB_SHM_SIZE) < 0) { +		close(fd); +		return NULL; +	} + +	getgrnam_r("squark", &grp, buf, sizeof(buf), &res); +	if (res != NULL) { +		fchown(fd, -1, res->gr_gid); +		fchmod(fd, 0660); +	} + +	base = mmap(NULL, AUTHDB_SHM_SIZE, +		    PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); +	close(fd); + +	if (base == MAP_FAILED) +		return NULL; + +	me = malloc(sizeof(*me)); +	if (me == NULL) { +		munmap(base, AUTHDB_SHM_SIZE); +		return NULL; +	} + +	me->next = NULL; +	me->baseaddr = *addr; +	me->entries = base; + +	return me; +} + +static void authdb_me_free(struct authdb_map_entry *me) +{ +	munmap(me->entries, AUTHDB_SHM_SIZE); +	free(me); +} + +int authdb_open(struct authdb *adb, struct authdb_config *cfg, struct sqdb *db) +{ +	memset(adb, 0, sizeof(*adb)); +	memset(cfg, 0, sizeof(*cfg)); +	cfg->db = db; +	return adbc_refresh(cfg, time(NULL)); +} + +void authdb_close(struct authdb *adb) +{ +	struct authdb_map_entry *c, *n; +	int i; + +	for (i = 0; i < ARRAY_SIZE(adb->hash_bucket); i++) { +		for (c = adb->hash_bucket[i]; c != NULL; c = n) { +			n = c->next; +			authdb_me_free(c); +		} +	} +} + +static unsigned int MurmurHash2(const void * key, int len, unsigned int seed) +{ +	// 'm' and 'r' are mixing constants generated offline. +	// They're not really 'magic', they just happen to work well. +	const unsigned int m = 0x5bd1e995; +	const int r = 24; +	unsigned int h = seed ^ len; +	const unsigned char * data = (const unsigned char *)key; + +	while(len >= 4) +	{ +		unsigned int k = *(unsigned int *)data; + +		k *= m; +		k ^= k >> r; +		k *= m; + +		h *= m; +		h ^= k; + +		data += 4; +		len -= 4; +	} + +	switch(len) +	{ +	case 3:	h ^= data[2] << 16; +	case 2:	h ^= data[1] << 8; +	case 1:	h ^= data[0]; +		h *= m; +	} + +	h ^= h >> 13; +	h *= m; +	h ^= h >> 15; + +	return h; +} + +static uint32_t authdb_entry_checksum(struct authdb_entry *entry) +{ +	return MurmurHash2(&entry->p, sizeof(entry->p), 0); +} + +void *authdb_get(struct authdb *adb, sockaddr_any *addr, struct authdb_entry *entry, int create) +{ +	struct authdb_map_entry *me; +	unsigned int hash, e, i; +	sockaddr_any baseaddr; +	blob_t b; + +	baseaddr = *addr; +	b = addr_get_hostaddr_blob(&baseaddr); +	if (b.len < 4) +		return NULL; + +	e = (unsigned char) b.ptr[0]; +	b.ptr[0] = 0x00; + +	hash = b.ptr[1] + b.ptr[2] + b.ptr[3]; +	hash %= ARRAY_SIZE(adb->hash_bucket); + +	for (me = adb->hash_bucket[hash]; me != NULL; me = me->next) { +		if (addr_cmp(&baseaddr, &me->baseaddr) == 0) +			break; +	} +	if (me == NULL) { +		me = authdb_me_open(&baseaddr, create); +		if (me == NULL) +			return NULL; +		me->next = adb->hash_bucket[hash]; +		adb->hash_bucket[hash] = me; +	} + +	for (i = 0; i < 3; i++) { +		memcpy(entry, &me->entries[e], sizeof(struct authdb_entry)); +		if (entry->checksum == 0 && entry->p.login_time == 0) +			return &me->entries[e]; +		if (entry->checksum == authdb_entry_checksum(entry)) +			return &me->entries[e]; +		sched_yield(); +	} + +	authdb_clear_entry(entry); + +	return &me->entries[e]; +} + +int authdb_set(void *token, struct authdb_entry *entry) +{ +	struct authdb_entry *mme = token; +	uint32_t checksum = entry->checksum; + +	entry->checksum = authdb_entry_checksum(entry); +	if (mme->checksum != checksum) +		return 0; + +	mme->checksum = ~0; +	memcpy(mme, entry, sizeof(*entry)); + +	return 1; +} + +int authdb_check_login(void *token, struct authdb_entry *e, blob_t username, time_t now) +{ +	struct authdb_entry *mme = token; + +	/* check username */ +	if (!blob_is_null(username) && +	    blob_cmp(username, BLOB_STRLEN(e->p.login_name)) != 0) +		return 0; + +	/* and dates */ +	if (now > e->last_activity_time + AUTHDB_LOGOFF_PERIOD) +		return 0; + +	/* and that no one clobbered the entry */ +	if (mme->checksum != e->checksum) +		return 0; + +	/* refresh last activity */ +	mme->last_activity_time = now; + +	return 1; +} + +void authdb_clear_entry(struct authdb_entry *entry) +{ +	uint32_t checksum = entry->checksum; + +	memset(entry, 0, sizeof(*entry)); +	entry->checksum = checksum; +} + +void authdb_commit_login(void *token, struct authdb_entry *e, time_t now, struct authdb_config *cfg) +{ +	e->p.block_categories = cfg->block_categories; +	e->p.hard_block_categories = cfg->hard_block_categories; +	e->p.login_time = now; +	e->last_activity_time = now; +	e->override_time = 0; + +	authdb_set(token, e); +} + +void authdb_commit_logout(void *token) +{ +	memset(token, 0, sizeof(struct authdb_entry)); +} + +void authdb_commit_override(void *token, struct authdb_entry *e, time_t now) +{ +	struct authdb_entry *mme = token; + +	mme->override_time = now; +} + +static blob_t read_word(FILE *in, int *lineno, blob_t b) +{ +	int ch, i, comment = 0; +	blob_t r; + +	ch = fgetc(in); +	while (1) { +		if (ch == EOF) +			return BLOB_NULL; +		if (ch == '#') +			comment = 1; +		if (!comment && !isspace(ch)) +			break; +		if (ch == '\n') { +			(*lineno)++; +			comment = 0; +		} +		ch = fgetc(in); +	} + +	r.ptr = b.ptr; +	r.len = 0; +	for (i = 0; i < b.len-1 && !isspace(ch); i++, r.len++) { +		r.ptr[i] = ch; +		ch = fgetc(in); +		if (ch == EOF) +			break; +		if (ch == '\n') +			(*lineno)++; +	} + +	return r; +} + +static int find_category_id(struct sqdb *db, blob_t cat) +{ +	uint32_t size, *ptr; +	int i; + +	ptr = sqdb_section_get(db, SQDB_SECTION_CATEGORIES, &size); +	if (ptr == NULL) +		return -1; + +	size /= sizeof(uint32_t); +	for (i = 0; i < size; i++) +		if (blob_cmp(cat, sqdb_get_string_literal(db, ptr[i])) == 0) +			return i; + +	return -1; +} + +static inline uint64_t to_category(struct sqdb *db, blob_t c) +{ +	int category; + +	category = find_category_id(db, c); +	if (category >= 0) +		return 1ULL << category; + +	fprintf(stderr, "WARNING: unknown category '%.*s'\n", +		c.len, c.ptr); +	return 0; +} + +int adbc_refresh(struct authdb_config *cfg, time_t now) +{ +	FILE *in; +	int lineno = 1; +	char word1[64], word2[64]; +	blob_t b, p; +	struct stat st; + +	if (cfg->last_check != 0 && cfg->last_check + 2*60 > now) +		return 0; + +	if (stat("/etc/squark/filter.conf", &st) != 0) +		return -1; + +	if (cfg->last_change == st.st_ctime) +		return 0; + +	/* check timestamp */ + +	in = fopen("/etc/squark/filter.conf", "r"); +	if (in == NULL) +		return -1; + +	cfg->block_categories = 0; +	cfg->hard_block_categories = 0; +	while (1) { +		b = read_word(in, &lineno, BLOB_BUF(word1)); +		if (blob_is_null(b)) +			break; + +		p = read_word(in, &lineno, BLOB_BUF(word2)); +		if (blob_cmp(b, BLOB_STR("redirect_path")) == 0) { +			cfg->redirect_url_base = blob_dup(p); +		} else if (blob_cmp(b, BLOB_STR("forbid")) == 0) { +			cfg->hard_block_categories |= to_category(cfg->db, p); +		} else if (blob_cmp(b, BLOB_STR("warn")) == 0) { +			cfg->block_categories |= to_category(cfg->db, p); +		} +	} +	cfg->block_categories |= cfg->hard_block_categories; + +	fclose(in); +} diff --git a/src/authdb.h b/src/authdb.h new file mode 100644 index 0000000..7bfa2f4 --- /dev/null +++ b/src/authdb.h @@ -0,0 +1,62 @@ +#ifndef AUTHDB_H +#define AUTHDB_H + +#include <stddef.h> +#include <stdint.h> +#include "blob.h" +#include "addr.h" + +#define AUTHDB_IP_HASH_SIZE		64 + +struct sqdb; +struct authdb_map_entry; + +struct authdb_config { +	struct sqdb *db; +	time_t last_check; +	time_t last_change; +	uint64_t block_categories; +	uint64_t hard_block_categories; +	blob_t redirect_url_base; +}; + +struct authdb { +	struct authdb_map_entry	*hash_bucket[AUTHDB_IP_HASH_SIZE]; +}; + +struct authdb_entry { +	struct { +		char		login_name[44]; +		char		mac_address[6]; +		uint16_t	switch_port; +		sockaddr_any	switch_ip; +		uint64_t	block_categories; +		uint64_t	hard_block_categories; +		uint32_t	login_time; +	} p; +	uint32_t		last_activity_time; +	uint32_t		override_time; +	uint32_t		checksum; +}; + +struct authdb_map_entry { +	struct authdb_map_entry	*next; +	sockaddr_any		baseaddr; +	struct authdb_entry *	entries; +}; + +int authdb_open(struct authdb *adb, struct authdb_config *adbc, struct sqdb *db); +void authdb_close(struct authdb *adb); + +void *authdb_get(struct authdb *adb, sockaddr_any *addr, struct authdb_entry *entry, int create); + +void authdb_clear_entry(struct authdb_entry *entry); +int authdb_set(void *token, struct authdb_entry *entry); +int authdb_check_login(void *token, struct authdb_entry *e, blob_t username, time_t now); +void authdb_commit_login(void *token, struct authdb_entry *e, time_t now, struct authdb_config *cfg); +void authdb_commit_logout(void *token); +void authdb_commit_override(void *token, struct authdb_entry *entry, time_t now); + +int adbc_refresh(struct authdb_config *cfg, time_t now); + +#endif diff --git a/src/blob.c b/src/blob.c new file mode 100644 index 0000000..1604308 --- /dev/null +++ b/src/blob.c @@ -0,0 +1,426 @@ +#include <time.h> +#include <ctype.h> +#include <string.h> + +#include "blob.h" + +/* RFC 3986 section 2.3 Unreserved Characters (January 2005) */ +#define CTYPE_UNRESERVED	1 + +static const char *xd = "0123456789abcdefghijklmnopqrstuvwxyz"; + +static const unsigned char chartype[128] = { +	['a'] = CTYPE_UNRESERVED, +	['b'] = CTYPE_UNRESERVED, +	['c'] = CTYPE_UNRESERVED, +	['d'] = CTYPE_UNRESERVED, +	['e'] = CTYPE_UNRESERVED, +	['f'] = CTYPE_UNRESERVED, +	['g'] = CTYPE_UNRESERVED, +	['h'] = CTYPE_UNRESERVED, +	['i'] = CTYPE_UNRESERVED, +	['j'] = CTYPE_UNRESERVED, +	['k'] = CTYPE_UNRESERVED, +	['l'] = CTYPE_UNRESERVED, +	['m'] = CTYPE_UNRESERVED, +	['n'] = CTYPE_UNRESERVED, +	['o'] = CTYPE_UNRESERVED, +	['p'] = CTYPE_UNRESERVED, +	['q'] = CTYPE_UNRESERVED, +	['r'] = CTYPE_UNRESERVED, +	['s'] = CTYPE_UNRESERVED, +	['t'] = CTYPE_UNRESERVED, +	['u'] = CTYPE_UNRESERVED, +	['v'] = CTYPE_UNRESERVED, +	['w'] = CTYPE_UNRESERVED, +	['x'] = CTYPE_UNRESERVED, +	['y'] = CTYPE_UNRESERVED, +	['z'] = CTYPE_UNRESERVED, +	['A'] = CTYPE_UNRESERVED, +	['B'] = CTYPE_UNRESERVED, +	['C'] = CTYPE_UNRESERVED, +	['D'] = CTYPE_UNRESERVED, +	['E'] = CTYPE_UNRESERVED, +	['F'] = CTYPE_UNRESERVED, +	['G'] = CTYPE_UNRESERVED, +	['H'] = CTYPE_UNRESERVED, +	['I'] = CTYPE_UNRESERVED, +	['J'] = CTYPE_UNRESERVED, +	['K'] = CTYPE_UNRESERVED, +	['L'] = CTYPE_UNRESERVED, +	['M'] = CTYPE_UNRESERVED, +	['N'] = CTYPE_UNRESERVED, +	['O'] = CTYPE_UNRESERVED, +	['P'] = CTYPE_UNRESERVED, +	['Q'] = CTYPE_UNRESERVED, +	['R'] = CTYPE_UNRESERVED, +	['S'] = CTYPE_UNRESERVED, +	['T'] = CTYPE_UNRESERVED, +	['U'] = CTYPE_UNRESERVED, +	['V'] = CTYPE_UNRESERVED, +	['W'] = CTYPE_UNRESERVED, +	['X'] = CTYPE_UNRESERVED, +	['Y'] = CTYPE_UNRESERVED, +	['Z'] = CTYPE_UNRESERVED, + +	['0'] = CTYPE_UNRESERVED, +	['1'] = CTYPE_UNRESERVED, +	['2'] = CTYPE_UNRESERVED, +	['3'] = CTYPE_UNRESERVED, +	['4'] = CTYPE_UNRESERVED, +	['5'] = CTYPE_UNRESERVED, +	['6'] = CTYPE_UNRESERVED, +	['7'] = CTYPE_UNRESERVED, +	['8'] = CTYPE_UNRESERVED, +	['9'] = CTYPE_UNRESERVED, + +	['-'] = CTYPE_UNRESERVED, +	['_'] = CTYPE_UNRESERVED, +	['.'] = CTYPE_UNRESERVED, +	['~'] = CTYPE_UNRESERVED, +}; + +static inline int dx(int c) +{ +	if (likely(c >= '0' && c <= '9')) +		return c - '0'; +	if (likely(c >= 'a' && c <= 'f')) +		return c - 'a' + 0xa; +	if (c >= 'A' && c <= 'F') +		return c - 'A' + 0xa; +	return -1; +} + +char *blob_cstr_dup(blob_t b) +{ +	char *p; + +	if (blob_is_null(b)) +		return NULL; + +	p = malloc(b.len+1); +	if (p != NULL) { +		memcpy(p, b.ptr, b.len); +		p[b.len] = 0; +	} +	return p; +} + +blob_t blob_dup(blob_t b) +{ +	blob_t p; + +	if (blob_is_null(b)) +		return BLOB_NULL; + +	p.ptr = malloc(b.len); +	if (p.ptr != NULL) { +		memcpy(p.ptr, b.ptr, b.len); +		p.len = b.len; +	} else { +		p.len = 0; +	} +	return p; +} + +int blob_cmp(blob_t a, blob_t b) +{ +	if (a.len != b.len) +		return a.len - b.len; +	return memcmp(a.ptr, b.ptr, a.len); +} + +unsigned long blob_inet_addr(blob_t b) +{ +	unsigned long ip = 0; +	int i; + +	for (i = 0; i < 3; i++) { +		ip += blob_pull_uint(&b, 10); +		ip <<= 8; +		if (!blob_pull_matching(&b, BLOB_STR("."))) +			return 0; +	} +	ip += blob_pull_uint(&b, 10); +	if (b.len != 0) +		return 0; +	return htonl(ip); +} + + +blob_t blob_pushed(blob_t buffer, blob_t left) +{ +	if (buffer.ptr + buffer.len != left.ptr + left.len) +		return BLOB_NULL; +	return BLOB_PTR_LEN(buffer.ptr, left.ptr - buffer.ptr); +} + +void blob_push(blob_t *b, blob_t d) +{ +	if (b->len >= d.len) { +		memcpy(b->ptr, d.ptr, d.len); +		b->ptr += d.len; +		b->len -= d.len; +	} else { +		*b = BLOB_NULL; +	} +} + +void blob_push_lower(blob_t *b, blob_t d) +{ +	int i; + +	if (b->len < d.len) { +		*b = BLOB_NULL; +		return; +	} +	for (i = 0; i < d.len; i++) +		b->ptr[i] = tolower(d.ptr[i]); +	b->ptr += d.len; +	b->len -= d.len; +} + +void blob_push_byte(blob_t *b, unsigned char byte) +{ +	if (b->len) { +		b->ptr[0] = byte; +		b->ptr ++; +		b->len --; +	} else { +		*b = BLOB_NULL; +	} +} + +void blob_push_uint(blob_t *to, unsigned int value, int radix) +{ +	char buf[64]; +	char *ptr = &buf[sizeof(buf)-1]; + +	if (value == 0) { +		blob_push_byte(to, '0'); +		return; +	} + +	while (value != 0) { +		*(ptr--) = xd[value % radix]; +		value /= radix; +	} + +	blob_push(to, BLOB_PTR_PTR(ptr+1, &buf[sizeof(buf)-1])); +} + +void blob_push_ctime(blob_t *to, time_t t) +{ +	char buf[128]; +	blob_t b; + +	ctime_r(&t, buf); +	b = BLOB_STRLEN(buf); +	b.len--; +	blob_push(to, b); +} + +void blob_push_hexdump(blob_t *to, blob_t binary) +{ +	char *d; +	int i; + +	if (blob_is_null(*to)) +		return; + +	if (to->len < binary.len * 2) { +		*to = BLOB_NULL; +		return; +	} + +	for (i = 0, d = to->ptr; i < binary.len; i++) { +		*(d++) = xd[(binary.ptr[i] >> 4) & 0xf]; +		*(d++) = xd[binary.ptr[i] & 0xf]; +	} +	to->ptr = d; +	to->len -= binary.len * 2; +} + +void blob_push_urldecode(blob_t *to, blob_t url) +{ +	blob_t b, orig = *to; + +	do { +		blob_pull_matching(&url, BLOB_STR("/")); +		b = blob_pull_cspn(&url, BLOB_STR("/")); +		if (blob_is_null(b) || blob_cmp(b, BLOB_STR(".")) == 0) { +			/* skip '.' or two consecutive / */ +		} else if (blob_cmp(b, BLOB_STR("..")) == 0) { +			/* go up one path component */ +			blob_shrink_tail(to, blob_pushed(orig, b), '/'); +		} else { +			/* copy decoded; FIXME decode percent encoding */ +			blob_push_byte(to, '/'); +			blob_push(to, b); +		} +	} while (!blob_is_null(url)); +} + +void blob_push_urlencode(blob_t *to, blob_t url) +{ +	unsigned char c; +	int i; + +	for (i = 0; i < url.len; i++) { +		c = url.ptr[i]; + +		if (c <= 127 && (chartype[c] & CTYPE_UNRESERVED)) { +			blob_push_byte(to, c); +		} else { +			blob_push_byte(to, '%'); +			blob_push_uint(to, c, 16); +		} +	} +} + +blob_t blob_pull(blob_t *b, int len) +{ +	blob_t r; + +	if (b->len >= len) { +		r = BLOB_PTR_LEN(b->ptr, len); +		b->ptr += len; +		b->len -= len; +		return r; +	} +	*b = BLOB_NULL; +	return BLOB_NULL; +} + +void blob_pull_skip(blob_t *b, int len) +{ +	if (b->len >= len) { +		b->ptr += len; +		b->len -= len; +	} else { +		*b = BLOB_NULL; +	} +} + +int blob_pull_matching(blob_t *b, blob_t e) +{ +	if (b->len < e.len) +		return 0; +	if (memcmp(b->ptr, e.ptr, e.len) != 0) +		return 0; +	b->ptr += e.len; +	b->len -= e.len; +	return 1; +} + +unsigned int blob_pull_uint(blob_t *b, int radix) +{ +	unsigned int val; +	int ch; + +	val = 0; +	while (b->len && b->ptr[0] != 0) { +		ch = dx(b->ptr[0]); +		if (ch < 0 || ch >= radix) +			break; +		val *= radix; +		val += ch; + +		b->ptr++; +		b->len--; +	} + +	return val; +} + +blob_t blob_pull_spn(blob_t *b, const blob_t reject) +{ +	blob_t t = *b; +	int i; + +	for (i = 0; i < t.len; i++) { +		if (memchr(reject.ptr, t.ptr[i], reject.len) == NULL) { +			*b = BLOB_PTR_LEN(t.ptr + i, t.len - i); +			return BLOB_PTR_LEN(t.ptr, i); +		} +	} + +	*b = BLOB_NULL; +	return t; +} + +blob_t blob_pull_cspn(blob_t *b, const blob_t reject) +{ +	blob_t t = *b; +	int i; + +	for (i = 0; i < t.len; i++) { +		if (memchr(reject.ptr, t.ptr[i], reject.len) != NULL) { +			*b = BLOB_PTR_LEN(t.ptr + i, t.len - i); +			return BLOB_PTR_LEN(t.ptr, i); +		} +	} + +	*b = BLOB_NULL; +	return t; +} + +blob_t blob_expand_head(blob_t *b, blob_t limits, unsigned char sep) +{ +	blob_t t = *b; +	blob_t r; + +	if (t.ptr < limits.ptr || t.ptr+t.len > limits.ptr+limits.len) +		return BLOB_NULL; +	while (t.ptr > limits.ptr && t.ptr[-1] == sep) +		t.ptr--, t.len++; + +	r.ptr = t.ptr; +	r.len = 0; +	while (t.ptr > limits.ptr && t.ptr[-1] != sep) { +		t.ptr--, t.len++; +		r.ptr--, r.len++; +	} +	*b = t; +	return r; +} + +blob_t blob_expand_tail(blob_t *b, blob_t limits, unsigned char sep) +{ +	blob_t t = *b; +	blob_t r; + +	if (t.ptr < limits.ptr || t.ptr+t.len > limits.ptr+limits.len) +		return BLOB_NULL; +	while (t.ptr + t.len < limits.ptr + limits.len && t.ptr[t.len] == sep) +		t.len++; + +	r.ptr = t.ptr + t.len; +	r.len = 0; +	while (t.ptr + t.len < limits.ptr + limits.len && t.ptr[t.len] != sep) { +		t.len++; +		r.len++; +	} +	*b = t; +	return r; +} + +blob_t blob_shrink_tail(blob_t *b, blob_t limits, unsigned char sep) +{ +	blob_t t = *b; +	blob_t r; + +	if (t.ptr <= limits.ptr || t.ptr+t.len > limits.ptr+limits.len) +		return BLOB_NULL; +	while (t.len && t.ptr[t.len-1] == sep) +		t.len--; + +	r.ptr = t.ptr; +	r.len = 0; +	while (t.len && t.ptr[t.len-1] != sep) { +		t.len--; +		r.ptr--, r.len++; +	} +	*b = t; +	return r; +} diff --git a/src/blob.h b/src/blob.h new file mode 100644 index 0000000..76afed7 --- /dev/null +++ b/src/blob.h @@ -0,0 +1,63 @@ +#ifndef BLOB_H +#define BLOB_H + +#include <string.h> + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +#if defined __GNUC__ && __GNUC__ == 2 && __GNUC_MINOR__ < 96 +#define __builtin_expect(x, expected_value) (x) +#endif + +#ifndef likely +#define likely(x) __builtin_expect((!!(x)),1) +#endif + +#ifndef unlikely +#define unlikely(x) __builtin_expect((!!(x)),0) +#endif + +typedef struct blob { +	char *ptr; +	unsigned int len; +} blob_t; + +#define BLOB_PTR_LEN(ptr,len)	(blob_t){(void*)(ptr), (len)} +#define BLOB_PTR_PTR(beg,end)	BLOB_PTR_LEN((beg),(end)-(beg)+1) +#define BLOB_BUF(buf)		(blob_t){(void*)(buf), sizeof(buf)} +#define BLOB_STRLEN(str)	(blob_t){(str), strlen(str)} +#define BLOB_STR_INIT(str)	{(str), sizeof(str)-1} +#define BLOB_STR(str)		(blob_t) BLOB_STR_INIT(str) +#define BLOB_NULL		(blob_t){NULL, 0} + +static inline int blob_is_null(blob_t b) +{ +	return b.len == 0; +} + +char *blob_cstr_dup(blob_t b); +blob_t blob_dup(blob_t b); +int blob_cmp(blob_t a, blob_t b); +unsigned long blob_inet_addr(blob_t buf); + +blob_t blob_pushed(blob_t buffer, blob_t left); +void blob_push(blob_t *b, blob_t d); +void blob_push_lower(blob_t *b, blob_t d); +void blob_push_byte(blob_t *b, unsigned char byte); +void blob_push_uint(blob_t *to, unsigned int value, int radix); +void blob_push_ctime(blob_t *to, time_t t); +void blob_push_hexdump(blob_t *to, blob_t binary); +void blob_push_urldecode(blob_t *to, blob_t url); +void blob_push_urlencode(blob_t *to, blob_t url); +blob_t blob_pull(blob_t *b, int len); +void blob_pull_skip(blob_t *b, int len); +int blob_pull_matching(blob_t *b, blob_t e); +unsigned int blob_pull_uint(blob_t *b, int radix); +blob_t blob_pull_spn(blob_t *b, const blob_t spn); +blob_t blob_pull_cspn(blob_t *b, const blob_t cspn); + +blob_t blob_expand_head(blob_t *b, blob_t limits, unsigned char sep); +blob_t blob_expand_tail(blob_t *b, blob_t limits, unsigned char sep); +blob_t blob_shrink_tail(blob_t *b, blob_t limits, unsigned char sep); + +#endif diff --git a/src/filterdb.c b/src/filterdb.c new file mode 100644 index 0000000..d3f4c6a --- /dev/null +++ b/src/filterdb.c @@ -0,0 +1,157 @@ +#include <fcntl.h> +#include <unistd.h> +#include <string.h> +#include <sys/mman.h> +#include <sys/stat.h> + +#include "filterdb.h" + +#define PAGE_SIZE		4096 +#define ALIGN(s,a)		(((s) + a - 1) & ~(a - 1)) + +const char *sqdb_section_names[SQDB_SECTION_MAX] = { +	[SQDB_SECTION_STRINGS]		= "strings", +	[SQDB_SECTION_CATEGORIES]	= "categories", +	[SQDB_SECTION_INDEX]		= "index", +	[SQDB_SECTION_INDEX_MPH]	= "index_mph", +	[SQDB_SECTION_KEYWORD]		= "keyword", +	[SQDB_SECTION_KEYWORD_MPH]	= "keyword_mph", +}; + +static int sqdb_allocate(struct sqdb *db, size_t s, int wr) +{ +	size_t old_size, new_size; +	void *base; +	int prot = PROT_READ; + +	old_size = db->file_length; +	new_size = ALIGN(db->file_length + s, PAGE_SIZE); + +	if (new_size == ALIGN(db->file_length, PAGE_SIZE)) { +		db->file_length += s; +		return old_size; +	} + +	if (wr && ftruncate(db->fd, new_size) < 0) +		return -1; + +	if (db->mmap_base == NULL) { +		if (wr) +			prot |= PROT_WRITE; +		base = mmap(NULL, new_size, prot, MAP_SHARED, db->fd, 0); +	} else { +		base = mremap(db->mmap_base, ALIGN(old_size, PAGE_SIZE), +			      new_size, MREMAP_MAYMOVE); +	} +	if (base == MAP_FAILED) +		return -1; + +	db->mmap_base = base; +	db->file_length += ALIGN(s, 16); + +	return old_size; +} + +int sqdb_open(struct sqdb *db, const char *fn) +{ +	struct stat st; + +	db->fd = open(fn, O_RDONLY); +	if (db->fd < 0) +		return -1; + +	fstat(db->fd, &st); + +	db->file_length = 0; +	db->mmap_base = NULL; +	sqdb_allocate(db, st.st_size, 0); + +	return 0; +} + +int sqdb_create(struct sqdb *db, const char *fn) +{ +	struct sqdb_header *hdr; +	int rc; + +	db->fd = open(fn, O_CREAT | O_TRUNC | O_RDWR, 0666); +	if (db->fd < 0) +		return -1; + +	db->file_length = 0; +	db->mmap_base = NULL; + +	rc = sqdb_allocate(db, sizeof(struct sqdb_header), 1); +	if (rc < 0) { +		close(db->fd); +		return rc; +	} + +	hdr = db->mmap_base; +	strcpy(hdr->description, "Squark Filtering Database"); +	hdr->version = 1; +	hdr->magic = 0xdbdbdbdb; +	hdr->num_sections = SQDB_SECTION_MAX; + +	return 0; +} + +int sqdb_open(struct sqdb *db, const char *fn); + +void sqdb_close(struct sqdb *db) +{ +	if (db->mmap_base) +		munmap(db->mmap_base, ALIGN(db->file_length, PAGE_SIZE)); +	close(db->fd); +} + +void *sqdb_section_create(struct sqdb *db, int id, uint32_t size) +{ +	struct sqdb_header *hdr; +	size_t pos; + +	hdr = db->mmap_base; +	if (hdr->section[id].offset || hdr->section[id].length) +		return NULL; + +	pos = sqdb_allocate(db, size, 1); +	if (pos < 0) +		return NULL; + +	/* sqdb_allocate can remap mmap_base */ +	hdr = db->mmap_base; +	hdr->section[id].offset = pos; +	hdr->section[id].length = size; + +	return db->mmap_base + pos; +} + +void *sqdb_section_get(struct sqdb *db, int id, uint32_t *size) +{ +	struct sqdb_header *hdr = db->mmap_base; + +	if (hdr->section[id].offset == 0) +		return NULL; + +	if (size) +		*size = hdr->section[id].length; + +	return db->mmap_base + hdr->section[id].offset; +} + +blob_t sqdb_get_string_literal(struct sqdb *db, uint32_t encoded_ptr) +{ +	unsigned char *ptr; +	unsigned int len, off; + +	ptr = sqdb_section_get(db, SQDB_SECTION_STRINGS, NULL); +	if (ptr == NULL) +		return BLOB_NULL; + +	off = encoded_ptr >> SQDB_LENGTH_BITS; +	len = encoded_ptr & ((1 << SQDB_LENGTH_BITS) - 1); +	if (len == 0) +		len = ptr[off++]; + +	return BLOB_PTR_LEN(ptr + off, len); +} diff --git a/src/filterdb.h b/src/filterdb.h new file mode 100644 index 0000000..2d16572 --- /dev/null +++ b/src/filterdb.h @@ -0,0 +1,59 @@ +#ifndef FILTERDB_H +#define FILTERDB_H + +#include <stddef.h> +#include <stdint.h> +#include "blob.h" + +#define SQDB_LENGTH_BITS		5 + +#define SQDB_SECTION_STRINGS		0 +#define SQDB_SECTION_CATEGORIES		1 +#define SQDB_SECTION_INDEX		2 +#define SQDB_SECTION_INDEX_MPH		3 +#define SQDB_SECTION_KEYWORD		4 +#define SQDB_SECTION_KEYWORD_MPH	5 +#define SQDB_SECTION_MAX		8 + +struct sqdb { +	int			fd; +	void *			mmap_base; +	size_t			file_length; +}; + +struct sqdb_section { +	uint32_t		offset; +	uint32_t		length; +}; + +struct sqdb_header { +	char			description[116]; +	uint32_t		num_sections; +	uint32_t		version; +	uint32_t		magic; +	struct sqdb_section	section[SQDB_SECTION_MAX]; +}; + +#define SQDB_PARENT_ROOT	0xffffff +#define SQDB_PARENT_IPV4	0xfffffe + +struct sqdb_index_entry { +	uint32_t		has_subdomains : 1; +	uint32_t		has_paths : 1; +	uint32_t		category : 6; +	uint32_t		parent : 24; +	uint32_t		component; +}; + + +const char *sqdb_section_names[SQDB_SECTION_MAX]; + +int sqdb_create(struct sqdb *db, const char *fn); +int sqdb_open(struct sqdb *db, const char *fn); +void sqdb_close(struct sqdb *db); + +void *sqdb_section_create(struct sqdb *db, int id, uint32_t size); +void *sqdb_section_get(struct sqdb *db, int id, uint32_t *size); +blob_t sqdb_get_string_literal(struct sqdb *db, uint32_t encoded_ptr); + +#endif diff --git a/src/lua-squarkdb.c b/src/lua-squarkdb.c new file mode 100644 index 0000000..5a30848 --- /dev/null +++ b/src/lua-squarkdb.c @@ -0,0 +1,343 @@ +#include <string.h> + +#include <lua.h> +#include <lualib.h> +#include <lauxlib.h> + +#include <cmph.h> + +#include "filterdb.h" + +#define SQUARKDB_META	"squarkdb" + +struct sqdb *Lsqdb_checkarg(lua_State *L, int index) +{ +	struct sqdb *db; + +	luaL_checktype(L, index, LUA_TUSERDATA); +	db = (struct sqdb *) luaL_checkudata(L, index, SQUARKDB_META); +	if (db == NULL) +		luaL_typerror(L, index, SQUARKDB_META); +	return db; +} + +static int Lsqdb_new(lua_State *L) +{ +	struct sqdb *db; +	const char *fn; + +	fn = luaL_checklstring(L, 1, NULL); + +	db = (struct sqdb *) lua_newuserdata(L, sizeof(struct sqdb)); +	luaL_getmetatable(L, SQUARKDB_META); +	lua_setmetatable(L, -2); + +	if (sqdb_create(db, fn) < 0) +		luaL_error(L, "Failed to create SquarkDB file '%s'", fn); + +	return 1; +} + +static int Lsqdb_destroy(lua_State *L) +{ +	struct sqdb *db; + +	db = Lsqdb_checkarg(L, 1); +	sqdb_close(db); + +	return 1; +} + + +struct ioa_data { +	lua_State *main; +	lua_State *thread; +}; + +static void ioa_rewind(void *data) +{ +	struct ioa_data *ioa = (struct ioa_data *) data; + +	/* pop previous thread */ +	lua_pop(ioa->main, 1); + +	/* create a new lua thread */ +	ioa->thread = lua_newthread(ioa->main); +	lua_pushvalue(ioa->main, -2);		/* copy function to top */ +	lua_xmove(ioa->main, ioa->thread, 1);	/* move function from L to NL */ +} + +static cmph_uint32 ioa_count_keys(void *data) +{ +	struct ioa_data *ioa = (struct ioa_data *) data; +	lua_State *NL; +	cmph_uint32 cnt = 0; + +	NL = lua_newthread(ioa->main); +	lua_pushvalue(ioa->main, -2);	/* copy function to top */ +	lua_xmove(ioa->main, NL, 1);	/* move function from L to NL */ + +	do { +		cnt++; +		lua_settop(NL, 1); +	} while (lua_resume(NL, 0) == LUA_YIELD); +	ioa_rewind(data); + +	return cnt - 1; +} + +static int ioa_read(void *data, char **key, cmph_uint32 *len) +{ +	struct ioa_data *ioa = (struct ioa_data *) data; +	lua_State *L = ioa->thread; +	size_t l; + +	/* get next key from lua thread */ +	lua_settop(L, 1); +	if (lua_resume(L, 0) != LUA_YIELD || +	    !lua_isstring(L, 1)) { +		*key = NULL; +		*len = 0; +		return -1; +	} + +	*key = (char *) lua_tolstring(L, 1, &l); +	*len = l; + +	return l; +} + +static void ioa_dispose(void *data, char *key, cmph_uint32 len) +{ +	/* LUA takes care of garbage collection */ +} + +static int Lsqdb_hash(lua_State *L) +{ +	struct sqdb *db; +	void *ptr; +	cmph_uint32 hash; +	const char *key; +	size_t keylen; + +	db = Lsqdb_checkarg(L, 1); +	key = luaL_checklstring(L, 2, &keylen); + +	ptr = sqdb_section_get(db, SQDB_SECTION_INDEX_MPH, NULL); +	hash = cmph_search_packed(ptr, key, keylen); + +	lua_pushinteger(L, hash); + +	return 1; +} + +static int Lsqdb_generate_hash(lua_State *L) +{ +	struct sqdb *db; +	struct ioa_data ioa; +	cmph_config_t *cfg; +	cmph_t *cmph; +	cmph_io_adapter_t io; + +	char *ptr; +	cmph_uint32 packed; + +	db = Lsqdb_checkarg(L, 1); +	luaL_argcheck(L, lua_isfunction(L, 2) && !lua_iscfunction(L, 2), +		      2, "Lua function expected"); + +	ioa.main	= L; +	io.data		= &ioa; +	io.nkeys	= ioa_count_keys(io.data); +	io.read		= ioa_read; +	io.dispose	= ioa_dispose; +	io.rewind	= ioa_rewind; + +	cfg = cmph_config_new(&io); +	if (cfg == NULL) +		luaL_error(L, "Failed to create CMPH config"); + +	cmph_config_set_algo(cfg, CMPH_CHD); +	cmph = cmph_new(cfg); +	cmph_config_destroy(cfg); + +	if (cmph == NULL) +		luaL_error(L, "Failed to create minimal perfect hash"); + +	packed = cmph_packed_size(cmph); +	ptr = sqdb_section_create(db, SQDB_SECTION_INDEX_MPH, packed); +	if (ptr == NULL) { +		cmph_destroy(cmph); +		luaL_error(L, "Unable to allocation MPH section from SquarkDB"); +	} + +	cmph_pack(cmph, ptr); +	cmph_destroy(cmph); + +	lua_pushinteger(L, io.nkeys); +	lua_pushinteger(L, packed); + +	return 2; +} + +static int Lsqdb_create_index(lua_State *L) +{ +	struct sqdb *db; +	lua_Integer num_entries; +	void *ptr; + +	db = Lsqdb_checkarg(L, 1); +	num_entries = luaL_checkinteger(L, 2); + +	ptr = sqdb_section_create(db, SQDB_SECTION_INDEX, sizeof(struct sqdb_index_entry) * num_entries); +	if (ptr == NULL) +		luaL_error(L, "Failed to create INDEX section"); + +	return 0; +} + +static int Lsqdb_assign_index(lua_State *L) +{ +	struct sqdb *db; +	size_t size; +	lua_Integer idx; +	struct sqdb_index_entry *ptr; + +	db = Lsqdb_checkarg(L, 1); +	idx = luaL_checkinteger(L, 2); + +	ptr = sqdb_section_get(db, SQDB_SECTION_INDEX, &size); +	if (size <= 0 || idx * sizeof(struct sqdb_index_entry) >= size) +		luaL_error(L, "Bad index assignment (idx=%d, section size=%d)", idx, size); + +	ptr += idx; +	if (ptr->component != 0) +		luaL_error(L, "Index entry %d has been already assigned", idx); + +	ptr->category = luaL_checkinteger(L, 3); +	ptr->has_subdomains = lua_toboolean(L, 4); +	ptr->has_paths = lua_toboolean(L, 5); +	ptr->component = luaL_checkinteger(L, 6); +	ptr->parent = luaL_checkinteger(L, 7); + +	return 0; +} + +static int Lsqdb_map_strings(lua_State *L) +{ +	struct sqdb *db; +	const char *str; +	unsigned char *ptr; +	size_t len, total, pos; + +	db = Lsqdb_checkarg(L, 1); +	luaL_checktype(L, 2, LUA_TTABLE); + +	/* go through the table and count total amount of data */ +	total = 0; +	lua_pushnil(L); +	while (lua_next(L, 2) != 0) { +		str = luaL_checklstring(L, -2, &len); +		total += len; +		if (len >= (1 << SQDB_LENGTH_BITS)) +			total++; +		lua_pop(L, 1); +	} + +	/* create string literal section */ +	ptr = sqdb_section_create(db, SQDB_SECTION_STRINGS, total); +	if (ptr == NULL) +		luaL_error(L, "Failed to create string literal section (%d bytes)", total); + +	/* populate string literals and return their indices */ +	pos = 0; +	lua_pushnil(L); +	while (lua_next(L, 2) != 0) { +		str = lua_tolstring(L, -2, &len); +		lua_pop(L, 1); + +		/* table[key] = encoded_string_pointer */ +		lua_pushvalue(L, -1); +		if (len >= (1 << SQDB_LENGTH_BITS)) { +			lua_pushinteger(L, pos << SQDB_LENGTH_BITS); +			ptr[pos++] = len; +		} else { +			lua_pushinteger(L, (pos << SQDB_LENGTH_BITS) + len); +		} +		memcpy(&ptr[pos], str, len); +		pos += len; + +		lua_rawset(L, 2); +	} + +	return 0; +} + +static int Lsqdb_write_section(lua_State *L) +{ +	struct sqdb *db; +	uint32_t *ptr; +	const char *section; +	int i, tbllen, si = -1; + +	db = Lsqdb_checkarg(L, 1); +	section = luaL_checkstring(L, 2); +	luaL_checktype(L, 3, LUA_TTABLE); +	tbllen = lua_objlen(L, 3); + +	for (i = 0; sqdb_section_names[i] && i < SQDB_SECTION_MAX; i++) { +		if (strcmp(sqdb_section_names[i], section) == 0) { +			si = 0; +			break; +		} +	} +	if (si < 0) +		luaL_error(L, "Section name '%s' is invalid", section); + +	ptr = sqdb_section_create(db, i, sizeof(uint32_t) * tbllen); +	if (ptr == NULL) +		luaL_error(L, "Failed to create section '%s'", section); + +	for (i = 0; i < tbllen; i++) { +		lua_rawgeti(L, 3, i + 1); +		ptr[i] = lua_tointeger(L, -1); +		lua_pop(L, 1); +	} + +	return 0; +} + +static const luaL_reg sqdb_meta_methods[] = { +	{ "__gc",		Lsqdb_destroy }, +	{ NULL,			NULL } +}; + +static const luaL_reg squarkdb_methods[] = { +	{ "new",		Lsqdb_new }, +	{ "hash",		Lsqdb_hash }, +	{ "generate_hash",	Lsqdb_generate_hash }, +	{ "create_index",	Lsqdb_create_index }, +	{ "assign_index",	Lsqdb_assign_index }, +	{ "map_strings",	Lsqdb_map_strings }, +	{ "write_section",	Lsqdb_write_section }, +	{ NULL,			NULL } +}; + +LUALIB_API int luaopen_squarkdb(lua_State *L) +{ +	/* Register squarkdb library */ +	luaI_openlib(L, "squarkdb", squarkdb_methods, 0); + +	/* And metatable for it */ +	luaL_newmetatable(L, SQUARKDB_META); +	luaI_openlib(L, NULL, sqdb_meta_methods, 0); +	lua_pushliteral(L, "__index"); +	lua_pushvalue(L, -3); +	lua_rawset(L, -3); +	lua_pushliteral(L, "__metatable"); +	lua_pushvalue(L, -3); +	lua_rawset(L, -3); +	lua_pop(L, 1); + +	return 1; +} diff --git a/src/sqdb-build.lua b/src/sqdb-build.lua new file mode 100755 index 0000000..cd039e2 --- /dev/null +++ b/src/sqdb-build.lua @@ -0,0 +1,335 @@ +#!/usr/bin/lua + +require("squarkdb") + +local all_strings = {} +local all_domains = {} +local all_ips = {} + +local all_categories = {} +local all_categories_by_id = {} +local num_categories = 0 + +local strfind = string.find +local strsub = string.sub +local tinsert = table.insert + +local function strsplit(delimiter, text) +	local list = {} +	local pos = 1 +	--if strfind("", delimiter, 1) then -- this would result in endless loops +	--	error("delimiter matches empty string!") +	--end +	while 1 do +		local first, last = strfind(text, delimiter, pos) +		if first then -- found? +			tinsert(list, strsub(text, pos, first-1)) +			pos = last+1 +		else +			tinsert(list, strsub(text, pos)) +			break +		end +	end +	return list +end + +local function account_string(s) +	all_strings[s] = true +end + +local function get_category(category_text) +	local cat + +	cat = all_categories[category_text] +	if cat ~= nil then return cat end + +	-- start category ID's from zero +	cat = { desc=category_text, id=num_categories } +	all_categories[category_text] = cat +	num_categories = num_categories + 1 + +	-- but index them from one +	all_categories_by_id[num_categories] = cat + +	account_string(category_text) + +	return cat +end + +local function get_domain(domain, locked) +	local parts, entry, idx, p, child + +	parts = strsplit("[.]", domain) +	entry = all_domains +	for idx=#parts,1,-1 do +		p = parts[idx] +		if entry.children == nil then +			entry.children = {} +		end +		child = entry.children[p] +		if child == nil then +			child = {} +			entry.children[p] = child +		end +		if child.locked and not locked then +			return nil +		end +		entry = child +	end +	return child +end + +local function get_path(domain_entry, path, locked) +	local entry, p, n, component + +	entry = domain_entry +	for n,component in pairs(strsplit("/", path)) do +		if entry.paths == nil then +			entry.paths = {} +		end +		p = entry.paths[component] +		if p == nil then +			p = {} +			entry.paths[component] = p +		end +		if p.locked and not locked then +			return nil +		end +		entry = p +	end +	return p +end + +local function is_ip_addr(s) +	return s:match("^%d+\.%d+\.%d+\.%d+$") +end + +local function read_urls(filename, category, locked) +	local fd, url, domain, path, d + +	fd = io.open(filename) +	if fd == nil then +		print("WARNING: File " .. filename .. " does not exist") +		return +	end +	print("Reading " .. filename) +	for url in fd:lines() do +		url = url:gsub("#.*", "") +		url = url:gsub(" *^", "") +		url = url:lower() +		url = url:gsub("^(www%d*[.])([^.]*[.])", "%2") +		domain, path = url:match("([^/]*)/?(.*)") +		domain = domain:gsub(":.*", "") +		domain = domain:gsub("[.]$", "")	-- trailing dot +		if domain == "" then +			d = nil +		elseif not is_ip_addr(domain) then +			d = get_domain(domain, locked) +		else +			d = all_ips[domain] +			if d == nil then +				d = {} +				all_ips[domain] = d +			end +		end +		if d == nil then +			--if url ~= "" then +			--	print(url .. " ignored due to locked record") +			--end +		elseif path ~= "" then +			if d.category ~= category and #path < 100 and path:match("([?;&])") == nil then +				path = path:gsub("^/", "") +				path = path:gsub("/$", "") +				p = get_path(d, path, locked) +				if p ~= nil then +					p.category = category +					if locked then +						p.locked = true +					end +				end +			end +		else +			if d.category == nil then +				d.category = category +				if locked then +					d.locked = true +				end +			end +		end + 	end + 	fd:close() +end + +local function enum_paths(cb, category, path, data) +	local fpath, cpath, cdata, cat + +	for cpath, cdata in pairs(data) do +		fpath = path .. "/" .. cpath +		cat = cdata.category or category +		cb(fpath, path, cpath, cat, false, cdata.paths) +		if cdata.paths then +			enum_paths(cb, cat, fpath, cdata.paths) +		end +	end +end + +local function enum_tree(cb, category, dns, data) +	local cdns, cdata, fdns, cat + +	if data.paths ~= nil then +		enum_paths(cb, category, dns, data.paths) +	end +	if data.children ~= nil then +		for cdns, cdata in pairs(data.children) do +			if dns ~= nil then +				fdns = cdns .. "." .. dns +			else +				fdns = cdns +			end +			cat = cdata.category or category +			cb(fdns, dns, cdns, cat, cdata.children, cdata.paths) +			enum_tree(cb, cat, fdns, cdata) +		end +	end +end + +function iptonumber(str) +	local num = 0 +	for elem in str:gmatch("%d+") do +		num = num * 256 + assert(tonumber(elem)) +	end +	return num +end + +local function enum_all(cb) +	local ipaddr, data, category + +	-- enumerate all domains +	enum_tree(cb, nil, nil, all_domains) + +	-- all IP addresses +	for ipaddr, data in pairs(all_ips) do +		if data.paths ~= nil then +			enum_paths(cb, data.category, ipaddr, data.paths) +		end +		cb(ipaddr, nil, iptonumber(ipaddr), data.category, nil, data.paths) +	end +end + +local function prune_paths(paths, category) +	local path, pdata, cat +	local num_paths = 0 + +	for path, pdata in pairs(paths) do +		local sub_paths = 0 + +		cat = pdata.category or category +		if pdata.paths ~= nil then +			sub_paths = prune_paths(pdata.paths, cat) +			if sub_paths == 0 then +				pdata.paths = nil +			end +		end +		if cat == category and sub_paths == 0 then +			paths[path] = nil +		else +			num_paths = num_paths + 1 +			account_string(path) +		end +	end +	return num_paths +end + +local function prune_tree(d, pcategory) +	local num_childs = 0 +	local num_paths = 0 +	local cat + +	cat = d.category or pcategory +	if d.children ~= nil then +		for n, child in pairs(d.children) do +			if prune_tree(child, cat, n) then +				d.children[n] = nil +			else +				num_childs = num_childs + 1 +				account_string(n) +			end +		end +		if num_childs == 0 then +			d.children = nil +		end +	end +	--print(name, d.category, category, d.num_paths, num_childs) +	if d.paths ~= nil then +		num_paths = prune_paths(d.paths, cat) +		if num_paths == 0 then +			d.paths = nil +		end +	end +	if d.category == pcategory and num_paths == 0 and num_childs == 0 then +		--num_pruned_leafs = num_pruned_leafs + 1 +		return true +	end +	return false +end + +local function load_lists(conffile, part) +	local line, fields, cat + +	for line in io.lines(conffile) do +		line = line:gsub("#(.*)", "") +		fields = strsplit("[\t ]", line) +		if fields[1] == "STOP" then +			break +		end +		if fields[3] then +			read_urls("lists/" .. fields[2] .. "list/" .. fields[3] .. "/" .. part, +				  get_category(fields[1]), +				  fields[4] == "LOCK") +		end +	end +end + +-- start by reading in all classification data +get_category("unknown") +load_lists("lists.conf", "domains") +prune_tree(all_domains, nil) +load_lists("lists.conf", "urls") +prune_tree(all_domains, nil) + +-- generate database +local db = squarkdb.new("squark.db") +num_entries = db:generate_hash(function() enum_all(coroutine.yield) end) + +-- write string literals +db:map_strings(all_strings) + +-- map category names and write the category section out +for id, cdata in ipairs(all_categories_by_id) do +	all_categories_by_id[id] = all_strings[cdata.desc] +end +db:write_section("categories", all_categories_by_id) + +-- create master index +db:create_index(num_entries) +enum_all( +	function(uri, parent_uri, component, category, childs, paths) +		if parent_uri == nil and type(component) == "number" then +			-- Embedded IPv4 address +			db:assign_index(db:hash(uri), +					category and category.id or 0, +					childs and true or false, +					paths and true or false, +					component, +					-2) +		else +			-- Regular entry +			db:assign_index(db:hash(uri), +					category and category.id or 0, +					childs and true or false, +					paths and true or false, +					all_strings[component] or 0, +					parent_uri and db:hash(parent_uri) or -1) +		end +	end +) diff --git a/src/squark-auth-ip.c b/src/squark-auth-ip.c new file mode 100644 index 0000000..3cdea0b --- /dev/null +++ b/src/squark-auth-ip.c @@ -0,0 +1,218 @@ +/* squark-auth-ip.c - Squid User Authentication and Rating Kit + *   An external acl helper for Squid which collects authentication + *   information about an IP-address from local shared memory database. + * + * Copyright (C) 2010 Timo Teräs <timo.teras@iki.fi> + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +#include <time.h> +#include <stdio.h> +#include <unistd.h> + +#include "blob.h" +#include "authdb.h" +#include "filterdb.h" + +#define DO_LOGIN	-1 +#define DO_OVERRIDE	-2 +#define DO_PRINT	-3 +#define DO_LOGOUT	-4 + +static int running = 1; +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; +static blob_t space = BLOB_STR_INIT(" "); +static blob_t lf = BLOB_STR_INIT("\n"); +static time_t now; + +static void handle_line(blob_t line) +{ +	char reply[128]; +	blob_t b, id, ipaddr; +	struct authdb_entry entry; +	sockaddr_any addr; +	void *token; +	int auth_ok = 0; + +	id = blob_pull_cspn(&line, space); +	blob_pull_spn(&line, space); +	ipaddr = blob_pull_cspn(&line, space); + +	if (addr_parse(ipaddr, &addr)) { +		token = authdb_get(&adb, &addr, &entry, 1); + +		if (authdb_check_login(token, &entry, BLOB_NULL, now)) +			auth_ok = 1; +	} + +	b = BLOB_BUF(reply); +	blob_push(&b, id); +	if (auth_ok) { +		blob_push(&b, BLOB_STR(" OK user=")); +		blob_push(&b, BLOB_STRLEN(entry.p.login_name)); +		blob_push(&b, BLOB_PTR_LEN("\n", 1)); +	} else { +		blob_push(&b, BLOB_STR(" ERR\n")); +	} + +	b = blob_pushed(BLOB_BUF(reply), b); +	write(STDOUT_FILENO, b.ptr, b.len); +} + +static void read_input(void) +{ +	static char buffer[256]; +	static blob_t left; + +	blob_t b, line; +	int r; + +	if (blob_is_null(left)) +		left = BLOB_BUF(buffer); + +	r = read(STDIN_FILENO, left.ptr, left.len); +	if (r < 0) +		return; +	if (r == 0) { +		running = 0; +		return; +	} +	left.ptr += r; +	left.len -= r; + +	now = time(NULL); + +	b = blob_pushed(BLOB_BUF(buffer), left); +	do { +		line = blob_pull_cspn(&b, lf); +		if (!blob_pull_matching(&b, lf)) +			return; + +		handle_line(line); + +		if (b.len) { +			memcpy(buffer, b.ptr, b.len); +			b.ptr = buffer; +		} +		left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); +	} while (b.len); +} + +#define DUMPPAR(b, name, fn) \ +	do {							\ +		blob_push(b, BLOB_STR("squark_" name "='"));	\ +		fn;						\ +		blob_push(b, BLOB_STR("'; "));			\ +	} while (0) + +int main(int argc, char **argv) +{ +	int opt; +	sockaddr_any ipaddr = { .any.sa_family = AF_UNSPEC }; +	blob_t ip = BLOB_NULL, username = BLOB_NULL; + +	while ((opt = getopt(argc, argv, "i:u:olpL")) != -1) { +		switch (opt) { +		case 'i': +			ip = BLOB_STRLEN(optarg); +			if (!addr_parse(ip, &ipaddr)) { +				fprintf(stderr, "'%s' does not look like IP-address\n", +					optarg); +				return 1; +			} +			break; +		case 'u': +			username = BLOB_STRLEN(optarg); +			break; +		case 'o': +			running = DO_OVERRIDE; +			break; +		case 'l': +			running = DO_LOGIN; +			break; +		case 'p': +			running = DO_PRINT; +			break; +		case 'L': +			running = DO_LOGOUT; +			break; +		} +	} + +	now = time(NULL); +	sqdb_open(&db, "/var/lib/squark/squark.db"); +	authdb_open(&adb, &adbc, &db); + +	if (running < 0) { +		struct authdb_entry entry; +		void *token; + +		if (ipaddr.any.sa_family == AF_UNSPEC) { +			fprintf(stderr, "IP-address not specified\n"); +			return 2; +		} + +		token = authdb_get(&adb, &ipaddr, &entry, 1); +		if (token == NULL) { +			fprintf(stderr, "Failed to get authdb record\n"); +			return 3; +		} + +		switch (running) { +		case DO_LOGIN: +			if (blob_is_null(username)) { +				fprintf(stderr, "Username not specified\n"); +				return 2; +			} +			authdb_clear_entry(&entry); +			memcpy(entry.p.login_name, username.ptr, username.len); +			authdb_commit_login(token, &entry, now, &adbc); +			break; +		case DO_OVERRIDE: +			if (authdb_check_login(token, &entry, username, now)) +				authdb_commit_override(token, &entry, now); +			break; +		case DO_PRINT: { +			char buf[512]; +			blob_t b = BLOB_BUF(buf); + +			DUMPPAR(&b, "ip_address", +				addr_push_hostaddr(&b, &ipaddr)); +			DUMPPAR(&b, "username", +				blob_push(&b, BLOB_BUF(entry.p.login_name))); +			DUMPPAR(&b, "mac_address", +				blob_push_hexdump(&b, BLOB_BUF(entry.p.mac_address))); +			DUMPPAR(&b, "login_time", +				blob_push_ctime(&b, entry.p.login_time)); +			DUMPPAR(&b, "activity_time", +				blob_push_ctime(&b, entry.last_activity_time)); +			DUMPPAR(&b, "override_time", +				blob_push_ctime(&b, entry.override_time)); +			DUMPPAR(&b, "block_categories", +				blob_push_hexdump(&b, BLOB_BUF(&entry.p.block_categories))); +			DUMPPAR(&b, "hard_block_categories", +				blob_push_hexdump(&b, BLOB_BUF(&entry.p.hard_block_categories))); +			blob_push(&b, BLOB_STR("\n")); +			b = blob_pushed(BLOB_BUF(buf), b); +			fwrite(b.ptr, b.len, 1, stdout); +			break; +		} +		case DO_LOGOUT: +			if (authdb_check_login(token, &entry, username, now)) +				authdb_commit_logout(token); +			break; +		} +	} else { +		while (running) +			read_input(); +	} + +	authdb_close(&adb); +	sqdb_close(&db); +} diff --git a/src/squark-auth-snmp.c b/src/squark-auth-snmp.c new file mode 100644 index 0000000..81b846d --- /dev/null +++ b/src/squark-auth-snmp.c @@ -0,0 +1,1152 @@ +/* squark-auth-snmp.c - Squid User Authentication and Rating Kit + *   An external acl helper for Squid which collects authentication + *   information about an IP-address from switches via SNMP. + * + * Copyright (C) 2010 Timo Teräs <timo.teras@iki.fi> + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +/* TODO: + * - implement Q-BRIDGE-MIB query + * - map vlan names to vlan index + * - print some usage information + * - poll lldpStatsRemTablesLastChangeTime when doing switch update + *   to figure out if lldp info is valid or not + */ + +#include <fcntl.h> +#include <stdio.h> +#include <string.h> +#include <unistd.h> + +#include <net-snmp/net-snmp-config.h> +#include <net-snmp/net-snmp-includes.h> + +#include "blob.h" +#include "addr.h" +#include "authdb.h" +#include "filterdb.h" + +/* Compile time configurables */ +#define SWITCH_HASH_SIZE		128 +#define PORT_HASH_SIZE			128 +#define CACHE_TIME			120 /* seconds */ + +/* Some helpers */ +#define ARRAY_SIZE(x)	(sizeof(x) / sizeof((x)[0])) +#define MAC_LEN		6 + +#define oid_const(oid)	(oid), ARRAY_SIZE(oid) +#define oid_blob(b)	((oid *) (b).ptr), ((b).len / sizeof(oid)) + +/* Format specifiers for username type */ +#define FORMAT_CLIENT_IP		0x01 /* %I */ +#define FORMAT_CLIENT_MAC		0x02 /* %M */ +#define FORMAT_SWITCH_NAME		0x04 /* %N */ +#define FORMAT_SWITCH_LOCATION		0x08 /* %L */ +#define FORMAT_PORT_INDEX		0x10 /* %i */ +#define FORMAT_PORT_NAME		0x20 /* %n */ +#define FORMAT_PORT_DESCR		0x40 /* %d */ +#define FORMAT_PORT_WEBAUTH		0x80 /* %w */ + +/* Some info about the switch which we need */ +#define SWITCHF_NO_LLDP			0x01 +#define SWITCHF_BRIDGE_MIB_HAS_VLAN	0x02 + +/* IANA-AddressFamilyNumbers */ +#define IANA_AFN_OTHER			0 +#define IANA_AFN_IPV4			1 +#define IANA_AFN_IPV6			2 + +/* OIDs used by the program */ +static const oid SNMPv2_MIB_sysObjectID[] = +	{ SNMP_OID_MIB2, 1, 2, 0 }; +static const oid SNMPv2_MIB_sysName[] = +	{ SNMP_OID_MIB2, 1, 5, 0 }; +static const oid SNMPv2_MIB_sysLocation[] = +	{ SNMP_OID_MIB2, 1, 6, 0 }; +static const oid IF_MIB_ifDescr[] = +	{ SNMP_OID_MIB2, 2, 2, 1, 2 }; +static const oid IF_MIB_ifName[] = +	{ SNMP_OID_MIB2, 31, 1, 1, 1, 1 }; +static const oid IF_MIB_ifStackStatus[] = +	{ SNMP_OID_MIB2, 31, 1, 2, 1, 3 }; +static const oid IP_MIB_ipNetToPhysicalPhysAddress[] = +	{ SNMP_OID_MIB2, 4, 35, 1, 4 }; +static const oid BRIDGE_MIB_dot1dTpFdbPort[] = +	{ SNMP_OID_MIB2, 17, 4, 3, 1, 2 }; +static const oid LLDP_lldpLocSysName[] = +	{ 1, 0, 8802, 1, 1, 2, 1, 3, 3, 0 }; +static const oid LLDP_lldpRemManAddrIfSubtype[] = +	{ 1, 0, 8802, 1, 1, 2, 1, 4, 2, 1, 3 }; +static const oid HP_hpicfUsrAuthWebAuthSessionName[] =  +	{ SNMP_OID_ENTERPRISES, 11, 2, 14, 11, 5, 1, 19, 5, 1, 1, 2 }; +static const oid HP_hpicfUsrAuthPortReauthenticate[] = +	{ SNMP_OID_ENTERPRISES, 11, 2, 14, 11, 5, 1, 19, 2, 1, 1, 4 }; +static const oid SEMI_MIB_hpHttpMgVersion[] =  +	{ SNMP_OID_ENTERPRISES, 11, 2, 36, 1, 1, 2, 6, 0 }; + +/* ----------------------------------------------------------------- */ + +struct switch_info; + +static int num_queries = 0; +static int running = TRUE; +static int kick_out = FALSE; + +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; +static const char *snmp_community = NULL; +static const char *username_format = "%w"; +static struct switch_info *all_switches[SWITCH_HASH_SIZE]; +static struct switch_info *l3_root_dev, *l2_root_dev; +static int l3_if_ndx, l2_vlan_ndx; +static time_t current_time; +static int username_format_flags; + +static const blob_t space = BLOB_STR_INIT(" "); +static const blob_t lf = BLOB_STR_INIT("\n"); + +/* ----------------------------------------------------------------- */ + +#define BLOB_OID(objid)		BLOB_BUF(objid) +#define BLOB_OID_DYN(objid,len)	BLOB_PTR_LEN(objid, (len) * sizeof(oid)) + +static inline void blob_push_oid(blob_t *b, oid objid) +{ +	if (b->len >= sizeof(objid)) { +		*((oid*) b->ptr) = objid; +		b->ptr += sizeof(oid); +		b->len -= sizeof(oid); +	} else { +		*b = BLOB_NULL; +	} +} + +static inline oid blob_pull_oid(blob_t *b) +{ +	oid objid; + +	if (b->len >= sizeof(objid)) { +		objid = *((oid*) b->ptr); +		b->ptr += sizeof(oid); +		b->len -= sizeof(oid); +	} else { +		*b = BLOB_NULL; +		objid = -1; +	} +	return objid; +} + +static inline void blob_push_oid_dump(blob_t *b, blob_t d) +{ +	int i; + +	if (b->len >= d.len * sizeof(oid)) { +		for (i = 0; i < d.len; i++) { +			*((oid*) b->ptr) = (unsigned char) d.ptr[i]; +			b->ptr += sizeof(oid); +			b->len -= sizeof(oid); +		} +	} else { +		*b = BLOB_NULL; +	} +} + +static inline void blob_pull_oid_dump(blob_t *b, blob_t d) +{ +	int i; + +	if (b->len >= d.len * sizeof(oid)) { +		for (i = 0; i < d.len; i++) { +			d.ptr[i] = (unsigned char) *((oid*) b->ptr); +			b->ptr += sizeof(oid); +			b->len -= sizeof(oid); +		} +	} else { +		*b = BLOB_NULL; +	} +} + +/* ----------------------------------------------------------------- */ + +void blob_push_iana_afn(blob_t *b, sockaddr_any *addr) +{ +	unsigned char *ptr; +	int type=0, len=0; + +	switch (addr->any.sa_family) { +	case AF_INET: +		type = IANA_AFN_IPV4; +		len = 4; +		ptr = (unsigned char*) &addr->ipv4.sin_addr; +		break; +	} +	if (type == 0 || b->len < len) { +		*b = BLOB_NULL; +		return; +	} +	blob_push_oid(b, type); +	blob_push_oid(b, len); +	blob_push_oid_dump(b, BLOB_PTR_LEN(ptr, len)); +} + +sockaddr_any *blob_pull_iana_afn(blob_t *b, sockaddr_any *addr) +{ +	unsigned char *ptr = NULL; +	int type, len; + +	memset(addr, 0, sizeof(*addr)); +	type = blob_pull_oid(b); +	len  = blob_pull_oid(b); +	if (type == IANA_AFN_IPV4 && len == 4) { +		addr->ipv4.sin_family = AF_INET; +		ptr = (unsigned char*) &addr->ipv4.sin_addr; +	} +	if (ptr == NULL) { +		blob_pull_skip(b, len); +		return NULL; +	} +	blob_pull_oid_dump(b, BLOB_PTR_LEN(ptr, len)); +	return addr; +} + +/* ----------------------------------------------------------------- */ + +static void safe_free(void *ptr) +{ +	void **pptr = ptr; +	if (*pptr != NULL) { +		free(*pptr); +		*pptr = NULL; +	} +} + +struct cache_control { +	time_t				update_time; +	struct auth_context *		sleepers; +}; + +struct switch_port_info { +	struct switch_port_info *	next; +	int				port; +	struct cache_control		cache_control; +	sockaddr_any			link_partner; +}; + +struct switch_info { +	struct switch_info *		next; +	sockaddr_any			addr; +	netsnmp_session *		session; + +	struct cache_control		cache_control; +	int				flags; +	int				info_available; +	char *				system_name; +	char *				system_location; +	char *				system_version; +	blob_t				system_oid; + +	struct switch_port_info *	all_ports[PORT_HASH_SIZE]; +}; + +struct auth_context { +	char *			token; +	sockaddr_any		addr; +	unsigned char		mac[MAC_LEN]; +	int			info_available; +	struct switch_info *	current_switch; +	struct switch_port_info *spi; +	int			local_port; +	int			lldp_port[8]; +	int			num_lldp_ports; +	char *			port_name; +	char *			port_descr; +	char *			webauth_name; + +	void			(*pending_operation)(struct auth_context *); +	struct auth_context *	next_sleeper; +}; + +static void cache_update_time(void) +{ +	current_time = time(NULL); +	adbc_refresh(&adbc, current_time); +} + +static int cache_refresh( +		struct cache_control *cc, struct auth_context *auth, +		void (*callback)(struct auth_context *auth)) +{ +	int ret; + +	if (cc->update_time == -1 || +	    cc->update_time + CACHE_TIME >= current_time) { +		callback(auth); +		return 0; +	} + +	auth->pending_operation = callback; + +	ret = (cc->sleepers == NULL); +	auth->next_sleeper = cc->sleepers; +	cc->sleepers = auth; + +	return ret; +} + +static void cache_update(struct cache_control *cc) +{ +	struct auth_context *auth, *next; + +	cc->update_time = current_time; +	auth = cc->sleepers; +	cc->sleepers = NULL; +	for (; auth; auth = next) { +		next = auth->next_sleeper; +		auth->pending_operation(auth); +	} +} + +static void cache_update_manual(struct cache_control *cc) +{ +	cache_update(cc); +	cc->update_time = -1; +} + +static void switch_info_free(struct switch_info *si) +{ +	safe_free(&si->system_name); +	safe_free(&si->system_location); +	safe_free(&si->system_version); +	safe_free(&si->system_oid.ptr); +	si->info_available = 0; +	si->flags = 0; +} + +struct switch_info *get_switch(sockaddr_any *addr) +{ +	struct snmp_session config; +	struct switch_info *si; +	unsigned int bucket = addr_hash(addr) % ARRAY_SIZE(all_switches); + +	for (si = all_switches[bucket]; si != NULL; si = si->next) +		if (addr_cmp(&si->addr, addr) == 0) +			return si; + +	si = calloc(1, sizeof(*si)); +	if (si == NULL) +		return NULL; + +	addr_copy(&si->addr, addr); + +	snmp_sess_init(&config); +	if (snmp_community != NULL) { +		config.version = SNMP_VERSION_2c; +		config.community = (unsigned char *) snmp_community; +		config.community_len = strlen(snmp_community); +	} +	config.peername = (char *) addr_print(addr); +	si->session = snmp_open(&config); + +	si->next = all_switches[bucket]; +	all_switches[bucket] = si; + +	return si; +} + +struct switch_port_info *get_switch_port(struct switch_info *si, int port) +{ +	unsigned int bucket = port % ARRAY_SIZE(si->all_ports); +	struct switch_port_info *spi; + +	if (si == NULL) +		return NULL; + +	for (spi = si->all_ports[bucket]; spi != NULL; spi = spi->next) +		if (spi->port == port) +			return spi; + +	spi = calloc(1, sizeof(*spi)); +	if (spi == NULL) +		return NULL; + +	spi->port = port; +	spi->next = si->all_ports[bucket]; +	si->all_ports[bucket] = spi; + +	return spi; +} + +void link_switch(const char *a, int ap, const char *b, int bp) +{ +	struct switch_info *sia, *sib; +	struct switch_port_info *spia, *spib; +	sockaddr_any addr; + +	sia = get_switch(addr_parse(BLOB_STRLEN(a), &addr)); +	spia = get_switch_port(sia, ap); + +	sib = get_switch(addr_parse(BLOB_STRLEN(b), &addr)); +	spib = get_switch_port(sib, bp); + +	addr_copy(&spia->link_partner, &sib->addr); +	addr_copy(&spib->link_partner, &sia->addr); + +	cache_update_manual(&spia->cache_control); +	cache_update_manual(&spib->cache_control); +} + +static void auth_query_switch_info(struct auth_context *auth); +static void auth_query_lldp(struct auth_context *auth, int root_query); + +static void auth_free(struct auth_context *auth) +{ +	safe_free(&auth->token); +	safe_free(&auth->port_name); +	safe_free(&auth->port_descr); +	safe_free(&auth->webauth_name); +	free(auth); +} + +int resolve_ifName2ifIndex(struct switch_info *si, blob_t ifName) +{ +	netsnmp_pdu *pdu, *response = NULL; +	netsnmp_variable_list *vars, *lastvar = NULL; +	int rc = -1; + +	pdu = snmp_pdu_create(SNMP_MSG_GETBULK); +	pdu->non_repeaters = 0; +	pdu->max_repetitions = 10; +	snmp_add_null_var(pdu, oid_const(IF_MIB_ifName)); + +	do { +		if (snmp_synch_response(si->session, pdu, &response) != 0) +			return -1; +		if (response->errstat != SNMP_ERR_NOERROR) +			goto done; + +		for (vars = response->variables; vars; vars = vars->next_variable) { +			lastvar = vars; + +			if (vars->name_length < ARRAY_SIZE(IF_MIB_ifName) || +			    memcmp(vars->name, IF_MIB_ifName, sizeof(IF_MIB_ifName)) != 0) +				goto done; + +			if (vars->type != ASN_OCTET_STR) +				continue; + +			if (blob_cmp(ifName, BLOB_PTR_LEN(vars->val.string, vars->val_len)) != 0) +				continue; + +			rc = vars->name[vars->name_length - 1]; +			goto done; +		} + +		pdu = snmp_pdu_create(SNMP_MSG_GETBULK); +		pdu->non_repeaters = 0; +		pdu->max_repetitions = 10; +		snmp_add_null_var(pdu, lastvar->name, lastvar->name_length); + +		snmp_free_pdu(response); +		response = NULL; +	} while (1); + +done: +	if (response) +		snmp_free_pdu(response); +	return rc; +} + + +static int parse_format(const char *fmt) +{ +	int flags = 0; +	const char *p = fmt; + +	while ((p = strchr(p, '%')) != NULL) { +		switch (p[1]) { +		case 'I': +			flags |= FORMAT_CLIENT_IP; +			break; +		case 'M': +			flags |= FORMAT_CLIENT_MAC; +			break; +		case 'N': +			flags |= FORMAT_SWITCH_NAME; +			break; +		case 'L': +			flags |= FORMAT_SWITCH_LOCATION; +			break; +		case 'i': +			flags |= FORMAT_PORT_INDEX; +			break; +		case 'n': +			flags |= FORMAT_PORT_NAME; +			break; +		case 'd': +			flags |= FORMAT_PORT_DESCR; +			break; +		case 'w': +			flags |= FORMAT_PORT_WEBAUTH; +			break; +		} +		p++; +	} +	return flags; +} + +static void blob_push_formatted_username( +	blob_t *b, const char *fmt, struct auth_context *auth) +{ +	const char *o = fmt, *p = fmt; +	struct switch_info *si = auth->current_switch; + +	while ((p = strchr(p, '%')) != NULL) { +		blob_push(b, BLOB_PTR_LEN(o, p - o)); +		switch (p[1]) { +		case 'I': +			blob_push(b, BLOB_STRLEN((char*) addr_print(&auth->addr))); +			break; +		case 'M': +			blob_push_hexdump(b, BLOB_BUF(auth->mac)); +			break; +		case 'N': +			blob_push(b, BLOB_STRLEN(si->system_name)); +			break; +		case 'L': +			blob_push(b, BLOB_STRLEN(si->system_location)); +			break; +		case 'i': +			blob_push_uint(b, auth->local_port, 10); +			break; +		case 'n': +			blob_push(b, BLOB_STRLEN(auth->port_name)); +			break; +		case 'd': +			blob_push(b, BLOB_STRLEN(auth->port_descr)); +			break; +		case 'w': +			blob_push(b, BLOB_STRLEN(auth->webauth_name)); +			break; +		default: +			o = p; +			p++; +			continue; +		} +		p += 2; +		o = p; +	} +	blob_push(b, BLOB_STRLEN((char*) o)); +} + +static int auth_ok(struct auth_context *auth) +{ +	return (auth->info_available & username_format_flags) == username_format_flags; +} + +static void auth_completed(struct auth_context *auth) +{ +	char tmp[256]; +	void *token; +	struct authdb_entry entry; +	blob_t b = BLOB_BUF(tmp), un; + +	token = authdb_get(&adb, &auth->addr, &entry, 1); +	authdb_clear_entry(&entry); + +	blob_push(&b, BLOB_STRLEN(auth->token)); +	if (auth_ok(auth)) { +		if (token != NULL) { +			un = BLOB_BUF(entry.p.login_name); +			blob_push_formatted_username(&un, username_format, auth); +			memcpy(entry.p.mac_address, auth->mac, MAC_LEN); +			entry.p.switch_ip = auth->current_switch->addr; +			entry.p.switch_port = auth->local_port; +			authdb_commit_login(token, &entry, current_time, &adbc); +		} + +		blob_push(&b, BLOB_STR(" OK user=")); +		blob_push_formatted_username(&b, username_format, auth); +		blob_push(&b, BLOB_PTR_LEN("\n", 1)); +	} else { +		if (token != NULL) +			authdb_commit_logout(token); +		blob_push(&b, BLOB_STR(" ERR\n")); +	} +	b = blob_pushed(BLOB_BUF(tmp), b); +	write(STDOUT_FILENO, b.ptr, b.len); + +	auth_free(auth); +	num_queries--; +} + +static void auth_talk_snmp(struct auth_context *auth, netsnmp_session *s, netsnmp_pdu *pdu, netsnmp_callback callback) +{ +	if (snmp_async_send(s, pdu, callback, auth) == 0) { +		snmp_free_pdu(pdu); +		auth_completed(auth); +	} +} + +static void cache_talk_snmp(struct cache_control *cc, netsnmp_session *s, netsnmp_pdu *pdu, netsnmp_callback callback, struct auth_context *auth) +{ +	if (snmp_async_send(s, pdu, callback, auth) == 0) { +		snmp_free_pdu(pdu); +		cache_update(cc); +	} +} + +static blob_t var_parse_type(netsnmp_variable_list **varptr, int asn_tag) +{ +	netsnmp_variable_list *var = *varptr; +	if (var == NULL) +		return BLOB_NULL; + +	*varptr = var->next_variable; +	if (var->type != asn_tag) +		return BLOB_NULL; + +	return BLOB_PTR_LEN(var->val.string, var->val_len); +} + +static void auth_force_reauthentication(struct auth_context *auth) +{ +	struct switch_info *si = auth->current_switch; +	netsnmp_pdu *pdu; +	oid query_oids[ARRAY_SIZE(HP_hpicfUsrAuthPortReauthenticate)+1]; +	blob_t b = BLOB_BUF(query_oids); +	long one = 1; + +	pdu = snmp_pdu_create(SNMP_MSG_SET); +	blob_push(&b, BLOB_OID(HP_hpicfUsrAuthPortReauthenticate)); +	blob_push_oid(&b, auth->local_port); +	b = blob_pushed(BLOB_OID(query_oids), b); + +	snmp_pdu_add_variable(pdu, oid_blob(b), ASN_INTEGER, +			      (u_char *) &one, sizeof(one)); + +	/* Send asynchornously - ignore response */ +	if (snmp_send(si->session, pdu) == 0) +		snmp_free_pdu(pdu); +} + +static int auth_handle_portinfo_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ +	struct auth_context *auth = data; +	netsnmp_variable_list *var; + +	if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) +		goto done; + +	var = resp->variables; +	if (username_format_flags & FORMAT_PORT_NAME) +		auth->port_name = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	if (auth->port_name) +		auth->info_available |= FORMAT_PORT_NAME; +	if (username_format_flags & FORMAT_PORT_DESCR) +		auth->port_descr = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	if (auth->port_descr) +		auth->info_available |= FORMAT_PORT_DESCR; +	if (username_format_flags & FORMAT_PORT_WEBAUTH) +		auth->webauth_name = blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	if (auth->webauth_name) +		auth->info_available |= FORMAT_PORT_WEBAUTH; + +done: +	if (kick_out && auth_ok(auth)) +		auth_force_reauthentication(auth); + +	auth_completed(auth); +	return 1; +} + +static void auth_query_port_info(struct auth_context *auth) +{ +	struct switch_info *si = auth->current_switch; +	netsnmp_pdu *pdu; +	oid query_oids[MAX_OID_LEN]; +	blob_t query; + +	if (auth_ok(auth)) { +		auth_completed(auth); +		return; +	} + +	pdu = snmp_pdu_create(SNMP_MSG_GET); +	if (username_format_flags & FORMAT_PORT_NAME) { +		query = BLOB_OID(query_oids); +		blob_push(&query, BLOB_OID(IF_MIB_ifName)); +		blob_push_oid(&query, auth->local_port); +		query = blob_pushed(BLOB_OID(query_oids), query); +		snmp_add_null_var(pdu, oid_blob(query)); +	} +	if (username_format_flags & FORMAT_PORT_DESCR) { +		query = BLOB_OID(query_oids); +		blob_push(&query, BLOB_OID(IF_MIB_ifDescr)); +		blob_push_oid(&query, auth->local_port); +		query = blob_pushed(BLOB_OID(query_oids), query); +		snmp_add_null_var(pdu, oid_blob(query)); +	} +	if (username_format_flags & FORMAT_PORT_WEBAUTH) { +		query = BLOB_OID(query_oids); +		blob_push(&query, BLOB_OID(HP_hpicfUsrAuthWebAuthSessionName)); +		blob_push_oid(&query, auth->local_port); +		blob_push_oid_dump(&query, BLOB_BUF(auth->mac)); +		query = blob_pushed(BLOB_OID(query_oids), query); +		snmp_add_null_var(pdu, oid_blob(query)); +	} +	auth_talk_snmp(auth, si->session, pdu, auth_handle_portinfo_reply); +} + +static int auth_handle_lldp_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ +	struct auth_context *auth = data; +	struct switch_port_info *spi = auth->spi; +	netsnmp_variable_list *var = resp->variables; +	blob_t res; +	int i; + +	if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) +		goto fail; + +	/* print_variable(var->name, var->name_length, var); */ + +	for (i = 0; i < auth->num_lldp_ports; i++) { +		if (var == NULL) +			goto fail; +		if (var->type != ASN_INTEGER) +			continue; +		/* INDEX: TimeFilter, Port, Idx, Family, Addr */ +		res = BLOB_OID_DYN(var->name, var->name_length); +		if (blob_pull_matching(&res, BLOB_OID(LLDP_lldpRemManAddrIfSubtype)) && +		    blob_pull_oid(&res) == 0 && +		    blob_pull_oid(&res) == auth->lldp_port[i]) { +			/* We have mathing LLDP neighbor */ +			blob_pull_oid(&res); +			blob_pull_iana_afn(&res, &spi->link_partner); +			cache_update(&spi->cache_control); +			return 1; +		} + +		var = var->next_variable; +	} +	auth->num_lldp_ports = 0; +	for (; var; var = var->next_variable) { +		if (var->type != ASN_INTEGER) +			break; +		/* print_variable(var->name, var->name_length, var); */ +		res = BLOB_OID_DYN(var->name, var->name_length); +		if (!blob_pull_matching(&res, BLOB_OID(IF_MIB_ifStackStatus))) +			break; +		if (blob_pull_oid(&res) != auth->local_port) +			break; +		auth->lldp_port[auth->num_lldp_ports++] = blob_pull_oid(&res); +		if (auth->num_lldp_ports >= ARRAY_SIZE(auth->lldp_port)) +			break; +	} +	if (auth->num_lldp_ports) { +		auth_query_lldp(auth, FALSE); +		return 1; +	} +fail: +	cache_update(&spi->cache_control); +	return 1; +} + +static void auth_query_lldp(struct auth_context *auth, int root_query) +{ +	struct switch_info *si = auth->current_switch; +	struct switch_port_info *spi = auth->spi; +	netsnmp_pdu *pdu; +	oid query_oids[MAX_OID_LEN]; +	blob_t query; +	int i; + +	/* printf("Query LLDP info for %s:%d\n", addr_print(&si->addr), spi->port); */ + +	if (si->flags & SWITCHF_NO_LLDP) { +		memset(&spi->link_partner, 0, sizeof(spi->link_partner)); +		cache_update(&spi->cache_control); +		return; +	} + +	if (root_query) { +		auth->num_lldp_ports = 1; +		auth->lldp_port[0] = auth->local_port; +	} + +	pdu = snmp_pdu_create(SNMP_MSG_GETBULK); +	pdu->non_repeaters = auth->num_lldp_ports; +	pdu->max_repetitions = 8; + +	for (i = 0; i < auth->num_lldp_ports; i++) { +		/* Query LLDP neighbor. lldpRemManAddrTable is INDEXed with +		 * [TimeFilter, LocalPort, Index, AddrSubType, Addr] */ +		query = BLOB_OID(query_oids); +		blob_push(&query, BLOB_OID(LLDP_lldpRemManAddrIfSubtype)); +		blob_push_oid(&query, 0); +		blob_push_oid(&query, auth->lldp_port[i]); +		query = blob_pushed(BLOB_OID(query_oids), query); +		snmp_add_null_var(pdu, oid_blob(query)); +	} + +	if (root_query) { +		/* Query interface stacking in case this is aggregated trunk: +		 * IF-MIB::ifStackStatus.<ifUpperIndex>.<ifLowerIndex> */ +		query = BLOB_OID(query_oids); +		blob_push(&query, BLOB_OID(IF_MIB_ifStackStatus)); +		blob_push_oid(&query, auth->local_port); +		query = blob_pushed(BLOB_OID(query_oids), query); +		snmp_add_null_var(pdu, oid_blob(query)); +	} + +	cache_talk_snmp(&spi->cache_control, si->session, pdu, auth_handle_lldp_reply, auth); +} + +static void auth_check_spi(struct auth_context *auth) +{ +	struct switch_port_info *spi = auth->spi; + +	if (addr_len(&spi->link_partner) != 0) { +		auth->current_switch = get_switch(&spi->link_partner); +		auth_query_switch_info(auth); +	} else { +		auth_query_port_info(auth); +	} +} + +static int auth_handle_fib_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ +	struct auth_context *auth = data; +	struct switch_info *si = auth->current_switch; +	struct switch_port_info *spi; +	netsnmp_variable_list *var; + +	if (oper != NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE) +		goto failed; + +	var = resp->variables; +	/* print_variable(var->name, var->name_length, var); */ + +	if (var->type != ASN_INTEGER) +		goto failed; + +	auth->local_port = *var->val.integer; +	auth->info_available |= FORMAT_PORT_INDEX; +	auth->spi = spi = get_switch_port(si, auth->local_port); +	if (cache_refresh(&spi->cache_control, auth, auth_check_spi)) +		auth_query_lldp(auth, TRUE); +	return 1; + +	/* No further info available */ +failed: +	auth_completed(auth); +	return 1; +} + +static void auth_query_fib(struct auth_context *auth) +{ +	oid query_oids[MAX_OID_LEN]; +	blob_t query; +	struct switch_info *si = auth->current_switch; +	netsnmp_pdu *pdu; + +	auth->info_available |= si->info_available; + +	/* printf("Probing switch %s\n", addr_print(&si->addr)); */ + +	pdu = snmp_pdu_create(SNMP_MSG_GET); + +	/* FIXME: Implement Q-BRIDGE-MIB query too. */ + +	/* BRIDGE-MIB::dot1dTpFdbPort.<MAC> = INTEGER: port */ +	query = BLOB_OID(query_oids); +	blob_push(&query, BLOB_OID(BRIDGE_MIB_dot1dTpFdbPort)); +	if (si->flags & SWITCHF_BRIDGE_MIB_HAS_VLAN) +		blob_push_oid(&query, l2_vlan_ndx); +	blob_push_oid_dump(&query, BLOB_BUF(auth->mac)); +	query = blob_pushed(BLOB_OID(query_oids), query); +	snmp_add_null_var(pdu, oid_blob(query)); + +	auth_talk_snmp(auth, si->session, pdu, auth_handle_fib_reply); +} + +static int auth_handle_switch_info_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ +	static const oid HP_ICF_OID_hpEtherSwitch[] = +		{ SNMP_OID_ENTERPRISES, 11, 2, 3, 7, 11 }; +	struct auth_context *auth = data; +	struct switch_info *si = auth->current_switch; +	netsnmp_variable_list *var; +	blob_t b; + +	switch_info_free(si); +	var = resp->variables; +	si->system_name		= blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	si->system_location	= blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	si->system_oid		= blob_dup(var_parse_type(&var, ASN_OBJECT_ID)); +	si->system_version	= blob_cstr_dup(var_parse_type(&var, ASN_OCTET_STR)); +	if (blob_is_null(var_parse_type(&var, ASN_OCTET_STR))) +		si->flags |= SWITCHF_NO_LLDP; +	if (si->system_name) +		si->info_available |= FORMAT_SWITCH_NAME; +	if (si->system_location) +		si->info_available |= FORMAT_SWITCH_LOCATION; +	b = si->system_oid; +	if (blob_pull_matching(&b, BLOB_OID(HP_ICF_OID_hpEtherSwitch))) { +		/* Hewlett-Packard ProCurve Switches */ +		switch (blob_pull_oid(&b)) { +		case 104: /* 1810G-24 with <P1.18 is useless +			   *                P1.19-P1.20 has BRIDGE-MIB bug +			   *                P2.x+ seem to work more or less */ +			if (si->system_version && +			    si->system_version[0] == 'P' && +			    si->system_version[1] == '.' && +			    si->system_version[2] <= '1') +			    si->flags |= SWITCHF_BRIDGE_MIB_HAS_VLAN; +			break; +		} +	} +	cache_update(&si->cache_control); +	return 1; +} + +static void auth_query_switch_info(struct auth_context *auth) +{ +	struct switch_info *si = auth->current_switch; +	netsnmp_pdu *pdu; + +	auth->info_available &= +		~(FORMAT_SWITCH_NAME | FORMAT_SWITCH_LOCATION | +		  FORMAT_PORT_INDEX); + +	if (!cache_refresh(&si->cache_control, auth, auth_query_fib)) +		return; + +	pdu = snmp_pdu_create(SNMP_MSG_GET); +	snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysName)); +	snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysLocation)); +	snmp_add_null_var(pdu, oid_const(SNMPv2_MIB_sysObjectID)); +	snmp_add_null_var(pdu, oid_const(SEMI_MIB_hpHttpMgVersion)); +	snmp_add_null_var(pdu, oid_const(LLDP_lldpLocSysName)); +	cache_talk_snmp(&si->cache_control, si->session, pdu, auth_handle_switch_info_reply, auth); +} + +static int auth_handle_arp_reply(int oper, netsnmp_session *s, int reqid, netsnmp_pdu *resp, void *data) +{ +	struct auth_context *auth = data; +	netsnmp_variable_list *var = resp->variables; + +	if (oper == NETSNMP_CALLBACK_OP_RECEIVED_MESSAGE && +	    var->type == ASN_OCTET_STR && +	    var->val_len == MAC_LEN) { +		memcpy(auth->mac, var->val.string, MAC_LEN); +		auth->info_available |= FORMAT_CLIENT_MAC; +		if (!auth_ok(auth)) { +			auth->current_switch = l2_root_dev; +			auth_query_switch_info(auth); +			return 1; +		} +	} + +	auth_completed(auth); +	return 1; +} + +void start_authentication(blob_t token, blob_t ip) +{ +	struct auth_context *auth; +	oid query_oids[MAX_OID_LEN]; +	blob_t query; +	netsnmp_pdu *pdu; + +	num_queries++; + +	auth = calloc(1, sizeof(*auth)); +	auth->token = blob_cstr_dup(token); +	if (addr_parse(ip, &auth->addr) == NULL) { +		auth_completed(auth); +		return; +	} +	auth->info_available = FORMAT_CLIENT_IP; + +	/* IP-MIB::ipNetToPhysicalPhysAddress.<ifIndex>.ipv4."1.2.3.4" +	 *  = STRING: 01:12:34:56:78:9a */ +	pdu = snmp_pdu_create(SNMP_MSG_GET); + +	query = BLOB_OID(query_oids); +	blob_push(&query, BLOB_OID(IP_MIB_ipNetToPhysicalPhysAddress)); +	blob_push_oid(&query, l3_if_ndx); +	blob_push_iana_afn(&query, &auth->addr); +	query = blob_pushed(BLOB_OID(query_oids), query); +	snmp_add_null_var(pdu, oid_blob(query)); + +	auth_talk_snmp(auth, l3_root_dev->session, pdu, auth_handle_arp_reply); +} + +static void handle_line(blob_t line) +{ +	blob_t id, ipaddr; + +	id = blob_pull_cspn(&line, space); +	blob_pull_spn(&line, space); +	ipaddr = blob_pull_cspn(&line, space); + +	start_authentication(id, ipaddr); +} + +static void read_input(void) +{ +	static char buffer[256]; +	static blob_t left; + +	blob_t b, line; +	int r; + +	if (blob_is_null(left)) +		left = BLOB_BUF(buffer); + +	r = read(STDIN_FILENO, left.ptr, left.len); +	if (r < 0) +		return; +	if (r == 0) { +		running = 0; +		return; +	} +	left.ptr += r; +	left.len -= r; + +	b = blob_pushed(BLOB_BUF(buffer), left); +	do { +		line = blob_pull_cspn(&b, lf); +		if (!blob_pull_matching(&b, lf)) +			return; + +		handle_line(line); + +		if (b.len) { +			memcpy(buffer, b.ptr, b.len); +			b.ptr = buffer; +		} +		left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); +	} while (b.len); +} + +void load_topology(const char *file) +{ +	char a_ip[64], b_ip[64]; +	int a_port, b_port; +	FILE *in; + +	in = fopen(file, "r"); +	if (in == NULL) +		return; + +	while (!feof(in)) { +		if (fscanf(in, "%s %d %s %d\n", +			   a_ip, &a_port, b_ip, &b_port) == 4) +			link_switch(a_ip, a_port, b_ip, b_port); +	} +	fclose(in); +} + +int main(int argc, char **argv) +{ +	const char *l3_root = NULL, *l3_ifname = NULL; +	const char *l2_root = NULL, *l2_vlan = NULL; +	struct timeval timeout; +	sockaddr_any addr; +	fd_set fdset; +	int opt, fds, block, i; + +	setenv("MIBS", "", 1); +	init_snmp("squark-auth"); + +	while ((opt = getopt(argc, argv, "c:r:i:R:v:f:T:K")) != -1) { +		switch (opt) { +		case 'c': +			snmp_community = optarg; +			break; +		case 'r': +			l3_root = optarg; +			break; +		case 'i': +			l3_ifname = optarg; +			break; +		case 'R': +			l2_root = optarg; +			break; +		case 'v': +			l2_vlan = optarg; +			break; +		case 'f': +			username_format = optarg; +			break; +		case 'T': +			load_topology(optarg); +			break; +		case 'K': +			kick_out = TRUE; +			break; +		} +	} +	argc -= optind; +	argv += optind; + +	if (l3_root == NULL || l3_ifname == NULL || l2_vlan == NULL) { +		printf("Mandatory information missing\n"); +		return 1; +	} + +	sqdb_open(&db, "/var/lib/squark/squark.db"); +	authdb_open(&adb, &adbc, &db); + +	if (l2_root == NULL) +		l2_root = l3_root; + +	l3_root_dev = get_switch(addr_parse(BLOB_STRLEN(l3_root), &addr)); +	l3_if_ndx   = resolve_ifName2ifIndex(l3_root_dev, BLOB_STRLEN((char *) l3_ifname)); +	l2_root_dev = get_switch(addr_parse(BLOB_STRLEN(l2_root), &addr)); +	l2_vlan_ndx = atoi(l2_vlan); +	username_format_flags = parse_format(username_format); + +	if (kick_out) +		username_format_flags |= FORMAT_PORT_WEBAUTH; + +	for (i = 0; i < argc; i++) { +		blob_t b = BLOB_STRLEN(argv[i]); +		start_authentication(b, b); +		running = FALSE; +	} + +	fcntl(STDIN_FILENO, F_SETFL, O_NONBLOCK); +	while (num_queries || running) { +		fds = 0; +		block = 1; + +		FD_ZERO(&fdset); +		if (running) { +			FD_SET(STDIN_FILENO, &fdset); +			fds = STDIN_FILENO + 1; +		} +		snmp_select_info(&fds, &fdset, &timeout, &block); +		fds = select(fds, &fdset, NULL, NULL, block ? NULL : &timeout); +		cache_update_time(); +		if (fds) { +			if (FD_ISSET(STDIN_FILENO, &fdset)) +				read_input(); +			snmp_read(&fdset); +		} else +			snmp_timeout(); +	} +	authdb_close(&adb); +	sqdb_close(&db); + +	return 0; +} diff --git a/src/squark-filter.c b/src/squark-filter.c new file mode 100644 index 0000000..995da40 --- /dev/null +++ b/src/squark-filter.c @@ -0,0 +1,431 @@ +/* squark-filter.c - Squid User Authentication and Rating Kit + *   An external redirector for Squid which analyzes the URL according + *   to a database and can redirect to a block page. + * + * Copyright (C) 2010 Timo Teräs <timo.teras@iki.fi> + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 as published + * by the Free Software Foundation. See http://www.gnu.org/ for details. + */ + +#include <time.h> +#include <stdio.h> +#include <ctype.h> +#include <string.h> +#include <unistd.h> + +#include <cmph.h> + +#include "blob.h" +#include "addr.h" +#include "filterdb.h" +#include "authdb.h" + +#define FILTER_OVERRIDE_TIMEOUT		(15*60) + +static struct sqdb db; +static struct authdb adb; +static struct authdb_config adbc; + +static int running = 1; +static const blob_t dash = BLOB_STR_INIT("-"); +static const blob_t space = BLOB_STR_INIT(" "); +static const blob_t slash = BLOB_STR_INIT("/"); +static const blob_t lf = BLOB_STR_INIT("\n"); +static struct authdb adb; +static time_t now; + +struct url_info { +	blob_t protocol; +	blob_t username; +	blob_t password; +	blob_t host; +	blob_t significant_host; +	blob_t path; +	blob_t query; +	blob_t fragment; +	int port; +	int is_ipv4; +	int num_dots; +}; + +struct url_dns_part_data { +	blob_t word; +	int num_dots; +	int numeric; +}; + +void blob_pull_url_dns_part(blob_t *b, struct url_dns_part_data *udp) +{ +	blob_t t = *b; +	int c, i, dots = 0, numeric = 1; + +	for (i = 0; i < t.len; i++) { +		c = (unsigned char) t.ptr[i]; +		switch (c) { +		case '.': +			dots++; +			break; +		case ':': case '@': case '/': case '?': +			*b = BLOB_PTR_LEN(t.ptr + i, t.len - i); +			udp->word = BLOB_PTR_LEN(t.ptr, i); +			udp->num_dots = dots; +			udp->numeric = numeric; +			return; +		case '0': case '1': case '2': case '3': case '4': +		case '5': case '6': case '7': case '8': case '9': +			break; +		default: +			numeric = 0; +			break; +		} +	} + +	*b = BLOB_NULL; +	udp->word = t; +	udp->num_dots = dots; +	udp->numeric = numeric; +} + +/* URI is generalized as: + * [proto://][user[:password]@]domain.name[:port][/[path/to][?p=a&q=b;r=c][#fragment]] + * Character literals used as separators are: + *  : / @ ? & ; # + * Also URI escaping says to treat %XX as encoded hex value. + */ + +static int url_parse(blob_t uri, struct url_info *nfo) +{ +	struct url_dns_part_data prev, cur; + +	memset(&prev, 0, sizeof(prev)); +	memset(nfo, 0, sizeof(*nfo)); + +	/* parse protocol, username/password and domain name/port */ +	do { +		blob_pull_url_dns_part(&uri, &cur); + +		switch (uri.len ? uri.ptr[0] : '/') { +		case ':': +			blob_pull_skip(&uri, 1); +			if (blob_is_null(nfo->protocol) && +			    blob_pull_matching(&uri, BLOB_STR("//"))) +				nfo->protocol = cur.word; +			else +				prev = cur; +			break; +		case '@': +			blob_pull_skip(&uri, 1); +			if (!blob_is_null(nfo->username) || +			    !blob_is_null(nfo->password)) +				goto error; +			if (!blob_is_null(prev.word)) { +				nfo->username = prev.word; +				nfo->password = cur.word; +			} else +				nfo->username = cur.word; +			memset(&prev, 0, sizeof(prev)); +			break; +		case '/': +		case '?': +			if (!blob_is_null(prev.word)) { +				nfo->host = prev.word; +				nfo->num_dots = prev.num_dots; +				nfo->is_ipv4 = prev.numeric && prev.num_dots == 3; +				nfo->port = blob_pull_uint(&cur.word, 10); +			} else { +				nfo->host = cur.word; +				nfo->num_dots = cur.num_dots; +				nfo->is_ipv4 = cur.numeric && cur.num_dots == 3; +			} +			if (blob_is_null(nfo->host)) +				nfo->host = BLOB_STR("localhost"); +			break; +		} +	} while (blob_is_null(nfo->host) && !blob_is_null(uri)); + +	/* rest of the components */ +	nfo->path = blob_pull_cspn(&uri, BLOB_STR("?&;#")); +	nfo->query = blob_pull_cspn(&uri, BLOB_STR("#")); +	nfo->fragment = uri; + +	/* fill in defaults if needed */ +	if (blob_is_null(nfo->protocol)) { +		if (nfo->port == 443) +			nfo->protocol = BLOB_STR("https"); +		else +			nfo->protocol = BLOB_STR("http"); +		if (nfo->port == 0) +			nfo->port = 80; +	} else if (nfo->port == 0) { +		if (blob_cmp(nfo->protocol, BLOB_STR("https")) == 0) +			nfo->port = 443; +		else +			nfo->port = 80; +	} +	if (blob_is_null(nfo->path)) +		nfo->path = BLOB_STR("/"); + +	/* significant host name */ +	nfo->significant_host = nfo->host; +	if (nfo->num_dots > 1) { +		blob_t b = nfo->significant_host; +		if (blob_pull_matching(&b, BLOB_STR("www")) && +		    (blob_pull_uint(&b, 10), 1) && +		    blob_pull_matching(&b, BLOB_STR("."))) +			nfo->significant_host = b; +	} +	return 1; + +error: +	return 0; +} + +static void url_print(struct url_info *nfo) +{ +#define print_field(nfo, x) if (!blob_is_null(nfo->x)) printf(" %s{%.*s}", #x, nfo->x.len, nfo->x.ptr) +	print_field(nfo, protocol); +	print_field(nfo, username); +	print_field(nfo, password); +	print_field(nfo, host); +	printf(" port{%d}", nfo->port); +	print_field(nfo, path); +	print_field(nfo, query); +	print_field(nfo, fragment); +#undef print_field +	printf("\n"); +	fflush(stdout); +} + +static int url_classify(struct url_info *url, struct sqdb *db) +{ +	unsigned char buffer[512]; +	blob_t key, got, tld, keybuf, keylimits; +	void *cmph; +	struct sqdb_index_entry *indx; +	cmph_uint32 i = SQDB_PARENT_ROOT, previ = SQDB_PARENT_ROOT; +	int dots_done = 1; + +	cmph = sqdb_section_get(db, SQDB_SECTION_INDEX_MPH, NULL); +	indx = sqdb_section_get(db, SQDB_SECTION_INDEX, NULL); + +	keybuf = BLOB_BUF(buffer); +	blob_push_lower(&keybuf, url->significant_host); +	key = keylimits = blob_pushed(BLOB_BUF(buffer), keybuf); + +	/* search for most qualified domain match; do first lookup +	 * with two domain components */ +	if (url->is_ipv4) { +		i = cmph_search_packed(cmph, key.ptr, key.len); + +		if (indx[i].parent != SQDB_PARENT_IPV4 || +		    indx[i].component != blob_inet_addr(url->host)) { +			i = previ; +			goto parent_dns_match; +		} +	} else { +		key = BLOB_PTR_LEN(key.ptr + key.len, 0); +		tld = blob_expand_head(&key, keylimits, '.'); + +		do { +			/* add one more domain component */ +			got = blob_expand_head(&key, keylimits, '.'); +			if (blob_is_null(got)) +				break; + +			previ = i; +			i = cmph_search_packed(cmph, key.ptr, key.len); +			if (!blob_is_null(tld)) { +				int p = indx[i].parent; + +				if (p == SQDB_PARENT_ROOT || +				    p == SQDB_PARENT_IPV4 || +				    indx[p].parent != SQDB_PARENT_ROOT || +				    blob_cmp(tld, sqdb_get_string_literal(db, indx[p].component)) != 0) { +					/* top level domain did not match */ +					i = -1; +					goto parent_dns_match; +				} +				tld = BLOB_NULL; +				previ = p; +			} +			if (indx[i].parent != previ || +			    blob_cmp(got, sqdb_get_string_literal(db, indx[i].component)) != 0) { +				/* the subdomain did no longer match, use +				 * parents classification */ +				i = previ; +				goto parent_dns_match; +			} +			dots_done++; +		} while (indx[i].has_subdomains); +	} + +	/* No paths to match for */ +	if (i == SQDB_PARENT_ROOT || !indx[i].has_paths || key.ptr != keylimits.ptr) +		goto parent_dns_match; + +	/* and then search for path matches -- construct hashing +	 * string of url decoded path */ +	blob_push_urldecode(&keybuf, url->path); +	key = keylimits = blob_pushed(BLOB_BUF(buffer), keybuf); + +	while (indx[i].has_paths) { +		/* add one more path component */ +		got = blob_expand_tail(&key, keylimits, '/'); +		if (blob_is_null(got)) +			break; +		previ = i; +		i = cmph_search_packed(cmph, key.ptr, key.len); +		tld = sqdb_get_string_literal(db, indx[i].component); +		if (blob_cmp(got, sqdb_get_string_literal(db, indx[i].component)) != 0) { +			/* the subdomain did no longer match, use  +			* parents classification */ +			i = previ; +			goto parent_dns_match; +		} +	} + +parent_dns_match: +	if (i == SQDB_PARENT_ROOT) +		return 0; /* no category */ + +	return indx[i].category; +} + +static blob_t get_category_name(struct sqdb *db, int id) +{ +	uint32_t *c, clen; + +	c = sqdb_section_get(db, SQDB_SECTION_CATEGORIES, &clen); +	if (c == NULL || id < 0 || id * sizeof(uint32_t) >= clen) +		return BLOB_NULL; + +	return sqdb_get_string_literal(db, c[id]); +} + +static void send_ok(blob_t tag) +{ +	static char buffer[64]; +	blob_t b = BLOB_BUF(buffer); + +	blob_push(&b, tag); +	blob_push(&b, lf); +	b = blob_pushed(BLOB_BUF(buffer), b); + +	write(STDOUT_FILENO, b.ptr, b.len); +} + +static void send_redirect(blob_t redirect_page, blob_t tag, blob_t url, blob_t categ, blob_t username) +{ +	static char buffer[8*1024]; +	blob_t b = BLOB_BUF(buffer); + +	blob_push(&b, tag); +	blob_push(&b, BLOB_STR(" 302:")); +	blob_push(&b, adbc.redirect_url_base); +	blob_push(&b, redirect_page); +	blob_push(&b, BLOB_STR("?REASON=")); +	blob_push_urlencode(&b, categ); +	blob_push(&b, BLOB_STR("&USER=")); +	blob_push_urlencode(&b, username); +	blob_push(&b, BLOB_STR("&DENIEDURL=")); +	blob_push_urlencode(&b, url); +	blob_push(&b, lf); +	b = blob_pushed(BLOB_BUF(buffer), b); + +	write(STDOUT_FILENO, b.ptr, b.len); +} + +static void read_input(struct sqdb *db) +{ +	static char buffer[8 * 1024]; +	static blob_t left; + +	blob_t b, line, id, ipaddr, url, username; +	struct url_info nfo; +	int r, category, auth_ok; +	sockaddr_any addr; +	struct authdb_entry entry; +	void *token; + +	if (blob_is_null(left)) +		left = BLOB_BUF(buffer); + +	r = read(STDIN_FILENO, left.ptr, left.len); +	if (r < 0) +		return; +	if (r == 0) { +		running = 0; +		return; +	} +	left.ptr += r; +	left.len -= r; + +	now = time(NULL); + +	b = blob_pushed(BLOB_BUF(buffer), left); +	do { +		line = blob_pull_cspn(&b, lf); +		if (!blob_pull_matching(&b, lf)) +			return; + +		id = blob_pull_cspn(&line, space); +		blob_pull_spn(&line, space); +		url = blob_pull_cspn(&line, space); +		blob_pull_spn(&line, space); +		ipaddr = blob_pull_cspn(&line, slash); /* client addr */ +		blob_pull_cspn(&line, space); /* fqdn */ +		blob_pull_spn(&line, space); +		username = blob_pull_cspn(&line, space); +		/* http method */ +		/* urlgroup */ +		/* myaddr=xxx myport=xxx etc */ + +		if (!blob_is_null(url) && +		    addr_parse(ipaddr, &addr)) { +			/* valid request, handle it */ +			if (url_parse(url, &nfo)) +				category = url_classify(&nfo, db); +			else +				category = 0; + +			token = authdb_get(&adb, &addr, &entry, 1); +			if (authdb_check_login(token, &entry, username, now)) { +				auth_ok = 1; +				username = BLOB_STRLEN(entry.p.login_name); +			} else { +				auth_ok = 0; +			} + +			if (!auth_ok) { +				send_redirect(BLOB_STR("login.cgi"), id, url, BLOB_STR("auth"), username); +			} else if (((1ULL << category) & entry.p.block_categories) && +				   (now < entry.override_time || +				    now > entry.override_time + FILTER_OVERRIDE_TIMEOUT || +				   ((1ULL << category) & entry.p.hard_block_categories))) { +				send_redirect(BLOB_STR("warning.cgi"), id, url, get_category_name(db, category), username); +			} else +				send_ok(id); +		} + +		if (b.len) { +			memcpy(buffer, b.ptr, b.len); +			b.ptr = buffer; +		} +		left = BLOB_PTR_LEN(buffer + b.len, sizeof(buffer) - b.len); +	} while (b.len); +} + +int main(int argc, char **argv) +{ +	sqdb_open(&db, "/var/lib/squark/squark.db"); +	authdb_open(&adb, &adbc, &db); + +	while (running) +		read_input(&db); + +	sqdb_close(&db); +	authdb_close(&adb); +} | 
