summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTed Trask <ttrask01@yahoo.com>2013-09-24 02:15:57 +0000
committerTed Trask <ttrask01@yahoo.com>2013-09-24 02:15:57 +0000
commite1f4c7588e1eb6253c3c1b0ee82205426fde19ec (patch)
tree9dd415cec1b636469055f8f1d46a0891b22c8d37
parent035eda4e6ab7e2f794fca8e51a61273f0a5f9171 (diff)
downloadacf-lib-e1f4c7588e1eb6253c3c1b0ee82205426fde19ec.tar.bz2
acf-lib-e1f4c7588e1eb6253c3c1b0ee82205426fde19ec.tar.xz
First cut at new db.lua library to simplify database access
-rw-r--r--db.lua154
1 files changed, 154 insertions, 0 deletions
diff --git a/db.lua b/db.lua
new file mode 100644
index 0000000..049d423
--- /dev/null
+++ b/db.lua
@@ -0,0 +1,154 @@
+module(..., package.seeall)
+
+-- ################################################################################
+-- PRIVATE DATABASE FUNCTIONS
+
+local function assert (v, m)
+ if not v then
+ m = m or "Assertion failed!"
+ error(m, 0)
+ end
+ return v, m
+end
+
+-- Escape special characters in sql statements
+local escape = function(dbobject, sql)
+ sql = sql or ""
+ return dbobject.con:escape(sql)
+end
+
+local databaseconnect = function(dbobject)
+ if not dbobject.con then
+ -- create environment object
+ if dbobject.engine == engine.postgresql then
+ require("luasql.postgres")
+ dbobject.env = assert (luasql.postgres())
+ else
+ error("Unknown database engine "..tostring(dbobject.engine))
+ end
+
+ -- connect to data source
+ dbobject.con = assert(dbobject.env:connect(dbobject.name, dbobject.user, dbobject.password, dbobject.host, dbobject.port))
+ return true
+ end
+ return false
+end
+
+local databasedisconnect = function(dbobject)
+ if dbobject.env then
+ dbobject.env:close()
+ dbobject.env = nil
+ end
+ if dbobject.con then
+ dbobject.con:close()
+ dbobject.con = nil
+ end
+end
+
+local runscript = function(dbobject, script, transaction)
+ for i,scr in ipairs(script) do
+ dbobject.runsqlcommand(scr, transaction)
+ end
+end
+
+local runsqlcommand = function(dbobject, sql, transaction)
+ if transaction then assert(dbobject.con:execute("SAVEPOINT before_command")) end
+ local res, err = dbobject.con:execute(sql)
+ if not res and err then
+ -- Catch the error to see if it's caused by lack of table
+ local table = string.match(err, "relation \"(%S+)\" does not exist")
+ if table and dbobject.table_creation_scripts[table] then
+ if transaction then assert(dbobject.con:execute("ROLLBACK TO before_command")) end
+ dbobject.runscript(dbobject.table_creation_scripts[table])
+ dbobject.runsqlcommand(sql)
+ else
+ assert(res, err)
+ end
+ else
+ if transaction then
+ assert(dbobject.con:execute("RELEASE SAVEPOINT before_command"))
+ end
+ if type(res) == userdata then
+ res:close()
+ end
+ end
+end
+
+local getselectresponse = function(dbobject, sql, transaction)
+ local retval = {}
+ if transaction then assert(dbobject.con:execute("SAVEPOINT before_select")) end
+ local res, err = pcall(function()
+ local cur = assert (dbobject.con:execute(sql))
+ local row = cur:fetch ({}, "a")
+ while row do
+ local tmp = {}
+ for name,val in pairs(row) do
+ tmp[name] = val
+ end
+ retval[#retval + 1] = tmp
+ row = cur:fetch (row, "a")
+ end
+ cur:close()
+ end)
+ if not res and err then
+ -- Catch the error to see if it's caused by lack of table
+ local table = string.match(err, "relation \"(%S+)\" does not exist")
+ if table and dbobject.table_creation_scripts[table] then
+ if transaction then assert(con:execute("ROLLBACK TO before_select")) end
+ dbobject.runscript(dbobject.table_creation_scripts[table])
+ return dbobject.getselectresponse(sql)
+ else
+ assert(res, err)
+ end
+ elseif transaction then
+ assert(dbobject.con:execute("RELEASE SAVEPOINT before_select"))
+ end
+ return retval
+end
+
+local listtables = function(dbobject)
+ local result = {}
+ if dbobject.engine == engine.postgresql then
+ local tab = dbobject.getselectresponse("SELECT tablename FROM pg_tables WHERE tablename !~* 'pg_*' ORDER BY tablename ASC")
+ for i,t in ipairs(tab) do
+ result[#result+1] = t.tablename
+ end
+ else
+ -- untested
+ result = dbobject.con:tables()
+ end
+ return result
+end
+
+local listcolumns = function(dbobject, table)
+ local result = {}
+ if dbobject.engine == engine.postgresql then
+ local col = dbobject.getselectresponse("SELECT a.attname AS field FROM pg_class c, pg_attribute a, pg_type t WHERE c.relname = '"..escape(dbobject, table).."' AND a.attnum > 0 AND a.attrelid = c.oid AND a.atttypid = t.oid ORDER BY a.attnum")
+ for i,c in ipairs(col) do
+ result[#result+1] = c.field
+ end
+ end
+ return result
+end
+
+-- ################################################################################
+-- PUBLIC FUNCTIONS / DEFINITIONS
+
+engine = {
+["postgresql"] = 1,
+}
+
+create = function(engine, name, user, password, host, port)
+ local dbobject = {engine=engine, name=name, user=user, password=password, host=host, port=port}
+ dbobject.escape = function(...) return escape(dbobject, ...) end
+ dbobject.databaseconnect = function(...) return databaseconnect(dbobject, ...) end
+ dbobject.databasedisconnect = function(...) return databasedisconnect(dbobject, ...) end
+ dbobject.runscript = function(...) return runscript(dbobject, ...) end
+ dbobject.runsqlcommand = function(...) return runsqlcommand(dbobject, ...) end
+ dbobject.getselectresponse = function(...) return getselectresponse(dbobject, ...) end
+ dbobject.listtables = function(...) return listtables(dbobject, ...) end
+ dbobject.listcolumns = function(...) return listcolumns(dbobject, ...) end
+ dbobject.isconnected = function() return dbobject.con ~= nil end
+ dbobject.table_creation_scripts = {}
+ return dbobject
+end