--[[ Copyright (c) 2012-2015 Kaarle Ritvanen See LICENSE file for license details --]] local M = {} local err = require('aconf.error') local object = require('aconf.object') local class = object.class local address = require('aconf.path.address') local util = require('aconf.util') -- TODO each transaction backend (i.e. persistence manager or -- transaction proper) should be implemented as a thread or have its -- internal state stored in shared storage (with appropriate locking) local generation = 0 local function gen_number() generation = generation + 1 return generation end M.TransactionBackend = class() function M.TransactionBackend:init() self.mod_time = {} end function M.TransactionBackend:get_if_older(path, timestamp) local value, ts = self:get(path) if ts > timestamp then err.raise('conflict', path) end return value, ts end function M.TransactionBackend:set(path, value) self:set_multiple{{path, value}} end function M.TransactionBackend:set_multiple(mods) -- TODO delegate to PM backends? local timestamp = gen_number() local effective = {} local function tostr(s) return s ~= nil and tostring(s) or nil end for _, mod in ipairs(mods) do local path, value = table.unpack(mod) if type(value) == 'table' or type( self:get(path) ) == 'table' or self:get(path) ~= value then table.insert(effective, mod) self.mod_time[path] = timestamp end end self:_set_multiple(effective) end -- TODO should be atomic, mutex with set_multiple function M.TransactionBackend:comp_and_setm(accessed, mods) local errors = err.ErrorDict() for path, timestamp in pairs(accessed) do errors:collect(self.get_if_older, self, path, timestamp) end errors:raise() self:set_multiple(mods) end local function remove_list_value(list, value) value = tostring(value) for i, v in ipairs(list) do if tostring(v) == value then table.remove(list, i) return end end end M.Transaction = class(M.TransactionBackend) function M.Transaction:init(backend) object.super(self, M.Transaction):init() self.backend = backend self:reset() end function M.Transaction:reset() self.started = gen_number() self.access_time = {} self.added = {} self.modified = {} self.deleted = {} end function M.Transaction:get(path) if self.deleted[path] then return nil, self.mod_time[path] end for _, tbl in ipairs{self.added, self.modified} do if tbl[path] ~= nil then return util.copy(tbl[path]), self.mod_time[path] end end local value, timestamp = self.backend:get_if_older(path, self.started) self.access_time[path] = timestamp return value, timestamp end function M.Transaction:expand(path) local prefix = {} path = address.split(path) while path[1] do local comp = path[1] table.remove(path, 1) if comp == address.wildcard then local p = address.join('/', table.unpack(prefix)) local res = {} local children = self:get(p) or {} table.sort(children) for _, child in ipairs(children) do util.extend( res, self:expand(address.join(p, child, table.unpack(path))) ) end return res end table.insert(prefix, comp) end return {address.join('/', table.unpack(prefix))} end function M.Transaction:_set_multiple(mods) local function set(path, value, new) local delete = value == nil if self.added[path] == nil and (not new or self.deleted[path]) then self.modified[path] = value self.deleted[path] = delete else self.added[path] = value end end for _, mod in ipairs(mods) do local path, value = table.unpack(mod) local ppath = address.parent(path) local parent = self:get(ppath) if parent == nil then parent = {} self:set(ppath, parent) end local name = address.name(path) local old = self:get(path) local is_table = type(value) == 'table' local delete = value == nil if delete then self:check_deleted(path) end if type(old) == 'table' then if delete then for _, child in ipairs(old) do self:set(address.join(path, child)) end elseif is_table then return elseif #old > 0 then error('Cannot assign a primitive value to non-leaf node '..path) end end if is_table then value = {} end set(path, value, old == nil) local function set_parent() set(ppath, parent) self.mod_time[ppath] = self.mod_time[path] end if old == nil and not delete then table.insert(parent, name) set_parent() elseif old ~= nil and delete then remove_list_value(parent, name) set_parent() end end end function M.Transaction:check_deleted(path) end function M.Transaction:commit() local mods = {} local handled = {} local function insert(path, value) assert(not handled[path]) table.insert(mods, {path, value}) handled[path] = true end local function insert_add(path) if not handled[path] then local pp = address.parent(path) if self.added[pp] then insert_add(pp) end insert(path, self.added[path]) end end local function insert_del(path) if not handled[path] then local value = self.backend:get(path) if type(value) == 'table' then for _, child in ipairs(value) do local cp = address.join(path, child) assert(self.deleted[cp]) insert_del(cp) end end insert(path) end end for path, deleted in pairs(self.deleted) do if deleted then insert_del(path) end end for path, value in pairs(self.modified) do if type(value) ~= 'table' then insert(path, value) end end for path, _ in pairs(self.added) do insert_add(path) end self.backend:comp_and_setm(self.access_time, mods) self:reset() end return M