summaryrefslogtreecommitdiff
path: root/src/lua/lua_database.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/lua/lua_database.cpp')
-rw-r--r--src/lua/lua_database.cpp72
1 files changed, 45 insertions, 27 deletions
diff --git a/src/lua/lua_database.cpp b/src/lua/lua_database.cpp
index 4a7c82a5..d41f4794 100644
--- a/src/lua/lua_database.cpp
+++ b/src/lua/lua_database.cpp
@@ -32,7 +32,7 @@ namespace lua {
static constexpr char kDbIndexMetatable[] = "db_index";
static constexpr char kDbRecordMetatable[] = "db_record";
-static constexpr char kDbIteratorMetatable[] = "db_record";
+static constexpr char kDbIteratorMetatable[] = "db_iterator";
static auto indexes(lua_State* state) -> int {
Bridge* instance = Bridge::Get(state);
@@ -94,13 +94,37 @@ static auto push_lua_record(lua_State* L, const database::IndexRecord& r)
std::memcpy(record->text, text.data(), text.size());
}
+auto db_check_iterator(lua_State* L, int stack_pos) -> database::Iterator* {
+ database::Iterator* it = *reinterpret_cast<database::Iterator**>(
+ luaL_checkudata(L, stack_pos, kDbIteratorMetatable));
+ return it;
+}
+
+static auto push_iterator(lua_State* state,
+ std::variant<database::Iterator*,
+ database::Continuation,
+ database::IndexInfo> val) -> void {
+ Bridge* instance = Bridge::Get(state);
+ database::Iterator** data = reinterpret_cast<database::Iterator**>(
+ lua_newuserdata(state, sizeof(uintptr_t)));
+ std::visit(
+ [&](auto&& arg) {
+ using T = std::decay_t<decltype(arg)>;
+ if constexpr (std::is_same_v<T, database::Iterator*>) {
+ *data = new database::Iterator(*arg);
+ } else {
+ *data = new database::Iterator(instance->services().database(), arg);
+ }
+ },
+ val);
+ luaL_setmetatable(state, kDbIteratorMetatable);
+}
+
static auto db_iterate(lua_State* state) -> int {
- luaL_checktype(state, 1, LUA_TFUNCTION);
+ database::Iterator* it = db_check_iterator(state, 1);
+ luaL_checktype(state, 2, LUA_TFUNCTION);
int callback_ref = luaL_ref(state, LUA_REGISTRYINDEX);
- database::Iterator* it = *reinterpret_cast<database::Iterator**>(
- lua_touserdata(state, lua_upvalueindex(1)));
-
it->Next([=](std::optional<database::IndexRecord> res) {
events::Ui().RunOnTask([=]() {
lua_rawgeti(state, LUA_REGISTRYINDEX, callback_ref);
@@ -116,29 +140,22 @@ static auto db_iterate(lua_State* state) -> int {
return 0;
}
+static auto db_iterator_clone(lua_State* state) -> int {
+ database::Iterator* it = db_check_iterator(state, 1);
+ push_iterator(state, it);
+ return 1;
+}
+
static auto db_iterator_gc(lua_State* state) -> int {
- database::Iterator** it = reinterpret_cast<database::Iterator**>(
- luaL_checkudata(state, 1, kDbIteratorMetatable));
- if (it != NULL) {
- delete *it;
- }
+ database::Iterator* it = db_check_iterator(state, 1);
+ delete it;
return 0;
}
-static auto push_iterator(
- lua_State* state,
- std::variant<database::Continuation, database::IndexInfo> val) -> void {
- Bridge* instance = Bridge::Get(state);
- database::Iterator** data = reinterpret_cast<database::Iterator**>(
- lua_newuserdata(state, sizeof(uintptr_t)));
- std::visit(
- [&](auto&& arg) {
- *data = new database::Iterator(instance->services().database(), arg);
- },
- val);
- luaL_setmetatable(state, kDbIteratorMetatable);
- lua_pushcclosure(state, db_iterate, 1);
-}
+static const struct luaL_Reg kDbIteratorFuncs[] = {{"next", db_iterate},
+ {"clone", db_iterator_clone},
+ {"__gc", db_iterator_gc},
+ {NULL, NULL}};
static auto record_text(lua_State* state) -> int {
LuaRecord* data = reinterpret_cast<LuaRecord*>(
@@ -219,9 +236,10 @@ static auto lua_database(lua_State* state) -> int {
luaL_setfuncs(state, kDbIndexFuncs, 0);
luaL_newmetatable(state, kDbIteratorMetatable);
- lua_pushliteral(state, "__gc");
- lua_pushcfunction(state, db_iterator_gc);
- lua_settable(state, -3);
+ lua_pushliteral(state, "__index");
+ lua_pushvalue(state, -2);
+ lua_settable(state, -3); // metatable.__index = metatable
+ luaL_setfuncs(state, kDbIteratorFuncs, 0);
luaL_newmetatable(state, kDbRecordMetatable);
lua_pushliteral(state, "__index");