package db import ( "database/sql" "errors" "fmt" "strings" ) // CableEndpoint identifies one side of a cable. Exactly one of PortID / // DeviceID / IOID must be non-nil; the store enforces this. type CableEndpoint struct { PortID *int64 DeviceID *int64 IOID *int64 } // CableCreate is the create-shape for /api/projects/:pid/cables. // auto=false (default) marks the cable as m-drawn; the solver writes // auto=true when it places its rows. type CableCreate struct { TypeID int64 Label string From CableEndpoint To CableEndpoint Auto bool } // CableUpdate is a partial update. PATCHing endpoint or type on an // auto=1 cable should promote it to manual; handler logic does that // (see slice 6 §5b.3). type CableUpdate struct { TypeID *int64 Label *string From *CableEndpoint To *CableEndpoint Auto *bool } // CreateCable inserts a cable. Validates that the endpoints exist in // the same project, that exactly one of (port/device/io) is set per side, // and that the cable type is real. func (s *Store) CreateCable(projectID int64, c CableCreate) (*Cable, error) { return s.createCable(s.db, projectID, c) } // createCable on a TX-or-DB executor; solver uses the tx form. func (s *Store) createCable(ex execer, projectID int64, c CableCreate) (*Cable, error) { if err := s.validateEndpointEx(ex, projectID, "from", c.From); err != nil { return nil, err } if err := s.validateEndpointEx(ex, projectID, "to", c.To); err != nil { return nil, err } if err := s.assertCableTypeEx(ex, c.TypeID); err != nil { return nil, err } autoInt := 0 if c.Auto { autoInt = 1 } res, err := ex.Exec( `INSERT INTO cables (project_id, type_id, label, from_port_id, from_device_id, from_io_id, to_port_id, to_device_id, to_io_id, auto) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, projectID, c.TypeID, nullableString(c.Label), nullableInt64(c.From.PortID), nullableInt64(c.From.DeviceID), nullableInt64(c.From.IOID), nullableInt64(c.To.PortID), nullableInt64(c.To.DeviceID), nullableInt64(c.To.IOID), autoInt, ) if err != nil { return nil, mapWriteErr(err) } id, _ := res.LastInsertId() return s.getCableTx(ex, projectID, id) } // validateEndpoint is the s.db variant for public CRUD callers. func (s *Store) validateEndpoint(projectID int64, label string, e CableEndpoint) error { return s.validateEndpointEx(s.db, projectID, label, e) } // validateEndpointEx runs the same checks against any executor so the // solver can call createCable inside its tx without deadlocking on the // MaxOpenConns(1) connection that the tx holds. func (s *Store) validateEndpointEx(ex execer, projectID int64, label string, e CableEndpoint) error { count := 0 if e.PortID != nil { count++ } if e.DeviceID != nil { count++ } if e.IOID != nil { count++ } if count != 1 { return fmt.Errorf("%w: %s must specify exactly one of port/device/io", ErrInvalidInput, label) } if e.PortID != nil { var pid int64 err := ex.QueryRow(`SELECT project_id FROM ports WHERE id = ?`, *e.PortID).Scan(&pid) if errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("%w: %s port_id %d not found", ErrInvalidInput, label, *e.PortID) } if err != nil { return err } if pid != projectID { return fmt.Errorf("%w: %s port_id %d is in another project", ErrInvalidInput, label, *e.PortID) } } if e.DeviceID != nil { var pid int64 err := ex.QueryRow(`SELECT project_id FROM devices WHERE id = ?`, *e.DeviceID).Scan(&pid) if errors.Is(err, sql.ErrNoRows) || (err == nil && pid != projectID) { return fmt.Errorf("%w: %s device_id %d not in project", ErrInvalidInput, label, *e.DeviceID) } if err != nil { return err } } if e.IOID != nil { var pid int64 err := ex.QueryRow(`SELECT project_id FROM io_markers WHERE id = ?`, *e.IOID).Scan(&pid) if errors.Is(err, sql.ErrNoRows) || (err == nil && pid != projectID) { return fmt.Errorf("%w: %s io_id %d not in project", ErrInvalidInput, label, *e.IOID) } if err != nil { return err } } return nil } // assertCableTypeEx is a lightweight existence check against any executor. func (s *Store) assertCableTypeEx(ex execer, id int64) error { var dummy int64 err := ex.QueryRow(`SELECT id FROM cable_types WHERE id = ?`, id).Scan(&dummy) if errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("%w: cable type %d not found", ErrInvalidInput, id) } return err } func (s *Store) GetCable(projectID, id int64) (*Cable, error) { return s.getCableTx(s.db, projectID, id) } func (s *Store) getCableTx(ex execer, projectID, id int64) (*Cable, error) { var c Cable var fp, fd, fio, tp, td, tio sql.NullInt64 var label, ex2 sql.NullString var autoInt int err := ex.QueryRow( `SELECT id, project_id, type_id, label, from_port_id, from_device_id, from_io_id, to_port_id, to_device_id, to_io_id, auto, excalidraw_id, created_at, updated_at FROM cables WHERE id = ? AND project_id = ?`, id, projectID, ).Scan(&c.ID, &c.ProjectID, &c.TypeID, &label, &fp, &fd, &fio, &tp, &td, &tio, &autoInt, &ex2, &c.CreatedAt, &c.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, err } if label.Valid { v := label.String c.Label = &v } if fp.Valid { v := fp.Int64 c.FromPortID = &v } if fd.Valid { v := fd.Int64 c.FromDeviceID = &v } if fio.Valid { v := fio.Int64 c.FromIOID = &v } if tp.Valid { v := tp.Int64 c.ToPortID = &v } if td.Valid { v := td.Int64 c.ToDeviceID = &v } if tio.Valid { v := tio.Int64 c.ToIOID = &v } c.Auto = autoInt != 0 if ex2.Valid { c.ExcalidrawID = &ex2.String } return &c, nil } // ListCables returns every cable in a project. func (s *Store) ListCables(projectID int64) ([]Cable, error) { return s.listCablesTx(s.db, projectID) } func (s *Store) listCablesTx(ex execer, projectID int64) ([]Cable, error) { rows, err := ex.Query( `SELECT id, project_id, type_id, label, from_port_id, from_device_id, from_io_id, to_port_id, to_device_id, to_io_id, auto, excalidraw_id, created_at, updated_at FROM cables WHERE project_id = ? ORDER BY id`, projectID, ) if err != nil { return nil, err } defer rows.Close() out := []Cable{} for rows.Next() { var c Cable var fp, fd, fio, tp, td, tio sql.NullInt64 var label, ex2 sql.NullString var autoInt int if err := rows.Scan(&c.ID, &c.ProjectID, &c.TypeID, &label, &fp, &fd, &fio, &tp, &td, &tio, &autoInt, &ex2, &c.CreatedAt, &c.UpdatedAt); err != nil { return nil, err } if label.Valid { v := label.String c.Label = &v } if fp.Valid { v := fp.Int64 c.FromPortID = &v } if fd.Valid { v := fd.Int64 c.FromDeviceID = &v } if fio.Valid { v := fio.Int64 c.FromIOID = &v } if tp.Valid { v := tp.Int64 c.ToPortID = &v } if td.Valid { v := td.Int64 c.ToDeviceID = &v } if tio.Valid { v := tio.Int64 c.ToIOID = &v } c.Auto = autoInt != 0 if ex2.Valid { c.ExcalidrawID = &ex2.String } out = append(out, c) } return out, rows.Err() } // UpdateCable applies a partial update. Caller-controlled — promote-to- // manual semantics live at the handler level (§5b.3: any PATCH touching // type/endpoint promotes auto→0). func (s *Store) UpdateCable(projectID, id int64, u CableUpdate) (*Cable, error) { cur, err := s.GetCable(projectID, id) if err != nil { return nil, err } if u.TypeID != nil { if _, err := s.GetCableType(*u.TypeID); err != nil { if errors.Is(err, ErrNotFound) { return nil, fmt.Errorf("%w: cable type %d not found", ErrInvalidInput, *u.TypeID) } return nil, err } cur.TypeID = *u.TypeID } if u.Label != nil { v := strings.TrimSpace(*u.Label) if v == "" { cur.Label = nil } else { cur.Label = &v } } if u.From != nil { if err := s.validateEndpoint(projectID, "from", *u.From); err != nil { return nil, err } cur.FromPortID = u.From.PortID cur.FromDeviceID = u.From.DeviceID cur.FromIOID = u.From.IOID } if u.To != nil { if err := s.validateEndpoint(projectID, "to", *u.To); err != nil { return nil, err } cur.ToPortID = u.To.PortID cur.ToDeviceID = u.To.DeviceID cur.ToIOID = u.To.IOID } if u.Auto != nil { cur.Auto = *u.Auto } autoInt := 0 if cur.Auto { autoInt = 1 } if _, err := s.db.Exec( `UPDATE cables SET type_id = ?, label = ?, from_port_id = ?, from_device_id = ?, from_io_id = ?, to_port_id = ?, to_device_id = ?, to_io_id = ?, auto = ?, updated_at = datetime('now') WHERE id = ? AND project_id = ?`, cur.TypeID, nullableStringPtr(cur.Label), nullableInt64(cur.FromPortID), nullableInt64(cur.FromDeviceID), nullableInt64(cur.FromIOID), nullableInt64(cur.ToPortID), nullableInt64(cur.ToDeviceID), nullableInt64(cur.ToIOID), autoInt, id, projectID, ); err != nil { return nil, mapWriteErr(err) } return s.GetCable(projectID, id) } // DeleteCable removes a cable from a project. func (s *Store) DeleteCable(projectID, id int64) error { if _, err := s.GetCable(projectID, id); err != nil { return err } if _, err := s.db.Exec( `DELETE FROM cables WHERE id = ? AND project_id = ?`, id, projectID, ); err != nil { return err } return nil } // nullableString → for label-style strings: "" → SQL NULL. func nullableString(s string) any { if s == "" { return nil } return s } func nullableStringPtr(p *string) any { if p == nil { return nil } return *p } // execer abstracts *sql.DB and *sql.Tx for store helpers used by both // the public API and inside transactions (e.g. the solver). type execer interface { Exec(query string, args ...any) (sql.Result, error) Query(query string, args ...any) (*sql.Rows, error) QueryRow(query string, args ...any) *sql.Row }