Skip to content

Commit 4952f07

Browse files
authored
Refactoring of check category connections (#905)
#### Description This PR refactors the replication category. Now all checks share the same PostgreSQL connection. The new `postgres.LazyConn` is not thread-safe. It is used for executing checks sequentially. Once the check finishes, the cleanup function must be called. #### Type of Change Please select the relevant option(s): - [ ] 🐛 Bug fix (non-breaking change that fixes an issue) - [ ] ✨ New feature (non-breaking change that adds functionality) - [ ] 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] 📚 Documentation update - [x] 🔧 Refactoring (no functional changes) - [ ] ⚡ Performance improvement - [ ] 🧪 Test coverage improvement - [ ] 🔨 Build/CI changes - [ ] 🧹 Code cleanup #### Testing - [ ] Unit tests added/updated - [ ] Integration tests added/updated - [ ] Manual testing performed - [x] All existing tests pass #### Checklist - [x] Code follows project style guidelines - [x] Self-review completed - [x] Code is well-commented - [ ] Documentation updated where necessary
1 parent 2cf967b commit 4952f07

5 files changed

Lines changed: 118 additions & 39 deletions

File tree

cmd/check_cmd.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ var checkCmd = &cobra.Command{
3636
RunE: func(cmd *cobra.Command, args []string) error {
3737
sp, _ := pterm.DefaultSpinner.WithText("running pgstream checks...").Start()
3838

39-
err := func() error {
39+
err := func() (retErr error) {
4040
streamConfig, err := config.ParseStreamConfig()
4141
if err != nil {
4242
return fmt.Errorf("parsing stream config: %w", err)
4343
}
4444

45-
checks := preflight.BuildChecks(streamConfig, selectedCategories(cmd))
45+
checks, cleanup := preflight.BuildChecks(streamConfig, selectedCategories(cmd))
46+
defer func() {
47+
if cerr := cleanup(context.Background()); cerr != nil && retErr == nil {
48+
retErr = fmt.Errorf("releasing check resources: %w", cerr)
49+
}
50+
}()
51+
4652
if len(checks) == 0 {
4753
sp.Success("no checks to run")
4854
return nil

internal/postgres/pg_lazy_conn.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
package postgres
4+
5+
import "context"
6+
7+
// AcquireFunc lazily yields a Postgres connection. Useful when a set of
8+
// related callers want to share a single TCP connection without each one
9+
// having to think about lifecycle.
10+
type AcquireFunc func(ctx context.Context) (Querier, error)
11+
12+
// LazyConn memoises a single *Conn (or its dial error) for a URL. Cheap to
13+
// construct: nothing is opened until Acquire is called for the first time.
14+
// Not safe for concurrent use — designed for the sequential case (e.g. a
15+
// preflight check engine that runs checks one after another).
16+
type LazyConn struct {
17+
url string
18+
conn *Conn
19+
err error
20+
}
21+
22+
// NewLazyConn returns a LazyConn that will open a connection to url on first
23+
// Acquire.
24+
func NewLazyConn(url string) *LazyConn {
25+
return &LazyConn{url: url}
26+
}
27+
28+
// Acquire returns the cached conn, opening it on the first call. A dial
29+
// failure is cached too — subsequent calls return the same error without
30+
// retrying.
31+
func (l *LazyConn) Acquire(ctx context.Context) (Querier, error) {
32+
if l.conn != nil || l.err != nil {
33+
return l.conn, l.err
34+
}
35+
l.conn, l.err = NewConn(ctx, l.url)
36+
return l.conn, l.err
37+
}
38+
39+
// Close releases the underlying connection if one was opened.
40+
func (l *LazyConn) Close(ctx context.Context) error {
41+
if l.conn == nil {
42+
return nil
43+
}
44+
c := l.conn
45+
l.conn = nil
46+
return c.Close(ctx)
47+
}

pkg/stream/preflight/CLAUDE.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ Guidance for Claude Code when working inside `pkg/stream/preflight`. The planned
66

77
- `preflight.go``Check` interface (`Name()` + `Run(ctx) ([]Finding, error)`), `Finding`, `CheckResult`, `Report`, `Run(ctx, []Check, ...RunOption)` engine.
88
- `printer.go``ReportPrinter{Report}` is the only thing that formats reports. The `Report` struct itself stays pure data.
9-
- `builder.go``Builder` struct, `Builders` registry slice, per-category builder functions (`BuildConnectivityChecks`, …), `BuildChecks(cfg, selected)`.
9+
- `builder.go``Builder` struct (returns `[]Check` + optional cleanup), `Builders` registry slice, per-category builder functions (`BuildConnectivityChecks`, …), `BuildChecks(cfg, selected)`.
1010
- One file per category of concrete checks (`connectivity.go`, `replication.go`, …).
1111

12+
The shared-conn primitive lives one floor down at `internal/postgres.LazyConn` so other callers can reuse it.
13+
1214
## Adding a new check
1315

1416
Adding a check is meant to be a small, mechanical edit. Keep it that way.
@@ -21,6 +23,7 @@ Adding a check is meant to be a small, mechanical edit. Keep it that way.
2123
- **Return `error` only when the check itself couldn't run** (timeout, internal bug, malformed input). A detected problem is a `Finding`, not an error.
2224
- **Put remediation in `Finding.Message`** — the user should be able to act on it without reading source.
2325
3. **Materialise instances in the category builder** (e.g. `BuildConnectivityChecks`). The builder is the applicability gate: it reads `*stream.Config` and decides which instances are relevant. Inapplicable checks are silently omitted today; an explicit "skipped: <reason>" mechanism is deferred (see `docs/migration_preflight_issue.md` "Architecture decisions" #6).
26+
- **If checks in the category share a Postgres connection**, call `postgres.NewLazyConn(url)` in the builder, hand `src.Acquire` (a `postgres.AcquireFunc`) to every check, and return `src.Close` as the cleanup. See `BuildReplicationChecks` for the pattern. The engine runs sequentially, so the first check to call `Source(ctx)` opens the conn and the rest reuse it. A failed dial is memoised too — only one connection attempt happens, even if every check reports its own check error.
2427
4. **Tests.** Unit-test the check directly against mocked dependencies (`internal/postgres/mocks` has the postgres conn mock). For new categories, exercise the builder selection path through the cmd layer too.
2528

2629
## Do not

pkg/stream/preflight/builder.go

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,25 @@
22

33
package preflight
44

5-
import "github.com/xataio/pgstream/pkg/stream"
5+
import (
6+
"context"
67

7-
// Builder turns a stream.Config into the concrete checks for a category. Each
8-
// new category adds an entry to Builders and a matching CLI flag in
9-
// cmd/root_cmd.go.
8+
"github.com/xataio/pgstream/internal/postgres"
9+
"github.com/xataio/pgstream/pkg/stream"
10+
)
11+
12+
// CleanupFunc releases any resources a builder set up (e.g. a shared Postgres
13+
// connection). Builders return nil when there's nothing to clean up.
14+
type CleanupFunc func(context.Context) error
15+
16+
// Builder turns a stream.Config into the concrete checks for a category, plus
17+
// an optional cleanup function that releases resources the checks share (e.g.
18+
// a Postgres connection). Each new category adds an entry to Builders and a
19+
// matching CLI flag in cmd/root_cmd.go.
1020
type Builder struct {
1121
Category Category
1222
Flag string
13-
Build func(*stream.Config) []Check
23+
Build func(*stream.Config) ([]Check, CleanupFunc)
1424
}
1525

1626
// Builders is the registry of category builders. Adding a new category = one
@@ -22,8 +32,9 @@ var Builders = []Builder{
2232

2333
// BuildConnectivityChecks returns the connectivity checks applicable to cfg.
2434
// A source check is added when a source postgres URL is configured; a target
25-
// check is added when a postgres target is configured.
26-
func BuildConnectivityChecks(cfg *stream.Config) []Check {
35+
// check is added when a postgres target is configured. Each check opens its
36+
// own conn (to its own URL), so no shared cleanup is needed.
37+
func BuildConnectivityChecks(cfg *stream.Config) ([]Check, CleanupFunc) {
2738
checks := []Check{}
2839
if url := cfg.SourcePostgresURL(); url != "" {
2940
checks = append(checks, &ConnectivityCheck{Label: "source", URL: url})
@@ -33,42 +44,58 @@ func BuildConnectivityChecks(cfg *stream.Config) []Check {
3344
checks = append(checks, &ConnectivityCheck{Label: "target", URL: url})
3445
}
3546
}
36-
return checks
47+
return checks, nil
3748
}
3849

3950
// BuildReplicationChecks returns the replication-preflight checks applicable
40-
// to cfg. Replication checks only apply when the source is configured with a
41-
// replication slot (i.e. the run is doing logical replication, not a one-shot
42-
// snapshot).
43-
func BuildReplicationChecks(cfg *stream.Config) []Check {
51+
// to cfg, plus a cleanup function that closes the shared source connection.
52+
// Replication checks only apply when the source is configured with a
53+
// replication slot.
54+
func BuildReplicationChecks(cfg *stream.Config) ([]Check, CleanupFunc) {
4455
if cfg.PostgresReplicationSlot() == "" {
45-
return nil
56+
return nil, nil
4657
}
4758
url := cfg.SourcePostgresURL()
4859
if url == "" {
49-
return nil
60+
return nil, nil
5061
}
62+
src := postgres.NewLazyConn(url)
5163
return []Check{
52-
&WALLevelCheck{URL: url},
53-
&WAL2JSONCheck{URL: url},
54-
&ReplicationSlotHeadroomCheck{URL: url},
55-
&ReplicationRoleAttrCheck{URL: url},
56-
}
64+
&WALLevelCheck{Source: src.Acquire},
65+
&WAL2JSONCheck{Source: src.Acquire},
66+
&ReplicationSlotHeadroomCheck{Source: src.Acquire},
67+
&ReplicationRoleAttrCheck{Source: src.Acquire},
68+
}, src.Close
5769
}
5870

5971
// BuildChecks returns the concrete checks for the selected categories,
60-
// preserving the registration order in Builders. An empty selection runs every
61-
// registered category.
62-
func BuildChecks(cfg *stream.Config, selected []Category) []Check {
72+
// preserving the registration order in Builders, plus a single cleanup
73+
// function that releases every category's resources. The returned cleanup is
74+
// always non-nil; callers can defer it unconditionally. An empty selection
75+
// runs every registered category.
76+
func BuildChecks(cfg *stream.Config, selected []Category) ([]Check, CleanupFunc) {
6377
want := make(map[Category]bool, len(selected))
6478
for _, c := range selected {
6579
want[c] = true
6680
}
6781
checks := []Check{}
82+
cleanups := []CleanupFunc{}
6883
for _, b := range Builders {
6984
if len(want) == 0 || want[b.Category] {
70-
checks = append(checks, b.Build(cfg)...)
85+
cs, cleanup := b.Build(cfg)
86+
checks = append(checks, cs...)
87+
if cleanup != nil {
88+
cleanups = append(cleanups, cleanup)
89+
}
90+
}
91+
}
92+
return checks, func(ctx context.Context) error {
93+
var firstErr error
94+
for _, c := range cleanups {
95+
if err := c(ctx); err != nil && firstErr == nil {
96+
firstErr = err
97+
}
7198
}
99+
return firstErr
72100
}
73-
return checks
74101
}

pkg/stream/preflight/replication.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,16 @@ import (
1212
// WALLevelCheck verifies the source Postgres has `wal_level=logical`, which
1313
// pgstream's replication path requires.
1414
type WALLevelCheck struct {
15-
URL string
15+
Source postgres.AcquireFunc
1616
}
1717

1818
func (c *WALLevelCheck) Name() string { return "wal_level" }
1919

2020
func (c *WALLevelCheck) Run(ctx context.Context) ([]Finding, error) {
21-
conn, err := postgres.NewConn(ctx, c.URL)
21+
conn, err := c.Source(ctx)
2222
if err != nil {
2323
return nil, fmt.Errorf("connecting to source: %w", err)
2424
}
25-
defer conn.Close(ctx)
2625

2726
var level string
2827
if err := conn.QueryRow(ctx, []any{&level}, "SHOW wal_level"); err != nil {
@@ -39,17 +38,16 @@ func (c *WALLevelCheck) Run(ctx context.Context) ([]Finding, error) {
3938
// WAL2JSONCheck verifies that the wal2json output plugin is installed and
4039
// loadable on the source. pgstream decodes WAL through wal2json.
4140
type WAL2JSONCheck struct {
42-
URL string
41+
Source postgres.AcquireFunc
4342
}
4443

4544
func (c *WAL2JSONCheck) Name() string { return "wal2json" }
4645

4746
func (c *WAL2JSONCheck) Run(ctx context.Context) ([]Finding, error) {
48-
conn, err := postgres.NewConn(ctx, c.URL)
47+
conn, err := c.Source(ctx)
4948
if err != nil {
5049
return nil, fmt.Errorf("connecting to source: %w", err)
5150
}
52-
defer conn.Close(ctx)
5351

5452
var present int
5553
if err := conn.QueryRow(ctx, []any{&present}, "SELECT count(*)::int FROM pg_available_extensions WHERE name = 'wal2json'"); err != nil {
@@ -66,17 +64,16 @@ func (c *WAL2JSONCheck) Run(ctx context.Context) ([]Finding, error) {
6664
// ReplicationSlotHeadroomCheck reports whether the source has at least one
6765
// slot still available before max_replication_slots is reached.
6866
type ReplicationSlotHeadroomCheck struct {
69-
URL string
67+
Source postgres.AcquireFunc
7068
}
7169

7270
func (c *ReplicationSlotHeadroomCheck) Name() string { return "replication_slot_headroom" }
7371

7472
func (c *ReplicationSlotHeadroomCheck) Run(ctx context.Context) ([]Finding, error) {
75-
conn, err := postgres.NewConn(ctx, c.URL)
73+
conn, err := c.Source(ctx)
7674
if err != nil {
7775
return nil, fmt.Errorf("connecting to source: %w", err)
7876
}
79-
defer conn.Close(ctx)
8077

8178
var maxSlots, usedSlots int
8279
err = conn.QueryRow(ctx, []any{&maxSlots, &usedSlots}, `
@@ -98,17 +95,16 @@ func (c *ReplicationSlotHeadroomCheck) Run(ctx context.Context) ([]Finding, erro
9895
// ReplicationRoleAttrCheck verifies the current source role has the
9996
// REPLICATION attribute, which is required to open a logical replication slot.
10097
type ReplicationRoleAttrCheck struct {
101-
URL string
98+
Source postgres.AcquireFunc
10299
}
103100

104101
func (c *ReplicationRoleAttrCheck) Name() string { return "replication_role_attr" }
105102

106103
func (c *ReplicationRoleAttrCheck) Run(ctx context.Context) ([]Finding, error) {
107-
conn, err := postgres.NewConn(ctx, c.URL)
104+
conn, err := c.Source(ctx)
108105
if err != nil {
109106
return nil, fmt.Errorf("connecting to source: %w", err)
110107
}
111-
defer conn.Close(ctx)
112108

113109
var roleName string
114110
var hasReplication bool

0 commit comments

Comments
 (0)