diff --git a/.mise.toml b/.mise.toml new file mode 100644 index 00000000..966f88a7 --- /dev/null +++ b/.mise.toml @@ -0,0 +1,2 @@ +[tools] +go = "1.26.2" diff --git a/README.md b/README.md index 4ab46e96..7f5acf1d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ A cross-platform CLI to interact with an OpenFGA server - [Building from Source](#building-from-source) - [Usage](#usage) - [Configuration](#configuration) + - [Custom Headers](#custom-headers) - [Commands](#commands) - [Stores](#stores) - [List All Stores](#list-stores) @@ -151,6 +152,7 @@ For any command that interacts with an OpenFGA server, these configuration value | Token Audience | `--api-audience` | `FGA_API_AUDIENCE` | `api-audience` | | Store ID | `--store-id` | `FGA_STORE_ID` | `store-id` | | Authorization Model ID | `--model-id` | `FGA_MODEL_ID` | `model-id` | +| Custom Headers | `--custom-headers` | `FGA_CUSTOM_HEADERS` | `custom-headers` | If you are authenticating with a shared secret, you should specify the API Token value. If you are authenticating using OAuth, you should specify the Client ID, Client Secret, API Audience and Token Issuer. For example: @@ -164,6 +166,37 @@ api-token-issuer: auth.fga.dev store-id: 01H0H015178Y2V4CX10C2KGHF4 ``` +#### Custom Headers + +You can add custom HTTP headers to all requests sent to the API using the `--custom-headers` flag. Headers are specified in `: ` format, and the flag can be repeated to add multiple headers. + +##### Flag +```shell +--custom-headers "Header-Name: header-value" +``` + +##### Example +```shell +fga store list --custom-headers "X-Custom-Header: value1" --custom-headers "X-Request-ID: abc123" +``` + +##### Configuration + +Custom headers can also be configured via the CLI environment variable or the configuration file: + +| Name | Flag | CLI | ~/.fga.yaml | +|----------------|----------------------|------------------------|---------------------| +| Custom Headers | `--custom-headers` | `FGA_CUSTOM_HEADERS` | `custom-headers` | + +Example `~/.fga.yaml`: +```yaml +api-url: https://api.fga.example +store-id: 01H0H015178Y2V4CX10C2KGHF4 +custom-headers: + - "X-Custom-Header: value1" + - "X-Request-ID: abc123" +``` + ### Commands #### Stores diff --git a/cmd/root.go b/cmd/root.go index 46fa3cd1..d00a65e5 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -65,9 +65,10 @@ func init() { rootCmd.PersistentFlags().String("api-token", "", "API Token. Will be sent in as a Bearer in the Authorization header") rootCmd.PersistentFlags().String("api-token-issuer", "", "API Token Issuer. API responsible for issuing the API Token. Used in the Client Credentials flow") //nolint:lll rootCmd.PersistentFlags().String("api-audience", "", "API Audience. Used when performing the Client Credentials flow") - rootCmd.PersistentFlags().String("client-id", "", "Client ID. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll - rootCmd.PersistentFlags().String("client-secret", "", "Client Secret. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll - rootCmd.PersistentFlags().StringArray("api-scopes", []string{}, "API Scopes (repeat option for multiple values). Used in the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().String("client-id", "", "Client ID. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().String("client-secret", "", "Client Secret. Sent to the Token Issuer during the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().StringArray("api-scopes", []string{}, "API Scopes (repeat option for multiple values). Used in the Client Credentials flow") //nolint:lll + rootCmd.PersistentFlags().StringArray("custom-headers", []string{}, "Custom HTTP headers in 'Header: value' format (repeat option for multiple values)") //nolint:lll rootCmd.PersistentFlags().Bool("debug", false, "Enable debug mode - can print more detailed information for debugging") _ = rootCmd.Flags().MarkHidden("debug") diff --git a/go.mod b/go.mod index a3572449..a4df06a4 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,7 @@ module github.com/openfga/cli go 1.25.0 -toolchain go1.26.1 +toolchain go1.26.2 require ( github.com/gocarina/gocsv v0.0.0-20240520201108-78e41c74b4b1 diff --git a/internal/build/build.go b/internal/build/build.go index 786f84a2..19bc4a1c 100644 --- a/internal/build/build.go +++ b/internal/build/build.go @@ -16,8 +16,6 @@ limitations under the License. // Package build provides build information that is linked into the application. Other // packages within this project can use this information in logs etc.. - -//nolint:revive // package name conflicts with stdlib is acceptable here package build var ( diff --git a/internal/cmdutils/bind-viper-to-flags.go b/internal/cmdutils/bind-viper-to-flags.go index 8f032851..8e24093e 100644 --- a/internal/cmdutils/bind-viper-to-flags.go +++ b/internal/cmdutils/bind-viper-to-flags.go @@ -18,6 +18,7 @@ package cmdutils import ( "fmt" + "reflect" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -31,8 +32,9 @@ func BindViperToFlags(cmd *cobra.Command, viperInstance *viper.Viper) { if !flag.Changed && viperInstance.IsSet(configName) { value := viperInstance.Get(configName) - err := cmd.Flags().Set(flag.Name, fmt.Sprintf("%v", value)) - cobra.CheckErr(err) + for _, strVal := range viperValueToStrings(value) { + cobra.CheckErr(cmd.Flags().Set(flag.Name, strVal)) + } } }) @@ -40,3 +42,21 @@ func BindViperToFlags(cmd *cobra.Command, viperInstance *viper.Viper) { BindViperToFlags(subcmd, viperInstance) } } + +// viperValueToStrings converts a Viper config value to a slice of strings +// suitable for pflag.Set calls. Slice values (from YAML lists) produce one +// string per element; scalar values produce a single-element slice. +func viperValueToStrings(value any) []string { + reflectValue := reflect.ValueOf(value) + + if reflectValue.Kind() != reflect.Slice && reflectValue.Kind() != reflect.Array { + return []string{fmt.Sprintf("%v", value)} + } + + result := make([]string, 0, reflectValue.Len()) + for i := range reflectValue.Len() { + result = append(result, fmt.Sprintf("%v", reflectValue.Index(i).Interface())) + } + + return result +} diff --git a/internal/cmdutils/bind-viper-to-flags_test.go b/internal/cmdutils/bind-viper-to-flags_test.go new file mode 100644 index 00000000..bf5225b4 --- /dev/null +++ b/internal/cmdutils/bind-viper-to-flags_test.go @@ -0,0 +1,86 @@ +/* +Copyright © 2023 OpenFGA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cmdutils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestViperValueToStrings(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + value any + expected []string + }{ + { + name: "slice value produces one string per element", + value: []any{ + "X-Custom-Header: value1", + "X-Request-ID: abc123", + }, + expected: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + }, + { + name: "single element slice", + value: []any{"X-Custom-Header: value1"}, + expected: []string{"X-Custom-Header: value1"}, + }, + { + name: "empty slice", + value: []any{}, + expected: []string{}, + }, + { + name: "typed string slice produces one string per element", + value: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + expected: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + }, + { + name: "typed int slice produces one string per element", + value: []int{1, 2, 3}, + expected: []string{"1", "2", "3"}, + }, + { + name: "scalar string produces single-element slice", + value: "https://api.fga.example", + expected: []string{"https://api.fga.example"}, + }, + { + name: "boolean value is stringified", + value: true, + expected: []string{"true"}, + }, + { + name: "integer value is stringified", + value: 42, + expected: []string{"42"}, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + result := viperValueToStrings(test.value) + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/internal/cmdutils/get-client-config.go b/internal/cmdutils/get-client-config.go index f6af6032..94bb1498 100644 --- a/internal/cmdutils/get-client-config.go +++ b/internal/cmdutils/get-client-config.go @@ -44,6 +44,7 @@ func GetClientConfig(cmd *cobra.Command) fga.ClientConfig { clientCredentialsClientID, _ := cmd.Flags().GetString("client-id") clientCredentialsClientSecret, _ := cmd.Flags().GetString("client-secret") clientCredentialsScopes, _ := cmd.Flags().GetStringArray("api-scopes") + customHeaders, _ := cmd.Flags().GetStringArray("custom-headers") debug, _ := cmd.Flags().GetBool("debug") return fga.ClientConfig{ @@ -56,6 +57,7 @@ func GetClientConfig(cmd *cobra.Command) fga.ClientConfig { ClientID: clientCredentialsClientID, ClientSecret: clientCredentialsClientSecret, APIScopes: clientCredentialsScopes, + CustomHeaders: customHeaders, Debug: debug, } } diff --git a/internal/fga/fga.go b/internal/fga/fga.go index efa7a130..07ba66bd 100644 --- a/internal/fga/fga.go +++ b/internal/fga/fga.go @@ -18,6 +18,8 @@ limitations under the License. package fga import ( + "errors" + "fmt" "strings" openfga "github.com/openfga/go-sdk" @@ -32,23 +34,33 @@ const ( MinSdkWaitInMs = 500 ) -var userAgent = "openfga-cli/" + build.Version +var ( + userAgent = "openfga-cli/" + build.Version + + ErrInvalidHeaderFormat = errors.New("expected format \"Header-Name: value\"") +) type ClientConfig struct { ApiUrl string `json:"api_url,omitempty"` //nolint:revive,stylecheck StoreID string `json:"store_id,omitempty"` AuthorizationModelID string `json:"authorization_model_id,omitempty"` - APIToken string `json:"api_token,omitempty"` //nolint:gosec + APIToken string `json:"api_token,omitempty"` APITokenIssuer string `json:"api_token_issuer,omitempty"` APIAudience string `json:"api_audience,omitempty"` APIScopes []string `json:"api_scopes,omitempty"` ClientID string `json:"client_id,omitempty"` - ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec + ClientSecret string `json:"client_secret,omitempty"` + CustomHeaders []string `json:"custom_headers,omitempty"` Debug bool `json:"debug,omitempty"` } func (c ClientConfig) GetFgaClient() (*client.OpenFgaClient, error) { - fgaClient, err := client.NewSdkClient(c.getClientConfig()) + clientConfig, err := c.getClientConfig() + if err != nil { + return nil, err + } + + fgaClient, err := client.NewSdkClient(clientConfig) if err != nil { return nil, err //nolint:wrapcheck } @@ -84,7 +96,12 @@ func (c ClientConfig) getCredentials() *credentials.Credentials { } } -func (c ClientConfig) getClientConfig() *client.ClientConfiguration { +func (c ClientConfig) getClientConfig() (*client.ClientConfiguration, error) { + customHeaders, err := c.getCustomHeaders() + if err != nil { + return nil, fmt.Errorf("invalid custom headers configuration: %w", err) + } + return &client.ClientConfiguration{ ApiUrl: c.ApiUrl, StoreId: c.StoreID, @@ -95,6 +112,24 @@ func (c ClientConfig) getClientConfig() *client.ClientConfiguration { MaxRetry: MaxSdkRetry, MinWaitInMs: MinSdkWaitInMs, }, - Debug: c.Debug, + Debug: c.Debug, + DefaultHeaders: customHeaders, + }, nil +} + +func (c ClientConfig) getCustomHeaders() (map[string]string, error) { + headers := make(map[string]string, len(c.CustomHeaders)) + + for _, header := range c.CustomHeaders { + name, value, _ := strings.Cut(header, ":") + + name, value = strings.TrimSpace(name), strings.TrimSpace(value) + if name == "" { + return nil, fmt.Errorf("invalid custom header %q: %w", header, ErrInvalidHeaderFormat) + } + + headers[name] = value } + + return headers, nil } diff --git a/internal/fga/fga_test.go b/internal/fga/fga_test.go new file mode 100644 index 00000000..ee90e26a --- /dev/null +++ b/internal/fga/fga_test.go @@ -0,0 +1,186 @@ +/* +Copyright © 2023 OpenFGA + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package fga + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openfga/go-sdk/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetCustomHeaders(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + headers []string + expected map[string]string + err error + }{ + { + name: "no headers", + headers: []string{}, + expected: map[string]string{}, + }, + { + name: "single valid header", + headers: []string{"X-Custom: value1"}, + expected: map[string]string{ + "X-Custom": "value1", + }, + }, + { + name: "multiple valid headers", + headers: []string{"X-Custom: value1", "X-Request-ID: abc123"}, + expected: map[string]string{ + "X-Custom": "value1", + "X-Request-ID": "abc123", + }, + }, + { + name: "colon in value is preserved", + headers: []string{"X-Custom: host:port"}, + expected: map[string]string{ + "X-Custom": "host:port", + }, + }, + { + name: "whitespace is trimmed", + headers: []string{" X-Custom : value1 "}, + expected: map[string]string{ + "X-Custom": "value1", + }, + }, + { + name: "empty value is valid", + headers: []string{"X-Custom: "}, + expected: map[string]string{ + "X-Custom": "", + }, + }, + { + name: "no colon allowed", + headers: []string{"X-Custom"}, + expected: map[string]string{ + "X-Custom": "", + }, + }, + { + name: "empty string returns error", + headers: []string{""}, + err: ErrInvalidHeaderFormat, + }, + { + name: "empty header name returns error", + headers: []string{": value"}, + err: ErrInvalidHeaderFormat, + }, + { + name: "valid header before invalid stops at first error", + headers: []string{"X-Good: ok", ""}, + err: ErrInvalidHeaderFormat, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + cfg := ClientConfig{CustomHeaders: test.headers} + result, err := cfg.getCustomHeaders() + + if test.err != nil { + require.Error(t, err) + assert.ErrorIs(t, err, test.err) + } else { + require.NoError(t, err) + assert.Equal(t, test.expected, result) + } + }) + } +} + +func TestCustomHeadersSentInRequest(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + customHeaders []string + expectedHeaders map[string]string + }{ + { + name: "single header is sent", + customHeaders: []string{"X-Custom-Header: value1"}, + expectedHeaders: map[string]string{"X-Custom-Header": "value1"}, + }, + { + name: "multiple headers are sent", + customHeaders: []string{"X-Custom-Header: value1", "X-Request-ID: abc123"}, + expectedHeaders: map[string]string{ + "X-Custom-Header": "value1", + "X-Request-ID": "abc123", + }, + }, + { + name: "no custom headers", + customHeaders: []string{}, + expectedHeaders: map[string]string{}, + }, + } + + for _, test := range testcases { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + headersCh := make(chan http.Header, 1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headersCh <- r.Header.Clone() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"stores": []}`)) + })) + defer server.Close() + + cfg := ClientConfig{ + ApiUrl: server.URL, + StoreID: "01H0H015178Y2V4CX10C2KGHF4", + CustomHeaders: test.customHeaders, + } + + fgaClient, err := cfg.GetFgaClient() + require.NoError(t, err) + + _, err = fgaClient.ListStores(context.Background()). + Options(client.ClientListStoresOptions{}). + Execute() + require.NoError(t, err) + + capturedHeaders := <-headersCh + for name, value := range test.expectedHeaders { + assert.Equal(t, value, capturedHeaders.Get(name), + "expected header %s to have value %q", name, value) + } + }) + } +} diff --git a/internal/requests/rampup.go b/internal/requests/rampup.go index b72210c6..9f453411 100644 --- a/internal/requests/rampup.go +++ b/internal/requests/rampup.go @@ -30,7 +30,7 @@ func RampUpAPIRequests( //nolint:cyclop semaphore = make(chan struct{}, maxInFlight) waitGroup sync.WaitGroup ticker = time.NewTicker(rampupPeriodDuration) - requestIndex int32 + requestIndex atomic.Int32 ) // if the ramp up period is 0, go to max rps directly @@ -65,7 +65,7 @@ func RampUpAPIRequests( //nolint:cyclop } for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange - idx := atomic.AddInt32(&requestIndex, 1) - 1 + idx := requestIndex.Add(1) - 1 if idx >= requestsLen { waitGroup.Wait() @@ -102,7 +102,7 @@ func RampUpAPIRequests( //nolint:cyclop } for i := 0; i < int(limiter.Limit()); i++ { //nolint:intrange - idx := atomic.AddInt32(&requestIndex, 1) - 1 + idx := requestIndex.Add(1) - 1 if idx >= requestsLen { waitGroup.Wait() diff --git a/internal/requests/rampup_test.go b/internal/requests/rampup_test.go index ce74ad2e..c065c050 100644 --- a/internal/requests/rampup_test.go +++ b/internal/requests/rampup_test.go @@ -47,7 +47,7 @@ func TestRampUpAPIRequests_RampUpRate(t *testing.T) { defer cancel() var ( - callCount int32 + callCount atomic.Int32 mutex sync.Mutex ) @@ -55,7 +55,7 @@ func TestRampUpAPIRequests_RampUpRate(t *testing.T) { for i := range requestsList { requestsList[i] = func() error { mutex.Lock() - atomic.AddInt32(&callCount, 1) + callCount.Add(1) mutex.Unlock() return nil diff --git a/internal/utils/context.go b/internal/utils/context.go index 98277fb1..766dcd5b 100644 --- a/internal/utils/context.go +++ b/internal/utils/context.go @@ -1,4 +1,4 @@ -package utils //nolint:revive +package utils import "context"