Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/dbv1/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ func (q *Queries) Parallel(ctx context.Context, arg ParallelParams) (*ParallelRe
var err error
trackMap, err = q.TracksKeyed(ctx, TracksParams{
GetTracksParams: GetTracksParams{
Ids: arg.TrackIds,
MyID: arg.MyID,
AuthedWallet: arg.AuthedWallet,
Ids: arg.TrackIds,
MyID: arg.MyID,
AuthedWallet: arg.AuthedWallet,
},
})
return err
Expand Down
8 changes: 4 additions & 4 deletions api/dbv1/tracks.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

type TracksParams struct {
GetTracksParams
SolanaWallet string
}

// Track is the standard track type containing all track data
Expand All @@ -35,7 +34,7 @@ type Track struct {
}

func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]Track, error) {
rawTracks, err := q.GetTracks(ctx, GetTracksParams(arg.GetTracksParams))
rawTracks, err := q.GetTracks(ctx, arg.GetTracksParams)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -84,8 +83,9 @@ func (q *Queries) TracksKeyed(ctx context.Context, arg TracksParams) (map[int32]
userPtrMap[id] = &userCopy
}

// Get bulk access for all tracks
accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, arg.SolanaWallet)
// Read solana wallet from context (set by middleware) for token gate checks
solanaWallet, _ := ctx.Value("solanaWallet").(string)
accessMap, err := q.GetBulkTrackAccess(ctx, arg.MyID.(int32), trackPtrs, userPtrMap, solanaWallet)
if err != nil {
return nil, err
}
Expand Down
18 changes: 7 additions & 11 deletions api/solana_wallet_middleware.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
package api

import (
"context"
"crypto/ed25519"

"github.com/gofiber/fiber/v2"
"github.com/mr-tron/base58"
"go.uber.org/zap"
)

// SolanaWalletCtxKey is the context key used to pass a verified Solana wallet
// from the HTTP middleware to the database layer.
const SolanaWalletCtxKey = "solanaWallet"

// solanaWalletMiddleware verifies Solana wallet signatures from request headers.
// If the X-Solana-Wallet, X-Solana-Message, and X-Solana-Signature headers are
// present and valid, the verified wallet public key is stored in c.Locals("solanaWallet").
// If valid, the verified wallet public key is set on the Go context via SolanaWalletCtxKey.
// If headers are absent, the middleware is a no-op. If present but invalid, returns 401.
func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error {
wallet := c.Get("X-Solana-Wallet")
Expand Down Expand Up @@ -48,14 +52,6 @@ func (app *ApiServer) solanaWalletMiddleware(c *fiber.Ctx) error {
}

app.logger.Debug("solanaWalletMiddleware: verified", zap.String("wallet", wallet))
c.Locals("solanaWallet", wallet)
c.SetUserContext(context.WithValue(c.UserContext(), SolanaWalletCtxKey, wallet))
return c.Next()
}

// tryGetSolanaWallet returns the verified Solana wallet from context, or "" if not set.
func (app *ApiServer) tryGetSolanaWallet(c *fiber.Ctx) string {
if w, ok := c.Locals("solanaWallet").(string); ok {
return w
}
return ""
}
4 changes: 3 additions & 1 deletion api/solana_wallet_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func TestSolanaWalletMiddleware(t *testing.T) {
var capturedWallet string
testApp := fiber.New()
testApp.Get("/", app.solanaWalletMiddleware, func(c *fiber.Ctx) error {
capturedWallet = app.tryGetSolanaWallet(c)
if w, ok := c.UserContext().Value(SolanaWalletCtxKey).(string); ok {
capturedWallet = w
}
return c.SendStatus(fiber.StatusOK)
})

Expand Down
6 changes: 0 additions & 6 deletions api/v1_track_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@ func (app *ApiServer) v1TrackStream(c *fiber.Ctx) error {
},
}

// If a verified Solana wallet is present, pass it through so
// GetBulkTrackAccess can check token gate balances for it.
if solWallet := app.tryGetSolanaWallet(c); solWallet != "" {
params.SolanaWallet = solWallet
}

tracks, err := app.queries.Tracks(c.Context(), params)
if err != nil {
return err
Expand Down
Loading