diff --git a/infisical/client.go b/infisical/client.go new file mode 100644 index 00000000..08705461 --- /dev/null +++ b/infisical/client.go @@ -0,0 +1,211 @@ +package infisical + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +var ErrInsecureSiteURL = errors.New("infisical-kms: INFISICAL_SITE_URL must use https://") + +const tokenExpiryBuffer = 5 * time.Second + +type kmsEncryptDecrypter interface { + encrypt(plaintext string) (string, error) + decrypt(ciphertext string) (string, error) +} + +type kmsClient struct { + httpClient *http.Client + baseURL string + kmsKeyID string + clientID string + clientSecret string + + mu sync.RWMutex + token string + expiresAt time.Time +} + +func newKmsClient(siteURL, kmsKeyID, clientID, clientSecret string) (*kmsClient, error) { + base := strings.TrimRight(siteURL, "/") + u, err := url.Parse(base) + if err != nil || u.Host == "" { + return nil, fmt.Errorf("infisical-kms: invalid INFISICAL_SITE_URL %q: %w", siteURL, err) + } + if !strings.EqualFold(u.Scheme, "https") { + return nil, ErrInsecureSiteURL + } + if !strings.HasSuffix(base, "/api") { + base += "/api" + } + return &kmsClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: base, + kmsKeyID: kmsKeyID, + clientID: clientID, + clientSecret: clientSecret, + }, nil +} + +type loginRequest struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` +} + +type loginResponse struct { + AccessToken string `json:"accessToken"` + ExpiresIn int64 `json:"expiresIn"` +} + +func (c *kmsClient) login() error { + body, err := json.Marshal(loginRequest{ + ClientID: c.clientID, + ClientSecret: c.clientSecret, + }) + if err != nil { + return fmt.Errorf("infisical-kms: failed to marshal login request: %w", err) + } + + resp, err := c.httpClient.Post( + c.baseURL+"/v1/auth/universal-auth/login", + "application/json", + bytes.NewReader(body), + ) + if err != nil { + return fmt.Errorf("infisical-kms: login request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + msg, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("infisical-kms: login returned %d: %s", resp.StatusCode, msg) + } + + var result loginResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return fmt.Errorf("infisical-kms: failed to decode login response: %w", err) + } + + c.mu.Lock() + c.token = result.AccessToken + c.expiresAt = time.Now().Add(time.Duration(result.ExpiresIn)*time.Second - tokenExpiryBuffer) + c.mu.Unlock() + + return nil +} + +func (c *kmsClient) ensureToken() error { + c.mu.RLock() + valid := c.token != "" && time.Now().Before(c.expiresAt) + c.mu.RUnlock() + if valid { + return nil + } + return c.login() +} + +func (c *kmsClient) doKmsRequest(path string, reqBody, respBody interface{}) error { + if err := c.ensureToken(); err != nil { + return err + } + + body, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("infisical-kms: failed to marshal request: %w", err) + } + + retried := false + for { + status, respBytes, err := c.sendKmsRequest(path, body) + if err != nil { + return err + } + if status == http.StatusOK { + if err := json.Unmarshal(respBytes, respBody); err != nil { + return fmt.Errorf("infisical-kms: failed to decode response: %w", err) + } + return nil + } + + if !retried && (status == http.StatusUnauthorized || status == http.StatusForbidden) { + retried = true + if err := c.login(); err != nil { + return fmt.Errorf("infisical-kms: re-authentication failed: %w", err) + } + continue + } + return fmt.Errorf("infisical-kms: request returned %d: %s", status, respBytes) + } +} + +func (c *kmsClient) sendKmsRequest(path string, body []byte) (int, []byte, error) { + req, err := http.NewRequest(http.MethodPost, c.baseURL+path, bytes.NewReader(body)) + if err != nil { + return 0, nil, fmt.Errorf("infisical-kms: failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + c.mu.RLock() + req.Header.Set("Authorization", "Bearer "+c.token) + c.mu.RUnlock() + + resp, err := c.httpClient.Do(req) + if err != nil { + return 0, nil, fmt.Errorf("infisical-kms: request failed: %w", err) + } + defer resp.Body.Close() + + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, nil, fmt.Errorf("infisical-kms: failed to read response: %w", err) + } + return resp.StatusCode, respBytes, nil +} + +type encryptRequest struct { + Plaintext string `json:"plaintext"` +} + +type encryptResponse struct { + Ciphertext string `json:"ciphertext"` +} + +func (c *kmsClient) encrypt(plaintext string) (string, error) { + var resp encryptResponse + path := fmt.Sprintf("/v1/kms/keys/%s/encrypt", c.kmsKeyID) + encoded := base64.StdEncoding.EncodeToString([]byte(plaintext)) + if err := c.doKmsRequest(path, encryptRequest{Plaintext: encoded}, &resp); err != nil { + return "", err + } + return resp.Ciphertext, nil +} + +type decryptRequest struct { + Ciphertext string `json:"ciphertext"` +} + +type decryptResponse struct { + Plaintext string `json:"plaintext"` +} + +func (c *kmsClient) decrypt(ciphertext string) (string, error) { + var resp decryptResponse + path := fmt.Sprintf("/v1/kms/keys/%s/decrypt", c.kmsKeyID) + if err := c.doKmsRequest(path, decryptRequest{Ciphertext: ciphertext}, &resp); err != nil { + return "", err + } + decoded, err := base64.StdEncoding.DecodeString(resp.Plaintext) + if err != nil { + return "", fmt.Errorf("infisical-kms: failed to base64-decode plaintext: %w", err) + } + return string(decoded), nil +} diff --git a/infisical/client_test.go b/infisical/client_test.go new file mode 100644 index 00000000..2c24858d --- /dev/null +++ b/infisical/client_test.go @@ -0,0 +1,180 @@ +package infisical + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func newTestClient(t *testing.T, base string) *kmsClient { + t.Helper() + // newKmsClient enforces https, so build the struct directly for httptest URLs. + return &kmsClient{ + httpClient: http.DefaultClient, + baseURL: strings.TrimRight(base, "/") + "/api", + kmsKeyID: "k-1", + clientID: "id", + clientSecret: "secret", + } +} + +func writeLogin(w http.ResponseWriter, token string) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(loginResponse{AccessToken: token, ExpiresIn: 3600}) +} + +func writeEncrypt(w http.ResponseWriter, ciphertext string) { + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(encryptResponse{Ciphertext: ciphertext}) +} + +func retryServer(t *testing.T, firstStatus int, firstBody string) (*httptest.Server, *int32, *int32) { + t.Helper() + var loginCalls, encryptCalls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): + n := atomic.AddInt32(&loginCalls, 1) + writeLogin(w, "token-"+string(rune('0'+n))) + case strings.Contains(r.URL.Path, "/v1/kms/keys/"): + n := atomic.AddInt32(&encryptCalls, 1) + if n == 1 { + w.WriteHeader(firstStatus) + _, _ = w.Write([]byte(firstBody)) + return + } + assert.Equal(t, "Bearer token-2", r.Header.Get("Authorization")) + writeEncrypt(w, "ct-ok") + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + return srv, &loginCalls, &encryptCalls +} + +func TestDoKmsRequest_Retries401WithReLogin(t *testing.T) { + srv, loginCalls, encryptCalls := retryServer(t, http.StatusUnauthorized, + `{"statusCode":401,"error":"UnauthorizedError","message":"token revoked"}`) + defer srv.Close() + + c := newTestClient(t, srv.URL) + require.NoError(t, c.login()) + + ct, err := c.encrypt("hello") + require.NoError(t, err) + assert.Equal(t, "ct-ok", ct) + assert.EqualValues(t, 2, atomic.LoadInt32(loginCalls), "should re-login after 401") + assert.EqualValues(t, 2, atomic.LoadInt32(encryptCalls), "should retry the encrypt call once") +} + +func TestDoKmsRequest_Retries403TokenErrorWithReLogin(t *testing.T) { + srv, loginCalls, encryptCalls := retryServer(t, http.StatusForbidden, + `{"statusCode":403,"error":"TokenError","message":"Your token has expired. Please re-authenticate."}`) + defer srv.Close() + + c := newTestClient(t, srv.URL) + require.NoError(t, c.login()) + + ct, err := c.encrypt("hello") + require.NoError(t, err) + assert.Equal(t, "ct-ok", ct) + assert.EqualValues(t, 2, atomic.LoadInt32(loginCalls), "should re-login after 403 TokenError") + assert.EqualValues(t, 2, atomic.LoadInt32(encryptCalls), "should retry once on 403 TokenError") +} + +func TestDoKmsRequest_DoesNotInfiniteLoopOnRepeated403(t *testing.T) { + var loginCalls, encryptCalls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): + atomic.AddInt32(&loginCalls, 1) + writeLogin(w, "tkn") + case strings.Contains(r.URL.Path, "/v1/kms/keys/"): + atomic.AddInt32(&encryptCalls, 1) + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"statusCode":403,"error":"PermissionDenied","message":"missing KMS permission"}`)) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer srv.Close() + + c := newTestClient(t, srv.URL) + require.NoError(t, c.login()) + + _, err := c.encrypt("hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "403") + assert.EqualValues(t, 2, atomic.LoadInt32(&encryptCalls), "should retry exactly once on persistent 403") + assert.EqualValues(t, 2, atomic.LoadInt32(&loginCalls), "should re-login once before giving up") +} + +func TestDoKmsRequest_DoesNotInfiniteLoopOnRepeated401(t *testing.T) { + var encryptCalls int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): + writeLogin(w, "any-token") + case strings.Contains(r.URL.Path, "/v1/kms/keys/"): + atomic.AddInt32(&encryptCalls, 1) + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`unauthorized`)) + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer srv.Close() + + c := newTestClient(t, srv.URL) + require.NoError(t, c.login()) + + _, err := c.encrypt("hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "401") + assert.EqualValues(t, 2, atomic.LoadInt32(&encryptCalls), "should retry exactly once on persistent 401") +} + +func TestDoKmsRequest_HappyPath_NoRetry(t *testing.T) { + var loginCalls, encryptCalls int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(r.URL.Path, "/v1/auth/universal-auth/login"): + atomic.AddInt32(&loginCalls, 1) + writeLogin(w, "good-token") + case strings.Contains(r.URL.Path, "/v1/kms/keys/"): + atomic.AddInt32(&encryptCalls, 1) + writeEncrypt(w, "ct") + default: + t.Fatalf("unexpected path: %s", r.URL.Path) + } + })) + defer srv.Close() + + c := newTestClient(t, srv.URL) + require.NoError(t, c.login()) + + ct, err := c.encrypt("hello") + require.NoError(t, err) + assert.Equal(t, "ct", ct) + assert.EqualValues(t, 1, atomic.LoadInt32(&loginCalls), "no extra login when first call succeeds") + assert.EqualValues(t, 1, atomic.LoadInt32(&encryptCalls)) +} + +func TestNewKmsClient_RejectsHTTP(t *testing.T) { + _, err := newKmsClient("http://infisical.internal", "k", "id", "sec") + assert.ErrorIs(t, err, ErrInsecureSiteURL) +} + +func TestNewKmsClient_AcceptsHTTPS(t *testing.T) { + c, err := newKmsClient("https://app.infisical.com", "k", "id", "sec") + require.NoError(t, err) + assert.Equal(t, "https://app.infisical.com/api", c.baseURL) +} diff --git a/infisical/infisical_kms.go b/infisical/infisical_kms.go new file mode 100644 index 00000000..efa3f02e --- /dev/null +++ b/infisical/infisical_kms.go @@ -0,0 +1,251 @@ +package infisical + +import ( + "encoding/json" + "errors" + "fmt" + "os" + + "github.com/libopenstorage/secrets" + "github.com/libopenstorage/secrets/pkg/store" + "github.com/portworx/kvdb" + "github.com/sirupsen/logrus" +) + +const ( + // Name of the secret store + Name = secrets.TypeInfisical + // SiteURLKey is the base URL of the Infisical instance + SiteURLKey = "INFISICAL_SITE_URL" + // ClientIDKey is the Infisical Universal Auth machine identity client ID + ClientIDKey = "INFISICAL_UNIVERSAL_AUTH_CLIENT_ID" + // ClientSecretKey is the Infisical Universal Auth machine identity client secret + ClientSecretKey = "INFISICAL_UNIVERSAL_AUTH_CLIENT_SECRET" + // KMSKeyIDKey is the ID of the Infisical KMS key used for encrypt/decrypt + KMSKeyIDKey = "INFISICAL_KMS_KEY_ID" + // KvdbKey is used to setup Infisical KMS with kvdb for persistence + KvdbKey = "KMS_KVDB" + defaultSiteURL = "https://app.infisical.com" + kvdbPublicBasePath = "infisical/secrets/public/" + kvdbDataBasePath = "infisical/secrets/data/" +) + +var ( + // ErrKvdbNotProvided is returned when a valid kvdb instance is not provided + ErrKvdbNotProvided = errors.New("a valid kvdb.Kvdb instance must be provided via the KMS_KVDB config key") + // ErrClientIDRequired is returned when INFISICAL_UNIVERSAL_AUTH_CLIENT_ID is not set + ErrClientIDRequired = errors.New("INFISICAL_UNIVERSAL_AUTH_CLIENT_ID is required (config key or env var)") + // ErrClientSecretRequired is returned when INFISICAL_UNIVERSAL_AUTH_CLIENT_SECRET is not set + ErrClientSecretRequired = errors.New("INFISICAL_UNIVERSAL_AUTH_CLIENT_SECRET is required (config key or env var)") + // ErrKMSKeyIDRequired is returned when INFISICAL_KMS_KEY_ID is not set + ErrKMSKeyIDRequired = errors.New("INFISICAL_KMS_KEY_ID is required (config key or env var)") +) + +type infisicalKms struct { + client kmsEncryptDecrypter + ps store.PersistenceStore +} + +func New( + secretConfig map[string]interface{}, +) (secrets.Secrets, error) { + v, ok := secretConfig[KvdbKey] + if !ok { + return nil, ErrKvdbNotProvided + } + kv, ok := v.(kvdb.Kvdb) + if !ok { + return nil, ErrKvdbNotProvided + } + ps := store.NewKvdbPersistenceStore(kv, kvdbPublicBasePath, kvdbDataBasePath) + + siteURL := configString(secretConfig, SiteURLKey, defaultSiteURL) + clientID := configString(secretConfig, ClientIDKey, "") + clientSecret := configString(secretConfig, ClientSecretKey, "") + kmsKeyID := configString(secretConfig, KMSKeyIDKey, "") + + if clientID == "" { + return nil, ErrClientIDRequired + } + if clientSecret == "" { + return nil, ErrClientSecretRequired + } + if kmsKeyID == "" { + return nil, ErrKMSKeyIDRequired + } + + client, err := newKmsClient(siteURL, kmsKeyID, clientID, clientSecret) + if err != nil { + return nil, err + } + if err := client.login(); err != nil { + return nil, fmt.Errorf("infisical-kms: authentication failed: %w", err) + } + + logrus.WithField("site", siteURL).Info("infisical-kms: authenticated successfully") + + return &infisicalKms{ + client: client, + ps: ps, + }, nil +} + +func (k *infisicalKms) String() string { + return Name +} + +func (k *infisicalKms) GetSecret( + secretId string, + keyContext map[string]string, +) (map[string]interface{}, secrets.Version, error) { + if secretId == "" { + return nil, secrets.NoVersion, secrets.ErrEmptySecretId + } + + _, customData := keyContext[secrets.CustomSecretData] + _, publicData := keyContext[secrets.PublicSecretData] + if customData && publicData { + return nil, secrets.NoVersion, &secrets.ErrInvalidKeyContext{ + Reason: "both CustomSecretData and PublicSecretData flags cannot be set", + } + } + + exists, err := k.ps.Exists(secretId) + if err != nil { + return nil, secrets.NoVersion, err + } + if !exists { + return nil, secrets.NoVersion, secrets.ErrInvalidSecretId + } + + ciphertextBytes, err := k.ps.GetPublic(secretId) + if err != nil { + return nil, secrets.NoVersion, err + } + + secretData := make(map[string]interface{}) + if publicData { + secretData[secretId] = ciphertextBytes + return secretData, secrets.NoVersion, nil + } + + plaintext, err := k.client.decrypt(string(ciphertextBytes)) + if err != nil { + return nil, secrets.NoVersion, fmt.Errorf("infisical-kms: decryption failed: %w", err) + } + + if customData { + if err := json.Unmarshal([]byte(plaintext), &secretData); err != nil { + return nil, secrets.NoVersion, fmt.Errorf("infisical-kms: failed to unmarshal decrypted data: %w", err) + } + } else { + secretData[secretId] = plaintext + } + return secretData, secrets.NoVersion, nil +} + +func (k *infisicalKms) PutSecret( + secretId string, + secretData map[string]interface{}, + keyContext map[string]string, +) (secrets.Version, error) { + if secretId == "" { + return secrets.NoVersion, secrets.ErrEmptySecretId + } + + _, override := keyContext[secrets.OverwriteSecretDataInStore] + _, customData := keyContext[secrets.CustomSecretData] + _, publicData := keyContext[secrets.PublicSecretData] + + if err := secrets.KeyContextChecks(keyContext, secretData); err != nil { + return secrets.NoVersion, err + } + + var ciphertext []byte + if publicData && len(secretData) > 0 { + raw, ok := secretData[secretId] + if !ok { + return secrets.NoVersion, secrets.ErrInvalidSecretData + } + ciphertext, ok = raw.([]byte) + if !ok { + return secrets.NoVersion, &secrets.ErrInvalidKeyContext{ + Reason: "secret data when PublicSecretData flag is set should be of the type []byte", + } + } + } else if customData && len(secretData) > 0 { + jsonBytes, err := json.Marshal(secretData) + if err != nil { + return secrets.NoVersion, fmt.Errorf("infisical-kms: failed to marshal secret data: %w", err) + } + encrypted, err := k.client.encrypt(string(jsonBytes)) + if err != nil { + return secrets.NoVersion, fmt.Errorf("infisical-kms: encryption failed: %w", err) + } + ciphertext = []byte(encrypted) + } else { + return secrets.NoVersion, secrets.ErrEmptySecretData + } + + return secrets.NoVersion, k.ps.Set(secretId, ciphertext, nil, nil, override) +} + +func (k *infisicalKms) DeleteSecret( + secretId string, + keyContext map[string]string, +) error { + if secretId == "" { + return secrets.ErrEmptySecretId + } + return k.ps.Delete(secretId) +} + +func (k *infisicalKms) ListSecrets() ([]string, error) { + return k.ps.List() +} + +func (k *infisicalKms) Encrypt( + secretId string, + plaintTextData string, + keyContext map[string]string, +) (string, error) { + return "", secrets.ErrNotSupported +} + +func (k *infisicalKms) Decrypt( + secretId string, + encryptedData string, + keyContext map[string]string, +) (string, error) { + return "", secrets.ErrNotSupported +} + +func (k *infisicalKms) Rencrypt( + originalSecretId string, + newSecretId string, + originalKeyContext map[string]string, + newKeyContext map[string]string, + encryptedData string, +) (string, error) { + return "", secrets.ErrNotSupported +} + +func configString(config map[string]interface{}, key, fallback string) string { + if config != nil { + if v, ok := config[key]; ok { + if s, ok := v.(string); ok && s != "" { + return s + } + } + } + if env := os.Getenv(key); env != "" { + return env + } + return fallback +} + +func init() { + if err := secrets.Register(Name, New); err != nil { + panic(err.Error()) + } +} diff --git a/infisical/infisical_kms_integration_test.go b/infisical/infisical_kms_integration_test.go new file mode 100644 index 00000000..4bed10f9 --- /dev/null +++ b/infisical/infisical_kms_integration_test.go @@ -0,0 +1,178 @@ +//go:build integration + +package infisical + +import ( + "fmt" + "os" + "testing" + + "github.com/libopenstorage/secrets" + memkv "github.com/portworx/kvdb/mem" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func requiredEnv(t *testing.T, keys ...string) { + t.Helper() + for _, k := range keys { + if os.Getenv(k) == "" { + t.Skipf("skipping integration test: env var %s is not set", k) + } + } +} + +func newIntegrationBackend(t *testing.T) secrets.Secrets { + t.Helper() + requiredEnv(t, + SiteURLKey, + ClientIDKey, + ClientSecretKey, + KMSKeyIDKey, + ) + + kv, err := memkv.New("integration-test", nil, nil, nil) + require.NoError(t, err) + + s, err := New(map[string]interface{}{ + KvdbKey: kv, + }) + require.NoError(t, err, "New() should succeed with valid env config") + return s +} + +func uniqueID(t *testing.T, suffix string) string { + return fmt.Sprintf("integration-test-%s-%s", t.Name(), suffix) +} + +func TestIntegration_FullLifecycle(t *testing.T) { + s := newIntegrationBackend(t) + secretID := uniqueID(t, "lifecycle") + + customCtx := map[string]string{secrets.CustomSecretData: "true"} + + original := map[string]interface{}{ + "passphrase": "super-secret-value", + "extra": "metadata", + } + + t.Run("PutSecret", func(t *testing.T) { + ver, err := s.PutSecret(secretID, original, customCtx) + require.NoError(t, err) + assert.Equal(t, secrets.NoVersion, ver) + }) + + t.Run("GetSecret", func(t *testing.T) { + result, ver, err := s.GetSecret(secretID, customCtx) + require.NoError(t, err) + assert.Equal(t, secrets.NoVersion, ver) + assert.Equal(t, original["passphrase"], result["passphrase"]) + assert.Equal(t, original["extra"], result["extra"]) + }) + + t.Run("ListSecrets_Contains", func(t *testing.T) { + ids, err := s.ListSecrets() + require.NoError(t, err) + assert.Contains(t, ids, secretID) + }) + + t.Run("DeleteSecret", func(t *testing.T) { + err := s.DeleteSecret(secretID, nil) + require.NoError(t, err) + }) + + t.Run("GetSecret_AfterDelete", func(t *testing.T) { + _, _, err := s.GetSecret(secretID, customCtx) + assert.ErrorIs(t, err, secrets.ErrInvalidSecretId) + }) + + t.Run("ListSecrets_NotContains_AfterDelete", func(t *testing.T) { + ids, err := s.ListSecrets() + require.NoError(t, err) + assert.NotContains(t, ids, secretID) + }) +} + +func TestIntegration_MultipleSecrets(t *testing.T) { + s := newIntegrationBackend(t) + customCtx := map[string]string{secrets.CustomSecretData: "true"} + + entries := []struct { + id string + data map[string]interface{} + }{ + {uniqueID(t, "s1"), map[string]interface{}{"vol": "disk1"}}, + {uniqueID(t, "s2"), map[string]interface{}{"vol": "disk2"}}, + {uniqueID(t, "s3"), map[string]interface{}{"vol": "disk3"}}, + } + + for _, e := range entries { + _, err := s.PutSecret(e.id, e.data, customCtx) + require.NoError(t, err) + } + + for _, e := range entries { + result, _, err := s.GetSecret(e.id, customCtx) + require.NoError(t, err) + assert.Equal(t, e.data["vol"], result["vol"]) + } + + for _, e := range entries { + require.NoError(t, s.DeleteSecret(e.id, nil)) + } +} + +func TestIntegration_Overwrite(t *testing.T) { + s := newIntegrationBackend(t) + id := uniqueID(t, "overwrite") + customCtx := map[string]string{secrets.CustomSecretData: "true"} + + _, err := s.PutSecret(id, map[string]interface{}{"v": "original"}, customCtx) + require.NoError(t, err) + + _, err = s.PutSecret(id, map[string]interface{}{"v": "new"}, customCtx) + require.Error(t, err) + assert.ErrorContains(t, err, "already exists") + + _, err = s.PutSecret(id, map[string]interface{}{"v": "updated"}, + map[string]string{ + secrets.OverwriteSecretDataInStore: "true", + secrets.CustomSecretData: "true", + }) + require.NoError(t, err) + + result, _, err := s.GetSecret(id, customCtx) + require.NoError(t, err) + assert.Equal(t, "updated", result["v"]) + + _ = s.DeleteSecret(id, nil) +} + +func TestIntegration_PublicData(t *testing.T) { + s := newIntegrationBackend(t) + id := uniqueID(t, "publicdata") + publicCtx := map[string]string{secrets.PublicSecretData: "true"} + + raw := []byte("opaque-bytes-stored-as-is") + _, err := s.PutSecret(id, map[string]interface{}{id: raw}, publicCtx) + require.NoError(t, err) + + result, _, err := s.GetSecret(id, publicCtx) + require.NoError(t, err) + assert.Equal(t, raw, result[id]) + + _ = s.DeleteSecret(id, nil) +} + +func TestIntegration_Unsupported(t *testing.T) { + s := newIntegrationBackend(t) + + _, err := s.Encrypt("id", "plaintext", nil) + assert.ErrorIs(t, err, secrets.ErrNotSupported) + + _, err = s.Decrypt("id", "ciphertext", nil) + assert.ErrorIs(t, err, secrets.ErrNotSupported) + + _, err = s.Rencrypt("a", "b", nil, nil, "data") + assert.ErrorIs(t, err, secrets.ErrNotSupported) +} diff --git a/infisical/infisical_kms_test.go b/infisical/infisical_kms_test.go new file mode 100644 index 00000000..da7fd17a --- /dev/null +++ b/infisical/infisical_kms_test.go @@ -0,0 +1,523 @@ +package infisical + +import ( + "errors" + "testing" + + "github.com/libopenstorage/secrets" + "github.com/portworx/kvdb" + memkv "github.com/portworx/kvdb/mem" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakePersistenceStore is a minimal, in-memory implementation of +// store.PersistenceStore used exclusively in unit tests. +type fakePersistenceStore struct { + data map[string][]byte +} + +func newFakePersistenceStore() *fakePersistenceStore { + return &fakePersistenceStore{data: make(map[string][]byte)} +} + +func (f *fakePersistenceStore) GetPublic(secretId string) ([]byte, error) { + v, ok := f.data[secretId] + if !ok { + return nil, secrets.ErrInvalidSecretId + } + return v, nil +} + +func (f *fakePersistenceStore) GetSecretData(_ string, _ []byte) (map[string]interface{}, error) { + return nil, secrets.ErrNotSupported +} + +func (f *fakePersistenceStore) Exists(secretId string) (bool, error) { + _, ok := f.data[secretId] + return ok, nil +} + +func (f *fakePersistenceStore) Set(secretId string, cipher, _ []byte, _ map[string]interface{}, override bool) error { + if _, exists := f.data[secretId]; exists && !override { + return secrets.ErrSecretExists + } + f.data[secretId] = cipher + return nil +} + +func (f *fakePersistenceStore) Delete(secretId string) error { + delete(f.data, secretId) + return nil +} + +func (f *fakePersistenceStore) Name() string { return "fake" } + +func (f *fakePersistenceStore) List() ([]string, error) { + ids := make([]string, 0, len(f.data)) + for k := range f.data { + ids = append(ids, k) + } + return ids, nil +} + +type fakeClient struct { + encryptFn func(plaintext string) (string, error) + decryptFn func(ciphertext string) (string, error) +} + +func (f *fakeClient) encrypt(plaintext string) (string, error) { + return f.encryptFn(plaintext) +} + +func (f *fakeClient) decrypt(ciphertext string) (string, error) { + return f.decryptFn(ciphertext) +} + +func newTestBackend( + enc func(string) (string, error), + dec func(string) (string, error), +) *infisicalKms { + return &infisicalKms{ + client: &fakeClient{encryptFn: enc, decryptFn: dec}, + ps: newFakePersistenceStore(), + } +} + +func newMemKvdb(t *testing.T) kvdb.Kvdb { + t.Helper() + kv, err := memkv.New("test", nil, nil, nil) + require.NoError(t, err) + return kv +} + +// --------------------------------------------------------------------------- +// New() – error path tests +// --------------------------------------------------------------------------- + +func TestNew_MissingKvdb(t *testing.T) { + _, err := New(map[string]interface{}{ + ClientIDKey: "x", + ClientSecretKey: "x", + KMSKeyIDKey: "x", + }) + assert.ErrorIs(t, err, ErrKvdbNotProvided) +} + +func TestNew_WrongKvdbType(t *testing.T) { + _, err := New(map[string]interface{}{ + KvdbKey: "not-a-kvdb-instance", + ClientIDKey: "x", + ClientSecretKey: "x", + KMSKeyIDKey: "x", + }) + assert.ErrorIs(t, err, ErrKvdbNotProvided) +} + +func TestNew_MissingClientID(t *testing.T) { + t.Setenv(ClientIDKey, "") + _, err := New(map[string]interface{}{ + KvdbKey: newMemKvdb(t), + ClientSecretKey: "x", + KMSKeyIDKey: "x", + }) + assert.ErrorIs(t, err, ErrClientIDRequired) +} + +func TestNew_MissingClientSecret(t *testing.T) { + t.Setenv(ClientSecretKey, "") + _, err := New(map[string]interface{}{ + KvdbKey: newMemKvdb(t), + ClientIDKey: "x", + KMSKeyIDKey: "x", + }) + assert.ErrorIs(t, err, ErrClientSecretRequired) +} + +func TestNew_RejectsHTTPSiteURL(t *testing.T) { + _, err := New(map[string]interface{}{ + KvdbKey: newMemKvdb(t), + SiteURLKey: "http://infisical.internal", + ClientIDKey: "x", + ClientSecretKey: "x", + KMSKeyIDKey: "x", + }) + assert.ErrorIs(t, err, ErrInsecureSiteURL) +} + +func TestNew_RejectsInvalidSiteURL(t *testing.T) { + _, err := New(map[string]interface{}{ + KvdbKey: newMemKvdb(t), + SiteURLKey: "not a url", + ClientIDKey: "x", + ClientSecretKey: "x", + KMSKeyIDKey: "x", + }) + require.Error(t, err) + assert.NotErrorIs(t, err, ErrClientIDRequired) +} + +func TestNew_MissingKMSKeyID(t *testing.T) { + t.Setenv(KMSKeyIDKey, "") + _, err := New(map[string]interface{}{ + KvdbKey: newMemKvdb(t), + ClientIDKey: "x", + ClientSecretKey: "x", + }) + assert.ErrorIs(t, err, ErrKMSKeyIDRequired) +} + +// --------------------------------------------------------------------------- +// PutSecret tests +// --------------------------------------------------------------------------- + +func customCtx() map[string]string { + return map[string]string{secrets.CustomSecretData: "true"} +} + +func publicCtx() map[string]string { + return map[string]string{secrets.PublicSecretData: "true"} +} + +func TestPutSecret_HappyPath(t *testing.T) { + k := newTestBackend( + func(plaintext string) (string, error) { + return "encrypted-blob", nil + }, + nil, + ) + + ver, err := k.PutSecret("my-secret", map[string]interface{}{"password": "hunter2"}, customCtx()) + require.NoError(t, err) + assert.Equal(t, secrets.NoVersion, ver) + + stored, err := k.ps.GetPublic("my-secret") + require.NoError(t, err) + assert.Equal(t, "encrypted-blob", string(stored)) +} + +func TestPutSecret_EmptySecretId(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("", map[string]interface{}{"x": "y"}, customCtx()) + assert.ErrorIs(t, err, secrets.ErrEmptySecretId) +} + +func TestPutSecret_NoFlagWithData(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, nil) + var kcErr *secrets.ErrInvalidKeyContext + assert.ErrorAs(t, err, &kcErr) +} + +func TestPutSecret_CustomDataWithoutData(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("my-secret", map[string]interface{}{}, customCtx()) + var kcErr *secrets.ErrInvalidKeyContext + assert.ErrorAs(t, err, &kcErr) +} + +func TestPutSecret_BothFlagsSet(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, map[string]string{ + secrets.CustomSecretData: "true", + secrets.PublicSecretData: "true", + }) + var kcErr *secrets.ErrInvalidKeyContext + assert.ErrorAs(t, err, &kcErr) +} + +func TestPutSecret_PublicData(t *testing.T) { + encryptCalled := false + k := newTestBackend( + func(_ string) (string, error) { + encryptCalled = true + return "should-not-be-called", nil + }, + nil, + ) + + raw := []byte("opaque-ciphertext-bytes") + _, err := k.PutSecret("my-secret", + map[string]interface{}{"my-secret": raw}, publicCtx()) + require.NoError(t, err) + assert.False(t, encryptCalled, "PublicSecretData must not invoke encrypt") + + stored, err := k.ps.GetPublic("my-secret") + require.NoError(t, err) + assert.Equal(t, raw, stored) +} + +func TestPutSecret_PublicData_NotByteSlice(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("my-secret", + map[string]interface{}{"my-secret": "a-string-not-bytes"}, publicCtx()) + var kcErr *secrets.ErrInvalidKeyContext + assert.ErrorAs(t, err, &kcErr) +} + +func TestPutSecret_PublicData_MissingSecretIdKey(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.PutSecret("my-secret", + map[string]interface{}{"other-key": []byte("x")}, publicCtx()) + assert.ErrorIs(t, err, secrets.ErrInvalidSecretData) +} + +func TestPutSecret_EncryptError(t *testing.T) { + encErr := errors.New("kms unavailable") + k := newTestBackend( + func(_ string) (string, error) { return "", encErr }, + nil, + ) + _, err := k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, customCtx()) + require.Error(t, err) + assert.ErrorContains(t, err, "encryption failed") +} + +func TestPutSecret_DuplicateWithoutOverride(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "blob", nil }, + nil, + ) + _, err := k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, customCtx()) + require.NoError(t, err) + + _, err = k.PutSecret("my-secret", map[string]interface{}{"x": "z"}, customCtx()) + assert.ErrorIs(t, err, secrets.ErrSecretExists) +} + +func TestPutSecret_OverwriteWithOverride(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "new-blob", nil }, + nil, + ) + _ = k.ps.Set("my-secret", []byte("old-blob"), nil, nil, false) + + _, err := k.PutSecret("my-secret", map[string]interface{}{"x": "z"}, + map[string]string{ + secrets.OverwriteSecretDataInStore: "true", + secrets.CustomSecretData: "true", + }) + require.NoError(t, err) + + stored, _ := k.ps.GetPublic("my-secret") + assert.Equal(t, "new-blob", string(stored)) +} + +// --------------------------------------------------------------------------- +// GetSecret tests +// --------------------------------------------------------------------------- + +func TestGetSecret_HappyPath(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "ct", nil }, + func(ciphertext string) (string, error) { + assert.Equal(t, "ct", ciphertext) + return `{"password":"hunter2"}`, nil + }, + ) + + _, _ = k.PutSecret("my-secret", map[string]interface{}{"password": "hunter2"}, customCtx()) + + result, ver, err := k.GetSecret("my-secret", customCtx()) + require.NoError(t, err) + assert.Equal(t, secrets.NoVersion, ver) + assert.Equal(t, "hunter2", result["password"]) +} + +func TestGetSecret_EmptySecretId(t *testing.T) { + k := newTestBackend(nil, nil) + _, _, err := k.GetSecret("", nil) + assert.ErrorIs(t, err, secrets.ErrEmptySecretId) +} + +func TestGetSecret_NotFound(t *testing.T) { + k := newTestBackend(nil, nil) + _, _, err := k.GetSecret("nonexistent", customCtx()) + assert.ErrorIs(t, err, secrets.ErrInvalidSecretId) +} + +func TestGetSecret_BothFlagsSet(t *testing.T) { + k := newTestBackend(nil, nil) + _, _, err := k.GetSecret("my-secret", map[string]string{ + secrets.CustomSecretData: "true", + secrets.PublicSecretData: "true", + }) + var kcErr *secrets.ErrInvalidKeyContext + assert.ErrorAs(t, err, &kcErr) +} + +func TestGetSecret_NoFlag_ReturnsPlaintextString(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "ct", nil }, + func(_ string) (string, error) { return "the-plaintext", nil }, + ) + _, _ = k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, customCtx()) + + result, _, err := k.GetSecret("my-secret", nil) + require.NoError(t, err) + assert.Equal(t, "the-plaintext", result["my-secret"]) +} + +func TestGetSecret_PublicData(t *testing.T) { + decryptCalled := false + k := newTestBackend( + nil, + func(_ string) (string, error) { + decryptCalled = true + return "", nil + }, + ) + raw := []byte("opaque-ciphertext-bytes") + _, err := k.PutSecret("my-secret", map[string]interface{}{"my-secret": raw}, publicCtx()) + require.NoError(t, err) + + result, _, err := k.GetSecret("my-secret", publicCtx()) + require.NoError(t, err) + assert.False(t, decryptCalled, "PublicSecretData must not invoke decrypt") + assert.Equal(t, raw, result["my-secret"]) +} + +func TestGetSecret_DecryptError(t *testing.T) { + decErr := errors.New("kms unavailable") + k := newTestBackend( + func(_ string) (string, error) { return "ct", nil }, + func(_ string) (string, error) { return "", decErr }, + ) + _, _ = k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, customCtx()) + + _, _, err := k.GetSecret("my-secret", customCtx()) + require.Error(t, err) + assert.ErrorContains(t, err, "decryption failed") +} + +func TestGetSecret_RoundTrip(t *testing.T) { + original := map[string]interface{}{ + "key": "my-passphrase", + "number": float64(42), + } + + var capturedPlaintext string + k := newTestBackend( + func(plaintext string) (string, error) { + capturedPlaintext = plaintext + return "ct", nil + }, + func(_ string) (string, error) { + return capturedPlaintext, nil + }, + ) + + _, err := k.PutSecret("rtrip", original, customCtx()) + require.NoError(t, err) + + result, _, err := k.GetSecret("rtrip", customCtx()) + require.NoError(t, err) + assert.Equal(t, original, result) +} + +// --------------------------------------------------------------------------- +// DeleteSecret tests +// --------------------------------------------------------------------------- + +func TestDeleteSecret_HappyPath(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "ct", nil }, + nil, + ) + _, _ = k.PutSecret("my-secret", map[string]interface{}{"x": "y"}, customCtx()) + + err := k.DeleteSecret("my-secret", nil) + require.NoError(t, err) + + _, _, err = k.GetSecret("my-secret", nil) + assert.ErrorIs(t, err, secrets.ErrInvalidSecretId) +} + +func TestDeleteSecret_EmptySecretId(t *testing.T) { + k := newTestBackend(nil, nil) + err := k.DeleteSecret("", nil) + assert.ErrorIs(t, err, secrets.ErrEmptySecretId) +} + +func TestDeleteSecret_Idempotent(t *testing.T) { + k := newTestBackend(nil, nil) + err := k.DeleteSecret("does-not-exist", nil) + assert.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// ListSecrets tests +// --------------------------------------------------------------------------- + +func TestListSecrets(t *testing.T) { + k := newTestBackend( + func(_ string) (string, error) { return "ct", nil }, + nil, + ) + _, _ = k.PutSecret("secret-a", map[string]interface{}{"x": "y"}, customCtx()) + _, _ = k.PutSecret("secret-b", map[string]interface{}{"x": "y"}, customCtx()) + + ids, err := k.ListSecrets() + require.NoError(t, err) + assert.ElementsMatch(t, []string{"secret-a", "secret-b"}, ids) +} + +func TestListSecrets_Empty(t *testing.T) { + k := newTestBackend(nil, nil) + ids, err := k.ListSecrets() + require.NoError(t, err) + assert.Empty(t, ids) +} + +// --------------------------------------------------------------------------- +// Unsupported methods +// --------------------------------------------------------------------------- + +func TestEncrypt_NotSupported(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.Encrypt("id", "data", nil) + assert.ErrorIs(t, err, secrets.ErrNotSupported) +} + +func TestDecrypt_NotSupported(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.Decrypt("id", "data", nil) + assert.ErrorIs(t, err, secrets.ErrNotSupported) +} + +func TestRencrypt_NotSupported(t *testing.T) { + k := newTestBackend(nil, nil) + _, err := k.Rencrypt("a", "b", nil, nil, "data") + assert.ErrorIs(t, err, secrets.ErrNotSupported) +} + +// --------------------------------------------------------------------------- +// String() +// --------------------------------------------------------------------------- + +func TestString(t *testing.T) { + k := newTestBackend(nil, nil) + assert.Equal(t, "infisical-kms", k.String()) +} + +// --------------------------------------------------------------------------- +// configString helper +// --------------------------------------------------------------------------- + +func TestConfigString_ConfigTakesPrecedence(t *testing.T) { + t.Setenv("MY_KEY", "from-env") + result := configString(map[string]interface{}{"MY_KEY": "from-config"}, "MY_KEY", "fallback") + assert.Equal(t, "from-config", result) +} + +func TestConfigString_EnvFallback(t *testing.T) { + t.Setenv("MY_KEY", "from-env") + result := configString(nil, "MY_KEY", "fallback") + assert.Equal(t, "from-env", result) +} + +func TestConfigString_Default(t *testing.T) { + t.Setenv("MY_KEY", "") + result := configString(nil, "MY_KEY", "fallback") + assert.Equal(t, "fallback", result) +} diff --git a/secrets.go b/secrets.go index c6e52439..4c7a0d49 100644 --- a/secrets.go +++ b/secrets.go @@ -53,6 +53,7 @@ const ( TypeVault = "vault" TypeVaultTransit = "vault-transit" TypeAWSSecretsManager = "aws-secrets-manager" + TypeInfisical = "infisical-kms" ) const (