Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 82 additions & 62 deletions internal/merkle/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ var (
overflowMask = new(big.Int).Sub(overflowValue, one)
)

var (
ErrBadInput = errors.New("merkle: invalid input")
ErrInvariant = errors.New("merkle: internal invariant violated")
)

// MerkleProof: dave/common-rs/merkle/src/tree.rs
type Proof struct {
Pos *big.Int
Expand Down Expand Up @@ -54,7 +59,7 @@ func (proof *Proof) BuildRoot() common.Hash {
func (proof *Proof) BuildRootChildren() (common.Hash, common.Hash, error) {
if len(proof.Siblings) == 0 {
zero := common.Hash{}
return zero, zero, errors.New("Siblings array is empty")
return zero, zero, fmt.Errorf("siblings array is empty: %w", ErrBadInput)
}
two := big.NewInt(2)
height := len(proof.Siblings)
Expand Down Expand Up @@ -120,15 +125,15 @@ func (inner *InnerNode) Valid() bool {
return (isPair || isIterated) && !(isPair && isIterated) // xor
}

func (inner *InnerNode) Children() (*Tree, *Tree) {
func (inner *InnerNode) Children() (*Tree, *Tree, error) {
if !inner.Valid() {
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
return nil, nil, fmt.Errorf("invalid InnerNode state: %+v: %w", inner, ErrInvariant)
}

if inner.Child != nil {
return inner.Child, inner.Child
return inner.Child, inner.Child, nil
} else {
return inner.LHS, inner.RHS
return inner.LHS, inner.RHS, nil
}
Comment thread
mpolitzer marked this conversation as resolved.
}

Expand All @@ -144,33 +149,36 @@ func (tree *Tree) GetRootHash() common.Hash {
return tree.RootHash
}

func (tree *Tree) FindChildByHash(hash common.Hash) *Tree {
func (tree *Tree) FindChildByHash(hash common.Hash) (*Tree, error) {
if tree.RootHash == hash {
return tree
return tree, nil
}
if inner := tree.Subtrees; inner != nil {
if !inner.Valid() {
panic(fmt.Sprintf("invalid InnerNode state: %v\n", inner))
lhs, rhs, err := inner.Children()
if err != nil {
return nil, err
}

if inner.Child != nil {
child := inner.Child.FindChildByHash(hash)
if child != nil {
return child
}
} else {
lhs := inner.LHS.FindChildByHash(hash)
if lhs != nil {
return lhs
}
child, err := lhs.FindChildByHash(hash)
if err != nil {
return nil, err
}
if child != nil {
return child, nil
}

rhs := inner.RHS.FindChildByHash(hash)
if rhs != nil {
return rhs
// For iterated nodes lhs == rhs, so the right-hand search is redundant.
if lhs != rhs {
child, err = rhs.FindChildByHash(hash)
if err != nil {
return nil, err
}
if child != nil {
return child, nil
}
}
}
return nil // not found
return nil, nil // not found
}

func (tree *Tree) Join(other *Tree) *Tree {
Expand Down Expand Up @@ -198,11 +206,11 @@ func (tree *Tree) Iterated(rep uint64) *Tree {
return root
}

func (tree *Tree) ProveLeaf(index *big.Int) *Proof {
func (tree *Tree) ProveLeaf(index *big.Int) (*Proof, error) {
return tree.ProveLeafRec(index)
}

func (tree *Tree) ProveLast() *Proof {
func (tree *Tree) ProveLast() (*Proof, error) {
// index = (1 << height) - 1
index := new(big.Int).Sub(
new(big.Int).Lsh(
Expand All @@ -214,48 +222,56 @@ func (tree *Tree) ProveLast() *Proof {
return tree.ProveLeaf(index)
}

func (tree *Tree) ProveLeafRec(index *big.Int) *Proof {
func (tree *Tree) ProveLeafRec(index *big.Int) (*Proof, error) {
numLeafs := new(big.Int).Lsh(one, uint(tree.Height))
if numLeafs.Cmp(index) <= 0 {
panic(fmt.Sprintf("index out of bounds: %v, %v", numLeafs, index))
return nil, fmt.Errorf("index out of bounds: %v, %v: %w", numLeafs, index, ErrBadInput)
}

subtree := tree.Subtrees
if subtree == nil {
if index.Cmp(zero) != 0 {
panic(fmt.Sprintf("invalid Tree state: %v", tree))
return nil, fmt.Errorf("invalid Tree state: %v: %w", tree, ErrInvariant)
}
if tree.Height != 0 {
panic(fmt.Sprintf("invalid Tree state: %v", tree))
return nil, fmt.Errorf("invalid Tree state: %v: %w", tree, ErrInvariant)
}
return Leaf(tree.RootHash, index)
return Leaf(tree.RootHash, index), nil
}

shiftAmount := uint(tree.Height - 1)
isLeftLeaf := new(big.Int).Rsh(index, shiftAmount).Cmp(zero) == 0

// innerIndex = index & !(1 << shiftAmount)
innerIndex := new(big.Int).And(
// innerIndex = index & ~(1 << shiftAmount)
innerIndex := new(big.Int).AndNot(
index,
new(big.Int).Not(
new(big.Int).Lsh(
one,
shiftAmount,
),
new(big.Int).Lsh(
one,
shiftAmount,
),
)

lhs, rhs := subtree.Children()
lhs, rhs, err := subtree.Children()
if err != nil {
return nil, err
}

if isLeftLeaf {
proof := lhs.ProveLeafRec(innerIndex)
proof, err := lhs.ProveLeafRec(innerIndex)
if err != nil {
return nil, err
}
proof.PushHash(rhs.RootHash)
proof.Pos = index
return proof
return proof, nil
} else {
proof := rhs.ProveLeafRec(innerIndex)
proof, err := rhs.ProveLeafRec(innerIndex)
if err != nil {
return nil, err
}
proof.PushHash(lhs.RootHash)
proof.Pos = index
return proof
return proof, nil
}
}

Expand Down Expand Up @@ -295,60 +311,64 @@ func (b *Builder) CanBuild() bool {
return isPow2(b.Trees[n-1].AccumulatedCount)
}

func (b *Builder) Append(leaf *Tree) {
b.AppendRepeated(leaf, big.NewInt(1))
func (b *Builder) Append(leaf *Tree) error {
return b.AppendRepeated(leaf, big.NewInt(1))
}

func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) {
b.AppendRepeated(leaf, new(big.Int).SetUint64(reps))
func (b *Builder) AppendRepeatedUint64(leaf *Tree, reps uint64) error {
return b.AppendRepeated(leaf, new(big.Int).SetUint64(reps))
}

func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) {
func (b *Builder) AppendRepeated(leaf *Tree, reps *big.Int) error {
if leaf == nil || reps == nil {
return fmt.Errorf("invalid parameter: %w", ErrBadInput)
}
if reps.Cmp(zero) <= 0 {
panic("invalid repetitions")
return fmt.Errorf("invalid repetitions: %v: %w", reps, ErrBadInput)
}

accumulatedCount, err := b.calculateAccumulatedCount(reps)
if err != nil {
return err
}

accumulatedCount := b.CalculateAccumulatedCount(reps)
if height, ok := b.Height(); ok {
if height != leaf.Height {
panic("mismatched tree size")
return fmt.Errorf("mismatched tree sizes, height: %v and leaf height: %v: %w", height, leaf.Height, ErrBadInput)
}
}
Comment thread
mpolitzer marked this conversation as resolved.
b.Trees = append(b.Trees, Node{
Tree: leaf,
AccumulatedCount: accumulatedCount,
})
return nil
}

func (b *Builder) Build() *Tree {
func (b *Builder) Build() (*Tree, error) {
if count, ok := b.Count(); ok {
if !isCountPow2(count) {
panic(fmt.Sprintf("builder has %v leafs, which is not a power of two", count))
return nil, fmt.Errorf("builder has %v leafs, which is not a power of two: %w", count, ErrBadInput)
}
log2Size := countTrailingZeroes(count)
return buildMerkle(b.Trees, log2Size, big.NewInt(0))
return buildMerkle(b.Trees, log2Size, big.NewInt(0)), nil
} else {
panic("no leafs in the merkle builder")
return nil, fmt.Errorf("empty merkle builder: %w", ErrBadInput)
}
}

func (b *Builder) CalculateAccumulatedCount(reps *big.Int) *big.Int {
func (b *Builder) calculateAccumulatedCount(reps *big.Int) (*big.Int, error) {
n := len(b.Trees)
if n != 0 {
if reps.Cmp(zero) == 0 {
panic("merkle builder is full")
}

accumulatedCount := new(big.Int).And(
new(big.Int).Add(reps, b.Trees[n-1].AccumulatedCount),
overflowMask,
)
if reps.Cmp(accumulatedCount) >= 0 {
panic("merkle tree overflow")
return nil, fmt.Errorf("merkle tree overflow: %w", ErrBadInput)
}
return accumulatedCount
return accumulatedCount, nil
} else {
return reps
return reps, nil
}
}

Expand Down
Loading
Loading