diff --git a/src/bitn/bit64.lua b/src/bitn/bit64.lua index 01d5dc8..066d3be 100644 --- a/src/bitn/bit64.lua +++ b/src/bitn/bit64.lua @@ -9,9 +9,31 @@ local bit64 = {} local bit32 = require("bitn.bit32") +-- Private metatable for Int64 type identification +local Int64Meta = { __name = "Int64" } + -- Type definitions --- @alias Int64HighLow [integer, integer] Array with [1]=high 32 bits, [2]=low 32 bits +-------------------------------------------------------------------------------- +-- Constructor and type checking +-------------------------------------------------------------------------------- + +--- Create a new Int64 value with metatable marker. +--- @param high? integer Upper 32 bits (default: 0) +--- @param low? integer Lower 32 bits (default: 0) +--- @return Int64HighLow value Int64 value with metatable marker +function bit64.new(high, low) + return setmetatable({ high or 0, low or 0 }, Int64Meta) +end + +--- Check if a value is an Int64 (created by bit64 functions). +--- @param value any Value to check +--- @return boolean isInt64 True if value is an Int64 +function bit64.isInt64(value) + return type(value) == "table" and getmetatable(value) == Int64Meta +end + -------------------------------------------------------------------------------- -- Bitwise operations -------------------------------------------------------------------------------- @@ -21,10 +43,7 @@ local bit32 = require("bitn.bit32") --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} AND result function bit64.band(a, b) - return { - bit32.band(a[1], b[1]), - bit32.band(a[2], b[2]), - } + return bit64.new(bit32.band(a[1], b[1]), bit32.band(a[2], b[2])) end --- Bitwise OR operation. @@ -32,10 +51,7 @@ end --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} OR result function bit64.bor(a, b) - return { - bit32.bor(a[1], b[1]), - bit32.bor(a[2], b[2]), - } + return bit64.new(bit32.bor(a[1], b[1]), bit32.bor(a[2], b[2])) end --- Bitwise XOR operation. @@ -43,20 +59,14 @@ end --- @param b Int64HighLow Second operand {high, low} --- @return Int64HighLow result {high, low} XOR result function bit64.bxor(a, b) - return { - bit32.bxor(a[1], b[1]), - bit32.bxor(a[2], b[2]), - } + return bit64.new(bit32.bxor(a[1], b[1]), bit32.bxor(a[2], b[2])) end --- Bitwise NOT operation. --- @param a Int64HighLow Operand {high, low} --- @return Int64HighLow result {high, low} NOT result function bit64.bnot(a) - return { - bit32.bnot(a[1]), - bit32.bnot(a[2]), - } + return bit64.new(bit32.bnot(a[1]), bit32.bnot(a[2])) end -------------------------------------------------------------------------------- @@ -69,17 +79,17 @@ end --- @return Int64HighLow result {high, low} shifted value function bit64.lshift(x, n) if n == 0 then - return { x[1], x[2] } + return bit64.new(x[1], x[2]) elseif n >= 64 then - return { 0, 0 } + return bit64.new(0, 0) elseif n >= 32 then -- Shift by 32 or more: low becomes 0, high gets bits from low - return { bit32.lshift(x[2], n - 32), 0 } + return bit64.new(bit32.lshift(x[2], n - 32), 0) else -- Shift by less than 32 local new_high = bit32.bor(bit32.lshift(x[1], n), bit32.rshift(x[2], 32 - n)) local new_low = bit32.lshift(x[2], n) - return { new_high, new_low } + return bit64.new(new_high, new_low) end end @@ -89,17 +99,17 @@ end --- @return Int64HighLow result {high, low} shifted value function bit64.rshift(x, n) if n == 0 then - return { x[1], x[2] } + return bit64.new(x[1], x[2]) elseif n >= 64 then - return { 0, 0 } + return bit64.new(0, 0) elseif n >= 32 then -- Shift by 32 or more: high becomes 0, low gets bits from high - return { 0, bit32.rshift(x[1], n - 32) } + return bit64.new(0, bit32.rshift(x[1], n - 32)) else -- Shift by less than 32 local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n)) local new_high = bit32.rshift(x[1], n) - return { new_high, new_low } + return bit64.new(new_high, new_low) end end @@ -109,7 +119,7 @@ end --- @return Int64HighLow result {high, low} shifted value function bit64.arshift(x, n) if n == 0 then - return { x[1], x[2] } + return bit64.new(x[1], x[2]) end -- Check sign bit (bit 31 of high word) @@ -118,20 +128,20 @@ function bit64.arshift(x, n) if n >= 64 then -- All bits shift out, result is all 1s if negative, all 0s if positive if is_negative then - return { 0xFFFFFFFF, 0xFFFFFFFF } + return bit64.new(0xFFFFFFFF, 0xFFFFFFFF) else - return { 0, 0 } + return bit64.new(0, 0) end elseif n >= 32 then -- High word shifts into low, high fills with sign local new_low = bit32.arshift(x[1], n - 32) local new_high = is_negative and 0xFFFFFFFF or 0 - return { new_high, new_low } + return bit64.new(new_high, new_low) else -- Shift by less than 32 local new_low = bit32.bor(bit32.rshift(x[2], n), bit32.lshift(x[1], 32 - n)) local new_high = bit32.arshift(x[1], n) - return { new_high, new_low } + return bit64.new(new_high, new_low) end end @@ -146,25 +156,25 @@ end function bit64.rol(x, n) n = n % 64 if n == 0 then - return { x[1], x[2] } + return bit64.new(x[1], x[2]) end local high, low = x[1], x[2] if n == 32 then -- Special case: swap high and low - return { low, high } + return bit64.new(low, high) elseif n < 32 then -- Rotate within 32-bit boundaries local new_high = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n)) local new_low = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n)) - return { new_high, new_low } + return bit64.new(new_high, new_low) else -- n > 32: rotate by (n - 32) after swapping n = n - 32 local new_high = bit32.bor(bit32.lshift(low, n), bit32.rshift(high, 32 - n)) local new_low = bit32.bor(bit32.lshift(high, n), bit32.rshift(low, 32 - n)) - return { new_high, new_low } + return bit64.new(new_high, new_low) end end @@ -175,25 +185,25 @@ end function bit64.ror(x, n) n = n % 64 if n == 0 then - return { x[1], x[2] } + return bit64.new(x[1], x[2]) end local high, low = x[1], x[2] if n == 32 then -- Special case: swap high and low - return { low, high } + return bit64.new(low, high) elseif n < 32 then -- Rotate within 32-bit boundaries local new_low = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n)) local new_high = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n)) - return { new_high, new_low } + return bit64.new(new_high, new_low) else -- n > 32: rotate by (n - 32) after swapping n = n - 32 local new_low = bit32.bor(bit32.rshift(high, n), bit32.lshift(low, 32 - n)) local new_high = bit32.bor(bit32.rshift(low, n), bit32.lshift(high, 32 - n)) - return { new_high, new_low } + return bit64.new(new_high, new_low) end end @@ -218,7 +228,7 @@ function bit64.add(a, b) -- Keep high within 32 bits high = high % 0x100000000 - return { high, low } + return bit64.new(high, low) end -------------------------------------------------------------------------------- @@ -248,7 +258,7 @@ function bit64.be_bytes_to_u64(str, offset) assert(#str >= offset + 7, "Insufficient bytes for u64") local high = bit32.be_bytes_to_u32(str, offset) local low = bit32.be_bytes_to_u32(str, offset + 4) - return { high, low } + return bit64.new(high, low) end --- Convert 8 bytes to 64-bit value (little-endian). @@ -260,7 +270,7 @@ function bit64.le_bytes_to_u64(str, offset) assert(#str >= offset + 7, "Insufficient bytes for u64") local low = bit32.le_bytes_to_u32(str, offset) local high = bit32.le_bytes_to_u32(str, offset + 4) - return { high, low } + return bit64.new(high, low) end -------------------------------------------------------------------------------- @@ -614,6 +624,135 @@ function bit64.selftest() end end + -- Int64 type identification tests + print("\nRunning Int64 type identification tests...") + + -- Test bit64.new() creates Int64 values + total = total + 1 + local new_val = bit64.new(0x12345678, 0x9ABCDEF0) + if bit64.isInt64(new_val) and new_val[1] == 0x12345678 and new_val[2] == 0x9ABCDEF0 then + print(" PASS: new() creates Int64 with correct values") + passed = passed + 1 + else + print(" FAIL: new() creates Int64 with correct values") + end + + -- Test bit64.new() with defaults + total = total + 1 + local zero_val = bit64.new() + if bit64.isInt64(zero_val) and zero_val[1] == 0 and zero_val[2] == 0 then + print(" PASS: new() with no args creates {0, 0}") + passed = passed + 1 + else + print(" FAIL: new() with no args creates {0, 0}") + end + + -- Test isInt64() returns false for regular tables + total = total + 1 + local plain_table = { 0x12345678, 0x9ABCDEF0 } + if not bit64.isInt64(plain_table) then + print(" PASS: isInt64() returns false for plain table") + passed = passed + 1 + else + print(" FAIL: isInt64() returns false for plain table") + end + + -- Test isInt64() returns false for non-tables + total = total + 1 + if not bit64.isInt64(123) and not bit64.isInt64("string") and not bit64.isInt64(nil) then + print(" PASS: isInt64() returns false for non-tables") + passed = passed + 1 + else + print(" FAIL: isInt64() returns false for non-tables") + end + + -- Test all operations return Int64 values + local ops_returning_int64 = { + { + name = "band", + fn = function() + return bit64.band({ 1, 2 }, { 3, 4 }) + end, + }, + { + name = "bor", + fn = function() + return bit64.bor({ 1, 2 }, { 3, 4 }) + end, + }, + { + name = "bxor", + fn = function() + return bit64.bxor({ 1, 2 }, { 3, 4 }) + end, + }, + { + name = "bnot", + fn = function() + return bit64.bnot({ 1, 2 }) + end, + }, + { + name = "lshift", + fn = function() + return bit64.lshift({ 1, 2 }, 1) + end, + }, + { + name = "rshift", + fn = function() + return bit64.rshift({ 1, 2 }, 1) + end, + }, + { + name = "arshift", + fn = function() + return bit64.arshift({ 1, 2 }, 1) + end, + }, + { + name = "rol", + fn = function() + return bit64.rol({ 1, 2 }, 1) + end, + }, + { + name = "ror", + fn = function() + return bit64.ror({ 1, 2 }, 1) + end, + }, + { + name = "add", + fn = function() + return bit64.add({ 1, 2 }, { 3, 4 }) + end, + }, + { + name = "be_bytes_to_u64", + fn = function() + return bit64.be_bytes_to_u64("\0\0\0\1\0\0\0\2") + end, + }, + { + name = "le_bytes_to_u64", + fn = function() + return bit64.le_bytes_to_u64("\2\0\0\0\1\0\0\0") + end, + }, + } + + for _, op in ipairs(ops_returning_int64) do + total = total + 1 + local result = op.fn() + if bit64.isInt64(result) then + print(" PASS: " .. op.name .. "() returns Int64") + passed = passed + 1 + else + print(" FAIL: " .. op.name .. "() returns Int64") + end + end + print(string.format("\n64-bit operations: %d/%d tests passed\n", passed, total)) return passed == total end