diff --git a/.cargo-husky/hooks/pre-commit b/.cargo-husky/hooks/pre-commit new file mode 100755 index 0000000..62fae86 --- /dev/null +++ b/.cargo-husky/hooks/pre-commit @@ -0,0 +1,6 @@ +#!/bin/sh + +set -e + +echo '+cargo fmt --all -- --check' +cargo fmt --all -- --check diff --git a/Cargo.lock b/Cargo.lock index 78f06fb..c43e7a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,6 +94,17 @@ version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "atoi" version = "2.0.0" @@ -238,6 +249,12 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cargo-husky" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b02b629252fe8ef6460461409564e2c21d0c8e77e0944f3d189ff06c4e932ad" + [[package]] name = "cc" version = "1.2.57" @@ -245,6 +262,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -490,6 +509,7 @@ name = "dry_run_cli" version = "0.6.0" dependencies = [ "anyhow", + "cargo-husky", "chrono", "clap", "dry_run_core", @@ -498,16 +518,19 @@ dependencies = [ "schemars", "serde", "serde_json", + "tempfile", "thiserror 2.0.18", "tokio", "tracing", "tracing-subscriber", + "zstd", ] [[package]] name = "dry_run_core" version = "0.6.0" dependencies = [ + "async-trait", "chrono", "pg_query", "regex", @@ -521,6 +544,7 @@ dependencies = [ "tokio", "toml", "tracing", + "zstd", ] [[package]] @@ -1185,6 +1209,16 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.91" @@ -3659,3 +3693,31 @@ name = "zmij" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index c6a8e63..1ff91d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ dbg_macro = "deny" [workspace.dependencies] dry_run_core = { path = "crates/dry_run_core" } +async-trait = "0.1" chrono = { version = "0.4", features = ["serde"] } clap = { version = "4", features = ["derive", "env"] } pg_query = "6.1" @@ -26,6 +27,7 @@ thiserror = "2" tokio = { version = "1", features = ["full"] } toml = "0.8" tracing = "0.1" +zstd = "0.13" reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } rmcp = { version = "0.8", features = ["server", "transport-io", "transport-sse-server", "macros"] } schemars = "1" diff --git a/README.md b/README.md index 5bb26f5..c7a50e9 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,6 @@ LLM/AI coding assistants are very good in writing code/SQL queries. But they are Some PostgreSQL MCP server ask you for the database connection. And to perform the administrative tasks you might need SUPERUSER permission. But that's like asking for problem. - We've already seen where this leads: [production databases wiped by AI agents](https://fortune.com/2025/07/23/ai-coding-tool-replit-wiped-database-called-it-a-catastrophic-failure/), and [SQL injection in MCP servers](https://securitylabs.datadoghq.com/articles/mcp-vulnerability-case-study-SQL-injection-in-the-postgresql-mcp-server/) that were supposed to be read-only. The model doesn't need to *query* your database. It needs to *understand* your schema: the structure, constraints, statistics, and version-specific behavior. That knowledge is structural. It changes when you deploy a migration, not between queries. @@ -107,7 +106,7 @@ If you can connect to a PostgreSQL instance (local, dev, or production), one com dryrun init --db "$DATABASE_URL" ``` -This creates `dryrun.toml`, the `.dryrun/` data directory, and introspects the database into `.dryrun/schema.json`. You're ready to go. +This creates `dryrun.toml` (with `[project] id` and default profile), the `.dryrun/` data directory, and introspects the database into `.dryrun/schema.json`. Snapshots are keyed by `(project_id, database_id)`; set `database_id` per profile when a project has multiple databases (e.g. `auth`, `billing`). See [`docs/dryrun-toml.md`](docs/dryrun-toml.md) for the full config reference. ### Option B: Someone else has database access @@ -134,6 +133,44 @@ dryrun lint All commands work offline from the schema file. Each project has its own `dryrun.toml` and `.dryrun/`, there is no global state. Add `.dryrun/` to your `.gitignore`. +### Multiple databases per project + +`dryrun snapshot take` keys snapshots by `(project_id, database_id)`. The defaults work — `project_id` is your folder name, `database_id` is the actual database name from `current_database()`: + +```sh +dryrun init --db "$AUTH_DB" # captures auth +dryrun snapshot take --db "$BILLING_DB" # captures billing into its own stream +dryrun snapshot list --db "$AUTH_DB" # only auth snapshots +``` + +For stable refs (and so `list` / `diff` can run without retyping URLs), declare profiles in `dryrun.toml`: + +```toml +[project] +id = "myapp" + +[profiles.auth] +db_url = "${AUTH_DATABASE_URL}" +database_id = "auth" + +[profiles.billing] +db_url = "${BILLING_DATABASE_URL}" +database_id = "billing" +``` + +Then: + +```sh +dryrun --profile billing snapshot list +dryrun --profile billing snapshot diff --latest +``` + +See [`docs/dryrun-toml.md`](docs/dryrun-toml.md) for all profile options. + +Every DB-related command (`init`, `import`, `probe`, `dump-schema`, `lint`, `drift`, `stats apply`, all `snapshot` subcommands) accepts `--profile` and falls back to the resolved profile's `db_url` and `schema_file` when the corresponding CLI flag is not provider. + +> **Note:** the MCP server is currently single-database. Using the default profile. Or the option is to run one `dryrun mcp-serve` process per database. Native multi-database support inside one MCP process is tracked in [#4](https://github.com/boringSQL/dryrun/issues/7). + ## MCP server Add `dryrun` to your AI assistant. If you installed via Homebrew, `dryrun` is already on your PATH: @@ -150,6 +187,8 @@ claude mcp add dryrun -- /path/to/dryrun mcp-serve That's it. The server auto-discovers `.dryrun/schema.json` in the current project. No database credentials needed, your AI assistant gets full schema intelligence from the offline snapshot. +For projects with multiple databases, run one `dryrun mcp-serve` per database and add an entry per server in your client config. Native multi-database serving inside one MCP process is tracked in [#4](https://github.com/boringSQL/dryrun/issues/4). + See the [Tutorial](TUTORIAL.md) for live database setup, SSE transport, and Claude Desktop configuration. ## More diff --git a/crates/dry_run_cli/Cargo.toml b/crates/dry_run_cli/Cargo.toml index ca3530a..31ceffc 100644 --- a/crates/dry_run_cli/Cargo.toml +++ b/crates/dry_run_cli/Cargo.toml @@ -22,3 +22,8 @@ serde_json = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } +zstd = { workspace = true } + +[dev-dependencies] +cargo-husky = { version = "1", default-features = false, features = ["user-hooks"] } +tempfile = "3" diff --git a/crates/dry_run_cli/src/main.rs b/crates/dry_run_cli/src/main.rs index 0808ec4..8020d12 100644 --- a/crates/dry_run_cli/src/main.rs +++ b/crates/dry_run_cli/src/main.rs @@ -4,6 +4,9 @@ mod pgmustard; use std::path::PathBuf; use clap::{Parser, Subcommand}; +use dry_run_core::history::{ + DatabaseId, PutOutcome, SnapshotKey, SnapshotRef, SnapshotStore, TimeRange, +}; use dry_run_core::schema::{NodeColumnStats, NodeIndexStats, NodeStats, NodeTableStats}; use dry_run_core::{DryRun, HistoryStore, ProjectConfig}; use rmcp::ServiceExt; @@ -134,6 +137,12 @@ enum SnapshotAction { #[arg(long)] pretty: bool, }, + Export { + #[arg(long)] + out: Option, + #[arg(long)] + history_db: Option, + }, } #[derive(Subcommand)] @@ -162,31 +171,58 @@ async fn main() { async fn run(cli: Cli) -> anyhow::Result<()> { match cli.command { - Command::Probe { ref db } => cmd_probe(db.as_deref()).await, - Command::DumpSchema { ref source, pretty, ref output, stats_only, ref name } => { - cmd_dump_schema(source.as_deref(), pretty, output.clone(), stats_only, name.clone()).await + Command::Probe { ref db } => cmd_probe(&cli, db.as_deref()).await, + Command::DumpSchema { + ref source, + pretty, + ref output, + stats_only, + ref name, + } => { + cmd_dump_schema( + &cli, + source.as_deref(), + pretty, + output.clone(), + stats_only, + name.clone(), + ) + .await } Command::Init { ref db } => cmd_init(db.as_deref()).await, - Command::Import { ref file, ref stats } => cmd_import(file, stats).await, + Command::Import { + ref file, + ref stats, + } => cmd_import(&cli, file, stats).await, Command::Lint { ref schema_name, pretty, json, } => cmd_lint(&cli, schema_name.as_deref(), pretty, json).await, - Command::Snapshot { ref action } => cmd_snapshot(action).await, + Command::Snapshot { ref action } => cmd_snapshot(&cli, action).await, Command::Profile { ref action } => cmd_profile(&cli, action), - Command::Stats { ref action } => cmd_stats(action).await, - Command::Drift { ref db, ref against, pretty, json } => { - cmd_drift(db.as_deref(), against.as_deref(), pretty, json).await - } - Command::McpServe { ref db, ref schema_file, ref transport, port } => { - cmd_mcp_serve(&cli, db.as_deref(), schema_file.as_deref(), transport, port).await - } + Command::Stats { ref action } => cmd_stats(&cli, action).await, + Command::Drift { + ref db, + ref against, + pretty, + json, + } => cmd_drift(&cli, db.as_deref(), against.as_deref(), pretty, json).await, + Command::McpServe { + ref db, + ref schema_file, + ref transport, + port, + } => cmd_mcp_serve(&cli, db.as_deref(), schema_file.as_deref(), transport, port).await, } } -async fn cmd_probe(db: Option<&str>) -> anyhow::Result<()> { - let db_url = require_db_url(db)?; +async fn cmd_probe(cli: &Cli, db: Option<&str>) -> anyhow::Result<()> { + let resolved = active_resolved_profile(cli, db, None)?; + let db_url = resolved + .db_url + .as_deref() + .ok_or_else(|| anyhow::anyhow!("--db or a profile with db_url is required"))?; let ctx = DryRun::connect(db_url).await?; let result = ctx.probe().await?; @@ -195,26 +231,48 @@ async fn cmd_probe(db: Option<&str>) -> anyhow::Result<()> { let report = ctx.check_privileges().await?; println!("Privileges:"); - println!(" pg_catalog: {}", if report.pg_catalog { "ok" } else { "DENIED" }); - println!(" information_schema: {}", if report.information_schema { "ok" } else { "DENIED" }); - println!(" pg_stat_user_tables: {}", if report.pg_stat_user_tables { "ok" } else { "DENIED" }); + println!( + " pg_catalog: {}", + if report.pg_catalog { "ok" } else { "DENIED" } + ); + println!( + " information_schema: {}", + if report.information_schema { + "ok" + } else { + "DENIED" + } + ); + println!( + " pg_stat_user_tables: {}", + if report.pg_stat_user_tables { + "ok" + } else { + "DENIED" + } + ); Ok(()) } async fn cmd_dump_schema( + cli: &Cli, source: Option<&str>, pretty: bool, output: Option, stats_only: bool, name: Option, ) -> anyhow::Result<()> { - let db_url = require_db_url(source)?; + let resolved = active_resolved_profile(cli, source, None)?; + let db_url = resolved + .db_url + .as_deref() + .ok_or_else(|| anyhow::anyhow!("--source or a profile with db_url is required"))?; + let name = name.or_else(|| resolved.database_id.as_ref().map(|d| d.0.clone())); let ctx = DryRun::connect(db_url).await?; if stats_only { - let source = name.ok_or_else(|| { - anyhow::anyhow!("--name is required when using --stats-only") - })?; + let source = + name.ok_or_else(|| anyhow::anyhow!("--name is required when using --stats-only"))?; let node_stats = ctx.introspect_stats_only(&source).await?; let json = if pretty { @@ -304,30 +362,39 @@ async fn cmd_dump_schema( async fn cmd_init(db: Option<&str>) -> anyhow::Result<()> { let config_path = PathBuf::from("dryrun.toml"); + let cwd = std::env::current_dir().unwrap_or_default(); // scaffold config file if !config_path.exists() { - let cwd = std::env::current_dir().unwrap_or_default(); - let profile_name = cwd + let project_id = cwd .file_name() .and_then(|n| n.to_str()) .unwrap_or("default"); + let profile_name = project_id; let content = format!( - r#"[default] + r#"[project] +id = "{project_id}" + +[default] profile = "{profile_name}" [profiles.{profile_name}] schema_file = ".dryrun/schema.json" +# database_id = "{profile_name}" # defaults to profile name; override to e.g. "auth", "billing" # [profiles.dev] # db_url = "${{DATABASE_URL}}" +# database_id = "dev" # [conventions] # See: https://boringsql.com/dryrun/docs/dryrun-toml "# ); std::fs::write(&config_path, &content)?; - eprintln!("Created {} (profile \"{profile_name}\")", config_path.display()); + eprintln!( + "Created {} (profile \"{profile_name}\")", + config_path.display() + ); } else { eprintln!("{} already exists, skipping", config_path.display()); } @@ -346,9 +413,12 @@ schema_file = ".dryrun/schema.json" std::fs::write(&schema_path, &json)?; let store = open_history_store(None)?; - if let Err(e) = store.save_snapshot(db_url, &snapshot) { - eprintln!("warning: could not save snapshot: {e}"); - } + let config = ProjectConfig::discover(&cwd) + .map(|(_, c)| Ok(c)) + .unwrap_or_else(|| ProjectConfig::parse(""))?; + let resolved = config.resolve_profile(Some(db_url), None, None, &cwd)?; + let key = complete_key(&resolved, &snapshot.database); + store.put(&key, &snapshot).await?; eprintln!( "Captured schema: {} tables, {} views, {} functions", @@ -357,6 +427,10 @@ schema_file = ".dryrun/schema.json" snapshot.functions.len() ); eprintln!(" Schema: {}", schema_path.display()); + eprintln!( + " project={} database={}", + key.project_id.0, key.database_id.0 + ); } else { eprintln!("Run 'dryrun init --db ' to capture a schema snapshot"); } @@ -399,7 +473,10 @@ async fn cmd_lint( println!("{output}"); } else { if report.violations.is_empty() { - println!("No lint violations found ({} tables checked).", report.tables_checked); + println!( + "No lint violations found ({} tables checked).", + report.tables_checked + ); } else { for v in &report.violations { let location = if let Some(col) = &v.column { @@ -433,7 +510,8 @@ async fn cmd_lint( Ok(()) } -async fn cmd_snapshot(action: &SnapshotAction) -> anyhow::Result<()> { +async fn cmd_snapshot(cli: &Cli, action: &SnapshotAction) -> anyhow::Result<()> { + let profile = cli.profile.as_deref(); match action { SnapshotAction::Take { db, history_db } => { let db_url = require_db_url(db.as_deref())?; @@ -441,29 +519,49 @@ async fn cmd_snapshot(action: &SnapshotAction) -> anyhow::Result<()> { let store = open_history_store(history_db.as_deref())?; let snapshot = ctx.introspect_schema().await?; - match store.save_snapshot(db_url, &snapshot)? { - true => { + let cwd = std::env::current_dir().unwrap_or_default(); + let config = ProjectConfig::discover(&cwd) + .map(|(_, c)| Ok(c)) + .unwrap_or_else(|| ProjectConfig::parse(""))?; + let resolved = config.resolve_profile(Some(db_url), None, profile, &cwd)?; + let key = complete_key(&resolved, &snapshot.database); + + match store.put(&key, &snapshot).await? { + PutOutcome::Inserted => { println!("Snapshot saved: {}", snapshot.content_hash); println!( " {} tables, {} views, {} functions", - snapshot.tables.len(), snapshot.views.len(), snapshot.functions.len() + snapshot.tables.len(), + snapshot.views.len(), + snapshot.functions.len() + ); + println!( + " project={} database={}", + key.project_id.0, key.database_id.0 ); } - false => { + PutOutcome::Deduped => { println!("Schema unchanged (hash: {})", snapshot.content_hash); + println!( + " project={} database={}", + key.project_id.0, key.database_id.0 + ); } } Ok(()) } SnapshotAction::List { db, history_db } => { - let db_url = require_db_url(db.as_deref())?; let store = open_history_store(history_db.as_deref())?; - let snapshots = store.list_snapshots(db_url)?; + let key = resolve_read_key(db.as_deref(), profile).await?; + let rows = store.list(&key, TimeRange::default()).await?; - if snapshots.is_empty() { - println!("No snapshots found for this database."); + if rows.is_empty() { + println!( + "No snapshots found (project={} database={})", + key.project_id.0, key.database_id.0 + ); } else { - for s in &snapshots { + for s in &rows { println!( "{} {} {}", s.timestamp.format("%Y-%m-%d %H:%M:%S"), @@ -471,30 +569,38 @@ async fn cmd_snapshot(action: &SnapshotAction) -> anyhow::Result<()> { s.database, ); } - println!("\n{} snapshot(s) total", snapshots.len()); + println!( + "\n{} snapshot(s) total (project={} database={})", + rows.len(), + key.project_id.0, + key.database_id.0 + ); } Ok(()) } SnapshotAction::Diff { - db, from, to, latest, history_db, pretty, + db, + from, + to, + latest, + history_db, + pretty, } => { let db_url = require_db_url(db.as_deref())?; let ctx = DryRun::connect(db_url).await?; let store = open_history_store(history_db.as_deref())?; + let key = resolve_read_key(Some(db_url), profile).await?; let from_snapshot = if let Some(hash) = &from { - store.load_snapshot(hash)? - .ok_or_else(|| anyhow::anyhow!("snapshot with hash '{hash}' not found"))? + store.get(&key, SnapshotRef::Hash(hash.clone())).await? } else if *latest { - store.latest_snapshot(db_url)? - .ok_or_else(|| anyhow::anyhow!("no saved snapshots found for this database"))? + store.get(&key, SnapshotRef::Latest).await? } else { anyhow::bail!("specify --from or --latest"); }; let to_snapshot = if let Some(hash) = &to { - store.load_snapshot(hash)? - .ok_or_else(|| anyhow::anyhow!("snapshot with hash '{hash}' not found"))? + store.get(&key, SnapshotRef::Hash(hash.clone())).await? } else { ctx.introspect_schema().await? }; @@ -508,6 +614,33 @@ async fn cmd_snapshot(action: &SnapshotAction) -> anyhow::Result<()> { println!("{json}"); Ok(()) } + SnapshotAction::Export { out, history_db } => { + let store = open_history_store(history_db.as_deref())?; + let out_root = out.clone().unwrap_or_else(|| { + dry_run_core::history::default_data_dir() + .map(|d| d.join("snapshots")) + .unwrap_or_else(|_| PathBuf::from(".dryrun/snapshots")) + }); + + let keys = store.list_keys()?; + let mut written = 0usize; + for key in &keys { + let summaries = store.list(key, TimeRange::default()).await?; + for s in &summaries { + let snap = store + .get(key, SnapshotRef::Hash(s.content_hash.clone())) + .await?; + write_snapshot_export(&out_root, key, &snap)?; + written += 1; + } + } + println!( + "Exported {written} snapshot(s) from {} stream(s) to {}", + keys.len(), + out_root.display(), + ); + Ok(()) + } } } @@ -517,17 +650,17 @@ fn cmd_profile(cli: &Cli, action: &ProfileAction) -> anyhow::Result<()> { let config = ProjectConfig::load(config_path)?; (config_path.clone(), config) } else { - ProjectConfig::discover(&cwd) - .ok_or_else(|| anyhow::anyhow!("no dryrun.toml found"))? + ProjectConfig::discover(&cwd).ok_or_else(|| anyhow::anyhow!("no dryrun.toml found"))? }; match action { ProfileAction::List => { println!("Config: {}", config_path.display()); if let Some(default) = &config.default - && let Some(profile) = &default.profile { - println!("Default profile: {profile}"); - } + && let Some(profile) = &default.profile + { + println!("Default profile: {profile}"); + } println!(); if config.profiles.is_empty() { @@ -546,7 +679,9 @@ fn cmd_profile(cli: &Cli, action: &ProfileAction) -> anyhow::Result<()> { } } ProfileAction::Show { name } => { - let profile = config.profiles.get(name) + let profile = config + .profiles + .get(name) .ok_or_else(|| anyhow::anyhow!("profile '{name}' not found"))?; println!("Profile: {name}"); if let Some(url) = &profile.db_url { @@ -560,7 +695,11 @@ fn cmd_profile(cli: &Cli, action: &ProfileAction) -> anyhow::Result<()> { Ok(()) } -async fn cmd_import(file: &std::path::Path, stats_files: &[PathBuf]) -> anyhow::Result<()> { +async fn cmd_import( + cli: &Cli, + file: &std::path::Path, + stats_files: &[PathBuf], +) -> anyhow::Result<()> { let json = std::fs::read_to_string(file)?; let mut snapshot: dry_run_core::SchemaSnapshot = serde_json::from_str(&json) .map_err(|e| anyhow::anyhow!("invalid schema JSON in '{}': {e}", file.display()))?; @@ -568,11 +707,10 @@ async fn cmd_import(file: &std::path::Path, stats_files: &[PathBuf]) -> anyhow:: if !stats_files.is_empty() { for stats_path in stats_files { let stats_json = std::fs::read_to_string(stats_path)?; - let node_stats: dry_run_core::NodeStats = serde_json::from_str(&stats_json) - .map_err(|e| anyhow::anyhow!( - "invalid stats JSON in '{}': {e}", - stats_path.display() - ))?; + let node_stats: dry_run_core::NodeStats = + serde_json::from_str(&stats_json).map_err(|e| { + anyhow::anyhow!("invalid stats JSON in '{}': {e}", stats_path.display()) + })?; eprintln!( " merging stats from '{}' ({} tables, {} indexes)", node_stats.source, @@ -586,7 +724,15 @@ async fn cmd_import(file: &std::path::Path, stats_files: &[PathBuf]) -> anyhow:: let data_dir = dry_run_core::history::default_data_dir()?; std::fs::create_dir_all(&data_dir)?; - let out_path = data_dir.join("schema.json"); + // route to the resolved profile's schema_file when one is configured; + // fall back to .dryrun/schema.json + let out_path = active_resolved_profile(cli, None, None) + .ok() + .and_then(|r| r.schema_file) + .unwrap_or_else(|| data_dir.join("schema.json")); + if let Some(parent) = out_path.parent() { + std::fs::create_dir_all(parent)?; + } let out_json = serde_json::to_string_pretty(&snapshot)?; std::fs::write(&out_path, &out_json)?; @@ -603,21 +749,28 @@ async fn cmd_import(file: &std::path::Path, stats_files: &[PathBuf]) -> anyhow:: Ok(()) } -async fn cmd_stats(action: &StatsAction) -> anyhow::Result<()> { +async fn cmd_stats(cli: &Cli, action: &StatsAction) -> anyhow::Result<()> { match action { - StatsAction::Apply { db, schema_file, node } => { - let db_url = require_db_url(db.as_deref())?; - - let snapshot = resolve_schema(schema_file.as_deref(), None, None)?; + StatsAction::Apply { + db, + schema_file, + node, + } => { + let resolved = active_resolved_profile(cli, db.as_deref(), schema_file.as_deref())?; + let db_url = resolved + .db_url + .as_deref() + .ok_or_else(|| anyhow::anyhow!("--db or a profile with db_url is required"))?; + + let snapshot = match resolved.schema_file.as_deref() { + Some(path) => load_schema_file(path)?, + None => resolve_schema(schema_file.as_deref(), None, None)?, + }; let ctx = DryRun::connect(db_url).await?; - let result = dry_run_core::schema::apply_stats( - ctx.pool(), - &snapshot, - node.as_deref(), - ) - .await?; + let result = + dry_run_core::schema::apply_stats(ctx.pool(), &snapshot, node.as_deref()).await?; // pg_regresql warning if !result.regresql_loaded { @@ -651,13 +804,22 @@ async fn cmd_stats(action: &StatsAction) -> anyhow::Result<()> { } async fn cmd_drift( + cli: &Cli, db: Option<&str>, against: Option<&std::path::Path>, pretty: bool, json: bool, ) -> anyhow::Result<()> { - let db_url = require_db_url(db)?; - let prod_snapshot = resolve_schema(against, None, None)?; + let resolved = active_resolved_profile(cli, db, against)?; + let db_url = resolved + .db_url + .as_deref() + .ok_or_else(|| anyhow::anyhow!("--db or a profile with db_url is required"))?; + + let prod_snapshot = match resolved.schema_file.as_deref() { + Some(path) => load_schema_file(path)?, + None => resolve_schema(against, None, None)?, + }; let ctx = DryRun::connect(db_url).await?; let local_snapshot = ctx.introspect_schema().await?; @@ -681,10 +843,13 @@ async fn cmd_drift( dry_run_core::diff::DriftDirection::Behind => "BEHIND", dry_run_core::diff::DriftDirection::Diverged => "DIVERGED", }; - let location = entry.change.schema.as_deref().map_or( - entry.change.name.clone(), - |s| format!("{s}.{}", entry.change.name), - ); + let location = entry + .change + .schema + .as_deref() + .map_or(entry.change.name.clone(), |s| { + format!("{s}.{}", entry.change.name) + }); println!("[{arrow:>8}] {}: {location}", entry.change.object_type); for detail in &entry.change.details { println!(" {detail}"); @@ -709,6 +874,18 @@ fn require_db_url(db: Option<&str>) -> anyhow::Result<&str> { db.ok_or_else(|| anyhow::anyhow!("--db or DATABASE_URL is required")) } +fn active_resolved_profile( + cli: &Cli, + cli_db: Option<&str>, + cli_schema: Option<&std::path::Path>, +) -> anyhow::Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config = ProjectConfig::discover(&cwd) + .map(|(_, c)| Ok(c)) + .unwrap_or_else(|| ProjectConfig::parse(""))?; + Ok(config.resolve_profile(cli_db, cli_schema, cli.profile.as_deref(), &cwd)?) +} + fn load_project_config(cli: &Cli, cwd: &std::path::Path) -> Option { if let Some(config_path) = &cli.config { ProjectConfig::load(config_path).ok() @@ -734,9 +911,10 @@ fn schema_candidate_paths( if let Some(config) = project_config && let Ok(resolved) = config.resolve_profile(None, None, profile, &cwd) - && let Some(sf) = resolved.schema_file { - candidates.push(sf); - } + && let Some(sf) = resolved.schema_file + { + candidates.push(sf); + } if let Ok(data_dir) = dry_run_core::history::default_data_dir() { candidates.push(data_dir.join("schema.json")); @@ -753,7 +931,9 @@ fn resolve_schema_path( schema_candidate_paths(schema_file, project_config, profile) .into_iter() .find(|p| p.exists()) - .ok_or_else(|| anyhow::anyhow!("no schema found — run dump-schema first or pass --schema-file")) + .ok_or_else(|| { + anyhow::anyhow!("no schema found — run dump-schema first or pass --schema-file") + }) } fn resolve_schema( @@ -779,6 +959,68 @@ fn open_history_store(path: Option<&std::path::Path>) -> anyhow::Result/ case). +fn write_snapshot_export( + out_root: &std::path::Path, + key: &SnapshotKey, + snap: &dry_run_core::SchemaSnapshot, +) -> anyhow::Result { + let path = out_root + .join(&key.project_id.0) + .join(&key.database_id.0) + .join(format!( + "{}-{}.json.zst", + snap.timestamp.format("%Y%m%dT%H%M%SZ"), + snap.content_hash, + )); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent)?; + } + let json = serde_json::to_vec(snap)?; + let compressed = zstd::encode_all(json.as_slice(), 3)?; + std::fs::write(&path, compressed)?; + Ok(path) +} + +fn complete_key(resolved: &dry_run_core::ResolvedProfile, snapshot_database: &str) -> SnapshotKey { + SnapshotKey { + project_id: resolved.project_id.clone(), + database_id: resolved + .database_id + .clone() + .unwrap_or_else(|| DatabaseId(snapshot_database.to_string())), + } +} + +async fn resolve_read_key( + db_url: Option<&str>, + profile: Option<&str>, +) -> anyhow::Result { + let cwd = std::env::current_dir().unwrap_or_default(); + let config = ProjectConfig::discover(&cwd) + .map(|(_, c)| Ok(c)) + .unwrap_or_else(|| ProjectConfig::parse(""))?; + let resolved = config.resolve_profile(db_url, None, profile, &cwd)?; + + if let Some(database_id) = resolved.database_id { + return Ok(SnapshotKey { + project_id: resolved.project_id, + database_id, + }); + } + + let url = resolved + .db_url + .ok_or_else(|| anyhow::anyhow!("no profile and no --db; cannot determine snapshot key"))?; + let ctx = DryRun::connect(&url).await?; + let dbname = ctx.current_database().await?; + Ok(SnapshotKey { + project_id: resolved.project_id, + database_id: DatabaseId(dbname), + }) +} + async fn cmd_mcp_serve( cli: &Cli, db: Option<&str>, @@ -794,19 +1036,15 @@ async fn cmd_mcp_serve( .map(|c| c.lint_config()) .unwrap_or_default(); - let pgmustard_api_key = project_config - .as_ref() - .and_then(|c| c.pgmustard_api_key()); + let pgmustard_api_key = project_config.as_ref().and_then(|c| c.pgmustard_api_key()); - let candidates = schema_candidate_paths( - schema_path, project_config.as_ref(), cli.profile.as_deref(), - ); + let candidates = + schema_candidate_paths(schema_path, project_config.as_ref(), cli.profile.as_deref()); // try to load schema — if missing, start in uninitialized mode; // if file exists but is broken, propagate the error - let schema_path_result = resolve_schema_path( - schema_path, project_config.as_ref(), cli.profile.as_deref(), - ); + let schema_path_result = + resolve_schema_path(schema_path, project_config.as_ref(), cli.profile.as_deref()); let server = match schema_path_result { Ok(schema_file) => { @@ -814,15 +1052,18 @@ async fn cmd_mcp_serve( let snapshot: dry_run_core::SchemaSnapshot = serde_json::from_str(&json)?; eprintln!( "dryrun: loaded schema from {} ({} tables)", - schema_file.display(), snapshot.tables.len() + schema_file.display(), + snapshot.tables.len() ); // optional --db enables live tools (explain_query, refresh_schema) let effective_db = db.map(|s| s.to_string()).or_else(|| { if let Some(ref config) = project_config - && let Ok(resolved) = config.resolve_profile(None, None, cli.profile.as_deref(), &cwd) { - return resolved.db_url; - } + && let Ok(resolved) = + config.resolve_profile(None, None, cli.profile.as_deref(), &cwd) + { + return resolved.db_url; + } None }); @@ -836,7 +1077,11 @@ async fn cmd_mcp_serve( }; mcp::DryRunServer::from_snapshot_with_db( - snapshot, db_connection, lint_config, pgmustard_api_key, get_version(), + snapshot, + db_connection, + lint_config, + pgmustard_api_key, + get_version(), candidates, ) } @@ -870,3 +1115,179 @@ async fn cmd_mcp_serve( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{TimeZone, Utc}; + use dry_run_core::history::{DatabaseId, ProjectId}; + use dry_run_core::{ResolvedProfile, SchemaSnapshot}; + use tempfile::TempDir; + + fn make_snap(hash: &str, database: &str) -> SchemaSnapshot { + SchemaSnapshot { + pg_version: "PostgreSQL 17.0".into(), + database: database.into(), + timestamp: Utc.with_ymd_and_hms(2026, 4, 30, 14, 22, 11).unwrap(), + content_hash: hash.into(), + source: None, + tables: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], + node_stats: vec![], + } + } + + fn key(proj: &str, db: &str) -> SnapshotKey { + SnapshotKey { + project_id: ProjectId(proj.into()), + database_id: DatabaseId(db.into()), + } + } + + #[test] + fn complete_key_uses_resolved_database_id_when_set() { + let resolved = ResolvedProfile { + name: "prod".into(), + db_url: None, + schema_file: None, + project_id: ProjectId("clusterity".into()), + database_id: Some(DatabaseId("auth".into())), + }; + let key = complete_key(&resolved, "fallback_db"); + assert_eq!(key.project_id.0, "clusterity"); + assert_eq!(key.database_id.0, "auth"); + } + + #[test] + fn complete_key_falls_back_to_snapshot_database() { + let resolved = ResolvedProfile { + name: "".into(), + db_url: None, + schema_file: None, + project_id: ProjectId("myproj".into()), + database_id: None, + }; + let key = complete_key(&resolved, "actual_db"); + assert_eq!(key.project_id.0, "myproj"); + assert_eq!(key.database_id.0, "actual_db"); + } + + #[test] + fn write_snapshot_export_roundtrips() { + let dir = TempDir::new().unwrap(); + let k = key("myproj", "auth"); + let snap = make_snap("abc123def456", "auth"); + + let path = write_snapshot_export(dir.path(), &k, &snap).unwrap(); + + // path layout + let expected = dir + .path() + .join("myproj") + .join("auth") + .join("20260430T142211Z-abc123def456.json.zst"); + assert_eq!(path, expected); + assert!(path.exists()); + + // round-trip: decompress and parse + let bytes = std::fs::read(&path).unwrap(); + let json = zstd::decode_all(bytes.as_slice()).unwrap(); + let restored: SchemaSnapshot = serde_json::from_slice(&json).unwrap(); + assert_eq!(restored.content_hash, "abc123def456"); + assert_eq!(restored.database, "auth"); + } + + #[test] + fn schema_candidate_paths_explicit_first_then_profile_then_default() { + // explicit --schema-file path goes first; then resolved profile's path; + // the default-data-dir fallback is appended last + let toml = r#" +[profiles.dev] +schema_file = "from-profile.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let explicit = PathBuf::from("/tmp/explicit.json"); + let candidates = schema_candidate_paths(Some(&explicit), Some(&config), Some("dev")); + assert!(candidates.len() >= 2); + assert_eq!(candidates[0], explicit); + // second candidate is the resolved profile path (relative to cwd) + let cwd = std::env::current_dir().unwrap_or_default(); + assert_eq!(candidates[1], cwd.join("from-profile.json")); + } + + #[test] + fn schema_candidate_paths_no_inputs_still_includes_default_dir() { + let candidates = schema_candidate_paths(None, None, None); + // expect at least the default data-dir fallback + assert!(!candidates.is_empty()); + assert!(candidates.last().unwrap().ends_with(".dryrun/schema.json")); + } + + #[test] + fn resolve_schema_path_picks_first_existing() { + let dir = TempDir::new().unwrap(); + let missing = dir.path().join("missing.json"); + let present = dir.path().join("present.json"); + std::fs::write(&present, "{}").unwrap(); + + // explicit path that doesn't exist; profile-resolved path that does + let toml = format!("[profiles.dev]\nschema_file = \"{}\"\n", present.display()); + let config = ProjectConfig::parse(&toml).unwrap(); + let resolved = resolve_schema_path(Some(&missing), Some(&config), Some("dev")).unwrap(); + assert_eq!(resolved, present); + } + + #[test] + fn resolve_schema_path_errors_when_nothing_exists() { + let dir = TempDir::new().unwrap(); + let missing = dir.path().join("nope.json"); + let result = resolve_schema_path(Some(&missing), None, None); + assert!(result.is_err()); + } + + #[test] + fn load_schema_file_round_trips() { + let dir = TempDir::new().unwrap(); + let snap = make_snap("h1", "auth"); + let path = dir.path().join("schema.json"); + std::fs::write(&path, serde_json::to_string(&snap).unwrap()).unwrap(); + let restored = load_schema_file(&path).unwrap(); + assert_eq!(restored.content_hash, "h1"); + assert_eq!(restored.database, "auth"); + } + + #[test] + fn load_schema_file_errors_on_invalid_json() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("broken.json"); + std::fs::write(&path, "{not json").unwrap(); + assert!(load_schema_file(&path).is_err()); + } + + #[test] + fn write_snapshot_export_isolates_streams() { + let dir = TempDir::new().unwrap(); + let auth = key("p", "auth"); + let billing = key("p", "billing"); + + write_snapshot_export(dir.path(), &auth, &make_snap("h1", "auth")).unwrap(); + write_snapshot_export(dir.path(), &billing, &make_snap("h2", "billing")).unwrap(); + + assert!(dir.path().join("p/auth").is_dir()); + assert!(dir.path().join("p/billing").is_dir()); + let auth_files: Vec<_> = std::fs::read_dir(dir.path().join("p/auth")) + .unwrap() + .collect(); + let billing_files: Vec<_> = std::fs::read_dir(dir.path().join("p/billing")) + .unwrap() + .collect(); + assert_eq!(auth_files.len(), 1); + assert_eq!(billing_files.len(), 1); + } +} diff --git a/crates/dry_run_cli/src/mcp/helpers.rs b/crates/dry_run_cli/src/mcp/helpers.rs index a13c3ad..707227a 100644 --- a/crates/dry_run_cli/src/mcp/helpers.rs +++ b/crates/dry_run_cli/src/mcp/helpers.rs @@ -23,7 +23,11 @@ pub fn format_number(n: i64) -> String { result.chars().rev().collect() } -pub fn format_node_table_breakdown(node_stats: &[NodeStats], schema: &str, table: &str) -> Option { +pub fn format_node_table_breakdown( + node_stats: &[NodeStats], + schema: &str, + table: &str, +) -> Option { if node_stats.is_empty() { return None; } @@ -50,8 +54,7 @@ pub fn format_node_table_breakdown(node_stats: &[NodeStats], schema: &str, table if let Some(ts) = ts { let size_mb = ts.stats.table_size / (1024 * 1024); let collected = ns.timestamp.format("%Y-%m-%d %H:%M"); - let stale = stale_threshold - .is_some_and(|threshold| ns.timestamp < threshold); + let stale = stale_threshold.is_some_and(|threshold| ns.timestamp < threshold); lines.push(format!( "{:<16} {:>12} {:>10} {:>10} {:>10} {:>9} MB {}{}", ns.source, diff --git a/crates/dry_run_cli/src/mcp/params.rs b/crates/dry_run_cli/src/mcp/params.rs index c926d55..3565a01 100644 --- a/crates/dry_run_cli/src/mcp/params.rs +++ b/crates/dry_run_cli/src/mcp/params.rs @@ -23,7 +23,9 @@ pub struct DescribeTableParams { #[schemars(description = "Schema filter (default: all schemas).")] pub schema: Option, #[serde(default)] - #[schemars(description = "Detail level: 'summary' (default, compact with profiles), 'full' (all raw stats), 'stats' (only profiles and stats).")] + #[schemars( + description = "Detail level: 'summary' (default, compact with profiles), 'full' (all raw stats), 'stats' (only profiles and stats)." + )] pub detail: Option, } @@ -50,10 +52,14 @@ pub struct FindRelatedParams { #[derive(Debug, Deserialize, schemars::JsonSchema)] pub struct SchemaDiffParams { #[serde(default)] - #[schemars(description = "Content hash of the base snapshot. Omit to use the latest saved snapshot.")] + #[schemars( + description = "Content hash of the base snapshot. Omit to use the latest saved snapshot." + )] pub from: Option, #[serde(default)] - #[schemars(description = "Content hash of the target snapshot. Omit to compare against current live schema.")] + #[schemars( + description = "Content hash of the target snapshot. Omit to compare against current live schema." + )] pub to: Option, } @@ -87,7 +93,9 @@ fn default_true() -> Option { #[derive(Debug, Deserialize, schemars::JsonSchema)] pub struct CheckMigrationParams { - #[schemars(description = "DDL statement(s) to check for migration safety (e.g. ALTER TABLE, CREATE INDEX).")] + #[schemars( + description = "DDL statement(s) to check for migration safety (e.g. ALTER TABLE, CREATE INDEX)." + )] pub ddl: String, } @@ -100,14 +108,18 @@ pub struct LintSchemaParams { #[schemars(description = "Table filter (default: all tables).")] pub table: Option, #[serde(default)] - #[schemars(description = "Scope: 'conventions' (lint only), 'audit' (audit only), or 'all' (default, both).")] + #[schemars( + description = "Scope: 'conventions' (lint only), 'audit' (audit only), or 'all' (default, both)." + )] pub scope: Option, } #[derive(Debug, Deserialize, schemars::JsonSchema)] pub struct DetectParams { #[serde(default)] - #[schemars(description = "Detection kind: stale_stats, unused_indexes, bloated_indexes, or all (default).")] + #[schemars( + description = "Detection kind: stale_stats, unused_indexes, bloated_indexes, or all (default)." + )] pub kind: Option, #[serde(default)] #[schemars(description = "Bloat ratio threshold (default 1.5).")] diff --git a/crates/dry_run_cli/src/mcp/server.rs b/crates/dry_run_cli/src/mcp/server.rs index 66f8d14..bcd6221 100644 --- a/crates/dry_run_cli/src/mcp/server.rs +++ b/crates/dry_run_cli/src/mcp/server.rs @@ -2,18 +2,20 @@ use std::path::PathBuf; use std::sync::Arc; use rmcp::{ + ErrorData as McpError, ServerHandler, handler::server::{router::tool::ToolRouter, wrapper::Parameters}, model::*, - tool, tool_handler, tool_router, ErrorData as McpError, ServerHandler, + tool, tool_handler, tool_router, }; use tokio::sync::RwLock; use tracing::info; use dry_run_core::audit::AuditConfig; +use dry_run_core::history::{SnapshotKey, SnapshotRef, SnapshotStore}; use dry_run_core::lint::LintConfig; use dry_run_core::schema::{ - ConstraintKind, detect_seq_scan_imbalance, detect_stale_stats, - detect_unused_indexes, effective_table_stats, + ConstraintKind, detect_seq_scan_imbalance, detect_stale_stats, detect_unused_indexes, + effective_table_stats, }; use dry_run_core::{DryRun, HistoryStore, SchemaSnapshot}; @@ -25,12 +27,12 @@ use super::params::*; #[derive(Clone)] pub struct DryRunServer { ctx: Option>, - db_url: String, app_version: String, pg_version_display: String, database_name: String, schema: Arc>>, - history: Option>>, + history: Option>, + snapshot_key: Option, lint_config: LintConfig, audit_config: AuditConfig, pgmustard: Option, @@ -47,14 +49,12 @@ impl DryRunServer { app_version: &str, schema_candidates: Vec, ) -> Self { - let (ctx, db_url) = match db { - Some((url, ctx)) => (Some(Arc::new(ctx)), url.to_string()), - None => (None, String::new()), - }; + let ctx = db.map(|(_url, ctx)| Arc::new(ctx)); - let pg_version_display = dry_run_core::PgVersion::parse_from_version_string(&snapshot.pg_version) - .map(|v| format!("{}.{}.{}", v.major, v.minor, v.patch)) - .unwrap_or_default(); + let pg_version_display = + dry_run_core::PgVersion::parse_from_version_string(&snapshot.pg_version) + .map(|v| format!("{}.{}.{}", v.major, v.minor, v.patch)) + .unwrap_or_default(); let database_name = snapshot.database.clone(); info!( @@ -66,12 +66,12 @@ impl DryRunServer { Self { ctx, - db_url, app_version: app_version.to_string(), pg_version_display, database_name, schema: Arc::new(RwLock::new(Some(snapshot))), history: None, + snapshot_key: None, lint_config, audit_config: AuditConfig::default(), pgmustard: Self::resolve_pgmustard(pgmustard_api_key), @@ -97,12 +97,12 @@ impl DryRunServer { ) -> Self { Self { ctx: None, - db_url: String::new(), app_version: app_version.to_string(), pg_version_display: String::new(), database_name: String::new(), schema: Arc::new(RwLock::new(None)), history: None, + snapshot_key: None, lint_config, audit_config: AuditConfig::default(), pgmustard: None, @@ -111,6 +111,13 @@ impl DryRunServer { } } + #[allow(dead_code)] + pub fn with_history(mut self, store: HistoryStore, key: Option) -> Self { + self.history = Some(Arc::new(store)); + self.snapshot_key = key; + self + } + async fn get_schema(&self) -> Result { let guard = self.schema.read().await; guard.clone().ok_or_else(|| { @@ -137,11 +144,20 @@ impl DryRunServer { } fn mode_str(&self) -> &'static str { - if self.ctx.is_some() { "live" } else { "offline" } + if self.ctx.is_some() { + "live" + } else { + "offline" + } } fn wrap_text(&self, body: &str, hint: Option<&str>) -> String { - let header = format!("PostgreSQL {} | {} | {}\n", self.pg_version_display, self.database_name, self.mode_str()); + let header = format!( + "PostgreSQL {} | {} | {}\n", + self.pg_version_display, + self.database_name, + self.mode_str() + ); if let Some(h) = hint { format!("{header}{body}\n\n> {h}") } else { @@ -150,7 +166,9 @@ impl DryRunServer { } fn inject_meta(&self, val: &mut serde_json::Value, hint: Option<&str>) { - let obj = val.as_object_mut().expect("inject_meta expects a JSON object"); + let obj = val + .as_object_mut() + .expect("inject_meta expects a JSON object"); let mut meta = serde_json::json!({ "pg_version": self.pg_version_display, "database": self.database_name, @@ -189,7 +207,11 @@ impl DryRunServer { .iter() .filter(|t| params.schema.as_ref().is_none_or(|s| &t.schema == s)) .map(|t| { - let node_count = if snapshot.node_stats.is_empty() { 0 } else { snapshot.node_stats.len() }; + let node_count = if snapshot.node_stats.is_empty() { + 0 + } else { + snapshot.node_stats.len() + }; let stats = effective_table_stats(t, &snapshot); let rows = stats.as_ref().map(|s| s.reltuples).unwrap_or(0.0); let size = stats.as_ref().map(|s| s.table_size).unwrap_or(0); @@ -202,24 +224,47 @@ impl DryRunServer { } else { String::new() }; - let partition = t.partition_info.as_ref() - .map(|pi| format!(" [partitioned: {} on '{}', {} children]", pi.strategy, pi.key, pi.children.len())) + let partition = t + .partition_info + .as_ref() + .map(|pi| { + format!( + " [partitioned: {} on '{}', {} children]", + pi.strategy, + pi.key, + pi.children.len() + ) + }) + .unwrap_or_default(); + let comment = t + .comment + .as_ref() + .map(|c| format!(" — {c}")) .unwrap_or_default(); - let comment = t.comment.as_ref().map(|c| format!(" — {c}")).unwrap_or_default(); let name = format!("{}.{}", t.schema, t.name); let line = format!("{name}{row_est}{partition}{comment}"); - TableEntry { line, name, rows, size } + TableEntry { + line, + name, + rows, + size, + } }) .collect(); match sort_by { - "rows" => entries.sort_by(|a, b| b.rows.partial_cmp(&a.rows).unwrap_or(std::cmp::Ordering::Equal)), + "rows" => entries.sort_by(|a, b| { + b.rows + .partial_cmp(&a.rows) + .unwrap_or(std::cmp::Ordering::Equal) + }), "size" => entries.sort_by_key(|b| std::cmp::Reverse(b.size)), _ => entries.sort_by(|a, b| a.name.cmp(&b.name)), } let total = entries.len(); - let paginated: Vec<&str> = entries.iter() + let paginated: Vec<&str> = entries + .iter() .skip(offset) .take(limit) .map(|e| e.line.as_str()) @@ -230,7 +275,10 @@ impl DryRunServer { } else if offset > 0 || paginated.len() < total { format!( "Showing {}-{} of {} table(s):\n{}", - offset + 1, offset + paginated.len(), total, paginated.join("\n") + offset + 1, + offset + paginated.len(), + total, + paginated.join("\n") ) } else { format!("{} table(s):\n{}", total, paginated.join("\n")) @@ -240,7 +288,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(text)])) } - #[tool(description = "Table columns, types, constraints, indexes and stats. Per-node stats when present.")] + #[tool( + description = "Table columns, types, constraints, indexes and stats. Per-node stats when present." + )] async fn describe_table( &self, Parameters(params): Parameters, @@ -265,7 +315,9 @@ impl DryRunServer { .unwrap_or(0.0); // build column profiles - let profiles: Vec = table.columns.iter() + let profiles: Vec = table + .columns + .iter() .filter_map(|col| { dry_run_core::schema::profile_column(col, table_rows).map(|p| { serde_json::json!({ @@ -278,12 +330,14 @@ impl DryRunServer { let mut json_val = match detail { "full" => { - let mut v = serde_json::to_value(table) - .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?; + let mut v = serde_json::to_value(table).map_err(|e| { + McpError::internal_error(format!("serialization error: {e}"), None) + })?; if let Some(obj) = v.as_object_mut() - && !profiles.is_empty() { - obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); - } + && !profiles.is_empty() + { + obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); + } v } "stats" => { @@ -293,41 +347,50 @@ impl DryRunServer { "stats": table.stats, }); if let Some(obj) = result.as_object_mut() - && !profiles.is_empty() { - obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); - } + && !profiles.is_empty() + { + obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); + } result } _ => { // summary: compact columns without raw stats - let compact_cols: Vec = table.columns.iter().map(|c| { - let mut col = serde_json::json!({ - "name": c.name, - "ordinal": c.ordinal, - "type_name": c.type_name, - "nullable": c.nullable, - "default": c.default, - "identity": c.identity, - "generated": c.generated, - "comment": c.comment, - }); - if let Some(target) = c.statistics_target { - col["statistics_target"] = serde_json::json!(target); - } - col - }).collect(); - let compact_idxs: Vec = table.indexes.iter().map(|i| { - serde_json::json!({ - "name": i.name, - "columns": i.columns, - "index_type": i.index_type, - "is_unique": i.is_unique, - "is_primary": i.is_primary, - "predicate": i.predicate, - "definition": i.definition, - "is_valid": i.is_valid, + let compact_cols: Vec = table + .columns + .iter() + .map(|c| { + let mut col = serde_json::json!({ + "name": c.name, + "ordinal": c.ordinal, + "type_name": c.type_name, + "nullable": c.nullable, + "default": c.default, + "identity": c.identity, + "generated": c.generated, + "comment": c.comment, + }); + if let Some(target) = c.statistics_target { + col["statistics_target"] = serde_json::json!(target); + } + col }) - }).collect(); + .collect(); + let compact_idxs: Vec = table + .indexes + .iter() + .map(|i| { + serde_json::json!({ + "name": i.name, + "columns": i.columns, + "index_type": i.index_type, + "is_unique": i.is_unique, + "is_primary": i.is_primary, + "predicate": i.predicate, + "definition": i.definition, + "is_valid": i.is_valid, + }) + }) + .collect(); let mut result = serde_json::json!({ "schema": table.schema, "name": table.name, @@ -339,16 +402,22 @@ impl DryRunServer { "partition_info": table.partition_info, }); if let Some(obj) = result.as_object_mut() - && !profiles.is_empty() { - obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); - } + && !profiles.is_empty() + { + obj.insert("column_profiles".into(), serde_json::Value::Array(profiles)); + } result } }; - let has_fks = table.constraints.iter().any(|c| c.kind == ConstraintKind::ForeignKey); + let has_fks = table + .constraints + .iter() + .any(|c| c.kind == ConstraintKind::ForeignKey); let hint = if has_fks { - Some("This table has foreign keys — use find_related for JOIN patterns with related tables.") + Some( + "This table has foreign keys — use find_related for JOIN patterns with related tables.", + ) } else { None }; @@ -357,14 +426,18 @@ impl DryRunServer { let mut text = serde_json::to_string_pretty(&json_val) .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?; - if let Some(breakdown) = format_node_table_breakdown(&snapshot.node_stats, schema_name, ¶ms.table) { + if let Some(breakdown) = + format_node_table_breakdown(&snapshot.node_stats, schema_name, ¶ms.table) + { text.push_str(&breakdown); } Ok(CallToolResult::success(vec![Content::text(text)])) } - #[tool(description = "Substring search over tables, columns, views, functions, enums, indexes, comments.")] + #[tool( + description = "Substring search over tables, columns, views, functions, enums, indexes, comments." + )] async fn search_schema( &self, Parameters(params): Parameters, @@ -377,34 +450,53 @@ impl DryRunServer { let qualified = format!("{}.{}", table.schema, table.name); if table.name.to_lowercase().contains(&query) { - let comment = table.comment.as_ref().map(|c| format!(" — {c}")).unwrap_or_default(); + let comment = table + .comment + .as_ref() + .map(|c| format!(" — {c}")) + .unwrap_or_default(); results.push(format!("TABLE {qualified}{comment}")); } for col in &table.columns { if col.name.to_lowercase().contains(&query) { - results.push(format!("COLUMN {qualified}.{} ({})", col.name, col.type_name)); + results.push(format!( + "COLUMN {qualified}.{} ({})", + col.name, col.type_name + )); } if let Some(comment) = &col.comment - && comment.to_lowercase().contains(&query) { - results.push(format!("COLUMN COMMENT {qualified}.{}: {comment}", col.name)); - } + && comment.to_lowercase().contains(&query) + { + results.push(format!( + "COLUMN COMMENT {qualified}.{}: {comment}", + col.name + )); + } } if let Some(comment) = &table.comment - && comment.to_lowercase().contains(&query) && !table.name.to_lowercase().contains(&query) { - results.push(format!("TABLE COMMENT {qualified}: {comment}")); - } + && comment.to_lowercase().contains(&query) + && !table.name.to_lowercase().contains(&query) + { + results.push(format!("TABLE COMMENT {qualified}: {comment}")); + } for con in &table.constraints { if let Some(def) = &con.definition - && def.to_lowercase().contains(&query) { - results.push(format!("CONSTRAINT {qualified}.{} ({:?}): {def}", con.name, con.kind)); - } + && def.to_lowercase().contains(&query) + { + results.push(format!( + "CONSTRAINT {qualified}.{} ({:?}): {def}", + con.name, con.kind + )); + } } for idx in &table.indexes { - if idx.name.to_lowercase().contains(&query) || idx.definition.to_lowercase().contains(&query) { + if idx.name.to_lowercase().contains(&query) + || idx.definition.to_lowercase().contains(&query) + { results.push(format!("INDEX {qualified}: {}", idx.definition)); } } @@ -412,27 +504,42 @@ impl DryRunServer { for view in &snapshot.views { if view.name.to_lowercase().contains(&query) { - let kind = if view.is_materialized { "MATERIALIZED VIEW" } else { "VIEW" }; + let kind = if view.is_materialized { + "MATERIALIZED VIEW" + } else { + "VIEW" + }; results.push(format!("{kind} {}.{}", view.schema, view.name)); } } for func in &snapshot.functions { if func.name.to_lowercase().contains(&query) { - results.push(format!("FUNCTION {}.{}({})", func.schema, func.name, func.identity_args)); + results.push(format!( + "FUNCTION {}.{}({})", + func.schema, func.name, func.identity_args + )); } } for e in &snapshot.enums { - if e.name.to_lowercase().contains(&query) || e.labels.iter().any(|l| l.to_lowercase().contains(&query)) { - results.push(format!("ENUM {}.{}: [{}]", e.schema, e.name, e.labels.join(", "))); + if e.name.to_lowercase().contains(&query) + || e.labels.iter().any(|l| l.to_lowercase().contains(&query)) + { + results.push(format!( + "ENUM {}.{}: [{}]", + e.schema, + e.name, + e.labels.join(", ") + )); } } let limit = params.limit.unwrap_or(30); let offset = params.offset.unwrap_or(0); let total = results.len(); - let paginated: Vec<&str> = results.iter() + let paginated: Vec<&str> = results + .iter() .skip(offset) .take(limit) .map(|s| s.as_str()) @@ -443,10 +550,19 @@ impl DryRunServer { } else if offset > 0 || paginated.len() < total { format!( "Showing {}-{} of {} match(es) for '{}':\n{}", - offset + 1, offset + paginated.len(), total, params.query, paginated.join("\n") + offset + 1, + offset + paginated.len(), + total, + params.query, + paginated.join("\n") ) } else { - format!("{} match(es) for '{}':\n{}", total, params.query, paginated.join("\n")) + format!( + "{} match(es) for '{}':\n{}", + total, + params.query, + paginated.join("\n") + ) }; let text = self.wrap_text(&body, None); @@ -466,12 +582,18 @@ impl DryRunServer { .tables .iter() .find(|t| t.name == params.table && t.schema == schema_name) - .ok_or_else(|| McpError::invalid_params(format!("table '{qualified}' not found"), None))?; + .ok_or_else(|| { + McpError::invalid_params(format!("table '{qualified}' not found"), None) + })?; let mut lines: Vec = Vec::new(); lines.push(format!("Relationships for {qualified}:\n")); - let outgoing: Vec<_> = table.constraints.iter().filter(|c| c.kind == ConstraintKind::ForeignKey).collect(); + let outgoing: Vec<_> = table + .constraints + .iter() + .filter(|c| c.kind == ConstraintKind::ForeignKey) + .collect(); if outgoing.is_empty() { lines.push("Outgoing FKs: none".into()); @@ -481,7 +603,9 @@ impl DryRunServer { let ref_table = fk.fk_table.as_deref().unwrap_or("?"); let local_cols = fk.columns.join(", "); let ref_cols = fk.fk_columns.join(", "); - lines.push(format!(" {qualified}({local_cols}) -> {ref_table}({ref_cols})")); + lines.push(format!( + " {qualified}({local_cols}) -> {ref_table}({ref_cols})" + )); lines.push(format!(" JOIN: SELECT * FROM {qualified} JOIN {ref_table} ON {}.{local_cols} = {ref_table}.{ref_cols}", params.table)); } } @@ -489,15 +613,20 @@ impl DryRunServer { let mut incoming: Vec = Vec::new(); for other in &snapshot.tables { for fk in &other.constraints { - if fk.kind != ConstraintKind::ForeignKey { continue; } + if fk.kind != ConstraintKind::ForeignKey { + continue; + } if let Some(ref_table) = &fk.fk_table - && ref_table == &qualified { - let other_qualified = format!("{}.{}", other.schema, other.name); - let local_cols = fk.columns.join(", "); - let ref_cols = fk.fk_columns.join(", "); - incoming.push(format!(" {other_qualified}({local_cols}) -> {qualified}({ref_cols})")); - incoming.push(format!(" JOIN: SELECT * FROM {qualified} JOIN {other_qualified} ON {qualified}.{ref_cols} = {other_qualified}.{local_cols}")); - } + && ref_table == &qualified + { + let other_qualified = format!("{}.{}", other.schema, other.name); + let local_cols = fk.columns.join(", "); + let ref_cols = fk.fk_columns.join(", "); + incoming.push(format!( + " {other_qualified}({local_cols}) -> {qualified}({ref_cols})" + )); + incoming.push(format!(" JOIN: SELECT * FROM {qualified} JOIN {other_qualified} ON {qualified}.{ref_cols} = {other_qualified}.{local_cols}")); + } } } @@ -514,37 +643,40 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(text)])) } - #[tool(description = "Diff two snapshots, or the latest snapshot against the live schema. Needs --history.")] + #[tool( + description = "Diff two snapshots, or the latest snapshot against the live schema. Needs --history." + )] async fn schema_diff( &self, Parameters(params): Parameters, ) -> Result { - let history_arc = self.history.as_ref() + let store = self + .history + .as_ref() .ok_or_else(|| McpError::internal_error("history store not configured", None))?; + let key = self.snapshot_key.as_ref().ok_or_else(|| { + McpError::internal_error( + "schema_diff needs a snapshot key — pass --db or set [default].profile", + None, + ) + })?; - let (from_snapshot, to_hash) = { - let history = history_arc.lock().map_err(|e| McpError::internal_error(format!("history lock poisoned: {e}"), None))?; - - let from = if let Some(hash) = ¶ms.from { - history.load_snapshot(hash).map_err(to_mcp_err)? - .ok_or_else(|| McpError::invalid_params(format!("snapshot '{hash}' not found"), None))? - } else { - history.latest_snapshot(&self.db_url).map_err(to_mcp_err)? - .ok_or_else(|| McpError::invalid_params("no saved snapshots found — run snapshot first", None))? - }; - - let to = if let Some(hash) = ¶ms.to { - Some(history.load_snapshot(hash).map_err(to_mcp_err)? - .ok_or_else(|| McpError::invalid_params(format!("snapshot '{hash}' not found"), None))?) - } else { - None - }; - - (from, to) + let from_snapshot = match ¶ms.from { + Some(hash) => store + .get(key, SnapshotRef::Hash(hash.clone())) + .await + .map_err(to_mcp_err)?, + None => store + .get(key, SnapshotRef::Latest) + .await + .map_err(to_mcp_err)?, }; - let to_snapshot = match to_hash { - Some(s) => s, + let to_snapshot = match ¶ms.to { + Some(hash) => store + .get(key, SnapshotRef::Hash(hash.clone())) + .await + .map_err(to_mcp_err)?, None => self.get_schema().await?, }; @@ -564,7 +696,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Parse SQL and check it against the schema. Flags missing tables or columns and common anti-patterns. Offline.")] + #[tool( + description = "Parse SQL and check it against the schema. Flags missing tables or columns and common anti-patterns. Offline." + )] async fn validate_query( &self, Parameters(params): Parameters, @@ -574,7 +708,9 @@ impl DryRunServer { .map_err(|e| McpError::invalid_params(format!("SQL parse error: {e}"), None))?; let hint = if result.valid && !result.warnings.is_empty() { - Some("Query is valid but has warnings. Use advise for index suggestions and plan analysis.") + Some( + "Query is valid but has warnings. Use advise for index suggestions and plan analysis.", + ) } else if result.valid { Some("Query is valid. Use advise if you need optimization suggestions.") } else { @@ -591,7 +727,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Run EXPLAIN on a query. Pass analyze=true to run EXPLAIN ANALYZE. Needs live DB.")] + #[tool( + description = "Run EXPLAIN on a query. Pass analyze=true to run EXPLAIN ANALYZE. Needs live DB." + )] async fn explain_query( &self, Parameters(params): Parameters, @@ -600,11 +738,18 @@ impl DryRunServer { let ctx = self.require_live_db()?; let result = dry_run_core::query::explain_query( - ctx.pool(), ¶ms.sql, params.analyze.unwrap_or(false), schema.as_ref(), - ).await.map_err(|e| McpError::invalid_params(format!("EXPLAIN failed: {e}"), None))?; + ctx.pool(), + ¶ms.sql, + params.analyze.unwrap_or(false), + schema.as_ref(), + ) + .await + .map_err(|e| McpError::invalid_params(format!("EXPLAIN failed: {e}"), None))?; let hint = if !result.warnings.is_empty() { - Some("Warnings detected. Use advise for index suggestions and actionable recommendations.") + Some( + "Warnings detected. Use advise for index suggestions and actionable recommendations.", + ) } else { None }; @@ -619,19 +764,27 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Plan analysis, anti-pattern checks and index suggestions for a query. Uses EXPLAIN when a live DB is available, static analysis otherwise.")] + #[tool( + description = "Plan analysis, anti-pattern checks and index suggestions for a query. Uses EXPLAIN when a live DB is available, static analysis otherwise." + )] async fn advise( &self, Parameters(params): Parameters, ) -> Result { let schema = self.get_schema().await?; - let pg_version = dry_run_core::PgVersion::parse_from_version_string(&schema.pg_version).ok(); + let pg_version = + dry_run_core::PgVersion::parse_from_version_string(&schema.pg_version).ok(); let include_idx = params.include_index_suggestions.unwrap_or(true); let explain_result = if let Some(ctx) = &self.ctx { dry_run_core::query::explain_query( - ctx.pool(), ¶ms.sql, params.analyze.unwrap_or(false), Some(&schema), - ).await.ok() + ctx.pool(), + ¶ms.sql, + params.analyze.unwrap_or(false), + Some(&schema), + ) + .await + .ok() } else { None }; @@ -642,11 +795,14 @@ impl DryRunServer { &schema, pg_version.as_ref(), include_idx, - ).map_err(|e| McpError::invalid_params(format!("analysis failed: {e}"), None))?; + ) + .map_err(|e| McpError::invalid_params(format!("analysis failed: {e}"), None))?; let has_ddl_suggestions = !advise_result.index_suggestions.is_empty(); let hint = if has_ddl_suggestions { - Some("Index suggestions contain DDL. Run each through check_migration before applying — it checks lock safety and duration.") + Some( + "Index suggestions contain DDL. Run each through check_migration before applying — it checks lock safety and duration.", + ) } else { None }; @@ -678,7 +834,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Analyze an existing EXPLAIN plan (JSON) against the schema. Returns warnings, index and safety hints. Offline.")] + #[tool( + description = "Analyze an existing EXPLAIN plan (JSON) against the schema. Returns warnings, index and safety hints. Offline." + )] async fn analyze_plan( &self, Parameters(params): Parameters, @@ -689,11 +847,9 @@ impl DryRunServer { // Parse the plan JSON — supports both wrapped [{"Plan": ...}] and bare {"Plan": ...} let plan_value = if let Some(arr) = params.plan_json.as_array() { - arr.first() - .and_then(|obj| obj.get("Plan")) - .ok_or_else(|| { - McpError::invalid_params("plan_json must contain a Plan key", None) - })? + arr.first().and_then(|obj| obj.get("Plan")).ok_or_else(|| { + McpError::invalid_params("plan_json must contain a Plan key", None) + })? } else { params.plan_json.get("Plan").ok_or_else(|| { McpError::invalid_params("plan_json must contain a Plan key", None) @@ -740,7 +896,9 @@ impl DryRunServer { let has_ddl_suggestions = !advise_result.index_suggestions.is_empty(); let hint = if has_ddl_suggestions { - Some("Index suggestions contain DDL. Run each through check_migration before applying — it checks lock safety and duration.") + Some( + "Index suggestions contain DDL. Run each through check_migration before applying — it checks lock safety and duration.", + ) } else { None }; @@ -777,16 +935,20 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Check a DDL statement for lock level, duration, table-size impact, and suggest safer alternatives.")] + #[tool( + description = "Check a DDL statement for lock level, duration, table-size impact, and suggest safer alternatives." + )] async fn check_migration( &self, Parameters(params): Parameters, ) -> Result { let schema = self.get_schema().await?; - let pg_version = dry_run_core::PgVersion::parse_from_version_string(&schema.pg_version).ok(); + let pg_version = + dry_run_core::PgVersion::parse_from_version_string(&schema.pg_version).ok(); - let checks = dry_run_core::query::check_migration(¶ms.ddl, &schema, pg_version.as_ref()) - .map_err(|e| McpError::invalid_params(format!("DDL parse error: {e}"), None))?; + let checks = + dry_run_core::query::check_migration(¶ms.ddl, &schema, pg_version.as_ref()) + .map_err(|e| McpError::invalid_params(format!("DDL parse error: {e}"), None))?; if checks.is_empty() { return Ok(CallToolResult::success(vec![Content::text( @@ -796,9 +958,13 @@ impl DryRunServer { )])); } - let has_dangerous = checks.iter().any(|c| c.safety == dry_run_core::query::SafetyRating::Dangerous); + let has_dangerous = checks + .iter() + .any(|c| c.safety == dry_run_core::query::SafetyRating::Dangerous); let hint = if has_dangerous { - Some("DANGEROUS operations detected. Check the recommendation and rollback_ddl fields for safe alternatives.") + Some( + "DANGEROUS operations detected. Check the recommendation and rollback_ddl fields for safe alternatives.", + ) } else { None }; @@ -812,7 +978,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Schema quality checks. scope=conventions, audit, or all (default). Offline.")] + #[tool( + description = "Schema quality checks. scope=conventions, audit, or all (default). Offline." + )] async fn lint_schema( &self, Parameters(params): Parameters, @@ -855,7 +1023,9 @@ impl DryRunServer { }; let hint = if has_ddl_fixes { - Some("Some findings include ddl_fix fields. Run those through check_migration before applying to verify lock safety.") + Some( + "Some findings include ddl_fix fields. Run those through check_migration before applying to verify lock safety.", + ) } else { None }; @@ -869,7 +1039,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Autovacuum status with thresholds, dead tuples and tuning hints. Offline.")] + #[tool( + description = "Autovacuum status with thresholds, dead tuples and tuning hints. Offline." + )] async fn vacuum_health( &self, Parameters(params): Parameters, @@ -899,7 +1071,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Health checks. kind=stale_stats, unused_indexes, anomalies, bloated_indexes, or all (default). Offline.")] + #[tool( + description = "Health checks. kind=stale_stats, unused_indexes, anomalies, bloated_indexes, or all (default). Offline." + )] async fn detect( &self, Parameters(params): Parameters, @@ -937,22 +1111,28 @@ impl DryRunServer { if run_stale { let stale = detect_stale_stats(&snapshot.node_stats, 7); found_stale = !stale.is_empty(); - result.insert("stale_stats".into(), serde_json::to_value(&stale) - .unwrap_or(serde_json::Value::Null)); + result.insert( + "stale_stats".into(), + serde_json::to_value(&stale).unwrap_or(serde_json::Value::Null), + ); } if run_unused { let unused = detect_unused_indexes(&snapshot.node_stats, &snapshot.tables); found_unused = !unused.is_empty(); - result.insert("unused_indexes".into(), serde_json::to_value(&unused) - .unwrap_or(serde_json::Value::Null)); + result.insert( + "unused_indexes".into(), + serde_json::to_value(&unused).unwrap_or(serde_json::Value::Null), + ); } if run_anomalies { let mut anomalies = Vec::new(); for table in &snapshot.tables { let schema_name = &table.schema; - if let Some(imb) = detect_seq_scan_imbalance(&snapshot.node_stats, schema_name, &table.name) { + if let Some(imb) = + detect_seq_scan_imbalance(&snapshot.node_stats, schema_name, &table.name) + { anomalies.push(serde_json::json!({ "table": format!("{}.{}", schema_name, table.name), "type": "seq_scan_imbalance", @@ -967,14 +1147,22 @@ impl DryRunServer { if run_bloated { let threshold = params.threshold.unwrap_or(1.5); let bloated = dry_run_core::schema::detect_bloated_indexes(&snapshot.tables, threshold); - result.insert("bloated_indexes".into(), serde_json::to_value(&bloated) - .unwrap_or(serde_json::Value::Null)); + result.insert( + "bloated_indexes".into(), + serde_json::to_value(&bloated).unwrap_or(serde_json::Value::Null), + ); } let hint = match (found_stale, found_unused) { - (true, true) => Some("Stale stats may cause bad plans — run ANALYZE. Unused indexes add write overhead — verify with compare_nodes before dropping."), - (true, false) => Some("Stale stats may cause bad query plans — consider running ANALYZE."), - (false, true) => Some("Unused indexes add write overhead. Use compare_nodes to verify across all replicas before dropping."), + (true, true) => Some( + "Stale stats may cause bad plans — run ANALYZE. Unused indexes add write overhead — verify with compare_nodes before dropping.", + ), + (true, false) => { + Some("Stale stats may cause bad query plans — consider running ANALYZE.") + } + (false, true) => Some( + "Unused indexes add write overhead. Use compare_nodes to verify across all replicas before dropping.", + ), (false, false) => None, }; @@ -986,7 +1174,9 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(json)])) } - #[tool(description = "Per-node stats for a table. Shows reltuples, relpages, scans, size and per-index numbers. Offline.")] + #[tool( + description = "Per-node stats for a table. Shows reltuples, relpages, scans, size and per-index numbers. Offline." + )] async fn compare_nodes( &self, Parameters(params): Parameters, @@ -1004,14 +1194,21 @@ impl DryRunServer { } let mut lines: Vec = Vec::new(); - lines.push(format!("Stats for {qualified} across {} node(s):", snapshot.node_stats.len())); - - if let Some(breakdown) = format_node_table_breakdown(&snapshot.node_stats, schema_name, ¶ms.table) { + lines.push(format!( + "Stats for {qualified} across {} node(s):", + snapshot.node_stats.len() + )); + + if let Some(breakdown) = + format_node_table_breakdown(&snapshot.node_stats, schema_name, ¶ms.table) + { lines.push(breakdown); } // anomaly detection: seq_scan imbalance - if let Some(imb) = detect_seq_scan_imbalance(&snapshot.node_stats, schema_name, ¶ms.table) { + if let Some(imb) = + detect_seq_scan_imbalance(&snapshot.node_stats, schema_name, ¶ms.table) + { lines.push(String::new()); lines.push(format!( "⚠ {} has {}x more seq_scans than the lowest node — \ @@ -1063,11 +1260,15 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(text)])) } - #[tool(description = "Compare the live local DB against the loaded production snapshot. Each diff is tagged ahead, behind or diverged. Needs live DB.")] + #[tool( + description = "Compare the live local DB against the loaded production snapshot. Each diff is tagged ahead, behind or diverged. Needs live DB." + )] async fn check_drift(&self) -> Result { let ctx = self.require_live_db()?; let prod_snapshot = self.get_schema().await?; - let local_snapshot = ctx.introspect_schema().await + let local_snapshot = ctx + .introspect_schema() + .await .map_err(|e| McpError::internal_error(format!("introspection failed: {e}"), None))?; let report = dry_run_core::diff::classify_drift(&prod_snapshot, &local_snapshot); @@ -1085,12 +1286,16 @@ impl DryRunServer { #[tool(description = "Force re-introspection of the database schema (requires live DB)")] async fn refresh_schema(&self) -> Result { let ctx = self.require_live_db()?; - let snapshot = ctx.introspect_schema().await + let snapshot = ctx + .introspect_schema() + .await .map_err(|e| McpError::internal_error(format!("introspection failed: {e}"), None))?; let body = format!( "Schema refreshed: {} tables, {} views, {} functions (hash: {})", - snapshot.tables.len(), snapshot.views.len(), snapshot.functions.len(), + snapshot.tables.len(), + snapshot.views.len(), + snapshot.functions.len(), &snapshot.content_hash[..16], ); @@ -1100,16 +1305,26 @@ impl DryRunServer { Ok(CallToolResult::success(vec![Content::text(text)])) } - #[tool(description = "Reload the on-disk schema without restarting. Run after `dryrun dump-schema`.")] + #[tool( + description = "Reload the on-disk schema without restarting. Run after `dryrun dump-schema`." + )] async fn reload_schema(&self) -> Result { for candidate in &self.schema_candidates { if !candidate.exists() { continue; } - let json = std::fs::read_to_string(candidate) - .map_err(|e| McpError::internal_error(format!("failed to read {}: {e}", candidate.display()), None))?; - let snapshot: SchemaSnapshot = serde_json::from_str(&json) - .map_err(|e| McpError::internal_error(format!("failed to parse {}: {e}", candidate.display()), None))?; + let json = std::fs::read_to_string(candidate).map_err(|e| { + McpError::internal_error( + format!("failed to read {}: {e}", candidate.display()), + None, + ) + })?; + let snapshot: SchemaSnapshot = serde_json::from_str(&json).map_err(|e| { + McpError::internal_error( + format!("failed to parse {}: {e}", candidate.display()), + None, + ) + })?; let body = format!( "Schema loaded from {}: {} tables, {} views, {} functions", @@ -1125,7 +1340,11 @@ impl DryRunServer { return Ok(CallToolResult::success(vec![Content::text(text)])); } - let paths: Vec<_> = self.schema_candidates.iter().map(|p| format!(" - {}", p.display())).collect(); + let paths: Vec<_> = self + .schema_candidates + .iter() + .map(|p| format!(" - {}", p.display())) + .collect(); Err(McpError::internal_error( format!( "no schema file found at any expected location:\n{}\n\n\ @@ -1235,20 +1454,42 @@ mod tests { #[tokio::test] async fn list_tables_includes_pg_version() { let snapshot = test_snapshot(); - let server = DryRunServer::from_snapshot_with_db(snapshot, None, LintConfig::default(), None, "test", vec![]); + let server = DryRunServer::from_snapshot_with_db( + snapshot, + None, + LintConfig::default(), + None, + "test", + vec![], + ); let result = server - .list_tables(Parameters(ListTablesParams { schema: None, sort: None, limit: None, offset: None })) + .list_tables(Parameters(ListTablesParams { + schema: None, + sort: None, + limit: None, + offset: None, + })) .await .unwrap(); let text = result.content.first().unwrap(); let text_str = format!("{text:?}"); - assert!(text_str.contains("PostgreSQL 18.3.0"), "list_tables output should contain PG version"); + assert!( + text_str.contains("PostgreSQL 18.3.0"), + "list_tables output should contain PG version" + ); } #[tokio::test] async fn describe_table_includes_pg_version() { let snapshot = test_snapshot(); - let server = DryRunServer::from_snapshot_with_db(snapshot, None, LintConfig::default(), None, "test", vec![]); + let server = DryRunServer::from_snapshot_with_db( + snapshot, + None, + LintConfig::default(), + None, + "test", + vec![], + ); let result = server .describe_table(Parameters(DescribeTableParams { table: "orders".into(), @@ -1259,7 +1500,10 @@ mod tests { .unwrap(); let text = result.content.first().unwrap(); let text_str = format!("{text:?}"); - assert!(text_str.contains("pg_version"), "describe_table output should contain pg_version field"); + assert!( + text_str.contains("pg_version"), + "describe_table output should contain pg_version field" + ); } fn test_snapshot() -> dry_run_core::SchemaSnapshot { @@ -1271,15 +1515,49 @@ mod tests { content_hash: "abc123".into(), source: None, tables: vec![Table { - oid: 1, schema: "public".into(), name: "orders".into(), - columns: vec![ - Column { name: "id".into(), ordinal: 1, type_name: "bigint".into(), nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, - ], - constraints: vec![], indexes: vec![], comment: None, - stats: Some(TableStats { reltuples: 50000.0, relpages: 625, dead_tuples: 0, last_vacuum: None, last_autovacuum: None, last_analyze: None, last_autoanalyze: None, seq_scan: 0, idx_scan: 0, table_size: 5000000 }), - partition_info: None, policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + oid: 1, + schema: "public".into(), + name: "orders".into(), + columns: vec![Column { + name: "id".into(), + ordinal: 1, + type_name: "bigint".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }], + constraints: vec![], + indexes: vec![], + comment: None, + stats: Some(TableStats { + reltuples: 50000.0, + relpages: 625, + dead_tuples: 0, + last_vacuum: None, + last_autovacuum: None, + last_analyze: None, + last_autoanalyze: None, + seq_scan: 0, + idx_scan: 0, + table_size: 5000000, + }), + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }], - enums: vec![], domains: vec![], composites: vec![], views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -1332,7 +1610,10 @@ impl ServerHandler for DryRunServer { self.app_version, self.pg_version_display, self.database_name ) } else { - format!("dryrun {} PostgreSQL schema advisor. No schema loaded yet.\n\n", self.app_version) + format!( + "dryrun {} PostgreSQL schema advisor. No schema loaded yet.\n\n", + self.app_version + ) }; let online_note = if self.ctx.is_some() { @@ -1342,13 +1623,13 @@ impl ServerHandler for DryRunServer { }; ServerInfo { - instructions: Some( - format!("{version_header}\ + instructions: Some(format!( + "{version_header}\ {online_note}\n\n\ Start with list_tables or search_schema to explore. Use advise for query help. \ Use check_migration before applying DDL. Each tool response includes a _meta.hint \ - field with contextual next-step guidance."), - ), + field with contextual next-step guidance." + )), capabilities: ServerCapabilities::builder().enable_tools().build(), ..Default::default() } diff --git a/crates/dry_run_core/Cargo.toml b/crates/dry_run_core/Cargo.toml index 21303e2..6a3330f 100644 --- a/crates/dry_run_core/Cargo.toml +++ b/crates/dry_run_core/Cargo.toml @@ -4,6 +4,7 @@ version.workspace = true edition.workspace = true [dependencies] +async-trait = { workspace = true } chrono = { workspace = true } pg_query = { workspace = true } regex = { workspace = true } @@ -13,9 +14,10 @@ serde_json = { workspace = true } sha2 = { workspace = true } sqlx = { workspace = true } thiserror = { workspace = true } -tokio = { version = "1", features = ["macros"] } +tokio = { version = "1", features = ["macros", "rt"] } toml = { workspace = true } tracing = { workspace = true } +zstd = { workspace = true } [dev-dependencies] tokio = { workspace = true } diff --git a/crates/dry_run_core/src/audit/rules/fk_graph.rs b/crates/dry_run_core/src/audit/rules/fk_graph.rs index 8ac1805..78eb056 100644 --- a/crates/dry_run_core/src/audit/rules/fk_graph.rs +++ b/crates/dry_run_core/src/audit/rules/fk_graph.rs @@ -22,10 +22,14 @@ impl FkGraph { for constraint in &table.constraints { if constraint.kind == ConstraintKind::ForeignKey - && let Some(ref target) = constraint.fk_table { - nodes.insert(target.clone()); - edges.entry(source.clone()).or_default().insert(target.clone()); - } + && let Some(ref target) = constraint.fk_table + { + nodes.insert(target.clone()); + edges + .entry(source.clone()) + .or_default() + .insert(target.clone()); + } } } @@ -241,44 +245,78 @@ mod tests { fn make_col(name: &str, type_name: &str) -> Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, } } fn make_pk(name: &str, columns: &[&str]) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::PrimaryKey, + name: name.into(), + kind: ConstraintKind::PrimaryKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: None, fk_columns: vec![], backing_index: None, comment: None, + definition: None, + fk_table: None, + fk_columns: vec![], + backing_index: None, + comment: None, } } fn make_fk(name: &str, columns: &[&str], fk_table: &str, fk_columns: &[&str]) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::ForeignKey, + name: name.into(), + kind: ConstraintKind::ForeignKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: Some(fk_table.into()), + definition: None, + fk_table: Some(fk_table.into()), fk_columns: fk_columns.iter().map(|s| s.to_string()).collect(), - backing_index: None, comment: None, + backing_index: None, + comment: None, } } fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { Table { - oid: 0, schema: "public".into(), name: name.into(), - columns, constraints, indexes: vec![], - comment: None, stats: None, partition_info: None, - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + oid: 0, + schema: "public".into(), + name: name.into(), + columns, + constraints, + indexes: vec![], + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, } } fn schema_with(tables: Vec) -> SchemaSnapshot { SchemaSnapshot { - pg_version: "PostgreSQL 17.0".into(), database: "test".into(), - timestamp: Utc::now(), content_hash: "abc".into(), source: None, - tables, enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + pg_version: "PostgreSQL 17.0".into(), + database: "test".into(), + timestamp: Utc::now(), + content_hash: "abc".into(), + source: None, + tables, + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -310,11 +348,7 @@ mod tests { #[test] fn no_cycle_in_linear_chain() { let schema = schema_with(vec![ - make_table( - "a", - vec![make_col("id", "bigint")], - vec![], - ), + make_table("a", vec![make_col("id", "bigint")], vec![]), make_table( "b", vec![make_col("id", "bigint"), make_col("a_id", "bigint")], @@ -333,15 +367,16 @@ mod tests { #[test] fn detects_orphan_table() { let schema = schema_with(vec![ - make_table( - "users", - vec![make_col("id", "bigint")], - vec![], - ), + make_table("users", vec![make_col("id", "bigint")], vec![]), make_table( "orders", vec![make_col("id", "bigint"), make_col("user_id", "bigint")], - vec![make_fk("fk_orders_users", &["user_id"], "public.users", &["id"])], + vec![make_fk( + "fk_orders_users", + &["user_id"], + "public.users", + &["id"], + )], ), make_table( "config", @@ -365,7 +400,12 @@ mod tests { make_table( "orders", vec![make_col("id", "bigint"), make_col("user_id", "integer")], - vec![make_fk("fk_orders_user", &["user_id"], "public.users", &["user_id"])], + vec![make_fk( + "fk_orders_user", + &["user_id"], + "public.users", + &["user_id"], + )], ), ]); let findings = check_fk_type_mismatch(&schema); @@ -384,10 +424,18 @@ mod tests { make_table( "orders", vec![make_col("id", "bigint"), make_col("user_id", "integer")], - vec![make_fk("fk_orders_user", &["user_id"], "public.users", &["user_id"])], + vec![make_fk( + "fk_orders_user", + &["user_id"], + "public.users", + &["user_id"], + )], ), ]); let findings = check_fk_type_mismatch(&schema); - assert!(findings.is_empty(), "int4 and integer should be treated as equivalent"); + assert!( + findings.is_empty(), + "int4 and integer should be treated as equivalent" + ); } } diff --git a/crates/dry_run_core/src/audit/rules/indexes.rs b/crates/dry_run_core/src/audit/rules/indexes.rs index a9e4871..f111571 100644 --- a/crates/dry_run_core/src/audit/rules/indexes.rs +++ b/crates/dry_run_core/src/audit/rules/indexes.rs @@ -9,11 +9,7 @@ pub fn check_duplicate_indexes(schema: &SchemaSnapshot) -> Vec { for table in &schema.tables { let qualified = format!("{}.{}", table.schema, table.name); - let non_primary: Vec<_> = table - .indexes - .iter() - .filter(|idx| !idx.is_primary) - .collect(); + let non_primary: Vec<_> = table.indexes.iter().filter(|idx| !idx.is_primary).collect(); for (i, a) in non_primary.iter().enumerate() { for b in non_primary.iter().skip(i + 1) { @@ -70,7 +66,11 @@ pub fn check_duplicate_indexes(schema: &SchemaSnapshot) -> Vec { "Drop '{}' — '{}'{}", to_drop.name, to_keep.name, - if to_keep.backs_constraint { " backs a constraint" } else { " is sufficient" }, + if to_keep.backs_constraint { + " backs a constraint" + } else { + " is sufficient" + }, ), ddl_fix: Some(format!("DROP INDEX {};", to_drop.name)), min_pg_version: None, @@ -238,8 +238,9 @@ pub fn check_bloated_indexes(schema: &SchemaSnapshot) -> Vec { let qualified = format!("{}.{}", table.schema, table.name); for idx in &table.indexes { if let Some(est) = crate::schema::bloat::estimate_index_bloat(idx, table) - && est.bloat_ratio > DEFAULT_BLOAT_THRESHOLD { - findings.push(AuditFinding { + && est.bloat_ratio > DEFAULT_BLOAT_THRESHOLD + { + findings.push(AuditFinding { rule: "indexes/bloated".into(), category: AuditCategory::Storage, severity: Severity::Warning, @@ -252,7 +253,7 @@ pub fn check_bloated_indexes(schema: &SchemaSnapshot) -> Vec { ddl_fix: Some(format!("REINDEX INDEX CONCURRENTLY {};", idx.name)), min_pg_version: None, }); - } + } } } @@ -267,8 +268,16 @@ mod tests { fn make_col(name: &str, type_name: &str) -> Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, } } @@ -276,8 +285,11 @@ mod tests { Index { name: name.into(), columns: columns.iter().map(|s| s.to_string()).collect(), - include_columns: vec![], index_type: "btree".into(), - is_unique: false, is_primary: false, predicate: None, + include_columns: vec![], + index_type: "btree".into(), + is_unique: false, + is_primary: false, + predicate: None, definition: format!("CREATE INDEX {name} ON ..."), is_valid: true, backs_constraint: false, @@ -285,25 +297,39 @@ mod tests { } } - fn make_table_with( - name: &str, - columns: Vec, - indexes: Vec, - ) -> Table { + fn make_table_with(name: &str, columns: Vec, indexes: Vec) -> Table { Table { - oid: 0, schema: "public".into(), name: name.into(), - columns, constraints: vec![], indexes, - comment: None, stats: None, partition_info: None, - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + oid: 0, + schema: "public".into(), + name: name.into(), + columns, + constraints: vec![], + indexes, + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, } } fn schema_with(tables: Vec
) -> SchemaSnapshot { SchemaSnapshot { - pg_version: "PostgreSQL 17.0".into(), database: "test".into(), - timestamp: Utc::now(), content_hash: "abc".into(), source: None, - tables, enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + pg_version: "PostgreSQL 17.0".into(), + database: "test".into(), + timestamp: Utc::now(), + content_hash: "abc".into(), + source: None, + tables, + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -435,10 +461,7 @@ mod tests { let schema = schema_with(vec![make_table_with( "orders", vec![make_col("user_id", "bigint"), make_col("status", "text")], - vec![ - make_index("idx_user_all", &["user_id"]), - partial, - ], + vec![make_index("idx_user_all", &["user_id"]), partial], )]); let findings = check_duplicate_indexes(&schema); assert!(findings.is_empty()); @@ -451,17 +474,21 @@ mod tests { let schema = schema_with(vec![make_table_with( "orders", vec![make_col("user_id", "bigint")], - vec![ - make_index("idx_user_plain", &["user_id"]), - unique, - ], + vec![make_index("idx_user_plain", &["user_id"]), unique], )]); let findings = check_duplicate_indexes(&schema); assert_eq!(findings.len(), 1); assert_eq!(findings[0].severity, Severity::Warning); - assert!(findings[0].message.contains("Non-unique index 'idx_user_plain'")); + assert!( + findings[0] + .message + .contains("Non-unique index 'idx_user_plain'") + ); assert!(findings[0].message.contains("unique index 'idx_user_uniq'")); - assert_eq!(findings[0].ddl_fix.as_deref(), Some("DROP INDEX idx_user_plain;")); + assert_eq!( + findings[0].ddl_fix.as_deref(), + Some("DROP INDEX idx_user_plain;") + ); } #[test] @@ -484,7 +511,10 @@ mod tests { let findings = check_duplicate_indexes(&schema); assert_eq!(findings.len(), 1); assert_eq!(findings[0].severity, Severity::Warning); - assert_eq!(findings[0].ddl_fix.as_deref(), Some("DROP INDEX idx_workspace_name;")); + assert_eq!( + findings[0].ddl_fix.as_deref(), + Some("DROP INDEX idx_workspace_name;") + ); } #[test] @@ -494,10 +524,7 @@ mod tests { let schema = schema_with(vec![make_table_with( "orders", vec![make_col("user_id", "bigint"), make_col("status", "text")], - vec![ - make_index("idx_user_plain", &["user_id"]), - covering, - ], + vec![make_index("idx_user_plain", &["user_id"]), covering], )]); let findings = check_duplicate_indexes(&schema); assert!(findings.is_empty()); @@ -533,11 +560,13 @@ mod tests { fn both_back_constraints_warns_without_ddl_fix() { // one index owns a UNIQUE constraint, the other is used by a FK — // neither can be simply dropped, needs FK drop+recreate - let mut constraint_idx = make_index("unique_status_id_workspace_id", &["workspace_id", "id"]); + let mut constraint_idx = + make_index("unique_status_id_workspace_id", &["workspace_id", "id"]); constraint_idx.is_unique = true; constraint_idx.backs_constraint = true; - let mut fk_used_idx = make_index("idx_unique_status_id_workspace_id", &["workspace_id", "id"]); + let mut fk_used_idx = + make_index("idx_unique_status_id_workspace_id", &["workspace_id", "id"]); fk_used_idx.is_unique = true; fk_used_idx.backs_constraint = true; @@ -550,7 +579,10 @@ mod tests { let findings = check_duplicate_indexes(&schema); assert_eq!(findings.len(), 1); assert_eq!(findings[0].severity, Severity::Warning); - assert!(findings[0].ddl_fix.is_none(), "no simple DDL fix when both back constraints"); + assert!( + findings[0].ddl_fix.is_none(), + "no simple DDL fix when both back constraints" + ); } #[test] diff --git a/crates/dry_run_core/src/audit/rules/mod.rs b/crates/dry_run_core/src/audit/rules/mod.rs index fd9660f..72d6afa 100644 --- a/crates/dry_run_core/src/audit/rules/mod.rs +++ b/crates/dry_run_core/src/audit/rules/mod.rs @@ -7,10 +7,7 @@ use crate::schema::SchemaSnapshot; // Runs all audit rules and returns findings, skipping disabled ones #[must_use] -pub fn run_all_audit_rules( - snapshot: &SchemaSnapshot, - config: &AuditConfig, -) -> Vec { +pub fn run_all_audit_rules(snapshot: &SchemaSnapshot, config: &AuditConfig) -> Vec { let mut findings = Vec::new(); let disabled = &config.disabled_rules; @@ -23,19 +20,37 @@ pub fn run_all_audit_rules( } // index rules - run_rule!("indexes/duplicate", indexes::check_duplicate_indexes(snapshot)); - run_rule!("indexes/redundant", indexes::check_redundant_indexes(snapshot)); - run_rule!("indexes/too_many", indexes::check_too_many_indexes(snapshot, config)); - run_rule!("indexes/wide_columns", indexes::check_wide_column_indexes(snapshot)); + run_rule!( + "indexes/duplicate", + indexes::check_duplicate_indexes(snapshot) + ); + run_rule!( + "indexes/redundant", + indexes::check_redundant_indexes(snapshot) + ); + run_rule!( + "indexes/too_many", + indexes::check_too_many_indexes(snapshot, config) + ); + run_rule!( + "indexes/wide_columns", + indexes::check_wide_column_indexes(snapshot) + ); run_rule!("indexes/bloated", indexes::check_bloated_indexes(snapshot)); // FK rules - run_rule!("fk/type_mismatch", fk_graph::check_fk_type_mismatch(snapshot)); + run_rule!( + "fk/type_mismatch", + fk_graph::check_fk_type_mismatch(snapshot) + ); run_rule!("fk/circular", fk_graph::check_circular_fks(snapshot)); run_rule!("fk/orphan", fk_graph::check_orphan_tables(snapshot)); // PK rules - run_rule!("pk/non_sequential", schema::check_pk_non_sequential(snapshot)); + run_rule!( + "pk/non_sequential", + schema::check_pk_non_sequential(snapshot) + ); // naming rules run_rule!("naming/bool_prefix", schema::check_bool_prefix(snapshot)); @@ -43,10 +58,16 @@ pub fn run_all_audit_rules( run_rule!("naming/id_mismatch", schema::check_id_mismatch(snapshot)); // documentation rules - run_rule!("docs/no_comment", schema::check_no_comment(snapshot, config)); + run_rule!( + "docs/no_comment", + schema::check_no_comment(snapshot, config) + ); // storage rules - run_rule!("vacuum/large_table_defaults", schema::check_vacuum_large_table_defaults(snapshot)); + run_rule!( + "vacuum/large_table_defaults", + schema::check_vacuum_large_table_defaults(snapshot) + ); findings } @@ -60,10 +81,19 @@ mod tests { fn empty_schema() -> SchemaSnapshot { SchemaSnapshot { - pg_version: "PostgreSQL 17.0".into(), database: "test".into(), - timestamp: Utc::now(), content_hash: "abc".into(), source: None, - tables: vec![], enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + pg_version: "PostgreSQL 17.0".into(), + database: "test".into(), + timestamp: Utc::now(), + content_hash: "abc".into(), + source: None, + tables: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -79,14 +109,30 @@ mod tests { fn disabled_rules_are_skipped() { let schema = SchemaSnapshot { tables: vec![Table { - oid: 0, schema: "public".into(), name: "user".into(), + oid: 0, + schema: "public".into(), + name: "user".into(), columns: vec![Column { - name: "id".into(), ordinal: 0, type_name: "bigint".into(), - nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None, + name: "id".into(), + ordinal: 0, + type_name: "bigint".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, }], - constraints: vec![], indexes: vec![], - comment: None, stats: None, partition_info: None, - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + constraints: vec![], + indexes: vec![], + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }], ..empty_schema() }; diff --git a/crates/dry_run_core/src/audit/rules/schema.rs b/crates/dry_run_core/src/audit/rules/schema.rs index 10f920f..ead2f11 100644 --- a/crates/dry_run_core/src/audit/rules/schema.rs +++ b/crates/dry_run_core/src/audit/rules/schema.rs @@ -10,13 +10,74 @@ const BOOL_PREFIXES: &[&str] = &["is_", "has_", "can_", "should_", "was_", "will // Top ~50 most problematic SQL reserved words const RESERVED_WORDS: &[&str] = &[ - "all", "alter", "and", "any", "as", "asc", "between", "by", "case", "check", "column", - "constraint", "create", "cross", "current", "default", "delete", "desc", "distinct", "drop", - "else", "end", "exists", "false", "fetch", "for", "foreign", "from", "full", "grant", "group", - "having", "in", "index", "inner", "insert", "into", "is", "join", "key", "left", "like", - "limit", "not", "null", "offset", "on", "or", "order", "outer", "primary", "references", - "right", "select", "set", "table", "then", "to", "true", "union", "unique", "update", "user", - "using", "values", "when", "where", "with", + "all", + "alter", + "and", + "any", + "as", + "asc", + "between", + "by", + "case", + "check", + "column", + "constraint", + "create", + "cross", + "current", + "default", + "delete", + "desc", + "distinct", + "drop", + "else", + "end", + "exists", + "false", + "fetch", + "for", + "foreign", + "from", + "full", + "grant", + "group", + "having", + "in", + "index", + "inner", + "insert", + "into", + "is", + "join", + "key", + "left", + "like", + "limit", + "not", + "null", + "offset", + "on", + "or", + "order", + "outer", + "primary", + "references", + "right", + "select", + "set", + "table", + "then", + "to", + "true", + "union", + "unique", + "update", + "user", + "using", + "values", + "when", + "where", + "with", ]; #[must_use] @@ -79,10 +140,7 @@ pub fn check_bool_prefix(schema: &SchemaSnapshot) -> Vec { "Boolean column '{}' missing prefix (is_, has_, can_, ...)", col.name, ), - recommendation: format!( - "Rename to 'is_{}' or similar for clarity", - col.name, - ), + recommendation: format!("Rename to 'is_{}' or similar for clarity", col.name,), ddl_fix: Some(format!( "ALTER TABLE {} RENAME COLUMN {} TO is_{};", qualified, col.name, col.name, @@ -131,10 +189,7 @@ pub fn check_reserved_words(schema: &SchemaSnapshot) -> Vec { "Column '{}' in table '{}' is a SQL reserved word", col.name, table.name, ), - recommendation: format!( - "Rename column '{}' to avoid quoting hell", - col.name, - ), + recommendation: format!("Rename column '{}' to avoid quoting hell", col.name,), ddl_fix: None, min_pg_version: None, }); @@ -198,7 +253,11 @@ pub fn check_id_mismatch(schema: &SchemaSnapshot) -> Vec { message: format!( "Table '{}' referenced inconsistently: {} used as FK column names", target_table, - names.iter().map(|n| format!("'{n}'")).collect::>().join(", "), + names + .iter() + .map(|n| format!("'{n}'")) + .collect::>() + .join(", "), ), recommendation: "Standardize FK column naming for consistency".into(), ddl_fix: None, @@ -233,10 +292,7 @@ pub fn check_no_comment(schema: &SchemaSnapshot, config: &AuditConfig) -> Vec Vec Vec Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, } } fn make_col_with_comment(name: &str, type_name: &str, comment: &str) -> Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: None, identity: None, generated: None, - comment: Some(comment.into()), statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: Some(comment.into()), + statistics_target: None, + stats: None, } } fn make_pk(name: &str, columns: &[&str]) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::PrimaryKey, + name: name.into(), + kind: ConstraintKind::PrimaryKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: None, fk_columns: vec![], backing_index: None, comment: None, + definition: None, + fk_table: None, + fk_columns: vec![], + backing_index: None, + comment: None, } } fn make_fk(name: &str, columns: &[&str], fk_table: &str, fk_columns: &[&str]) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::ForeignKey, + name: name.into(), + kind: ConstraintKind::ForeignKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: Some(fk_table.into()), + definition: None, + fk_table: Some(fk_table.into()), fk_columns: fk_columns.iter().map(|s| s.to_string()).collect(), - backing_index: None, comment: None, + backing_index: None, + comment: None, } } fn make_table(name: &str, columns: Vec, constraints: Vec) -> Table { Table { - oid: 0, schema: "public".into(), name: name.into(), - columns, constraints, indexes: vec![], - comment: None, stats: None, partition_info: None, - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + oid: 0, + schema: "public".into(), + name: name.into(), + columns, + constraints, + indexes: vec![], + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, } } fn schema_with(tables: Vec
) -> SchemaSnapshot { SchemaSnapshot { - pg_version: "PostgreSQL 17.0".into(), database: "test".into(), - timestamp: Utc::now(), content_hash: "abc".into(), source: None, - tables, enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + pg_version: "PostgreSQL 17.0".into(), + database: "test".into(), + timestamp: Utc::now(), + content_hash: "abc".into(), + source: None, + tables, + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -472,20 +574,26 @@ mod tests { #[test] fn detects_inconsistent_fk_naming() { let schema = schema_with(vec![ - make_table( - "users", - vec![make_col("user_id", "bigint")], - vec![], - ), + make_table("users", vec![make_col("user_id", "bigint")], vec![]), make_table( "orders", vec![make_col("id", "bigint"), make_col("user_id", "bigint")], - vec![make_fk("fk_orders_user", &["user_id"], "public.users", &["user_id"])], + vec![make_fk( + "fk_orders_user", + &["user_id"], + "public.users", + &["user_id"], + )], ), make_table( "comments", vec![make_col("id", "bigint"), make_col("uid", "bigint")], - vec![make_fk("fk_comments_user", &["uid"], "public.users", &["user_id"])], + vec![make_fk( + "fk_comments_user", + &["uid"], + "public.users", + &["user_id"], + )], ), ]); let findings = check_id_mismatch(&schema); @@ -535,15 +643,15 @@ mod tests { fn skips_small_tables_for_comments() { let schema = schema_with(vec![make_table( "config", - vec![ - make_col("key", "text"), - make_col("value", "text"), - ], + vec![make_col("key", "text"), make_col("value", "text")], vec![], )]); let config = AuditConfig::default(); let findings = check_no_comment(&schema, &config); - assert!(findings.is_empty(), "tables with < 5 columns should be skipped"); + assert!( + findings.is_empty(), + "tables with < 5 columns should be skipped" + ); } #[test] diff --git a/crates/dry_run_core/src/config.rs b/crates/dry_run_core/src/config.rs index a010b4b..dfc56b4 100644 --- a/crates/dry_run_core/src/config.rs +++ b/crates/dry_run_core/src/config.rs @@ -4,6 +4,7 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; use crate::error::{Error, Result}; +use crate::history::{DatabaseId, ProjectId}; use crate::lint::{LintConfig, Severity}; #[derive(Debug, Clone)] @@ -23,6 +24,9 @@ impl ConnectionConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ProjectConfig { + #[serde(default)] + pub project: Option, + #[serde(default)] pub default: Option, @@ -36,6 +40,12 @@ pub struct ProjectConfig { pub services: Option, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProjectMeta { + #[serde(default)] + pub id: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DefaultConfig { pub profile: Option, @@ -45,6 +55,8 @@ pub struct DefaultConfig { pub struct ProfileConfig { pub db_url: Option, pub schema_file: Option, + #[serde(default)] + pub database_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -88,6 +100,8 @@ pub struct ResolvedProfile { pub name: String, pub db_url: Option, pub schema_file: Option, + pub project_id: ProjectId, + pub database_id: Option, } impl ProjectConfig { @@ -106,9 +120,10 @@ impl ProjectConfig { loop { let candidate = dir.join("dryrun.toml"); if candidate.is_file() - && let Ok(config) = Self::load(&candidate) { - return Some((candidate, config)); - } + && let Ok(config) = Self::load(&candidate) + { + return Some((candidate, config)); + } if dir.join(".git").exists() { return None; } @@ -119,11 +134,14 @@ impl ProjectConfig { } // resolution order: - // 1. explicit cli_db or cli_schema (CLI flags) - // 2. cli_profile flag (--profile) - // 3. PROFILE env var - // 4. [default].profile in toml - // 5. auto-discovery of .dryrun/schema.json + // 1. cli_profile flag (--profile) + // 2. PROFILE env var + // 3. [default].profile in toml + // 4. auto-discovery of .dryrun/schema.json + // + // CLI flags (cli_db, cli_schema) override the resolved profile's matching + // fields for the current invocation. So `--profile billing --db $OTHER` + // connects to $OTHER but keeps billing's database_id for snapshot keying. pub fn resolve_profile( &self, cli_db: Option<&str>, @@ -131,11 +149,42 @@ impl ProjectConfig { cli_profile: Option<&str>, project_root: &Path, ) -> Result { + let project_id = self.project_id(project_root); + + let explicit_profile = cli_profile + .map(|s| s.to_string()) + .or_else(|| std::env::var("PROFILE").ok()); + let default_profile = self.default.as_ref().and_then(|d| d.profile.clone()); + let profile_name = explicit_profile.clone().or(default_profile); + + if let Some(name) = profile_name { + if let Some(profile) = self.profiles.get(&name) { + let mut resolved = resolve_profile_config(&name, profile, project_root, project_id); + if let Some(db) = cli_db { + resolved.db_url = Some(expand_env_vars(db)); + } + if let Some(schema) = cli_schema { + resolved.schema_file = Some(schema.to_path_buf()); + } + return Ok(resolved); + } + + // Missing profile causes error. + if explicit_profile.is_some() || (cli_db.is_none() && cli_schema.is_none()) { + return Err(Error::Config(format!( + "profile '{name}' not found in dryrun.toml" + ))); + } + } + + // No profile resolved: fall back to or . if let Some(db) = cli_db { return Ok(ResolvedProfile { name: "".into(), db_url: Some(expand_env_vars(db)), schema_file: None, + project_id, + database_id: None, }); } if let Some(schema) = cli_schema { @@ -143,27 +192,19 @@ impl ProjectConfig { name: "".into(), db_url: None, schema_file: Some(schema.to_path_buf()), + project_id, + database_id: None, }); } - let profile_name = cli_profile - .map(|s| s.to_string()) - .or_else(|| std::env::var("PROFILE").ok()) - .or_else(|| self.default.as_ref().and_then(|d| d.profile.clone())); - - if let Some(name) = profile_name { - let profile = self.profiles.get(&name).ok_or_else(|| { - Error::Config(format!("profile '{name}' not found in dryrun.toml")) - })?; - return Ok(resolve_profile_config(&name, profile, project_root)); - } - let auto_schema = project_root.join(".dryrun/schema.json"); if auto_schema.is_file() { return Ok(ResolvedProfile { name: "".into(), db_url: None, schema_file: Some(auto_schema), + project_id, + database_id: None, }); } @@ -175,6 +216,16 @@ impl ProjectConfig { )) } + pub fn project_id(&self, project_root: &Path) -> ProjectId { + if let Some(meta) = &self.project + && let Some(id) = &meta.id + && !id.is_empty() + { + return ProjectId(id.clone()); + } + default_project_id(project_root) + } + pub fn pgmustard_api_key(&self) -> Option { self.services .as_ref() @@ -242,6 +293,7 @@ fn resolve_profile_config( name: &str, profile: &ProfileConfig, project_root: &Path, + project_id: ProjectId, ) -> ResolvedProfile { let db_url = profile.db_url.as_ref().map(|u| expand_env_vars(u)); let schema_file = profile.schema_file.as_ref().map(|p| { @@ -252,14 +304,30 @@ fn resolve_profile_config( project_root.join(path) } }); + let database_id = Some(DatabaseId( + profile + .database_id + .clone() + .filter(|s| !s.is_empty()) + .unwrap_or_else(|| name.to_string()), + )); ResolvedProfile { name: name.to_string(), db_url, schema_file, + project_id, + database_id, } } +fn default_project_id(project_root: &Path) -> ProjectId { + project_root + .file_name() + .map(|n| ProjectId(n.to_string_lossy().into_owned())) + .unwrap_or_else(|| ProjectId("default".into())) +} + pub fn expand_env_vars(input: &str) -> String { let mut result = input.to_string(); while let Some(start) = result.find("${") { @@ -308,7 +376,10 @@ table_name_regex = "^[a-z][a-z0-9_]*$" "#; let config = ProjectConfig::parse(toml).unwrap(); - assert_eq!(config.default.as_ref().unwrap().profile.as_deref(), Some("production")); + assert_eq!( + config.default.as_ref().unwrap().profile.as_deref(), + Some("production") + ); assert_eq!(config.profiles.len(), 3); assert!(config.profiles.contains_key("development")); assert!(config.profiles.contains_key("staging")); @@ -389,10 +460,18 @@ rules = ["pk/exists"] fn resolve_profile_cli_db_wins() { let config = ProjectConfig::parse("[default]\nprofile = \"prod\"").unwrap(); let resolved = config - .resolve_profile(Some("postgres://localhost/test"), None, None, Path::new("/tmp")) + .resolve_profile( + Some("postgres://localhost/test"), + None, + None, + Path::new("/tmp"), + ) .unwrap(); assert_eq!(resolved.name, ""); - assert_eq!(resolved.db_url.as_deref(), Some("postgres://localhost/test")); + assert_eq!( + resolved.db_url.as_deref(), + Some("postgres://localhost/test") + ); } #[test] @@ -406,7 +485,10 @@ schema_file = ".dryrun/staging.json" .resolve_profile(None, None, Some("staging"), Path::new("/project")) .unwrap(); assert_eq!(resolved.name, "staging"); - assert_eq!(resolved.schema_file.unwrap(), PathBuf::from("/project/.dryrun/staging.json")); + assert_eq!( + resolved.schema_file.unwrap(), + PathBuf::from("/project/.dryrun/staging.json") + ); } #[test] @@ -414,4 +496,378 @@ schema_file = ".dryrun/staging.json" let result = ProjectConfig::discover(Path::new("/nonexistent/path/that/doesnt/exist")); assert!(result.is_none()); } + + #[test] + fn parse_with_project_section() { + let toml = r#" +[project] +id = "myapp" + +[profiles.dev] +schema_file = ".dryrun/schema.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + assert_eq!(config.project.unwrap().id.as_deref(), Some("myapp")); + } + + #[test] + fn parse_with_database_id_per_profile() { + let toml = r#" +[profiles.prod-auth] +schema_file = ".dryrun/auth.json" +database_id = "auth" + +[profiles.prod-billing] +schema_file = ".dryrun/billing.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + assert_eq!( + config.profiles["prod-auth"].database_id.as_deref(), + Some("auth") + ); + assert!(config.profiles["prod-billing"].database_id.is_none()); + } + + #[test] + fn resolve_profile_uses_configured_project_id() { + let toml = r#" +[project] +id = "myapp" + +[profiles.dev] +schema_file = ".dryrun/schema.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("dev"), Path::new("/tmp/some-folder")) + .unwrap(); + assert_eq!(resolved.project_id.0, "myapp"); + } + + #[test] + fn resolve_profile_falls_back_to_cwd_basename() { + let toml = r#" +[profiles.dev] +schema_file = ".dryrun/schema.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("dev"), Path::new("/tmp/test-myapp")) + .unwrap(); + assert_eq!(resolved.project_id.0, "test-myapp"); + } + + #[test] + fn resolve_profile_database_id_defaults_to_profile_name() { + let toml = r#" +[profiles.staging] +schema_file = ".dryrun/staging.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("staging"), Path::new("/project")) + .unwrap(); + assert_eq!( + resolved.database_id.as_ref().map(|d| d.0.as_str()), + Some("staging") + ); + } + + #[test] + fn resolve_profile_database_id_from_config() { + let toml = r#" +[profiles.prod-auth] +schema_file = ".dryrun/auth.json" +database_id = "auth" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("prod-auth"), Path::new("/project")) + .unwrap(); + assert_eq!( + resolved.database_id.as_ref().map(|d| d.0.as_str()), + Some("auth") + ); + } + + #[test] + fn cli_profile_has_no_database_id() { + let config = ProjectConfig::parse("").unwrap(); + let resolved = config + .resolve_profile( + Some("postgres://localhost/test"), + None, + None, + Path::new("/tmp/myproj"), + ) + .unwrap(); + assert_eq!(resolved.name, ""); + assert!(resolved.database_id.is_none()); + assert_eq!(resolved.project_id.0, "myproj"); + } + + #[test] + fn cli_db_overrides_profile_db_url_keeps_database_id() { + let toml = r#" +[profiles.billing] +db_url = "postgres://prod/billing" +database_id = "billing" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile( + Some("postgres://localhost/other"), + None, + Some("billing"), + Path::new("/project"), + ) + .unwrap(); + assert_eq!(resolved.name, "billing"); + assert_eq!( + resolved.db_url.as_deref(), + Some("postgres://localhost/other") + ); + assert_eq!( + resolved.database_id.as_ref().map(|d| d.0.as_str()), + Some("billing") + ); + } + + #[test] + fn cli_schema_overrides_profile_schema_file_keeps_database_id() { + let toml = r#" +[profiles.staging] +schema_file = ".dryrun/staging.json" +database_id = "stg" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let override_path = PathBuf::from("/tmp/other-schema.json"); + let resolved = config + .resolve_profile( + None, + Some(&override_path), + Some("staging"), + Path::new("/project"), + ) + .unwrap(); + assert_eq!(resolved.name, "staging"); + assert_eq!( + resolved.schema_file.as_deref(), + Some(override_path.as_path()) + ); + assert_eq!( + resolved.database_id.as_ref().map(|d| d.0.as_str()), + Some("stg") + ); + } + + #[test] + fn explicit_profile_missing_errors() { + let config = ProjectConfig::parse("").unwrap(); + let result = config.resolve_profile(None, None, Some("nope"), Path::new("/tmp")); + let err = result.unwrap_err().to_string(); + assert!(err.contains("'nope'"), "got: {err}"); + } + + #[test] + fn default_profile_missing_with_cli_db_falls_back_to_cli() { + let config = ProjectConfig::parse("[default]\nprofile = \"prod\"").unwrap(); + let resolved = config + .resolve_profile( + Some("postgres://localhost/x"), + None, + None, + Path::new("/tmp"), + ) + .unwrap(); + assert_eq!(resolved.name, ""); + assert!(resolved.database_id.is_none()); + } + + #[test] + fn default_profile_missing_without_cli_args_errors() { + let config = ProjectConfig::parse("[default]\nprofile = \"missing\"").unwrap(); + let result = config.resolve_profile(None, None, None, Path::new("/tmp")); + let err = result.unwrap_err().to_string(); + assert!(err.contains("'missing'"), "got: {err}"); + } + + #[test] + fn project_id_falls_back_to_default_for_root_path() { + let config = ProjectConfig::parse("").unwrap(); + // root path has no file_name; falls back to "default" + assert_eq!(config.project_id(Path::new("/")).0, "default"); + } + + #[test] + fn explicit_profile_overrides_default_profile() { + let toml = r#" +[default] +profile = "prod" + +[profiles.prod] +schema_file = "prod.json" + +[profiles.dev] +schema_file = "dev.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("dev"), Path::new("/p")) + .unwrap(); + assert_eq!(resolved.name, "dev"); + assert_eq!(resolved.schema_file.unwrap(), PathBuf::from("/p/dev.json")); + } + + #[test] + fn resolve_profile_absolute_schema_path_kept_as_is() { + let toml = r#" +[profiles.dev] +schema_file = "/abs/schema.json" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("dev"), Path::new("/project")) + .unwrap(); + assert_eq!( + resolved.schema_file.unwrap(), + PathBuf::from("/abs/schema.json") + ); + } + + #[test] + fn resolve_profile_empty_database_id_falls_back_to_profile_name() { + let toml = r#" +[profiles.staging] +schema_file = "x.json" +database_id = "" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + let resolved = config + .resolve_profile(None, None, Some("staging"), Path::new("/p")) + .unwrap(); + assert_eq!( + resolved.database_id.as_ref().map(|d| d.0.as_str()), + Some("staging") + ); + } + + #[test] + fn resolve_profile_auto_discovers_schema_json() { + let dir = tempfile::TempDir::new().unwrap(); + let dryrun_dir = dir.path().join(".dryrun"); + std::fs::create_dir_all(&dryrun_dir).unwrap(); + std::fs::write(dryrun_dir.join("schema.json"), "{}").unwrap(); + + let config = ProjectConfig::parse("").unwrap(); + let resolved = config + .resolve_profile(None, None, None, dir.path()) + .unwrap(); + assert_eq!(resolved.name, ""); + assert!(resolved.database_id.is_none()); + assert_eq!( + resolved.schema_file.unwrap(), + dir.path().join(".dryrun/schema.json") + ); + } + + #[test] + fn resolve_profile_cli_schema_without_profile_falls_back() { + let config = ProjectConfig::parse("").unwrap(); + let p = PathBuf::from("/some/where.json"); + let resolved = config + .resolve_profile(None, Some(&p), None, Path::new("/p")) + .unwrap(); + assert_eq!(resolved.name, ""); + assert_eq!(resolved.schema_file.as_deref(), Some(p.as_path())); + assert!(resolved.db_url.is_none()); + } + + #[test] + fn resolve_profile_no_profile_no_schema_no_cli_errors() { + let dir = tempfile::TempDir::new().unwrap(); + let config = ProjectConfig::parse("").unwrap(); + let result = config.resolve_profile(None, None, None, dir.path()); + assert!(result.is_err()); + } + + #[test] + fn expand_env_vars_multiple_in_one_string() { + // SAFETY: test-only, single-threaded test runner + unsafe { + std::env::set_var("DRYRUN_A", "alpha"); + std::env::set_var("DRYRUN_B", "beta"); + } + assert_eq!(expand_env_vars("${DRYRUN_A}-${DRYRUN_B}"), "alpha-beta"); + unsafe { + std::env::remove_var("DRYRUN_A"); + std::env::remove_var("DRYRUN_B"); + } + } + + #[test] + fn expand_env_vars_unterminated_brace_left_alone() { + // no closing brace — should not loop forever, return as-is + assert_eq!(expand_env_vars("foo ${UNCLOSED bar"), "foo ${UNCLOSED bar"); + } + + #[test] + fn discover_finds_config_in_parent() { + let dir = tempfile::TempDir::new().unwrap(); + // simulate repo root + std::fs::create_dir(dir.path().join(".git")).unwrap(); + std::fs::write( + dir.path().join("dryrun.toml"), + "[profiles.dev]\nschema_file = \"x.json\"\n", + ) + .unwrap(); + + let nested = dir.path().join("a").join("b"); + std::fs::create_dir_all(&nested).unwrap(); + let (path, config) = ProjectConfig::discover(&nested).unwrap(); + assert_eq!(path, dir.path().join("dryrun.toml")); + assert!(config.profiles.contains_key("dev")); + } + + #[test] + fn discover_stops_at_git_root() { + let dir = tempfile::TempDir::new().unwrap(); + // .git in inner dir, dryrun.toml only above it — discovery must NOT cross the boundary + std::fs::create_dir(dir.path().join(".git")).unwrap(); + std::fs::write( + dir.path().parent().unwrap().join("dryrun.toml"), + "[profiles.dev]\n", + ) + .ok(); + // discovery from the git root should not find the parent's dryrun.toml + assert!(ProjectConfig::discover(dir.path()).is_none()); + } + + #[test] + fn pgmustard_api_key_from_config_expands_env() { + // SAFETY: test-only, single-threaded test runner + unsafe { std::env::set_var("DRYRUN_PGM_KEY", "sk-test-123") }; + let toml = r#" +[services] +pgmustard_api_key = "${DRYRUN_PGM_KEY}" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + assert_eq!(config.pgmustard_api_key().as_deref(), Some("sk-test-123")); + unsafe { std::env::remove_var("DRYRUN_PGM_KEY") }; + } + + #[test] + fn pgmustard_api_key_empty_after_expansion_falls_through() { + // SAFETY: test-only, single-threaded test runner + unsafe { + std::env::remove_var("DRYRUN_PGM_MISSING"); + std::env::remove_var("PGMUSTARD_API_KEY"); + } + let toml = r#" +[services] +pgmustard_api_key = "${DRYRUN_PGM_MISSING}" +"#; + let config = ProjectConfig::parse(toml).unwrap(); + assert!(config.pgmustard_api_key().is_none()); + } } diff --git a/crates/dry_run_core/src/connection.rs b/crates/dry_run_core/src/connection.rs index 409f08c..f4f300e 100644 --- a/crates/dry_run_core/src/connection.rs +++ b/crates/dry_run_core/src/connection.rs @@ -1,8 +1,8 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; -use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use sqlx::PgPool; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tracing::{debug, info}; use crate::error::{Error, Result}; @@ -89,6 +89,13 @@ impl DryRun { crate::schema::fetch_is_standby(&self.pool).await } + pub async fn current_database(&self) -> Result { + let dbname: String = sqlx::query_scalar("SELECT current_database()") + .fetch_one(&self.pool) + .await?; + Ok(dbname) + } + pub fn pool(&self) -> &PgPool { &self.pool } diff --git a/crates/dry_run_core/src/diff/changeset.rs b/crates/dry_run_core/src/diff/changeset.rs index 10d8598..326ed3e 100644 --- a/crates/dry_run_core/src/diff/changeset.rs +++ b/crates/dry_run_core/src/diff/changeset.rs @@ -300,15 +300,16 @@ fn diff_views(from: &[View], to: &[View], changes: &mut Vec) { } for (key, old) in &from_map { if let Some(new) = to_map.get(key) - && old.definition != new.definition { - changes.push(Change { - kind: ChangeKind::Modified, - object_type: "view".into(), - schema: Some(old.schema.clone()), - name: old.name.clone(), - details: vec!["definition changed".into()], - }); - } + && old.definition != new.definition + { + changes.push(Change { + kind: ChangeKind::Modified, + object_type: "view".into(), + schema: Some(old.schema.clone()), + name: old.name.clone(), + details: vec!["definition changed".into()], + }); + } } } @@ -416,14 +417,15 @@ fn diff_named( } for (key, old) in &from_map { if let Some(new) = to_map.get(key) - && old != new { - changes.push(Change { - kind: ChangeKind::Modified, - object_type: object_type.into(), - schema: None, - name: key.clone(), - details: vec!["definition changed".into()], - }); - } + && old != new + { + changes.push(Change { + kind: ChangeKind::Modified, + object_type: object_type.into(), + schema: None, + name: key.clone(), + details: vec!["definition changed".into()], + }); + } } } diff --git a/crates/dry_run_core/src/diff/mod.rs b/crates/dry_run_core/src/diff/mod.rs index 1b2088d..af4674e 100644 --- a/crates/dry_run_core/src/diff/mod.rs +++ b/crates/dry_run_core/src/diff/mod.rs @@ -158,8 +158,8 @@ mod tests { local.tables.push(shared); let report = classify_drift(&prod, &local); - assert_eq!(report.summary.ahead, 1); // local_only - assert_eq!(report.summary.behind, 1); // prod_only + assert_eq!(report.summary.ahead, 1); // local_only + assert_eq!(report.summary.behind, 1); // prod_only assert_eq!(report.summary.diverged, 1); // shared (modified) assert_eq!(report.entries.len(), 3); } diff --git a/crates/dry_run_core/src/error.rs b/crates/dry_run_core/src/error.rs index bf2ea2f..0f7425c 100644 --- a/crates/dry_run_core/src/error.rs +++ b/crates/dry_run_core/src/error.rs @@ -28,4 +28,10 @@ pub enum Error { Database(#[from] sqlx::Error), } +impl From for Error { + fn from(e: rusqlite::Error) -> Self { + Error::History(e.to_string()) + } +} + pub type Result = std::result::Result; diff --git a/crates/dry_run_core/src/history/mod.rs b/crates/dry_run_core/src/history/mod.rs index e7a2624..b37282b 100644 --- a/crates/dry_run_core/src/history/mod.rs +++ b/crates/dry_run_core/src/history/mod.rs @@ -1,3 +1,7 @@ +mod snapshot_store; mod store; -pub use store::{default_data_dir, HistoryStore}; +pub use snapshot_store::{ + DatabaseId, ProjectId, PutOutcome, SnapshotKey, SnapshotRef, SnapshotStore, TimeRange, +}; +pub use store::{HistoryStore, SnapshotSummary, default_data_dir}; diff --git a/crates/dry_run_core/src/history/snapshot_store.rs b/crates/dry_run_core/src/history/snapshot_store.rs new file mode 100644 index 0000000..4da5b7c --- /dev/null +++ b/crates/dry_run_core/src/history/snapshot_store.rs @@ -0,0 +1,48 @@ +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +use crate::error::Result; +use crate::schema::SchemaSnapshot; + +pub use super::store::SnapshotSummary; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ProjectId(pub String); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct DatabaseId(pub String); + +#[derive(Debug, Clone)] +pub struct SnapshotKey { + pub project_id: ProjectId, + pub database_id: DatabaseId, +} + +#[derive(Debug, Clone)] +pub enum SnapshotRef { + Latest, + At(DateTime), + Hash(String), +} + +#[derive(Debug, Clone, Default)] +pub struct TimeRange { + pub from: Option>, + pub to: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PutOutcome { + Inserted, + Deduped, +} + +#[async_trait] +pub trait SnapshotStore: Send + Sync { + async fn put(&self, key: &SnapshotKey, snap: &SchemaSnapshot) -> Result; + async fn get(&self, key: &SnapshotKey, at: SnapshotRef) -> Result; + async fn list(&self, key: &SnapshotKey, range: TimeRange) -> Result>; + async fn latest(&self, key: &SnapshotKey) -> Result>; + async fn delete_before(&self, key: &SnapshotKey, cutoff: DateTime) -> Result; +} diff --git a/crates/dry_run_core/src/history/store.rs b/crates/dry_run_core/src/history/store.rs index b37090f..ef90152 100644 --- a/crates/dry_run_core/src/history/store.rs +++ b/crates/dry_run_core/src/history/store.rs @@ -1,24 +1,29 @@ use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; +use async_trait::async_trait; use chrono::{DateTime, Utc}; -use rusqlite::{params, Connection}; -use sha2::{Digest, Sha256}; +use rusqlite::{Connection, params}; use tracing::{debug, info}; use crate::error::{Error, Result}; +use crate::history::snapshot_store::{ + PutOutcome, SnapshotKey, SnapshotRef, SnapshotStore, TimeRange, +}; use crate::schema::SchemaSnapshot; pub struct HistoryStore { - conn: Connection, + conn: Arc>, } #[derive(Debug, Clone)] pub struct SnapshotSummary { pub id: i64, - pub db_url_hash: String, pub timestamp: DateTime, pub content_hash: String, pub database: String, + pub project_id: Option, + pub database_id: Option, } impl HistoryStore { @@ -31,7 +36,9 @@ impl HistoryStore { let conn = Connection::open(path) .map_err(|e| Error::History(format!("cannot open history db: {e}")))?; - let store = Self { conn }; + let store = Self { + conn: Arc::new(Mutex::new(conn)), + }; store.migrate()?; debug!(path = %path.display(), "history store opened"); @@ -43,179 +50,41 @@ impl HistoryStore { Self::open(&path) } - // saves snapshot, returns false if content_hash unchanged from latest - pub fn save_snapshot(&self, db_url: &str, snapshot: &SchemaSnapshot) -> Result { - let db_url_hash = hash_url(db_url); - - let latest_hash: Option = self - .conn - .query_row( - "SELECT content_hash FROM snapshots - WHERE db_url_hash = ?1 - ORDER BY timestamp DESC LIMIT 1", - params![db_url_hash], - |row| row.get(0), - ) - .ok(); - - if latest_hash.as_deref() == Some(&snapshot.content_hash) { - debug!(hash = %snapshot.content_hash, "schema unchanged, skipping save"); - return Ok(false); - } - - let json = serde_json::to_string(snapshot) - .map_err(|e| Error::History(format!("cannot serialize snapshot: {e}")))?; - - self.conn - .execute( - "INSERT INTO snapshots (db_url_hash, timestamp, content_hash, database_name, snapshot_json) - VALUES (?1, ?2, ?3, ?4, ?5)", - params![ - db_url_hash, - snapshot.timestamp.to_rfc3339(), - snapshot.content_hash, - snapshot.database, - json, - ], - ) - .map_err(|e| Error::History(format!("cannot save snapshot: {e}")))?; - - info!( - hash = %snapshot.content_hash, - database = %snapshot.database, - "snapshot saved" - ); - Ok(true) - } - - pub fn load_snapshot(&self, content_hash: &str) -> Result> { - let json: Option = self - .conn - .query_row( - "SELECT snapshot_json FROM snapshots WHERE content_hash = ?1 LIMIT 1", - params![content_hash], - |row| row.get(0), - ) - .ok(); - - match json { - Some(j) => { - let snapshot: SchemaSnapshot = serde_json::from_str(&j) - .map_err(|e| Error::History(format!("corrupt snapshot JSON: {e}")))?; - Ok(Some(snapshot)) - } - None => Ok(None), - } - } - - pub fn list_snapshots(&self, db_url: &str) -> Result> { - let db_url_hash = hash_url(db_url); - - let mut stmt = self - .conn - .prepare( - "SELECT id, db_url_hash, timestamp, content_hash, database_name - FROM snapshots - WHERE db_url_hash = ?1 - ORDER BY timestamp DESC", - ) - .map_err(|e| Error::History(e.to_string()))?; - - let rows = stmt - .query_map(params![db_url_hash], |row| { - let ts_str: String = row.get(2)?; - Ok(SnapshotSummary { - id: row.get(0)?, - db_url_hash: row.get(1)?, - timestamp: DateTime::parse_from_rfc3339(&ts_str) - .map(|dt| dt.with_timezone(&Utc)) - .unwrap_or_default(), - content_hash: row.get(3)?, - database: row.get(4)?, - }) + pub fn list_keys(&self) -> Result> { + let conn = lock_conn(&self.conn)?; + let mut stmt = conn.prepare( + "SELECT DISTINCT project_id, database_id + FROM snapshots + WHERE project_id IS NOT NULL AND database_id IS NOT NULL + ORDER BY project_id, database_id", + )?; + let rows = stmt.query_map([], |row| { + let pid: String = row.get(0)?; + let did: String = row.get(1)?; + Ok(SnapshotKey { + project_id: crate::history::ProjectId(pid), + database_id: crate::history::DatabaseId(did), }) - .map_err(|e| Error::History(e.to_string()))?; - - let mut summaries = Vec::new(); - for row in rows { - summaries.push(row.map_err(|e| Error::History(e.to_string()))?); - } - Ok(summaries) - } - - pub fn snapshots_since( - &self, - db_url: &str, - since: DateTime, - ) -> Result> { - let db_url_hash = hash_url(db_url); - - let mut stmt = self - .conn - .prepare( - "SELECT snapshot_json FROM snapshots - WHERE db_url_hash = ?1 AND timestamp >= ?2 - ORDER BY timestamp ASC", - ) - .map_err(|e| Error::History(e.to_string()))?; - - let rows = stmt - .query_map(params![db_url_hash, since.to_rfc3339()], |row| { - row.get::<_, String>(0) - }) - .map_err(|e| Error::History(e.to_string()))?; - - let mut snapshots = Vec::new(); - for row in rows { - let json = row.map_err(|e| Error::History(e.to_string()))?; - let snapshot: SchemaSnapshot = serde_json::from_str(&json) - .map_err(|e| Error::History(format!("corrupt snapshot JSON: {e}")))?; - snapshots.push(snapshot); - } - Ok(snapshots) - } - - pub fn latest_snapshot(&self, db_url: &str) -> Result> { - let db_url_hash = hash_url(db_url); - - let json: Option = self - .conn - .query_row( - "SELECT snapshot_json FROM snapshots - WHERE db_url_hash = ?1 - ORDER BY timestamp DESC LIMIT 1", - params![db_url_hash], - |row| row.get(0), - ) - .ok(); - - match json { - Some(j) => { - let snapshot: SchemaSnapshot = serde_json::from_str(&j) - .map_err(|e| Error::History(format!("corrupt snapshot JSON: {e}")))?; - Ok(Some(snapshot)) - } - None => Ok(None), - } + })?; + rows.map(|r| r.map_err(Error::from)).collect() } fn migrate(&self) -> Result<()> { - self.conn - .execute_batch( - "CREATE TABLE IF NOT EXISTS snapshots ( + let conn = lock_conn(&self.conn)?; + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, - db_url_hash TEXT NOT NULL, timestamp TEXT NOT NULL, content_hash TEXT NOT NULL, database_name TEXT NOT NULL, - snapshot_json TEXT NOT NULL + snapshot_json TEXT NOT NULL, + project_id TEXT, + database_id TEXT ); - CREATE INDEX IF NOT EXISTS idx_snapshots_db_url_hash - ON snapshots(db_url_hash, timestamp DESC); CREATE INDEX IF NOT EXISTS idx_snapshots_content_hash ON snapshots(content_hash);", - ) - .map_err(|e| Error::History(format!("migration failed: {e}")))?; + ) + .map_err(|e| Error::History(format!("migration failed: {e}")))?; Ok(()) } } @@ -231,37 +100,221 @@ pub fn default_data_dir() -> Result { Ok(cwd.join(".dryrun")) } -fn hash_url(url: &str) -> String { - let digest = Sha256::digest(url.as_bytes()); - let hex: String = digest.iter().fold(String::new(), |mut s, b| { - use std::fmt::Write; - write!(s, "{b:02x}").expect("write to String cannot fail"); - s - }); - hex[..16].to_string() +fn lock_conn(conn: &Mutex) -> Result> { + conn.lock() + .map_err(|e| Error::History(format!("lock poisoned: {e}"))) +} + +fn row_to_summary(row: &rusqlite::Row<'_>) -> rusqlite::Result { + let ts_str: String = row.get(1)?; + Ok(SnapshotSummary { + id: row.get(0)?, + timestamp: DateTime::parse_from_rfc3339(&ts_str) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_default(), + content_hash: row.get(2)?, + database: row.get(3)?, + project_id: row.get(4)?, + database_id: row.get(5)?, + }) +} + +async fn run_blocking(conn: &Arc>, f: F) -> Result +where + F: FnOnce(&Connection) -> Result + Send + 'static, + T: Send + 'static, +{ + let conn = conn.clone(); + tokio::task::spawn_blocking(move || -> Result { + let conn = conn + .lock() + .map_err(|e| Error::History(format!("lock poisoned: {e}")))?; + f(&conn) + }) + .await + .map_err(|e| Error::History(format!("blocking task failed: {e}")))? +} + +#[async_trait] +impl SnapshotStore for HistoryStore { + async fn put(&self, key: &SnapshotKey, snap: &SchemaSnapshot) -> Result { + let key = key.clone(); + let snap = snap.clone(); + run_blocking(&self.conn, move |conn| { + let pid = &key.project_id.0; + let did = &key.database_id.0; + + let latest: Option = conn + .query_row( + "SELECT content_hash FROM snapshots + WHERE project_id = ?1 AND database_id = ?2 + ORDER BY timestamp DESC LIMIT 1", + params![pid, did], + |row| row.get(0), + ) + .ok(); + + if latest.as_deref() == Some(snap.content_hash.as_str()) { + debug!(hash = %snap.content_hash, "schema unchanged, skipping put"); + return Ok(PutOutcome::Deduped); + } + + let json = serde_json::to_string(&snap) + .map_err(|e| Error::History(format!("cannot serialize snapshot: {e}")))?; + + conn.execute( + "INSERT INTO snapshots (timestamp, content_hash, database_name, + snapshot_json, project_id, database_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + snap.timestamp.to_rfc3339(), + snap.content_hash, + snap.database, + json, + pid, + did, + ], + )?; + + info!(hash = %snap.content_hash, project = %pid, database = %did, "snapshot put"); + Ok(PutOutcome::Inserted) + }) + .await + } + + async fn get(&self, key: &SnapshotKey, at: SnapshotRef) -> Result { + let pid = key.project_id.0.clone(); + let did = key.database_id.0.clone(); + run_blocking(&self.conn, move |conn| { + let row = match &at { + SnapshotRef::Latest => conn.query_row( + "SELECT snapshot_json FROM snapshots + WHERE project_id = ?1 AND database_id = ?2 + ORDER BY timestamp DESC LIMIT 1", + params![pid, did], + |r| r.get::<_, String>(0), + ), + SnapshotRef::At(ts) => conn.query_row( + "SELECT snapshot_json FROM snapshots + WHERE project_id = ?1 AND database_id = ?2 AND timestamp <= ?3 + ORDER BY timestamp DESC LIMIT 1", + params![pid, did, ts.to_rfc3339()], + |r| r.get::<_, String>(0), + ), + SnapshotRef::Hash(h) => conn.query_row( + "SELECT snapshot_json FROM snapshots + WHERE project_id = ?1 AND database_id = ?2 AND content_hash = ?3 + LIMIT 1", + params![pid, did, h], + |r| r.get::<_, String>(0), + ), + }; + + let json = match row { + Ok(j) => j, + Err(rusqlite::Error::QueryReturnedNoRows) => { + let detail = match at { + SnapshotRef::Latest => "latest".to_string(), + SnapshotRef::At(ts) => format!("at-or-before {ts}"), + SnapshotRef::Hash(h) => format!("hash {h}"), + }; + return Err(Error::History(format!("snapshot not found ({detail})"))); + } + Err(e) => return Err(e.into()), + }; + + serde_json::from_str(&json) + .map_err(|e| Error::History(format!("corrupt snapshot JSON: {e}"))) + }) + .await + } + + async fn list(&self, key: &SnapshotKey, range: TimeRange) -> Result> { + let pid = key.project_id.0.clone(); + let did = key.database_id.0.clone(); + run_blocking(&self.conn, move |conn| { + let mut sql = String::from( + "SELECT id, timestamp, content_hash, database_name, + project_id, database_id + FROM snapshots + WHERE project_id = ?1 AND database_id = ?2", + ); + let mut bound: Vec> = vec![Box::new(pid), Box::new(did)]; + if let Some(from) = range.from { + sql += &format!(" AND timestamp >= ?{}", bound.len() + 1); + bound.push(Box::new(from.to_rfc3339())); + } + if let Some(to) = range.to { + sql += &format!(" AND timestamp < ?{}", bound.len() + 1); + bound.push(Box::new(to.to_rfc3339())); + } + sql += " ORDER BY timestamp DESC"; + + let mut stmt = conn.prepare(&sql)?; + let params: Vec<&dyn rusqlite::ToSql> = bound.iter().map(|b| b.as_ref()).collect(); + stmt.query_map(params.as_slice(), row_to_summary)? + .map(|r| r.map_err(Error::from)) + .collect() + }) + .await + } + + async fn latest(&self, key: &SnapshotKey) -> Result> { + Ok(self + .list(key, TimeRange::default()) + .await? + .into_iter() + .next()) + } + + async fn delete_before(&self, key: &SnapshotKey, cutoff: DateTime) -> Result { + let pid = key.project_id.0.clone(); + let did = key.database_id.0.clone(); + run_blocking(&self.conn, move |conn| { + Ok(conn.execute( + "DELETE FROM snapshots + WHERE project_id = ?1 AND database_id = ?2 AND timestamp < ?3", + params![pid, did, cutoff.to_rfc3339()], + )?) + }) + .await + } } #[cfg(test)] -mod tests { - use chrono::Utc; +mod trait_tests { + use chrono::Duration; use tempfile::TempDir; use super::*; - use crate::schema::SchemaSnapshot; + use crate::history::snapshot_store::{DatabaseId, ProjectId}; - fn make_snapshot(hash: &str, database: &str) -> SchemaSnapshot { + fn make_snap(hash: &str, database: &str) -> SchemaSnapshot { SchemaSnapshot { pg_version: "PostgreSQL 17.0".into(), database: database.into(), timestamp: Utc::now(), content_hash: hash.into(), source: None, - tables: vec![], enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + tables: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } + fn key(proj: &str, db: &str) -> SnapshotKey { + SnapshotKey { + project_id: ProjectId(proj.into()), + database_id: DatabaseId(db.into()), + } + } + fn temp_store() -> (TempDir, HistoryStore) { let dir = TempDir::new().unwrap(); let path = dir.path().join("test_history.db"); @@ -269,78 +322,293 @@ mod tests { (dir, store) } - #[test] - fn save_and_load() { + #[tokio::test] + async fn put_inserts_then_dedupes() { + let (_dir, store) = temp_store(); + let k = key("p", "auth"); + let snap = make_snap("h1", "auth"); + + assert_eq!(store.put(&k, &snap).await.unwrap(), PutOutcome::Inserted); + assert_eq!(store.put(&k, &snap).await.unwrap(), PutOutcome::Deduped); + } + + #[tokio::test] + async fn put_isolates_across_databases() { let (_dir, store) = temp_store(); - let snap = make_snapshot("abc123", "mydb"); - let url = "postgres://user@host/mydb"; + let auth = key("p", "auth"); + let billing = key("p", "billing"); - assert!(store.save_snapshot(url, &snap).unwrap()); + // same content_hash under different database_id should not dedupe + assert_eq!( + store.put(&auth, &make_snap("same", "auth")).await.unwrap(), + PutOutcome::Inserted + ); + assert_eq!( + store + .put(&billing, &make_snap("same", "billing")) + .await + .unwrap(), + PutOutcome::Inserted + ); - let loaded = store.load_snapshot("abc123").unwrap(); - assert!(loaded.is_some()); - assert_eq!(loaded.unwrap().content_hash, "abc123"); + let auth_rows = store.list(&auth, TimeRange::default()).await.unwrap(); + let billing_rows = store.list(&billing, TimeRange::default()).await.unwrap(); + assert_eq!(auth_rows.len(), 1); + assert_eq!(billing_rows.len(), 1); + assert_eq!(auth_rows[0].database_id.as_deref(), Some("auth")); + assert_eq!(billing_rows[0].database_id.as_deref(), Some("billing")); } - #[test] - fn skip_duplicate_hash() { + #[tokio::test] + async fn put_isolates_across_projects() { let (_dir, store) = temp_store(); - let url = "postgres://user@host/mydb"; + let a = key("a", "x"); + let b = key("b", "x"); + store.put(&a, &make_snap("h", "x")).await.unwrap(); + store.put(&b, &make_snap("h", "x")).await.unwrap(); + + let a_rows = store.list(&a, TimeRange::default()).await.unwrap(); + let b_rows = store.list(&b, TimeRange::default()).await.unwrap(); + assert_eq!(a_rows.len(), 1); + assert_eq!(b_rows.len(), 1); + assert_eq!(a_rows[0].project_id.as_deref(), Some("a")); + assert_eq!(b_rows[0].project_id.as_deref(), Some("b")); + } - assert!(store.save_snapshot(url, &make_snapshot("same_hash", "mydb")).unwrap()); - assert!(!store.save_snapshot(url, &make_snapshot("same_hash", "mydb")).unwrap()); + #[tokio::test] + async fn list_orders_newest_first() { + let (_dir, store) = temp_store(); + let k = key("p", "x"); + let mut s1 = make_snap("h1", "x"); + s1.timestamp = Utc::now() - Duration::hours(2); + let mut s2 = make_snap("h2", "x"); + s2.timestamp = Utc::now() - Duration::hours(1); + store.put(&k, &s1).await.unwrap(); + store.put(&k, &s2).await.unwrap(); + + let rows = store.list(&k, TimeRange::default()).await.unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].content_hash, "h2"); + assert_eq!(rows[1].content_hash, "h1"); } - #[test] - fn list_snapshots_order() { + #[tokio::test] + async fn list_filters_by_time_range() { let (_dir, store) = temp_store(); - let url = "postgres://user@host/mydb"; + let k = key("p", "x"); + let now = Utc::now(); + for (i, hash) in ["h0", "h1", "h2"].iter().enumerate() { + let mut s = make_snap(hash, "x"); + s.timestamp = now - Duration::hours(2 - i as i64); + store.put(&k, &s).await.unwrap(); + } + + // from = -90min: h0 at -2h is excluded, h1 at -1h and h2 at 0 included + let rows = store + .list( + &k, + TimeRange { + from: Some(now - Duration::minutes(90)), + to: None, + }, + ) + .await + .unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].content_hash, "h2"); + assert_eq!(rows[1].content_hash, "h1"); + + // to = -30min (exclusive): h2 at 0 excluded, h0 and h1 included + let rows = store + .list( + &k, + TimeRange { + from: None, + to: Some(now - Duration::minutes(30)), + }, + ) + .await + .unwrap(); + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].content_hash, "h1"); + assert_eq!(rows[1].content_hash, "h0"); + } - let mut s1 = make_snapshot("hash1", "mydb"); - s1.timestamp = Utc::now() - chrono::Duration::hours(2); - store.save_snapshot(url, &s1).unwrap(); + #[tokio::test] + async fn latest_returns_most_recent_or_none() { + let (_dir, store) = temp_store(); + let k = key("p", "x"); + assert!(store.latest(&k).await.unwrap().is_none()); - let mut s2 = make_snapshot("hash2", "mydb"); - s2.timestamp = Utc::now() - chrono::Duration::hours(1); - store.save_snapshot(url, &s2).unwrap(); + let mut s1 = make_snap("old", "x"); + s1.timestamp = Utc::now() - Duration::hours(1); + let s2 = make_snap("new", "x"); + store.put(&k, &s1).await.unwrap(); + store.put(&k, &s2).await.unwrap(); - let list = store.list_snapshots(url).unwrap(); - assert_eq!(list.len(), 2); - assert_eq!(list[0].content_hash, "hash2"); // newest first - assert_eq!(list[1].content_hash, "hash1"); + let latest = store.latest(&k).await.unwrap().unwrap(); + assert_eq!(latest.content_hash, "new"); } - #[test] - fn latest_snapshot() { + #[tokio::test] + async fn get_latest_returns_most_recent() { let (_dir, store) = temp_store(); - let url = "postgres://user@host/mydb"; + let k = key("p", "x"); + let mut s1 = make_snap("old", "x"); + s1.timestamp = Utc::now() - Duration::hours(1); + let s2 = make_snap("new", "x"); + store.put(&k, &s1).await.unwrap(); + store.put(&k, &s2).await.unwrap(); + + let got = store.get(&k, SnapshotRef::Latest).await.unwrap(); + assert_eq!(got.content_hash, "new"); + } - let mut s1 = make_snapshot("old", "mydb"); - s1.timestamp = Utc::now() - chrono::Duration::hours(1); - store.save_snapshot(url, &s1).unwrap(); + #[tokio::test] + async fn get_at_returns_at_or_before() { + let (_dir, store) = temp_store(); + let k = key("p", "x"); + let now = Utc::now(); + let mut s1 = make_snap("h1", "x"); + s1.timestamp = now - Duration::hours(2); + let mut s2 = make_snap("h2", "x"); + s2.timestamp = now; + store.put(&k, &s1).await.unwrap(); + store.put(&k, &s2).await.unwrap(); + + // at -1h: h2 is in the future, only h1 qualifies + let got = store + .get(&k, SnapshotRef::At(now - Duration::hours(1))) + .await + .unwrap(); + assert_eq!(got.content_hash, "h1"); + } - let s2 = make_snapshot("new", "mydb"); - store.save_snapshot(url, &s2).unwrap(); + #[tokio::test] + async fn get_hash_returns_matching_scoped_to_key() { + let (_dir, store) = temp_store(); + let a = key("p", "auth"); + let b = key("p", "billing"); + store.put(&a, &make_snap("shared", "auth")).await.unwrap(); + + // direct lookup under correct key works + let got = store + .get(&a, SnapshotRef::Hash("shared".into())) + .await + .unwrap(); + assert_eq!(got.content_hash, "shared"); + + // same hash under different key fails — content_hash lookup is key-scoped + let result = store.get(&b, SnapshotRef::Hash("shared".into())).await; + assert!(result.is_err()); + } - let latest = store.latest_snapshot(url).unwrap().unwrap(); - assert_eq!(latest.content_hash, "new"); + #[tokio::test] + async fn get_missing_returns_error() { + let (_dir, store) = temp_store(); + let k = key("p", "x"); + assert!(store.get(&k, SnapshotRef::Latest).await.is_err()); + assert!( + store + .get(&k, SnapshotRef::Hash("nope".into())) + .await + .is_err() + ); + assert!(store.get(&k, SnapshotRef::At(Utc::now())).await.is_err()); } - #[test] - fn different_urls_isolated() { + #[tokio::test] + async fn delete_before_returns_count_and_removes_old() { let (_dir, store) = temp_store(); - let url1 = "postgres://user@host/db1"; - let url2 = "postgres://user@host/db2"; + let k = key("p", "x"); + let now = Utc::now(); + for (i, hash) in ["h0", "h1", "h2", "h3"].iter().enumerate() { + let mut s = make_snap(hash, "x"); + s.timestamp = now - Duration::hours(3 - i as i64); + store.put(&k, &s).await.unwrap(); + } + + let deleted = store + .delete_before(&k, now - Duration::minutes(90)) + .await + .unwrap(); + assert_eq!(deleted, 2); // h0 (-3h) and h1 (-2h) - store.save_snapshot(url1, &make_snapshot("h1", "db1")).unwrap(); - store.save_snapshot(url2, &make_snapshot("h2", "db2")).unwrap(); + let remaining = store.list(&k, TimeRange::default()).await.unwrap(); + assert_eq!(remaining.len(), 2); + assert_eq!(remaining[0].content_hash, "h3"); + assert_eq!(remaining[1].content_hash, "h2"); + } - let list1 = store.list_snapshots(url1).unwrap(); - assert_eq!(list1.len(), 1); - assert_eq!(list1[0].content_hash, "h1"); + #[tokio::test] + async fn delete_before_scoped_to_key() { + let (_dir, store) = temp_store(); + let a = key("p", "auth"); + let b = key("p", "billing"); + let mut s = make_snap("h", "auth"); + s.timestamp = Utc::now() - Duration::hours(2); + store.put(&a, &s).await.unwrap(); + let mut s = make_snap("h", "billing"); + s.timestamp = Utc::now() - Duration::hours(2); + store.put(&b, &s).await.unwrap(); + + // delete in `a` should not touch `b` + let deleted = store + .delete_before(&a, Utc::now() - Duration::hours(1)) + .await + .unwrap(); + assert_eq!(deleted, 1); + assert_eq!(store.list(&a, TimeRange::default()).await.unwrap().len(), 0); + assert_eq!(store.list(&b, TimeRange::default()).await.unwrap().len(), 1); + } - let list2 = store.list_snapshots(url2).unwrap(); - assert_eq!(list2.len(), 1); - assert_eq!(list2[0].content_hash, "h2"); + #[tokio::test] + async fn list_keys_returns_distinct_streams_ordered() { + let (_dir, store) = temp_store(); + // empty store + assert!(store.list_keys().unwrap().is_empty()); + + // put under three streams, with one stream getting two snapshots + store + .put(&key("p", "billing"), &make_snap("h1", "billing")) + .await + .unwrap(); + store + .put(&key("p", "auth"), &make_snap("h2", "auth")) + .await + .unwrap(); + store + .put(&key("p", "auth"), &make_snap("h3", "auth")) + .await + .unwrap(); + store + .put(&key("other", "auth"), &make_snap("h4", "auth")) + .await + .unwrap(); + + let keys = store.list_keys().unwrap(); + // three distinct (project, database) pairs, ordered by project then database + assert_eq!(keys.len(), 3); + assert_eq!( + ( + keys[0].project_id.0.as_str(), + keys[0].database_id.0.as_str() + ), + ("other", "auth") + ); + assert_eq!( + ( + keys[1].project_id.0.as_str(), + keys[1].database_id.0.as_str() + ), + ("p", "auth") + ); + assert_eq!( + ( + keys[2].project_id.0.as_str(), + keys[2].database_id.0.as_str() + ), + ("p", "billing") + ); } } diff --git a/crates/dry_run_core/src/jit.rs b/crates/dry_run_core/src/jit.rs index 81546d4..f0ab058 100644 --- a/crates/dry_run_core/src/jit.rs +++ b/crates/dry_run_core/src/jit.rs @@ -67,12 +67,7 @@ pub fn add_column_volatile_default( } } -pub fn add_column_pre_pg11( - table: &str, - col: &str, - col_type: &str, - default_expr: &str, -) -> Entry { +pub fn add_column_pre_pg11(table: &str, col: &str, col_type: &str, default_expr: &str) -> Entry { Entry { status: "unsafe".into(), reason: format!( @@ -142,12 +137,7 @@ pub fn set_not_null(table: &str, col: &str, pg_major: u32) -> Entry { } } -pub fn add_foreign_key_unsafe( - table: &str, - col: &str, - ref_table: &str, - ref_col: &str, -) -> Entry { +pub fn add_foreign_key_unsafe(table: &str, col: &str, ref_table: &str, ref_col: &str) -> Entry { Entry { status: "unsafe".into(), reason: format!( @@ -177,12 +167,7 @@ pub fn add_check_constraint_unsafe(table: &str, constraint_expr: &str) -> Entry } } -pub fn create_index_blocking( - table: &str, - idx_name: &str, - method: &str, - columns: &str, -) -> Entry { +pub fn create_index_blocking(table: &str, idx_name: &str, method: &str, columns: &str) -> Entry { Entry { status: "unsafe".into(), reason: format!( @@ -278,9 +263,7 @@ pub fn suggest_gin(table: &str, col: &str, col_type: &str) -> Entry { reason: format!( "Column `{table}.{col}` ({col_type}) would benefit from a GIN index for containment and existence queries." ), - fix: format!( - "CREATE INDEX CONCURRENTLY ON {table} USING gin ({col});" - ), + fix: format!("CREATE INDEX CONCURRENTLY ON {table} USING gin ({col});"), note: Some("GIN indexes are ideal for JSONB, arrays, and full-text search columns.".into()), } } @@ -291,10 +274,10 @@ pub fn suggest_gist(table: &str, col: &str, col_type: &str) -> Entry { reason: format!( "Column `{table}.{col}` ({col_type}) would benefit from a GiST index for range or spatial queries." ), - fix: format!( - "CREATE INDEX CONCURRENTLY ON {table} USING gist ({col});" + fix: format!("CREATE INDEX CONCURRENTLY ON {table} USING gist ({col});"), + note: Some( + "GiST indexes are ideal for range types, geometric types, and inet/cidr.".into(), ), - note: Some("GiST indexes are ideal for range types, geometric types, and inet/cidr.".into()), } } @@ -304,9 +287,7 @@ pub fn suggest_partial_index(table: &str, col: &str, predicate: &str) -> Entry { reason: format!( "Column `{table}.{col}` is mostly filtered with `{predicate}`. A partial index avoids indexing irrelevant rows." ), - fix: format!( - "CREATE INDEX CONCURRENTLY ON {table} ({col}) WHERE {predicate};" - ), + fix: format!("CREATE INDEX CONCURRENTLY ON {table} ({col}) WHERE {predicate};"), note: None, } } @@ -402,9 +383,7 @@ pub fn partition_no_default(parent: &str) -> Entry { reason: format!( "Partitioned table `{parent}` has no DEFAULT partition. Rows that don't match any partition boundary will be rejected." ), - fix: format!( - "CREATE TABLE {parent}_default PARTITION OF {parent} DEFAULT;" - ), + fix: format!("CREATE TABLE {parent}_default PARTITION OF {parent} DEFAULT;"), note: None, } } diff --git a/crates/dry_run_core/src/lib.rs b/crates/dry_run_core/src/lib.rs index 4526a9b..131bd0e 100644 --- a/crates/dry_run_core/src/lib.rs +++ b/crates/dry_run_core/src/lib.rs @@ -11,7 +11,7 @@ pub mod schema; pub mod version; pub use audit::AuditConfig; -pub use config::{ConnectionConfig, ProjectConfig}; +pub use config::{ConnectionConfig, ProjectConfig, ResolvedProfile}; pub use connection::{DryRun, PrivilegeReport, ProbeResult}; pub use diff::SchemaChangeset; pub use error::{Error, Result}; diff --git a/crates/dry_run_core/src/lint/mod.rs b/crates/dry_run_core/src/lint/mod.rs index 03ca660..be2e3f4 100644 --- a/crates/dry_run_core/src/lint/mod.rs +++ b/crates/dry_run_core/src/lint/mod.rs @@ -72,8 +72,7 @@ pub fn compact_report(report: &LintReport, max_examples: usize) -> LintReportCom .then(b.count.cmp(&a.count)) }); - let total_violations = - report.summary.errors + report.summary.warnings + report.summary.info; + let total_violations = report.summary.errors + report.summary.warnings + report.summary.info; LintReportCompact { tables_checked: report.tables_checked, diff --git a/crates/dry_run_core/src/lint/rules/constraints.rs b/crates/dry_run_core/src/lint/rules/constraints.rs index e92c856..3651066 100644 --- a/crates/dry_run_core/src/lint/rules/constraints.rs +++ b/crates/dry_run_core/src/lint/rules/constraints.rs @@ -54,7 +54,11 @@ pub fn check_fk_has_index( } } -pub fn check_unnamed_constraints(table: &Table, qualified: &str, violations: &mut Vec) { +pub fn check_unnamed_constraints( + table: &Table, + qualified: &str, + violations: &mut Vec, +) { for constraint in &table.constraints { let name = &constraint.name; let is_auto = name.ends_with("_pkey") @@ -72,7 +76,7 @@ pub fn check_unnamed_constraints(table: &Table, qualified: &str, violations: &mu message: format!("constraint '{}' appears to be auto-generated", name), recommendation: "name constraints explicitly for readable error messages".into(), ddl_fix: None, - convention_doc: "constraints".into(), + convention_doc: "constraints".into(), }); } } diff --git a/crates/dry_run_core/src/lint/rules/mod.rs b/crates/dry_run_core/src/lint/rules/mod.rs index 1a8f2e0..196738c 100644 --- a/crates/dry_run_core/src/lint/rules/mod.rs +++ b/crates/dry_run_core/src/lint/rules/mod.rs @@ -212,17 +212,29 @@ mod tests { fn make_col(name: &str, type_name: &str) -> Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, } } fn make_fk(name: &str, columns: &[&str], fk_table: &str) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::ForeignKey, + name: name.into(), + kind: ConstraintKind::ForeignKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: Some(fk_table.into()), - fk_columns: vec!["id".into()], backing_index: None, comment: None, + definition: None, + fk_table: Some(fk_table.into()), + fk_columns: vec!["id".into()], + backing_index: None, + comment: None, } } @@ -230,8 +242,11 @@ mod tests { Index { name: name.into(), columns: columns.iter().map(|s| s.to_string()).collect(), - include_columns: vec![], index_type: "btree".into(), - is_unique: false, is_primary: false, predicate: None, + include_columns: vec![], + index_type: "btree".into(), + is_unique: false, + is_primary: false, + predicate: None, definition: format!("CREATE INDEX {name} ON ..."), is_valid: true, backs_constraint: false, @@ -246,19 +261,37 @@ mod tests { indexes: Vec, ) -> Table { Table { - oid: 0, schema: "public".into(), name: name.into(), - columns, constraints, indexes, - comment: None, stats: None, partition_info: None, - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + oid: 0, + schema: "public".into(), + name: name.into(), + columns, + constraints, + indexes, + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, } } fn schema_with(tables: Vec
) -> SchemaSnapshot { SchemaSnapshot { - pg_version: "PostgreSQL 17.0".into(), database: "test".into(), - timestamp: Utc::now(), content_hash: "abc".into(), source: None, - tables, enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + pg_version: "PostgreSQL 17.0".into(), + database: "test".into(), + timestamp: Utc::now(), + content_hash: "abc".into(), + source: None, + tables, + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -268,13 +301,19 @@ mod tests { config.min_severity = Severity::Info; // disable everything except fk_has_index to isolate the test config.disabled_rules = vec![ - "naming/table_style".into(), "naming/column_style".into(), - "naming/fk_pattern".into(), "naming/index_pattern".into(), - "pk/exists".into(), "pk/bigint_identity".into(), - "types/text_over_varchar".into(), "types/timestamptz".into(), - "types/no_serial".into(), "types/bigint_pk_fk".into(), + "naming/table_style".into(), + "naming/column_style".into(), + "naming/fk_pattern".into(), + "naming/index_pattern".into(), + "pk/exists".into(), + "pk/bigint_identity".into(), + "types/text_over_varchar".into(), + "types/timestamptz".into(), + "types/no_serial".into(), + "types/bigint_pk_fk".into(), "constraints/unnamed".into(), - "timestamps/has_created_at".into(), "timestamps/has_updated_at".into(), + "timestamps/has_created_at".into(), + "timestamps/has_updated_at".into(), "timestamps/correct_type".into(), ]; config @@ -285,13 +324,28 @@ mod tests { // FK (order_id, product_id) covered by index (order_id, product_id, status) let schema = schema_with(vec![make_table_with( "line_item", - vec![make_col("order_id", "bigint"), make_col("product_id", "bigint"), make_col("status", "text")], - vec![make_fk("fk_line_item_order_product", &["order_id", "product_id"], "public.order")], - vec![make_index("idx_line_item_composite", &["order_id", "product_id", "status"])], + vec![ + make_col("order_id", "bigint"), + make_col("product_id", "bigint"), + make_col("status", "text"), + ], + vec![make_fk( + "fk_line_item_order_product", + &["order_id", "product_id"], + "public.order", + )], + vec![make_index( + "idx_line_item_composite", + &["order_id", "product_id", "status"], + )], )]); let violations = run_all_rules(&schema, &only_fk_rules()); - assert!(!violations.iter().any(|v| v.rule == "constraints/fk_has_index"), - "3-col index covering 2-col FK as prefix should pass"); + assert!( + !violations + .iter() + .any(|v| v.rule == "constraints/fk_has_index"), + "3-col index covering 2-col FK as prefix should pass" + ); } #[test] @@ -299,13 +353,27 @@ mod tests { // FK (order_id, product_id) but index is (product_id, order_id) — wrong prefix order let schema = schema_with(vec![make_table_with( "line_item", - vec![make_col("order_id", "bigint"), make_col("product_id", "bigint")], - vec![make_fk("fk_line_item_order_product", &["order_id", "product_id"], "public.order")], - vec![make_index("idx_line_item_wrong_order", &["product_id", "order_id"])], + vec![ + make_col("order_id", "bigint"), + make_col("product_id", "bigint"), + ], + vec![make_fk( + "fk_line_item_order_product", + &["order_id", "product_id"], + "public.order", + )], + vec![make_index( + "idx_line_item_wrong_order", + &["product_id", "order_id"], + )], )]); let violations = run_all_rules(&schema, &only_fk_rules()); - assert!(violations.iter().any(|v| v.rule == "constraints/fk_has_index"), - "index with swapped column order should NOT satisfy the FK"); + assert!( + violations + .iter() + .any(|v| v.rule == "constraints/fk_has_index"), + "index with swapped column order should NOT satisfy the FK" + ); } #[test] @@ -313,13 +381,24 @@ mod tests { // FK (order_id, product_id) but index only on (order_id) — not enough columns let schema = schema_with(vec![make_table_with( "line_item", - vec![make_col("order_id", "bigint"), make_col("product_id", "bigint")], - vec![make_fk("fk_line_item_order_product", &["order_id", "product_id"], "public.order")], + vec![ + make_col("order_id", "bigint"), + make_col("product_id", "bigint"), + ], + vec![make_fk( + "fk_line_item_order_product", + &["order_id", "product_id"], + "public.order", + )], vec![make_index("idx_line_item_order_id", &["order_id"])], )]); let violations = run_all_rules(&schema, &only_fk_rules()); - assert!(violations.iter().any(|v| v.rule == "constraints/fk_has_index"), - "single-col index should NOT satisfy 2-col FK"); + assert!( + violations + .iter() + .any(|v| v.rule == "constraints/fk_has_index"), + "single-col index should NOT satisfy 2-col FK" + ); } #[test] @@ -327,13 +406,27 @@ mod tests { // FK (order_id, product_id) with index (order_id, product_id) — exact match let schema = schema_with(vec![make_table_with( "line_item", - vec![make_col("order_id", "bigint"), make_col("product_id", "bigint")], - vec![make_fk("fk_line_item_order_product", &["order_id", "product_id"], "public.order")], - vec![make_index("idx_line_item_order_product", &["order_id", "product_id"])], + vec![ + make_col("order_id", "bigint"), + make_col("product_id", "bigint"), + ], + vec![make_fk( + "fk_line_item_order_product", + &["order_id", "product_id"], + "public.order", + )], + vec![make_index( + "idx_line_item_order_product", + &["order_id", "product_id"], + )], )]); let violations = run_all_rules(&schema, &only_fk_rules()); - assert!(!violations.iter().any(|v| v.rule == "constraints/fk_has_index"), - "exact match index should satisfy the FK"); + assert!( + !violations + .iter() + .any(|v| v.rule == "constraints/fk_has_index"), + "exact match index should satisfy the FK" + ); } // --- partition dedup helpers --- @@ -348,28 +441,47 @@ mod tests { fn make_partitioned_table(name: &str, children: Vec) -> Table { Table { - oid: 0, schema: "public".into(), name: name.into(), + oid: 0, + schema: "public".into(), + name: name.into(), columns: vec![make_col("id", "integer")], - constraints: vec![], indexes: vec![], - comment: None, stats: None, + constraints: vec![], + indexes: vec![], + comment: None, + stats: None, partition_info: Some(PartitionInfo { strategy: PartitionStrategy::Range, key: "created_at".into(), children, }), - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, } } /// Config that only enables the given rules fn config_with_only(rules: &[&str]) -> LintConfig { let all_rules = [ - "naming/table_style", "naming/column_style", "naming/fk_pattern", - "naming/index_pattern", "pk/exists", "pk/bigint_identity", - "types/text_over_varchar", "types/timestamptz", "types/no_serial", - "types/bigint_pk_fk", "constraints/fk_has_index", "constraints/unnamed", - "timestamps/has_created_at", "timestamps/has_updated_at", "timestamps/correct_type", - "partition/too_many_children", "partition/range_gaps", "partition/no_default", + "naming/table_style", + "naming/column_style", + "naming/fk_pattern", + "naming/index_pattern", + "pk/exists", + "pk/bigint_identity", + "types/text_over_varchar", + "types/timestamptz", + "types/no_serial", + "types/bigint_pk_fk", + "constraints/fk_has_index", + "constraints/unnamed", + "timestamps/has_created_at", + "timestamps/has_updated_at", + "timestamps/correct_type", + "partition/too_many_children", + "partition/range_gaps", + "partition/no_default", "partition/gucs", ]; let mut config = LintConfig::default(); @@ -384,18 +496,29 @@ mod tests { fn make_pk(name: &str, columns: &[&str]) -> Constraint { Constraint { - name: name.into(), kind: ConstraintKind::PrimaryKey, + name: name.into(), + kind: ConstraintKind::PrimaryKey, columns: columns.iter().map(|s| s.to_string()).collect(), - definition: None, fk_table: None, - fk_columns: vec![], backing_index: None, comment: None, + definition: None, + fk_table: None, + fk_columns: vec![], + backing_index: None, + comment: None, } } fn make_col_with_default(name: &str, type_name: &str, default: &str) -> Column { Column { - name: name.into(), ordinal: 0, type_name: type_name.into(), - nullable: false, default: Some(default.into()), identity: None, generated: None, - comment: None, statistics_target: None, stats: None, + name: name.into(), + ordinal: 0, + type_name: type_name.into(), + nullable: false, + default: Some(default.into()), + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, } } @@ -403,14 +526,32 @@ mod tests { #[test] fn partition_parent_with_three_children_only_parent_violations() { - let parent = make_partitioned_table("event", vec![ - make_partition_child("event_2024_01"), - make_partition_child("event_2024_02"), - make_partition_child("event_2024_03"), - ]); - let child1 = make_table_with("event_2024_01", vec![make_col("id", "integer")], vec![], vec![]); - let child2 = make_table_with("event_2024_02", vec![make_col("id", "integer")], vec![], vec![]); - let child3 = make_table_with("event_2024_03", vec![make_col("id", "integer")], vec![], vec![]); + let parent = make_partitioned_table( + "event", + vec![ + make_partition_child("event_2024_01"), + make_partition_child("event_2024_02"), + make_partition_child("event_2024_03"), + ], + ); + let child1 = make_table_with( + "event_2024_01", + vec![make_col("id", "integer")], + vec![], + vec![], + ); + let child2 = make_table_with( + "event_2024_02", + vec![make_col("id", "integer")], + vec![], + vec![], + ); + let child3 = make_table_with( + "event_2024_03", + vec![make_col("id", "integer")], + vec![], + vec![], + ); let schema = schema_with(vec![parent, child1, child2, child3]); let config = config_with_only(&["pk/exists"]); @@ -423,23 +564,31 @@ mod tests { #[test] fn nested_partitions_grandchild_also_skipped() { - let parent = make_partitioned_table("event", vec![ - make_partition_child("event_2024_01"), - ]); + let parent = make_partitioned_table("event", vec![make_partition_child("event_2024_01")]); let mid = Table { - oid: 0, schema: "public".into(), name: "event_2024_01".into(), + oid: 0, + schema: "public".into(), + name: "event_2024_01".into(), columns: vec![make_col("id", "integer")], - constraints: vec![], indexes: vec![], - comment: None, stats: None, + constraints: vec![], + indexes: vec![], + comment: None, + stats: None, partition_info: Some(PartitionInfo { strategy: PartitionStrategy::Hash, key: "id".into(), children: vec![make_partition_child("event_2024_01_h0")], }), - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }; let grandchild = make_table_with( - "event_2024_01_h0", vec![make_col("id", "integer")], vec![], vec![], + "event_2024_01_h0", + vec![make_col("id", "integer")], + vec![], + vec![], ); let schema = schema_with(vec![parent, mid, grandchild]); @@ -458,17 +607,22 @@ mod tests { let table = make_table_with( "user", vec![make_col("created_at", "timestamp without time zone")], - vec![], vec![], + vec![], + vec![], ); let schema = schema_with(vec![table]); let config = config_with_only(&["timestamps/correct_type", "types/timestamptz"]); let violations = run_all_rules(&schema, &config); let rules: Vec<&str> = violations.iter().map(|v| v.rule.as_str()).collect(); - assert!(rules.contains(&"timestamps/correct_type"), - "winner rule should fire"); - assert!(!rules.contains(&"types/timestamptz"), - "loser rule should be suppressed"); + assert!( + rules.contains(&"timestamps/correct_type"), + "winner rule should fire" + ); + assert!( + !rules.contains(&"types/timestamptz"), + "loser rule should be suppressed" + ); } #[test] @@ -476,7 +630,11 @@ mod tests { // integer PK with serial default should fire pk/bigint_identity but NOT types/no_serial let table = make_table_with( "user", - vec![make_col_with_default("id", "integer", "nextval('user_id_seq')")], + vec![make_col_with_default( + "id", + "integer", + "nextval('user_id_seq')", + )], vec![make_pk("user_pkey", &["id"])], vec![], ); @@ -485,10 +643,14 @@ mod tests { let violations = run_all_rules(&schema, &config); let rules: Vec<&str> = violations.iter().map(|v| v.rule.as_str()).collect(); - assert!(rules.contains(&"pk/bigint_identity"), - "winner rule should fire"); - assert!(!rules.contains(&"types/no_serial"), - "loser rule should be suppressed"); + assert!( + rules.contains(&"pk/bigint_identity"), + "winner rule should fire" + ); + assert!( + !rules.contains(&"types/no_serial"), + "loser rule should be suppressed" + ); } #[test] @@ -497,14 +659,17 @@ mod tests { let table = make_table_with( "user", vec![make_col("created_at", "timestamp without time zone")], - vec![], vec![], + vec![], + vec![], ); let schema = schema_with(vec![table]); let config = config_with_only(&["types/timestamptz"]); let violations = run_all_rules(&schema, &config); - assert!(violations.iter().any(|v| v.rule == "types/timestamptz"), - "loser should fire when winner is disabled"); + assert!( + violations.iter().any(|v| v.rule == "types/timestamptz"), + "loser should fire when winner is disabled" + ); } // --- Change 3: auto-detect table name style tests --- @@ -547,9 +712,11 @@ mod tests { let violations = run_all_rules(&schema, &config); // snake_plural doesn't check for plural (just snake_case), so no violations expected - assert!(violations.is_empty(), + assert!( + violations.is_empty(), "auto-resolved to snake_plural should accept all snake_case names, got: {:?}", - violations.iter().map(|v| &v.table).collect::>()); + violations.iter().map(|v| &v.table).collect::>() + ); } // --- partition lint rules --- @@ -592,10 +759,14 @@ mod tests { #[test] fn partition_range_gaps_detected() { let table = Table { - oid: 0, schema: "public".into(), name: "events".into(), + oid: 0, + schema: "public".into(), + name: "events".into(), columns: vec![make_col("id", "integer")], - constraints: vec![], indexes: vec![], - comment: None, stats: None, + constraints: vec![], + indexes: vec![], + comment: None, + stats: None, partition_info: Some(PartitionInfo { strategy: PartitionStrategy::Range, key: "created_at".into(), @@ -613,7 +784,10 @@ mod tests { }, ], }), - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }; let schema = schema_with(vec![table]); let config = config_with_only(&["partition/range_gaps"]); @@ -624,13 +798,14 @@ mod tests { #[test] fn partition_no_default_warns() { - let table = make_partitioned_table("orders", vec![ - PartitionChild { + let table = make_partitioned_table( + "orders", + vec![PartitionChild { schema: "public".into(), name: "orders_q1".into(), bound: "FOR VALUES FROM ('2024-01-01') TO ('2024-04-01')".into(), - }, - ]); + }], + ); let schema = schema_with(vec![table]); let config = config_with_only(&["partition/no_default"]); let violations = run_all_rules(&schema, &config); @@ -640,18 +815,21 @@ mod tests { #[test] fn partition_no_default_skips_when_default_exists() { - let table = make_partitioned_table("orders", vec![ - PartitionChild { - schema: "public".into(), - name: "orders_q1".into(), - bound: "FOR VALUES FROM ('2024-01-01') TO ('2024-04-01')".into(), - }, - PartitionChild { - schema: "public".into(), - name: "orders_default".into(), - bound: "DEFAULT".into(), - }, - ]); + let table = make_partitioned_table( + "orders", + vec![ + PartitionChild { + schema: "public".into(), + name: "orders_q1".into(), + bound: "FOR VALUES FROM ('2024-01-01') TO ('2024-04-01')".into(), + }, + PartitionChild { + schema: "public".into(), + name: "orders_default".into(), + bound: "DEFAULT".into(), + }, + ], + ); let schema = schema_with(vec![table]); let config = config_with_only(&["partition/no_default"]); let violations = run_all_rules(&schema, &config); @@ -669,6 +847,10 @@ mod tests { }); let config = config_with_only(&["partition/gucs"]); let violations = run_all_rules(&schema, &config); - assert!(violations.iter().any(|v| v.message.contains("enable_partition_pruning"))); + assert!( + violations + .iter() + .any(|v| v.message.contains("enable_partition_pruning")) + ); } } diff --git a/crates/dry_run_core/src/lint/rules/naming.rs b/crates/dry_run_core/src/lint/rules/naming.rs index eccabc8..8d95871 100644 --- a/crates/dry_run_core/src/lint/rules/naming.rs +++ b/crates/dry_run_core/src/lint/rules/naming.rs @@ -86,7 +86,7 @@ pub fn check_column_name_style( ), recommendation: format!("rename to match {} convention", config.column_name_style), ddl_fix: None, - convention_doc: "naming".into(), + convention_doc: "naming".into(), }); } } @@ -119,7 +119,7 @@ pub fn check_fk_naming( ), recommendation: format!("rename constraint to '{expected}'"), ddl_fix: None, - convention_doc: "naming".into(), + convention_doc: "naming".into(), }); } } @@ -152,7 +152,7 @@ pub fn check_index_naming( ), recommendation: format!("rename index to '{expected}'"), ddl_fix: None, - convention_doc: "naming".into(), + convention_doc: "naming".into(), }); } } diff --git a/crates/dry_run_core/src/lint/rules/partitions.rs b/crates/dry_run_core/src/lint/rules/partitions.rs index 40f25ca..2d7fe8d 100644 --- a/crates/dry_run_core/src/lint/rules/partitions.rs +++ b/crates/dry_run_core/src/lint/rules/partitions.rs @@ -26,9 +26,7 @@ pub fn check_partition_too_many_children( severity: Severity::Warning, table: qualified.into(), column: None, - message: format!( - "table has {n} partitions; planning overhead may be significant" - ), + message: format!("table has {n} partitions; planning overhead may be significant"), recommendation: rec, ddl_fix: None, convention_doc: "partitioning".into(), @@ -55,9 +53,8 @@ pub fn check_partition_range_gaps( .children .iter() .filter_map(|c| { - re.captures(&c.bound).map(|cap| { - (cap[1].to_string(), cap[2].to_string()) - }) + re.captures(&c.bound) + .map(|cap| (cap[1].to_string(), cap[2].to_string())) }) .collect(); @@ -77,7 +74,7 @@ pub fn check_partition_range_gaps( ), recommendation: e.reason, ddl_fix: Some(e.fix), - convention_doc: "partitioning".into(), + convention_doc: "partitioning".into(), }); } } @@ -113,10 +110,7 @@ pub fn check_partition_no_default( } pub fn check_partition_gucs(schema: &SchemaSnapshot, violations: &mut Vec) { - let has_partitioned = schema - .tables - .iter() - .any(|t| t.partition_info.is_some()); + let has_partitioned = schema.tables.iter().any(|t| t.partition_info.is_some()); if !has_partitioned { return; @@ -142,7 +136,7 @@ pub fn check_partition_gucs(schema: &SchemaSnapshot, violations: &mut Vec) { for col in &table.columns { if let Some(default) = &col.default - && default.to_lowercase().contains("nextval(") { - violations.push(LintViolation { - rule: "types/no_serial".into(), - severity: Severity::Warning, - table: qualified.into(), - column: Some(col.name.clone()), - message: format!( - "column '{}' uses serial/sequence default ({})", - col.name, default - ), - recommendation: "use bigint GENERATED ALWAYS AS IDENTITY instead of serial" - .into(), - ddl_fix: None, - convention_doc: "types".into(), - }); - } + && default.to_lowercase().contains("nextval(") + { + violations.push(LintViolation { + rule: "types/no_serial".into(), + severity: Severity::Warning, + table: qualified.into(), + column: Some(col.name.clone()), + message: format!( + "column '{}' uses serial/sequence default ({})", + col.name, default + ), + recommendation: "use bigint GENERATED ALWAYS AS IDENTITY instead of serial".into(), + ddl_fix: None, + convention_doc: "types".into(), + }); + } } } -pub fn check_bigint_pk_fk(table: &Table, qualified: &str, config: &LintConfig, violations: &mut Vec) { +pub fn check_bigint_pk_fk( + table: &Table, + qualified: &str, + config: &LintConfig, + violations: &mut Vec, +) { let pk_cols: Vec<&str> = table .constraints .iter() @@ -104,10 +109,7 @@ pub fn check_bigint_pk_fk(table: &Table, qualified: &str, config: &LintConfig, v if is_int && config.pk_type == "int_identity" { continue; } - if is_int - || type_lower == "smallint" - || type_lower == "int2" - { + if is_int || type_lower == "smallint" || type_lower == "int2" { violations.push(LintViolation { rule: "types/bigint_pk_fk".into(), severity: Severity::Warning, @@ -119,7 +121,7 @@ pub fn check_bigint_pk_fk(table: &Table, qualified: &str, config: &LintConfig, v ), recommendation: "use bigint for PK and FK columns".into(), ddl_fix: None, - convention_doc: "types".into(), + convention_doc: "types".into(), }); } } diff --git a/crates/dry_run_core/src/query/advise.rs b/crates/dry_run_core/src/query/advise.rs index 20b614f..186952d 100644 --- a/crates/dry_run_core/src/query/advise.rs +++ b/crates/dry_run_core/src/query/advise.rs @@ -141,15 +141,17 @@ fn advise_seq_scan( // stats-aware refinements if let Some(col) = col_obj - && col.stats.is_some() { - let mut table_rows = node.plan_rows; - if let Some(t) = table - && let Some(s) = &t.stats - && s.reltuples > table_rows { - table_rows = s.reltuples; - } - recommendation.push_str(&stats_aware_advice(col, filter_col_name, table_rows)); + && col.stats.is_some() + { + let mut table_rows = node.plan_rows; + if let Some(t) = table + && let Some(s) = &t.stats + && s.reltuples > table_rows + { + table_rows = s.reltuples; } + recommendation.push_str(&stats_aware_advice(col, filter_col_name, table_rows)); + } let idx_name = format!("idx_{table_name}_{filter_col_name}"); @@ -339,7 +341,9 @@ fn stats_aware_advice(col: &Column, filter_col: &str, table_rows: f64) -> String } else if nd > 0.0 && nd <= 20.0 { parts.push(format!( "\nColumn '{}' has {} distinct values (selectivity ~{:.1}%).", - filter_col, nd as i64, sel * 100.0 + filter_col, + nd as i64, + sel * 100.0 )); } } @@ -354,22 +358,26 @@ fn stats_aware_advice(col: &Column, filter_col: &str, table_rows: f64) -> String // high null fraction if let Some(nf) = stats.null_frac - && nf > 0.5 { - let null_rows = (nf * table_rows) as i64; - parts.push(format!( + && nf > 0.5 + { + let null_rows = (nf * table_rows) as i64; + parts.push(format!( "Column is {:.0}% NULL (~{} rows). Use a partial index WHERE {} IS NOT NULL to index only the non-null rows.", nf * 100.0, null_rows, filter_col )); - } + } // correlation warning for range scans if let Some(c) = stats.correlation - && c > -0.3 && c < 0.3 && table_rows > 10_000.0 { - parts.push(format!( + && c > -0.3 + && c < 0.3 + && table_rows > 10_000.0 + { + parts.push(format!( "Physical ordering is random (correlation: {:.2}); index range scans will cause random I/O.", c )); - } + } parts.join(" ") } @@ -420,8 +428,12 @@ fn suggest_index_type(table: &str, col_type: &str, col_name: &str) -> (&'static }; return ("gin", rec); } - if ct.contains("geometry") || ct.contains("geography") || ct.contains("range") - || ct == "tsrange" || ct == "daterange" || ct == "int4range" + if ct.contains("geometry") + || ct.contains("geography") + || ct.contains("range") + || ct == "tsrange" + || ct == "daterange" + || ct == "int4range" { let e = jit::suggest_gist(table, col_name, col_type); return ("gist", e.reason); @@ -470,28 +482,96 @@ mod tests { content_hash: "test".into(), source: None, tables: vec![Table { - oid: 1, schema: "public".into(), name: "orders".into(), + oid: 1, + schema: "public".into(), + name: "orders".into(), columns: vec![ - Column { name: "id".into(), ordinal: 1, type_name: "bigint".into(), nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, - Column { name: "customer_id".into(), ordinal: 2, type_name: "bigint".into(), nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, - Column { name: "data".into(), ordinal: 3, type_name: "jsonb".into(), nullable: true, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, + Column { + name: "id".into(), + ordinal: 1, + type_name: "bigint".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, + Column { + name: "customer_id".into(), + ordinal: 2, + type_name: "bigint".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, + Column { + name: "data".into(), + ordinal: 3, + type_name: "jsonb".into(), + nullable: true, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, ], - constraints: vec![], indexes: vec![], comment: None, stats: None, - partition_info: None, policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + constraints: vec![], + indexes: vec![], + comment: None, + stats: None, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }], - enums: vec![], domains: vec![], composites: vec![], views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } fn make_seq_scan(table: &str, rows: f64, filter: Option<&str>) -> PlanNode { PlanNode { - node_type: "Seq Scan".into(), relation_name: Some(table.into()), schema: Some("public".into()), - alias: None, startup_cost: 0.0, total_cost: rows * 0.01, plan_rows: rows, plan_width: 64, - actual_rows: None, actual_loops: None, actual_startup_time: None, actual_total_time: None, - shared_hit_blocks: None, shared_read_blocks: None, index_name: None, index_cond: None, - filter: filter.map(String::from), rows_removed_by_filter: None, - sort_key: None, sort_method: None, hash_cond: None, join_type: None, subplans_removed: None, cte_name: None, parent_relationship: None, children: vec![], + node_type: "Seq Scan".into(), + relation_name: Some(table.into()), + schema: Some("public".into()), + alias: None, + startup_cost: 0.0, + total_cost: rows * 0.01, + plan_rows: rows, + plan_width: 64, + actual_rows: None, + actual_loops: None, + actual_startup_time: None, + actual_total_time: None, + shared_hit_blocks: None, + shared_read_blocks: None, + index_name: None, + index_cond: None, + filter: filter.map(String::from), + rows_removed_by_filter: None, + sort_key: None, + sort_method: None, + hash_cond: None, + join_type: None, + subplans_removed: None, + cte_name: None, + parent_relationship: None, + children: vec![], } } @@ -527,7 +607,11 @@ mod tests { fn advise_includes_version_note() { let schema = empty_schema(); let plan = make_seq_scan("orders", 100_000.0, Some("(customer_id = 42)")); - let pg14 = PgVersion { major: 14, minor: 0, patch: 0 }; + let pg14 = PgVersion { + major: 14, + minor: 0, + patch: 0, + }; let advice = advise(&plan, &schema, Some(&pg14)); assert!(!advice.is_empty()); assert!(advice[0].version_note.is_some()); @@ -545,10 +629,16 @@ mod tests { schema: "public".into(), table: "orders".into(), stats: TableStats { - reltuples: 100_000.0, relpages: 1250, dead_tuples: 0, - last_vacuum: None, last_autovacuum: None, - last_analyze: None, last_autoanalyze: None, - seq_scan: 100, idx_scan: 5000, table_size: 10_000_000, + reltuples: 100_000.0, + relpages: 1250, + dead_tuples: 0, + last_vacuum: None, + last_autovacuum: None, + last_analyze: None, + last_autoanalyze: None, + seq_scan: 100, + idx_scan: 5000, + table_size: 10_000_000, }, }], index_stats: vec![], @@ -562,10 +652,16 @@ mod tests { schema: "public".into(), table: "orders".into(), stats: TableStats { - reltuples: 100_000.0, relpages: 1250, dead_tuples: 0, - last_vacuum: None, last_autovacuum: None, - last_analyze: None, last_autoanalyze: None, - seq_scan: 42000, idx_scan: 1000, table_size: 10_000_000, + reltuples: 100_000.0, + relpages: 1250, + dead_tuples: 0, + last_vacuum: None, + last_autovacuum: None, + last_analyze: None, + last_autoanalyze: None, + seq_scan: 42000, + idx_scan: 1000, + table_size: 10_000_000, }, }], index_stats: vec![], @@ -582,8 +678,17 @@ mod tests { #[test] fn extract_column_simple() { - assert_eq!(extract_column_from_filter("(customer_id = 42)"), Some("customer_id".into())); - assert_eq!(extract_column_from_filter("(status IS NOT NULL)"), Some("status".into())); - assert_eq!(extract_column_from_filter("(t.name = 'foo')"), Some("name".into())); + assert_eq!( + extract_column_from_filter("(customer_id = 42)"), + Some("customer_id".into()) + ); + assert_eq!( + extract_column_from_filter("(status IS NOT NULL)"), + Some("status".into()) + ); + assert_eq!( + extract_column_from_filter("(t.name = 'foo')"), + Some("name".into()) + ); } } diff --git a/crates/dry_run_core/src/query/antipatterns.rs b/crates/dry_run_core/src/query/antipatterns.rs index 36029dc..8e1afbd 100644 --- a/crates/dry_run_core/src/query/antipatterns.rs +++ b/crates/dry_run_core/src/query/antipatterns.rs @@ -50,16 +50,17 @@ fn detect_unbounded_query( let reltuples = effective_table_stats(table, schema).map(|s| s.reltuples); if let Some(rows) = reltuples - && rows > LARGE_TABLE_THRESHOLD { - warnings.push(ValidationWarning { - severity: WarningSeverity::Warning, - message: format!( - "unbounded query on {}.{} (~{} rows) with no WHERE or LIMIT — \ + && rows > LARGE_TABLE_THRESHOLD + { + warnings.push(ValidationWarning { + severity: WarningSeverity::Warning, + message: format!( + "unbounded query on {}.{} (~{} rows) with no WHERE or LIMIT — \ consider adding a filter or LIMIT clause", - table.schema, table.name, rows as i64 - ), - }); - } + table.schema, table.name, rows as i64 + ), + }); + } } } } @@ -217,9 +218,9 @@ fn detect_partition_key_update( fn func_wrap_rewrite_hint(func_name: &str, col: &str) -> String { match func_name { - "extract" | "::date" | "to_char" => format!( - "Rewrite as: WHERE {col} >= '2025-01-01' AND {col} < '2026-01-01'" - ), + "extract" | "::date" | "to_char" => { + format!("Rewrite as: WHERE {col} >= '2025-01-01' AND {col} < '2026-01-01'") + } "date_trunc" => format!( "Rewrite as: WHERE {col} >= date_trunc('month', target) \ AND {col} < date_trunc('month', target) + interval '1 month'" @@ -238,10 +239,8 @@ fn parse_partition_key_columns(key: &str) -> Vec { #[cfg(test)] mod tests { use super::*; - use crate::schema::{ - PartitionChild, PartitionInfo, PartitionStrategy, Table, - }; use crate::query::{QueryInfo, ReferencedTable}; + use crate::schema::{PartitionChild, PartitionInfo, PartitionStrategy, Table}; fn partitioned_snapshot() -> SchemaSnapshot { SchemaSnapshot { @@ -317,7 +316,11 @@ mod tests { let mut warnings = Vec::new(); detect_partition_key_antipatterns(&parsed, &snap, &mut warnings); assert_eq!(warnings.len(), 1); - assert!(warnings[0].message.contains("does not filter on partition key")); + assert!( + warnings[0] + .message + .contains("does not filter on partition key") + ); } #[test] @@ -350,7 +353,11 @@ mod tests { let mut warnings = Vec::new(); detect_partition_key_antipatterns(&parsed, &snap, &mut warnings); // should have a func-wrap warning (partition key is in filter_columns so no missing-key warning) - assert!(warnings.iter().any(|w| w.message.contains("wrapped in extract"))); + assert!( + warnings + .iter() + .any(|w| w.message.contains("wrapped in extract")) + ); assert!(warnings.iter().any(|w| w.message.contains("Rewrite as"))); } diff --git a/crates/dry_run_core/src/query/explain.rs b/crates/dry_run_core/src/query/explain.rs index 9008139..ef20afa 100644 --- a/crates/dry_run_core/src/query/explain.rs +++ b/crates/dry_run_core/src/query/explain.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use sqlx::PgPool; -use super::plan::{parse_plan_json, PlanNode}; +use super::plan::{PlanNode, parse_plan_json}; use super::plan_warnings::detect_plan_warnings; use crate::error::{Error, Result}; use crate::schema::SchemaSnapshot; diff --git a/crates/dry_run_core/src/query/migration.rs b/crates/dry_run_core/src/query/migration.rs index 26768ca..62d085d 100644 --- a/crates/dry_run_core/src/query/migration.rs +++ b/crates/dry_run_core/src/query/migration.rs @@ -45,9 +45,9 @@ pub fn check_migration( if let Some(pg_query::protobuf::node::Node::AlterTableCmd(cmd)) = &cmd_node.node && let Some(check) = analyze_alter_table_cmd(cmd, &result, schema, pg_version) - { - checks.push(check); - } + { + checks.push(check); + } } } NodeRef::IndexStmt(idx) => { @@ -61,9 +61,10 @@ pub fn check_migration( } if checks.is_empty() - && let Some(check) = fallback_keyword_check(ddl, schema, pg_version) { - checks.push(check); - } + && let Some(check) = fallback_keyword_check(ddl, schema, pg_version) + { + checks.push(check); + } Ok(checks) } @@ -282,7 +283,9 @@ fn analyze_add_constraint( ) } else { let e = match operation { - "ADD FOREIGN KEY" => jit::add_foreign_key_unsafe(table_name, "", "", ""), + "ADD FOREIGN KEY" => { + jit::add_foreign_key_unsafe(table_name, "", "", "") + } "ADD CHECK CONSTRAINT" => jit::add_check_constraint_unsafe(table_name, ""), _ => jit::add_check_constraint_unsafe(table_name, ""), }; @@ -339,7 +342,11 @@ fn analyze_create_index( "SHARE UPDATE EXCLUSIVE".to_string(), ) } else { - let idx_method = if idx.access_method.is_empty() { "btree" } else { &idx.access_method }; + let idx_method = if idx.access_method.is_empty() { + "btree" + } else { + &idx.access_method + }; let e = jit::create_index_blocking(&table_name, &idx.idxname, idx_method, ""); ( SafetyRating::Dangerous, @@ -419,7 +426,11 @@ fn fallback_keyword_check( None } -fn find_column<'a>(schema: &'a SchemaSnapshot, table_name: &str, col_name: &str) -> Option<&'a crate::schema::Column> { +fn find_column<'a>( + schema: &'a SchemaSnapshot, + table_name: &str, + col_name: &str, +) -> Option<&'a crate::schema::Column> { let (schema_part, name_part) = if let Some((s, n)) = table_name.rsplit_once('.') { (s, n) } else { @@ -487,30 +498,58 @@ mod tests { content_hash: "test".into(), source: None, tables: vec![Table { - oid: 1, schema: "public".into(), name: "orders".into(), - columns: vec![], constraints: vec![], indexes: vec![], + oid: 1, + schema: "public".into(), + name: "orders".into(), + columns: vec![], + constraints: vec![], + indexes: vec![], comment: None, stats: Some(TableStats { - reltuples: 5_000_000.0, relpages: 262144, dead_tuples: 0, - last_vacuum: None, last_autovacuum: None, - last_analyze: None, last_autoanalyze: None, - seq_scan: 0, idx_scan: 0, table_size: 2_147_483_648, + reltuples: 5_000_000.0, + relpages: 262144, + dead_tuples: 0, + last_vacuum: None, + last_autovacuum: None, + last_analyze: None, + last_autoanalyze: None, + seq_scan: 0, + idx_scan: 0, + table_size: 2_147_483_648, }), - partition_info: None, policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + partition_info: None, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }], - enums: vec![], domains: vec![], composites: vec![], views: vec![], - functions: vec![], extensions: vec![], gucs: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } fn pg17() -> PgVersion { - PgVersion { major: 17, minor: 0, patch: 0 } + PgVersion { + major: 17, + minor: 0, + patch: 0, + } } #[test] fn add_column_no_default_safe() { - let checks = check_migration("ALTER TABLE orders ADD COLUMN notes text", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "ALTER TABLE orders ADD COLUMN notes text", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].operation, "ADD COLUMN"); assert_eq!(checks[0].safety, SafetyRating::Safe); @@ -518,7 +557,12 @@ mod tests { #[test] fn add_column_with_default() { - let checks = check_migration("ALTER TABLE orders ADD COLUMN status text DEFAULT 'pending'", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "ALTER TABLE orders ADD COLUMN status text DEFAULT 'pending'", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].safety, SafetyRating::Caution); assert!(checks[0].recommendation.contains("immutable")); @@ -526,7 +570,12 @@ mod tests { #[test] fn create_index_without_concurrently() { - let checks = check_migration("CREATE INDEX idx_orders_status ON orders(status)", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "CREATE INDEX idx_orders_status ON orders(status)", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].safety, SafetyRating::Dangerous); assert!(checks[0].recommendation.contains("CONCURRENTLY")); @@ -534,15 +583,29 @@ mod tests { #[test] fn create_index_concurrently_safe() { - let checks = check_migration("CREATE INDEX CONCURRENTLY idx_orders_status ON orders(status)", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "CREATE INDEX CONCURRENTLY idx_orders_status ON orders(status)", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].safety, SafetyRating::Safe); } #[test] fn set_not_null_caution_pg12() { - let pg12 = PgVersion { major: 12, minor: 0, patch: 0 }; - let checks = check_migration("ALTER TABLE orders ALTER COLUMN status SET NOT NULL", &empty_schema(), Some(&pg12)).unwrap(); + let pg12 = PgVersion { + major: 12, + minor: 0, + patch: 0, + }; + let checks = check_migration( + "ALTER TABLE orders ALTER COLUMN status SET NOT NULL", + &empty_schema(), + Some(&pg12), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].operation, "SET NOT NULL"); assert_eq!(checks[0].safety, SafetyRating::Caution); @@ -551,21 +614,36 @@ mod tests { #[test] fn alter_column_type_dangerous() { - let checks = check_migration("ALTER TABLE orders ALTER COLUMN id TYPE bigint", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "ALTER TABLE orders ALTER COLUMN id TYPE bigint", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].safety, SafetyRating::Dangerous); } #[test] fn drop_column_safe() { - let checks = check_migration("ALTER TABLE orders DROP COLUMN legacy", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "ALTER TABLE orders DROP COLUMN legacy", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert_eq!(checks.len(), 1); assert_eq!(checks[0].safety, SafetyRating::Safe); } #[test] fn includes_table_size() { - let checks = check_migration("ALTER TABLE orders ADD COLUMN x text", &empty_schema(), Some(&pg17())).unwrap(); + let checks = check_migration( + "ALTER TABLE orders ADD COLUMN x text", + &empty_schema(), + Some(&pg17()), + ) + .unwrap(); assert!(checks[0].table_size.as_ref().unwrap().contains("GB")); assert_eq!(checks[0].row_estimate, Some(5_000_000.0)); } diff --git a/crates/dry_run_core/src/query/mod.rs b/crates/dry_run_core/src/query/mod.rs index 22bc23a..36da730 100644 --- a/crates/dry_run_core/src/query/mod.rs +++ b/crates/dry_run_core/src/query/mod.rs @@ -8,11 +8,11 @@ mod plan_warnings; mod suggest; mod validate; -pub use advise::{advise, advise_with_index_suggestions, Advice, AdviseResult}; -pub use explain::{explain_query, ExplainResult, PlanWarning}; -pub use migration::{check_migration, MigrationCheck, SafetyRating}; +pub use advise::{Advice, AdviseResult, advise, advise_with_index_suggestions}; +pub use explain::{ExplainResult, PlanWarning, explain_query}; +pub use migration::{MigrationCheck, SafetyRating, check_migration}; pub use parse::{FuncWrappedColumn, ParsedQuery, QueryInfo, ReferencedTable}; pub use plan::{PlanNode, parse_plan_json}; pub use plan_warnings::detect_plan_warnings; pub use suggest::IndexSuggestion; -pub use validate::{validate_query, ValidationResult, ValidationWarning}; +pub use validate::{ValidationResult, ValidationWarning, validate_query}; diff --git a/crates/dry_run_core/src/query/parse.rs b/crates/dry_run_core/src/query/parse.rs index 66facf2..69c9053 100644 --- a/crates/dry_run_core/src/query/parse.rs +++ b/crates/dry_run_core/src/query/parse.rs @@ -97,21 +97,19 @@ pub fn parse_sql(sql: &str) -> Result { for target in &s.target_list { if let Some(pg_query::protobuf::node::Node::ResTarget(rt)) = &target.node && let Some(val) = &rt.val - && let Some(pg_query::protobuf::node::Node::ColumnRef(cr)) = &val.node { - for field in &cr.fields { - if let Some(pg_query::protobuf::node::Node::AStar(_)) = - &field.node - { - has_select_star = true; - } - } + && let Some(pg_query::protobuf::node::Node::ColumnRef(cr)) = &val.node + { + for field in &cr.fields { + if let Some(pg_query::protobuf::node::Node::AStar(_)) = &field.node { + has_select_star = true; } + } + } } } - NodeRef::InsertStmt(_) - if statement_type.is_empty() => { - statement_type = "INSERT".into(); - } + NodeRef::InsertStmt(_) if statement_type.is_empty() => { + statement_type = "INSERT".into(); + } NodeRef::UpdateStmt(u) => { if statement_type.is_empty() { statement_type = "UPDATE".into(); @@ -125,9 +123,10 @@ pub fn parse_sql(sql: &str) -> Result { } for tl in &u.target_list { if let Some(pg_query::protobuf::node::Node::ResTarget(rt)) = &tl.node - && !rt.name.is_empty() { - update_targets.push(rt.name.clone()); - } + && !rt.name.is_empty() + { + update_targets.push(rt.name.clone()); + } } } NodeRef::DeleteStmt(d) => { diff --git a/crates/dry_run_core/src/query/plan.rs b/crates/dry_run_core/src/query/plan.rs index 4d2c26d..92f4ac3 100644 --- a/crates/dry_run_core/src/query/plan.rs +++ b/crates/dry_run_core/src/query/plan.rs @@ -145,7 +145,13 @@ mod tests { "Plan Width": 48 } }]); - let plan_value = json.as_array().unwrap().first().unwrap().get("Plan").unwrap(); + let plan_value = json + .as_array() + .unwrap() + .first() + .unwrap() + .get("Plan") + .unwrap(); let plan = parse_plan_json(plan_value).unwrap(); assert_eq!(plan.node_type, "Seq Scan"); assert_eq!(plan.relation_name.as_deref(), Some("orders")); diff --git a/crates/dry_run_core/src/query/plan_warnings.rs b/crates/dry_run_core/src/query/plan_warnings.rs index c3d9162..9be8b5a 100644 --- a/crates/dry_run_core/src/query/plan_warnings.rs +++ b/crates/dry_run_core/src/query/plan_warnings.rs @@ -72,9 +72,11 @@ fn detect_nested_loop_seq_scan(node: &PlanNode, warnings: &mut Vec) } if let Some(inner) = node.children.get(1) - && inner.node_type == "Seq Scan" && inner.plan_rows > 100.0 { - let table_name = inner.relation_name.as_deref().unwrap_or("unknown"); - warnings.push(PlanWarning { + && inner.node_type == "Seq Scan" + && inner.plan_rows > 100.0 + { + let table_name = inner.relation_name.as_deref().unwrap_or("unknown"); + warnings.push(PlanWarning { severity: "warning".into(), message: format!( "nested loop with sequential scan on inner side '{}' (~{} rows) — this executes once per outer row", @@ -84,7 +86,7 @@ fn detect_nested_loop_seq_scan(node: &PlanNode, warnings: &mut Vec) node_type: "Nested Loop".into(), detail: None, }); - } + } } fn detect_sort_without_index(node: &PlanNode, warnings: &mut Vec) { @@ -113,8 +115,11 @@ fn detect_sort_without_index(node: &PlanNode, warnings: &mut Vec) { fn detect_high_rows_removed(node: &PlanNode, warnings: &mut Vec) { if let Some(removed) = node.rows_removed_by_filter && let Some(actual) = node.actual_rows - && removed > 0.0 && actual > 0.0 && removed / (removed + actual) > 0.9 { - warnings.push(PlanWarning { + && removed > 0.0 + && actual > 0.0 + && removed / (removed + actual) > 0.9 + { + warnings.push(PlanWarning { severity: "warning".into(), message: format!( "'{}' filter removed {:.0} rows, kept {:.0} — index on the filter column would help", @@ -123,7 +128,7 @@ fn detect_high_rows_removed(node: &PlanNode, warnings: &mut Vec) { node_type: node.node_type.clone(), detail: node.filter.clone(), }); - } + } } fn detect_partition_pruning_issues( @@ -217,11 +222,12 @@ fn detect_cte_materialized( if child.node_type == "Append" || child.node_type == "Merge Append" { for grandchild in &child.children { if let Some(rel) = &grandchild.relation_name - && let Some(p) = find_partition_parent(rel, schema) { - let qualified = format!("{}.{}", p.schema, p.name); - e = jit::cte_over_partitioned_table(cte_name, &qualified); - break; - } + && let Some(p) = find_partition_parent(rel, schema) + { + let qualified = format!("{}.{}", p.schema, p.name); + e = jit::cte_over_partitioned_table(cte_name, &qualified); + break; + } } } } @@ -240,9 +246,9 @@ fn find_partition_parent<'a>( schema: &'a SchemaSnapshot, ) -> Option<&'a crate::schema::Table> { schema.tables.iter().find(|t| { - t.partition_info.as_ref().is_some_and(|pi| { - pi.children.iter().any(|c| c.name == child_table_name) - }) + t.partition_info + .as_ref() + .is_some_and(|pi| pi.children.iter().any(|c| c.name == child_table_name)) }) } @@ -285,14 +291,22 @@ mod tests { fn seq_scan_large_table() { let plan = make_seq_scan("users", 100_000.0); let warnings = detect_plan_warnings(&plan, None); - assert!(warnings.iter().any(|w| w.message.contains("sequential scan"))); + assert!( + warnings + .iter() + .any(|w| w.message.contains("sequential scan")) + ); } #[test] fn seq_scan_small_table_no_warning() { let plan = make_seq_scan("config", 10.0); let warnings = detect_plan_warnings(&plan, None); - assert!(!warnings.iter().any(|w| w.message.contains("sequential scan"))); + assert!( + !warnings + .iter() + .any(|w| w.message.contains("sequential scan")) + ); } #[test] @@ -346,17 +360,41 @@ mod tests { strategy: PartitionStrategy::Range, key: "created_at".into(), children: vec![ - PartitionChild { schema: "public".into(), name: "orders_q1".into(), bound: "FOR VALUES FROM ('2024-01-01') TO ('2024-04-01')".into() }, - PartitionChild { schema: "public".into(), name: "orders_q2".into(), bound: "FOR VALUES FROM ('2024-04-01') TO ('2024-07-01')".into() }, - PartitionChild { schema: "public".into(), name: "orders_q3".into(), bound: "FOR VALUES FROM ('2024-07-01') TO ('2024-10-01')".into() }, - PartitionChild { schema: "public".into(), name: "orders_q4".into(), bound: "FOR VALUES FROM ('2024-10-01') TO ('2025-01-01')".into() }, + PartitionChild { + schema: "public".into(), + name: "orders_q1".into(), + bound: "FOR VALUES FROM ('2024-01-01') TO ('2024-04-01')".into(), + }, + PartitionChild { + schema: "public".into(), + name: "orders_q2".into(), + bound: "FOR VALUES FROM ('2024-04-01') TO ('2024-07-01')".into(), + }, + PartitionChild { + schema: "public".into(), + name: "orders_q3".into(), + bound: "FOR VALUES FROM ('2024-07-01') TO ('2024-10-01')".into(), + }, + PartitionChild { + schema: "public".into(), + name: "orders_q4".into(), + bound: "FOR VALUES FROM ('2024-10-01') TO ('2025-01-01')".into(), + }, ], }), - policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, + policies: vec![], + triggers: vec![], + reloptions: vec![], + rls_enabled: false, }], - enums: vec![], domains: vec![], composites: vec![], - views: vec![], functions: vec![], extensions: vec![], - gucs: vec![], node_stats: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], + node_stats: vec![], } } @@ -375,8 +413,11 @@ mod tests { ..make_seq_scan("", 0.0) }; let warnings = detect_plan_warnings(&plan, Some(&schema)); - assert!(warnings.iter().any(|w| - w.message.contains("no pruning") && w.message.contains("4/4"))); + assert!( + warnings + .iter() + .any(|w| w.message.contains("no pruning") && w.message.contains("4/4")) + ); } #[test] @@ -390,7 +431,11 @@ mod tests { ..make_seq_scan("", 0.0) }; let warnings = detect_plan_warnings(&plan, Some(&schema)); - assert!(!warnings.iter().any(|w| w.message.contains("partition pruning"))); + assert!( + !warnings + .iter() + .any(|w| w.message.contains("partition pruning")) + ); } #[test] @@ -408,6 +453,10 @@ mod tests { ..make_seq_scan("", 0.0) }; let warnings = detect_plan_warnings(&plan, Some(&schema)); - assert!(warnings.iter().any(|w| w.message.contains("partial pruning"))); + assert!( + warnings + .iter() + .any(|w| w.message.contains("partial pruning")) + ); } } diff --git a/crates/dry_run_core/src/query/suggest.rs b/crates/dry_run_core/src/query/suggest.rs index 3940069..7f06c38 100644 --- a/crates/dry_run_core/src/query/suggest.rs +++ b/crates/dry_run_core/src/query/suggest.rs @@ -44,68 +44,71 @@ fn suggest_from_plan( schema: &SchemaSnapshot, suggestions: &mut Vec, ) { - if node.node_type == "Seq Scan" && node.plan_rows >= 1000.0 - && let Some(table_name) = &node.relation_name { - let schema_name = node.schema.as_deref().unwrap_or("public"); - let table = schema - .tables - .iter() - .find(|t| t.name == *table_name && t.schema == schema_name); - - if let Some(filter) = &node.filter - && let Some(col) = extract_filter_column(filter) - && !has_leading_index(table, &col) { - let idx_type = choose_index_type(table, &col); - let qualified = format!("{schema_name}.{table_name}"); - let idx_name = format!("idx_{table_name}_{col}"); - suggestions.push(IndexSuggestion { - table: qualified.clone(), - index_type: idx_type.to_string(), - columns: vec![col.clone()], - include_columns: vec![], - partial_predicate: None, - ddl: format!( - "CREATE INDEX CONCURRENTLY {idx_name} ON {qualified} USING {idx_type}({col});" - ), - rationale: format!( - "Seq scan on '{qualified}' filtering on '{col}' (~{} rows)", - node.plan_rows as i64 - ), - estimated_impact: estimate_impact(node.plan_rows), - }); - } + if node.node_type == "Seq Scan" + && node.plan_rows >= 1000.0 + && let Some(table_name) = &node.relation_name + { + let schema_name = node.schema.as_deref().unwrap_or("public"); + let table = schema + .tables + .iter() + .find(|t| t.name == *table_name && t.schema == schema_name); + + if let Some(filter) = &node.filter + && let Some(col) = extract_filter_column(filter) + && !has_leading_index(table, &col) + { + let idx_type = choose_index_type(table, &col); + let qualified = format!("{schema_name}.{table_name}"); + let idx_name = format!("idx_{table_name}_{col}"); + suggestions.push(IndexSuggestion { + table: qualified.clone(), + index_type: idx_type.to_string(), + columns: vec![col.clone()], + include_columns: vec![], + partial_predicate: None, + ddl: format!( + "CREATE INDEX CONCURRENTLY {idx_name} ON {qualified} USING {idx_type}({col});" + ), + rationale: format!( + "Seq scan on '{qualified}' filtering on '{col}' (~{} rows)", + node.plan_rows as i64 + ), + estimated_impact: estimate_impact(node.plan_rows), + }); } + } - if node.node_type == "Sort" && node.plan_rows >= 5000.0 + if node.node_type == "Sort" + && node.plan_rows >= 5000.0 && let Some(sort_keys) = &node.sort_key - && let Some((schema_name, table_name)) = find_table_in_subtree(node) { - let cols: Vec = sort_keys - .iter() - .map(|k| k.split_whitespace().next().unwrap_or(k).to_string()) - .collect(); - let qualified = format!("{schema_name}.{table_name}"); - let col_list = cols.join(", "); - let idx_name = format!( - "idx_{table_name}_{}", - cols.first().unwrap_or(&"sort".into()) - ); - - suggestions.push(IndexSuggestion { - table: qualified.clone(), - index_type: "btree".into(), - columns: cols, - include_columns: vec![], - partial_predicate: None, - ddl: format!( - "CREATE INDEX CONCURRENTLY {idx_name} ON {qualified}({col_list});" - ), - rationale: format!( - "Sort on ~{} rows could be avoided with an index on ({})", - node.plan_rows as i64, col_list - ), - estimated_impact: "eliminates sort step".into(), - }); - } + && let Some((schema_name, table_name)) = find_table_in_subtree(node) + { + let cols: Vec = sort_keys + .iter() + .map(|k| k.split_whitespace().next().unwrap_or(k).to_string()) + .collect(); + let qualified = format!("{schema_name}.{table_name}"); + let col_list = cols.join(", "); + let idx_name = format!( + "idx_{table_name}_{}", + cols.first().unwrap_or(&"sort".into()) + ); + + suggestions.push(IndexSuggestion { + table: qualified.clone(), + index_type: "btree".into(), + columns: cols, + include_columns: vec![], + partial_predicate: None, + ddl: format!("CREATE INDEX CONCURRENTLY {idx_name} ON {qualified}({col_list});"), + rationale: format!( + "Sort on ~{} rows could be avoided with an index on ({})", + node.plan_rows as i64, col_list + ), + estimated_impact: "eliminates sort step".into(), + }); + } for child in &node.children { suggest_from_plan(child, schema, suggestions); @@ -149,10 +152,7 @@ fn suggest_from_query_structure( let idx_type = choose_index_type(Some(table), col_name); let qualified = format!("{}.{}", table.schema, table.name); let idx_name = format!("idx_{}_{col_name}", table.name); - let reltuples = effective_stats - .as_ref() - .map(|s| s.reltuples) - .unwrap_or(0.0); + let reltuples = effective_stats.as_ref().map(|s| s.reltuples).unwrap_or(0.0); suggestions.push(IndexSuggestion { table: qualified.clone(), @@ -198,15 +198,16 @@ fn has_leading_index(table: Option<&Table>, col: &str) -> bool { fn choose_index_type<'a>(table: Option<&Table>, col: &str) -> &'a str { if let Some(table) = table - && let Some(column) = table.columns.iter().find(|c| c.name == col) { - let ct = column.type_name.to_lowercase(); - if ct == "jsonb" || ct == "tsvector" { - return "gin"; - } - if ct.contains("geometry") || ct.contains("geography") || ct.contains("range") { - return "gist"; - } + && let Some(column) = table.columns.iter().find(|c| c.name == col) + { + let ct = column.type_name.to_lowercase(); + if ct == "jsonb" || ct == "tsvector" { + return "gin"; + } + if ct.contains("geometry") || ct.contains("geography") || ct.contains("range") { + return "gist"; } + } "btree" } @@ -259,21 +260,71 @@ mod tests { schema: "public".into(), name: "users".into(), columns: vec![ - Column { name: "id".into(), ordinal: 1, type_name: "bigint".into(), nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, - Column { name: "email".into(), ordinal: 2, type_name: "text".into(), nullable: false, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, - Column { name: "data".into(), ordinal: 3, type_name: "jsonb".into(), nullable: true, default: None, identity: None, generated: None, comment: None, statistics_target: None, stats: None }, + Column { + name: "id".into(), + ordinal: 1, + type_name: "bigint".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, + Column { + name: "email".into(), + ordinal: 2, + type_name: "text".into(), + nullable: false, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, + Column { + name: "data".into(), + ordinal: 3, + type_name: "jsonb".into(), + nullable: true, + default: None, + identity: None, + generated: None, + comment: None, + statistics_target: None, + stats: None, + }, ], constraints: vec![], indexes: vec![], comment: None, - stats: Some(TableStats { reltuples: 500_000.0, relpages: 6250, dead_tuples: 0, last_vacuum: None, last_autovacuum: None, last_analyze: None, last_autoanalyze: None, seq_scan: 0, idx_scan: 0, table_size: 50_000_000 }), + stats: Some(TableStats { + reltuples: 500_000.0, + relpages: 6250, + dead_tuples: 0, + last_vacuum: None, + last_autovacuum: None, + last_analyze: None, + last_autoanalyze: None, + seq_scan: 0, + idx_scan: 0, + table_size: 50_000_000, + }), partition_info: None, policies: vec![], triggers: vec![], reloptions: vec![], rls_enabled: false, }], - enums: vec![], domains: vec![], composites: vec![], views: vec![], functions: vec![], extensions: vec![], gucs: vec![], + enums: vec![], + domains: vec![], + composites: vec![], + views: vec![], + functions: vec![], + extensions: vec![], + gucs: vec![], node_stats: vec![], } } @@ -281,7 +332,13 @@ mod tests { #[test] fn suggest_from_where_clause() { let schema = test_schema(); - let suggestions = suggest_index("SELECT * FROM users WHERE email = 'test@example.com'", &schema, None, None).unwrap(); + let suggestions = suggest_index( + "SELECT * FROM users WHERE email = 'test@example.com'", + &schema, + None, + None, + ) + .unwrap(); assert!(!suggestions.is_empty()); assert_eq!(suggestions[0].table, "public.users"); assert!(suggestions[0].columns.contains(&"email".to_string())); @@ -292,8 +349,16 @@ mod tests { #[test] fn suggest_gin_for_jsonb() { let schema = test_schema(); - let suggestions = suggest_index("SELECT * FROM users u WHERE u.data = '{}'", &schema, None, None).unwrap(); - let jsonb = suggestions.iter().find(|s| s.columns.contains(&"data".to_string())); + let suggestions = suggest_index( + "SELECT * FROM users u WHERE u.data = '{}'", + &schema, + None, + None, + ) + .unwrap(); + let jsonb = suggestions + .iter() + .find(|s| s.columns.contains(&"data".to_string())); assert!(jsonb.is_some()); assert_eq!(jsonb.unwrap().index_type, "gin"); } @@ -302,7 +367,8 @@ mod tests { fn no_suggestion_for_small_table() { let mut schema = test_schema(); schema.tables[0].stats.as_mut().unwrap().reltuples = 50.0; - let suggestions = suggest_index("SELECT * FROM users WHERE email = 'x'", &schema, None, None).unwrap(); + let suggestions = + suggest_index("SELECT * FROM users WHERE email = 'x'", &schema, None, None).unwrap(); assert!(suggestions.is_empty()); } @@ -310,15 +376,44 @@ mod tests { fn no_duplicate_suggestions() { let schema = test_schema(); let plan = PlanNode { - node_type: "Seq Scan".into(), relation_name: Some("users".into()), schema: Some("public".into()), - alias: None, startup_cost: 0.0, total_cost: 500.0, plan_rows: 100_000.0, plan_width: 64, - actual_rows: None, actual_loops: None, actual_startup_time: None, actual_total_time: None, - shared_hit_blocks: None, shared_read_blocks: None, index_name: None, index_cond: None, - filter: Some("(email = 'test@example.com')".into()), rows_removed_by_filter: None, - sort_key: None, sort_method: None, hash_cond: None, join_type: None, subplans_removed: None, cte_name: None, parent_relationship: None, children: vec![], + node_type: "Seq Scan".into(), + relation_name: Some("users".into()), + schema: Some("public".into()), + alias: None, + startup_cost: 0.0, + total_cost: 500.0, + plan_rows: 100_000.0, + plan_width: 64, + actual_rows: None, + actual_loops: None, + actual_startup_time: None, + actual_total_time: None, + shared_hit_blocks: None, + shared_read_blocks: None, + index_name: None, + index_cond: None, + filter: Some("(email = 'test@example.com')".into()), + rows_removed_by_filter: None, + sort_key: None, + sort_method: None, + hash_cond: None, + join_type: None, + subplans_removed: None, + cte_name: None, + parent_relationship: None, + children: vec![], }; - let suggestions = suggest_index("SELECT * FROM users WHERE email = 'test@example.com'", &schema, Some(&plan), None).unwrap(); - let email_count = suggestions.iter().filter(|s| s.columns.contains(&"email".to_string())).count(); + let suggestions = suggest_index( + "SELECT * FROM users WHERE email = 'test@example.com'", + &schema, + Some(&plan), + None, + ) + .unwrap(); + let email_count = suggestions + .iter() + .filter(|s| s.columns.contains(&"email".to_string())) + .count(); assert_eq!(email_count, 1, "should deduplicate"); } } diff --git a/crates/dry_run_core/src/query/validate.rs b/crates/dry_run_core/src/query/validate.rs index d176434..6401e38 100644 --- a/crates/dry_run_core/src/query/validate.rs +++ b/crates/dry_run_core/src/query/validate.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use super::antipatterns::detect_antipatterns; -use super::parse::{parse_sql, ParsedQuery, ReferencedTable}; +use super::parse::{ParsedQuery, ReferencedTable, parse_sql}; use crate::error::Result; use crate::schema::SchemaSnapshot; @@ -115,12 +115,13 @@ fn validate_filter_columns( .tables .iter() .find(|t| t.name == table_ref.name && t.schema == schema_name) - && !table.columns.iter().any(|c| c.name == *col_name) { - errors.push(format!( - "column '{col_name}' does not exist on table '{}.{}'", - table.schema, table.name - )); - } + && !table.columns.iter().any(|c| c.name == *col_name) + { + errors.push(format!( + "column '{col_name}' does not exist on table '{}.{}'", + table.schema, table.name + )); + } } } } diff --git a/crates/dry_run_core/src/schema/bloat.rs b/crates/dry_run_core/src/schema/bloat.rs index 7151594..ab67179 100644 --- a/crates/dry_run_core/src/schema/bloat.rs +++ b/crates/dry_run_core/src/schema/bloat.rs @@ -96,7 +96,10 @@ fn lookup_type_width(type_name: &str) -> usize { "double precision" | "float8" => 8, "boolean" | "bool" => 1, "date" => 4, - "timestamp without time zone" | "timestamp" | "timestamp with time zone" | "timestamptz" => 8, + "timestamp without time zone" + | "timestamp" + | "timestamp with time zone" + | "timestamptz" => 8, "uuid" => 16, "inet" | "cidr" => 19, "macaddr" => 6, @@ -149,13 +152,7 @@ mod tests { #[test] fn estimate_bloat_ratio() { let table = make_table_with_cols(vec![("id", "bigint"), ("name", "text")]); - let est = estimate_index_bloat_from_stats( - 100_000.0, - 1000, - &["id".into()], - &table, - "btree", - ); + let est = estimate_index_bloat_from_stats(100_000.0, 1000, &["id".into()], &table, "btree"); let est = est.unwrap(); assert!(est.bloat_ratio > 0.0); assert_eq!(est.actual_pages, 1000); @@ -165,13 +162,7 @@ mod tests { #[test] fn non_btree_returns_none() { let table = make_table_with_cols(vec![("data", "jsonb")]); - let est = estimate_index_bloat_from_stats( - 100_000.0, - 500, - &["data".into()], - &table, - "gin", - ); + let est = estimate_index_bloat_from_stats(100_000.0, 500, &["data".into()], &table, "gin"); assert!(est.is_none()); } diff --git a/crates/dry_run_core/src/schema/inject.rs b/crates/dry_run_core/src/schema/inject.rs index 793322c..b12accb 100644 --- a/crates/dry_run_core/src/schema/inject.rs +++ b/crates/dry_run_core/src/schema/inject.rs @@ -2,9 +2,7 @@ use sqlx::PgPool; use tracing::info; use crate::error::{Error, Result}; -use crate::schema::types::{ - ColumnStats, IndexStats, NodeStats, SchemaSnapshot, TableStats, -}; +use crate::schema::types::{ColumnStats, IndexStats, NodeStats, SchemaSnapshot, TableStats}; #[derive(Debug)] pub struct ApplyResult { @@ -53,16 +51,19 @@ pub async fn apply_stats( regresql_loaded, }; - let mut tx = pool.begin().await.map_err(|e| { - Error::StatsInjection(format!("failed to begin transaction: {e}")) - })?; + let mut tx = pool + .begin() + .await + .map_err(|e| Error::StatsInjection(format!("failed to begin transaction: {e}")))?; // phase 1: pg_class for tables for (schema, table, stats) in &resolved.tables { match update_pg_class(&mut tx, schema, table, "r", stats.reltuples, stats.relpages).await { Ok(true) => result.tables_updated += 1, Ok(false) => { - result.skipped.push(format!("{schema}.{table}: not found on target")); + result + .skipped + .push(format!("{schema}.{table}: not found on target")); } Err(e) => { result.skipped.push(format!("{schema}.{table}: {e}")); @@ -72,15 +73,26 @@ pub async fn apply_stats( // phase 2: pg_class for indexes for (schema, _table, index_name, stats) in &resolved.indexes { - match update_pg_class(&mut tx, schema, index_name, "i", stats.reltuples, stats.relpages) - .await + match update_pg_class( + &mut tx, + schema, + index_name, + "i", + stats.reltuples, + stats.relpages, + ) + .await { Ok(true) => result.indexes_updated += 1, Ok(false) => { - result.skipped.push(format!("index {schema}.{index_name}: not found on target")); + result + .skipped + .push(format!("index {schema}.{index_name}: not found on target")); } Err(e) => { - result.skipped.push(format!("index {schema}.{index_name}: {e}")); + result + .skipped + .push(format!("index {schema}.{index_name}: {e}")); } } } @@ -90,9 +102,9 @@ pub async fn apply_stats( let meta = match lookup_column_meta(&mut tx, schema, table, column).await { Ok(Some(m)) => m, Ok(None) => { - result - .skipped - .push(format!("{schema}.{table}.{column}: column not found on target")); + result.skipped.push(format!( + "{schema}.{table}.{column}: column not found on target" + )); continue; } Err(e) => { @@ -113,9 +125,9 @@ pub async fn apply_stats( continue; } Err(e) => { - result - .skipped - .push(format!("{schema}.{table}.{column}: type validation failed: {e}")); + result.skipped.push(format!( + "{schema}.{table}.{column}: type validation failed: {e}" + )); continue; } }; @@ -140,9 +152,9 @@ pub async fn apply_stats( } } - tx.commit().await.map_err(|e| { - Error::StatsInjection(format!("failed to commit: {e}")) - })?; + tx.commit() + .await + .map_err(|e| Error::StatsInjection(format!("failed to commit: {e}")))?; info!( tables = result.tables_updated, @@ -210,7 +222,11 @@ fn resolve_stats(snapshot: &SchemaSnapshot, node: Option<&str>) -> Result = snapshot.node_stats.iter().map(|n| n.source.as_str()).collect(); + let available: Vec<&str> = snapshot + .node_stats + .iter() + .map(|n| n.source.as_str()) + .collect(); Error::StatsInjection(format!( "node '{}' not found. Available: {}", node_name, @@ -229,7 +245,11 @@ fn resolve_stats(snapshot: &SchemaSnapshot, node: Option<&str>) -> Result = snapshot.node_stats.iter().map(|n| n.source.as_str()).collect(); + let available: Vec<&str> = snapshot + .node_stats + .iter() + .map(|n| n.source.as_str()) + .collect(); return Err(Error::StatsInjection(format!( "multiple node stats found ({}). Use --node to select one: {}", snapshot.node_stats.len(), @@ -323,7 +343,12 @@ fn resolve_from_inline(snapshot: &SchemaSnapshot) -> ResolvedStats { } } -fn find_column_type(snapshot: &SchemaSnapshot, schema: &str, table: &str, column: &str) -> Option { +fn find_column_type( + snapshot: &SchemaSnapshot, + schema: &str, + table: &str, + column: &str, +) -> Option { snapshot .tables .iter() @@ -416,13 +441,15 @@ async fn lookup_column_meta( .fetch_optional(&mut **tx) .await?; - Ok(row.map(|(attrelid, attnum, _atttypid, eq_opr, lt_opr)| ColumnMeta { - attrelid, - attnum, - type_name: String::new(), // filled in by caller after validate_type_name - eq_opr, - lt_opr, - })) + Ok( + row.map(|(attrelid, attnum, _atttypid, eq_opr, lt_opr)| ColumnMeta { + attrelid, + attnum, + type_name: String::new(), // filled in by caller after validate_type_name + eq_opr, + lt_opr, + }), + ) } /// Validate a type name against the target database, returning normalized form. @@ -481,33 +508,35 @@ async fn inject_column_stats( let mut slot_idx = 0; // MCV slot (stakind = 1) - if let (Some(mcv_vals), Some(mcv_freqs)) = - (&stats.most_common_vals, &stats.most_common_freqs) - && let Some(eq_op) = meta.eq_opr { - slot_kinds[slot_idx] = 1; - slot_ops[slot_idx] = eq_op; - slot_numbers[slot_idx] = Some(mcv_freqs.clone()); - slot_values[slot_idx] = Some(mcv_vals.clone()); - slot_idx += 1; - } + if let (Some(mcv_vals), Some(mcv_freqs)) = (&stats.most_common_vals, &stats.most_common_freqs) + && let Some(eq_op) = meta.eq_opr + { + slot_kinds[slot_idx] = 1; + slot_ops[slot_idx] = eq_op; + slot_numbers[slot_idx] = Some(mcv_freqs.clone()); + slot_values[slot_idx] = Some(mcv_vals.clone()); + slot_idx += 1; + } // Histogram slot (stakind = 2) if let Some(ref hist) = stats.histogram_bounds - && let Some(lt_op) = meta.lt_opr { - slot_kinds[slot_idx] = 2; - slot_ops[slot_idx] = lt_op; - slot_values[slot_idx] = Some(hist.clone()); - slot_idx += 1; - } + && let Some(lt_op) = meta.lt_opr + { + slot_kinds[slot_idx] = 2; + slot_ops[slot_idx] = lt_op; + slot_values[slot_idx] = Some(hist.clone()); + slot_idx += 1; + } // Correlation slot (stakind = 3) if let Some(corr) = stats.correlation - && let Some(lt_op) = meta.lt_opr { - slot_kinds[slot_idx] = 3; - slot_ops[slot_idx] = lt_op; - slot_numbers[slot_idx] = Some(format!("{{{corr}}}")); - // no stavalues for correlation - } + && let Some(lt_op) = meta.lt_opr + { + slot_kinds[slot_idx] = 3; + slot_ops[slot_idx] = lt_op; + slot_numbers[slot_idx] = Some(format!("{{{corr}}}")); + // no stavalues for correlation + } // Build dynamic INSERT — we need dynamic SQL because stavalues is anyarray // and we need to cast to the actual column type diff --git a/crates/dry_run_core/src/schema/introspect/mod.rs b/crates/dry_run_core/src/schema/introspect/mod.rs index dc56053..49716c7 100644 --- a/crates/dry_run_core/src/schema/introspect/mod.rs +++ b/crates/dry_run_core/src/schema/introspect/mod.rs @@ -15,7 +15,7 @@ use sqlx::postgres::PgRow; use sqlx::{PgPool, Row}; use tracing::info; -use super::hash::{compute_content_hash, HashInput}; +use super::hash::{HashInput, compute_content_hash}; use super::types::*; use crate::error::Result; @@ -60,19 +60,21 @@ pub async fn introspect_schema(pool: &PgPool) -> Result { )?; // Group 2: top-level objects. - let (enums, domains, composites, views, functions, extensions, gucs, is_standby) = - tokio::try_join!( - catalog::fetch_enums(pool), - catalog::fetch_domains(pool), - catalog::fetch_composites(pool), - objects::fetch_views(pool), - objects::fetch_functions(pool), - objects::fetch_extensions(pool), - objects::fetch_gucs(pool), - fetch_is_standby(pool), - )?; - - let with_vacuum = raw_table_stats.iter().filter(|s| s.last_autovacuum.is_some()).count(); + let (enums, domains, composites, views, functions, extensions, gucs, is_standby) = tokio::try_join!( + catalog::fetch_enums(pool), + catalog::fetch_domains(pool), + catalog::fetch_composites(pool), + objects::fetch_views(pool), + objects::fetch_functions(pool), + objects::fetch_extensions(pool), + objects::fetch_gucs(pool), + fetch_is_standby(pool), + )?; + + let with_vacuum = raw_table_stats + .iter() + .filter(|s| s.last_autovacuum.is_some()) + .count(); if with_vacuum == 0 && !raw_table_stats.is_empty() { if is_standby { info!("all vacuum timestamps are null;expected on standby"); diff --git a/crates/dry_run_core/src/schema/mod.rs b/crates/dry_run_core/src/schema/mod.rs index 75c3ab5..fdaecf8 100644 --- a/crates/dry_run_core/src/schema/mod.rs +++ b/crates/dry_run_core/src/schema/mod.rs @@ -1,13 +1,13 @@ pub mod bloat; mod hash; -pub mod profile; -pub mod vacuum; pub mod inject; mod introspect; +pub mod profile; mod types; +pub mod vacuum; -pub use hash::{compute_content_hash, HashInput}; -pub use inject::{apply_stats, ApplyResult}; +pub use hash::{HashInput, compute_content_hash}; +pub use inject::{ApplyResult, apply_stats}; pub use introspect::{fetch_is_standby, fetch_stats_only, introspect_schema}; pub use profile::*; pub use types::*; diff --git a/crates/dry_run_core/src/schema/profile.rs b/crates/dry_run_core/src/schema/profile.rs index 8eee693..c115d62 100644 --- a/crates/dry_run_core/src/schema/profile.rs +++ b/crates/dry_run_core/src/schema/profile.rs @@ -59,18 +59,16 @@ pub fn column_selectivity(col: &Column, table_rows: f64) -> f64 { /// Returns Some((dominant_value, frequency)) when a single value exceeds the /// given frequency threshold. -pub fn has_skewed_distribution( - stats: &ColumnStats, - threshold: f64, -) -> Option<(String, f64)> { +pub fn has_skewed_distribution(stats: &ColumnStats, threshold: f64) -> Option<(String, f64)> { let vals = stats.most_common_vals.as_deref().map(parse_pg_array)?; let freqs = stats.most_common_freqs.as_deref().map(parse_pg_array)?; for (v, f_str) in vals.iter().zip(freqs.iter()) { if let Ok(f) = f_str.parse::() - && f > threshold { - return Some((v.clone(), f)); - } + && f > threshold + { + return Some((v.clone(), f)); + } } None } @@ -211,29 +209,33 @@ fn parse_top_values(s: &ColumnStats, limit: usize) -> Vec { fn profile_note(col: &Column, s: &ColumnStats, table_rows: f64) -> Option { // low-cardinality text column -> suggest enum if let Some(nd) = s.n_distinct - && nd > 0.0 && nd <= 10.0 { - let t = col.type_name.to_lowercase(); - if t.contains("text") || t.contains("varchar") || t.contains("character varying") { - return Some("Consider using an enum type".to_string()); - } + && nd > 0.0 + && nd <= 10.0 + { + let t = col.type_name.to_lowercase(); + if t.contains("text") || t.contains("varchar") || t.contains("character varying") { + return Some("Consider using an enum type".to_string()); } + } // very high null ratio if let Some(nf) = s.null_frac - && nf > 0.8 { - return Some( - "Very high null ratio; partial index WHERE col IS NOT NULL recommended" - .to_string(), - ); - } + && nf > 0.8 + { + return Some( + "Very high null ratio; partial index WHERE col IS NOT NULL recommended".to_string(), + ); + } // low physical correlation on large table if let Some(corr) = s.correlation - && corr.abs() < 0.3 && table_rows > 100_000.0 { - return Some( - "Low physical correlation; BRIN index will be ineffective, use btree".to_string(), - ); - } + && corr.abs() < 0.3 + && table_rows > 100_000.0 + { + return Some( + "Low physical correlation; BRIN index will be ineffective, use btree".to_string(), + ); + } None } @@ -458,7 +460,10 @@ mod tests { histogram_bounds: None, correlation: Some(0.95), }; - assert_eq!(profile_correlation(&stats), Some("well ordered".to_string())); + assert_eq!( + profile_correlation(&stats), + Some("well ordered".to_string()) + ); } #[test] diff --git a/crates/dry_run_core/src/schema/types.rs b/crates/dry_run_core/src/schema/types.rs index 47b20bb..b2a2744 100644 --- a/crates/dry_run_core/src/schema/types.rs +++ b/crates/dry_run_core/src/schema/types.rs @@ -47,7 +47,11 @@ pub struct Table { pub policies: Vec, #[serde(default, deserialize_with = "null_as_empty_vec")] pub triggers: Vec, - #[serde(default, deserialize_with = "null_as_empty_vec", skip_serializing_if = "Vec::is_empty")] + #[serde( + default, + deserialize_with = "null_as_empty_vec", + skip_serializing_if = "Vec::is_empty" + )] pub reloptions: Vec, pub rls_enabled: bool, } @@ -346,7 +350,10 @@ pub fn aggregate_table_stats( let last_vacuum = primary_stats.iter().filter_map(|s| s.last_vacuum).max(); let last_autovacuum = primary_stats.iter().filter_map(|s| s.last_autovacuum).max(); let last_analyze = primary_stats.iter().filter_map(|s| s.last_analyze).max(); - let last_autoanalyze = primary_stats.iter().filter_map(|s| s.last_autoanalyze).max(); + let last_autoanalyze = primary_stats + .iter() + .filter_map(|s| s.last_autoanalyze) + .max(); Some(TableStats { reltuples, @@ -418,7 +425,9 @@ pub fn summarize_table_stats(node_stats: &[NodeStats]) -> Vec { }); entry.total_seq_scan += ts.stats.seq_scan; entry.total_idx_scan += ts.stats.idx_scan; - entry.per_node_seq.push((ns.source.clone(), ts.stats.seq_scan)); + entry + .per_node_seq + .push((ns.source.clone(), ts.stats.seq_scan)); } } @@ -426,10 +435,7 @@ pub fn summarize_table_stats(node_stats: &[NodeStats]) -> Vec { } // Compute anomaly flags for a single table summary. -pub fn detect_table_flags( - summary: &TableSummary, - node_stats: &[NodeStats], -) -> Vec { +pub fn detect_table_flags(summary: &TableSummary, node_stats: &[NodeStats]) -> Vec { let mut flags = Vec::new(); if summary.total_seq_scan > 100 && summary.total_idx_scan > 0 { @@ -510,7 +516,11 @@ pub fn detect_seq_scan_imbalance( } let min = nonzero.iter().map(|(_, v)| *v).min().unwrap_or(1); - let (hot_node, max) = nonzero.iter().max_by_key(|(_, v)| *v).copied().unwrap_or(("", 1)); + let (hot_node, max) = nonzero + .iter() + .max_by_key(|(_, v)| *v) + .copied() + .unwrap_or(("", 1)); if min > 0 && max / min >= 5 { Some(NodeImbalanceInfo { @@ -551,21 +561,26 @@ pub fn detect_bloated_indexes(tables: &[Table], threshold: f64) -> Vec threshold { - entries.push(BloatedIndexEntry { - schema: table.schema.clone(), - table: table.name.clone(), - index_name: idx.name.clone(), - bloat_ratio: est.bloat_ratio, - actual_pages: est.actual_pages, - expected_pages: est.expected_pages, - definition: idx.definition.clone(), - }); - } + && est.bloat_ratio > threshold + { + entries.push(BloatedIndexEntry { + schema: table.schema.clone(), + table: table.name.clone(), + index_name: idx.name.clone(), + bloat_ratio: est.bloat_ratio, + actual_pages: est.actual_pages, + expected_pages: est.expected_pages, + definition: idx.definition.clone(), + }); + } } } - entries.sort_by(|a, b| b.bloat_ratio.partial_cmp(&a.bloat_ratio).unwrap_or(std::cmp::Ordering::Equal)); + entries.sort_by(|a, b| { + b.bloat_ratio + .partial_cmp(&a.bloat_ratio) + .unwrap_or(std::cmp::Ordering::Equal) + }); entries } @@ -585,17 +600,18 @@ pub fn detect_unused_indexes(node_stats: &[NodeStats], tables: &[Table]) -> Vec< continue; } if let Some(ref stats) = idx.stats - && stats.idx_scan == 0 { - entries.push(UnusedIndexEntry { - schema: t.schema.clone(), - table: t.name.clone(), - index_name: idx.name.clone(), - total_idx_scan: 0, - total_size_bytes: stats.size, - is_unique: idx.is_unique, - definition: idx.definition.clone(), - }); - } + && stats.idx_scan == 0 + { + entries.push(UnusedIndexEntry { + schema: t.schema.clone(), + table: t.name.clone(), + index_name: idx.name.clone(), + total_idx_scan: 0, + total_size_bytes: stats.size, + is_unique: idx.is_unique, + definition: idx.definition.clone(), + }); + } } } } else { @@ -675,7 +691,12 @@ mod tests { } } - fn make_index(name: &str, is_primary: bool, is_unique: bool, stats: Option) -> Index { + fn make_index( + name: &str, + is_primary: bool, + is_unique: bool, + stats: Option, + ) -> Index { Index { name: name.into(), columns: vec!["col".into()], @@ -724,9 +745,15 @@ mod tests { #[test] fn test_single_node_unused_index_detected() { - let tables = vec![make_table("orders", vec![ - make_index("idx_unused", false, false, Some(make_index_stats(0, 8192))), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index( + "idx_unused", + false, + false, + Some(make_index_stats(0, 8192)), + )], + )]; let result = detect_unused_indexes(&[], &tables); assert_eq!(result.len(), 1); @@ -736,9 +763,15 @@ mod tests { #[test] fn test_single_node_used_index_not_reported() { - let tables = vec![make_table("orders", vec![ - make_index("idx_used", false, false, Some(make_index_stats(42, 8192))), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index( + "idx_used", + false, + false, + Some(make_index_stats(42, 8192)), + )], + )]; let result = detect_unused_indexes(&[], &tables); assert!(result.is_empty()); @@ -746,9 +779,15 @@ mod tests { #[test] fn test_single_node_primary_key_skipped() { - let tables = vec![make_table("orders", vec![ - make_index("orders_pkey", true, true, Some(make_index_stats(0, 8192))), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index( + "orders_pkey", + true, + true, + Some(make_index_stats(0, 8192)), + )], + )]; let result = detect_unused_indexes(&[], &tables); assert!(result.is_empty()); @@ -756,9 +795,10 @@ mod tests { #[test] fn test_single_node_no_stats_skipped() { - let tables = vec![make_table("orders", vec![ - make_index("idx_no_stats", false, false, None), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index("idx_no_stats", false, false, None)], + )]; let result = detect_unused_indexes(&[], &tables); assert!(result.is_empty()); @@ -766,9 +806,15 @@ mod tests { #[test] fn test_single_node_unique_flag_preserved() { - let tables = vec![make_table("orders", vec![ - make_index("idx_unique_unused", false, true, Some(make_index_stats(0, 4096))), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index( + "idx_unique_unused", + false, + true, + Some(make_index_stats(0, 4096)), + )], + )]; let result = detect_unused_indexes(&[], &tables); assert_eq!(result.len(), 1); @@ -779,21 +825,30 @@ mod tests { #[test] fn test_multi_node_unused_across_all_nodes() { - let tables = vec![make_table("orders", vec![ - make_index("idx_unused", false, false, None), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index("idx_unused", false, false, None)], + )]; let node_stats = vec![ - make_node_stats("node1", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), - index_name: "idx_unused".into(), - stats: make_index_stats(0, 8192), - }]), - make_node_stats("node2", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), - index_name: "idx_unused".into(), - stats: make_index_stats(0, 16384), - }]), + make_node_stats( + "node1", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), + index_name: "idx_unused".into(), + stats: make_index_stats(0, 8192), + }], + ), + make_node_stats( + "node2", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), + index_name: "idx_unused".into(), + stats: make_index_stats(0, 16384), + }], + ), ]; let result = detect_unused_indexes(&node_stats, &tables); @@ -805,21 +860,30 @@ mod tests { #[test] fn test_multi_node_used_on_one_node_not_reported() { - let tables = vec![make_table("orders", vec![ - make_index("idx_partial_use", false, false, None), - ])]; + let tables = vec![make_table( + "orders", + vec![make_index("idx_partial_use", false, false, None)], + )]; let node_stats = vec![ - make_node_stats("node1", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), - index_name: "idx_partial_use".into(), - stats: make_index_stats(0, 8192), - }]), - make_node_stats("node2", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), - index_name: "idx_partial_use".into(), - stats: make_index_stats(5, 8192), - }]), + make_node_stats( + "node1", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), + index_name: "idx_partial_use".into(), + stats: make_index_stats(0, 8192), + }], + ), + make_node_stats( + "node2", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), + index_name: "idx_partial_use".into(), + stats: make_index_stats(5, 8192), + }], + ), ]; let result = detect_unused_indexes(&node_stats, &tables); @@ -828,17 +892,20 @@ mod tests { #[test] fn test_multi_node_primary_key_skipped() { - let tables = vec![make_table("orders", vec![ - make_index("orders_pkey", true, true, None), - ])]; - - let node_stats = vec![ - make_node_stats("node1", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), + let tables = vec![make_table( + "orders", + vec![make_index("orders_pkey", true, true, None)], + )]; + + let node_stats = vec![make_node_stats( + "node1", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), index_name: "orders_pkey".into(), stats: make_index_stats(0, 8192), - }]), - ]; + }], + )]; let result = detect_unused_indexes(&node_stats, &tables); assert!(result.is_empty()); @@ -846,25 +913,31 @@ mod tests { #[test] fn test_multi_node_sorted_by_size_desc() { - let tables = vec![make_table("orders", vec![ - make_index("idx_small", false, false, None), - make_index("idx_big", false, false, None), - ])]; - - let node_stats = vec![ - make_node_stats("node1", vec![ + let tables = vec![make_table( + "orders", + vec![ + make_index("idx_small", false, false, None), + make_index("idx_big", false, false, None), + ], + )]; + + let node_stats = vec![make_node_stats( + "node1", + vec![ NodeIndexStats { - schema: "public".into(), table: "orders".into(), + schema: "public".into(), + table: "orders".into(), index_name: "idx_small".into(), stats: make_index_stats(0, 1024), }, NodeIndexStats { - schema: "public".into(), table: "orders".into(), + schema: "public".into(), + table: "orders".into(), index_name: "idx_big".into(), stats: make_index_stats(0, 999_999), }, - ]), - ]; + ], + )]; let result = detect_unused_indexes(&node_stats, &tables); assert_eq!(result.len(), 2); @@ -877,13 +950,15 @@ mod tests { // index in node_stats but not in tables — should still appear with defaults let tables: Vec
= vec![]; - let node_stats = vec![ - make_node_stats("node1", vec![NodeIndexStats { - schema: "public".into(), table: "orders".into(), + let node_stats = vec![make_node_stats( + "node1", + vec![NodeIndexStats { + schema: "public".into(), + table: "orders".into(), index_name: "idx_ghost".into(), stats: make_index_stats(0, 4096), - }]), - ]; + }], + )]; let result = detect_unused_indexes(&node_stats, &tables); assert_eq!(result.len(), 1); @@ -902,9 +977,10 @@ mod tests { // use aggregated multi-node stats over table-level stats pub fn effective_table_stats(table: &Table, schema: &SchemaSnapshot) -> Option { if !schema.node_stats.is_empty() - && let Some(agg) = aggregate_table_stats(&schema.node_stats, &table.schema, &table.name) { - return Some(agg); - } + && let Some(agg) = aggregate_table_stats(&schema.node_stats, &table.schema, &table.name) + { + return Some(agg); + } table.stats.clone() } diff --git a/crates/dry_run_core/src/schema/vacuum.rs b/crates/dry_run_core/src/schema/vacuum.rs index a8bc129..3fa6951 100644 --- a/crates/dry_run_core/src/schema/vacuum.rs +++ b/crates/dry_run_core/src/schema/vacuum.rs @@ -98,21 +98,25 @@ pub fn analyze_vacuum_health(snap: &SchemaSnapshot) -> Vec { let mut av_enabled = defaults.enabled; if let Some(v) = opts.get("autovacuum_vacuum_threshold") - && let Ok(parsed) = v.parse::() { - threshold = parsed; - } + && let Ok(parsed) = v.parse::() + { + threshold = parsed; + } if let Some(v) = opts.get("autovacuum_vacuum_scale_factor") - && let Ok(parsed) = v.parse::() { - scale_factor = parsed; - } + && let Ok(parsed) = v.parse::() + { + scale_factor = parsed; + } if let Some(v) = opts.get("autovacuum_analyze_threshold") - && let Ok(parsed) = v.parse::() { - analyze_threshold = parsed; - } + && let Ok(parsed) = v.parse::() + { + analyze_threshold = parsed; + } if let Some(v) = opts.get("autovacuum_analyze_scale_factor") - && let Ok(parsed) = v.parse::() { - analyze_scale_factor = parsed; - } + && let Ok(parsed) = v.parse::() + { + analyze_scale_factor = parsed; + } if let Some(v) = opts.get("autovacuum_enabled") { av_enabled = v == "on" || v == "true"; } @@ -156,9 +160,7 @@ pub fn analyze_vacuum_health(snap: &SchemaSnapshot) -> Vec { )); } - if stats.reltuples > 0.0 - && stats.dead_tuples as f64 / stats.reltuples > 0.10 - { + if stats.reltuples > 0.0 && stats.dead_tuples as f64 / stats.reltuples > 0.10 { recommendations.push(format!( "high dead tuple ratio: {} dead / {}k live ({:.1}%)", stats.dead_tuples, @@ -265,7 +267,12 @@ mod tests { let snap = make_snap(vec![make_table_with_stats("big", 5_000_000.0, 100)]); let results = analyze_vacuum_health(&snap); assert_eq!(results.len(), 1); - assert!(results[0].recommendations.iter().any(|r| r.contains("large table"))); + assert!( + results[0] + .recommendations + .iter() + .any(|r| r.contains("large table")) + ); } #[test] @@ -273,7 +280,12 @@ mod tests { let snap = make_snap(vec![make_table_with_stats("dirty", 100_000.0, 20_000)]); let results = analyze_vacuum_health(&snap); assert_eq!(results.len(), 1); - assert!(results[0].recommendations.iter().any(|r| r.contains("high dead tuple"))); + assert!( + results[0] + .recommendations + .iter() + .any(|r| r.contains("high dead tuple")) + ); } #[test] @@ -283,17 +295,38 @@ mod tests { let snap = make_snap(vec![table]); let results = analyze_vacuum_health(&snap); assert_eq!(results.len(), 1); - assert!(results[0].recommendations.iter().any(|r| r.contains("disabled"))); + assert!( + results[0] + .recommendations + .iter() + .any(|r| r.contains("disabled")) + ); assert!(!results[0].autovacuum_enabled); } #[test] fn parses_defaults_from_gucs() { let gucs = vec![ - GucSetting { name: "autovacuum_vacuum_threshold".into(), setting: "100".into(), unit: None }, - GucSetting { name: "autovacuum_vacuum_scale_factor".into(), setting: "0.05".into(), unit: None }, - GucSetting { name: "autovacuum_analyze_threshold".into(), setting: "200".into(), unit: None }, - GucSetting { name: "autovacuum_analyze_scale_factor".into(), setting: "0.02".into(), unit: None }, + GucSetting { + name: "autovacuum_vacuum_threshold".into(), + setting: "100".into(), + unit: None, + }, + GucSetting { + name: "autovacuum_vacuum_scale_factor".into(), + setting: "0.05".into(), + unit: None, + }, + GucSetting { + name: "autovacuum_analyze_threshold".into(), + setting: "200".into(), + unit: None, + }, + GucSetting { + name: "autovacuum_analyze_scale_factor".into(), + setting: "0.02".into(), + unit: None, + }, ]; let d = parse_autovacuum_defaults(&gucs); assert_eq!(d.vacuum_threshold, 100); diff --git a/crates/dry_run_core/src/version.rs b/crates/dry_run_core/src/version.rs index 3ee5c4f..2b262d7 100644 --- a/crates/dry_run_core/src/version.rs +++ b/crates/dry_run_core/src/version.rs @@ -85,7 +85,14 @@ mod tests { "PostgreSQL 17.2 on x86_64-pc-linux-gnu, compiled by gcc 12.2.0, 64-bit", ) .unwrap(); - assert_eq!(v, PgVersion { major: 17, minor: 2, patch: 0 }); + assert_eq!( + v, + PgVersion { + major: 17, + minor: 2, + patch: 0 + } + ); } #[test] @@ -94,19 +101,40 @@ mod tests { "PostgreSQL 16.1.3 (Debian 16.1.3-1) on aarch64-unknown-linux-gnu", ) .unwrap(); - assert_eq!(v, PgVersion { major: 16, minor: 1, patch: 3 }); + assert_eq!( + v, + PgVersion { + major: 16, + minor: 1, + patch: 3 + } + ); } #[test] fn parse_pg14_beta() { let v = PgVersion::parse_from_version_string("PostgreSQL 14.0beta1 on x86_64").unwrap(); - assert_eq!(v, PgVersion { major: 14, minor: 0, patch: 0 }); + assert_eq!( + v, + PgVersion { + major: 14, + minor: 0, + patch: 0 + } + ); } #[test] fn parse_pg12_minor_only() { let v = PgVersion::parse_from_version_string("PostgreSQL 12.18 on aarch64").unwrap(); - assert_eq!(v, PgVersion { major: 12, minor: 18, patch: 0 }); + assert_eq!( + v, + PgVersion { + major: 12, + minor: 18, + patch: 0 + } + ); } #[test] diff --git a/docs/dryrun-toml.md b/docs/dryrun-toml.md index 0386b55..b1b71b8 100644 --- a/docs/dryrun-toml.md +++ b/docs/dryrun-toml.md @@ -5,6 +5,9 @@ Project configuration. dryrun finds this file by walking up from the current dir ## Minimal example ```toml +[project] +id = "myapp" + [default] profile = "offline" @@ -14,6 +17,15 @@ schema_file = ".dryrun/schema.json" That's it. Everything else has sensible defaults. +## Project + +```toml +[project] +id = "myapp" +``` + +Identifies the project. Snapshots are keyed by `(project_id, database_id)` so a single store can hold history for multiple projects without collisions. Defaults to the cwd basename if absent. + ## Profiles A profile points dryrun at a schema source, either an offline JSON snapshot or a live database connection. Most projects have two or three: one for offline work, one for local dev, maybe one for staging. Each profile has a name and exactly one source. @@ -27,8 +39,14 @@ db_url = "postgresql://dev:dev@localhost:5432/myapp" [profiles.staging] db_url = "${STAGING_DATABASE_URL}" # environment variables work + +[profiles.prod-auth] +db_url = "${PROD_AUTH_DATABASE_URL}" +database_id = "auth" # set when a project has multiple databases ``` +`database_id` defaults to the profile name. Override it when you want the snapshot stream named differently from the profile (e.g. profile `prod-auth` → stream `auth`). + Pick one with `--profile`, or set a default: ```toml @@ -38,12 +56,16 @@ profile = "offline" ### Resolution order -1. `--db` flag (CLI only, bypasses profiles entirely) -2. `--schema-file` flag (CLI only) -3. `--profile` flag -4. `PROFILE` environment variable -5. `[default].profile` in dryrun.toml -6. Auto-discovery of `.dryrun/schema.json` +A profile is selected from: + +1. `--profile` flag +2. `PROFILE` environment variable +3. `[default].profile` in dryrun.toml +4. Auto-discovery of `.dryrun/schema.json` (no profile, just a schema) + +CLI flags `--db` and `--schema-file` override the resolved profile's matching fields for that invocation; they don't bypass the profile, so `database_id` and `project_id` are still taken from it. `--profile billing --db $OTHER` connects to `$OTHER` but keys snapshots under billing's `database_id`. + +Every DB command (`init`, `import`, `probe`, `dump-schema`, `lint`, `drift`, `stats apply`, all `snapshot` subcommands) accepts `--profile` and falls back to the resolved profile's `db_url` / `schema_file` when the corresponding CLI flag is omitted. Relative paths in `schema_file` are resolved from the project root (the directory containing `dryrun.toml`). Absolute paths work too. diff --git a/examples/demo/dryrun.toml b/examples/demo/dryrun.toml index f0bf2cf..f995369 100644 --- a/examples/demo/dryrun.toml +++ b/examples/demo/dryrun.toml @@ -1,3 +1,6 @@ +[project] +id = "demo" + [default] profile = "offline"