diff --git a/src/cmd/auth.go b/src/cmd/auth.go index 2ed2e712..fe4c8366 100644 --- a/src/cmd/auth.go +++ b/src/cmd/auth.go @@ -29,6 +29,8 @@ func init() { authCmd.Flags().Int("ttl", 0, "Token TTL in seconds (for external identity modes)") authCmd.Flags().Bool("no-store", false, "Print token to stdout instead of storing (for external identity modes)") authCmd.Flags().String("azure-resource", "", "Azure AD resource/audience for token request (for azure mode, default: https://management.azure.com/)") + authCmd.Flags().String("token-name", "", "Name for the personal access token (webauth mode, default: username@hostname)") + authCmd.Flags().String("token-lifetime", "", "Lifetime for the personal access token, e.g. 7d, 12h, 30m (webauth mode, default: never expires)") rootCmd.AddCommand(authCmd) } diff --git a/src/cmd/auth_webauth.go b/src/cmd/auth_webauth.go index 780e72ec..d0890790 100644 --- a/src/cmd/auth_webauth.go +++ b/src/cmd/auth_webauth.go @@ -21,6 +21,39 @@ import ( "github.com/spf13/cobra" ) +// webAuthPayload is the request payload the Console webauth page parses. It is sent +// as base64(JSON) in the webauth URL. Lifetime is the requested token lifetime in +// seconds; when 0 it is omitted and the token never expires. +type webAuthPayload struct { + Port int `json:"port"` + PublicKey string `json:"publicKey"` + Name string `json:"name"` + Lifetime int64 `json:"lifetime,omitempty"` +} + +// resolveTokenName returns the requested token name: the trimmed flag value when +// set, otherwise the default username@hostname. +func resolveTokenName(flagValue, username, hostname string) string { + if name := strings.TrimSpace(flagValue); name != "" { + return name + } + return fmt.Sprintf("%s@%s", username, hostname) +} + +// encodeWebAuthPayload serializes the webauth request payload as base64(JSON). +func encodeWebAuthPayload(port int, pubKeyHex, name string, lifetimeSeconds int64) (string, error) { + rawData, err := json.Marshal(webAuthPayload{ + Port: port, + PublicKey: pubKeyHex, + Name: name, + Lifetime: lifetimeSeconds, + }) + if err != nil { + return "", fmt.Errorf("failed to encode webauth payload: %w", err) + } + return base64.StdEncoding.EncodeToString(rawData), nil +} + func runWebAuth(cmd *cobra.Command, host string) error { // Pick random port port := 8002 + rand.Intn(12001) @@ -34,17 +67,27 @@ func runWebAuth(cmd *cobra.Command, host string) error { pubKeyHex := hex.EncodeToString(kp.PublicKey[:]) privKeyHex := hex.EncodeToString(kp.SecretKey[:]) - // Build PAT name + // Build PAT name (default username@hostname, overridable via --token-name) username := "unknown" if u, err := user.Current(); err == nil { username = u.Username } hostname, _ := os.Hostname() - patName := fmt.Sprintf("%s@%s", username, hostname) + tokenNameFlag, _ := cmd.Flags().GetString("token-name") + patName := resolveTokenName(tokenNameFlag, username, hostname) + + // Parse the requested token lifetime (default: never expires) + lifetimeStr, _ := cmd.Flags().GetString("token-lifetime") + lifetimeSeconds, err := util.ParseTokenLifetime(lifetimeStr) + if err != nil { + return err + } - // Encode payload - rawData := fmt.Sprintf("%d-%s-%s", port, pubKeyHex, patName) - encoded := base64.StdEncoding.EncodeToString([]byte(rawData)) + // Encode payload as base64(JSON): { port, publicKey, name, lifetime? } + encoded, err := encodeWebAuthPayload(port, pubKeyHex, patName, lifetimeSeconds) + if err != nil { + return err + } // Channel to receive auth data type authData struct { diff --git a/src/cmd/auth_webauth_test.go b/src/cmd/auth_webauth_test.go new file mode 100644 index 00000000..4b44dffe --- /dev/null +++ b/src/cmd/auth_webauth_test.go @@ -0,0 +1,62 @@ +package cmd + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestResolveTokenName(t *testing.T) { + if got := resolveTokenName("", "alice", "laptop"); got != "alice@laptop" { + t.Fatalf("empty flag: got %q, want %q", got, "alice@laptop") + } + if got := resolveTokenName(" ", "alice", "laptop"); got != "alice@laptop" { + t.Fatalf("whitespace flag: got %q, want %q", got, "alice@laptop") + } + if got := resolveTokenName(" ci-prod-api ", "alice", "laptop"); got != "ci-prod-api" { + t.Fatalf("set flag: got %q, want %q", got, "ci-prod-api") + } +} + +func TestEncodeWebAuthPayload(t *testing.T) { + // With a lifetime: all fields present, name with hyphens preserved. + encoded, err := encodeWebAuthPayload(8002, "abc123", "ci-prod-api", 604800) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + raw, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + t.Fatalf("payload is not valid base64: %v", err) + } + + var fields map[string]any + if err := json.Unmarshal(raw, &fields); err != nil { + t.Fatalf("payload is not valid JSON: %v", err) + } + if fields["port"].(float64) != 8002 { + t.Fatalf("port: got %v, want 8002", fields["port"]) + } + if fields["publicKey"] != "abc123" { + t.Fatalf("publicKey: got %v, want abc123", fields["publicKey"]) + } + if fields["name"] != "ci-prod-api" { + t.Fatalf("name: got %v, want ci-prod-api", fields["name"]) + } + if fields["lifetime"].(float64) != 604800 { + t.Fatalf("lifetime: got %v, want 604800", fields["lifetime"]) + } + + // Without a lifetime (0): the field is omitted so the token never expires. + encoded, err = encodeWebAuthPayload(8002, "abc123", "alice@laptop", 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + raw, _ = base64.StdEncoding.DecodeString(encoded) + fields = map[string]any{} + if err := json.Unmarshal(raw, &fields); err != nil { + t.Fatalf("payload is not valid JSON: %v", err) + } + if _, ok := fields["lifetime"]; ok { + t.Fatalf("lifetime should be omitted when zero, got %v", fields["lifetime"]) + } +} diff --git a/src/pkg/util/misc.go b/src/pkg/util/misc.go index 45303a69..8b8e8a22 100644 --- a/src/pkg/util/misc.go +++ b/src/pkg/util/misc.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "runtime" + "strconv" "strings" sdk "github.com/phasehq/golang-sdk/v2/phase" @@ -96,6 +97,45 @@ func GetShellCommand(shellType string) ([]string, error) { return []string{path}, nil } +// ParseTokenLifetime parses a token lifetime string such as "7d", "12h", "30m", "60s" +// or "2w" into a number of seconds. Supported units are s (seconds), m (minutes), +// h (hours), d (days) and w (weeks). An empty string returns 0, meaning the token +// never expires. +func ParseTokenLifetime(lifetime string) (int64, error) { + lifetime = strings.TrimSpace(strings.ToLower(lifetime)) + if lifetime == "" { + return 0, nil + } + + invalid := fmt.Errorf("invalid token lifetime %q (expected a number and a unit, e.g. 7d, 12h, 30m, 60s, 2w)", lifetime) + if len(lifetime) < 2 { + return 0, invalid + } + + value, err := strconv.ParseInt(lifetime[:len(lifetime)-1], 10, 64) + if err != nil || value < 0 { + return 0, invalid + } + + var perUnit int64 + switch lifetime[len(lifetime)-1] { + case 's': + perUnit = 1 + case 'm': + perUnit = 60 + case 'h': + perUnit = 3600 + case 'd': + perUnit = 86400 + case 'w': + perUnit = 604800 + default: + return 0, invalid + } + + return value * perUnit, nil +} + // ValidateURL checks that a URL has both a scheme (e.g. https) and a host (e.g. example.com). func ValidateURL(rawURL string) bool { parsed, err := url.Parse(rawURL) diff --git a/src/pkg/util/misc_test.go b/src/pkg/util/misc_test.go index 28c97327..1eeea9f4 100644 --- a/src/pkg/util/misc_test.go +++ b/src/pkg/util/misc_test.go @@ -50,6 +50,35 @@ func TestParseBoolFlag(t *testing.T) { } } +func TestParseTokenLifetime(t *testing.T) { + valid := map[string]int64{ + "": 0, + "60s": 60, + "30m": 1800, + "12h": 43200, + "7d": 604800, + "2w": 1209600, + " 7D ": 604800, // trimmed and case-insensitive + "0d": 0, + } + for in, want := range valid { + got, err := ParseTokenLifetime(in) + if err != nil { + t.Fatalf("unexpected error for %q: %v", in, err) + } + if got != want { + t.Fatalf("ParseTokenLifetime(%q) = %d, want %d", in, got, want) + } + } + + invalid := []string{"7", "d", "7x", "-7d", "abc", "1.5h", "7 d"} + for _, in := range invalid { + if _, err := ParseTokenLifetime(in); err == nil { + t.Fatalf("expected error for %q", in) + } + } +} + func TestParseEnvFile(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, ".env")