Skip to content

Commit

Permalink
Add ChanUpgradeOpen core handler.
Browse files Browse the repository at this point in the history
  • Loading branch information
DimitrisJim committed Jun 19, 2023
1 parent 1625adb commit 6b507a9
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 0 deletions.
88 changes: 88 additions & 0 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,94 @@ func (k Keeper) startFlushUpgradeHandshake(
return nil
}

// ChanUpgradeOpen is called by a module to complete the channel upgrade handshake and move the channel back to an OPEN state.
// This method should only be called after both channels have flushed any in-flight packets.
func (k Keeper) ChanUpgradeOpen(
ctx sdk.Context,
portID,
channelID string,
counterpartyChannelState types.State,
proofChannel []byte,
proofHeight clienttypes.Height,
) error {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

if k.hasInflightPackets(ctx, portID, channelID) {
return errorsmod.Wrapf(types.ErrPendingInflightPackets, "port ID (%s) channel ID (%s)", portID, channelID)
}

if !collections.Contains(channel.State, []types.State{types.TRYUPGRADE, types.ACKUPGRADE}) {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.TRYUPGRADE, types.ACKUPGRADE, channel.State)
}

if channel.FlushStatus != types.FLUSHCOMPLETE {
return errorsmod.Wrapf(types.ErrInvalidFlushStatus, "expected %s, got %s", types.FLUSHCOMPLETE, channel.FlushStatus)
}

connection, err := k.GetConnection(ctx, channel.ConnectionHops[0])
if err != nil {
return errorsmod.Wrap(err, "failed to retrieve connection using the channel connection hops")
}

if connection.GetState() != int32(connectiontypes.OPEN) {
return errorsmod.Wrapf(connectiontypes.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectiontypes.State(connection.GetState()).String())
}

var counterpartyChannel types.Channel
counterpartyHops := []string{connection.GetCounterparty().GetConnectionID()}
switch counterpartyChannelState {
case types.OPEN:
upgrade, found := k.GetUpgrade(ctx, portID, channelID)
if !found {
return errorsmod.Wrapf(types.ErrUpgradeNotFound, "failed to retrieve channel upgrade: port ID (%s) channel ID (%s)", portID, channelID)
}
counterpartyChannel = types.Channel{
State: types.OPEN,
Ordering: channel.Ordering,
ConnectionHops: upgrade.Fields.ConnectionHops,
Counterparty: types.NewCounterparty(portID, channelID),
Version: upgrade.Fields.GetVersion(),
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.NOTINFLUSH,
}

case types.TRYUPGRADE:
// If the counterparty is in TRYUPGRADE, then we must have gone through the ACKUPGRADE step.
if channel.State != types.ACKUPGRADE {
return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected %s, got %s", types.ACKUPGRADE, channel.State)
}
counterpartyChannel = types.Channel{
State: types.TRYUPGRADE,
Ordering: channel.Ordering,
ConnectionHops: counterpartyHops,
Counterparty: types.NewCounterparty(portID, channelID),
Version: channel.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.FLUSHCOMPLETE,
}

case types.ACKUPGRADE:
counterpartyChannel = types.Channel{
State: types.ACKUPGRADE,
Ordering: channel.Ordering,
ConnectionHops: counterpartyHops,
Counterparty: types.NewCounterparty(portID, channelID),
Version: channel.Version,
UpgradeSequence: channel.UpgradeSequence,
FlushStatus: types.FLUSHCOMPLETE,
}
default:
panic(fmt.Sprintf("counterparty channel state should be in one of [%s, %s, %s]; got %s", types.TRYUPGRADE, types.ACKUPGRADE, types.OPEN, counterpartyChannelState))
}

return k.connectionKeeper.VerifyChannelState(
ctx, connection, proofHeight, proofChannel, portID, channelID, counterpartyChannel,
)
}

// WriteUpgradeOpenChannel writes the agreed upon upgrade fields to the channel, sets the channel flush status to NOTINFLUSH and sets the channel state back to OPEN. This can be called in one of two cases:
// - In the UpgradeAck step of the handshake if both sides have already flushed all in-flight packets.
// - In the UpgradeOpen step of the handshake.
Expand Down
101 changes: 101 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,107 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
}
}

func (suite *KeeperTestSuite) TestChanUpgradeOpen() {
var (
path *ibctesting.Path
)
testCases := []struct {
name string
malleate func()
expError error
}{
{
"success, counterparty in TRYUPGRADE",
func() {},
nil,
},
// TODO: Rest of combinations for counterparty state.
{
"channel not found",
func() {
path.EndpointA.ChannelID = ibctesting.InvalidID
path.EndpointA.ChannelConfig.PortID = ibctesting.InvalidID
},
types.ErrChannelNotFound,
},
{
"in-flight packets still exist",
func() {
// TODO:
},
types.ErrPendingInflightPackets,
},
{
"flush status is not FLUSHCOMPLETE",
func() {
channel := path.EndpointB.GetChannel()
channel.FlushStatus = types.FLUSHING
path.EndpointB.SetChannel(channel)
},
types.ErrInvalidFlushStatus,
},
{
"connection not found",
func() {
channel := path.EndpointA.GetChannel()
channel.ConnectionHops = []string{"connection-100"}
path.EndpointA.SetChannel(channel)
},
connectiontypes.ErrConnectionNotFound,
},
{
"invalid connection state",
func() {
connectionEnd := path.EndpointA.GetConnection()
connectionEnd.State = connectiontypes.UNINITIALIZED
path.EndpointA.SetConnection(connectionEnd)
},
connectiontypes.ErrInvalidConnectionState,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
expPass := tc.expError == nil
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion

err := path.EndpointA.ChanUpgradeInit()
suite.Require().NoError(err)

err = path.EndpointB.ChanUpgradeTry()
suite.Require().NoError(err)

// Pending implementation of ack on msg_server to move channel state for A.
// Currently fails on validation.
err = path.EndpointA.ChanUpgradeAck()
suite.Require().NoError(err)

tc.malleate()

// ensure clients are up to date to receive valid proofs
suite.Require().NoError(path.EndpointA.UpdateClient())
proofCounterpartyChannel, _, proofHeight := path.EndpointB.QueryChannelUpgradeProof()

err = suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeOpen(
suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID,
path.EndpointB.GetChannel().State, proofCounterpartyChannel, proofHeight,
)
if expPass {
suite.Require().NoError(err)
} else {
suite.Require().ErrorIs(err, tc.expError)
}
})
}
}

func (suite *KeeperTestSuite) TestChanUpgradeTimeout() {
var (
path *ibctesting.Path
Expand Down
1 change: 1 addition & 0 deletions modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ var (
ErrUpgradeRestoreFailed = errorsmod.Register(SubModuleName, 34, "restore failed")
ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out")
ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid")
ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 37, "pending inflight packets exist")
)

0 comments on commit 6b507a9

Please sign in to comment.