diff --git a/stable_baselines/__init__.py b/stable_baselines/__init__.py index f6aa41cfe8..7a2162ee23 100644 --- a/stable_baselines/__init__.py +++ b/stable_baselines/__init__.py @@ -1,3 +1,5 @@ +import gym + from stable_baselines.a2c import A2C from stable_baselines.acer import ACER from stable_baselines.acktr import ACKTR @@ -9,3 +11,15 @@ from stable_baselines.trpo_mpi import TRPO __version__ = "2.0.0" + +if not hasattr(gym.spaces.MultiBinary, '__eq__'): + def _eq(self, other): + return self.n == other.n + + gym.spaces.MultiBinary.__eq__ = _eq + +if not hasattr(gym.spaces.MultiDiscrete, '__eq__'): + def _eq(self, other): + return self.nvec == other.nvec + + gym.spaces.MultiDiscrete.__eq__ = _eq