Skip to content

Commit

Permalink
BUGFIX+FEATURE: Trace loaded from file is fully functional class now.…
Browse files Browse the repository at this point in the history
… It was just a file object earlier
  • Loading branch information
vikashplus committed Mar 23, 2023
1 parent 2bb6fdd commit 0e52724
Showing 1 changed file with 65 additions and 29 deletions.
94 changes: 65 additions & 29 deletions robohive/logger/grouped_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sys import platform
import skvideo.io
import os
import enum

# Trace_name: {
# grp1: {dataset{k1:v1}, dataset{k2:v2}, ...}
Expand All @@ -15,13 +16,25 @@

# ToDo
# access pattern for pickle and h5 backbone post load isn't the same
# Should we get rid of pickle support and double down on h5?
# - Should we get rid of pickle support and double down on h5?
# - other way would to make the default container (trace.trace) h5 container instead of a dict
# Should we explicitely keep tract if the trace has been flattened/ stacked/ closed etc?


class TraceType(enum.Enum):
"""Trace types."""
UNSET = -1
ROBOHIVE = 0
ROBOSET = 1


class Trace:
def __init__(self, name):
def __init__(self, name, trace_type=TraceType.ROBOHIVE):
self.name = name
self.root = {name: {}}
self.trace = self.root[name]
self.index = 0
self.type = trace_type

# Create a group in your logs
def create_group(self, name):
Expand Down Expand Up @@ -218,30 +231,38 @@ def __len__(self) -> str:
# Display data
def __repr__(self) -> str:
disp = "Trace_name: {}\n".format(self.root.keys())
for grp_k, grp_v in self.trace.items():
disp += "{"+grp_k+": \n"
for dst_k, dst_v in grp_v.items():

# raw
if type(dst_v) == list:
datum = dst_v[0]
try:
ll = datum.shape
except:
ll = ()
disp += "\t{}:[{}_{}]_{}\n".format(dst_k, str(type(dst_v[0])), ll, len(dst_v))

# flattened
elif type(dst_v) == dict:
datum = dst_v
disp += "\t{}: {}\n".format(dst_k, str(type(datum)))

# numpified
else:
datum = dst_v
disp += "\t{}: {}, shape{}, type({})\n".format(dst_k, str(type(datum)), datum.shape, datum.dtype)
if isinstance(self.trace, h5py.File):
# Trace (when reloaded from h5)
for k, v in self.trace.items():
disp += v.__repr__()+"\n"
for kk,vv in v.items():
disp += "\t"+vv.__repr__()+"\n"

disp += "}\n"
else:
# Trace (while open)
for grp_k, grp_v in self.trace.items():
disp += "{"+grp_k+": \n"
for dst_k, dst_v in grp_v.items():
# raw
if type(dst_v) == list:
datum = dst_v[0]
try:
ll = datum.shape
except:
ll = ()
disp += "\t{}:[{}_{}]_{}\n".format(dst_k, str(type(dst_v[0])), ll, len(dst_v))

# flattened
elif type(dst_v) == dict:
datum = dst_v
disp += "\t{}: {}\n".format(dst_k, str(type(datum)))

# numpified
else:
datum = dst_v
disp += "\t{}: {}, shape{}, type({})\n".format(dst_k, str(type(datum)), datum.shape, datum.dtype)
disp += "}\n"
return disp


Expand Down Expand Up @@ -318,13 +339,28 @@ def save(self,


# load trace from disk
def load(trace_path):
trace_format = trace_path.split('.')[-1]
print("Reading: ", trace_path)
@staticmethod
def load(trace_path, trace_type=TraceType.UNSET):
"""
trace_path: Load the trace using the provided path
trace_type: Provide the trace type of the path; UNSET will be used if not provided
Note:
Loaded trace has some difference with the original trace
- h5 vs dict format
- flattend schema
"""
trace_name, trace_format = trace_path.split('.')
print("Reading:", trace_path)
if trace_format == "h5":
trace = h5py.File(trace_path, "r")
trace = Trace(name=trace_name, trace_type=trace_type)
file_data = h5py.File(trace_path, "r")
trace.trace = file_data # load data
trace.root[trace.name] = trace.trace # build root
else:
trace = pickle.load(open(trace_path, 'rb'))
file_data = pickle.load(open(trace_path, 'rb'))
trace = Trace(name=list(file_data.keys())[0], trace_type=trace_type)
trace.trace = file_data[trace.name] # load data
trace.root = file_data # build root
return trace


Expand Down

0 comments on commit 0e52724

Please sign in to comment.