diff --git a/src/lib/db/index.ts b/src/lib/db/index.ts index e876ca9..c1baad9 100644 --- a/src/lib/db/index.ts +++ b/src/lib/db/index.ts @@ -28,13 +28,18 @@ export async function withTenantDb( ): Promise { const client: PoolClient = await pool.connect(); try { - // Set tenant context for RLS — uses parameterized SET to prevent SQL injection - await client.query(`SET LOCAL app.tenant_id = $1`, [tenantId]); + await client.query('BEGIN'); + // Set tenant context for RLS — set_config supports parameterized queries + // Third arg `true` = LOCAL (scoped to current transaction only) + await client.query(`SELECT set_config('app.tenant_id', $1, true)`, [tenantId]); const tenantDb = drizzle(client, { schema }); - return await callback(tenantDb); + const result = await callback(tenantDb); + await client.query('COMMIT'); + return result; + } catch (error) { + await client.query('ROLLBACK'); + throw error; } finally { - // RESET ensures no tenant context leaks to the next user of this connection - await client.query(`RESET app.tenant_id`); client.release(); } } diff --git a/src/lib/db/tenant.ts b/src/lib/db/tenant.ts index 0cee838..eb719f7 100644 --- a/src/lib/db/tenant.ts +++ b/src/lib/db/tenant.ts @@ -14,8 +14,14 @@ export async function withTenant( ): Promise { const client = await pool.connect(); try { - await client.query(`SET LOCAL app.tenant_id = '${tenantId}'`); - return await fn(client); + await client.query('BEGIN'); + await client.query(`SELECT set_config('app.tenant_id', $1, true)`, [tenantId]); + const result = await fn(client); + await client.query('COMMIT'); + return result; + } catch (error) { + await client.query('ROLLBACK'); + throw error; } finally { client.release(); }