diff --git a/.gitignore b/.gitignore index 3ca43ae..f89ac0c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb +db.sqlite \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..90f5fd8 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "build-my-own-sqlite" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = "1.0.94" +bytes = "1.9.0" +thiserror = "2.0.4" diff --git a/src/header.rs b/src/header.rs new file mode 100644 index 0000000..6acca8f --- /dev/null +++ b/src/header.rs @@ -0,0 +1,322 @@ +use std::io::Read; +use bytes::Buf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum HeaderParseError { + #[error("Not enough data")] + NotEnoughData, + #[error("Invalid magic header")] + InvalidMagicHeader, + #[error("Invalid page size")] + InvalidPageSize, + #[error("Invalid file format version")] + InvalidFileFormatVersion, + #[error("Invalid text encoding")] + InvalidTextEncoding, + #[error("Invalid maximum embedded payload fraction")] + InvalidMaximumEmbeddedPayloadFraction, + #[error("Invalid minimum embedded payload fraction")] + InvalidMinimumEmbeddedPayloadFraction, + #[error("Invalid leaf payload fraction")] + InvalidLeafPayloadFraction, +} + +const MAGIC_HEADER: &str = "SQLite format 3\0"; + +#[derive(Debug, Clone, Copy)] +pub enum SqliteFileFormatVersion { + Legacy = 1, + WAL = 2, +} + +#[derive(Debug, Clone, Copy)] +enum SqliteTextEncoding { + Utf8 = 1, + Utf16le = 2, + Utf16be = 3, +} + +const MAX_EMBEDDED_PAYLOAD_FRACTION: u8 = 64; +const MIN_EMBEDDED_PAYLOAD_FRACTION: u8 = 32; +const LEAF_PAYLOAD_FRACTION: u8 = 32; + +#[derive(Debug, Clone, Copy)] +pub struct Version { + major: u8, + minor: u8, + patch: u8, +} + +impl std::fmt::Display for Version { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.major, self.minor, self.patch) + } +} + +impl Version { + pub fn from_sqlite_version_number(mut version_number: u32) -> Self { + let major = (version_number / 1000000) as u8; + version_number %= 1000000; + let minor = (version_number / 1000) as u8; + version_number %= 1000; + let patch = version_number as u8; + Version { + major, + minor, + patch, + } + } +} + +#[derive(Debug)] +pub struct SqliteHeader { + /** + * The database page size in bytes. Must be a power of two between 512 and 65536 inclusive. + */ + pub page_size: u32, + /* + * The file format write version. 1 for legacy, 2 for WAL. + */ + pub file_format_write_version: SqliteFileFormatVersion, + pub file_format_read_version: SqliteFileFormatVersion, + /** + * The reserved space at the end of each page. Usually 0. + */ + pub reserved_space: u8, + /** + * file change counter. + */ + pub file_change_counter: u32, + /** + * The size of the database in pages. The "in-header database size" is the number of pages in the database according to the most recent commit. + * This might be different from the "on-disk database size" if the database has been growing and the size of the database file has not yet been updated. + */ + pub database_size: u32, + /** + * The page number of the first freelist trunk page. + */ + pub first_freelist_trunk_page: u32, + /** + * The total number of freelist pages. + */ + pub total_freelist_pages: u32, + /** + * The schema cookie. + */ + pub schema_cookie: u32, + /** + * The schema format number. + * New database files created by SQLite use format 4 by default. + * If the database is completely empty, then the schema format number is 0. + */ + pub schema_format_number: u32, + /** + * The default page cache size in pages. The value is a suggestion only. + */ + pub default_page_cache_size: u32, + + pub largest_root_b_tree_page_number: u32, + /** + * The database text encoding. 1 for UTF-8, 2 for UTF-16le, 3 for UTF-16be. + */ + pub text_encoding: SqliteTextEncoding, + /** + * The user version. + * sqlite3 does not use this value internally. It is available for use by applications. + */ + pub user_version: u32, + + pub incremental_vacuum_mode: u32, + /** + * The application ID. + * sqlite3 does not use this value internally. It is available for use by applications. + */ + pub application_id: u32, + /** + * The version-valid-for number. + */ + pub version_valid_for: u32, + /** + * The SQLite version number. + */ + pub sqlite_version_number: Version, +} + +impl TryFrom for SqliteFileFormatVersion { + type Error = HeaderParseError; + + fn try_from(value: u8) -> Result { + match value { + 1 => Ok(SqliteFileFormatVersion::Legacy), + 2 => Ok(SqliteFileFormatVersion::WAL), + _ => Err(HeaderParseError::InvalidFileFormatVersion), + } + } +} + +impl TryFrom for SqliteTextEncoding { + type Error = HeaderParseError; + + fn try_from(value: u32) -> Result { + match value { + 1 => Ok(SqliteTextEncoding::Utf8), + 2 => Ok(SqliteTextEncoding::Utf16le), + 3 => Ok(SqliteTextEncoding::Utf16be), + _ => Err(HeaderParseError::InvalidTextEncoding), + } + } +} + +impl SqliteHeader { + fn parse_page_size(page_size: u16) -> Result { + if page_size == 1 { + Ok(65536) + } else { + if page_size < 512 || page_size > 32768 { + Err(HeaderParseError::InvalidPageSize) + } + else { + Ok(page_size as u32) + } + } + } + fn parse_maximum_embedded_payload_fraction(fraction: u8) -> Result { + if (fraction < MIN_EMBEDDED_PAYLOAD_FRACTION) || (fraction > MAX_EMBEDDED_PAYLOAD_FRACTION) { + Err(HeaderParseError::InvalidMaximumEmbeddedPayloadFraction) + } + else { + Ok(fraction) + } + } + + fn parse_minimum_embedded_payload_fraction(fraction: u8) -> Result { + if (fraction < MIN_EMBEDDED_PAYLOAD_FRACTION) || (fraction > MAX_EMBEDDED_PAYLOAD_FRACTION) { + Err(HeaderParseError::InvalidMinimumEmbeddedPayloadFraction) + } + else { + Ok(fraction) + } + } + + fn parse_leaf_payload_fraction(fraction: u8) -> Result { + if (fraction < LEAF_PAYLOAD_FRACTION) || (fraction > MAX_EMBEDDED_PAYLOAD_FRACTION) { + Err(HeaderParseError::InvalidLeafPayloadFraction) + } + else { + Ok(fraction) + } + } + + pub fn read_from_reader(reader: &mut R) -> Result { + let mut buffer = [0u8; 100]; + reader.read_exact(&mut buffer).map_err(|_| HeaderParseError::NotEnoughData)?; + Self::read(&buffer) + } + + pub fn read(buffer: &[u8]) -> Result { + let mut reader = buffer; + // check file size + if buffer.len() < 100 { + return Err(HeaderParseError::NotEnoughData); + } + // check magic header + let magic_header = &reader.copy_to_bytes(MAGIC_HEADER.len()); + let magic_header = std::str::from_utf8(&magic_header).map_err( + |_| HeaderParseError::InvalidMagicHeader + )?; + + if magic_header != MAGIC_HEADER { + return Err(HeaderParseError::InvalidMagicHeader); + } + + let page_size = SqliteHeader::parse_page_size(reader.get_u16())?; + let file_format_write_version = SqliteFileFormatVersion::try_from(reader.get_u8())?; + let file_format_read_version = SqliteFileFormatVersion::try_from(reader.get_u8())?; + let reserved_space = reader.get_u8(); + let maximum_embedded_payload_fraction = reader.get_u8(); + SqliteHeader::parse_maximum_embedded_payload_fraction(maximum_embedded_payload_fraction)?; + let minimum_embedded_payload_fraction = reader.get_u8(); + SqliteHeader::parse_minimum_embedded_payload_fraction(minimum_embedded_payload_fraction)?; + let leaf_payload_fraction = reader.get_u8(); + SqliteHeader::parse_leaf_payload_fraction(leaf_payload_fraction)?; + + let file_change_counter = reader.get_u32(); + let database_size = reader.get_u32(); + let first_freelist_trunk_page = reader.get_u32(); + let total_freelist_pages = reader.get_u32(); + let schema_cookie = reader.get_u32(); + let schema_format_number = reader.get_u32(); + let default_page_cache_size = reader.get_u32(); + let largest_root_b_tree_page_number = reader.get_u32(); + let text_encoding = SqliteTextEncoding::try_from(reader.get_u32())?; + let user_version = reader.get_u32(); + let incremental_vacuum_mode = reader.get_u32(); + let application_id = reader.get_u32(); + // reserved space + reader.advance(20); + let version_valid_for = reader.get_u32(); + let sqlite_version_number = reader.get_u32(); + let sqlite_version_number = Version::from_sqlite_version_number(sqlite_version_number); + + Ok(Self { + page_size, + file_format_write_version, + file_format_read_version, + reserved_space, + file_change_counter, + database_size, + first_freelist_trunk_page, + total_freelist_pages, + schema_cookie, + schema_format_number, + default_page_cache_size, + largest_root_b_tree_page_number, + text_encoding, + user_version, + incremental_vacuum_mode, + application_id, + version_valid_for, + sqlite_version_number, + }) + } +} + +#[cfg(test)] +mod test { + #[test] + fn test_version_from_sqlite_version_number() { + use super::Version; + assert_eq!(Version::from_sqlite_version_number(3008000).to_string(), "3.8.0"); + assert_eq!(Version::from_sqlite_version_number(3010000).to_string(), "3.10.0"); + assert_eq!(Version::from_sqlite_version_number(3011000).to_string(), "3.11.0"); + } + + #[test] + fn test_parse_sqlite_page_size() { + use super::SqliteHeader; + assert_eq!(SqliteHeader::parse_page_size(1).unwrap(), 65536); + assert_eq!(SqliteHeader::parse_page_size(512).unwrap(), 512); + assert_eq!(SqliteHeader::parse_page_size(32768).unwrap(), 32768); + assert!(SqliteHeader::parse_page_size(511).is_err()); + assert!(SqliteHeader::parse_page_size(32769).is_err()); + } + + #[test] + fn test_parse_maximum_embedded_payload_fraction() { + use super::SqliteHeader; + assert_eq!(SqliteHeader::parse_maximum_embedded_payload_fraction(32).unwrap(), 32); + assert_eq!(SqliteHeader::parse_maximum_embedded_payload_fraction(64).unwrap(), 64); + assert!(SqliteHeader::parse_maximum_embedded_payload_fraction(31).is_err()); + assert!(SqliteHeader::parse_maximum_embedded_payload_fraction(65).is_err()); + } + + #[test] + fn test_parse_minimum_embedded_payload_fraction() { + use super::SqliteHeader; + assert_eq!(SqliteHeader::parse_minimum_embedded_payload_fraction(32).unwrap(), 32); + assert_eq!(SqliteHeader::parse_minimum_embedded_payload_fraction(64).unwrap(), 64); + assert!(SqliteHeader::parse_minimum_embedded_payload_fraction(31).is_err()); + assert!(SqliteHeader::parse_minimum_embedded_payload_fraction(65).is_err()); + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..8fa43d9 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,60 @@ +mod header; +mod page; + +use std::{fs::File, io::{Read, Seek}}; +use bytes::Buf; + +use header::SqliteHeader; +use page::{varint_read_from_buf, BTreePage, Value}; + +fn main() -> anyhow::Result<()> { + let mut f = File::options().read(true).write(true).open("db.sqlite")?; + let header = SqliteHeader::read_from_reader(& mut f)?; + println!("{:?}", header); + // print page size + println!("Page size: {}", header.page_size); + + let mut buf = vec![0; header.page_size as usize]; + f.seek(std::io::SeekFrom::Start(0))?; + f.read_exact(&mut buf)?; + let btr = BTreePage::read(&buf, true)?; + // println!("{:?}", btr); + let mut table_count = 0; + match btr { + BTreePage::LeafNode(leaf) => { + println!("{:?}", leaf); + leaf.data_cells.iter().for_each(|cell| { + let buf = &cell.data; + let mut buf = buf.as_slice(); + + let (mut column_size, i) = varint_read_from_buf(&mut buf); + let mut serial_types = vec![]; + column_size -= i as u64; + while column_size > 0 { + let (serial_type, i) = varint_read_from_buf(&mut buf); + serial_types.push(serial_type); + column_size -= i as u64; + } + let values = serial_types.iter().map(|&serial_type| { + let v = Value::from_bytes(&mut buf, serial_type); + println!("{:?}", v); + v + }).collect::>(); + println!("{:?}", values); + if let Some(v) = values.first() { + if let Value::Text(s) = v { + if s == "table" { + table_count += 1; + } + } + } + }); + } + BTreePage::InternalNode(internal) => { + println!("{:?}", internal); + } + } + println!("Table count: {}", table_count); + // get + Ok(()) +} diff --git a/src/page.rs b/src/page.rs new file mode 100644 index 0000000..d477d27 --- /dev/null +++ b/src/page.rs @@ -0,0 +1,344 @@ +use std::io::{Read, Seek}; + +use bytes::{buf, Buf}; +use thiserror::Error; + +struct FreeTrunkPage { + /** + * The next trunk page in the list of free pages. + * If this is the last trunk page in the list, this field is 0. + */ + next_trunk_page: u32, + /** + * vector of free list leaf pages in this trunk page. + */ + free_list: Vec, +// A bug in SQLite versions prior to 3.6.0 (2008-07-16) caused the database to be reported as corrupt +// if any of the last 6 entries in the freelist trunk page array contained non-zero values. +// Newer versions of SQLite do not have this problem. However, newer versions of SQLite still avoid +// using the last six entries in the freelist trunk page array in order that database files created +// by newer versions of SQLite can be read by older versions of SQLite. +} + +impl FreeTrunkPage { + pub fn read_from_reader(reader: &mut R, page_size: u32) -> Result { + let mut buf = vec![0; page_size as usize]; + reader.read_exact(&mut buf)?; + let mut buf = buf.as_slice(); + let next_trunk_page = buf.get_u32(); + let count = buf.get_u32(); + let mut free_list = Vec::new(); + for _ in 0..count { + free_list.push(buf.get_u32()); + } + Ok(FreeTrunkPage { + next_trunk_page, + free_list, + }) + } + pub fn get_next_trunk_page(&self) -> u32 { + self.next_trunk_page + } + pub fn get_free_list(&self) -> &Vec { + &self.free_list + } +} + +/** + * A free list leaf page. + * Should contain no information. + */ +struct FreeLeapPage; + +/** + * B-tree page. + */ +#[derive(Debug)] +pub enum BTreePage { + /** + * Internal node. + */ + InternalNode(InternalNode), + /** + * Leaf node. + */ + LeafNode(LeafNode), +} + +#[derive(Debug)] +enum NodeType { + Table, + Index, +} + +#[derive(Debug)] +pub struct InternalNode { + node_type: NodeType, + + /** + * start of the cells on the page + */ + first_free_cell: u16, + /** + * number of cells on the page + */ + cell_count: u16, + /** + * cell content area start + */ + data_start: u16, + /** + * the number of fragmented free bytes within the cell content area. + */ + fragment_count: u8, + last_child_cell: u32, + + child_cells: Vec, +} + +#[derive(Debug)] +pub struct LeafNode { + node_type: NodeType, + /** + * start of the cells on the page + */ + first_free_cell: u16, + /** + * number of cells on the page + */ + cell_count: u16, + /** + * cell content area start + */ + data_start: u16, + /** + * the number of fragmented free bytes within the cell content area. + */ + fragment_count: u8, + + pub data_cells: Vec, +} + +#[derive(Debug)] +pub struct DataCell { + pub row_id: u64, + pub data: Vec, +} + +#[derive(Debug)] +pub enum Value { + Null, + Integer(i64), + Real(f64), + Text(String), + Blob(Vec), +} + +impl Value { + pub fn from_bytes(data: &mut &[u8], serial_type: u64) -> Value { + let buf = data; + match serial_type { + 0 => Value::Null, + 1 => Value::Integer(buf.get_u8() as i64), + 2 => Value::Integer(buf.get_u16() as i64), + 3 => Value::Integer({ + let mut value = 0; + for i in 0..3 { + value |= (buf.get_u8() as i64) << (i * 8); + } + value + } as i64), + 4 => Value::Integer(buf.get_u32() as i64), + 5 => Value::Integer({ + let mut value = 0; + for i in 0..6 { + value |= (buf.get_u8() as i64) << (i * 8); + } + value + } as i64), + 6 => Value::Integer(buf.get_u64() as i64), + 7 => Value::Real(buf.get_f64()), + 8 => Value::Integer(0), + 9 => Value::Integer(1), + // Reserved for internal use. + 10|11 => panic!("Invalid serial type: {}", serial_type), + st if st >= 12 && st % 2 == 0 => { + let len = (st - 12) / 2; + let blob = buf[0..len as usize].to_vec(); + buf.advance(len as usize); + Value::Blob(blob) + } + st if st >= 13 && st % 2 == 1 => { + let len = (st - 13) / 2; + let text = buf[0..len as usize].to_vec(); + buf.advance(len as usize); + // println!("text: {:?} {:?}", text, len); + Value::Text(String::from_utf8(text).unwrap()) + }, + _ => panic!("Invalid serial type: {}", serial_type), + } + } +} + +#[derive(Debug, Error)] +pub enum Error { + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + #[error("Invalid page type: {0}")] + InvalidPageType(u8), + #[error("Invalid page size: {0}")] + InvalidPageSize(u32), + #[error("Invalid cell count: {0}")] + InvalidCellCount(u16), + #[error("Invalid right child page: {0}")] + InvalidRightChildPage(u32), + #[error("Invalid child page: {0}")] + InvalidChildPage(u32), + #[error("Invalid cell pointer: {0}")] + InvalidCellPointer(u16), +} + +impl BTreePage { + pub fn read_from_reader ( + reader: &mut R, + page_size: u32, + ) -> Result { + let mut buf = vec![0; page_size as usize]; + reader.read_exact(&mut buf)?; + Self::read(&buf, false) + } + pub fn read( + buffer: &[u8], + is_first_page: bool, + ) -> Result { + let mut buf = buffer; + if is_first_page { + // skip the header + buf = &buffer[100..]; + } + let page_type = buf.get_u8(); + let first_free_cell = buf.get_u16(); + let cell_count = buf.get_u16(); + let data_start = buf.get_u16(); + let fragment_count = buf.get_u8(); + + match page_type { + 2 | 5 => { + let last_child_cell = if page_type == 2 { + buf.get_u32() + } else { + 0 + }; + + let mut cell_pointers = Vec::new(); + for _ in 0..cell_count { + cell_pointers.push(buf.get_u16()); + } + + let mut child = Vec::new(); + for offset in cell_pointers { + let offset = offset as u32; + let offset = offset; + buf = &buffer[offset as usize..]; + let child_page = buf.get_u32(); + child.push(child_page); + } + + Ok(BTreePage::InternalNode(InternalNode { + cell_count, + first_free_cell, + data_start, + fragment_count, + last_child_cell, + child_cells: child, + node_type: if page_type == 2 { NodeType::Table } else { NodeType::Index }, + })) + } + 10 | 13 => { + // buf.get_u32(); + let mut cell_pointers = Vec::new(); + for _ in 0..cell_count { + cell_pointers.push(buf.get_u16()); + } + // println!("cell_pointers: {:?}", cell_pointers); + let mut data_cells = Vec::new(); + for offset in cell_pointers { + println!("offset: {:#02x}", offset); + let offset = offset as u32; + let offset = offset; + buf = &buffer[(offset) as usize..]; + + let (data_len, _) = varint_read_from_buf(&mut buf); + let (row_id, _) = varint_read_from_buf(&mut buf); + println!("data_len: {}, row_id: {}", data_len, row_id); + let data = buf[..data_len as usize].to_vec(); + data_cells.push(DataCell { + row_id, + data, + }); + } + + + Ok(BTreePage::LeafNode(LeafNode { + node_type: if page_type == 10 { NodeType::Table } else { NodeType::Index }, + cell_count, + first_free_cell, + data_start, + fragment_count, + data_cells, + })) + } + _ => Err(Error::InvalidPageType(page_type)), + } + } + +} + +pub fn varint_read_from_buf(buf: &mut &[u8]) -> (u64, u32) { + let mut byte = buf.get_u8(); + let mut value: u64 = (byte & 0x7F) as u64; + let mut i = 1; + while byte & 0x80 != 0 && i < 9 { + byte = buf.get_u8(); + value <<= 7; + value |= (byte & 0x7F) as u64; + i += 1; + } + if i == 8 && byte & 0x80 != 0 { + byte = buf.get_u8(); + value <<= 8; + value |= byte as u64; + i += 1; + }; + (value, i) +} + +#[cfg(test)] +mod test { + #[test] + fn test_varint_read_from_buf() { + let buf = [0x82, 0x00]; + let mut buf = buf.as_ref(); + let (value, i) = super::varint_read_from_buf(&mut buf); + assert_eq!(value, 0x100); + assert_eq!(i, 2); + } + + #[test] + fn test_varint_read_from_buf_2() { + let buf = [0x82, 0x01]; + let mut buf = buf.as_ref(); + let (value, i) = super::varint_read_from_buf(&mut buf); + assert_eq!(value, 0x101); + assert_eq!(i, 2); + } + + #[test] + fn test_varint_read_from_buf_3() { + let buf = [0x81, 0x91, 0xd1, 0xac, 0x78]; + let mut buf = buf.as_ref(); + let (value, i) = super::varint_read_from_buf(&mut buf); + assert_eq!(value, 0x12345678); + assert_eq!(i, 5); + } +} \ No newline at end of file