Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions pkg/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"github.com/adrg/xdg"
"github.com/mitchellh/go-homedir"
"github.com/spf13/viper"
"path"
"path/filepath"
)

const (
Expand All @@ -33,47 +33,49 @@ func InitConfig(kubeconfig string, sources string) (*Config, error) {
return nil, err
}

viper.SetDefault(keyKubeconfig, defaultKubeconfig)
viper.SetDefault(keySources, defaultSources)
v := viper.New()
v.SetDefault(keyKubeconfig, defaultKubeconfig)
v.SetDefault(keySources, defaultSources)

viper.SetConfigName(configFileName)
viper.SetConfigType(configFileExtension)
viper.AddConfigPath(path.Dir(cfgLocation))
v.SetConfigName(configFileName)
v.SetConfigType(configFileExtension)
v.AddConfigPath(filepath.Dir(cfgLocation))

shouldWriteConfig := false
err = v.ReadInConfig()
if err != nil {
var configFileNotFoundError viper.ConfigFileNotFoundError
if errors.As(err, &configFileNotFoundError) {
shouldWriteConfig = true
fmt.Println("No kuse configuration found, no sweat, I'll create one with defaults at", cfgLocation)
} else {
return nil, err
}
}

if kubeconfig != "" {
viper.Set(keyKubeconfig, kubeconfig)
v.Set(keyKubeconfig, kubeconfig)
shouldWriteConfig = true
}

if sources != "" {
viper.Set(keySources, sources)
v.Set(keySources, sources)
shouldWriteConfig = true
}

if kubeconfig != "" || sources != "" {
err := viper.WriteConfigAs(cfgLocation)
if shouldWriteConfig {
err := v.WriteConfigAs(cfgLocation)
if err != nil {
return nil, err
}
}

err = viper.ReadInConfig()
if err != nil {
var configFileNotFoundError viper.ConfigFileNotFoundError
if errors.As(err, &configFileNotFoundError) {
fmt.Println("No kuse configuration found, no sweat, I'll create one with defaults at", cfgLocation)
err := viper.WriteConfigAs(cfgLocation)
if err != nil {
fmt.Println(err)
return nil, err
}
}
}

expandedKubeconfig, err := homedir.Expand(viper.GetString(keyKubeconfig))
expandedKubeconfig, err := homedir.Expand(v.GetString(keyKubeconfig))
if err != nil {
return nil, err
}

expandedSources, err := homedir.Expand(viper.GetString(keySources))
expandedSources, err := homedir.Expand(v.GetString(keySources))
if err != nil {
return nil, err
}
Expand Down
80 changes: 80 additions & 0 deletions pkg/common/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package common

import (
"github.com/adrg/xdg"
"os"
"path/filepath"
"testing"
)

func setupConfigTestEnv(t *testing.T) (string, string) {
t.Helper()

xdgConfigHome := t.TempDir()
homeDir := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", xdgConfigHome)
t.Setenv("HOME", homeDir)
originalConfigHome := xdg.ConfigHome
xdg.ConfigHome = xdgConfigHome
t.Cleanup(func() {
xdg.ConfigHome = originalConfigHome
})

return xdgConfigHome, homeDir
}

func TestInitConfig_PartialUpdatePreservesExistingValues(t *testing.T) {
_, homeDir := setupConfigTestEnv(t)

initialKubeconfig := filepath.Join(homeDir, ".kube", "config")
updatedKubeconfig := filepath.Join(homeDir, ".kube", "config-updated")
customSources := filepath.Join(t.TempDir(), "kubeconfigs")

if _, err := InitConfig(initialKubeconfig, customSources); err != nil {
t.Fatalf("InitConfig (initial write) returned error: %v", err)
}

updatedConfig, err := InitConfig(updatedKubeconfig, "")
if err != nil {
t.Fatalf("InitConfig (partial update) returned error: %v", err)
}

if updatedConfig.Kubeconfig != updatedKubeconfig {
t.Fatalf("Kubeconfig = %q, want %q", updatedConfig.Kubeconfig, updatedKubeconfig)
}
if updatedConfig.Sources != customSources {
t.Fatalf("Sources = %q, want %q", updatedConfig.Sources, customSources)
}

reloadedConfig, err := InitConfig("", "")
if err != nil {
t.Fatalf("InitConfig (reload) returned error: %v", err)
}

if reloadedConfig.Kubeconfig != updatedKubeconfig {
t.Fatalf("reloaded Kubeconfig = %q, want %q", reloadedConfig.Kubeconfig, updatedKubeconfig)
}
if reloadedConfig.Sources != customSources {
t.Fatalf("reloaded Sources = %q, want %q", reloadedConfig.Sources, customSources)
}
}

func TestInitConfig_ReturnsErrorForInvalidConfigFile(t *testing.T) {
setupConfigTestEnv(t)

configPath, err := xdg.ConfigFile(configFileLocation)
if err != nil {
t.Fatalf("ConfigFile returned error: %v", err)
}
if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
t.Fatalf("MkdirAll returned error: %v", err)
}

if err := os.WriteFile(configPath, []byte("kubeconfig: \"unterminated\n"), 0o600); err != nil {
t.Fatalf("WriteFile returned error: %v", err)
}

if _, err := InitConfig("", ""); err == nil {
t.Fatal("InitConfig returned nil error for invalid config")
}
}
8 changes: 4 additions & 4 deletions pkg/common/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"
"fmt"
"os"
"path"
"path/filepath"
"strings"
)

