diff --git a/internal/merkle/builder.go b/internal/merkle/builder.go index e517f0fe0..bca0bf82a 100644 --- a/internal/merkle/builder.go +++ b/internal/merkle/builder.go @@ -20,6 +20,11 @@ var ( overflowMask = new(big.Int).Sub(overflowValue, one) ) +var ( + ErrBadInput = errors.New("merkle: invalid input") + ErrInvariant = errors.New("merkle: internal invariant violated") +) + // MerkleProof: dave/common-rs/merkle/src/tree.rs type Proof struct { Pos *big.Int @@ -54,7 +59,7 @@ func (proof *Proof) BuildRoot() common.Hash { func (proof *Proof) BuildRootChildren() (common.Hash, common.Hash, error) { if len(proof.Siblings) == 0 { zero := common.Hash{} - return zero, zero, errors.New("Siblings array is empty") + return zero, zero, fmt.Errorf("siblings array is empty: %w", ErrBadInput) } two := big.NewInt(2) height := len(proof.Siblings) @@ -120,15 +125,15 @@ func (inner *InnerNode) Valid() bool { return (isPair || isIterated) && !(isPair && isIterated) // xor } -func (inner *InnerNode) Children() (*Tree, *Tree) { +func (inner *InnerNode) Children() (*Tree, *Tree, error) { if !inner.Valid() { - panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner)) + return nil, nil, fmt.Errorf("invalid InnerNode state: %+v: %w", inner, ErrInvariant) } if inner.Child != nil { - return inner.Child, inner.Child + return inner.Child, inner.Child, nil } else { - return inner.LHS, inner.RHS + return inner.LHS, inner.RHS, nil } } @@ -144,33 +149,36 @@ func (tree *Tree) GetRootHash() common.Hash { return tree.RootHash } -func (tree *Tree) FindChildByHash(hash common.Hash) *Tree { +func (tree *Tree) FindChildByHash(hash common.Hash) (*Tree, error) { if tree.RootHash == hash { - return tree + return tree, nil } if inner := tree.Subtrees; inner != nil { - if !inner.Valid() { - panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner)) + lhs, rhs, err := inner.Children() + if err != nil { + return nil, err } - if inner.Child != nil { - child := inner.Child.FindChildByHash(hash) - if child != nil { - return child - } - } else { - lhs := inner.LHS.FindChildByHash(hash) - if lhs != nil { - return lhs - } + child, err := lhs.FindChildByHash(hash) + if err != nil { + return nil, err + } + if child != nil { + return child, nil + } - rhs := inner.RHS.FindChildByHash(hash) - if rhs != nil { - return rhs + // For iterated nodes lhs == rhs, so the right-hand search is redundant. + if lhs != rhs { + child, err = rhs.FindChildByHash(hash) + if err != nil { + return nil, err + } + if child != nil { + return child, nil } } } - return nil // not found + return nil, nil // not found } func (tree *Tree) Join(other *Tree) *Tree { @@ -198,11 +206,11 @@ func (tree *Tree) Iterated(rep uint64) *Tree { return root } -func (tree *Tree) ProveLeaf(index *big.Int) *Proof { +func (tree *Tree) ProveLeaf(index *big.Int) (*Proof, error) { return tree.ProveLeafRec(index) } -func (tree *Tree) ProveLast() *Proof { +func (tree *Tree) ProveLast() (*Proof, error) { // index = (1 << height) - 1 index := new(big.Int).Sub( new(big.Int).Lsh( @@ -214,48 +222,56 @@ func (tree *Tree) ProveLast() *Proof { return tree.ProveLeaf(index) } -func (tree *Tree) ProveLeafRec(index *big.Int) *Proof { +func (tree *Tree) ProveLeafRec(index *big.Int) (*Proof, error) { numLeafs := new(big.Int).Lsh(one, uint(tree.Height)) if numLeafs.Cmp(index) <= 0 { - panic(fmt.Sprintf("index out of bounds: %v, %v", numLeafs, index)) + return nil, fmt.Errorf("index out of bounds: %v, %v: %w", numLeafs, index, ErrBadInput) } subtree := tree.Subtrees if subtree == nil { if index.Cmp(zero) != 0 { - panic(fmt.Sprintf("invalid Tree state: %v", tree)) + return nil, fmt.Errorf("invalid Tree state: %v: %w", tree, ErrInvariant) } if tree.Height != 0 { - panic(fmt.Sprintf("invalid Tree state: %v", tree)) + return nil, fmt.Errorf("invalid Tree state: %v: %w", tree, ErrInvariant) } - return Leaf(tree.RootHash, index) + return Leaf(tree.RootHash, index), nil } shiftAmount := uint(tree.Height - 1) isLeftLeaf := new(big.Int).Rsh(index, shiftAmount).Cmp(zero) == 0 - // innerIndex = index & !(1 << shiftAmount) - innerIndex := new(big.Int).And( + // innerIndex = index & ~(1 << shiftAmount) + innerIndex := new(big.Int).AndNot( index, - new(big.Int).Not( - new(big.Int).Lsh( - one, - shiftAmount, - ), + new(big.Int).Lsh( + one, + shiftAmount, ), ) - lhs, rhs := subtree.Children() + lhs, rhs, err := subtree.Children() + if err != nil { + return nil, err + } + if isLeftLeaf { - proof := lhs.ProveLeafRec(innerIndex) + proof, err := lhs.ProveLeafRec(innerIndex) + if err != nil { + return nil, err + } proof.PushHash(rhs.RootHash) proof.Pos = index - return proof + return proof, nil } else { - proof := rhs.ProveLeafRec(innerIndex) + proof, err := rhs.ProveLeafRec(innerIndex) + if err != nil { + return nil, err + } proof.PushHash(lhs.RootHash) proof.Pos = index - return proof + return proof, nil } } @@ -295,60 +311,64 @@ func (b *Builder) CanBuild() bool { return isPow2(b.Trees[n-1].AccumulatedCount) } -func (b *Builder) Append(leaf *Tree) { - b.AppendRepeated(leaf, big.NewInt(1)) +func (b *Builder) Append(leaf *Tree) error { + return b.AppendRepeated(leaf, big.NewInt(1)) } -func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) { - b.AppendRepeated(leaf, new(big.Int).SetUint64(reps)) +func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) error { + return b.AppendRepeated(leaf, new(big.Int).SetUint64(reps)) } -func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) { +func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) error { + if leaf == nil || reps == nil { + return fmt.Errorf("invalid parameter: %w", ErrBadInput) + } if reps.Cmp(zero) <= 0 { - panic("invalid repetitions") + return fmt.Errorf("invalid repetitions: %v: %w", reps, ErrBadInput) + } + + accumulatedCount, err := b.calculateAccumulatedCount(reps) + if err != nil { + return err } - accumulatedCount := b.CalculateAccumulatedCount(reps) if height, ok := b.Height(); ok { if height != leaf.Height { - panic("mismatched tree size") + return fmt.Errorf("mismatched tree sizes, height: %v and leaf height: %v: %w", height, leaf.Height, ErrBadInput) } } b.Trees = append(b.Trees, Node{ Tree: leaf, AccumulatedCount: accumulatedCount, }) + return nil } -func (b *Builder) Build() *Tree { +func (b *Builder) Build() (*Tree, error) { if count, ok := b.Count(); ok { if !isCountPow2(count) { - panic(fmt.Sprintf("builder has %v leafs, which is not a power of two", count)) + return nil, fmt.Errorf("builder has %v leafs, which is not a power of two: %w", count, ErrBadInput) } log2Size := countTrailingZeroes(count) - return buildMerkle(b.Trees, log2Size, big.NewInt(0)) + return buildMerkle(b.Trees, log2Size, big.NewInt(0)), nil } else { - panic("no leafs in the merkle builder") + return nil, fmt.Errorf("empty merkle builder: %w", ErrBadInput) } } -func (b *Builder) CalculateAccumulatedCount(reps *big.Int) *big.Int { +func (b *Builder) calculateAccumulatedCount(reps *big.Int) (*big.Int, error) { n := len(b.Trees) if n != 0 { - if reps.Cmp(zero) == 0 { - panic("merkle builder is full") - } - accumulatedCount := new(big.Int).And( new(big.Int).Add(reps, b.Trees[n-1].AccumulatedCount), overflowMask, ) if reps.Cmp(accumulatedCount) >= 0 { - panic("merkle tree overflow") + return nil, fmt.Errorf("merkle tree overflow: %w", ErrBadInput) } - return accumulatedCount + return accumulatedCount, nil } else { - return reps + return reps, nil } } diff --git a/internal/merkle/builder_test.go b/internal/merkle/builder_test.go index d533ba2b7..1ef5da419 100644 --- a/internal/merkle/builder_test.go +++ b/internal/merkle/builder_test.go @@ -4,12 +4,14 @@ package merkle import ( + "errors" "math/big" "testing" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var ( @@ -26,68 +28,91 @@ func TestIsCountPow2(t *testing.T) { assert.False(t, isCountPow2(big.NewInt(5))) } -// repanicked -//func TestRepeatZero(t *testing.T) { -// defer recover() -// -// builder := Builder{} -// builder.AppendRepeatedUint64(TreeLeaf(zeroHash), 0) -//} +func TestRepeatZero(t *testing.T) { + zeroHash := common.Hash{} + + builder := Builder{} + assert.ErrorIs(t, builder.AppendRepeatedUint64(TreeLeaf(zeroHash), 0), ErrBadInput) +} + +func TestBuilderErrorSentinels(t *testing.T) { + builder := Builder{} + assert.ErrorIs(t, builder.Append(nil), ErrBadInput) + assert.ErrorIs(t, builder.AppendRepeated(TreeLeaf(zeroDigest), nil), ErrBadInput) + + _, err := builder.Build() + assert.ErrorIs(t, err, ErrBadInput) + + invalidTree := &Tree{RootHash: zeroDigest, Height: 1} + _, err = invalidTree.ProveLeaf(big.NewInt(0)) + assert.ErrorIs(t, err, ErrInvariant) + + proof := Proof{} + _, _, err = proof.BuildRootChildren() + assert.ErrorIs(t, err, ErrBadInput) + + assert.False(t, errors.Is(ErrBadInput, ErrInvariant)) +} func TestSimple0(t *testing.T) { builder := Builder{} - builder.Append(TreeLeaf(oneDigest)) - treeRoot := builder.Build().RootHash + require.NoError(t, builder.Append(TreeLeaf(oneDigest))) + treeRoot, err := builder.Build() + require.NoError(t, err) expected := oneDigest - assert.Equal(t, expected, treeRoot) + assert.Equal(t, expected, treeRoot.RootHash) } func TestSimple1(t *testing.T) { builder := Builder{} - builder.Append(TreeLeaf(zeroDigest)) - builder.Append(TreeLeaf(oneDigest)) - treeRoot := builder.Build().RootHash + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + require.NoError(t, builder.Append(TreeLeaf(oneDigest))) + treeRoot, err := builder.Build() + require.NoError(t, err) expected := TreeLeaf(zeroDigest).Join(TreeLeaf(oneDigest)).RootHash - assert.Equal(t, expected, treeRoot) + assert.Equal(t, expected, treeRoot.RootHash) } func TestSimple2(t *testing.T) { builder := Builder{} - builder.AppendRepeatedUint64(TreeLeaf(oneDigest), 2) - builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2) - treeRoot := builder.Build().RootHash + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(oneDigest), 2)) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2)) + treeRoot, err := builder.Build() + require.NoError(t, err) lhs := TreeLeaf(oneDigest).Join(TreeLeaf(oneDigest)) rhs := TreeLeaf(zeroDigest).Join(TreeLeaf(zeroDigest)) expected := lhs.Join(rhs).RootHash - assert.Equal(t, expected, treeRoot) + assert.Equal(t, expected, treeRoot.RootHash) } func TestSimple3(t *testing.T) { builder := Builder{} - builder.Append(TreeLeaf(zeroDigest)) - builder.AppendRepeatedUint64(TreeLeaf(oneDigest), 2) - builder.Append(TreeLeaf(zeroDigest)) - treeRoot := builder.Build().RootHash + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(oneDigest), 2)) + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + treeRoot, err := builder.Build() + require.NoError(t, err) lhs := TreeLeaf(zeroDigest).Join(TreeLeaf(oneDigest)) rhs := TreeLeaf(oneDigest).Join(TreeLeaf(zeroDigest)) expected := lhs.Join(rhs).RootHash - assert.Equal(t, expected, treeRoot) + assert.Equal(t, expected, treeRoot.RootHash) } func TestMerkleBuilder8(t *testing.T) { builder := Builder{} - builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2) - builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 6) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2)) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 6)) assert.True(t, builder.CanBuild()) - merkle := builder.Build() + merkle, err := builder.Build() + require.NoError(t, err) assert.Equal(t, merkle.RootHash, TreeLeaf(zeroDigest).Iterated(3).RootHash) } @@ -97,11 +122,12 @@ func TestMerkleBuilder64(t *testing.T) { reps := new(big.Int).Sub(new(big.Int).Lsh(one, 64), two) builder := Builder{} - builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2) - builder.AppendRepeated(TreeLeaf(zeroDigest), reps) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 2)) + require.NoError(t, builder.AppendRepeated(TreeLeaf(zeroDigest), reps)) assert.True(t, builder.CanBuild()) - merkle := builder.Build() + merkle, err := builder.Build() + require.NoError(t, err) assert.Equal(t, merkle.RootHash, TreeLeaf(zeroDigest).Iterated(64).RootHash) } @@ -110,22 +136,25 @@ func TestMerkleBuilder256(t *testing.T) { reps := new(big.Int).Lsh(one, 256) builder := Builder{} - builder.AppendRepeated(TreeLeaf(zeroDigest), reps) + require.NoError(t, builder.AppendRepeated(TreeLeaf(zeroDigest), reps)) assert.True(t, builder.CanBuild()) - merkle := builder.Build() + merkle, err := builder.Build() + require.NoError(t, err) assert.Equal(t, merkle.RootHash, TreeLeaf(zeroDigest).Iterated(256).RootHash) } func TestAppendAndRepeated(t *testing.T) { builder := Builder{} - builder.Append(TreeLeaf(zeroDigest)) + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) assert.True(t, builder.CanBuild()) - tree1 := builder.Build() + tree1, err := builder.Build() + require.NoError(t, err) builder = Builder{} - builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 1) - tree2 := builder.Build() + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(zeroDigest), 1)) + tree2, err := builder.Build() + require.NoError(t, err) assert.Equal(t, tree1, tree2) } @@ -141,7 +170,7 @@ func TestBuildRootChildren1(t *testing.T) { rootHash := p.BuildRoot() lhs, rhs, err := p.BuildRootChildren() - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, rootHash, crypto.Keccak256Hash(lhs[:], rhs[:])) } @@ -157,19 +186,25 @@ func TestBuildRootChildren2(t *testing.T) { rootHash := p.BuildRoot() lhs, rhs, err := p.BuildRootChildren() - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, rootHash, crypto.Keccak256Hash(lhs[:], rhs[:])) } func TestBuildRootChildrenAgainstBuilder(t *testing.T) { builder := Builder{} - builder.AppendRepeatedUint64(TreeLeaf(common.HexToHash("0x976dc34e226f0c9803d556f26426aaa82ba7b5f96a5ed094f4f150c3c27aeaf5")), 16777216) - builder.AppendRepeatedUint64(TreeLeaf(common.HexToHash("0xfffeb0e2d6fc065fdcf03c25e23e9730528ca7b890308765b0e6b07586db9c6e")), 16777216) - builder.AppendRepeatedUint64(TreeLeaf(common.HexToHash("0x1588d343bd73f167bf4886b8ab7694b4d83b60087ddbdb445c427c16f26d2644")), 16777216) - builder.AppendRepeatedUint64(TreeLeaf(common.HexToHash("0x1588d343bd73f167bf4886b8ab7694b4d83b60087ddbdb445c427c16f26d2644")), 281474926379008) + hash0 := common.HexToHash("0x976dc34e226f0c9803d556f26426aaa82ba7b5f96a5ed094f4f150c3c27aeaf5") + hash1 := common.HexToHash("0xfffeb0e2d6fc065fdcf03c25e23e9730528ca7b890308765b0e6b07586db9c6e") + hash2 := common.HexToHash("0x1588d343bd73f167bf4886b8ab7694b4d83b60087ddbdb445c427c16f26d2644") + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(hash0), 16777216)) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(hash1), 16777216)) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(hash2), 16777216)) + require.NoError(t, builder.AppendRepeatedUint64(TreeLeaf(hash2), 281474926379008)) - builderTree := builder.Build() - proofBuilder := builderTree.ProveLast() + builderTree, err := builder.Build() + require.NoError(t, err) + + proofBuilder, err := builderTree.ProveLast() + require.NoError(t, err) rootHashBuilder := proofBuilder.BuildRoot() lhsBuilder, rhsBuilder, err := proofBuilder.BuildRootChildren() @@ -232,7 +267,7 @@ func TestBuildRootChildrenAgainstBuilder(t *testing.T) { rootHashProof := proofSiblings.BuildRoot() lhsProof, rhsProof, err := proofSiblings.BuildRootChildren() - assert.Nil(t, err) + require.NoError(t, err) assert.Equal(t, rootHashBuilder, rootHashProof) assert.Equal(t, lhsProof, lhsBuilder) assert.Equal(t, rhsProof, rhsBuilder) @@ -245,11 +280,13 @@ func TestFindChildByHash(t *testing.T) { builder := Builder{} leaf1 := TreeLeaf(common.HexToHash("0x1")) leaf2 := TreeLeaf(common.HexToHash("0x2")) - builder.Append(leaf1) - builder.Append(leaf2) - tree := builder.Build() + require.NoError(t, builder.Append(leaf1)) + require.NoError(t, builder.Append(leaf2)) + tree, err := builder.Build() + require.NoError(t, err) - child1 := tree.FindChildByHash(leaf1.RootHash) + child1, err := tree.FindChildByHash(leaf1.RootHash) + require.NoError(t, err) assert.NotNil(t, child1) assert.Equal(t, leaf1.RootHash, child1.RootHash) }) @@ -257,10 +294,12 @@ func TestFindChildByHash(t *testing.T) { t.Run("repetitions", func(t *testing.T) { builder := Builder{} leaf1 := TreeLeaf(common.HexToHash("0x1")) - builder.AppendRepeatedUint64(leaf1, 1024) - tree := builder.Build() + require.NoError(t, builder.AppendRepeatedUint64(leaf1, 1024)) + tree, err := builder.Build() + require.NoError(t, err) - child1 := tree.FindChildByHash(leaf1.RootHash) + child1, err := tree.FindChildByHash(leaf1.RootHash) + require.NoError(t, err) assert.NotNil(t, child1) assert.Equal(t, leaf1.RootHash, child1.RootHash) }) @@ -271,42 +310,45 @@ func TestFindChildByHash(t *testing.T) { leaf2 := TreeLeaf(common.HexToHash("0x2")) leaf3 := TreeLeaf(common.HexToHash("0x3")) leaf4 := TreeLeaf(common.HexToHash("0x4")) - builder.Append(leaf1) - builder.Append(leaf2) - builder.Append(leaf3) - builder.Append(leaf4) - tree := builder.Build() - - child1 := tree.FindChildByHash(leaf1.RootHash) + require.NoError(t, builder.Append(leaf1)) + require.NoError(t, builder.Append(leaf2)) + require.NoError(t, builder.Append(leaf3)) + require.NoError(t, builder.Append(leaf4)) + tree, err := builder.Build() + require.NoError(t, err) + + child1, err := tree.FindChildByHash(leaf1.RootHash) + require.NoError(t, err) assert.NotNil(t, child1) assert.Equal(t, leaf1.RootHash, child1.RootHash) - child2 := tree.FindChildByHash(leaf2.RootHash) + child2, err := tree.FindChildByHash(leaf2.RootHash) + require.NoError(t, err) assert.NotNil(t, child2) assert.Equal(t, child2.RootHash, leaf2.RootHash) }) t.Run("notfound", func(t *testing.T) { builder := Builder{} - builder.Append(TreeLeaf(zeroDigest)) - builder.Append(TreeLeaf(oneDigest)) - tree := builder.Build() + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + require.NoError(t, builder.Append(TreeLeaf(oneDigest))) + tree, err := builder.Build() + require.NoError(t, err) missing := common.HexToHash("0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") - child := tree.FindChildByHash(missing) + child, err := tree.FindChildByHash(missing) + require.NoError(t, err) assert.Nil(t, child) }) } -// repanicked -//func TestBuildNotPow2(t *testing.T) { -// defer recover() -// -// builder := Builder{} -// builder.Append(TreeLeaf(zeroDigest)) -// builder.Append(TreeLeaf(zeroDigest)) -// builder.Append(TreeLeaf(zeroDigest)) -// assert.False(t, builder.CanBuild()) -// -// builder.Build() -//} +func TestBuildNotPow2(t *testing.T) { + builder := Builder{} + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + require.NoError(t, builder.Append(TreeLeaf(zeroDigest))) + assert.False(t, builder.CanBuild()) + + _, err := builder.Build() + assert.Error(t, err) +} diff --git a/internal/validator/validator.go b/internal/validator/validator.go index 217f386b7..0578d9fff 100644 --- a/internal/validator/validator.go +++ b/internal/validator/validator.go @@ -108,10 +108,6 @@ func (s *Service) String() string { return s.Name } -// The maximum height for the Merkle tree of all outputs produced -// by an application -const MAX_OUTPUT_TREE_HEIGHT = merkle.TREE_DEPTH //nolint: revive - type ValidatorRepository interface { ListApplications(ctx context.Context, f repository.ApplicationFilter, p repository.Pagination, descending bool) ([]*Application, uint64, error) UpdateApplicationStatus(ctx context.Context, appID int64, status ApplicationStatus, reason *string) error @@ -146,6 +142,7 @@ func (s *Service) setApplicationCorrupted(ctx context.Context, app *Application, // validateApplication calculates, validates and stores the claim and/or proofs // for each processed epoch of the application. +// Epochs are iterated in ascending virtual-index order as returned by getProcessedEpochs. func (s *Service) validateApplication(ctx context.Context, app *Application) error { s.Logger.Debug("Starting validation", "application", app.Name) appAddress := app.IApplicationAddress.String() @@ -158,6 +155,9 @@ func (s *Service) validateApplication(ctx context.Context, app *Application) err } for _, epoch := range processedEpochs { + if err := ctx.Err(); err != nil { + return err + } if app.ForecloseBlock != 0 && epoch.LastBlock >= app.ForecloseBlock { s.Logger.Info("Skipping foreclosed epoch that cannot be accepted", "application", appAddress, @@ -172,6 +172,16 @@ func (s *Service) validateApplication(ctx context.Context, app *Application) err "epoch_index", epoch.Index, "last_block", epoch.LastBlock, ) + + if epoch.MachineHash == nil { + return s.setApplicationCorrupted(ctx, app, + "epoch %v (%v) has no machine hash", epoch.Index, epoch.VirtualIndex) + } + if epoch.OutputsMerkleRoot == nil { + return s.setApplicationCorrupted(ctx, app, + "epoch %v (%v) has no outputs merkle root", epoch.Index, epoch.VirtualIndex) + } + merkleRoot, outputs, err := s.computeMerkleTreeAndProofs(ctx, app, epoch) if err != nil { // Don't log shutdown-cancellation at ERR — every in-flight DB @@ -209,8 +219,7 @@ func (s *Service) validateApplication(ctx context.Context, app *Application) err ) } - // DaveConsensus can have empty epochs. Authority and Quorum don't. - if !app.IsDaveConsensus() || input != nil { + if input != nil { if input.OutputsHash == nil { return s.setApplicationCorrupted(ctx, app, "inconsistent state: epoch %v last input (%v) outputs merkle root is not defined", @@ -224,12 +233,23 @@ func (s *Service) validateApplication(ctx context.Context, app *Application) err epoch.Index, input.Index, *input.OutputsHash, *epoch.OutputsMerkleRoot) } + if input.MachineHash == nil { + return s.setApplicationCorrupted(ctx, app, + "inconsistent state: epoch %v last input (%v) machine hash is not defined", + epoch.Index, input.Index) + } if *epoch.MachineHash != *input.MachineHash { return s.setApplicationCorrupted(ctx, app, "epoch %v machine hash does not match epoch last input (%v) machine hash. Expected: %v, Got %v", epoch.Index, input.Index, *input.MachineHash, *epoch.MachineHash) } } else { // empty epochs + // DaveConsensus can have empty epochs. Authority and Quorum can't. + if !app.IsDaveConsensus() { + return s.setApplicationCorrupted(ctx, app, + "epoch %v (%v) has no inputs while not DaveConsensus", epoch.Index, epoch.VirtualIndex) + } + if epoch.VirtualIndex > 0 { previousEpoch, err := s.repository.GetEpochByVirtualIndex(ctx, appAddress, epoch.VirtualIndex-1) if err != nil { @@ -238,11 +258,26 @@ func (s *Service) validateApplication(ctx context.Context, app *Application) err epoch.Index, epoch.VirtualIndex, appAddress, err, ) } + + // nil is a valid result from GetEpochByVirtualIndex. + // However it indicates some kind of corruption when combined with the epoch.VirtualIndex > 0 check. + if previousEpoch == nil { + return s.setApplicationCorrupted(ctx, app, + "epoch %v (%v) has no previous epoch", epoch.Index, epoch.VirtualIndex) + } + if previousEpoch.MachineHash == nil { + return s.setApplicationCorrupted(ctx, app, + "previous epoch %v (%v) machine hash is not defined", previousEpoch.Index, previousEpoch.VirtualIndex) + } if *epoch.MachineHash != *previousEpoch.MachineHash { return s.setApplicationCorrupted(ctx, app, "epoch %v machine hash does not match previous epoch %v machine hash. Expected: %v, Got %v", epoch.Index, previousEpoch.Index, *previousEpoch.MachineHash, *epoch.MachineHash) } + if previousEpoch.OutputsMerkleRoot == nil { + return s.setApplicationCorrupted(ctx, app, + "previous epoch %v (%v) outputs merkle root is not defined", previousEpoch.Index, previousEpoch.VirtualIndex) + } if *epoch.OutputsMerkleRoot != *previousEpoch.OutputsMerkleRoot { return s.setApplicationCorrupted(ctx, app, "epoch %v outputs merkle root does not match previous epoch %v one. Expected: %v, Got %v", @@ -305,8 +340,23 @@ func (s *Service) buildCommitment(ctx context.Context, app *Application, epoch * "application", app.Name, "epoch", epoch.Index) + if epoch.InputIndexLowerBound > epoch.InputIndexUpperBound { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "invalid epoch %v (%v): lower bound (%v) > upper bound (%v)", + epoch.Index, epoch.VirtualIndex, epoch.InputIndexLowerBound, epoch.InputIndexUpperBound) + } + if epoch.MachineHash == nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "epoch %v (%v) machine hash is not defined", epoch.Index, epoch.VirtualIndex) + } builder := merkle.Builder{} inputCount := epoch.InputIndexUpperBound - epoch.InputIndexLowerBound + if inputCount > pkgm.InputsPerEpoch { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "input count is too large for epoch %v of application %v: max %v, got %v", + epoch.Index, app.Name, pkgm.InputsPerEpoch, inputCount) + } + if inputCount > 0 { statesHashes, total, err := s.repository.ListStateHashes(ctx, app.IApplicationAddress.String(), repository.StateHashFilter{EpochIndex: &epoch.Index}, repository.Pagination{}, false) @@ -315,26 +365,59 @@ func (s *Service) buildCommitment(ctx context.Context, app *Application, epoch * epoch.Index, app.Name, err) } if total < inputCount { - return nil, nil, fmt.Errorf("not enough state hashes for epoch %d of application %s: expected at least %d, got %d", + return nil, nil, s.setApplicationCorrupted(ctx, app, + "not enough state hashes for epoch %d of application %s: expected at least %d, got %d", epoch.Index, app.Name, inputCount, total) } if uint64(len(statesHashes)) != total { - return nil, nil, fmt.Errorf("inconsistent number of state hashes for epoch %d of application %s: expected %d, got %d", epoch.Index, app.Name, total, len(statesHashes)) + return nil, nil, s.setApplicationCorrupted(ctx, app, + "inconsistent number of state hashes for epoch %d of application %s: expected %d, got %d", + epoch.Index, app.Name, total, len(statesHashes)) } for _, stateHash := range statesHashes { - builder.AppendRepeatedUint64(merkle.TreeLeaf(stateHash.MachineHash), stateHash.Repetitions) + if err := builder.AppendRepeatedUint64(merkle.TreeLeaf(stateHash.MachineHash), stateHash.Repetitions); err != nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "failed to append state hash to builder for epoch %d of application %s with error: %v", epoch.Index, app.Name, err) + } } } remainingInputs := pkgm.InputsPerEpoch - inputCount + // Safe: inputCount ≤ InputsPerEpoch enforced above, so remainingInputs << Log2StridesPerInput won't overflow. remainingStrides := remainingInputs << pkgm.Log2StridesPerInput if remainingStrides > 0 { - builder.AppendRepeatedUint64(merkle.TreeLeaf(*epoch.MachineHash), remainingStrides) + if err := builder.AppendRepeatedUint64(merkle.TreeLeaf(*epoch.MachineHash), remainingStrides); err != nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "failed to append state hash to builder for epoch %d of application %s with error: %v", epoch.Index, app.Name, err) + } + } + + epochCommitmentTree, err := builder.Build() + if err != nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "failed to build commitment for epoch %d of application %s with error: %v", epoch.Index, app.Name, err) + } + // The commitment geometry is fixed: 2²⁴ inputs × 2²⁴ strides ⇒ height 48. + const expectedHeight = pkgm.Log2InputSpanToEpoch + pkgm.Log2StridesPerInput // 48 + if uint64(epochCommitmentTree.Height) != expectedHeight { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "epoch %v commitment tree height %v, expected %v — state hash repetitions are inconsistent", + epoch.Index, epochCommitmentTree.Height, expectedHeight) } - epochCommitmentTree := builder.Build() commitment := epochCommitmentTree.GetRootHash() - proof := epochCommitmentTree.ProveLast() + proof, err := epochCommitmentTree.ProveLast() + if err != nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "failed to retrieve commitment proof for epoch %d of application %s with error: %v", epoch.Index, app.Name, err) + } + // PRT reconstructs the root children from (epoch.MachineHash, proof). + // The tree's last leaf must therefore be the epoch's final machine hash. + if proof.Node != *epoch.MachineHash { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "epoch %v commitment last leaf %v does not match machine hash %v", + epoch.Index, proof.Node, *epoch.MachineHash) + } s.Logger.Info("DaveConsensus epoch commitment built", "application", app.Name, "epoch", epoch.Index, @@ -385,8 +468,15 @@ func (s *Service) computeMerkleTreeAndProofs( // if there are no outputs if len(epochOutputs) == 0 { + // and there should be a previous epoch (guard against DB corruption) + if epoch.VirtualIndex != 0 && previousEpoch == nil { + return nil, nil, s.setApplicationCorrupted(ctx, app, + "epoch %v (%v): previous epoch with virtual index %v is missing", + epoch.Index, epoch.VirtualIndex, epoch.VirtualIndex-1) + } + // and there is no previous epoch - if previousEpoch == nil { + if epoch.VirtualIndex == 0 { // this is the first epoch, return the pristine claim return &s.pristineRootHash, nil, nil } diff --git a/internal/validator/validator_test.go b/internal/validator/validator_test.go index 13ffd4fd3..d866b584a 100644 --- a/internal/validator/validator_test.go +++ b/internal/validator/validator_test.go @@ -6,11 +6,13 @@ package validator import ( "context" "fmt" + "strings" "testing" "github.com/cartesi/rollups-node/internal/merkle" . "github.com/cartesi/rollups-node/internal/model" "github.com/cartesi/rollups-node/internal/repository" + pkgm "github.com/cartesi/rollups-node/pkg/machine" "github.com/cartesi/rollups-node/pkg/service" "github.com/ethereum/go-ethereum/common" "github.com/stretchr/testify/mock" @@ -32,6 +34,15 @@ var ( dummyOutputs []Output ) +func expectCorrupted(app *Application, reasonSubstring string) { + repo.On("UpdateApplicationStatus", + mock.Anything, app.ID, ApplicationStatus_Corrupted, + mock.MatchedBy(func(reason *string) bool { + return reason != nil && strings.Contains(*reason, reasonSubstring) + }), + ).Return(nil).Once() +} + func (s *ValidatorSuite) SetupSubTest() { repo = newMockrepo() postContext := merkle.CreatePostContext() @@ -136,19 +147,22 @@ func (s *ValidatorSuite) TestCreateClaimAndProofSuccess() { } s.Run("FirstEpochNoOutputs", func() { + ctx := context.Background() repo.On("ListOutputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, ).Return([]*Output{}, uint64(0), nil) - claimHash, _, err := validator.computeMerkleTreeAndProofs(nil, &app, &dummyEpochs[0]) + claimHash, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[0]) + s.NoError(err) claimHashRef, _, err := merkle.CreateProofs(nil, merkle.TREE_DEPTH) - s.ErrorIs(nil, err) + s.NoError(err) s.NotNil(claimHash) s.Equal(claimHashRef, *claimHash) repo.AssertExpectations(s.T()) }) s.Run("FirstEpochOneOutput", func() { + ctx := context.Background() output := Output{ RawData: common.Hash{}.Bytes(), } @@ -157,13 +171,14 @@ func (s *ValidatorSuite) TestCreateClaimAndProofSuccess() { mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, ).Return([]*Output{&output}, uint64(1), nil) - claimHash, _, err := validator.computeMerkleTreeAndProofs(nil, &app, &dummyEpochs[0]) - s.ErrorIs(nil, err) + claimHash, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[0]) + s.NoError(err) s.NotNil(claimHash) repo.AssertExpectations(s.T()) }) s.Run("SecondEpochNoOutputs", func() { + ctx := context.Background() repo.On("ListOutputs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, ).Return([]*Output{}, uint64(0), nil).Once() @@ -172,13 +187,14 @@ func (s *ValidatorSuite) TestCreateClaimAndProofSuccess() { mock.Anything, mock.Anything, mock.Anything, ).Return(&dummyEpochs[0], nil).Once() - claimHash, _, err := validator.computeMerkleTreeAndProofs(nil, &app, &dummyEpochs[1]) - s.ErrorIs(nil, err) + claimHash, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[1]) + s.NoError(err) s.Equal(dummyEpochs[0].OutputsMerkleRoot, claimHash) repo.AssertExpectations(s.T()) }) s.Run("SecondEpochTwoOutputs", func() { + ctx := context.Background() newOutput0 := Output{ Index: 1, RawData: common.Hash{}.Bytes(), @@ -199,8 +215,8 @@ func (s *ValidatorSuite) TestCreateClaimAndProofSuccess() { mock.Anything, mock.Anything, mock.Anything, ).Return(&dummyOutputs[0], nil).Once() - _, _, err := validator.computeMerkleTreeAndProofs(nil, &app, &dummyEpochs[1]) - s.ErrorIs(nil, err) + _, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[1]) + s.NoError(err) repo.AssertExpectations(s.T()) }) } @@ -250,9 +266,7 @@ func (s *ValidatorSuite) TestCreateClaimAndProofFailures() { mock.Anything, mock.Anything, mock.Anything, ).Return(&invalidEpoch, nil).Once() - repo.On("UpdateApplicationStatus", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Once() + expectCorrupted(&app, "Previous epoch has no claim") _, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[1]) s.NotNil(err) @@ -292,9 +306,7 @@ func (s *ValidatorSuite) TestCreateClaimAndProofFailures() { mock.Anything, mock.Anything, mock.Anything, ).Return(&Output{}, nil).Once() - repo.On("UpdateApplicationStatus", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Once() + expectCorrupted(&app, "Output (0) preceding epoch 1 is missing or has invalid hash siblings") _, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[1]) s.NotNil(err) @@ -315,9 +327,7 @@ func (s *ValidatorSuite) TestCreateClaimAndProofFailures() { mock.Anything, mock.Anything, mock.Anything, ).Return((*Output)(nil), nil).Once() - repo.On("UpdateApplicationStatus", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Once() + expectCorrupted(&app, "Output (0) preceding epoch 1 is missing or has invalid hash siblings") _, _, err := validator.computeMerkleTreeAndProofs(ctx, &app, &dummyEpochs[1]) s.NotNil(err) @@ -337,7 +347,7 @@ func (s *ValidatorSuite) TestValidateApplicationSuccess() { ).Return(([]*Epoch)(nil), uint64(0), nil).Once() err := validator.validateApplication(ctx, &app) - s.ErrorIs(nil, err) + s.NoError(err) repo.AssertExpectations(s.T()) }) @@ -365,7 +375,7 @@ func (s *ValidatorSuite) TestValidateApplicationSuccess() { ).Return(nil).Once() err := validator.validateApplication(ctx, &app) - s.ErrorIs(nil, err) + s.NoError(err) repo.AssertExpectations(s.T()) }) @@ -380,7 +390,7 @@ func (s *ValidatorSuite) TestValidateApplicationSuccess() { ).Return([]*Epoch{&unacceptableEpoch}, uint64(1), nil).Once() err := validator.validateApplication(ctx, &foreclosedApp) - s.ErrorIs(nil, err) + s.NoError(err) repo.AssertExpectations(s.T()) }) } @@ -475,11 +485,35 @@ func (s *ValidatorSuite) TestValidateApplicationFailure() { mock.Anything, app.IApplicationAddress.String(), dummyEpochs[0].Index, ).Return(&input, nil).Once() - repo.On("UpdateApplicationStatus", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Once() + expectCorrupted(&app, "outputs merkle root is not defined") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("NilInputMachineHash", func() { + input := Input{ + EpochApplicationID: app.ID, + OutputsHash: &validator.pristineRootHash, + MachineHash: nil, // <- trigger nil guard + } + + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&dummyEpochs[0]}, uint64(1), nil).Once() - err := validator.validateApplication(nil, &app) + repo.On("ListOutputs", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, + ).Return([]*Output{}, uint64(0), nil).Once() + + repo.On("GetLastInput", + mock.Anything, app.IApplicationAddress.String(), dummyEpochs[0].Index, + ).Return(&input, nil).Once() + + expectCorrupted(&app, "machine hash is not defined") + + err := validator.validateApplication(ctx, &app) s.NotNil(err) repo.AssertExpectations(s.T()) }) @@ -503,9 +537,7 @@ func (s *ValidatorSuite) TestValidateApplicationFailure() { mock.Anything, app.IApplicationAddress.String(), dummyEpochs[0].Index, ).Return(&input, nil).Once() - repo.On("UpdateApplicationStatus", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(nil).Once() + expectCorrupted(&app, "computed outputs merkle root does not match") err := validator.validateApplication(ctx, &app) s.NotNil(err) @@ -539,6 +571,489 @@ func (s *ValidatorSuite) TestValidateApplicationFailure() { s.ErrorIs(err, xerror) repo.AssertExpectations(s.T()) }) + + // --- nil pointer access guards added in fix/validator-hardening --- + + s.Run("NilEpochMachineHash", func() { + epoch := Epoch{ + Index: 0, + VirtualIndex: 0, + FirstBlock: 0, + LastBlock: 9, + OutputsMerkleRoot: &validator.pristineRootHash, + MachineHash: nil, // <- nil triggers the new guard + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + expectCorrupted(&app, "has no machine hash") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("NilEpochOutputsMerkleRoot", func() { + epoch := Epoch{ + Index: 0, + VirtualIndex: 0, + FirstBlock: 0, + LastBlock: 9, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: nil, // <- nil triggers the new guard + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + expectCorrupted(&app, "has no outputs merkle root") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("EmptyEpochInNonDaveConsensusApp", func() { + app := Application{ + Name: "dummy-application-name", + ConsensusType: Consensus_Authority, // non-DaveConsensus + } + epoch := Epoch{ + Index: 0, + VirtualIndex: 0, + FirstBlock: 0, + LastBlock: 9, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: dummyEpochs[0].OutputsMerkleRoot, + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + // computeMerkleTreeAndProofs: no outputs, no previous epoch -> pristine claim + repo.On("ListOutputs", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, + ).Return([]*Output{}, uint64(0), nil).Once() + + // empty epoch (GetLastInput returns nil) + repo.On("GetLastInput", + mock.Anything, app.IApplicationAddress.String(), epoch.Index, + ).Return((*Input)(nil), nil).Once() + + expectCorrupted(&app, "has no inputs while not DaveConsensus") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("EmptyEpochPreviousEpochIsNil", func() { + app := Application{ + Name: "dummy-application-name", + ConsensusType: Consensus_PRT, // DaveConsensus so !IsDaveConsensus guard passes + } + epoch := Epoch{ + Index: 1, + VirtualIndex: 1, + FirstBlock: 10, + LastBlock: 19, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: dummyEpochs[0].OutputsMerkleRoot, + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + // computeMerkleTreeAndProofs: no outputs -> looks up previous epoch + repo.On("ListOutputs", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, + ).Return([]*Output{}, uint64(0), nil).Once() + + // 1st GetEpochByVirtualIndex (inside computeMerkleTreeAndProofs) succeeds + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return(&dummyEpochs[0], nil).Once() + + // empty epoch (GetLastInput returns nil) + repo.On("GetLastInput", + mock.Anything, app.IApplicationAddress.String(), epoch.Index, + ).Return((*Input)(nil), nil).Once() + + // 2nd GetEpochByVirtualIndex (inside validateApplication else branch) returns nil + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return((*Epoch)(nil), nil).Once() + + expectCorrupted(&app, "has no previous epoch") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("EmptyEpochPreviousEpochMachineHashNil", func() { + app := Application{ + Name: "dummy-application-name", + ConsensusType: Consensus_PRT, + } + epoch := Epoch{ + Index: 1, + VirtualIndex: 1, + FirstBlock: 10, + LastBlock: 19, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: dummyEpochs[0].OutputsMerkleRoot, + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + repo.On("ListOutputs", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, + ).Return([]*Output{}, uint64(0), nil).Once() + + // 1st GetEpochByVirtualIndex (inside computeMerkleTreeAndProofs) succeeds + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return(&dummyEpochs[0], nil).Once() + + repo.On("GetLastInput", + mock.Anything, app.IApplicationAddress.String(), epoch.Index, + ).Return((*Input)(nil), nil).Once() + + // 2nd GetEpochByVirtualIndex: returns epoch with nil MachineHash + prev := Epoch{ + Index: 0, + VirtualIndex: 0, + MachineHash: nil, // <- nil triggers the guard + OutputsMerkleRoot: &validator.pristineRootHash, + } + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return(&prev, nil).Once() + + expectCorrupted(&app, "machine hash is not defined") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("EmptyEpochPreviousEpochOutputsMerkleRootNil", func() { + app := Application{ + Name: "dummy-application-name", + ConsensusType: Consensus_PRT, + } + epoch := Epoch{ + Index: 1, + VirtualIndex: 1, + FirstBlock: 10, + LastBlock: 19, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: dummyEpochs[0].OutputsMerkleRoot, + } + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&epoch}, uint64(1), nil).Once() + + repo.On("ListOutputs", + mock.Anything, mock.Anything, mock.Anything, mock.Anything, false, + ).Return([]*Output{}, uint64(0), nil).Once() + + // 1st GetEpochByVirtualIndex (inside computeMerkleTreeAndProofs) succeeds + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return(&dummyEpochs[0], nil).Once() + + repo.On("GetLastInput", + mock.Anything, app.IApplicationAddress.String(), epoch.Index, + ).Return((*Input)(nil), nil).Once() + + // 2nd GetEpochByVirtualIndex: MachineHash matches but OutputsMerkleRoot is nil + prev := Epoch{ + Index: 0, + VirtualIndex: 0, + MachineHash: &validator.pristineRootHash, // matches epoch.MachineHash + OutputsMerkleRoot: nil, // <- nil triggers the guard + } + repo.On("GetEpochByVirtualIndex", + mock.Anything, app.IApplicationAddress.String(), uint64(0), + ).Return(&prev, nil).Once() + + expectCorrupted(&app, "outputs merkle root is not defined") + + err := validator.validateApplication(ctx, &app) + s.NotNil(err) + repo.AssertExpectations(s.T()) + }) + + s.Run("ContextCancellation", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + repo.On("ListEpochs", + mock.Anything, app.IApplicationAddress.String(), mock.Anything, mock.Anything, false, + ).Return([]*Epoch{&dummyEpochs[0]}, uint64(1), nil).Once() + + err := validator.validateApplication(ctx, &app) + s.ErrorIs(err, context.Canceled) + repo.AssertExpectations(s.T()) + }) +} + +// cover the integer overflow guards. +func (s *ValidatorSuite) TestBuildCommitment() { + ctx := context.Background() + const testAppName = "test-app" + + s.Run("ValidEpochWithInputs", func() { + app := &Application{ + Name: testAppName, + ConsensusType: Consensus_PRT, + } + epoch := &Epoch{ + Index: 0, + VirtualIndex: 0, + InputIndexLowerBound: 0, + InputIndexUpperBound: 5, + MachineHash: &validator.pristineRootHash, + OutputsMerkleRoot: &validator.pristineRootHash, + } + + // 5 inputs, each with one state hash covering the full + // strides-per-input count (1<