diff --git a/cli/testdata/deploy/neo-go.yml b/cli/testdata/deploy/neo-go.yml index e69de29bb2..c38e0355b6 100644 --- a/cli/testdata/deploy/neo-go.yml +++ b/cli/testdata/deploy/neo-go.yml @@ -0,0 +1 @@ +name: Test deploy diff --git a/pkg/core/blockchain.go b/pkg/core/blockchain.go index 3765fe4ac5..4fd1a58570 100644 --- a/pkg/core/blockchain.go +++ b/pkg/core/blockchain.go @@ -1380,6 +1380,7 @@ var ( ErrTxSmallNetworkFee = errors.New("too small network fee") ErrTxTooBig = errors.New("too big transaction") ErrMemPoolConflict = errors.New("invalid transaction due to conflicts with the memory pool") + ErrInvalidScript = errors.New("invalid script") ErrTxInvalidWitnessNum = errors.New("number of signers doesn't match witnesses") ErrInvalidAttribute = errors.New("invalid attribute") ) @@ -1387,6 +1388,13 @@ var ( // verifyAndPoolTx verifies whether a transaction is bonafide or not and tries // to add it to the mempool given. func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool.Pool, feer mempool.Feer, data ...interface{}) error { + // This code can technically be moved out of here, because it doesn't + // really require a chain lock. + err := vm.IsScriptCorrect(t.Script, nil) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidScript, err) + } + height := bc.BlockHeight() isPartialTx := data != nil if t.ValidUntilBlock <= height || !isPartialTx && t.ValidUntilBlock > height+transaction.MaxValidUntilBlockIncrement { @@ -1424,7 +1432,7 @@ func (bc *Blockchain) verifyAndPoolTx(t *transaction.Transaction, pool *mempool. return err } } - err := bc.verifyTxWitnesses(t, nil, isPartialTx) + err = bc.verifyTxWitnesses(t, nil, isPartialTx) if err != nil { return err } @@ -1728,7 +1736,9 @@ var ( ErrWitnessHashMismatch = errors.New("witness hash mismatch") ErrNativeContractWitness = errors.New("native contract witness must have empty verification script") ErrVerificationFailed = errors.New("signature check failed") + ErrInvalidInvocation = errors.New("invalid invocation script") ErrInvalidSignature = fmt.Errorf("%w: invalid signature", ErrVerificationFailed) + ErrInvalidVerification = errors.New("invalid verification script") ErrUnknownVerificationContract = errors.New("unknown verification contract") ErrInvalidVerificationContract = errors.New("verification contract is missing `verify` method") ) @@ -1744,6 +1754,10 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160, if bc.contracts.ByHash(hash) != nil { return ErrNativeContractWitness } + err := vm.IsScriptCorrect(witness.VerificationScript, nil) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidVerification, err) + } v.LoadScriptWithFlags(witness.VerificationScript, callflag.ReadStates) } else { cs, err := ic.GetContract(hash) @@ -1765,6 +1779,10 @@ func (bc *Blockchain) initVerificationVM(ic *interop.Context, hash util.Uint160, } } if len(witness.InvocationScript) != 0 { + err := vm.IsScriptCorrect(witness.InvocationScript, nil) + if err != nil { + return fmt.Errorf("%w: %v", ErrInvalidInvocation, err) + } v.LoadScript(witness.InvocationScript) if isNative { if err := v.StepOut(); err != nil { diff --git a/pkg/core/blockchain_test.go b/pkg/core/blockchain_test.go index 7d72223c1b..bb6d2b4dfb 100644 --- a/pkg/core/blockchain_test.go +++ b/pkg/core/blockchain_test.go @@ -390,6 +390,42 @@ func TestVerifyTx(t *testing.T) { require.Equal(t, expectedNetFee, bc.FeePerByte()*int64(actualSize)+gasConsumed) }) }) + t.Run("InvalidTxScript", func(t *testing.T) { + tx := bc.newTestTx(h, testScript) + tx.Script = append(tx.Script, 0xff) + require.NoError(t, accs[0].SignTx(tx)) + checkErr(t, ErrInvalidScript, tx) + }) + t.Run("InvalidVerificationScript", func(t *testing.T) { + tx := bc.newTestTx(h, testScript) + verif := []byte{byte(opcode.JMP), 3, 0xff, byte(opcode.PUSHT)} + tx.Signers = append(tx.Signers, transaction.Signer{ + Account: hash.Hash160(verif), + Scopes: transaction.Global, + }) + tx.NetworkFee += 1000000 + require.NoError(t, accs[0].SignTx(tx)) + tx.Scripts = append(tx.Scripts, transaction.Witness{ + InvocationScript: []byte{}, + VerificationScript: verif, + }) + checkErr(t, ErrInvalidVerification, tx) + }) + t.Run("InvalidInvocationScript", func(t *testing.T) { + tx := bc.newTestTx(h, testScript) + verif := []byte{byte(opcode.PUSHT)} + tx.Signers = append(tx.Signers, transaction.Signer{ + Account: hash.Hash160(verif), + Scopes: transaction.Global, + }) + tx.NetworkFee += 1000000 + require.NoError(t, accs[0].SignTx(tx)) + tx.Scripts = append(tx.Scripts, transaction.Witness{ + InvocationScript: []byte{byte(opcode.JMP), 3, 0xff}, + VerificationScript: verif, + }) + checkErr(t, ErrInvalidInvocation, tx) + }) t.Run("Conflict", func(t *testing.T) { balance := bc.GetUtilityTokenBalance(h).Int64() tx := bc.newTestTx(h, testScript) @@ -583,7 +619,7 @@ func TestVerifyTx(t *testing.T) { }) t.Run("InvalidScript", func(t *testing.T) { tx := getOracleTx(t) - tx.Script[0] = ^tx.Script[0] + tx.Script = append(tx.Script, byte(opcode.NOP)) require.NoError(t, oracleAcc.SignTx(tx)) checkErr(t, ErrInvalidAttribute, tx) }) diff --git a/pkg/core/interop_system_test.go b/pkg/core/interop_system_test.go index 7bbe96958c..c7e8934671 100644 --- a/pkg/core/interop_system_test.go +++ b/pkg/core/interop_system_test.go @@ -910,7 +910,7 @@ func TestRuntimeCheckWitness(t *testing.T) { Hash: contractScriptHash, NEF: *ne, Manifest: manifest.Manifest{ - Groups: []manifest.Group{{PublicKey: pk.PublicKey()}}, + Groups: []manifest.Group{{PublicKey: pk.PublicKey(), Signature: make([]byte, keys.SignatureLen)}}, }, } require.NoError(t, bc.contracts.Management.PutContractState(ic.DAO, contractState)) diff --git a/pkg/core/native/management.go b/pkg/core/native/management.go index 217168afcc..6aded95129 100644 --- a/pkg/core/native/management.go +++ b/pkg/core/native/management.go @@ -7,6 +7,7 @@ import ( "math" "math/big" "sync" + "unicode/utf8" "github.com/nspcc-dev/neo-go/pkg/core/dao" "github.com/nspcc-dev/neo-go/pkg/core/interop" @@ -21,6 +22,8 @@ import ( "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/util/bitfield" + "github.com/nspcc-dev/neo-go/pkg/vm" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -206,6 +209,9 @@ func (m *Management) getNefAndManifestFromItems(ic *interop.Context, args []stac resNef = &nf } if manifestBytes != nil { + if !utf8.Valid(manifestBytes) { + return nil, nil, errors.New("manifest is not UTF-8 compliant") + } resManifest = new(manifest.Manifest) err := json.Unmarshal(manifestBytes, resManifest) if err != nil { @@ -265,8 +271,13 @@ func (m *Management) Deploy(d dao.DAO, sender util.Uint160, neff *nef.File, mani if err != nil { return nil, err } - if !manif.IsValid(h) { - return nil, errors.New("invalid manifest for this contract") + err = manif.IsValid(h) + if err != nil { + return nil, fmt.Errorf("invalid manifest: %w", err) + } + err = checkScriptAndMethods(neff.Script, manif.ABI.Methods) + if err != nil { + return nil, err } newcontract := &state.Contract{ ID: id, @@ -322,12 +333,17 @@ func (m *Management) Update(d dao.DAO, hash util.Uint160, neff *nef.File, manif if manif.Name != contract.Manifest.Name { return nil, errors.New("contract name can't be changed") } - if !manif.IsValid(contract.Hash) { - return nil, errors.New("invalid manifest for this contract") + err = manif.IsValid(contract.Hash) + if err != nil { + return nil, fmt.Errorf("invalid manifest: %w", err) } m.markUpdated(hash) contract.Manifest = *manif } + err = checkScriptAndMethods(contract.NEF.Script, contract.Manifest.ABI.Methods) + if err != nil { + return nil, err + } contract.UpdateCounter++ err = m.PutContractState(d, contract) if err != nil { @@ -545,3 +561,15 @@ func (m *Management) emitNotification(ic *interop.Context, name string, hash uti } ic.Notifications = append(ic.Notifications, ne) } + +func checkScriptAndMethods(script []byte, methods []manifest.Method) error { + l := len(script) + offsets := bitfield.New(l) + for i := range methods { + if methods[i].Offset >= l { + return errors.New("out of bounds method offset") + } + offsets.Set(methods[i].Offset) + } + return vm.IsScriptCorrect(script, offsets) +} diff --git a/pkg/core/native/management_test.go b/pkg/core/native/management_test.go index 61fbd8b95e..7dfa4c0542 100644 --- a/pkg/core/native/management_test.go +++ b/pkg/core/native/management_test.go @@ -8,9 +8,11 @@ import ( "github.com/nspcc-dev/neo-go/pkg/core/interop" "github.com/nspcc-dev/neo-go/pkg/core/state" "github.com/nspcc-dev/neo-go/pkg/core/storage" + "github.com/nspcc-dev/neo-go/pkg/smartcontract" "github.com/nspcc-dev/neo-go/pkg/smartcontract/manifest" "github.com/nspcc-dev/neo-go/pkg/smartcontract/nef" "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/stretchr/testify/require" ) @@ -18,11 +20,16 @@ func TestDeployGetUpdateDestroyContract(t *testing.T) { mgmt := newManagement() d := dao.NewCached(dao.NewSimple(storage.NewMemoryStore(), netmode.UnitTestNet, false)) mgmt.Initialize(&interop.Context{DAO: d}) - script := []byte{1} + script := []byte{byte(opcode.RET)} sender := util.Uint160{1, 2, 3} ne, err := nef.NewFile(script) require.NoError(t, err) manif := manifest.NewManifest("Test") + manif.ABI.Methods = append(manif.ABI.Methods, manifest.Method{ + Name: "dummy", + ReturnType: smartcontract.VoidType, + Parameters: []manifest.Parameter{}, + }) h := state.CreateContractHash(sender, ne.Checksum, manif.Name) diff --git a/pkg/core/native_management_test.go b/pkg/core/native_management_test.go index 0408cac003..b0ba0082aa 100644 --- a/pkg/core/native_management_test.go +++ b/pkg/core/native_management_test.go @@ -1,6 +1,7 @@ package core import ( + "bytes" "encoding/json" "math/big" "testing" @@ -162,6 +163,17 @@ func TestContractDeploy(t *testing.T) { require.NoError(t, err) checkFAULTState(t, res) }) + t.Run("bad script in NEF", func(t *testing.T) { + nf, err := nef.FileFromBytes(nef1b) // make a full copy + require.NoError(t, err) + nf.Script[0] = 0xff + nf.CalculateChecksum() + nefbad, err := nf.Bytes() + require.NoError(t, err) + res, err := invokeContractMethod(bc, 11_00000000, mgmtHash, "deploy", nefbad, manif1) + require.NoError(t, err) + checkFAULTState(t, res) + }) t.Run("int for manifest", func(t *testing.T) { res, err := invokeContractMethod(bc, 11_00000000, mgmtHash, "deploy", nef1b, int64(1)) require.NoError(t, err) @@ -177,6 +189,13 @@ func TestContractDeploy(t *testing.T) { require.NoError(t, err) checkFAULTState(t, res) }) + t.Run("non-utf8 manifest", func(t *testing.T) { + manifB := bytes.Replace(manif1, []byte("TestMain"), []byte("\xff\xfe\xfd"), 1) // Replace name. + + res, err := invokeContractMethod(bc, 11_00000000, mgmtHash, "deploy", nef1b, manifB) + require.NoError(t, err) + checkFAULTState(t, res) + }) t.Run("invalid manifest", func(t *testing.T) { pkey, err := keys.NewPrivateKey() require.NoError(t, err) @@ -190,6 +209,32 @@ func TestContractDeploy(t *testing.T) { require.NoError(t, err) checkFAULTState(t, res) }) + t.Run("bad methods in manifest 1", func(t *testing.T) { + var badManifest = cs1.Manifest + badManifest.ABI.Methods = make([]manifest.Method, len(cs1.Manifest.ABI.Methods)) + copy(badManifest.ABI.Methods, cs1.Manifest.ABI.Methods) + badManifest.ABI.Methods[0].Offset = 100500 // out of bounds + + manifB, err := json.Marshal(badManifest) + require.NoError(t, err) + res, err := invokeContractMethod(bc, 11_00000000, mgmtHash, "deploy", nef1b, manifB) + require.NoError(t, err) + checkFAULTState(t, res) + }) + + t.Run("bad methods in manifest 2", func(t *testing.T) { + var badManifest = cs1.Manifest + badManifest.ABI.Methods = make([]manifest.Method, len(cs1.Manifest.ABI.Methods)) + copy(badManifest.ABI.Methods, cs1.Manifest.ABI.Methods) + badManifest.ABI.Methods[0].Offset = len(cs1.NEF.Script) - 2 // Ends with `CALLT(X,X);RET`. + + manifB, err := json.Marshal(badManifest) + require.NoError(t, err) + res, err := invokeContractMethod(bc, 11_00000000, mgmtHash, "deploy", nef1b, manifB) + require.NoError(t, err) + checkFAULTState(t, res) + }) + t.Run("not enough GAS", func(t *testing.T) { res, err := invokeContractMethod(bc, 1_00000000, mgmtHash, "deploy", nef1b, manif1) require.NoError(t, err) @@ -374,6 +419,19 @@ func TestContractUpdate(t *testing.T) { require.NoError(t, err) checkFAULTState(t, res) }) + t.Run("manifest and script mismatch", func(t *testing.T) { + nf, err := nef.FileFromBytes(nef1b) // Make a full copy. + require.NoError(t, err) + nf.Script = append(nf.Script, byte(opcode.RET)) + copy(nf.Script[1:], nf.Script) // Now all method offsets are wrong. + nf.Script[0] = byte(opcode.RET) // Even though the script is correct. + nf.CalculateChecksum() + nefnew, err := nf.Bytes() + require.NoError(t, err) + res, err := invokeContractMethod(bc, 10_00000000, cs1.Hash, "update", nefnew, manif1) + require.NoError(t, err) + checkFAULTState(t, res) + }) t.Run("change name", func(t *testing.T) { var badManifest = cs1.Manifest diff --git a/pkg/crypto/keys/publickey.go b/pkg/crypto/keys/publickey.go index 7eac8931fb..eb941fd88b 100644 --- a/pkg/crypto/keys/publickey.go +++ b/pkg/crypto/keys/publickey.go @@ -24,6 +24,9 @@ import ( // coordLen is the number of bytes in serialized X or Y coordinate. const coordLen = 32 +// SignatureLen is the length of standard signature for 256-bit EC key. +const SignatureLen = 64 + // PublicKeys is a list of public keys. type PublicKeys []*PublicKey @@ -333,7 +336,7 @@ func (p *PublicKey) Address() string { // Verify returns true if the signature is valid and corresponds // to the hash and public key. func (p *PublicKey) Verify(signature []byte, hash []byte) bool { - if p.X == nil || p.Y == nil || len(signature) != 64 { + if p.X == nil || p.Y == nil || len(signature) != SignatureLen { return false } rBytes := new(big.Int).SetBytes(signature[0:32]) diff --git a/pkg/smartcontract/manifest/abi.go b/pkg/smartcontract/manifest/abi.go new file mode 100644 index 0000000000..8d8f97c760 --- /dev/null +++ b/pkg/smartcontract/manifest/abi.go @@ -0,0 +1,160 @@ +package manifest + +import ( + "errors" + "sort" + + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +const ( + // MethodInit is a name for default initialization method. + MethodInit = "_initialize" + + // MethodDeploy is a name for default method called during contract deployment. + MethodDeploy = "_deploy" + + // MethodVerify is a name for default verification method. + MethodVerify = "verify" + + // MethodOnNEP17Payment is name of the method which is called when contract receives NEP-17 tokens. + MethodOnNEP17Payment = "onNEP17Payment" + + // MethodOnNEP11Payment is the name of the method which is called when contract receives NEP-11 tokens. + MethodOnNEP11Payment = "onNEP11Payment" +) + +// ABI represents a contract application binary interface. +type ABI struct { + Methods []Method `json:"methods"` + Events []Event `json:"events"` +} + +// GetMethod returns methods with the specified name. +func (a *ABI) GetMethod(name string, paramCount int) *Method { + for i := range a.Methods { + if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) { + return &a.Methods[i] + } + } + return nil +} + +// GetEvent returns event with the specified name. +func (a *ABI) GetEvent(name string) *Event { + for i := range a.Events { + if a.Events[i].Name == name { + return &a.Events[i] + } + } + return nil +} + +// IsValid checks ABI consistency and correctness. +func (a *ABI) IsValid() error { + if len(a.Methods) == 0 { + return errors.New("ABI contains no methods") + } + for i := range a.Methods { + err := a.Methods[i].IsValid() + if err != nil { + return err + } + } + if len(a.Methods) > 1 { + methods := make([]struct { + name string + params int + }, len(a.Methods)) + for i := range methods { + methods[i].name = a.Methods[i].Name + methods[i].params = len(a.Methods[i].Parameters) + } + sort.Slice(methods, func(i, j int) bool { + if methods[i].name < methods[j].name { + return true + } + if methods[i].name == methods[j].name { + return methods[i].params < methods[j].params + } + return false + }) + for i := range methods { + if i == 0 { + continue + } + if methods[i].name == methods[i-1].name && + methods[i].params == methods[i-1].params { + return errors.New("duplicate method specifications") + } + } + } + for i := range a.Events { + err := a.Events[i].IsValid() + if err != nil { + return err + } + } + if len(a.Events) > 1 { + names := make([]string, len(a.Events)) + for i := range a.Events { + names[i] = a.Events[i].Name + } + if stringsHaveDups(names) { + return errors.New("duplicate event names") + } + } + return nil +} + +// ToStackItem converts ABI to stackitem.Item. +func (a *ABI) ToStackItem() stackitem.Item { + methods := make([]stackitem.Item, len(a.Methods)) + for i := range a.Methods { + methods[i] = a.Methods[i].ToStackItem() + } + events := make([]stackitem.Item, len(a.Events)) + for i := range a.Events { + events[i] = a.Events[i].ToStackItem() + } + return stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(methods), + stackitem.Make(events), + }) +} + +// FromStackItem converts stackitem.Item to ABI. +func (a *ABI) FromStackItem(item stackitem.Item) error { + if item.Type() != stackitem.StructT { + return errors.New("invalid ABI stackitem type") + } + str := item.Value().([]stackitem.Item) + if len(str) != 2 { + return errors.New("invalid ABI stackitem length") + } + if str[0].Type() != stackitem.ArrayT { + return errors.New("invalid Methods stackitem type") + } + methods := str[0].Value().([]stackitem.Item) + a.Methods = make([]Method, len(methods)) + for i := range methods { + m := new(Method) + if err := m.FromStackItem(methods[i]); err != nil { + return err + } + a.Methods[i] = *m + } + if str[1].Type() != stackitem.ArrayT { + return errors.New("invalid Events stackitem type") + } + events := str[1].Value().([]stackitem.Item) + a.Events = make([]Event, len(events)) + for i := range events { + e := new(Event) + if err := e.FromStackItem(events[i]); err != nil { + return err + } + a.Events[i] = *e + } + return nil +} diff --git a/pkg/smartcontract/manifest/abi_test.go b/pkg/smartcontract/manifest/abi_test.go new file mode 100644 index 0000000000..2b9a54f0e0 --- /dev/null +++ b/pkg/smartcontract/manifest/abi_test.go @@ -0,0 +1,41 @@ +package manifest + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/stretchr/testify/require" +) + +func TestABIIsValid(t *testing.T) { + a := &ABI{} + require.Error(t, a.IsValid()) // No methods. + + a.Methods = append(a.Methods, Method{Name: "qwe"}) + require.NoError(t, a.IsValid()) + + a.Methods = append(a.Methods, Method{Name: "qaz"}) + require.NoError(t, a.IsValid()) + + a.Methods = append(a.Methods, Method{Name: "qaz", Offset: -42}) + require.Error(t, a.IsValid()) + + a.Methods = append(a.Methods[:len(a.Methods)-1], Method{Name: "qwe", Parameters: []Parameter{NewParameter("param", smartcontract.BoolType)}}) + require.NoError(t, a.IsValid()) + + a.Methods = append(a.Methods, Method{Name: "qwe"}) + require.Error(t, a.IsValid()) + a.Methods = a.Methods[:len(a.Methods)-1] + + a.Events = append(a.Events, Event{Name: "wsx"}) + require.NoError(t, a.IsValid()) + + a.Events = append(a.Events, Event{}) + require.Error(t, a.IsValid()) + + a.Events = append(a.Events[:len(a.Events)-1], Event{Name: "edc"}) + require.NoError(t, a.IsValid()) + + a.Events = append(a.Events, Event{Name: "wsx"}) + require.Error(t, a.IsValid()) +} diff --git a/pkg/smartcontract/manifest/event.go b/pkg/smartcontract/manifest/event.go new file mode 100644 index 0000000000..00a87e988f --- /dev/null +++ b/pkg/smartcontract/manifest/event.go @@ -0,0 +1,62 @@ +package manifest + +import ( + "errors" + + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// Event is a description of a single event. +type Event struct { + Name string `json:"name"` + Parameters []Parameter `json:"parameters"` +} + +// IsValid checks Event consistency and correctness. +func (e *Event) IsValid() error { + if e.Name == "" { + return errors.New("empty or absent name") + } + return Parameters(e.Parameters).AreValid() +} + +// ToStackItem converts Event to stackitem.Item. +func (e *Event) ToStackItem() stackitem.Item { + params := make([]stackitem.Item, len(e.Parameters)) + for i := range e.Parameters { + params[i] = e.Parameters[i].ToStackItem() + } + return stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(e.Name), + stackitem.Make(params), + }) +} + +// FromStackItem converts stackitem.Item to Event. +func (e *Event) FromStackItem(item stackitem.Item) error { + var err error + if item.Type() != stackitem.StructT { + return errors.New("invalid Event stackitem type") + } + event := item.Value().([]stackitem.Item) + if len(event) != 2 { + return errors.New("invalid Event stackitem length") + } + e.Name, err = stackitem.ToString(event[0]) + if err != nil { + return err + } + if event[1].Type() != stackitem.ArrayT { + return errors.New("invalid Params stackitem type") + } + params := event[1].Value().([]stackitem.Item) + e.Parameters = make([]Parameter, len(params)) + for i := range params { + p := new(Parameter) + if err := p.FromStackItem(params[i]); err != nil { + return err + } + e.Parameters[i] = *p + } + return nil +} diff --git a/pkg/smartcontract/manifest/event_test.go b/pkg/smartcontract/manifest/event_test.go new file mode 100644 index 0000000000..2e50bdd46a --- /dev/null +++ b/pkg/smartcontract/manifest/event_test.go @@ -0,0 +1,66 @@ +package manifest + +import ( + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +func TestEventIsValid(t *testing.T) { + e := Event{} + require.Error(t, e.IsValid()) + + e.Name = "some" + require.NoError(t, e.IsValid()) + + e.Parameters = make([]Parameter, 0) + require.NoError(t, e.IsValid()) + + e.Parameters = append(e.Parameters, NewParameter("p1", smartcontract.BoolType)) + require.NoError(t, e.IsValid()) + + e.Parameters = append(e.Parameters, NewParameter("p2", smartcontract.IntegerType)) + require.NoError(t, e.IsValid()) + + e.Parameters = append(e.Parameters, NewParameter("p3", smartcontract.IntegerType)) + require.NoError(t, e.IsValid()) + + e.Parameters = append(e.Parameters, NewParameter("p1", smartcontract.IntegerType)) + require.Error(t, e.IsValid()) +} + +func TestEvent_ToStackItemFromStackItem(t *testing.T) { + m := &Event{ + Name: "mur", + Parameters: []Parameter{{Name: "p1", Type: smartcontract.BoolType}}, + } + expected := stackitem.NewStruct([]stackitem.Item{ + stackitem.NewByteArray([]byte(m.Name)), + stackitem.NewArray([]stackitem.Item{ + stackitem.NewStruct([]stackitem.Item{ + stackitem.NewByteArray([]byte(m.Parameters[0].Name)), + stackitem.NewBigInteger(big.NewInt(int64(m.Parameters[0].Type))), + }), + }), + }) + CheckToFromStackItem(t, m, expected) +} + +func TestEvent_FromStackItemErrors(t *testing.T) { + errCases := map[string]stackitem.Item{ + "not a struct": stackitem.NewArray([]stackitem.Item{}), + "invalid length": stackitem.NewStruct([]stackitem.Item{}), + "invalid name type": stackitem.NewStruct([]stackitem.Item{stackitem.NewInterop(nil), stackitem.Null{}}), + "invalid parameters type": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.Null{}}), + "invalid parameter": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.NewArray([]stackitem.Item{stackitem.NewStruct([]stackitem.Item{})})}), + } + for name, errCase := range errCases { + t.Run(name, func(t *testing.T) { + p := new(Event) + require.Error(t, p.FromStackItem(errCase)) + }) + } +} diff --git a/pkg/smartcontract/manifest/group.go b/pkg/smartcontract/manifest/group.go new file mode 100644 index 0000000000..920f17d38b --- /dev/null +++ b/pkg/smartcontract/manifest/group.go @@ -0,0 +1,132 @@ +package manifest + +import ( + "crypto/elliptic" + "encoding/hex" + "encoding/json" + "errors" + "sort" + + "github.com/nspcc-dev/neo-go/pkg/crypto/hash" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// Group represents a group of smartcontracts identified by a public key. +// Every SC in a group must provide signature of it's hash to prove +// it belongs to a group. +type Group struct { + PublicKey *keys.PublicKey `json:"pubkey"` + Signature []byte `json:"signature"` +} + +// Groups is just an array of Group. +type Groups []Group + +type groupAux struct { + PublicKey string `json:"pubkey"` + Signature []byte `json:"signature"` +} + +// IsValid checks whether group's signature corresponds to the given hash. +func (g *Group) IsValid(h util.Uint160) error { + if !g.PublicKey.Verify(g.Signature, hash.Sha256(h.BytesBE()).BytesBE()) { + return errors.New("incorrect group signature") + } + return nil +} + +// AreValid checks for groups correctness and uniqueness. +func (g Groups) AreValid(h util.Uint160) error { + for i := range g { + err := g[i].IsValid(h) + if err != nil { + return err + } + } + if len(g) < 2 { + return nil + } + pkeys := make(keys.PublicKeys, len(g)) + for i := range g { + pkeys[i] = g[i].PublicKey + } + sort.Sort(pkeys) + for i := range pkeys { + if i == 0 { + continue + } + if pkeys[i].Cmp(pkeys[i-1]) == 0 { + return errors.New("duplicate group keys") + } + } + return nil +} + +// MarshalJSON implements json.Marshaler interface. +func (g *Group) MarshalJSON() ([]byte, error) { + aux := &groupAux{ + PublicKey: hex.EncodeToString(g.PublicKey.Bytes()), + Signature: g.Signature, + } + return json.Marshal(aux) +} + +// UnmarshalJSON implements json.Unmarshaler interface. +func (g *Group) UnmarshalJSON(data []byte) error { + aux := new(groupAux) + if err := json.Unmarshal(data, aux); err != nil { + return err + } + b, err := hex.DecodeString(aux.PublicKey) + if err != nil { + return err + } + pub := new(keys.PublicKey) + if err := pub.DecodeBytes(b); err != nil { + return err + } + g.PublicKey = pub + if len(aux.Signature) != keys.SignatureLen { + return errors.New("wrong signature length") + } + g.Signature = aux.Signature + return nil +} + +// ToStackItem converts Group to stackitem.Item. +func (g *Group) ToStackItem() stackitem.Item { + return stackitem.NewStruct([]stackitem.Item{ + stackitem.NewByteArray(g.PublicKey.Bytes()), + stackitem.NewByteArray(g.Signature), + }) +} + +// FromStackItem converts stackitem.Item to Group. +func (g *Group) FromStackItem(item stackitem.Item) error { + if item.Type() != stackitem.StructT { + return errors.New("invalid Group stackitem type") + } + group := item.Value().([]stackitem.Item) + if len(group) != 2 { + return errors.New("invalid Group stackitem length") + } + pKey, err := group[0].TryBytes() + if err != nil { + return err + } + g.PublicKey, err = keys.NewPublicKeyFromBytes(pKey, elliptic.P256()) + if err != nil { + return err + } + sig, err := group[1].TryBytes() + if err != nil { + return err + } + if len(sig) != keys.SignatureLen { + return errors.New("wrong signature length") + } + g.Signature = sig + return nil +} diff --git a/pkg/smartcontract/manifest/group_test.go b/pkg/smartcontract/manifest/group_test.go new file mode 100644 index 0000000000..464ada1300 --- /dev/null +++ b/pkg/smartcontract/manifest/group_test.go @@ -0,0 +1,43 @@ +package manifest + +import ( + "testing" + + "github.com/nspcc-dev/neo-go/internal/testserdes" + "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/util" + "github.com/stretchr/testify/require" +) + +func TestGroupJSONInOut(t *testing.T) { + priv, err := keys.NewPrivateKey() + require.NoError(t, err) + pub := priv.PublicKey() + sig := make([]byte, keys.SignatureLen) + g := Group{pub, sig} + testserdes.MarshalUnmarshalJSON(t, &g, new(Group)) +} + +func TestGroupsAreValid(t *testing.T) { + h := util.Uint160{42, 42, 42} + priv, err := keys.NewPrivateKey() + require.NoError(t, err) + priv2, err := keys.NewPrivateKey() + require.NoError(t, err) + pub := priv.PublicKey() + pub2 := priv2.PublicKey() + gcorrect := Group{pub, priv.Sign(h.BytesBE())} + gcorrect2 := Group{pub2, priv2.Sign(h.BytesBE())} + gincorrect := Group{pub, priv.Sign(h.BytesLE())} + gps := Groups{gcorrect} + require.NoError(t, gps.AreValid(h)) + + gps = Groups{gincorrect} + require.Error(t, gps.AreValid(h)) + + gps = Groups{gcorrect, gcorrect2} + require.NoError(t, gps.AreValid(h)) + + gps = Groups{gcorrect, gcorrect} + require.Error(t, gps.AreValid(h)) +} diff --git a/pkg/smartcontract/manifest/manifest.go b/pkg/smartcontract/manifest/manifest.go index a2c41c1fe0..c7bbd0fe78 100644 --- a/pkg/smartcontract/manifest/manifest.go +++ b/pkg/smartcontract/manifest/manifest.go @@ -4,8 +4,8 @@ import ( "encoding/json" "errors" "math" + "sort" - "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -14,33 +14,12 @@ const ( // MaxManifestSize is a max length for a valid contract manifest. MaxManifestSize = math.MaxUint16 - // MethodInit is a name for default initialization method. - MethodInit = "_initialize" - - // MethodDeploy is a name for default method called during contract deployment. - MethodDeploy = "_deploy" - - // MethodVerify is a name for default verification method. - MethodVerify = "verify" - - // MethodOnNEP17Payment is name of the method which is called when contract receives NEP-17 tokens. - MethodOnNEP17Payment = "onNEP17Payment" - - // MethodOnNEP11Payment is the name of the method which is called when contract receives NEP-11 tokens. - MethodOnNEP11Payment = "onNEP11Payment" - // NEP10StandardName represents the name of NEP10 smartcontract standard. NEP10StandardName = "NEP-10" // NEP17StandardName represents the name of NEP17 smartcontract standard. NEP17StandardName = "NEP-17" ) -// ABI represents a contract application binary interface. -type ABI struct { - Methods []Method `json:"methods"` - Events []Event `json:"events"` -} - // Manifest represens contract metadata. type Manifest struct { // Name is a contract's name. @@ -81,26 +60,6 @@ func DefaultManifest(name string) *Manifest { return m } -// GetMethod returns methods with the specified name. -func (a *ABI) GetMethod(name string, paramCount int) *Method { - for i := range a.Methods { - if a.Methods[i].Name == name && (paramCount == -1 || len(a.Methods[i].Parameters) == paramCount) { - return &a.Methods[i] - } - } - return nil -} - -// GetEvent returns event with the specified name. -func (a *ABI) GetEvent(name string) *Event { - for i := range a.Events { - if a.Events[i].Name == name { - return &a.Events[i] - } - } - return nil -} - // CanCall returns true is current contract is allowed to call // method of another contract with specified hash. func (m *Manifest) CanCall(hash util.Uint160, toCall *Manifest, method string) bool { @@ -112,34 +71,51 @@ func (m *Manifest) CanCall(hash util.Uint160, toCall *Manifest, method string) b return false } -// IsValid checks whether the hash given is correct wrt manifest's groups. -func (m *Manifest) IsValid(hash util.Uint160) bool { - for _, g := range m.Groups { - if !g.IsValid(hash) { - return false - } +// IsValid checks manifest internal consistency and correctness, one of the +// checks is for group signature correctness, contract hash is passed for it. +func (m *Manifest) IsValid(hash util.Uint160) error { + var err error + + if m.Name == "" { + return errors.New("no name") } - return true -} -// EncodeBinary implements io.Serializable. -func (m *Manifest) EncodeBinary(w *io.BinWriter) { - data, err := json.Marshal(m) + for i := range m.SupportedStandards { + if m.SupportedStandards[i] == "" { + return errors.New("invalid nameless supported standard") + } + } + if len(m.SupportedStandards) > 1 { + names := make([]string, len(m.SupportedStandards)) + copy(names, m.SupportedStandards) + if stringsHaveDups(names) { + return errors.New("duplicate supported standards") + } + } + err = m.ABI.IsValid() if err != nil { - w.Err = err - return + return err } - w.WriteVarBytes(data) -} - -// DecodeBinary implements io.Serializable. -func (m *Manifest) DecodeBinary(r *io.BinReader) { - data := r.ReadVarBytes(MaxManifestSize) - if r.Err != nil { - return - } else if err := json.Unmarshal(data, m); err != nil { - r.Err = err + err = Groups(m.Groups).AreValid(hash) + if err != nil { + return err } + if len(m.Trusts.Value) > 1 { + hashes := make([]util.Uint160, len(m.Trusts.Value)) + copy(hashes, m.Trusts.Value) + sort.Slice(hashes, func(i, j int) bool { + return hashes[i].Less(hashes[j]) + }) + for i := range hashes { + if i == 0 { + continue + } + if hashes[i] == hashes[i-1] { + return errors.New("duplicate trusted contracts") + } + } + } + return Permissions(m.Permissions).AreValid() } // ToStackItem converts Manifest to stackitem.Item. @@ -267,55 +243,3 @@ func (m *Manifest) FromStackItem(item stackitem.Item) error { } return json.Unmarshal(extra, &m.Extra) } - -// ToStackItem converts ABI to stackitem.Item. -func (a *ABI) ToStackItem() stackitem.Item { - methods := make([]stackitem.Item, len(a.Methods)) - for i := range a.Methods { - methods[i] = a.Methods[i].ToStackItem() - } - events := make([]stackitem.Item, len(a.Events)) - for i := range a.Events { - events[i] = a.Events[i].ToStackItem() - } - return stackitem.NewStruct([]stackitem.Item{ - stackitem.Make(methods), - stackitem.Make(events), - }) -} - -// FromStackItem converts stackitem.Item to ABI. -func (a *ABI) FromStackItem(item stackitem.Item) error { - if item.Type() != stackitem.StructT { - return errors.New("invalid ABI stackitem type") - } - str := item.Value().([]stackitem.Item) - if len(str) != 2 { - return errors.New("invalid ABI stackitem length") - } - if str[0].Type() != stackitem.ArrayT { - return errors.New("invalid Methods stackitem type") - } - methods := str[0].Value().([]stackitem.Item) - a.Methods = make([]Method, len(methods)) - for i := range methods { - m := new(Method) - if err := m.FromStackItem(methods[i]); err != nil { - return err - } - a.Methods[i] = *m - } - if str[1].Type() != stackitem.ArrayT { - return errors.New("invalid Events stackitem type") - } - events := str[1].Value().([]stackitem.Item) - a.Events = make([]Event, len(events)) - for i := range events { - e := new(Event) - if err := e.FromStackItem(events[i]); err != nil { - return err - } - a.Events[i] = *e - } - return nil -} diff --git a/pkg/smartcontract/manifest/manifest_test.go b/pkg/smartcontract/manifest/manifest_test.go index 629bea1d9f..60a9172528 100644 --- a/pkg/smartcontract/manifest/manifest_test.go +++ b/pkg/smartcontract/manifest/manifest_test.go @@ -109,12 +109,94 @@ func TestPermission_IsAllowed(t *testing.T) { func TestIsValid(t *testing.T) { contractHash := util.Uint160{1, 2, 3} - m := NewManifest("Test") + m := &Manifest{} - t.Run("valid, no groups", func(t *testing.T) { - require.True(t, m.IsValid(contractHash)) + t.Run("invalid, no name", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) }) + m = NewManifest("Test") + + t.Run("invalid, no ABI methods", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + + m.ABI.Methods = append(m.ABI.Methods, Method{ + Name: "dummy", + ReturnType: smartcontract.VoidType, + Parameters: []Parameter{}, + }) + + t.Run("valid, no groups/events", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.ABI.Events = append(m.ABI.Events, Event{ + Name: "itHappened", + Parameters: []Parameter{}, + }) + + t.Run("valid, with events", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.ABI.Events = append(m.ABI.Events, Event{ + Name: "itHappened", + Parameters: []Parameter{ + NewParameter("qwerty", smartcontract.IntegerType), + NewParameter("qwerty", smartcontract.IntegerType), + }, + }) + + t.Run("invalid, bad event", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + m.ABI.Events = m.ABI.Events[:1] + + m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) + t.Run("valid, with permissions", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.Permissions = append(m.Permissions, *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) + t.Run("invalid, with permissions", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + m.Permissions = m.Permissions[:1] + + m.SupportedStandards = append(m.SupportedStandards, "NEP-17") + t.Run("valid, with standards", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.SupportedStandards = append(m.SupportedStandards, "") + t.Run("invalid, with nameless standard", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + m.SupportedStandards = m.SupportedStandards[:1] + + m.SupportedStandards = append(m.SupportedStandards, "NEP-17") + t.Run("invalid, with duplicate standards", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + m.SupportedStandards = m.SupportedStandards[:1] + + m.Trusts.Add(util.Uint160{1, 2, 3}) + t.Run("valid, with trust", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.Trusts.Add(util.Uint160{3, 2, 1}) + t.Run("valid, with trusts", func(t *testing.T) { + require.NoError(t, m.IsValid(contractHash)) + }) + + m.Trusts.Add(util.Uint160{1, 2, 3}) + t.Run("invalid, with trusts", func(t *testing.T) { + require.Error(t, m.IsValid(contractHash)) + }) + m.Trusts.Restrict() + t.Run("with groups", func(t *testing.T) { m.Groups = make([]Group, 3) pks := make([]*keys.PrivateKey, 3) @@ -129,11 +211,11 @@ func TestIsValid(t *testing.T) { } t.Run("valid", func(t *testing.T) { - require.True(t, m.IsValid(contractHash)) + require.NoError(t, m.IsValid(contractHash)) }) t.Run("invalid, wrong contract hash", func(t *testing.T) { - require.False(t, m.IsValid(util.Uint160{4, 5, 6})) + require.Error(t, m.IsValid(util.Uint160{4, 5, 6})) }) t.Run("invalid, wrong group signature", func(t *testing.T) { @@ -145,7 +227,7 @@ func TestIsValid(t *testing.T) { // of the contract hash. Signature: pk.Sign([]byte{1, 2, 3}), }) - require.False(t, m.IsValid(contractHash)) + require.Error(t, m.IsValid(contractHash)) }) }) } @@ -189,7 +271,7 @@ func TestManifestToStackItem(t *testing.T) { }, Groups: []Group{{ PublicKey: pk.PublicKey(), - Signature: []byte{1, 2, 3}, + Signature: make([]byte, keys.SignatureLen), }}, Permissions: []Permission{*NewPermission(PermissionWildcard)}, SupportedStandards: []string{"NEP-17"}, diff --git a/pkg/smartcontract/manifest/method.go b/pkg/smartcontract/manifest/method.go index b3ef30ab9c..85f512fd0c 100644 --- a/pkg/smartcontract/manifest/method.go +++ b/pkg/smartcontract/manifest/method.go @@ -1,43 +1,12 @@ package manifest import ( - "crypto/elliptic" - "encoding/hex" - "encoding/json" "errors" - "github.com/nspcc-dev/neo-go/pkg/crypto/hash" - "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/smartcontract" - "github.com/nspcc-dev/neo-go/pkg/util" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) -// Parameter represents smartcontract's parameter's definition. -type Parameter struct { - Name string `json:"name"` - Type smartcontract.ParamType `json:"type"` -} - -// Event is a description of a single event. -type Event struct { - Name string `json:"name"` - Parameters []Parameter `json:"parameters"` -} - -// Group represents a group of smartcontracts identified by a public key. -// Every SC in a group must provide signature of it's hash to prove -// it belongs to a group. -type Group struct { - PublicKey *keys.PublicKey `json:"pubkey"` - Signature []byte `json:"signature"` -} - -type groupAux struct { - PublicKey string `json:"pubkey"` - Signature []byte `json:"signature"` -} - // Method represents method's metadata. type Method struct { Name string `json:"name"` @@ -47,78 +16,19 @@ type Method struct { Safe bool `json:"safe"` } -// NewParameter returns new parameter of specified name and type. -func NewParameter(name string, typ smartcontract.ParamType) Parameter { - return Parameter{ - Name: name, - Type: typ, +// IsValid checks Method consistency and correctness. +func (m *Method) IsValid() error { + if m.Name == "" { + return errors.New("empty or absent name") } -} - -// IsValid checks whether group's signature corresponds to the given hash. -func (g *Group) IsValid(h util.Uint160) bool { - return g.PublicKey.Verify(g.Signature, hash.Sha256(h.BytesBE()).BytesBE()) -} - -// MarshalJSON implements json.Marshaler interface. -func (g *Group) MarshalJSON() ([]byte, error) { - aux := &groupAux{ - PublicKey: hex.EncodeToString(g.PublicKey.Bytes()), - Signature: g.Signature, - } - return json.Marshal(aux) -} - -// UnmarshalJSON implements json.Unmarshaler interface. -func (g *Group) UnmarshalJSON(data []byte) error { - aux := new(groupAux) - if err := json.Unmarshal(data, aux); err != nil { - return err + if m.Offset < 0 { + return errors.New("negative offset") } - b, err := hex.DecodeString(aux.PublicKey) + _, err := smartcontract.ConvertToParamType(int(m.ReturnType)) if err != nil { return err } - pub := new(keys.PublicKey) - if err := pub.DecodeBytes(b); err != nil { - return err - } - g.PublicKey = pub - g.Signature = aux.Signature - return nil -} - -// ToStackItem converts Group to stackitem.Item. -func (g *Group) ToStackItem() stackitem.Item { - return stackitem.NewStruct([]stackitem.Item{ - stackitem.NewByteArray(g.PublicKey.Bytes()), - stackitem.NewByteArray(g.Signature), - }) -} - -// FromStackItem converts stackitem.Item to Group. -func (g *Group) FromStackItem(item stackitem.Item) error { - if item.Type() != stackitem.StructT { - return errors.New("invalid Group stackitem type") - } - group := item.Value().([]stackitem.Item) - if len(group) != 2 { - return errors.New("invalid Group stackitem length") - } - pKey, err := group[0].TryBytes() - if err != nil { - return err - } - g.PublicKey, err = keys.NewPublicKeyFromBytes(pKey, elliptic.P256()) - if err != nil { - return err - } - sig, err := group[1].TryBytes() - if err != nil { - return err - } - g.Signature = sig - return nil + return Parameters(m.Parameters).AreValid() } // ToStackItem converts Method to stackitem.Item. @@ -182,77 +92,3 @@ func (m *Method) FromStackItem(item stackitem.Item) error { m.Safe = safe return nil } - -// ToStackItem converts Parameter to stackitem.Item. -func (p *Parameter) ToStackItem() stackitem.Item { - return stackitem.NewStruct([]stackitem.Item{ - stackitem.Make(p.Name), - stackitem.Make(int(p.Type)), - }) -} - -// FromStackItem converts stackitem.Item to Parameter. -func (p *Parameter) FromStackItem(item stackitem.Item) error { - var err error - if item.Type() != stackitem.StructT { - return errors.New("invalid Parameter stackitem type") - } - param := item.Value().([]stackitem.Item) - if len(param) != 2 { - return errors.New("invalid Parameter stackitem length") - } - p.Name, err = stackitem.ToString(param[0]) - if err != nil { - return err - } - typ, err := param[1].TryInteger() - if err != nil { - return err - } - p.Type, err = smartcontract.ConvertToParamType(int(typ.Int64())) - if err != nil { - return err - } - return nil -} - -// ToStackItem converts Event to stackitem.Item. -func (e *Event) ToStackItem() stackitem.Item { - params := make([]stackitem.Item, len(e.Parameters)) - for i := range e.Parameters { - params[i] = e.Parameters[i].ToStackItem() - } - return stackitem.NewStruct([]stackitem.Item{ - stackitem.Make(e.Name), - stackitem.Make(params), - }) -} - -// FromStackItem converts stackitem.Item to Event. -func (e *Event) FromStackItem(item stackitem.Item) error { - var err error - if item.Type() != stackitem.StructT { - return errors.New("invalid Event stackitem type") - } - event := item.Value().([]stackitem.Item) - if len(event) != 2 { - return errors.New("invalid Event stackitem length") - } - e.Name, err = stackitem.ToString(event[0]) - if err != nil { - return err - } - if event[1].Type() != stackitem.ArrayT { - return errors.New("invalid Params stackitem type") - } - params := event[1].Value().([]stackitem.Item) - e.Parameters = make([]Parameter, len(params)) - for i := range params { - p := new(Parameter) - if err := p.FromStackItem(params[i]); err != nil { - return err - } - e.Parameters[i] = *p - } - return nil -} diff --git a/pkg/smartcontract/manifest/method_test.go b/pkg/smartcontract/manifest/method_test.go index 658b2894a7..2f48eb7972 100644 --- a/pkg/smartcontract/manifest/method_test.go +++ b/pkg/smartcontract/manifest/method_test.go @@ -10,6 +10,27 @@ import ( "github.com/stretchr/testify/require" ) +func TestMethodIsValid(t *testing.T) { + m := &Method{} + require.Error(t, m.IsValid()) // No name. + + m.Name = "qwerty" + require.NoError(t, m.IsValid()) + + m.Offset = -100 + require.Error(t, m.IsValid()) + + m.Offset = 100 + m.ReturnType = 0x42 // Invalid type. + require.Error(t, m.IsValid()) + + m.ReturnType = smartcontract.BoolType + require.NoError(t, m.IsValid()) + + m.Parameters = append(m.Parameters, NewParameter("param", smartcontract.BoolType), NewParameter("param", smartcontract.BoolType)) + require.Error(t, m.IsValid()) +} + func TestMethod_ToStackItemFromStackItem(t *testing.T) { m := &Method{ Name: "mur", @@ -52,76 +73,15 @@ func TestMethod_FromStackItemErrors(t *testing.T) { } } -func TestParameter_ToStackItemFromStackItem(t *testing.T) { - p := &Parameter{ - Name: "param", - Type: smartcontract.StringType, - } - expected := stackitem.NewStruct([]stackitem.Item{ - stackitem.NewByteArray([]byte(p.Name)), - stackitem.NewBigInteger(big.NewInt(int64(p.Type))), - }) - CheckToFromStackItem(t, p, expected) -} - -func TestParameter_FromStackItemErrors(t *testing.T) { - errCases := map[string]stackitem.Item{ - "not a struct": stackitem.NewArray([]stackitem.Item{}), - "invalid length": stackitem.NewStruct([]stackitem.Item{}), - "invalid name type": stackitem.NewStruct([]stackitem.Item{stackitem.NewInterop(nil), stackitem.Null{}}), - "invalid type type": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.Null{}}), - "invalid type value": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.NewBigInteger(big.NewInt(-100500))}), - } - for name, errCase := range errCases { - t.Run(name, func(t *testing.T) { - p := new(Parameter) - require.Error(t, p.FromStackItem(errCase)) - }) - } -} - -func TestEvent_ToStackItemFromStackItem(t *testing.T) { - m := &Event{ - Name: "mur", - Parameters: []Parameter{{Name: "p1", Type: smartcontract.BoolType}}, - } - expected := stackitem.NewStruct([]stackitem.Item{ - stackitem.NewByteArray([]byte(m.Name)), - stackitem.NewArray([]stackitem.Item{ - stackitem.NewStruct([]stackitem.Item{ - stackitem.NewByteArray([]byte(m.Parameters[0].Name)), - stackitem.NewBigInteger(big.NewInt(int64(m.Parameters[0].Type))), - }), - }), - }) - CheckToFromStackItem(t, m, expected) -} - -func TestEvent_FromStackItemErrors(t *testing.T) { - errCases := map[string]stackitem.Item{ - "not a struct": stackitem.NewArray([]stackitem.Item{}), - "invalid length": stackitem.NewStruct([]stackitem.Item{}), - "invalid name type": stackitem.NewStruct([]stackitem.Item{stackitem.NewInterop(nil), stackitem.Null{}}), - "invalid parameters type": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.Null{}}), - "invalid parameter": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.NewArray([]stackitem.Item{stackitem.NewStruct([]stackitem.Item{})})}), - } - for name, errCase := range errCases { - t.Run(name, func(t *testing.T) { - p := new(Event) - require.Error(t, p.FromStackItem(errCase)) - }) - } -} - func TestGroup_ToStackItemFromStackItem(t *testing.T) { pk, _ := keys.NewPrivateKey() g := &Group{ PublicKey: pk.PublicKey(), - Signature: []byte{1, 2, 3}, + Signature: make([]byte, keys.SignatureLen), } expected := stackitem.NewStruct([]stackitem.Item{ stackitem.NewByteArray(pk.PublicKey().Bytes()), - stackitem.NewByteArray([]byte{1, 2, 3}), + stackitem.NewByteArray(make([]byte, keys.SignatureLen)), }) CheckToFromStackItem(t, g, expected) } @@ -134,6 +94,7 @@ func TestGroup_FromStackItemErrors(t *testing.T) { "invalid pub type": stackitem.NewStruct([]stackitem.Item{stackitem.NewInterop(nil), stackitem.Null{}}), "invalid pub bytes": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{1}), stackitem.Null{}}), "invalid sig type": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray(pk.Bytes()), stackitem.NewInterop(nil)}), + "invalid sig len": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray(pk.Bytes()), stackitem.NewByteArray([]byte{1})}), } for name, errCase := range errCases { t.Run(name, func(t *testing.T) { diff --git a/pkg/smartcontract/manifest/parameter.go b/pkg/smartcontract/manifest/parameter.go new file mode 100644 index 0000000000..e739a52f5a --- /dev/null +++ b/pkg/smartcontract/manifest/parameter.go @@ -0,0 +1,106 @@ +package manifest + +import ( + "errors" + "sort" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" +) + +// Parameter represents smartcontract's parameter's definition. +type Parameter struct { + Name string `json:"name"` + Type smartcontract.ParamType `json:"type"` +} + +// Parameters is just an array of Parameter. +type Parameters []Parameter + +// NewParameter returns new parameter of specified name and type. +func NewParameter(name string, typ smartcontract.ParamType) Parameter { + return Parameter{ + Name: name, + Type: typ, + } +} + +// IsValid checks Parameter consistency and correctness. +func (p *Parameter) IsValid() error { + if p.Name == "" { + return errors.New("empty or absent name") + } + if p.Type == smartcontract.VoidType { + return errors.New("void parameter") + } + _, err := smartcontract.ConvertToParamType(int(p.Type)) + return err +} + +// ToStackItem converts Parameter to stackitem.Item. +func (p *Parameter) ToStackItem() stackitem.Item { + return stackitem.NewStruct([]stackitem.Item{ + stackitem.Make(p.Name), + stackitem.Make(int(p.Type)), + }) +} + +// FromStackItem converts stackitem.Item to Parameter. +func (p *Parameter) FromStackItem(item stackitem.Item) error { + var err error + if item.Type() != stackitem.StructT { + return errors.New("invalid Parameter stackitem type") + } + param := item.Value().([]stackitem.Item) + if len(param) != 2 { + return errors.New("invalid Parameter stackitem length") + } + p.Name, err = stackitem.ToString(param[0]) + if err != nil { + return err + } + typ, err := param[1].TryInteger() + if err != nil { + return err + } + p.Type, err = smartcontract.ConvertToParamType(int(typ.Int64())) + if err != nil { + return err + } + return nil +} + +// AreValid checks all parameters for validity and consistency. +func (p Parameters) AreValid() error { + for i := range p { + err := p[i].IsValid() + if err != nil { + return err + } + } + if len(p) < 2 { + return nil + } + names := make([]string, len(p)) + for i := range p { + names[i] = p[i].Name + } + if stringsHaveDups(names) { + return errors.New("duplicate parameter name") + } + return nil +} + +// stringsHaveDups checks given set of strings for duplicates. It modifies the slice given! +func stringsHaveDups(strings []string) bool { + sort.Strings(strings) + for i := range strings { + if i == 0 { + continue + } + if strings[i] == strings[i-1] { + return true + } + } + return false +} diff --git a/pkg/smartcontract/manifest/parameter_test.go b/pkg/smartcontract/manifest/parameter_test.go new file mode 100644 index 0000000000..9826e5f2a0 --- /dev/null +++ b/pkg/smartcontract/manifest/parameter_test.go @@ -0,0 +1,61 @@ +package manifest + +import ( + "math/big" + "testing" + + "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" + "github.com/stretchr/testify/require" +) + +func TestParametersAreValid(t *testing.T) { + ps := Parameters{} + require.NoError(t, ps.AreValid()) // No parameters. + + ps = append(ps, Parameter{}) + require.Error(t, ps.AreValid()) + + ps[0].Name = "qwerty" + require.NoError(t, ps.AreValid()) + + ps[0].Type = 0x42 // Invalid type. + require.Error(t, ps.AreValid()) + + ps[0].Type = smartcontract.VoidType + require.Error(t, ps.AreValid()) + + ps[0].Type = smartcontract.BoolType + require.NoError(t, ps.AreValid()) + + ps = append(ps, Parameter{Name: "qwerty"}) + require.Error(t, ps.AreValid()) +} + +func TestParameter_ToStackItemFromStackItem(t *testing.T) { + p := &Parameter{ + Name: "param", + Type: smartcontract.StringType, + } + expected := stackitem.NewStruct([]stackitem.Item{ + stackitem.NewByteArray([]byte(p.Name)), + stackitem.NewBigInteger(big.NewInt(int64(p.Type))), + }) + CheckToFromStackItem(t, p, expected) +} + +func TestParameter_FromStackItemErrors(t *testing.T) { + errCases := map[string]stackitem.Item{ + "not a struct": stackitem.NewArray([]stackitem.Item{}), + "invalid length": stackitem.NewStruct([]stackitem.Item{}), + "invalid name type": stackitem.NewStruct([]stackitem.Item{stackitem.NewInterop(nil), stackitem.Null{}}), + "invalid type type": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.Null{}}), + "invalid type value": stackitem.NewStruct([]stackitem.Item{stackitem.NewByteArray([]byte{}), stackitem.NewBigInteger(big.NewInt(-100500))}), + } + for name, errCase := range errCases { + t.Run(name, func(t *testing.T) { + p := new(Parameter) + require.Error(t, p.FromStackItem(errCase)) + }) + } +} diff --git a/pkg/smartcontract/manifest/permission.go b/pkg/smartcontract/manifest/permission.go index 831229cf65..e1c7cd6b19 100644 --- a/pkg/smartcontract/manifest/permission.go +++ b/pkg/smartcontract/manifest/permission.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "sort" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" "github.com/nspcc-dev/neo-go/pkg/util" @@ -36,6 +37,9 @@ type Permission struct { Methods WildStrings `json:"methods"` } +// Permissions is just an array of Permission. +type Permissions []Permission + type permissionAux struct { Contract PermissionDesc `json:"contract"` Methods WildStrings `json:"methods"` @@ -85,6 +89,82 @@ func (d *PermissionDesc) Group() *keys.PublicKey { return d.Value.(*keys.PublicKey) } +// IsValid checks if Permission is correct. +func (p *Permission) IsValid() error { + for i := range p.Methods.Value { + if p.Methods.Value[i] == "" { + return errors.New("empty method name") + } + } + if len(p.Methods.Value) < 2 { + return nil + } + names := make([]string, len(p.Methods.Value)) + copy(names, p.Methods.Value) + if stringsHaveDups(names) { + return errors.New("duplicate method names") + } + return nil +} + +// AreValid checks each Permission and ensures there are no duplicates. +func (ps Permissions) AreValid() error { + for i := range ps { + err := ps[i].IsValid() + if err != nil { + return err + } + } + if len(ps) < 2 { + return nil + } + contracts := make([]PermissionDesc, 0, len(ps)) + for i := range ps { + contracts = append(contracts, ps[i].Contract) + } + sort.Slice(contracts, func(i, j int) bool { + if contracts[i].Type < contracts[j].Type { + return true + } + if contracts[i].Type != contracts[j].Type { + return false + } + switch contracts[i].Type { + case PermissionHash: + return contracts[i].Hash().Less(contracts[j].Hash()) + case PermissionGroup: + return contracts[i].Group().Cmp(contracts[j].Group()) < 0 + } + return false + }) + for i := range contracts { + if i == 0 { + continue + } + j := i - 1 + if contracts[i].Type != contracts[j].Type { + continue + } + var bad bool + switch contracts[i].Type { + case PermissionWildcard: + bad = true + case PermissionHash: + if contracts[i].Hash() == contracts[j].Hash() { + bad = true + } + case PermissionGroup: + if contracts[i].Group().Cmp(contracts[j].Group()) == 0 { + bad = true + } + } + if bad { + return errors.New("duplicate contracts") + } + } + return nil +} + // IsAllowed checks if method is allowed to be executed. func (p *Permission) IsAllowed(hash util.Uint160, m *Manifest, method string) bool { switch p.Contract.Type { diff --git a/pkg/smartcontract/manifest/permission_test.go b/pkg/smartcontract/manifest/permission_test.go index 97eebb0073..1016bb542f 100644 --- a/pkg/smartcontract/manifest/permission_test.go +++ b/pkg/smartcontract/manifest/permission_test.go @@ -21,6 +21,62 @@ func TestNewPermission(t *testing.T) { require.Panics(t, func() { NewPermission(PermissionGroup, util.Uint160{}) }) } +func TestPermissionIsValid(t *testing.T) { + p := Permission{} + require.NoError(t, p.IsValid()) + + p.Methods.Add("") + require.Error(t, p.IsValid()) + + p.Methods.Value = nil + p.Methods.Add("qwerty") + require.NoError(t, p.IsValid()) + + p.Methods.Add("poiuyt") + require.NoError(t, p.IsValid()) + + p.Methods.Add("qwerty") + require.Error(t, p.IsValid()) +} + +func TestPermissionsAreValid(t *testing.T) { + p := Permissions{} + require.NoError(t, p.AreValid()) + + p = append(p, Permission{Methods: WildStrings{Value: []string{""}}}) + require.Error(t, p.AreValid()) + + p = p[:0] + p = append(p, *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) + require.NoError(t, p.AreValid()) + + priv0, err := keys.NewPrivateKey() + require.NoError(t, err) + priv1, err := keys.NewPrivateKey() + require.NoError(t, err) + + p = append(p, *NewPermission(PermissionGroup, priv0.PublicKey())) + require.NoError(t, p.AreValid()) + + p = append(p, *NewPermission(PermissionGroup, priv1.PublicKey())) + require.NoError(t, p.AreValid()) + + p = append(p, *NewPermission(PermissionWildcard)) + require.NoError(t, p.AreValid()) + + p = append(p, *NewPermission(PermissionHash, util.Uint160{3, 2, 1})) + require.NoError(t, p.AreValid()) + + p = append(p, *NewPermission(PermissionWildcard)) + require.Error(t, p.AreValid()) + + p = append(p[:len(p)-1], *NewPermission(PermissionHash, util.Uint160{1, 2, 3})) + require.Error(t, p.AreValid()) + + p = append(p[:len(p)-1], *NewPermission(PermissionGroup, priv0.PublicKey())) + require.Error(t, p.AreValid()) +} + func TestPermission_MarshalJSON(t *testing.T) { t.Run("wildcard", func(t *testing.T) { expected := NewPermission(PermissionWildcard) diff --git a/pkg/util/bitfield/bitfield.go b/pkg/util/bitfield/bitfield.go new file mode 100644 index 0000000000..8766c8d60f --- /dev/null +++ b/pkg/util/bitfield/bitfield.go @@ -0,0 +1,79 @@ +/* +Package bitfield provides a simple and efficient arbitrary size bit field implementation. +It doesn't attempt to cover everything that could be done with bit fields, +providing only things used by neo-go. +*/ +package bitfield + +// Field is a bit field represented as a slice of uint64 values. +type Field []uint64 + +// Bits and bytes count in a basic element of Field. +const elemBits = 64 +const elemBytes = 8 + +// New creates a new bit field of specified length. Actual field length +// can be rounded to the next multiple of 64, so it's a responsibility +// of the user to deal with that. +func New(n int) Field { + return make(Field, 1+(n-1)/elemBits) +} + +// Set sets one bit at specified offset. No bounds checking is done. +func (f Field) Set(i int) { + addr, offset := (i / elemBits), (i % elemBits) + f[addr] |= (1 << offset) +} + +// IsSet returns true if the bit with specified offset is set. +func (f Field) IsSet(i int) bool { + addr, offset := (i / elemBits), (i % elemBits) + return (f[addr] & (1 << offset)) != 0 +} + +// Copy makes a copy of current Field. +func (f Field) Copy() Field { + fn := make(Field, len(f)) + copy(fn, f) + return fn +} + +// And implements logical AND between f's and m's bits saving the result into f. +func (f Field) And(m Field) { + l := len(m) + for i := range f { + if i >= l { + f[i] = 0 + continue + } + f[i] &= m[i] + } +} + +// Equals compares two Fields and returns true if they're equal. +func (f Field) Equals(o Field) bool { + if len(f) != len(o) { + return false + } + for i := range f { + if f[i] != o[i] { + return false + } + } + return true +} + +// IsSubset returns true when f is a subset of o (only has bits set that are +// set in o). +func (f Field) IsSubset(o Field) bool { + if len(f) > len(o) { + return false + } + for i := range f { + r := f[i] & o[i] + if r != f[i] { + return false + } + } + return true +} diff --git a/pkg/util/bitfield/bitfield_test.go b/pkg/util/bitfield/bitfield_test.go new file mode 100644 index 0000000000..f71bd4e71e --- /dev/null +++ b/pkg/util/bitfield/bitfield_test.go @@ -0,0 +1,46 @@ +package bitfield + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFields(t *testing.T) { + a := New(128) + b := New(128) + a.Set(10) + b.Set(10) + a.Set(42) + b.Set(42) + a.Set(100) + b.Set(100) + require.True(t, a.IsSet(42)) + require.False(t, b.IsSet(43)) + require.True(t, a.IsSubset(b)) + + v := uint64(1<<10 | 1<<42) + require.Equal(t, v, a[0]) + require.Equal(t, v, b[0]) + + require.True(t, a.Equals(b)) + + c := a.Copy() + require.True(t, c.Equals(b)) + + z := New(128) + require.True(t, z.IsSubset(c)) + c.And(a) + require.True(t, c.Equals(b)) + c.And(z) + require.True(t, c.Equals(z)) + + c = New(64) + require.False(t, z.IsSubset(c)) + c[0] = a[0] + require.False(t, c.Equals(a)) + require.True(t, c.IsSubset(a)) + + b.And(c) + require.False(t, b.Equals(a)) +} diff --git a/pkg/vm/context.go b/pkg/vm/context.go index dcbee69abf..af20090cc0 100644 --- a/pkg/vm/context.go +++ b/pkg/vm/context.go @@ -3,6 +3,7 @@ package vm import ( "encoding/binary" "errors" + "fmt" "math/big" "github.com/nspcc-dev/neo-go/pkg/crypto/hash" @@ -109,6 +110,9 @@ func (c *Context) Next() (opcode.Opcode, []byte, error) { var instrbyte = c.prog[c.ip] instr := opcode.Opcode(instrbyte) + if !opcode.IsValid(instr) { + return instr, nil, fmt.Errorf("incorrect opcode %s", instr.String()) + } c.nextip++ var numtoread int diff --git a/pkg/vm/contract_checks.go b/pkg/vm/contract_checks.go index 1e8905663b..00b4250323 100644 --- a/pkg/vm/contract_checks.go +++ b/pkg/vm/contract_checks.go @@ -2,9 +2,12 @@ package vm import ( "encoding/binary" + "errors" + "fmt" "github.com/nspcc-dev/neo-go/pkg/core/interop/interopnames" "github.com/nspcc-dev/neo-go/pkg/encoding/bigint" + "github.com/nspcc-dev/neo-go/pkg/util/bitfield" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" ) @@ -127,3 +130,62 @@ func IsSignatureContract(script []byte) bool { func IsStandardContract(script []byte) bool { return IsSignatureContract(script) || IsMultiSigContract(script) } + +// IsScriptCorrect checks script for errors and mask provided for correctness wrt +// instruction boundaries. Normally it returns nil, but can return some specific +// error if there is any. +func IsScriptCorrect(script []byte, methods bitfield.Field) error { + var ( + l = len(script) + instrs = bitfield.New(l) + jumps = bitfield.New(l) + ) + ctx := NewContext(script) + for ctx.nextip < l { + op, param, err := ctx.Next() + if err != nil { + return err + } + instrs.Set(ctx.ip) + switch op { + case opcode.JMP, opcode.JMPIF, opcode.JMPIFNOT, opcode.JMPEQ, opcode.JMPNE, + opcode.JMPGT, opcode.JMPGE, opcode.JMPLT, opcode.JMPLE, + opcode.CALL, opcode.ENDTRY, opcode.JMPL, opcode.JMPIFL, + opcode.JMPIFNOTL, opcode.JMPEQL, opcode.JMPNEL, + opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLTL, opcode.JMPLEL, + opcode.ENDTRYL, opcode.CALLL, opcode.PUSHA: + off, _, err := calcJumpOffset(ctx, param) // It does bounds checking. + if err != nil { + return err + } + jumps.Set(off) + case opcode.TRY, opcode.TRYL: + catchP, finallyP := getTryParams(op, param) + off, _, err := calcJumpOffset(ctx, catchP) + if err != nil { + return err + } + jumps.Set(off) + off, _, err = calcJumpOffset(ctx, finallyP) + if err != nil { + return err + } + jumps.Set(off) + case opcode.NEWARRAYT, opcode.ISTYPE, opcode.CONVERT: + typ := stackitem.Type(param[0]) + if !typ.IsValid() { + return fmt.Errorf("invalid type specification at offset %d", ctx.ip) + } + if typ == stackitem.AnyT && op != opcode.NEWARRAYT { + return fmt.Errorf("using type ANY is incorrect at offset %d", ctx.ip) + } + } + } + if !jumps.IsSubset(instrs) { + return errors.New("some jumps are done to wrong offsets (not to instruction boundary)") + } + if methods != nil && !methods.IsSubset(instrs) { + return errors.New("some methods point to wrong offsets (not to instruction boundary)") + } + return nil +} diff --git a/pkg/vm/contract_checks_test.go b/pkg/vm/contract_checks_test.go index 4f5f835303..88eb1aeab5 100644 --- a/pkg/vm/contract_checks_test.go +++ b/pkg/vm/contract_checks_test.go @@ -5,8 +5,12 @@ import ( "testing" "github.com/nspcc-dev/neo-go/pkg/crypto/keys" + "github.com/nspcc-dev/neo-go/pkg/io" "github.com/nspcc-dev/neo-go/pkg/smartcontract" + "github.com/nspcc-dev/neo-go/pkg/util/bitfield" + "github.com/nspcc-dev/neo-go/pkg/vm/emit" "github.com/nspcc-dev/neo-go/pkg/vm/opcode" + "github.com/nspcc-dev/neo-go/pkg/vm/stackitem" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -115,3 +119,152 @@ func TestIsMultiSigContract(t *testing.T) { assert.False(t, IsMultiSigContract(prog)) }) } + +func TestIsScriptCorrect(t *testing.T) { + w := io.NewBufBinWriter() + emit.String(w.BinWriter, "something") + + jmpOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.JMP, opcode.Opcode(-jmpOff)) + + retOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.RET) + + jmplOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.JMPL, opcode.Opcode(0xff), opcode.Opcode(0xff), opcode.Opcode(0xff), opcode.Opcode(0xff)) + + tryOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.TRY, opcode.Opcode(3), opcode.Opcode(0xfb)) // -5 + + trylOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.TRYL, opcode.Opcode(0xfd), opcode.Opcode(0xff), opcode.Opcode(0xff), opcode.Opcode(0xff), + opcode.Opcode(9), opcode.Opcode(0), opcode.Opcode(0), opcode.Opcode(0)) + + istypeOff := w.Len() + emit.Opcodes(w.BinWriter, opcode.ISTYPE, opcode.Opcode(stackitem.IntegerT)) + + pushOff := w.Len() + emit.String(w.BinWriter, "else") + + good := w.Bytes() + + getScript := func() []byte { + s := make([]byte, len(good)) + copy(s, good) + return s + } + + t.Run("good", func(t *testing.T) { + require.NoError(t, IsScriptCorrect(good, nil)) + }) + + t.Run("bad instruction", func(t *testing.T) { + bad := getScript() + bad[retOff] = 0xff + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds JMP 1", func(t *testing.T) { + bad := getScript() + bad[jmpOff+1] = 0x80 // -128 + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds JMP 2", func(t *testing.T) { + bad := getScript() + bad[jmpOff+1] = 0x7f + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad JMP offset 1", func(t *testing.T) { + bad := getScript() + bad[jmpOff+1] = 0xff // into "something" + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad JMP offset 2", func(t *testing.T) { + bad := getScript() + bad[jmpOff+1] = byte(pushOff - jmpOff + 1) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds JMPL 1", func(t *testing.T) { + bad := getScript() + bad[jmplOff+1] = byte(-jmplOff - 1) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds JMPL 1", func(t *testing.T) { + bad := getScript() + bad[jmplOff+1] = byte(len(bad) - jmplOff) + bad[jmplOff+2] = 0 + bad[jmplOff+3] = 0 + bad[jmplOff+4] = 0 + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad JMPL offset", func(t *testing.T) { + bad := getScript() + bad[jmplOff+1] = 0xfe // into JMP + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds TRY 1", func(t *testing.T) { + bad := getScript() + bad[tryOff+1] = byte(-tryOff - 1) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("out of bounds TRY 2", func(t *testing.T) { + bad := getScript() + bad[tryOff+2] = byte(len(bad) - tryOff) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad TRYL offset 1", func(t *testing.T) { + bad := getScript() + bad[trylOff+1] = byte(-(trylOff - jmpOff) - 1) // into "something" + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad TRYL offset 2", func(t *testing.T) { + bad := getScript() + bad[trylOff+5] = byte(len(bad) - trylOff - 1) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad ISTYPE type", func(t *testing.T) { + bad := getScript() + bad[istypeOff+1] = byte(0xff) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("bad ISTYPE type (Any)", func(t *testing.T) { + bad := getScript() + bad[istypeOff+1] = byte(stackitem.AnyT) + require.Error(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("good NEWARRAY_T type", func(t *testing.T) { + bad := getScript() + bad[istypeOff] = byte(opcode.NEWARRAYT) + bad[istypeOff+1] = byte(stackitem.AnyT) + require.NoError(t, IsScriptCorrect(bad, nil)) + }) + + t.Run("good methods", func(t *testing.T) { + methods := bitfield.New(len(good)) + methods.Set(retOff) + methods.Set(tryOff) + methods.Set(pushOff) + require.NoError(t, IsScriptCorrect(good, methods)) + }) + + t.Run("bad methods", func(t *testing.T) { + methods := bitfield.New(len(good)) + methods.Set(retOff) + methods.Set(tryOff) + methods.Set(pushOff + 1) + require.Error(t, IsScriptCorrect(good, methods)) + }) +} diff --git a/pkg/vm/opcode/opcode.go b/pkg/vm/opcode/opcode.go index f2fcca417d..25e759a31c 100644 --- a/pkg/vm/opcode/opcode.go +++ b/pkg/vm/opcode/opcode.go @@ -219,3 +219,9 @@ const ( ISTYPE Opcode = 0xD9 CONVERT Opcode = 0xDB ) + +// IsValid returns true if the opcode passed is valid (defined in the VM). +func IsValid(op Opcode) bool { + _, ok := _Opcode_map[op] // We rely on stringer here, it has a map anyway. + return ok +} diff --git a/pkg/vm/opcode/opcode_test.go b/pkg/vm/opcode/opcode_test.go index f0f5a04515..89514532af 100644 --- a/pkg/vm/opcode/opcode_test.go +++ b/pkg/vm/opcode/opcode_test.go @@ -28,3 +28,10 @@ func TestFromString(t *testing.T) { require.NoError(t, err) require.Equal(t, MUL, op) } + +func TestIsValid(t *testing.T) { + require.True(t, IsValid(ADD)) + require.True(t, IsValid(CONVERT)) + require.False(t, IsValid(0xff)) + require.False(t, IsValid(0xa5)) +} diff --git a/pkg/vm/vm.go b/pkg/vm/vm.go index 67f3b1f0c7..3320da5ebe 100644 --- a/pkg/vm/vm.go +++ b/pkg/vm/vm.go @@ -185,11 +185,11 @@ func (v *VM) PrintOps(out io.Writer) { opcode.JMPEQL, opcode.JMPNEL, opcode.JMPGTL, opcode.JMPGEL, opcode.JMPLEL, opcode.JMPLTL, opcode.PUSHA, opcode.ENDTRY, opcode.ENDTRYL: - desc = v.getOffsetDesc(ctx, parameter) + desc = getOffsetDesc(ctx, parameter) case opcode.TRY, opcode.TRYL: catchP, finallyP := getTryParams(instr, parameter) desc = fmt.Sprintf("catch %s, finally %s", - v.getOffsetDesc(ctx, catchP), v.getOffsetDesc(ctx, finallyP)) + getOffsetDesc(ctx, catchP), getOffsetDesc(ctx, finallyP)) case opcode.INITSSLOT: desc = fmt.Sprint(parameter[0]) case opcode.CONVERT, opcode.ISTYPE: @@ -226,8 +226,8 @@ func (v *VM) PrintOps(out io.Writer) { w.Flush() } -func (v *VM) getOffsetDesc(ctx *Context, parameter []byte) string { - offset, rOffset, err := v.calcJumpOffset(ctx, parameter) +func getOffsetDesc(ctx *Context, parameter []byte) string { + offset, rOffset, err := calcJumpOffset(ctx, parameter) if err != nil { return fmt.Sprintf("ERROR: %v", err) } @@ -552,7 +552,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro v.estack.PushVal(parameter) case opcode.PUSHA: - n := v.getJumpOffset(ctx, parameter) + n := getJumpOffset(ctx, parameter) ptr := stackitem.NewPointerWithHash(n, ctx.prog, ctx.ScriptHash()) v.estack.PushVal(ptr) @@ -1249,7 +1249,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro opcode.JMPEQ, opcode.JMPEQL, opcode.JMPNE, opcode.JMPNEL, opcode.JMPGT, opcode.JMPGTL, opcode.JMPGE, opcode.JMPGEL, opcode.JMPLT, opcode.JMPLTL, opcode.JMPLE, opcode.JMPLEL: - offset := v.getJumpOffset(ctx, parameter) + offset := getJumpOffset(ctx, parameter) cond := true switch op { case opcode.JMP, opcode.JMPL: @@ -1268,7 +1268,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro case opcode.CALL, opcode.CALLL: // Note: jump offset must be calculated regarding to new context, // but it is cloned and thus has the same script and instruction pointer. - v.call(ctx, v.getJumpOffset(ctx, parameter)) + v.call(ctx, getJumpOffset(ctx, parameter)) case opcode.CALLA: ptr := v.estack.Pop().Item().(*stackitem.Pointer) @@ -1406,8 +1406,8 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if ctx.tryStack.Len() >= MaxTryNestingDepth { panic("maximum TRY depth exceeded") } - cOffset := v.getJumpOffset(ctx, catchP) - fOffset := v.getJumpOffset(ctx, finallyP) + cOffset := getJumpOffset(ctx, catchP) + fOffset := getJumpOffset(ctx, finallyP) if cOffset == ctx.ip && fOffset == ctx.ip { panic("invalid offset for TRY*") } else if cOffset == ctx.ip { @@ -1423,7 +1423,7 @@ func (v *VM) execute(ctx *Context, op opcode.Opcode, parameter []byte) (err erro if eCtx.State == eFinally { panic("invalid exception handling state during ENDTRY*") } - eOffset := v.getJumpOffset(ctx, parameter) + eOffset := getJumpOffset(ctx, parameter) if eCtx.HasFinally() { eCtx.State = eFinally eCtx.EndOffset = eOffset @@ -1527,8 +1527,8 @@ func (v *VM) call(ctx *Context, offset int) { // to a which JMP should be performed. // parameter should have length either 1 or 4 and // is interpreted as little-endian. -func (v *VM) getJumpOffset(ctx *Context, parameter []byte) int { - offset, _, err := v.calcJumpOffset(ctx, parameter) +func getJumpOffset(ctx *Context, parameter []byte) int { + offset, _, err := calcJumpOffset(ctx, parameter) if err != nil { panic(err) } @@ -1537,7 +1537,7 @@ func (v *VM) getJumpOffset(ctx *Context, parameter []byte) int { // calcJumpOffset returns absolute and relative offset of JMP/CALL/TRY instructions // either in short (1-byte) or long (4-byte) form. -func (v *VM) calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) { +func calcJumpOffset(ctx *Context, parameter []byte) (int, int, error) { var rOffset int32 switch l := len(parameter); l { case 1: