diff --git a/api/dbv1/parallel.go b/api/dbv1/parallel.go index efa8cd8a..69efd789 100644 --- a/api/dbv1/parallel.go +++ b/api/dbv1/parallel.go @@ -43,9 +43,9 @@ func (q *Queries) Parallel(ctx context.Context, arg ParallelParams) (*ParallelRe var err error trackMap, err = q.TracksKeyed(ctx, TracksParams{ GetTracksParams: GetTracksParams{ - Ids: arg.TrackIds, - MyID: arg.MyID, - AuthedWallet: arg.AuthedWallet, + Ids: arg.TrackIds, + MyID: arg.MyID, + AuthedWallet: arg.AuthedWallet, }, }) return err diff --git a/api/dbv1/tracks.go b/api/dbv1/tracks.go index e7c8bda3..fecb0542 100644 --- a/api/dbv1/tracks.go +++ b/api/dbv1/tracks.go @@ -10,7 +10,6 @@ import ( type TracksParams struct { GetTracksParams - SolanaWallet string } // Track is the standard track type containing all track data @@ -35,7 +34,7 @@ type Track struct { } func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]Track, error) { - rawTracks, err := q.GetTracks(ctx, GetTracksParams(arg.GetTracksParams)) + rawTracks, err := q.GetTracks(ctx, arg.GetTracksParams) if err != nil { return nil, err } @@ -84,8 +83,9 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32] userPtrMap[id] = &userCopy } - // Get bulk access for all tracks - accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.SolanaWallet) + // Read solana wallet from context (set by middleware) for token gate checks + solanaWallet, _ := ctx.Value("solanaWallet").(string) + accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, solanaWallet) if err != nil { return nil, err } diff --git a/api/solana_wallet_middleware.go b/api/solana_wallet_middleware.go index 15bd4c57..c178166f 100644 --- a/api/solana_wallet_middleware.go +++ b/api/solana_wallet_middleware.go @@ -1,6 +1,7 @@ package api import ( + "context" "crypto/ed25519" "github.com/gofiber/fiber/v2" @@ -8,9 +9,12 @@ import ( "go.uber.org/zap" ) +// SolanaWalletCtxKey is the context key used to pass a verified Solana wallet +// from the HTTP middleware to the database layer. +const SolanaWalletCtxKey = "solanaWallet" + // solanaWalletMiddleware verifies Solana wallet signatures from request headers. -// If the X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers are -// present and valid, the verified wallet public key is stored in c.Locals("solanaWallet"). +// If valid, the verified wallet public key is set on the Go context via SolanaWalletCtxKey. // If headers are absent, the middleware is a no-op. If present but invalid, returns 401. func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error { wallet := c.Get("X-Solana-Wallet") @@ -48,14 +52,6 @@ func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error { } app.logger.Debug("solanaWalletMiddleware: verified", zap.String("wallet", wallet)) - c.Locals("solanaWallet", wallet) + c.SetUserContext(context.WithValue(c.UserContext(), SolanaWalletCtxKey, wallet)) return c.Next() } - -// tryGetSolanaWallet returns the verified Solana wallet from context, or "" if not set. -func (app *ApiServer) tryGetSolanaWallet(c *fiber.Ctx) string { - if w, ok := c.Locals("solanaWallet").(string); ok { - return w - } - return "" -} diff --git a/api/solana_wallet_middleware_test.go b/api/solana_wallet_middleware_test.go index 075b84d5..81fc65fd 100644 --- a/api/solana_wallet_middleware_test.go +++ b/api/solana_wallet_middleware_test.go @@ -18,7 +18,9 @@ func TestSolanaWalletMiddleware(t *testing.T) { var capturedWallet string testApp := fiber.New() testApp.Get("/", app.solanaWalletMiddleware, func(c *fiber.Ctx) error { - capturedWallet = app.tryGetSolanaWallet(c) + if w, ok := c.UserContext().Value(SolanaWalletCtxKey).(string); ok { + capturedWallet = w + } return c.SendStatus(fiber.StatusOK) }) diff --git a/api/v1_track_stream.go b/api/v1_track_stream.go index 64a28786..3f17969b 100644 --- a/api/v1_track_stream.go +++ b/api/v1_track_stream.go @@ -18,12 +18,6 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error { }, } - // If a verified Solana wallet is present, pass it through so - // GetBulkTrackAccess can check token gate balances for it. - if solWallet := app.tryGetSolanaWallet(c); solWallet != "" { - params.SolanaWallet = solWallet - } - tracks, err := app.queries.Tracks(c.Context(), params) if err != nil { return err