Expand Down Expand Up @@ -41,8 +41,8 @@ func (s *State) loadTargets() error {
s.targets = make([]Link, 0)
for _, file := range files {
if isYaml(file.Name()) {
filepath := path.Join(s.config.Sources, file.Name())
s.targets = append(s.targets, fileToLink(filepath))
targetPath := filepath.Join(s.config.Sources, file.Name())
s.targets = append(s.targets, fileToLink(targetPath))
}
}

Expand Down Expand Up @@ -116,7 +116,7 @@ func (s *State) SetTarget(target string) error {
}

if !valid {
return errors.New(fmt.Sprintf("invalid target: %s", target))
return fmt.Errorf("invalid target: %s", target)
}

return s.switchLink(filename)
Expand Down
54 changes: 54 additions & 0 deletions pkg/common/state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package common

import (
"os"
"path/filepath"
"testing"
)

func TestSetTarget_ReplacesBrokenKubeconfigSymlink(t *testing.T) {
tempDir := t.TempDir()
sourcesDir := filepath.Join(tempDir, "sources")
if err := os.MkdirAll(sourcesDir, 0o755); err != nil {
t.Fatalf("MkdirAll sources returned error: %v", err)
}

targetFile := filepath.Join(sourcesDir, "dev.yaml")
if err := os.WriteFile(targetFile, []byte("apiVersion: v1\n"), 0o644); err != nil {
t.Fatalf("WriteFile target returned error: %v", err)
}

kubeconfigPath := filepath.Join(tempDir, ".kube", "config")
if err := os.MkdirAll(filepath.Dir(kubeconfigPath), 0o755); err != nil {
t.Fatalf("MkdirAll kubeconfig parent returned error: %v", err)
}

brokenTarget := filepath.Join(sourcesDir, "missing.yaml")
if err := os.Symlink(brokenTarget, kubeconfigPath); err != nil {
t.Fatalf("Symlink returned error: %v", err)
}

state, err := LoadState(&Config{
Kubeconfig: kubeconfigPath,
Sources: sourcesDir,
})
if err != nil {
t.Fatalf("LoadState returned error: %v", err)
}

if state.current.Name != "missing" {
t.Fatalf("current.Name = %q, want %q", state.current.Name, "missing")
}

if err := state.SetTarget("dev"); err != nil {
t.Fatalf("SetTarget returned error: %v", err)
}

resolvedLink, err := os.Readlink(kubeconfigPath)
if err != nil {
t.Fatalf("Readlink returned error: %v", err)
}
if resolvedLink != targetFile {
t.Fatalf("kubeconfig symlink target = %q, want %q", resolvedLink, targetFile)
}
}
12 changes: 5 additions & 7 deletions pkg/common/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package common
import (
"errors"
"os"
"path"
"path/filepath"
"strings"
)

Expand Down Expand Up @@ -34,16 +34,14 @@ func isSymlink(filename string) bool {
}

func exists(filename string) bool {
if _, err := os.Stat(filename); !errors.Is(err, os.ErrNotExist) {
return true
}
return false
_, err := os.Lstat(filename)
return !errors.Is(err, os.ErrNotExist)
}

func fileToLink(filename string) Link {
return Link{
Name: trimYamlSuffix(path.Base(filename)),
Name: trimYamlSuffix(filepath.Base(filename)),
File: filename,
Extension: path.Ext(filename),
Extension: filepath.Ext(filename),
}
}
21 changes: 21 additions & 0 deletions pkg/common/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package common

import (
"os"
"path/filepath"
"testing"
)

func TestExists_ReturnsTrueForBrokenSymlink(t *testing.T) {
tempDir := t.TempDir()
brokenLink := filepath.Join(tempDir, "broken-link")
missingTarget := filepath.Join(tempDir, "missing.yaml")

if err := os.Symlink(missingTarget, brokenLink); err != nil {
t.Fatalf("Symlink returned error: %v", err)
}

if !exists(brokenLink) {
t.Fatalf("exists(%q) = false, want true", brokenLink)
}
}