Skip to content

Commit 16b024c

Browse files
committed
refactor(server): use cobra to validate required flags in server firewall
Also adds helper to execute cobra command and capture its output and errors to help, for example, checking that cobra validates given args correctly.
1 parent f95cd5d commit 16b024c

5 files changed

Lines changed: 62 additions & 38 deletions

File tree

internal/commands/serverfirewall/create.go

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,38 +76,18 @@ Use --position parameter or create default rule after other rules.`
7676
flagSet.StringVar(&s.comment, "comment", "", "Freeform comment that can include 0-250 characters.")
7777

7878
s.AddFlags(flagSet)
79+
s.Cobra().MarkFlagRequired("direction") //nolint:errcheck
80+
s.Cobra().MarkFlagRequired("action") //nolint:errcheck
81+
s.Cobra().MarkFlagsRequiredTogether("destination-port-start", "destination-port-end")
82+
s.Cobra().MarkFlagsRequiredTogether("source-port-start", "source-port-end")
7983
}
8084

8185
// Execute implements commands.MultipleArgumentCommand
8286
func (s *createCommand) Execute(exec commands.Executor, arg string) (output.Output, error) {
83-
if s.direction == "" {
84-
return nil, fmt.Errorf("direction is required")
85-
}
86-
87-
if s.action == "" {
88-
return nil, fmt.Errorf("action is required")
89-
}
90-
9187
if s.family != "" && s.family != "IPv4" && s.family != "IPv6" {
9288
return nil, fmt.Errorf("invalid family, use either IPv4 or IPv6")
9389
}
9490

95-
if s.destinationPortStart == "" && s.destinationPortEnd != "" {
96-
return nil, fmt.Errorf("destination-port-start is required if destination-port-end is set")
97-
}
98-
99-
if s.destinationPortEnd == "" && s.destinationPortStart != "" {
100-
return nil, fmt.Errorf("destination-port-end is required if destination-port-start is set")
101-
}
102-
103-
if s.sourcePortStart == "" && s.sourcePortEnd != "" {
104-
return nil, fmt.Errorf("source-port-start is required if source-port-end is set")
105-
}
106-
107-
if s.sourcePortEnd == "" && s.sourcePortStart != "" {
108-
return nil, fmt.Errorf("source-port-end is required if source-port-start is set")
109-
}
110-
11191
var (
11292
destinationNetwork *cidr.ParsedCIDR
11393
sourceNetwork *cidr.ParsedCIDR

internal/commands/serverfirewall/create_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import (
66
"github.com/UpCloudLtd/upcloud-cli/internal/commands"
77
"github.com/UpCloudLtd/upcloud-cli/internal/config"
88
smock "github.com/UpCloudLtd/upcloud-cli/internal/mock"
9+
"github.com/UpCloudLtd/upcloud-cli/internal/mockexecute"
910

1011
"github.com/UpCloudLtd/upcloud-go-api/v4/upcloud"
1112
"github.com/UpCloudLtd/upcloud-go-api/v4/upcloud/request"
12-
"github.com/gemalto/flume"
1313
"github.com/stretchr/testify/assert"
1414
"github.com/stretchr/testify/mock"
1515
)
@@ -40,7 +40,7 @@ func TestCreateFirewallRuleCommand(t *testing.T) {
4040
name: "Empty info",
4141
flags: []string{},
4242
arg: Server1.UUID,
43-
error: "direction is required",
43+
error: `required flag(s) "direction", "action" not set`,
4444
},
4545
{
4646
name: "Action is required",
@@ -49,7 +49,7 @@ func TestCreateFirewallRuleCommand(t *testing.T) {
4949
"--direction", "in",
5050
},
5151
arg: Server1.UUID,
52-
error: "action is required",
52+
error: `required flag(s) "action" not set`,
5353
},
5454
{
5555
name: "FirewallRule, drop incoming by default",
@@ -96,10 +96,10 @@ func TestCreateFirewallRuleCommand(t *testing.T) {
9696

9797
conf := config.New()
9898
cc := commands.BuildCommand(CreateCommand(), nil, conf)
99-
err := cc.Cobra().Flags().Parse(test.flags)
100-
assert.NoError(t, err)
10199

102-
_, err = cc.(commands.MultipleArgumentCommand).Execute(commands.NewExecutor(conf, &mService, flume.New("test")), test.arg)
100+
cc.Cobra().SetArgs(append(test.flags, test.arg))
101+
_, err := mockexecute.MockExecute(cc, &mService, conf)
102+
103103
if test.error != "" {
104104
assert.Error(t, err)
105105
assert.Equal(t, test.error, err.Error())

internal/commands/serverfirewall/delete.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,12 @@ func (s *deleteCommand) InitCommand() {
3434
flagSet := &pflag.FlagSet{}
3535
flagSet.IntVar(&s.rulePosition, "position", 0, "Rule position. Available: 1-1000")
3636
s.AddFlags(flagSet)
37+
38+
s.Cobra().MarkFlagRequired("position") //nolint:errcheck
3739
}
3840

3941
// Execute implements commands.MultipleArgumentCommand
4042
func (s *deleteCommand) Execute(exec commands.Executor, arg string) (output.Output, error) {
41-
if s.rulePosition == 0 {
42-
return nil, fmt.Errorf("position is required")
43-
}
4443
if s.rulePosition < 1 || s.rulePosition > 1000 {
4544
return nil, fmt.Errorf("invalid position (1-1000 allowed)")
4645
}

internal/commands/serverfirewall/delete_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import (
77
"github.com/UpCloudLtd/upcloud-cli/internal/commands"
88
"github.com/UpCloudLtd/upcloud-cli/internal/config"
99
smock "github.com/UpCloudLtd/upcloud-cli/internal/mock"
10+
"github.com/UpCloudLtd/upcloud-cli/internal/mockexecute"
1011

1112
"github.com/UpCloudLtd/upcloud-go-api/v4/upcloud"
12-
"github.com/gemalto/flume"
1313
"github.com/stretchr/testify/assert"
1414
"github.com/stretchr/testify/mock"
1515
)
@@ -37,7 +37,7 @@ func TestDeleteServerFirewallRuleCommand(t *testing.T) {
3737
{
3838
name: "no position",
3939
flags: []string{},
40-
error: "position is required",
40+
error: `required flag(s) "position" not set`,
4141
},
4242
{
4343
name: "position 1",
@@ -56,10 +56,10 @@ func TestDeleteServerFirewallRuleCommand(t *testing.T) {
5656

5757
conf := config.New()
5858
cc := commands.BuildCommand(DeleteCommand(), nil, conf)
59-
err := cc.Cobra().Flags().Parse(test.flags)
60-
assert.NoError(t, err)
6159

62-
_, err = cc.(commands.MultipleArgumentCommand).Execute(commands.NewExecutor(conf, mService, flume.New("test")), Server1.UUID)
60+
cc.Cobra().SetArgs(append(test.flags, Server1.UUID))
61+
_, err := mockexecute.MockExecute(cc, mService, conf)
62+
6363
if test.error != "" {
6464
fmt.Println("ERROR", test.error, "==", err)
6565
assert.EqualError(t, err, test.error)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package mockexecute
2+
3+
import (
4+
"bytes"
5+
6+
"github.com/gemalto/flume"
7+
"github.com/spf13/cobra"
8+
9+
"github.com/UpCloudLtd/upcloud-cli/internal/commands"
10+
"github.com/UpCloudLtd/upcloud-cli/internal/config"
11+
"github.com/UpCloudLtd/upcloud-cli/internal/output"
12+
"github.com/UpCloudLtd/upcloud-cli/internal/service"
13+
)
14+
15+
func MockExecute(command commands.Command, service service.AllServices, config *config.Config) (string, error) {
16+
buf := bytes.NewBuffer(nil)
17+
command.Cobra().SetErr(buf)
18+
command.Cobra().SetOut(buf)
19+
20+
command.Cobra().RunE = func(_ *cobra.Command, args []string) error {
21+
return mockRunE(command, service, config, args)
22+
}
23+
err := command.Cobra().Execute()
24+
25+
return buf.String(), err
26+
}
27+
28+
func mockRunE(command commands.Command, service service.AllServices, config *config.Config, args []string) error {
29+
executor := commands.NewExecutor(config, service, flume.New("test"))
30+
31+
var err error
32+
var out output.Output
33+
switch typedCommand := command.(type) {
34+
case commands.NoArgumentCommand:
35+
out, err = typedCommand.ExecuteWithoutArguments(executor)
36+
case commands.SingleArgumentCommand:
37+
out, err = typedCommand.ExecuteSingleArgument(executor, args[0])
38+
case commands.MultipleArgumentCommand:
39+
out, err = typedCommand.Execute(executor, args[0])
40+
}
41+
42+
_ = output.Render(command.Cobra().OutOrStdout(), config, out)
43+
44+
return err
45+
}

0 commit comments

Comments
 (0)