--[[ Copyright (c) 2012-2013 Kaarle Ritvanen See LICENSE file for license details --]] local ErrorDict = require('acf.error').ErrorDict local root = require('acf.model.root') local object = require('acf.object') local pth = require('acf.path') local be_mod = require('acf.transaction.backend') local util = require('acf.util') local copy = util.copy local function remove_list_value(list, value) value = tostring(value) for i, v in ipairs(list) do if tostring(v) == value then table.remove(list, i) return end end assert(false) end local Transaction = object.class(be_mod.TransactionBackend) function Transaction:init(backend, validate) object.super(self, Transaction):init() self.backend = backend self.started = be_mod.gen_number() self.access_time = {} self.added = {} self.modified = {} self.deleted = {} self.validate = validate self.validable = {} self.root = root.RootModel(self) end function Transaction:check() if not self.backend then error('Transaction already committed') end end function Transaction:get(path) self:check() if self.deleted[path] then return nil, self.mod_time[path] end for _, tbl in ipairs{self.added, self.modified} do if tbl[path] ~= nil then return copy(tbl[path]), self.mod_time[path] end end local value, timestamp = self.backend:get_if_older(path, self.started) self.access_time[path] = timestamp return value, timestamp end function Transaction:_set_multiple(mods) local function set(path, value, new) local delete = value == nil if self.added[path] == nil and (not new or self.deleted[path]) then self.modified[path] = value self.deleted[path] = delete else self.added[path] = value end end for _, mod in ipairs(mods) do local path, value = unpack(mod) local ppath = pth.parent(path) local parent = self:get(ppath) if parent == nil then parent = {} self:set(ppath, parent) end local name = pth.name(path) local old = self:get(path) local is_table = type(value) == 'table' local delete = value == nil if delete then -- assume one-level refs for now local top = root.topology(ppath) if top then local errors = ErrorDict() for _, refs in ipairs(top.referrers) do for _, ref in ipairs(self.root:search_refs(refs)) do errors:collect(ref.deleted, ref, path) end end errors:raise() end end if type(old) == 'table' then if delete then for _, child in ipairs(old) do self:set(pth.join(path, child)) end elseif is_table then return elseif #old > 0 then error('Cannot assign a primitive value to non-leaf node '..path) end end if is_table then value = {} end set(path, value, old == nil) local function set_parent() set(ppath, parent) self.mod_time[ppath] = self.mod_time[path] end if old == nil and not delete then table.insert(parent, name) set_parent() elseif old ~= nil and delete then remove_list_value(parent, name) set_parent() end end end function Transaction:fetch(path) return self.root:fetch(path) end function Transaction:meta(path) return self.root:meta(path) end function Transaction:commit() self:check() if self.validate then local errors = ErrorDict() for path, addr in pairs(copy(self.validable)) do if self:get(addr) ~= nil then errors:collect(getmetatable(self:fetch(path)).validate) end end errors:raise() end local mods = {} local handled = {} local function insert(path, value) assert(not handled[path]) table.insert(mods, {path, value}) handled[path] = true end local function insert_add(path) if not handled[path] then local pp = pth.parent(path) if self.added[pp] then insert_add(pp) end insert(path, self.added[path]) end end local function insert_del(path) if not handled[path] then local value = self.backend:get(path) if type(value) == 'table' then for _, child in ipairs(value) do local cp = pth.join(path, child) assert(self.deleted[cp]) insert_del(cp) end end insert(path) end end for path, deleted in pairs(self.deleted) do if deleted then insert_del(path) end end for path, value in pairs(self.modified) do if type(value) ~= 'table' then insert(path, value) end end for path, _ in pairs(self.added) do insert_add(path) end self.backend:comp_and_setm(self.access_time, mods) if not self.validate then util.update(self.backend.validable, self.validable) end self.backend = nil end local store = require('acf.persistence') return function(txn, defer_validation) return Transaction(txn or store, not (txn and defer_validation)) end