Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core/reflect): handle missing values in slice with multiple elements #3762

Merged
merged 5 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/core/arg_file_content.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func loadArgsFileContent(cmd *Command, cmdArgs interface{}) error {
}

fieldName := strcase.ToPublicGoName(argSpec.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
continue
}
Expand Down
29 changes: 18 additions & 11 deletions internal/core/reflect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"errors"
"fmt"
"reflect"
"sort"
Expand Down Expand Up @@ -34,26 +35,33 @@ func newObjectWithForcedJSONTags(t reflect.Type) interface{} {
return reflect.New(reflect.StructOf(structFieldsCopy)).Interface()
}

// getValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
// GetValuesForFieldByName recursively search for fields in a cmdArgs' value and returns its values if they exist.
// The search is based on the name of the field.
func getValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
func GetValuesForFieldByName(value reflect.Value, parts []string) (values []reflect.Value, err error) {
if len(parts) == 0 {
return []reflect.Value{value}, nil
}

switch value.Kind() {
case reflect.Ptr:
return getValuesForFieldByName(value.Elem(), parts)
return GetValuesForFieldByName(value.Elem(), parts)

case reflect.Slice:
values := []reflect.Value(nil)
errs := []error(nil)

for i := 0; i < value.Len(); i++ {
newValues, err := getValuesForFieldByName(value.Index(i), parts[1:])
newValues, err := GetValuesForFieldByName(value.Index(i), parts[1:])
if err != nil {
return nil, err
errs = append(errs, err)
} else {
values = append(values, newValues...)
}
values = append(values, newValues...)
}

if len(values) == 0 && len(errs) != 0 {
return nil, errors.Join(errs...)
}

return values, nil

case reflect.Map:
Expand All @@ -70,7 +78,7 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl

for _, mapKey := range mapKeys {
mapValue := value.MapIndex(mapKey)
newValues, err := getValuesForFieldByName(mapValue, parts[1:])
newValues, err := GetValuesForFieldByName(mapValue, parts[1:])
if err != nil {
return nil, err
}
Expand All @@ -93,19 +101,18 @@ func getValuesForFieldByName(value reflect.Value, parts []string) (values []refl

fieldName := strcase.ToPublicGoName(parts[0])
if fieldIndex, exist := fieldIndexByName[fieldName]; exist {
return getValuesForFieldByName(value.Field(fieldIndex), parts[1:])
return GetValuesForFieldByName(value.Field(fieldIndex), parts[1:])
}

// If it does not exist we try to find it in nested anonymous field
for _, fieldIndex := range anonymousFieldIndexes {
newValues, err := getValuesForFieldByName(value.Field(fieldIndex), parts)
newValues, err := GetValuesForFieldByName(value.Field(fieldIndex), parts)
if err == nil {
return newValues, nil
}
}

return nil, fmt.Errorf("field %v does not exist for %v", fieldName, value.Type().Name())
}

return nil, fmt.Errorf("case is not handled")
}
181 changes: 181 additions & 0 deletions internal/core/reflect_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package core_test

import (
"net"
"reflect"
"strings"
"testing"

"github.com/alecthomas/assert"
"github.com/scaleway/scaleway-cli/v2/internal/core"
"github.com/scaleway/scaleway-sdk-go/scw"
)

type RequestEmbedding struct {
EmbeddingField1 string
EmbeddingField2 int
}

type CreateRequest struct {
*RequestEmbedding
CreateField1 string
CreateField2 int
}

type ExtendedRequest struct {
*CreateRequest
ExtendedField1 string
ExtendedField2 int
}

type ArrowRequest struct {
PrivateNetwork *PrivateNetwork
}

type SpecialRequest struct {
*RequestEmbedding
TabRequest []*ArrowRequest
}

type EndpointSpecPrivateNetwork struct {
PrivateNetworkID string
ServiceIP *scw.IPNet
}

type PrivateNetwork struct {
*EndpointSpecPrivateNetwork
OtherValue string
}

func Test_getValuesForFieldByName(t *testing.T) {
type TestCase struct {
cmdArgs interface{}
fieldName string
expectedError string
expectedValues []reflect.Value
}

expectedServiceIP := &scw.IPNet{
IPNet: net.IPNet{
IP: net.ParseIP("192.0.2.1"),
Mask: net.CIDRMask(24, 32),
},
}

tests := []struct {
name string
testCase TestCase
testFunc func(*testing.T, TestCase)
}{
{
name: "Simple test",
testCase: TestCase{
cmdArgs: &ExtendedRequest{
CreateRequest: &CreateRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
CreateField1: "value3",
CreateField2: 4,
},
ExtendedField1: "value5",
ExtendedField2: 6,
},
fieldName: "EmbeddingField1",
expectedError: "",
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, tc.expectedError, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
{
name: "Error test",
testCase: TestCase{
cmdArgs: &ExtendedRequest{
CreateRequest: &CreateRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
CreateField1: "value3",
CreateField2: 4,
},
ExtendedField1: "value5",
ExtendedField2: 6,
},
fieldName: "NotExist",
expectedError: "field NotExist does not exist for ExtendedRequest",
expectedValues: []reflect.Value{reflect.ValueOf("value1")},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, tc.expectedError, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
{

name: "Special test",
testCase: TestCase{
cmdArgs: &SpecialRequest{
RequestEmbedding: &RequestEmbedding{
EmbeddingField1: "value1",
EmbeddingField2: 2,
},
TabRequest: []*ArrowRequest{
{
PrivateNetwork: &PrivateNetwork{
EndpointSpecPrivateNetwork: &EndpointSpecPrivateNetwork{
ServiceIP: &scw.IPNet{
IPNet: net.IPNet{
IP: net.ParseIP("192.0.2.1"),
Mask: net.CIDRMask(24, 32),
},
},
},
},
},
{
PrivateNetwork: &PrivateNetwork{
OtherValue: "hello",
},
},
},
},
fieldName: "tabRequest.{index}.privateNetwork.serviceIP",
expectedError: "",
expectedValues: []reflect.Value{reflect.ValueOf(expectedServiceIP)},
},
testFunc: func(t *testing.T, tc TestCase) {
values, err := core.GetValuesForFieldByName(reflect.ValueOf(tc.cmdArgs), strings.Split(tc.fieldName, "."))
if err != nil {
assert.Equal(t, nil, err.Error())
} else {
if tc.expectedValues != nil && !reflect.DeepEqual(tc.expectedValues[0].Interface(), values[0].Interface()) {
t.Errorf("Expected %v, got %v", tc.expectedValues[0].Interface(), values[0].Interface())
}
}
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.testFunc(t, tt.testCase)
})
}
}
6 changes: 3 additions & 3 deletions internal/core/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func DefaultCommandValidateFunc() CommandValidateFunc {
func validateArgValues(cmd *Command, cmdArgs interface{}) error {
for _, argSpec := range cmd.ArgSpecs {
fieldName := strcase.ToPublicGoName(argSpec.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
logger.Infof("could not validate arg value for '%v': invalid fieldName: %v: %v", argSpec.Name, fieldName, err.Error())
continue
Expand Down Expand Up @@ -75,7 +75,7 @@ func validateRequiredArgs(cmd *Command, cmdArgs interface{}, rawArgs args.RawArg
}

fieldName := strcase.ToPublicGoName(arg.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
if !arg.Required {
Expand Down Expand Up @@ -117,7 +117,7 @@ func validateDeprecated(ctx context.Context, cmd *Command, cmdArgs interface{},
deprecatedArgs := cmd.ArgSpecs.GetDeprecated(true)
for _, arg := range deprecatedArgs {
fieldName := strcase.ToPublicGoName(arg.Name)
fieldValues, err := getValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
fieldValues, err := GetValuesForFieldByName(reflect.ValueOf(cmdArgs), strings.Split(fieldName, "."))
if err != nil {
validationErr := fmt.Errorf("could not validate arg value for '%v': invalid field name '%v': %v", arg.Name, fieldName, err.Error())
if !arg.Required {
Expand Down
Loading