mAi: #10 - multi-model backend expansion (workflow templates + compare harness)
Path 1 architecture: one comfyui adapter, workflows as data.
- workflow_template.go: embed.FS + token substitution with type-preserving
whole-value placeholders. ${prompt} → string, ${seed} → int64,
${cfg} → float64 — no JSON round-tripping. Partial matches ignored.
- comfyui.go: refactored to load workflow from embedded FS or filesystem
path. Back-compat preserved: workflow: defaults to flux1-schnell.
- workflows/{flux1-schnell,flux2-klein,sd35-medium}.json — bundled
templates. flux1-schnell migrated from hardcoded with identical node IDs.
- compare.go: new `imagen compare` subcommand. Sequential N-backend run
(one GPU on mRock — parallel would OOM), per-backend PNG, sidecar JSON
with per-model metadata + errors, composite contact sheet via Go image
package (no ImageMagick dep).
- Sample config gains flux2-klein-local + sd35-medium-local instances.
- docs/backends.md: architecture rationale + per-model HF download paths
+ how to add a new bundled workflow + compare-harness reference.
Live smoke verified: compare mock + flux-schnell-local at 768×768 →
both PNGs written, sidecar JSON has workflow="flux1-schnell" + full
metadata, contact sheet renders. Worker contract (Request → Generate)
unchanged, so flexsiebels /imagine UI API surface preserved.
Tests: 11 existing comfyui + 6 new workflow_template + 5 new compare
tests, all green.
Adding a new model is now yaml + JSON, never Go.
This commit is contained in:
386
cmd/imagen/compare.go
Normal file
386
cmd/imagen/compare.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/image/font"
|
||||
"golang.org/x/image/font/basicfont"
|
||||
"golang.org/x/image/math/fixed"
|
||||
|
||||
"mgit.msbls.de/m/ImaGen/internal/backend"
|
||||
"mgit.msbls.de/m/ImaGen/internal/config"
|
||||
"mgit.msbls.de/m/ImaGen/internal/output"
|
||||
"mgit.msbls.de/m/ImaGen/internal/prompt"
|
||||
)
|
||||
|
||||
// runCompare implements `imagen compare "<prompt>" --models a,b,c --output <dir>`.
|
||||
//
|
||||
// Each backend in --models runs sequentially against the same prompt (mRock
|
||||
// has a single GPU; parallelising would just OOM). Each generation lands as
|
||||
// a backend-suffixed file in the output dir; a contact sheet stitches them
|
||||
// together into one PNG with the backend name overlaid on each cell. A
|
||||
// sidecar JSON next to the contact sheet lists every generation with its
|
||||
// per-model metadata (latency, seed, model file, VRAM peak).
|
||||
func runCompare(ctx context.Context, args []string) error {
|
||||
fs := flag.NewFlagSet("compare", flag.ContinueOnError)
|
||||
var (
|
||||
modelsCSV string
|
||||
size string
|
||||
outDir string
|
||||
style string
|
||||
negative string
|
||||
seed int64
|
||||
steps int
|
||||
configPath string
|
||||
noContact bool
|
||||
)
|
||||
fs.StringVar(&modelsCSV, "models", "", "comma-separated backend instance names (required)")
|
||||
fs.StringVar(&size, "size", "1024x1024", "WxH for every backend")
|
||||
fs.StringVar(&outDir, "output", "", "directory to write the images + contact sheet (default: ~/Pictures/imagen/compare)")
|
||||
fs.StringVar(&style, "style", "", "style preset applied to the prompt before dispatching to each backend")
|
||||
fs.StringVar(&negative, "negative", "", "negative prompt (forwarded to every backend that supports it)")
|
||||
fs.Int64Var(&seed, "seed", 0, "deterministic seed for every backend (0 = each backend rolls its own)")
|
||||
fs.IntVar(&steps, "steps", 0, "diffusion steps (0 = each backend's default)")
|
||||
fs.StringVar(&configPath, "config", "", "config file path (default: ~/.config/imagen.yaml)")
|
||||
fs.BoolVar(&noContact, "no-contact-sheet", false, "skip the composite PNG; only write per-backend images + sidecar")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintln(fs.Output(), `Usage: imagen compare "<prompt>" --models a,b,c [flags]`)
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
leadingPositional, flagArgs := splitLeadingPositional(args)
|
||||
if err := fs.Parse(flagArgs); err != nil {
|
||||
return err
|
||||
}
|
||||
positional := append(leadingPositional, fs.Args()...)
|
||||
if len(positional) == 0 {
|
||||
fs.Usage()
|
||||
return userErr("missing prompt")
|
||||
}
|
||||
rawPrompt := strings.Join(positional, " ")
|
||||
modelNames := splitCSV(modelsCSV)
|
||||
if len(modelNames) == 0 {
|
||||
return userErr("--models is required (comma-separated backend instance names)")
|
||||
}
|
||||
|
||||
w, h, err := parseSize(size)
|
||||
if err != nil {
|
||||
return userErr("bad --size: %v", err)
|
||||
}
|
||||
|
||||
cfg, cfgErr := config.Load(configPath)
|
||||
if cfgErr != nil && !os.IsNotExist(cfgErr) {
|
||||
return cfgErr
|
||||
}
|
||||
|
||||
if outDir == "" {
|
||||
home, _ := os.UserHomeDir()
|
||||
outDir = filepath.Join(home, "Pictures", "imagen", "compare")
|
||||
}
|
||||
outDir = config.ExpandPath(outDir)
|
||||
|
||||
finalPrompt, err := prompt.Apply(rawPrompt, style)
|
||||
if err != nil {
|
||||
return userErr("%v", err)
|
||||
}
|
||||
|
||||
runID := time.Now().Format("20060102-150405")
|
||||
runDir := filepath.Join(outDir, runID+"-"+output.Slug(rawPrompt))
|
||||
if err := os.MkdirAll(runDir, 0o755); err != nil {
|
||||
return fmt.Errorf("mkdir %s: %w", runDir, err)
|
||||
}
|
||||
|
||||
results := make([]compareResult, 0, len(modelNames))
|
||||
for i, name := range modelNames {
|
||||
fmt.Fprintf(os.Stderr, "[%d/%d] %s ...\n", i+1, len(modelNames), name)
|
||||
res, err := generateOne(ctx, cfg, name, finalPrompt, negative, w, h, seed, steps, runDir, rawPrompt)
|
||||
if err != nil {
|
||||
// Don't abort the whole run on a single backend failure — record
|
||||
// the error and continue. flexsiebels-style consumers want to
|
||||
// see N-1 results rather than zero when one model is offline.
|
||||
fmt.Fprintf(os.Stderr, " failed: %v\n", err)
|
||||
results = append(results, compareResult{Backend: name, Error: err.Error()})
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, " %s (%d ms)\n", res.ImagePath, res.LatencyMs)
|
||||
results = append(results, res)
|
||||
}
|
||||
|
||||
// Sidecar JSON beside the run dir captures every attempt.
|
||||
sidecar := filepath.Join(runDir, "compare.json")
|
||||
if err := writeCompareSidecar(sidecar, rawPrompt, style, negative, w, h, seed, steps, results); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Fprintln(os.Stderr, "sidecar:", sidecar)
|
||||
|
||||
// Contact sheet stitches the successful results together. If every
|
||||
// backend failed there's nothing to draw, so skip silently.
|
||||
if !noContact {
|
||||
successes := successfulResults(results)
|
||||
if len(successes) > 0 {
|
||||
sheet := filepath.Join(runDir, "contact-sheet.png")
|
||||
if err := writeContactSheet(sheet, rawPrompt, successes); err != nil {
|
||||
return fmt.Errorf("contact sheet: %w", err)
|
||||
}
|
||||
fmt.Println(sheet)
|
||||
} else {
|
||||
fmt.Fprintln(os.Stderr, "imagen compare: all backends failed; no contact sheet written")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// compareResult is one backend's output in a comparison run. Error is set
|
||||
// when Generate failed for this backend; ImagePath + Metadata are empty in
|
||||
// that case.
|
||||
type compareResult struct {
|
||||
Backend string `json:"backend"`
|
||||
ImagePath string `json:"image_path,omitempty"`
|
||||
Seed int64 `json:"seed"`
|
||||
LatencyMs int64 `json:"latency_ms,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
VRAMUsedMiB int64 `json:"vram_used_mib,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func generateOne(ctx context.Context, cfg *config.Config, name, finalPrompt, negative string, w, h int, seed int64, steps int, runDir, rawPrompt string) (compareResult, error) {
|
||||
be, err := buildBackend(cfg, name)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, err
|
||||
}
|
||||
attachUsageSink(be)
|
||||
|
||||
req := backend.Request{
|
||||
Prompt: finalPrompt,
|
||||
NegativePrompt: negative,
|
||||
Width: w,
|
||||
Height: h,
|
||||
Steps: steps,
|
||||
Seed: seed,
|
||||
}
|
||||
res, err := be.Generate(ctx, req)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, err
|
||||
}
|
||||
defer res.ImageReader.Close()
|
||||
|
||||
imgBytes, err := io.ReadAll(res.ImageReader)
|
||||
if err != nil {
|
||||
return compareResult{Backend: name}, fmt.Errorf("read image: %w", err)
|
||||
}
|
||||
|
||||
imgPath := filepath.Join(runDir, output.Slug(rawPrompt)+"--"+output.Slug(name)+"."+extFromMime(res.MimeType))
|
||||
if err := os.WriteFile(imgPath, imgBytes, 0o644); err != nil {
|
||||
return compareResult{Backend: name}, fmt.Errorf("write %s: %w", imgPath, err)
|
||||
}
|
||||
|
||||
cr := compareResult{
|
||||
Backend: name,
|
||||
ImagePath: imgPath,
|
||||
Seed: seedFromMetadata(res.Metadata, seed),
|
||||
LatencyMs: metaInt64(res.Metadata, "latency_ms"),
|
||||
Model: metaString(res.Metadata, "model"),
|
||||
Metadata: res.Metadata,
|
||||
}
|
||||
if v, ok := res.Metadata["vram_used_mib"].(int64); ok {
|
||||
cr.VRAMUsedMiB = v
|
||||
}
|
||||
return cr, nil
|
||||
}
|
||||
|
||||
func successfulResults(rs []compareResult) []compareResult {
|
||||
out := make([]compareResult, 0, len(rs))
|
||||
for _, r := range rs {
|
||||
if r.Error == "" && r.ImagePath != "" {
|
||||
out = append(out, r)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeCompareSidecar(path, rawPrompt, style, negative string, w, h int, seed int64, steps int, results []compareResult) error {
|
||||
body := map[string]any{
|
||||
"timestamp": time.Now().UTC().Format(time.RFC3339),
|
||||
"prompt": rawPrompt,
|
||||
"style": style,
|
||||
"negative": negative,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"results": results,
|
||||
"backends": backendNames(results),
|
||||
"successful": len(successfulResults(results)),
|
||||
"total": len(results),
|
||||
}
|
||||
data, err := json.MarshalIndent(body, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal sidecar: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, append(data, '\n'), 0o644)
|
||||
}
|
||||
|
||||
func backendNames(rs []compareResult) []string {
|
||||
out := make([]string, len(rs))
|
||||
for i, r := range rs {
|
||||
out[i] = r.Backend
|
||||
}
|
||||
sort.Strings(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// writeContactSheet stitches a grid of (image, label) cells into one PNG.
|
||||
// Cells are sized to fit in a target width of ~2400px while keeping each
|
||||
// individual image full-resolution (no downscale) up to the column limit;
|
||||
// past that, images sit at their native size and we just lay them out.
|
||||
//
|
||||
// The grid is a simple horizontal row when N <= 4; otherwise N/2 rows of 2.
|
||||
// This is a contact sheet, not a fancy gallery — readability for side-by-
|
||||
// side eyeballing is the goal.
|
||||
func writeContactSheet(path, prompt string, results []compareResult) error {
|
||||
if len(results) == 0 {
|
||||
return fmt.Errorf("no successful results to lay out")
|
||||
}
|
||||
cells := make([]contactCell, 0, len(results))
|
||||
for _, r := range results {
|
||||
img, err := readPNG(r.ImagePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", r.ImagePath, err)
|
||||
}
|
||||
cells = append(cells, contactCell{
|
||||
Image: img,
|
||||
Label: r.Backend,
|
||||
SubLabel: fmt.Sprintf("%dms · seed %d", r.LatencyMs, r.Seed),
|
||||
})
|
||||
}
|
||||
|
||||
cols := len(cells)
|
||||
if cols > 4 {
|
||||
cols = 2
|
||||
}
|
||||
rows := (len(cells) + cols - 1) / cols
|
||||
|
||||
const labelH = 64
|
||||
const pad = 16
|
||||
|
||||
cellW := cells[0].Image.Bounds().Dx()
|
||||
cellH := cells[0].Image.Bounds().Dy()
|
||||
for _, c := range cells {
|
||||
if w := c.Image.Bounds().Dx(); w > cellW {
|
||||
cellW = w
|
||||
}
|
||||
if h := c.Image.Bounds().Dy(); h > cellH {
|
||||
cellH = h
|
||||
}
|
||||
}
|
||||
|
||||
totalW := cols*cellW + (cols+1)*pad
|
||||
totalH := rows*(cellH+labelH) + (rows+1)*pad + 48 // header band
|
||||
|
||||
canvas := image.NewRGBA(image.Rect(0, 0, totalW, totalH))
|
||||
draw.Draw(canvas, canvas.Bounds(), &image.Uniform{C: color.RGBA{R: 30, G: 30, B: 35, A: 255}}, image.Point{}, draw.Src)
|
||||
|
||||
// Header: show the truncated prompt.
|
||||
headerText := "imagen compare — " + truncate(prompt, 100)
|
||||
drawText(canvas, headerText, pad, 30, color.RGBA{R: 240, G: 240, B: 245, A: 255})
|
||||
|
||||
for i, c := range cells {
|
||||
col := i % cols
|
||||
row := i / cols
|
||||
x0 := pad + col*(cellW+pad)
|
||||
y0 := 48 + pad + row*(cellH+labelH+pad)
|
||||
// Center the image inside the cell when smaller than the max cell size.
|
||||
iw := c.Image.Bounds().Dx()
|
||||
ih := c.Image.Bounds().Dy()
|
||||
offX := (cellW - iw) / 2
|
||||
offY := (cellH - ih) / 2
|
||||
dstRect := image.Rect(x0+offX, y0+offY, x0+offX+iw, y0+offY+ih)
|
||||
draw.Draw(canvas, dstRect, c.Image, c.Image.Bounds().Min, draw.Src)
|
||||
|
||||
// Label band underneath.
|
||||
labelY := y0 + cellH + 20
|
||||
drawText(canvas, c.Label, x0+8, labelY, color.RGBA{R: 250, G: 250, B: 250, A: 255})
|
||||
drawText(canvas, c.SubLabel, x0+8, labelY+22, color.RGBA{R: 180, G: 180, B: 190, A: 255})
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create %s: %w", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
return png.Encode(f, canvas)
|
||||
}
|
||||
|
||||
type contactCell struct {
|
||||
Image image.Image
|
||||
Label string
|
||||
SubLabel string
|
||||
}
|
||||
|
||||
func readPNG(path string) (image.Image, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
img, _, err := image.Decode(f)
|
||||
return img, err
|
||||
}
|
||||
|
||||
func drawText(dst *image.RGBA, s string, x, y int, c color.Color) {
|
||||
drawer := &font.Drawer{
|
||||
Dst: dst,
|
||||
Src: &image.Uniform{C: c},
|
||||
Face: basicfont.Face7x13,
|
||||
Dot: fixed.Point26_6{X: fixed.I(x), Y: fixed.I(y)},
|
||||
}
|
||||
drawer.DrawString(s)
|
||||
}
|
||||
|
||||
func truncate(s string, max int) string {
|
||||
if len(s) <= max {
|
||||
return s
|
||||
}
|
||||
return s[:max-1] + "…"
|
||||
}
|
||||
|
||||
func splitCSV(s string) []string {
|
||||
parts := strings.Split(s, ",")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
out = append(out, p)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func metaInt64(m map[string]any, key string) int64 {
|
||||
v, ok := m[key]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
switch n := v.(type) {
|
||||
case int64:
|
||||
return n
|
||||
case int:
|
||||
return int64(n)
|
||||
case float64:
|
||||
return int64(n)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
Reference in New Issue
Block a user