aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--awall/init.lua59
-rw-r--r--awall/policy.lua109
2 files changed, 111 insertions, 57 deletions
diff --git a/awall/init.lua b/awall/init.lua
index 1d3e47a..a7f80af 100644
--- a/awall/init.lua
+++ b/awall/init.lua
@@ -6,7 +6,6 @@ Licensed under the terms of GPL2
module(..., package.seeall)
-require 'json'
require 'lfs'
require 'stringy'
@@ -14,6 +13,7 @@ require 'awall.ipset'
require 'awall.iptables'
require 'awall.model'
require 'awall.object'
+require 'awall.policy'
require 'awall.util'
@@ -55,64 +55,9 @@ Config = object.class(object.Object)
function Config:init(confdirs, importdirs)
- self.input = {}
+ self.input = policy.PolicySet.new(confdirs, importdirs):load()
self.iptables = iptables.IPTables.new()
- local required = {}
- local imported = {}
-
- local function import(name, fname)
- local file
- if fname then
- file = io.open(fname)
- else
- for i, dir in ipairs(importdirs or {'/usr/share/awall/optional'}) do
- file = io.open(dir..'/'..name..'.json')
- if file then break end
- end
- end
- if not file then error('Import failed: '..name) end
-
- local data = ''
- for line in file:lines() do data = data..line end
- file:close()
- data = json.decode(data)
-
- table.insert(required, name)
- for i, iname in util.listpairs(data.import) do
- if not util.contains(imported, iname) then
- if util.contains(required, iname) then
- error('Circular import: ' + iname)
- end
- import(iname)
- end
- end
- table.insert(imported, name)
-
- for cls, objs in pairs(data) do
- if cls ~= 'import' then
- if not self.input[cls] then self.input[cls] = objs
- elseif objs[1] then util.extend(self.input[cls], objs)
- else
- for k, v in pairs(objs) do self.input[cls][k] = v end
- end
- end
- end
- end
-
- for i, dir in ipairs(confdirs or
- {'/usr/share/awall/mandatory', '/etc/awall'}) do
- local names = {}
- for fname in lfs.dir(dir) do
- local si, ei, name = string.find(fname, '^([%w-]+)%.json$')
- if name then table.insert(names, name) end
- end
- table.sort(names)
-
- for i, name in ipairs(names) do import(name, dir..'/'..name..'.json') end
- end
-
-
local function expandvars(obj)
for k, v in pairs(obj) do
if type(v) == 'table' then
diff --git a/awall/policy.lua b/awall/policy.lua
new file mode 100644
index 0000000..9e45a21
--- /dev/null
+++ b/awall/policy.lua
@@ -0,0 +1,109 @@
+--[[
+Policy file handling for Alpine Wall
+Copyright (C) 2012 Kaarle Ritvanen
+Licensed under the terms of GPL2
+]]--
+
+module(..., package.seeall)
+
+require 'json'
+require 'lfs'
+require 'lpc'
+
+require 'awall.object'
+require 'awall.util'
+
+local util = awall.util
+
+
+local function open(name, dirs)
+ if not string.match(name, '^[%w-]+$') then
+ error('Invalid characters in policy name: '..name)
+ end
+ for i, dir in ipairs(dirs) do
+ local path = dir..'/'..name..'.json'
+ file = io.open(path)
+ if file then return file, path end
+ end
+end
+
+local function list(dirs)
+ local allnames = {}
+ local res = {}
+
+ for i, dir in ipairs(dirs) do
+ local names = {}
+ local paths = {}
+
+ for fname in lfs.dir(dir) do
+ local si, ei, name = string.find(fname, '^([%w-]+)%.json$')
+ if name then
+ if util.contains(allnames, name) then
+ error('Duplicate policy name: '..name)
+ end
+ table.insert(allnames, name)
+
+ table.insert(names, name)
+ paths[name] = dir..'/'..fname
+ end
+ end
+
+ table.sort(names)
+ for i, name in ipairs(names) do
+ table.insert(res, {name, paths[name]})
+ end
+ end
+
+ return res
+end
+
+
+PolicySet = awall.object.class(awall.object.Object)
+
+function PolicySet:init(confdirs, importdirs)
+ self.autodirs = confdirs or {'/usr/share/awall/mandatory', '/etc/awall'}
+ self.confdir = self.autodirs[#self.autodirs]
+ self.importdirs = importdirs or {'/usr/share/awall/optional'}
+end
+
+
+function PolicySet:load()
+
+ local input = {}
+ local required = {}
+ local imported = {}
+
+ local function import(name, fname)
+
+ if util.contains(imported, name) then return end
+ if util.contains(required, name) then
+ error('Circular import: '..name)
+ end
+
+ local file = fname and io.open(fname) or open(name, self.importdirs)
+ if not file then error('Import failed: '..name) end
+
+ local data = ''
+ for line in file:lines() do data = data..line end
+ file:close()
+ data = json.decode(data)
+
+ table.insert(required, name)
+ for i, iname in util.listpairs(data.import) do import(iname) end
+ table.insert(imported, name)
+
+ for cls, objs in pairs(data) do
+ if cls ~= 'import' then
+ if not input[cls] then input[cls] = objs
+ elseif objs[1] then util.extend(input[cls], objs)
+ else
+ for k, v in pairs(objs) do input[cls][k] = v end
+ end
+ end
+ end
+ end
+
+ for i, pol in ipairs(list(self.autodirs)) do import(unpack(pol)) end
+
+ return input, imported
+end