Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Readonly settings #169

Merged
merged 18 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions gufe/protocols/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Optional, Iterable, Any, Union
from openff.units import Quantity

from ..settings import Settings
from ..settings import Settings, SettingsBaseModel
from ..tokenization import GufeTokenizable, GufeKey
from ..chemicalsystem import ChemicalSystem
from ..mapping import ComponentMapping
Expand Down Expand Up @@ -89,13 +89,14 @@ def __init__(self, settings: Settings):
Parameters
----------
settings : Settings
The full settings for this ``Protocol`` instance.
The full settings for this ``Protocol`` instance. Must be passed an instance of Settings or a
subclass which is specialised for a particular Protocol
"""
self._settings = settings
self._settings = settings.frozen_copy()

@property
def settings(self) -> Settings:
"""The full settings for this ``Protocol`` instance."""
"""A read-only view of the settings for this ``Protocol`` instance."""
return self._settings

@classmethod
Expand Down
60 changes: 60 additions & 0 deletions gufe/settings/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,85 @@
Extra,
Field,
PositiveFloat,
PrivateAttr,
validator,
)
except ImportError:
from pydantic import (
Extra,
Field,
PositiveFloat,
PrivateAttr,
validator,
)


class SettingsBaseModel(DefaultModel):
"""Settings and modifications we want for all settings classes."""
_is_frozen: bool = PrivateAttr(default_factory=lambda: False)

class Config:
extra = Extra.forbid
arbitrary_types_allowed = False
smart_union = True

def frozen_copy(self):
"""A copy of this Settings object which cannot be modified

This is intended to be used by Protocols to make their stored Settings
read-only
"""
copied = self.copy(deep=True)

def freeze_model(model):
submodels = (
mod for field in model.__fields__
if isinstance(mod := getattr(model, field), SettingsBaseModel)
)
for mod in submodels:
freeze_model(mod)

if not model._is_frozen:
model._is_frozen = True

freeze_model(copied)
return copied

def unfrozen_copy(self):
"""A copy of this Settings object, which can be modified

Settings objects become frozen when within a Protocol. If you *really*
need to reverse this, this method is how.
"""
copied = self.copy(deep=True)

def unfreeze_model(model):
submodels = (
mod for field in model.__fields__
if isinstance(mod := getattr(model, field), SettingsBaseModel)
)
for mod in submodels:
unfreeze_model(mod)

model._is_frozen = False

unfreeze_model(copied)

return copied

@property
def is_frozen(self):
"""If this Settings object is frozen and cannot be modified"""
return self._is_frozen

def __setattr__(self, name, value):
if name != "_is_frozen" and self._is_frozen:
raise AttributeError(
f"Cannot set '{name}': Settings are immutable once attached"
" to a Protocol and cannot be modified. Modify Settings "
"*before* creating the Protocol.")
return super().__setattr__(name, value)


class ThermoSettings(SettingsBaseModel):
"""Settings for thermodynamic parameters.
Expand Down
8 changes: 4 additions & 4 deletions gufe/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def absolute_transformation(solvated_ligand, solvated_complex):
return gufe.Transformation(
solvated_ligand,
solvated_complex,
protocol=DummyProtocol(settings=None),
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
mapping=None,
)

Expand All @@ -253,7 +253,7 @@ def absolute_transformation(solvated_ligand, solvated_complex):
def complex_equilibrium(solvated_complex):
return gufe.NonTransformation(
solvated_complex,
protocol=DummyProtocol(settings=None)
protocol=DummyProtocol(settings=DummyProtocol.default_settings())
)


Expand Down Expand Up @@ -292,7 +292,7 @@ def benzene_variants_star_map(
] = gufe.Transformation(
solvated_ligands["benzene"],
solvated_ligands[ligand.name],
protocol=DummyProtocol(settings=None),
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
mapping=None,
)

Expand All @@ -316,7 +316,7 @@ def benzene_variants_star_map(
] = gufe.Transformation(
solvated_complexes["benzene"],
solvated_complexes[ligand.name],
protocol=DummyProtocol(settings=None),
protocol=DummyProtocol(settings=DummyProtocol.default_settings()),
mapping=None,
)

Expand Down
4 changes: 2 additions & 2 deletions gufe/tests/test_alchemicalnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
class TestAlchemicalNetwork(GufeTokenizableTestsMixin):

cls = AlchemicalNetwork
key = "AlchemicalNetwork-8c6df17d7ecf5902e2e338984cc11140"
repr = "<AlchemicalNetwork-8c6df17d7ecf5902e2e338984cc11140>"
key = "AlchemicalNetwork-d1035e11493ca60ff7bac5171eddfee3"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking: is this changing the gufe key because you've changed the fixture to use non-None settings? (Otherwise, nothing in this PR should affect gufe keys, right?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep it's this

repr = "<AlchemicalNetwork-d1035e11493ca60ff7bac5171eddfee3>"

