#!/usr/bin/lua5.2 --[[ Certificate Authority tool for Dynamic Multipoint VPN Copyright (C) 2014-2017 Kaarle Ritvanen Copyright (C) 2015 Timo Teräs Copyright (C) 2017 Natanael Copa Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ]]-- conf_file = io.open('config.yaml') config = require('lyaml').load(conf_file:read('*a')) conf_file:close() asn1 = require('asn1') rfc3779 = require('asn1.rfc3779') rfc5280 = require('asn1.rfc5280') pkcs12 = require('openssl.pkcs12') pkey = require('openssl.pkey') x509 = require('openssl.x509') x509an = require('openssl.x509.altname') x509ch = require('openssl.x509.chain') x509crl = require('openssl.x509.crl') x509ext = require('openssl.x509.extension') x509n = require('openssl.x509.name') stat = require('posix.sys.stat') unistd = require('posix.unistd') stringy = require('stringy') dev = io.open('/dev/urandom') s = dev:read(4) dev:close() assert(s:len() == 4) seed = 0 for i=1,4 do seed = seed * 256 + s:byte(i) end math.randomseed(seed) function set_config_defaults(config, defaults) for k, v in pairs(defaults) do if not config[k] then config[k] = v elseif type(v) == 'table' then set_config_defaults(config[k], v) end end end set_config_defaults( config, { ['db-file']='dmvpn-ca.sqlite3', cert={ lifetime=365 * 24 * 60 * 60, ['hash-alg']='SHA256', key={type='EC', curve='secp384r1'} }, ca={dn='O=DMVPN'}, hub={ dn='$ROOT,OU=Hubs,CN=$NAME', ['default-name']='Hub-$ID', subnets={'10.0.0.0/8'} }, spoke={ dn='$ROOT,OU=$SITE,CN=$NAME', ['default-name']='VPNc-$ID' }, crl={}, ['password-length']=16, ['default-asn']=65000 } ) for _, ct in ipairs{'ca', 'hub', 'spoke', 'crl'} do set_config_defaults(config[ct], config.cert) end now = os.time() stat.umask(bit32.bor(stat.S_IRWXG, stat.S_IRWXO)) function connection() if not sql then sql = require('luasql.sqlite3').sqlite3() conn = sql:connect(config['db-file']) conn:setautocommit(false) end return conn end function execute(statement) local res, err = connection():execute(statement) if res then return res end error(err) end function escape(s) if not s then return 'NULL' end return type(s) == 'string' and "'"..connection():escape(s).."'" or s end function value_list(values) local res = {} for k, v in pairs(values) do table.insert(res, k..' = '..escape(v)) end return res end function insert(tbl, values) local columns = {} local escaped = {} for k, v in pairs(values) do table.insert(columns, k) table.insert(escaped, escape(v)) end execute( 'INSERT INTO '..tbl..'('..table.concat(columns, ',').. ') VALUES ('..table.concat(escaped, ',')..')' ) end function where(filter) if not filter then return '' end if type(filter) == 'table' then filter = value_list(filter) if not next(filter) then return '' end filter = table.concat(filter, ' AND ') end return ' WHERE '..filter end function update(tbl, values, filter) execute( 'UPDATE '..tbl..' SET '.. table.concat(value_list(values), ', ')..where(filter) ) end function delete(tbl, filter) execute('DELETE FROM '..tbl..where(filter)) end function select_many(what, from, filter, mode) local cur = execute('SELECT '..what..' FROM '..from..where(filter)) return function() local row = cur:fetch(mode and {}, mode) if row == nil then cur:close() end return row end end function select_certs(filter) return select_many('*', 'certificate', filter, 'a') end function select_one(...) local res for row in select_many(...) do assert(res == nil) res = row end return res end function exists(tbl, filter, active_only) if active_only then filter.active = '1' end return select_one('0', tbl, filter) and true or false end function next_key(tbl, key, filter) return (select_one('MAX('..key..')', tbl, filter) or 0) + 1 end function toint(s, min, max, desc) local i = tonumber(s) if i and i == math.floor(i) and i >= min and (not max or i <= max) then return i end if desc then error('Invalid '..desc..': '..s) end end function toid(s) return toint(s, 1, nil, 'ID') end function detect_afi(s) x509an.new():add('IP', s) return s:find(':') and 2 or 1 end function detect_prefix_afi(s) local function fail() error('Invalid network address: '..s) end local comps = stringy.split(s, '/') if #comps ~= 2 then fail() end local afi = detect_afi(comps[1]) if not (afi and toint(comps[2], 0, ({32, 128})[afi])) then fail() end return afi end function load_ca() if not ca_cert then local row = select_one( 'data, privateKey', 'certificate', {serial=0}, 'n' ) ca_cert = x509.new(row[1], 'PEM') ca_key = pkey.new(row[2], 'PEM', 'private') end return ca_cert, ca_key end function sign(object, hash_alg, cert, key) if not cert then cert, key = load_ca() end object:setIssuer(cert:getSubject()) object:addExtension( x509ext.new( 'authorityKeyIdentifier', 'DER', rfc5280.AuthorityKeyIdentifier.encode{ keyIdentifier=cert:getPublicKeyDigest() } ) ) return object:sign(key, hash_alg) end function issue_cert(attrs, func) local ca = attrs.serial == 0 local key = pkey.new(attrs.params.key) local cert = x509.new() cert:setVersion(3) cert:setSerial(attrs.serial) cert:setPublicKey(key) local dn = x509n.new() for _, rdn in ipairs(stringy.split(attrs.dn, ',')) do local attr, value = string.match(rdn, '^%s*(%a+)=(.*)$') if not attr then error('Invalid RDN: '..rdn) end dn:add(attr, value) end cert:setSubject(dn) cert:setBasicConstraints{CA=ca} cert:setBasicConstraintsCritical(true) local issued = cert:getLifetime() local expires = issued + attrs.params.lifetime cert:setLifetime(issued, expires) attrs.issued = issued attrs.expires = expires attrs.privateKey = key:toPEM('private') cert:addExtension( x509ext.new( 'subjectKeyIdentifier', 'DER', rfc5280.KeyIdentifier.encode(cert:getPublicKeyDigest()) ) ) local crl_dp = config.crl['dist-point'] if crl_dp then cert:addExtension( x509ext.new( 'crlDistributionPoints', 'DER', rfc5280.CRLDistributionPoints.encode{ { distributionPoint={ fullName={{uniformResourceIdentifier=crl_dp}} } } } ) ) end if func then func(cert, attrs) end sign(cert, attrs.params['hash-alg'], ca and cert, ca and key) attrs.data = tostring(cert) attrs.dn = tostring(dn) attrs.params = nil attrs.sname = nil attrs.vname = nil insert('certificate', attrs) return attrs end function export_cert(cert) local password = {} for i=1,config['password-length'] do local r = math.random(0, 61) if r > 35 then r = r + 61 elseif r > 9 then r = r + 55 else r = r + 48 end password[i] = r end password = string.char(table.unpack(password)) local ca_cert = load_ca() local chain = x509ch.new() chain:add(ca_cert) chain:add(x509.new(cert.data, 'PEM')) local file = io.open( ('%s_%s.%s.pfx'):format(cert.site, cert.vpnc, password), 'w' ) file:write( tostring( pkcs12.new{ key=pkey.new(cert.privateKey, 'PEM', 'private'), certs=chain, password=password } ) ) file:close() end function add_header(tbl, columns) if tbl[1] then table.insert(tbl, 1, columns) end return tbl end function format_cert_info(certs) local res = {} local function format_ts(timestamp) return os.date('%x %X', timestamp) end for _, cert in ipairs(certs) do local exp_info if cert.revoked then exp_info = 'Revoked '..format_ts(cert.revoked) else exp_info = 'Expire'.. (cert.expires < now and 'd' or 's')..' '.. format_ts(cert.expires) end table.insert( res, {cert.serial, cert.dn, format_ts(cert.issued), exp_info} ) end return add_header(res, {'serial', 'dn', 'issued', 'validity'}) end function print_cert(cert) print(x509.new(cert.data, 'PEM'):text{'ext_parse'}) end function is_valid(cert) return not cert.revoked and now < cert.expires end function revoke(filter) local revoked = {} for cert in select_certs(filter) do if is_valid(cert) then update( 'certificate', {revoked=now}, {serial=cert.serial} ) cert.revoked = now table.insert(revoked, cert) end end return revoked end function scan_next(desc) if #arg == 0 then if desc then error('Missing '..desc) end return end local res = arg[1] table.remove(arg, 1) return res end function scan_addr() local addr = scan_next('IP address') detect_afi(addr) return addr end function scan_subnet() local addr = scan_next('network address') detect_prefix_afi(addr) return addr end function scan_choice(choices, desc) local token = scan_next(desc) if not token then return end for k, v in pairs(choices) do if token == k then return v end end error('Invalid '..desc..': '..token) end function scan_param(choices, desc, optional) if type(choices) == 'string' then choices = {choices} end local k = scan_next(not optional and desc) if not k then return end for _, v in ipairs(choices) do if k == v then return k, scan_next(k..' specification') end end error('Invalid '..desc..': '..k) end function scan_site_code() return scan_next('site code'):upper() end function validate_site(code, retired) code = code:upper() if exists('site', {code=code}, not retired) then return code end error('Invalid site code: '..code) end function scan_site(retired) return validate_site(scan_site_code(), retired) end function scan_site_selector(options) options = options or {} local _, code = scan_param( options.param or 'site', 'selector', options.required == false ) if code then return validate_site(code, options.retired) end end function scan_site_filter(options) options = options or {} options.required = false local code = scan_site_selector(options) if code then return {code=code} end end function site_filter(filter, column) if not filter then return end if not column then column = 'site' end return {[column]=filter.code} end function scan_vpnc(options) options = options or {} if options.multiple and arg[1] == 'hubs' then return {site=''} end local k, v = scan_param({'hub', 'site'}, 'selector', options.multiple) if not k then return end local res, obj if k == 'hub' then res = {site='', id=v} obj = 'hub' else local _, vpnc = scan_param( options.id_attr or 'vpnc', 'VPNc number', true ) if not (vpnc or options.multiple) then vpnc = 1 end res = {site=validate_site(v, options.retired), id=vpnc} obj = 'VPNc' end if res.id and not exists('vpnc', res, not options.retired) then error('Invalid '..obj..' number: '..vpnc) end return res end function vpnc_filter(filter, site_column, id_column) if not filter then return end if not site_column then site_column = 'site' end if not id_column then id_column = 'vpnc' end return {[site_column]=filter.site, [id_column]=filter.id} end function scan_finished() if #arg > 0 then error('Invalid parameter: '..scan_next()) end end function display_active(value) return ({['0']='no', ['1']='yes'})[value] end function show(tbl, columns, filter, mangle) local output = {columns} for row in select_many(table.concat(columns, ','), tbl, filter, 'a') do if row.active then row.active = display_active(row.active) end if mangle then mangle(row) end local line = {} for _, column in ipairs(columns) do table.insert(line, row[column] or '') end table.insert(output, line) end return output end function show_site(filter) return show('site', {'code', 'name', 'asn', 'active'}, filter) end function show_vpnc(site, vpnc) local filter = site and site_filter(site, 'v.site') or "v.site > ''" if vpnc then filter['v.id'] = vpnc end local columns = { 'v.id', 'v.name', 'v.active', 'a1.address', 'a2.address' } local titles = {'vpnc', 'name', 'active','ipv4Address', 'ipv6Address'} local i = 2 if not site then table.insert(columns, 1, 'v.site') table.insert(titles, 1, 'site') i = 3 elseif site.code == '' then titles[1] = 'hub' end local output = {} for row in select_many( table.concat(columns, ', '), [[ vpnc v LEFT JOIN greAddress a1 ON v.site = a1.site AND v.id = a1.vpnc AND a1.family = 1 LEFT JOIN greAddress a2 ON v.site = a2.site AND v.id = a2.vpnc AND a2.family = 2 ]], filter, 'n' ) do if not row[i] then row[i] = '' end row[i + 1] = display_active(row[i + 1]) table.insert(output, row) end return add_header(output, titles) end function add_subnet(site, addr) insert( 'subnet', {site=site, family=detect_prefix_afi(addr), address=addr} ) end function retire_vpnc(site, id) local filter = {site=site, id=id} local vf = vpnc_filter(filter) local revoked = revoke(vf) delete('greAddress', vf) update('vpnc', {active='0'}, filter) return format_cert_info(revoked) end function add_site(attrs) if not attrs.asn then attrs.asn = config['default-asn'] end insert('site', attrs) end function get_cert_by_serial(serial) scan_finished() serial = toint(serial, 0, nil, 'serial number') local cert = select_one('*', 'certificate', {serial=serial}, 'a') if not cert then error('Invalid serial number') end return cert end function get_certs() if arg[1] == 'serial' then local _, serial = scan_param('serial', 'serial number') return function(cert, v) return not v and cert or nil end, get_cert_by_serial(serial) end return select_certs( vpnc_filter(scan_vpnc{multiple=true, retired=true}) or 'serial > 0' ) end function generate_crl() local crl = x509crl.new() crl:setVersion(2) local filter = {name='next-crl-number'} local serial = select_one('value', 'counter', filter) update('counter', {value=serial + 1}, filter) crl:addExtension( x509ext.new( 'crlNumber', 'DER', rfc5280.CRLNumber.encode(serial) ) ) local timestamp = crl:getLastUpdate() crl:setNextUpdate(timestamp + config.crl.lifetime) for cert in select_certs() do if cert.expires > timestamp and cert.revoked then crl:add(cert.serial, cert.revoked) end end sign(crl, config.crl['hash-alg']) return crl end function print_table(tbl) local colwidth = {} for _, row in ipairs(tbl) do for i, col in ipairs(row) do colwidth[i] = math.max( colwidth[i] or 0, string.len(col) ) end end for _, row in ipairs(tbl) do for i = 1,#row do if i > 1 then io.write(' ') end io.write(row[i]) if i < #row then for _ = 1,colwidth[i] - string.len(row[i]) do io.write(' ') end end end io.write('\n') end end function confirm(object, action, tbl) if not (unistd.isatty(0) and tbl[1]) then return end io.write('The following '..object..' will be '..action..':\n\n') print_table(tbl) io.write('\nContinue (y/n)? ') io.stdout:flush() local input = io.read() if not input or input:lower() ~= 'y' then os.exit() end io.write('\n') end output = scan_choice( scan_choice( { site={ add=function() local site = scan_site_code() scan_finished() add_site{code=site} end, set=function() local k, v = scan_param( {'asn', 'name'}, 'attribute' ) local site = scan_site() scan_finished() if k == 'asn' then v = toint( v, 0, math.pow(2, 16) - 1, 'AS number' ) end update('site', {[k]=v}, {code=site}) end, unset=function() local k = scan_choice( {name='name'}, 'attribute' ) local site = scan_site() scan_finished() update('site', {[k]=false}, {code=site}) end, show=function() local site = scan_site_filter{ param='code', retired=true } scan_finished() return show_site(site or "code > ''") end, retire=function() local site = scan_site() scan_finished() confirm( 'site', 'retired', show_site{code=site} ) delete('subnet', {site=site}) local output = retire_vpnc(site) update( 'site', {active='0'}, {code=site} ) return output end }, subnet={ add=function() local addr = scan_subnet() local site = scan_site_selector() scan_finished() add_subnet(site, addr) end, show=function() local site = scan_site_filter() scan_finished() return show( 'subnet', {'site', 'address'}, site and site_filter(site) or "site > ''" ) end, del=function() local addr = scan_subnet() local filter = site_filter( scan_site_filter() ) scan_finished() if not filter then filter = {} end filter.address = addr if not exists('subnet', filter) then error( 'Address does not exist: '.. addr ) end delete('subnet', filter) end }, vpnc={ create=function() local site = scan_site_selector() local _, id = scan_param( 'id', 'VPNc ID', true ) scan_finished() insert( 'vpnc', { site=site, id=id and toid(id) or next_key( 'vpnc', 'id', {site=site} ) } ) end, set=function() local _, name = scan_param( 'name', 'attribute' ) local vpnc = scan_vpnc{id_attr='id'} scan_finished() update('vpnc', {name=name}, vpnc) end, unset=function() scan_choice({name=true}, 'attribute') local vpnc = scan_vpnc{id_attr='id'} scan_finished() update('vpnc', {name=false}, vpnc) end, show=function() local site = scan_site_filter{ retired=true } scan_finished() return show_vpnc(site) end, retire=function() local id = toid( scan_next('VPNc number') ) local site = scan_site_selector() scan_finished() confirm( 'VPNc', 'retired', show_vpnc({code=site}, id) ) return retire_vpnc(site, id) end }, ['gre-addr']={ add=function() local addr = scan_addr() local vpnc = scan_vpnc() scan_finished() insert( 'greAddress', { site=vpnc.site, vpnc=vpnc.id, family=detect_afi(addr), address=addr } ) end, del=function() local addr = scan_addr() scan_finished() delete('greAddress', {address=addr}) end }, hub={ create=function() local id = scan_next() scan_finished() insert( 'vpnc', { site='', id=id and toid(id) or next_key( 'vpnc', 'id', {site=''} ) } ) end, set=function() local _, name = scan_param( 'name', 'attribute' ) local id = toid(scan_next('hub number')) scan_finished() update( 'vpnc', {name=name}, {site='', id=id} ) end, unset=function() scan_choice({name=true}, 'attribute') local id = toid(scan_next('hub number')) scan_finished() update( 'vpnc', {name=false}, {site='', id=id} ) end, show=function() scan_finished() return show_vpnc{code=''} end, retire=function() local id = toid(scan_next('hub number')) scan_finished() confirm( 'hub', 'retired', show_vpnc({code=''}, id) ) return retire_vpnc('', id) end }, ['root-cert']={ generate=function() scan_finished() os.remove(config['db-file']) for _, statement in ipairs( { [[ CREATE TABLE counter ( name VARCHAR(16) NOT NULL PRIMARY KEY, value INTEGER NOT NULL DEFAULT 1 ) ]], "INSERT INTO counter (name) VALUES ('next-crl-number')", [[ CREATE TABLE site ( code VARCHAR(16) NOT NULL PRIMARY KEY, asn INTEGER NOT NULL, name VARCHAR(32), active CHAR(1) DEFAULT '1' ) ]], [[ CREATE TABLE subnet ( site VARCHAR(16) NOT NULL REFERENCES site(code), family INTEGER NOT NULL, address CHAR(14) NOT NULL, PRIMARY KEY(site, family, address) ) ]], [[ CREATE TABLE vpnc ( site VARCHAR(16) NOT NULL REFERENCES site(code), id INTEGER NOT NULL, name VARCHAR(32), active CHAR(1) DEFAULT '1', PRIMARY KEY(site, id) ) ]], [[ CREATE TABLE greAddress ( site VARCHAR(16) NOT NULL, vpnc INTEGER NOT NULL, family INTEGER NOT NULL, address CHAR(14) NOT NULL, PRIMARY KEY(site, vpnc, family), FOREIGN KEY(site, vpnc) REFERENCES vpnc(site, id), UNIQUE(family, address) ) ]], [[ CREATE TABLE certificate ( serial INTEGER NOT NULL PRIMARY KEY, site VARCHAR(16), vpnc INTEGER, dn VARCHAR(128) NOT NULL, issued DATETIME NOT NULL, expires DATETIME NOT NULL, revoked DATETIME, privateKey TEXT NOT NULL, data TEXT NOT NULL, FOREIGN KEY(site, vpnc) REFERENCES vpnc(site, id) ) ]] } ) do execute(statement) end add_site{code=''} for _, subnet in ipairs( config.hub.subnets ) do add_subnet('', subnet) end issue_cert{ dn=config.ca.dn, serial=0, params=config.ca } end, show=function() print_cert(get_cert_by_serial(0)) end }, cert={ generate=function() local vpnc = vpnc_filter( scan_vpnc{multiple=true}, 's.code', 'v.id' ) local filter = vpnc or {} filter['s.active'] = '1' filter['v.active'] = '1' local subjects = {} local dns = {} for row in select_many( 's.code, v.id, s.asn, s.name, v.name', 'site s INNER JOIN vpnc v ON s.code = v.site', filter, 'n' ) do local attrs = { site=row[1], vpnc=row[2], asn=row[3], sname=row[4], vname=row[5] } local function insert() attrs.params = config[ attrs.site == '' and 'hub' or 'spoke' ] attrs.dn = attrs.params.dn:gsub( '%$(%u+)', { ROOT=config.ca.dn, SITE=attrs.sname or attrs.site, NAME=attrs.vname or attrs.params[ 'default-name' ]:gsub( '$ID', attrs.vpnc ) } ) table.insert( subjects, attrs ) table.insert( dns, {tostring(attrs.dn)} ) end if vpnc then insert() else local valid for cert in select_certs{ site=row[1], vpnc=row[2] } do if is_valid(cert) then valid = true end end if not valid then insert() end end end add_header(dns, {'dn'}) confirm('certificates', 'issued', dns) local issued = {} for _, attrs in ipairs(subjects) do attrs.serial = next_key( 'certificate', 'serial' ) local asn = attrs.asn attrs.asn = nil local cert = issue_cert( attrs, function(cert, attrs) local function get_subnets() return select_many( 'family, address', 'subnet', {site=attrs.site}, 'a' ) end local function get_gre_addrs() return select_many( 'family, address', 'greAddress', {site=attrs.site, vpnc=attrs.vpnc}, 'a' ) end cert:addExtension( x509ext.new( '1.3.6.1.4.1.31536.1.1', 'critical,DER', asn1.boolean.encode(attrs.site == '') ) ) local net_config = {} local pr_config = {} for subnet in get_subnets() do local f = subnet.family if not pr_config[f] then pr_config[f] = {} table.insert( net_config, { addressFamily={afi=f}, ipAddressChoice={ addressesOrRanges=pr_config[f] } } ) end table.insert( pr_config[f], {addressPrefix=subnet.address} ) end if net_config[1] then cert:addExtension( x509ext.new( 'sbgp-ipAddrBlock', 'critical,DER', rfc3779.IPAddrBlocks.encode(net_config) ) ) end local san for ga in get_gre_addrs() do if not san then san = x509an.new() end san:add( 'IP', ga.address ) end if san then cert:setSubjectAlt(san) end cert:addExtension( x509ext.new( 'sbgp-autonomousSysNum', 'critical,DER', rfc3779.ASIdentifiers.encode{ asnum={asIdsOrRanges={{id=asn}}} } ) ) end ) export_cert(cert) table.insert(issued, cert) end return format_cert_info(issued) end, list=function() local certs = {} for cert in get_certs() do table.insert(certs, cert) end return format_cert_info(certs) end, show=function() for cert in get_certs() do print_cert(cert) end end, revoke=function() local certs = {} for cert, selector in get_certs() do local valid = is_valid(cert) if selector == 'serial' and not valid then error('Certificate already expired or revoked') end if valid then table.insert( certs, cert ) end end confirm( 'certificates', 'revoked', format_cert_info(certs) ) local revoked = {} for _, cert in ipairs(certs) do table.insert( revoked, revoke{serial=cert.serial}[1] ) end return format_cert_info(revoked) end, export=function() local _, serial = scan_param( 'serial', 'selector' ) local cert = get_cert_by_serial(serial) if cert.serial == 0 then error('Cannot export root certificate') end export_cert(cert) end }, crl={ show=function() scan_finished() io.write(generate_crl():text()) end, export=function() scan_finished() io.write(tostring(generate_crl())) end } }, 'object type' ), 'action' )() if output then print_table(output) end if sql then conn:commit() conn:close() sql:close() end