Skip to content

Commit

Permalink
Add figure object to return statement of draw method of EGraph (#55)
Browse files Browse the repository at this point in the history
* viz.plot_Z_err_cond adjust xlim centralpeak plot

* black formatting viz.py

* remove 1 variable for viz.plot_Z_err_cond

* using GKPFormatter for x-ticks

* removed fancy formatter and use existing gkp.to_pi_string

* removed unused import FormatStrFormatter

* removed unnecessary newxlabels

* add return fig to plot_frac and more consistent code

* mess from cherry-pick :see-no-evil:

* Revert "removed fancy formatter and use existing gkp.to_pi_string"

This reverts commit fe59c23.

* using GKPFormatter for x-ticks

* removed fancy formatter and use existing gkp.to_pi_string

* removed unused import FormatStrFormatter

* removed unnecessary newxlabels

* add return fig to plot_frac and more consistent code

* changes to viz.py for adjusted xlim

* changes to gkp.py for adjusted xlim

* using GKPFormatter for x-ticks

* removed fancy formatter and use existing gkp.to_pi_string

* removed unused import FormatStrFormatter

* removed unnecessary newxlabels

* add return fig to plot_frac and more consistent code

* removed definition GKPFormatter

* found the culprit - the formatter *is* needed

* also use GKP_Formatter for yaxis frac plot

* rename GKP_Formatter to GKPFormatter

* use matplotlib Formatter class

* docstring for GKP Formatter class

* docstring init GKPFormatter

* Update gkp.py

* consistent style in string formatting float

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* More precise language in docstring

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Renaming GKPFormatter to PiFormatter

* Moving PiFormatter to viz.py

* moving import math package

* moving test_to_pi_string to new test_viz file

* adding test to_pi_string for tex=False

* docformatter for viz.py

* Minor changes

* have draw_decoding return fig, ax

* added return statements in draw methods of stabilizer and matching graph

* adding return for draw_decoding

* forgot to add fig2, ax2 from G_match.draw

* luis for the rescue

* black

* Added changelog details of PR #33

* add return to `egraph.draw`

* adjust example file surface code to new return statement

* added test draw EGraph

* assert that f,a have right types

* black

* adjust test_draw with issubclass

* call EGraph

* added random_graph as def param

* added docstring test

* remove random graph

* draw method does not work on empty graph

* Update .github/CHANGELOG.md

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* moved draw test to test_test

* remove matplotlib import

* group EGraph draw tests in class

* change encoding from CRLF to LF

* Create .gitattributes

* Update .gitattributes

* Delete .gitattributes

* Update tests/unit/test_viz.py

Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>

* Finalizing PR

Co-authored-by: soosub <joost.bus@resident.xanadu.ai>
Co-authored-by: ilan-tz <57886357+ilan-tz@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 7, 2022
1 parent 489d50b commit d160088
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
### Bug fixes
* Small fix in `viz.draw_EGraph` that raised an error whenever a graph state with non-integer coordinates was plotted. [#68](https://github.com/XanaduAI/flamingpy/pull/68)


### Improvements
* Pylint is pinned to stable version `pylint==2.14.0` and added to `dev_requirements.txt`. [#76](https://github.com/XanaduAI/flamingpy/pull/76)
* pylint no-self-use tags are removed as this check has been removed from pylint (see [here](https://github.com/PyCQA/pylint/issues/5502)).
* Added tests for `EGraph` plots. [#60](https://github.com/XanaduAI/flamingpy/pull/60)
* Added `.gitattributes` to the repository, so git automatically handles consistent `eol`'s for all commits and contributors across different operating systems. [#78](https://github.com/XanaduAI/flamingpy/pull/78)
* Added `fig, ax` returns for the draw methods in `utils/viz.py` and some additional tests. [#55](https://github.com/XanaduAI/flamingpy/pull/55)

### Documentation changes
* Mention the new graph state functions from `flamingpy.utils.graph_states` in the `run_graph_states.py` tutorial. [#68](https://github.com/XanaduAI/flamingpy/pull/68)
Expand Down
2 changes: 1 addition & 1 deletion flamingpy/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
"""Version number (major.minor.patch[label])"""


__version__ = "0.8.2a5.dev4"
__version__ = "0.8.2a5.dev5"
4 changes: 2 additions & 2 deletions flamingpy/codes/graphs/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ def draw(self, **kwargs):
"""
from flamingpy.utils.viz import draw_EGraph

_, ax = draw_EGraph(self, **kwargs)
return ax
fig, ax = draw_EGraph(self, **kwargs)
return fig, ax

def draw_adj(self, **kwargs):
"""Draw the adjacency matrix with matplotlib.
Expand Down
4 changes: 2 additions & 2 deletions flamingpy/examples/surface_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def illustrate_surface_code(d, boundaries, err, polarity, stabilizer_inds=None,
# Instantiate a surface code.
RHG_code = SurfaceCode(d, ec=err, boundaries=boundaries, polarity=polarity)
RHG_lattice = RHG_code.graph
RHG_fig = RHG_code.draw()
RHG_fig, RHG_ax = RHG_code.draw()

# Check edges between boundaries for periodic boundary conditions.
if boundaries == "periodic":
Expand All @@ -67,7 +67,7 @@ def illustrate_surface_code(d, boundaries, err, polarity, stabilizer_inds=None,
color = np.random.rand(3)
for point in stabilizer.egraph:
x, z, y = point
RHG_fig.scatter(x, z, y, color=color, s=40)
RHG_ax.scatter(x, z, y, color=color, s=40)

if show:
plt.show()
Expand Down
3 changes: 0 additions & 3 deletions tests/unit/test_graphstates.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ def test_index(self, random_graph):
# def test_slice_coords(self):
# pass

# def test_draw(self):
# pass


class TestCVHelpers:
"""Tests for CVLayer helper functions."""
Expand Down
60 changes: 35 additions & 25 deletions tests/unit/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import numpy as np
from numpy.random import default_rng as rng
import pytest
import matplotlib
import matplotlib.pyplot as plt


from flamingpy.utils.viz import to_pi_string, draw_EGraph
from flamingpy.codes.graphs import EGraph
from flamingpy.codes import SurfaceCode
Expand Down Expand Up @@ -53,35 +53,45 @@ def test_to_pi_string():
assert to_pi_string(-np.sqrt(np.pi) / 2, tex=False) == "-\\sqrt{\\pi}/2"


def test_draw_egraph_bell():
"""Test for the draw method of EGraph of Bell state."""
# Bell state EGraph
edge = [(0, 0, 0), (0, 0, 1)]
bell_state = EGraph()
bell_state.add_edge(*edge, color="MidnightBlue")
class TestDrawEGraph:
"""Tests for visualizing EGraphs."""

def test_draw_egraph_bell(self):
"""Test for the draw method of EGraph of Bell state."""
# Bell state EGraph
edge = [(0, 0, 0), (0, 0, 1)]
bell_state = EGraph()
bell_state.add_edge(*edge, color="MidnightBlue")

# Test for drawing the EGraph
_, a = draw_EGraph(bell_state)
plt.close()
# Test for drawing the EGraph
_, a = draw_EGraph(bell_state)
plt.close()

assert len(a.get_xticks()) == 1
assert a.get_xlim() == (-1, 1)
assert len(a.get_xticks()) == 1
assert a.get_xlim() == (-1, 1)

def test_wrapper_draw_egraph(self):
"""Tests the returned object of EGraph.draw of EGraph with one node."""
E = EGraph()
E.add_node((0, 0, 0))
f, a = E.draw()
assert issubclass(type(f), matplotlib.figure.Figure)
assert issubclass(type(a), matplotlib.axes.Axes)

@pytest.mark.parametrize("d", (2, 3))
def test_draw_egraph_rhg(d):
"""Test for the draw method of EGraph of RHG lattice."""
# Bell state EGraph
RHG = SurfaceCode(d).graph
@pytest.mark.parametrize("d", (2, 3))
def test_draw_egraph_rhg(self, d):
"""Test for the draw method of EGraph of RHG lattice."""
# Bell state EGraph
RHG = SurfaceCode(d).graph

# Test for drawing the EGraph
_, a = draw_EGraph(RHG)
plt.close()
# Test for drawing the EGraph
_, a = draw_EGraph(RHG)
plt.close()

n_ticks = 2 * d - 1
n_ticks = 2 * d - 1

ticks = (a.get_xticks(), a.get_yticks(), a.get_zticks())
assert [len(tick) for tick in ticks] == [n_ticks] * 3
ticks = (a.get_xticks(), a.get_yticks(), a.get_zticks())
assert [len(tick) for tick in ticks] == [n_ticks] * 3

actual_lims = (a.get_xlim(), a.get_ylim(), a.get_zlim())
assert actual_lims == ((0, n_ticks - 1), (1, n_ticks), (1, n_ticks))
actual_lims = (a.get_xlim(), a.get_ylim(), a.get_zlim())
assert actual_lims == ((0, n_ticks - 1), (1, n_ticks), (1, n_ticks))

0 comments on commit d160088

Please sign in to comment.