From 571eaa0daeb6884936fbf126e44d763ce664474f Mon Sep 17 00:00:00 2001 From: Kaarle Ritvanen Date: Fri, 15 Nov 2013 23:11:24 +0200 Subject: transaction: improve validation robustness validate parents before their subordinates validate all non-deleted TreeNodes, stored or not --- acf2/transaction/init.lua | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) (limited to 'acf2/transaction') diff --git a/acf2/transaction/init.lua b/acf2/transaction/init.lua index 8f11495..395059a 100644 --- a/acf2/transaction/init.lua +++ b/acf2/transaction/init.lua @@ -5,7 +5,10 @@ See LICENSE file for license details local ErrorDict = require('acf2.error').ErrorDict local root = require('acf2.model.root') + local object = require('acf2.object') +local super = object.super + local pth = require('acf2.path') local be_mod = require('acf2.transaction.backend') @@ -28,7 +31,7 @@ end local Transaction = object.class(be_mod.TransactionBackend) function Transaction:init(backend, validate) - object.super(self, Transaction):init() + super(self, Transaction):init() self.backend = backend @@ -41,6 +44,7 @@ function Transaction:init(backend, validate) self.validate = validate self.validable = {} + self.commit_val = {} self.root = root.RootModel(self) end @@ -64,6 +68,20 @@ function Transaction:get(path) return value, timestamp end +function Transaction:set_multiple(mods) + super(self, Transaction):set_multiple(mods) + for _, mod in ipairs(mods) do + local addr, value = unpack(mod) + if value == nil then + for _, val in ipairs{self.validable, self.commit_val} do + for path, a in pairs(val) do + if a == addr then val[path] = nil end + end + end + end + end +end + function Transaction:_set_multiple(mods) local function set(path, value, new) @@ -140,12 +158,17 @@ function Transaction:commit() self:check() if self.validate then + self.commit_val = copy(self.validable) local errors = ErrorDict() - for path, addr in pairs(copy(self.validable)) do - if self:get(addr) ~= nil then - errors:collect(getmetatable(self:fetch(path)).validate) - end + + local function validate(path) + if path > '/' then validate(pth.parent(path)) end + if not self.commit_val[path] then return end + errors:collect(getmetatable(self:fetch(path)).validate) + self.commit_val[path] = nil end + + while next(self.commit_val) do validate(next(self.commit_val)) end errors:raise() end -- cgit v1.2.3