package db import ( "database/sql" "errors" "fmt" "strings" ) // BundleCreate is the create-shape: a name + the cable IDs to include. // Auto=true means the solver created the bundle; user-created bundles // stay auto=0 and survive a re-solve. type BundleCreate struct { Name string CableIDs []int64 Auto bool } type BundleUpdate struct { Name *string CableIDs *[]int64 } // CreateBundle inserts a bundle + its cable_bundle rows in one tx. func (s *Store) CreateBundle(projectID int64, b BundleCreate) (*Bundle, error) { return s.createBundle(s.db, projectID, b, true) } func (s *Store) createBundle(ex execer, projectID int64, b BundleCreate, ownTx bool) (*Bundle, error) { name := strings.TrimSpace(b.Name) if name == "" { return nil, fmt.Errorf("%w: name is required", ErrInvalidInput) } // When the caller already holds a tx (ownTx=false), do all validation // against `ex` (the tx executor) — calling Store methods that hit // s.db would deadlock against the connection the tx is holding under // MaxOpenConns(1). for _, cid := range b.CableIDs { if _, err := s.getCableTx(ex, projectID, cid); err != nil { if errors.Is(err, ErrNotFound) { return nil, fmt.Errorf("%w: cable_id %d not in project", ErrInvalidInput, cid) } return nil, err } } autoInt := 0 if b.Auto { autoInt = 1 } var tx *sql.Tx var err error useEx := ex if ownTx { tx, err = s.db.Begin() if err != nil { return nil, err } defer tx.Rollback() useEx = tx } res, err := useEx.Exec( `INSERT INTO bundles (project_id, name, auto) VALUES (?, ?, ?)`, projectID, name, autoInt, ) if err != nil { return nil, mapWriteErr(err) } id, _ := res.LastInsertId() for _, cid := range b.CableIDs { if _, err := useEx.Exec( `INSERT INTO bundle_cables (bundle_id, cable_id) VALUES (?, ?)`, id, cid, ); err != nil { return nil, mapWriteErr(err) } } if ownTx { if err := tx.Commit(); err != nil { return nil, err } return s.GetBundle(projectID, id) } // In tx-inheriting mode, build the response struct locally — the // caller will re-fetch via GetBundle after commit if it needs more. out := &Bundle{ ID: id, ProjectID: projectID, Name: name, Auto: b.Auto, CableIDs: append([]int64(nil), b.CableIDs...), } return out, nil } func (s *Store) GetBundle(projectID, id int64) (*Bundle, error) { var b Bundle var autoInt int err := s.db.QueryRow( `SELECT id, project_id, name, auto, created_at, updated_at FROM bundles WHERE id = ? AND project_id = ?`, id, projectID, ).Scan(&b.ID, &b.ProjectID, &b.Name, &autoInt, &b.CreatedAt, &b.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } b.Auto = autoInt != 0 ids, err := s.bundleCableIDs(id) if err != nil { return nil, err } b.CableIDs = ids return &b, nil } func (s *Store) bundleCableIDs(bundleID int64) ([]int64, error) { rows, err := s.db.Query( `SELECT cable_id FROM bundle_cables WHERE bundle_id = ? ORDER BY cable_id`, bundleID, ) if err != nil { return nil, err } defer rows.Close() out := []int64{} for rows.Next() { var v int64 if err := rows.Scan(&v); err != nil { return nil, err } out = append(out, v) } return out, rows.Err() } // ListBundles returns every bundle in a project, ordered by id. func (s *Store) ListBundles(projectID int64) ([]Bundle, error) { rows, err := s.db.Query( `SELECT id, project_id, name, auto, created_at, updated_at FROM bundles WHERE project_id = ? ORDER BY id`, projectID, ) if err != nil { return nil, err } defer rows.Close() out := []Bundle{} for rows.Next() { var b Bundle var autoInt int if err := rows.Scan(&b.ID, &b.ProjectID, &b.Name, &autoInt, &b.CreatedAt, &b.UpdatedAt); err != nil { return nil, err } b.Auto = autoInt != 0 out = append(out, b) } if err := rows.Err(); err != nil { return nil, err } for i := range out { ids, err := s.bundleCableIDs(out[i].ID) if err != nil { return nil, err } out[i].CableIDs = ids } return out, nil } // UpdateBundle: name + cable set are mutable. Replacing cables wipes // bundle_cables and re-inserts in one tx. func (s *Store) UpdateBundle(projectID, id int64, u BundleUpdate) (*Bundle, error) { cur, err := s.GetBundle(projectID, id) if err != nil { return nil, err } if u.Name != nil { v := strings.TrimSpace(*u.Name) if v == "" { return nil, fmt.Errorf("%w: name cannot be empty", ErrInvalidInput) } cur.Name = v } tx, err := s.db.Begin() if err != nil { return nil, err } defer tx.Rollback() if _, err := tx.Exec( `UPDATE bundles SET name = ?, updated_at = datetime('now') WHERE id = ?`, cur.Name, id, ); err != nil { return nil, mapWriteErr(err) } if u.CableIDs != nil { if _, err := tx.Exec(`DELETE FROM bundle_cables WHERE bundle_id = ?`, id); err != nil { return nil, err } for _, cid := range *u.CableIDs { if _, err := s.getCableTx(tx, projectID, cid); err != nil { return nil, fmt.Errorf("%w: cable_id %d not in project", ErrInvalidInput, cid) } if _, err := tx.Exec( `INSERT INTO bundle_cables (bundle_id, cable_id) VALUES (?, ?)`, id, cid, ); err != nil { return nil, mapWriteErr(err) } } } if err := tx.Commit(); err != nil { return nil, err } return s.GetBundle(projectID, id) } func (s *Store) DeleteBundle(projectID, id int64) error { if _, err := s.GetBundle(projectID, id); err != nil { return err } if _, err := s.db.Exec( `DELETE FROM bundles WHERE id = ? AND project_id = ?`, id, projectID, ); err != nil { return err } return nil }