@pytest.fixture
def instance(self, benzene_variants_star_map):
Expand Down
60 changes: 59 additions & 1 deletion gufe/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from openff.units import unit
import pytest

from gufe.settings.models import Settings, OpenMMSystemGeneratorFFSettings
from gufe.settings.models import (
OpenMMSystemGeneratorFFSettings,
Settings,
ThermoSettings,
)


def test_model_schema():
Expand Down Expand Up @@ -53,3 +57,57 @@ def test_invalid_constraint(value, good):
else:
with pytest.raises(ValueError):
_ = OpenMMSystemGeneratorFFSettings(constraints=value)


class TestFreezing:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current tests check that it works we do settings.subsettings.item = .... Could you add a test that does settings.subsettings = ... ? (example: you attempt to replace the entire thermo_settings in one go).

def test_default_not_frozen(self):
s = Settings.get_defaults()
# make a frozen copy to check this doesn't alter the original
s2 = s.frozen_copy()

s.thermo_settings.temperature = 199 * unit.kelvin
assert s.thermo_settings.temperature == 199 * unit.kelvin

def test_freezing(self):
s = Settings.get_defaults()

s2 = s.frozen_copy()

with pytest.raises(AttributeError, match="immutable"):
s2.thermo_settings.temperature = 199 * unit.kelvin

def test_unfreezing(self):
s = Settings.get_defaults()

s2 = s.frozen_copy()

with pytest.raises(AttributeError, match="immutable"):
s2.thermo_settings.temperature = 199 * unit.kelvin

assert s2.is_frozen

s3 = s2.unfrozen_copy()

s3.thermo_settings.temperature = 199 * unit.kelvin
assert s3.thermo_settings.temperature == 199 * unit.kelvin

def test_frozen_equality(self):
# the frozen-ness of Settings doesn't alter its contents
# therefore a frozen/unfrozen Settings which are otherwise identical
# should be considered equal
s = Settings.get_defaults()
s2 = s.frozen_copy()

assert s == s2

def test_set_subsection(self):
# check that attempting to set a subsection of settings still respects
# frozen state of parent object
s = Settings.get_defaults().frozen_copy()

assert s.is_frozen

ts = ThermoSettings(temperature=301 * unit.kelvin)

with pytest.raises(AttributeError, match="immutable"):
s.thermo_settings = ts
24 changes: 22 additions & 2 deletions gufe/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_dag_execute_failure(self, protocol_dag_broken):
assert len(succeeded_units) > 0

def test_dag_execute_failure_raise_error(self, solvated_ligand, vacuum_ligand, tmpdir):
protocol = BrokenProtocol(settings=None)
protocol = BrokenProtocol(settings=BrokenProtocol.default_settings())
dag = protocol.create(
stateA=solvated_ligand, stateB=vacuum_ligand, name="a broken dummy run",
mapping=None,
Expand Down Expand Up @@ -507,7 +507,7 @@ def _defaults(cls):

@classmethod
def _default_settings(cls):
return {}
return settings.Settings.get_defaults()

def _create(
self,
Expand Down Expand Up @@ -719,3 +719,23 @@ def test_execute_DAG_bad_nretries(solvated_ligand, vacuum_ligand, tmpdir):
keep_scratch=True,
raise_error=False,
n_retries=-1)


def test_settings_readonly():
# checks that settings aren't editable once inside a Protocol
p = DummyProtocol(DummyProtocol.default_settings())

before = p.settings.n_repeats

with pytest.raises(AttributeError, match="immutable"):
p.settings.n_repeats = before + 1

assert p.settings.n_repeats == before

# also check child settings
before = p.settings.thermo_settings.temperature

with pytest.raises(AttributeError, match="immutable"):
p.settings.thermo_settings.temperature = 400.0 * unit.kelvin

assert p.settings.thermo_settings.temperature == before
9 changes: 6 additions & 3 deletions gufe/tests/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@ def test_equality(self, absolute_transformation, solvated_ligand, solvated_compl
)
assert absolute_transformation != opposite

s = DummyProtocol.default_settings()
s.n_repeats = 99
different_protocol_settings = Transformation(
solvated_ligand,
solvated_complex,
protocol=DummyProtocol(settings={"lol": True}),
protocol=DummyProtocol(settings=s),
)
assert absolute_transformation != different_protocol_settings

Expand Down Expand Up @@ -188,9 +190,10 @@ def test_protocol_extend(self, complex_equilibrium, tmpdir):
assert len(protocolresult.data) == 2

def test_equality(self, complex_equilibrium, solvated_ligand, solvated_complex):

s = DummyProtocol.default_settings()
s.n_repeats = 4031
different_protocol_settings = NonTransformation(
solvated_complex, protocol=DummyProtocol(settings={"lol": True})
solvated_complex, protocol=DummyProtocol(settings=s)
)
assert complex_equilibrium != different_protocol_settings

Expand Down