diff --git a/README.md b/README.md index 1b4d967b..6e3e403f 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ Keep your training algorithm the same, just replace the data loader! Look at the ## Installation +### Linux ``` conda create -y -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge conda activate ffcv @@ -41,6 +42,21 @@ pip install ffcv ``` Troubleshooting note: if the above commands result in a package conflict error, try running ``conda config --env --set channel_priority flexible`` in the environment and rerunning the installation command. +### Windows +* Install opencv4 + * Add `..../opencv/build/x64/vc15/bin` to PATH environment variable +* Install libjpeg-turbo, download libjpeg-turbo-x.x.x-vc64.exe, not gcc64 + * Add `..../libjpeg-turbo64/bin` to PATH environment variable +* Install pthread, download last release.zip + * After unzip, rename Pre-build.2 folder to pthread + * Open `pthread/include/pthread.h`, and add the code below to the top of the file. + ```cpp + #define HAVE_STRUCT_TIMESPEC + ``` + * Add `..../pthread/dll` to PATH environment variable +* Install cupy depending on your CUDA Toolkit version. +* `pip install ffcv` + ## Citation If you use FFCV, please cite it as: diff --git a/ffcv/.DS_Store b/ffcv/.DS_Store new file mode 100644 index 00000000..f1fad9b3 Binary files /dev/null and b/ffcv/.DS_Store differ diff --git a/ffcv/fields/__init__.py b/ffcv/fields/__init__.py index 90987fe7..79b76d52 100644 --- a/ffcv/fields/__init__.py +++ b/ffcv/fields/__init__.py @@ -2,8 +2,8 @@ from .basics import FloatField, IntField from .rgb_image import RGBImageField from .bytes import BytesField -from .ndarray import NDArrayField +from .ndarray import NDArrayField, TorchTensorField from .json import JSONField __all__ = ['Field', 'BytesField', 'IntField', 'FloatField', 'RGBImageField', - 'NDArrayField', 'JSONField'] \ No newline at end of file + 'NDArrayField', 'JSONField', 'TorchTensorField'] \ No newline at end of file diff --git a/ffcv/fields/ndarray.py b/ffcv/fields/ndarray.py index 740f8bf5..df347d43 100644 --- a/ffcv/fields/ndarray.py +++ b/ffcv/fields/ndarray.py @@ -1,8 +1,10 @@ from typing import Callable, TYPE_CHECKING, Tuple, Type +import warnings import json from dataclasses import replace import numpy as np +import torch as ch from .base import Field, ARG_TYPE from ..pipeline.operation import Operation @@ -55,6 +57,10 @@ def __init__(self, dtype:np.dtype, shape:Tuple[int, ...]): self.dtype = dtype self.shape = shape self.element_size = dtype.itemsize * np.prod(shape) + if dtype == np.uint16: + warnings.warn("Pytorch currently doesn't support uint16" + "we recommend storing as int16 and reinterpret your data later" + "in your pipeline") @property def metadata_type(self) -> np.dtype: @@ -93,4 +99,21 @@ def encode(self, destination, field, malloc): data_region[:] = field.reshape(-1).view(' Type[Operation]: - return NDArrayDecoder \ No newline at end of file + return NDArrayDecoder + + +class TorchTensorField(NDArrayField): + """A subclass of :class:`~ffcv.fields.Field` supporting + multi-dimensional fixed size matrices of any torch type. + """ + def __init__(self, dtype:ch.dtype, shape:Tuple[int, ...]): + self.dtype = dtype + self.shape = shape + dtype = ch.zeros(0, dtype=dtype).numpy().dtype + + super().__init__(dtype, shape) + + + def encode(self, destination, field, malloc): + field = field.numpy() + return super().encode(destination, field, malloc) diff --git a/ffcv/libffcv.py b/ffcv/libffcv.py index 52219f3c..693269f6 100644 --- a/ffcv/libffcv.py +++ b/ffcv/libffcv.py @@ -1,13 +1,18 @@ import ctypes from numba import njit import numpy as np +import platform from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll import ffcv._libffcv lib = CDLL(ffcv._libffcv.__file__) -libc = cdll.LoadLibrary('libc.so.6') +if platform.system() == "Windows": + libc = cdll.msvcrt + read_c = libc._read +else: + libc = cdll.LoadLibrary('libc.so.6') + read_c = libc.pread -read_c = libc.pread read_c.argtypes = [c_uint32, c_void_p, c_uint64, c_uint64] def read(fileno:int, destination:np.ndarray, offset:int): @@ -47,5 +52,5 @@ def imdecode(source: np.ndarray, dst: np.ndarray, ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64] def memcpy(source: np.ndarray, dest: np.ndarray): - return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size) + return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize) diff --git a/ffcv/loader/epoch_iterator.py b/ffcv/loader/epoch_iterator.py index 658ff86c..1be7cdfa 100644 --- a/ffcv/loader/epoch_iterator.py +++ b/ffcv/loader/epoch_iterator.py @@ -19,6 +19,17 @@ (`OrderOption.QUASI_RANDOM`) in the dataloader constructor's `order` argument. ''' +def select_buffer(buffer, batch_slot, count): + """Util function to select the relevent subpart of a buffer for a given + batch_slot and batch size""" + if buffer is None: + return None + if isinstance(buffer, tuple): + return tuple(select_buffer(x, batch_slot, count) for x in buffer) + + return buffer[batch_slot][:count] + + class EpochIterator(Thread): def __init__(self, loader: 'Loader', order: Sequence[int]): super().__init__(daemon=True) @@ -33,6 +44,10 @@ def __init__(self, loader: 'Loader', order: Sequence[int]): self.terminate_event = Event() self.memory_context = self.loader.memory_manager.schedule_epoch( batches) + + if IS_CUDA: + self.current_stream = ch.cuda.current_stream() + try: self.memory_context.__enter__() except MemoryError as e: @@ -44,23 +59,13 @@ def __init__(self, loader: 'Loader', order: Sequence[int]): self.storage_state = self.memory_context.state - self.memory_bank_per_stage = defaultdict(list) - self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None) for _ in range(self.loader.batches_ahead + 2)] - # Allocate all the memory - memory_allocations = {} - for (p_id, p) in self.loader.pipelines.items(): - memory_allocations[p_id] = p.allocate_memory(self.loader.batch_size, - self.loader.batches_ahead + 2) - - # Assign each memory bank to the pipeline stage it belongs to - for s_ix, banks in self.loader.memory_bank_keys_per_stage.items(): - for (pipeline_name, op_id) in banks: - self.memory_bank_per_stage[s_ix].append( - memory_allocations[pipeline_name][op_id] - ) + self.memory_allocations = self.loader.graph.allocate_memory( + self.loader.batch_size, + self.loader.batches_ahead + 2 + ) self.start() @@ -77,6 +82,7 @@ def run(self): self.current_batch_slot = ( slot + 1) % (self.loader.batches_ahead + 2) result = self.run_pipeline(b_ix, ixes, slot, events[slot]) + # print("RES", b_ix, "ready") to_output = (slot, result) while True: try: @@ -88,15 +94,17 @@ def run(self): if self.terminate_event.is_set(): return if IS_CUDA: + # print("SUB", b_ix) # We were able to submit this batch # Therefore it means that the user must have entered the for loop for # (batch_slot - batch_ahead + 1) % (batches ahead + 2) # Therefore batch_slot - batch_ahead must have all it's work submitted # We will record an event of all the work submitted on the main stream # and make sure no one overwrite the data until they are done - just_finished_slot = (slot - self.loader.batches_ahead) % (self.loader.batches_ahead + 2) + just_finished_slot = (slot - self.loader.batches_ahead - 1) % (self.loader.batches_ahead + 2) + # print("JFS", just_finished_slot) event = ch.cuda.Event() - event.record(ch.cuda.default_stream()) + event.record(self.current_stream) events[just_finished_slot] = event b_ix += 1 @@ -104,7 +112,6 @@ def run(self): self.output_queue.put(None) def run_pipeline(self, b_ix, batch_indices, batch_slot, cuda_event): - # print(b_ix, batch_indices) self.memory_context.start_batch(b_ix) args = [] if IS_CUDA: @@ -114,28 +121,35 @@ def run_pipeline(self, b_ix, batch_indices, batch_slot, cuda_event): ctx = nullcontext() first_stage = False + + code, outputs = self.loader.code with ctx: if IS_CUDA: if cuda_event: cuda_event.wait() - for stage, banks in self.memory_bank_per_stage.items(): - args.insert(0, batch_indices) - for bank in banks: - if bank is not None: - if isinstance(bank, tuple): - bank = tuple(x[batch_slot] for x in bank) - else: - bank = bank[batch_slot] - args.append(bank) - args.append(self.metadata) - args.append(self.storage_state) - code = self.loader.code_per_stage[stage] - result = code(*args) - args = list(result) - if first_stage: - first_stage = False - self.memory_context.end_batch(b_ix) - return tuple(x[:len(batch_indices)] for x in args) + + args = { + 'batch_indices': batch_indices, + 'storage_state': self.storage_state, + 'metadata': self.metadata, + **{ + f'memory_{k}':select_buffer(v, batch_slot, len(batch_indices)) + for (k, v) in self.memory_allocations['operation'].items() + }, + **{ + f'shared_memory_{k}': select_buffer(v, batch_slot, len(batch_indices)) + for (k, v) in self.memory_allocations['shared'].items() + } + } + + for stage_code, define_outputs in code: + results = stage_code(**args) + for node_id, result in zip(define_outputs, results): + args[f'result_{node_id}'] = result + pass + + result = tuple(args[f'result_{x}'] for x in outputs) + return result def __next__(self): result = self.output_queue.get() @@ -146,7 +160,7 @@ def __next__(self): if IS_CUDA: stream = self.cuda_streams[slot] # We wait for the copy to be done - ch.cuda.current_stream().wait_stream(stream) + self.current_stream.wait_stream(stream) return result def __iter__(self): diff --git a/ffcv/loader/loader.py b/ffcv/loader/loader.py index 21cdd104..6e4240ea 100644 --- a/ffcv/loader/loader.py +++ b/ffcv/loader/loader.py @@ -8,7 +8,9 @@ from re import sub from typing import Any, Callable, Mapping, Sequence, Type, Union, Literal from collections import defaultdict +from collections.abc import Collection from enum import Enum, unique, auto + from ffcv.fields.base import Field import torch as ch @@ -18,11 +20,9 @@ from ..reader import Reader from ..traversal_order.base import TraversalOrder from ..traversal_order import Random, Sequential, QuasiRandom -from ..pipeline import Pipeline -from ..pipeline.compiler import Compiler +from ..pipeline import Pipeline, PipelineSpec, Compiler from ..pipeline.operation import Operation -from ..transforms.ops import ToTensor -from ..transforms.module import ModuleWrapper +from ..pipeline.graph import Graph from ..memory_managers import ( ProcessCacheManager, OSCacheManager, MemoryManager ) @@ -63,8 +63,8 @@ class Loader: Number of workers used for data loading. Consider using the actual number of cores instead of the number of threads if you only use JITed augmentations as they usually don't benefit from hyper-threading. os_cache : bool Leverages the operating for caching purposes. This is beneficial when there is enough memory to cache the dataset and/or when multiple processes on the same machine training using the same dataset. See https://docs.ffcv.io/performance_guide.html for more information. - order : OrderOption - Traversal order, one of: SEQEUNTIAL, RANDOM, QUASI_RANDOM + order : Union[OrderOption, TraversalOrder] + Traversal order, one of: SEQEUNTIAL, RANDOM, QUASI_RANDOM, or a custom TraversalOrder QUASI_RANDOM is a random order that tries to be as uniform as possible while minimizing the amount of data read from the disk. Note that it is mostly useful when `os_cache=False`. Currently unavailable in distributed mode. distributed : bool @@ -91,7 +91,7 @@ def __init__(self, batch_size: int, num_workers: int = -1, os_cache: bool = DEFAULT_OS_CACHE, - order: ORDER_TYPE = OrderOption.SEQUENTIAL, + order: Union[ORDER_TYPE, TraversalOrder] = OrderOption.SEQUENTIAL, distributed: bool = False, seed: int = None, # For ordering of samples indices: Sequence[int] = None, # For subset selection @@ -135,7 +135,7 @@ def __init__(self, self.num_workers: int = num_workers self.drop_last: bool = drop_last self.distributed: bool = distributed - self.code_per_stage = None + self.code = None self.recompile = recompile if self.num_workers < 1: @@ -154,49 +154,61 @@ def __init__(self, self.memory_manager: MemoryManager = ProcessCacheManager( self.reader) - self.traversal_order: TraversalOrder = ORDER_MAP[order](self) + if order in ORDER_MAP: + self.traversal_order: TraversalOrder = ORDER_MAP[order](self) + elif isinstance(order, TraversalOrder): + self.traversal_order: TraversalOrder = order(self) + else: + raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder") memory_read = self.memory_manager.compile_reader() self.next_epoch: int = 0 self.pipelines = {} + self.pipeline_specs = {} self.field_name_to_f_ix = {} - + + custom_pipeline_specs = {} + + # Creating PipelineSpec objects from the pipeline dict passed + # by the user + for output_name, spec in pipelines.items(): + if isinstance(spec, PipelineSpec): + pass + elif isinstance(spec, Sequence): + spec = PipelineSpec(output_name, decoder=None, transforms=spec) + elif spec is None: + continue # This is a disabled field + else: + msg = f"The pipeline for {output_name} has to be " + msg += f"either a PipelineSpec or a sequence of operations" + raise ValueError(msg) + custom_pipeline_specs[output_name] = spec + + # Adding the default pipelines for f_ix, (field_name, field) in enumerate(self.reader.handlers.items()): self.field_name_to_f_ix[field_name] = f_ix - DecoderClass = field.get_decoder_class() - try: - operations = pipelines[field_name] - # We check if the user disabled this field - if operations is None: - continue - if not isinstance(operations[0], DecoderClass): - msg = "The first operation of the pipeline for " - msg += f"'{field_name}' has to be a subclass of " - msg += f"{DecoderClass}" - raise ValueError(msg) - except KeyError: - try: - operations = [ - DecoderClass(), - ToTensor() - ] - except Exception: - msg = f"Impossible to create a default pipeline" - msg += f"{field_name}, please define one manually" - raise ValueError(msg) - - for i, op in enumerate(operations): - assert isinstance(op, (ch.nn.Module, Operation)), op - if isinstance(op, ch.nn.Module): - operations[i] = ModuleWrapper(op) - - for op in operations: - op.accept_field(field) - op.accept_globals(self.reader.metadata[f'f{f_ix}'], - memory_read) - - self.pipelines[field_name] = Pipeline(operations) + + if field_name not in custom_pipeline_specs: + # We add the default pipeline + if field_name not in pipelines: + self.pipeline_specs[field_name] = PipelineSpec(field_name) + else: + self.pipeline_specs[field_name] = custom_pipeline_specs[field_name] + + # We add the custom fields after the default ones + # This is to preserve backwards compatibility and make sure the order + # is intuitive + for field_name, spec in custom_pipeline_specs.items(): + if field_name not in self.pipeline_specs: + self.pipeline_specs[field_name] = spec + + self.graph = Graph(self.pipeline_specs, self.reader.handlers, + self.field_name_to_f_ix, self.reader.metadata, + memory_read) + + self.generate_code() + self.first_traversal_order = self.next_traversal_order() def next_traversal_order(self): return self.traversal_order.sample_order(self.next_epoch) @@ -208,7 +220,7 @@ def __iter__(self): self.next_epoch += 1 # Compile at the first epoch - if self.code_per_stage is None or self.recompile: + if self.code is None or self.recompile: self.generate_code() return EpochIterator(self, selected_order) @@ -251,123 +263,16 @@ def filter(self, field_name:str, condition: Callable[[Any], bool]) -> 'Loader': def __len__(self): - next_order = self.next_traversal_order() + next_order = self.first_traversal_order if self.drop_last: return len(next_order) // self.batch_size else: return int(np.ceil(len(next_order) / self.batch_size)) - def generate_function_call(self, pipeline_name, op_id, needs_indices): - p_ix = self.field_name_to_f_ix[pipeline_name] - pipeline_identifier = f'code_{pipeline_name}_{op_id}' - memory_identifier = f'memory_{pipeline_name}_{op_id}' - result_identifier = f'result_{pipeline_name}' - - arg_id = result_identifier - # This is the decoder so we pass the indices instead of the previous - # result - if op_id == 0: - arg_id = 'batch_indices' - - tree = ast.parse(f""" -{result_identifier} = {pipeline_identifier}({arg_id}, {memory_identifier}) - """).body[0] - - # This is the first call of the pipeline, we pass the metadata and - # storage state - if op_id == 0: - tree.value.args.extend([ - ast.Subscript(value=ast.Name(id='metadata', ctx=ast.Load()), - slice=ast.Index(value=ast.Constant(value=f'f{p_ix}', kind=None)), ctx=ast.Load()), - ast.Name(id='storage_state', ctx=ast.Load()), - ]) - if needs_indices: - tree.value.args.extend([ - ast.Name(id='batch_indices', ctx=ast.Load()), - ]) - return tree - - def generate_stage_code(self, stage, stage_ix, functions): - fun_name = f'stage_{stage_ix}' - base_code = ast.parse(f""" -def {fun_name}(): - pass - """).body[0] - - function_calls = [] - memory_banks = [] - memory_banks_id = [] - for p_ix, pipeline_name, op_id, needs_indices in stage: - function_calls.append(self.generate_function_call(pipeline_name, - op_id, needs_indices)) - arg = ast.arg(arg=f'memory_{pipeline_name}_{op_id}') - memory_banks.append(arg) - memory_banks_id.append((pipeline_name, op_id)) - - base_code.body.pop() - base_code.body.extend(function_calls) - - return_tuple = ast.Return(value=ast.Tuple(elts=[], ctx=ast.Load())) - - base_code.args.args.append(ast.arg(arg='batch_indices')) - - for p_id in self.pipelines.keys(): - r = f'result_{p_id}' - if stage_ix != 0: - base_code.args.args.append(ast.arg(arg=r)) - return_tuple.value.elts.append(ast.Name(id=r, ctx=ast.Load())) - - - base_code.body.append(return_tuple) - base_code.args.args.extend(memory_banks) - base_code.args.args.append(ast.arg(arg='metadata')) - base_code.args.args.append(ast.arg(arg='storage_state')) - - module = ast.fix_missing_locations( - ast.Module(body=[base_code], - type_ignores=[]) - ) - namespace = { - **functions, - } - - exec(compile(module, '', 'exec'), namespace) - final_code = namespace[fun_name] - if stage_ix % 2 == 0: - final_code = Compiler.compile(final_code) - return final_code, memory_banks_id def generate_code(self): - schedule = defaultdict(lambda: []) - compiled_functions = {} - for p_ix, (p_id, p) in enumerate(self.pipelines.items()): - stage = 0 - for jitted_block, block_content in p.operation_blocks: - # Even stages are jitted Odds are not - # If this doesn't match for this pipeline we - # shift the operations - if 1 - jitted_block % 2 != stage % 2: - stage += 1 - for op in block_content: - ops_code = p.compiled_ops[op] - - needs_indices = False - if hasattr(ops_code, 'with_indices'): - needs_indices = ops_code.with_indices - - if stage % 2 == 0: - ops_code = Compiler.compile(ops_code) - compiled_functions[f'code_{p_id}_{op}'] = ops_code - schedule[stage].append((p_ix, p_id, op, needs_indices)) - stage += 1 - - memory_bank_keys_per_stage = {} - self.code_per_stage = {} - for stage_ix, stage in schedule.items(): - code_for_stage, mem_banks_ids = self.generate_stage_code(stage, stage_ix, - compiled_functions) - self.code_per_stage[stage_ix] = code_for_stage - memory_bank_keys_per_stage[stage_ix] = mem_banks_ids - - self.memory_bank_keys_per_stage = memory_bank_keys_per_stage + queries, code = self.graph.collect_requirements() + self.code = self.graph.codegen_all(code) + + diff --git a/ffcv/pipeline/__init__.py b/ffcv/pipeline/__init__.py index 92eeaa47..444e71d9 100644 --- a/ffcv/pipeline/__init__.py +++ b/ffcv/pipeline/__init__.py @@ -1,3 +1,5 @@ from .pipeline import Pipeline +from .pipeline_spec import PipelineSpec +from .compiler import Compiler -__all__ = ['Pipeline'] \ No newline at end of file +__all__ = ['Pipeline', 'PipelineSpec', 'Compiler'] \ No newline at end of file diff --git a/ffcv/pipeline/allocation_query.py b/ffcv/pipeline/allocation_query.py index fd12b735..81a7f725 100644 --- a/ffcv/pipeline/allocation_query.py +++ b/ffcv/pipeline/allocation_query.py @@ -12,4 +12,31 @@ class AllocationQuery: device: Optional[ch.device] = None -Allocation = Union[AllocationQuery, Sequence[AllocationQuery]] \ No newline at end of file +Allocation = Union[AllocationQuery, Sequence[AllocationQuery]] + +def allocate_query(memory_allocation: AllocationQuery, batch_size: int, batches_ahead: int): + # We compute the total amount of memory needed for this + # operation + final_shape = [batches_ahead, + batch_size, *memory_allocation.shape] + if isinstance(memory_allocation.dtype, ch.dtype): + result = [] + for _ in range(final_shape[0]): + partial = ch.empty(*final_shape[1:], + dtype=memory_allocation.dtype, + device=memory_allocation.device) + try: + partial = partial.pin_memory() + except: + pass + result.append(partial) + else: + ch_dtype = ch.from_numpy(np.empty(0, dtype=memory_allocation.dtype)).dtype + result = ch.empty(*final_shape, + dtype=ch_dtype) + try: + result = result.pin_memory() + except: + pass + result = result.numpy() + return result \ No newline at end of file diff --git a/ffcv/pipeline/graph.py b/ffcv/pipeline/graph.py new file mode 100644 index 00000000..05da7cee --- /dev/null +++ b/ffcv/pipeline/graph.py @@ -0,0 +1,488 @@ +from distutils.log import warn +import warnings +import ast + +try: + # Useful for debugging + import astor +except ImportError: + pass + +from collections import defaultdict +from typing import Callable, Dict, List, Optional, Sequence, Set +from abc import ABC, abstractmethod +from ffcv.pipeline.allocation_query import AllocationQuery + +from ffcv.pipeline.pipeline_spec import PipelineSpec +from ffcv.pipeline.compiler import Compiler +from ffcv.pipeline.allocation_query import allocate_query +from .operation import Operation +from ..transforms import ModuleWrapper +from .state import State + +import torch as ch +import numpy as np + +# This is the starting state of the pipeline +INITIAL_STATE = State(jit_mode=True, + device=ch.device('cpu'), + dtype=np.dtype('u1'), + shape=None) + + +class Node(ABC): + last_node_id: int = 0 + def __init__(self): + self.id = Node.last_node_id + self._code = None + Node.last_node_id += 1 + + @property + @abstractmethod + def is_jitted(self): + raise NotImplemented() + + @property + @abstractmethod + def parent(self): + raise NotImplemented() + + @property + @abstractmethod + def arg_id(self): + raise NotImplemented() + + @property + @abstractmethod + def result_id(self): + raise NotImplemented() + + @property + @abstractmethod + def result_id(self): + raise NotImplemented() + + def get_shared_code_ast(self, done_ops): + return ast.Pass() + + @abstractmethod + def generate_code(self): + raise NotImplemented() + + def recompile(self): + self._code = self.generate_code() + + @property + def with_indices(self): + try: + return self.code.with_indices + except: + return False + + @property + def code(self): + if self._code is None: + self.recompile() + + return self._code + + @property + def func_call_ast(self): + pipeline_identifier = f'code_{self.id}' + memory_identifier = f'memory_{self.id}' + + tree = ast.parse(f""" +{self.result_id} = {pipeline_identifier}({self.arg_id}, {memory_identifier}) + """).body[0] + + if self.with_indices: + tree.value.args.extend([ + ast.Name(id='batch_indices', ctx=ast.Load()), + ]) + return tree + + +class DecoderNode(Node): + def __init__(self, field_name:str, decoder: Operation, f_ix:int): + super().__init__() + self.field_name = field_name + self.decoder = decoder + self.f_ix = f_ix + + @property + def is_jitted(self): + # Decoder have to jitted + return True + + @property + def parent(self): + return None + + @property + def arg_id(self): + return 'batch_indices' + + @property + def result_id(self): + return f"result_{self.id}" + + def generate_code(self): + return self.decoder.generate_code() + + @property + def func_call_ast(self): + tree = super().func_call_ast + tree.value.args.extend([ + ast.Subscript(value=ast.Name(id='metadata', ctx=ast.Load()), + slice=ast.Index(value=ast.Constant(value=f'f{self.f_ix}', kind=None)), ctx=ast.Load()), + ast.Name(id='storage_state', ctx=ast.Load()), + ]) + + return tree + + +class TransformNode(Node): + def __init__(self, parent:Node, operation: Operation): + super().__init__() + self._parent = parent + self.operation = operation + self.jitted = True + + def __repr__(self): + return f'TransformerNode({self.operation})' + + def generate_code(self): + return self.operation.generate_code() + + @property + def parent(self): + return self._parent + + @property + def is_jitted(self): + # Decoder have to jitted + return self.jitted + + @property + def arg_id(self): + return self.parent.result_id + + @property + def result_id(self): + return f"result_{self.id}" + + def get_shared_code_ast(self, done_ops): + if self.operation in done_ops: + return ast.Pass() + + done_ops[self.operation] = self.id + + pipeline_identifier = f'init_shared_state_code_{self.id}' + memory_identifier = f'shared_memory_{self.id}' + + tree = ast.parse(f"""{pipeline_identifier}({memory_identifier})""").body[0] + + + return tree + + +class RefNode(Node): + def __init__(self, ref_operation: Operation): + super().__init__() + self.ref_operation = ref_operation + self._parent = None + + def resolve_operation(self, operation_to_node: Dict[Operation, List[Node]]): + entries = operation_to_node[self.ref_operation] + if not entries: + raise ValueError(f"{self.ref_operation} not found in other pipelines") + if len(entries) > 1: + raise ValueError(f"Reference to {self.ref_operation} ambiguous") + + self._parent = entries[0] + + @property + def parent(self): + assert self._parent is not None + return self._parent + + @property + def is_jitted(self): + # RefNodes can be either jitted or not, + # whatever produces smaller pipelines + return None + + @property + def arg_id(self): + return None # Ref's don't have arguments + + def generate_code(self): + def nop(*args, **kwargs): + return None + + @property + def func_call_ast(self): + return ast.Pass() + + @property + def result_id(self): + return self.parent.result_id + + +class Graph: + + def __init__(self, pipeline_specs: Dict[str, PipelineSpec], handlers, + fieldname_to_fix, metadata, memory_read): + + self.memory_read = memory_read + self.handlers = handlers + self.fieldname_to_fix = fieldname_to_fix + self.metadata = metadata + self.pipeline_specs = pipeline_specs + self.nodes: List[Node] = [] + self.root_nodes: Dict[Node, str] = {} + self.leaf_nodes: Dict[str, Node] = {} + self.operation_to_node = defaultdict(list) + self.id_to_node = {} + self.node_to_id = {} + + # Filling the default decoders + for output_name, spec in pipeline_specs.items(): + if spec.source in self.handlers: + field = self.handlers[spec.source] + Decoder = field.get_decoder_class() + spec.accept_decoder(Decoder, output_name) + + # registering nodes + for output_name, spec in pipeline_specs.items(): + if spec.source is None: + raise ValueError(f"Field {output_name} has no source") + + source = spec.source + # This pipeline starts with a decoder + if isinstance(source, str): + assert spec.decoder is not None + node = DecoderNode(source, spec.decoder, fieldname_to_fix[source]) + self.operation_to_node[spec.decoder].append(node) + self.root_nodes[node] = source + else: + node = RefNode(source) + assert spec.decoder is None + + self.nodes.append(node) + + for operation in spec.transforms: + node = TransformNode(node, operation) + self.operation_to_node[operation].append(node) + self.nodes.append(node) + + self.leaf_nodes[output_name] = node + + # resolve references + for node in self.nodes: + if isinstance(node, RefNode): + node.resolve_operation(self.operation_to_node) + + # Filling the adjacency list + self.adjacency_list = defaultdict(list) + for node in self.nodes: + self.id_to_node[node.id] = node + self.node_to_id[node] = node.id + if node.parent is not None: + self.adjacency_list[node.parent].append(node) + + + def collect_requirements(self, state=INITIAL_STATE, + current_node: Node = None, + allocations: Dict[int, Optional[AllocationQuery]] = None, + code: Dict[int, Optional[Callable]] = None, + source_field:str = None): + + if allocations is None: + allocations: Dict[int, Optional[AllocationQuery]] = { + 'shared': {}, + 'operation': {} + } + if code is None: + code: Dict[int, Optional[Callable]] = { + 'shared': {}, + 'operation': {} + } + next_state = state + if current_node is None: + next_nodes = self.root_nodes.keys() + else: + if not isinstance(current_node, RefNode): + if isinstance(current_node, TransformNode): + operation = current_node.operation + else: + operation = current_node.decoder + + if isinstance(current_node, DecoderNode): + source_field = current_node.field_name + + fix = self.fieldname_to_fix[source_field] + metadata = self.metadata[f'f{fix}'] + + operation.accept_field(self.handlers[source_field]) + operation.accept_globals(metadata, self.memory_read) + + next_state, allocation = operation.declare_state_and_memory(state) + state_allocation = operation.declare_shared_memory(state) + + if next_state.device.type != 'cuda' and isinstance(operation, + ModuleWrapper): + msg = ("Using a pytorch transform on the CPU is extremely" + "detrimental to the performance, consider moving the augmentation" + "on the GPU or using an FFCV native transform") + warnings.warn(msg, ResourceWarning) + + + if isinstance(current_node, TransformNode): + current_node.jitted = next_state.jit_mode + + allocations['operation'][current_node.id] = allocation + allocations['shared'][current_node.id] = state_allocation + code['operation'][current_node.id] = operation.generate_code() + code['shared'][current_node.id] = operation.generate_code_for_shared_state() + + next_nodes = self.adjacency_list[current_node] + + for node in next_nodes: + self.collect_requirements(next_state, node, allocations, code, source_field=source_field) + + return allocations, code + + def allocate_memory(self, batch_size, batches_ahead): + + memory_buffers = defaultdict(dict) + full_memory_requirements, _ = self.collect_requirements() + + for kind, requirements in full_memory_requirements.items(): + for node_id, memory_allocation in requirements.items(): + # If the operation didn't make a query we stop here + allocated_buffer = None + if isinstance(memory_allocation, AllocationQuery): + allocated_buffer = allocate_query(memory_allocation, + batch_size, + batches_ahead) + elif isinstance(memory_allocation, Sequence): + allocated_buffer = tuple( + allocate_query(q, batch_size, batches_ahead) for q in memory_allocation + ) + + memory_buffers[kind][node_id] = allocated_buffer + + return memory_buffers + + def group_operations(self): + current_front: Set[Node] = set() + next_front: Set[Node] = set() + stages = [] + + for node in self.root_nodes.keys(): + current_front.add(node) + + + while current_front: + current_stage = list() + jitted_stage = len(stages) % 2 == 0 + + while current_front: + node = current_front.pop() + if node.is_jitted == jitted_stage or node.is_jitted is None: + current_stage.append(self.node_to_id[node]) + current_front.update(set(self.adjacency_list[node])) + + else: + next_front.add(node) + + stages.append(current_stage) + current_front = next_front + + return stages + + def codegen_stage(self, stage:List[Node], s_ix:int, op_to_node, code, already_defined): + fun_name = f"stage_code_{s_ix}" + base_code = ast.parse(f""" +def {fun_name}(batch_indices, metadata, storage_state): + pass + """).body[0] + + + base_code.args.args.extend([ + ast.arg(arg=f'memory_{x}') for x in code['operation'] + ]) + + base_code.args.args.extend([ + ast.arg(arg=f'shared_memory_{x}') for x in code['shared'] + ]) + + base_code.args.args.extend([ + ast.arg(f'result_{x}') for x in already_defined + ]) + + return_tuple = ast.Return(value=ast.Tuple(elts=[], ctx=ast.Load())) + + defined_here = [] + + base_code.body.pop() + compiled_functions = {} + for node_id in stage: + node: Node = self.id_to_node[node_id] + has_shared_state = node_id in code['shared'] and code['shared'][node_id] is not None + + try: + compiled_functions[f'code_{node_id}'] = code['operation'][node_id] + except KeyError: + pass # No code for this node + + func_call_ast = node.func_call_ast + if has_shared_state: + fname = f'init_shared_state_code_{node_id}' + compiled_functions[fname] = code['shared'][node_id] + base_code.body.append(node.get_shared_code_ast(op_to_node)) + func_call_ast.value.args.extend([ + ast.Name(id=f'shared_memory_{op_to_node[node.operation]}', ctx=ast.Load()), + ]) + + base_code.body.append(func_call_ast) + return_tuple.value.elts.append(ast.Name(id=node.result_id, ctx=ast.Load())) + already_defined.append(node.id) + defined_here.append(node.id) + + # If the stage is even we are compiling it + if s_ix % 2 == 0: + compiled_functions = {k: Compiler.compile(v) for (k, v) in compiled_functions.items()} + + base_code.body.append(return_tuple) + + module = ast.fix_missing_locations( + ast.Module(body=[base_code], + type_ignores=[]) + ) + + # print(astor.to_source(base_code)) + namespace = { + **compiled_functions + } + + exec(compile(module, '', 'exec'), namespace) + final_code = namespace[fun_name] + return final_code, defined_here + + + def codegen_all(self, code): + stages = self.group_operations() + code_stages = [] + already_defined = [] + + # Set of operations that already had their state initialized + # (We do not want to have their random state reset) + op_to_node = {} + + for s_ix, stage in enumerate(stages): + code_stages.append(self.codegen_stage(stage, s_ix, op_to_node, code, already_defined)) + + final_output = [x.id for x in self.leaf_nodes.values()] + return code_stages, final_output \ No newline at end of file diff --git a/ffcv/pipeline/operation.py b/ffcv/pipeline/operation.py index 0c763a99..8ad947e8 100644 --- a/ffcv/pipeline/operation.py +++ b/ffcv/pipeline/operation.py @@ -10,25 +10,32 @@ if TYPE_CHECKING: from ..fields.base import Field + class Operation(ABC): def __init__(self): self.metadata: np.ndarray = None self.memory_read: Callable[[np.uint64], np.ndarray] = None pass - - def accept_field(self, field:'Field'): + + def accept_field(self, field: 'Field'): self.field: 'Field' = field - + def accept_globals(self, metadata, memory_read): self.metadata = metadata self.memory_read = memory_read - + # Return the code to run this operation @abstractmethod def generate_code(self) -> Callable: raise NotImplementedError - + + def declare_shared_memory(self, previous_state: State) -> Optional[AllocationQuery]: + return None + + def generate_code_for_shared_state(self) -> Optional[Callable]: + return None + @abstractmethod - def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: + def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: raise NotImplementedError diff --git a/ffcv/pipeline/pipeline_spec.py b/ffcv/pipeline/pipeline_spec.py new file mode 100644 index 00000000..71d66e75 --- /dev/null +++ b/ffcv/pipeline/pipeline_spec.py @@ -0,0 +1,51 @@ +import torch as ch + +from typing import List, Union +from .operation import Operation +from ..transforms.module import ModuleWrapper +from ..transforms import ToTensor + +class PipelineSpec: + + def __init__(self, source: Union[str, Operation], decoder: Operation = None, + transforms:List[Operation] = None ): + + self.source = source + self.decoder = decoder + if transforms is None: + transforms = [] + self.transforms = transforms + self.default_pipeline = (decoder is None + and not transforms + and isinstance(source, str)) + + def __repr__(self): + return repr((self.source, self.decoder, self.transforms)) + + def __str__(self): + return self.__repr__() + + def accept_decoder(self, Decoder, output_name): + if not isinstance(self.source, str) and self.decoder is not None: + raise ValueError("Source can't be a node and also have a decoder") + + if Decoder is not None: + # The first element of the operations is a decoder + if self.transforms and isinstance(self.transforms[0], Decoder): + self.decoder = self.transforms.pop(0) + + elif self.decoder is None: + try: + self.decoder = Decoder() + except Exception: + msg = f"Impossible to use default decoder for {output_name}," + msg += "make sure you specify one in your pipeline." + raise ValueError(msg) + + if self.default_pipeline: + self.transforms.append(ToTensor()) + + for i, op in enumerate(self.transforms): + if isinstance(op, ch.nn.Module): + self.transforms[i] = ModuleWrapper(op) + \ No newline at end of file diff --git a/ffcv/transforms/__init__.py b/ffcv/transforms/__init__.py index bc8fa321..2636a447 100644 --- a/ffcv/transforms/__init__.py +++ b/ffcv/transforms/__init__.py @@ -9,6 +9,7 @@ from .translate import RandomTranslate from .mixup import ImageMixup, LabelMixup, MixupToOneHot from .module import ModuleWrapper +from .color_jitter import RandomBrightness, RandomContrast, RandomSaturation __all__ = ['ToTensor', 'ToDevice', 'ToTorchImage', 'NormalizeImage', @@ -16,4 +17,5 @@ 'RandomResizedCrop', 'RandomHorizontalFlip', 'RandomTranslate', 'Cutout', 'ImageMixup', 'LabelMixup', 'MixupToOneHot', 'Poison', 'ReplaceLabel', - 'ModuleWrapper'] \ No newline at end of file + 'ModuleWrapper', + 'RandomBrightness', 'RandomContrast', 'RandomSaturation'] diff --git a/ffcv/transforms/color_jitter.py b/ffcv/transforms/color_jitter.py new file mode 100644 index 00000000..a79b72fd --- /dev/null +++ b/ffcv/transforms/color_jitter.py @@ -0,0 +1,139 @@ +''' +Random color operations similar to torchvision.transforms.ColorJitter except not supporting hue +Reference : https://github.com/pytorch/vision/blob/main/torchvision/transforms/functional_tensor.py +''' + +import numpy as np + +from dataclasses import replace +from ..pipeline.allocation_query import AllocationQuery +from ..pipeline.operation import Operation +from ..pipeline.state import State +from ..pipeline.compiler import Compiler + + + +class RandomBrightness(Operation): + ''' + Randomly adjust image brightness. Operates on raw arrays (not tensors). + + Parameters + ---------- + magnitude : float + randomly choose brightness enhancement factor on [max(0, 1-magnitude), 1+magnitude] + p : float + probability to apply brightness + ''' + def __init__(self, magnitude: float, p=0.5): + super().__init__() + self.p = p + self.magnitude = magnitude + + def generate_code(self): + my_range = Compiler.get_iterator() + p = self.p + magnitude = self.magnitude + + def brightness(images, *_): + def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype) + + apply_bright = np.random.rand(images.shape[0]) < p + magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0]) + for i in my_range(images.shape[0]): + if apply_bright[i]: + images[i] = blend(images[i], 0, magnitudes[i]) + + return images + + brightness.is_parallel = True + return brightness + + def declare_state_and_memory(self, previous_state): + return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype)) + + + +class RandomContrast(Operation): + ''' + Randomly adjust image contrast. Operates on raw arrays (not tensors). + + Parameters + ---------- + magnitude : float + randomly choose contrast enhancement factor on [max(0, 1-magnitude), 1+magnitude] + p : float + probability to apply contrast + ''' + def __init__(self, magnitude, p=0.5): + super().__init__() + self.p = p + self.magnitude = magnitude + + def generate_code(self): + my_range = Compiler.get_iterator() + p = self.p + magnitude = self.magnitude + + def contrast(images, *_): + def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype) + + apply_contrast = np.random.rand(images.shape[0]) < p + magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0]) + for i in my_range(images.shape[0]): + if apply_contrast[i]: + r, g, b = images[i,:,:,0], images[i,:,:,1], images[i,:,:,2] + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).astype(images[i].dtype) + images[i] = blend(images[i], l_img.mean(), magnitudes[i]) + + return images + + contrast.is_parallel = True + return contrast + + def declare_state_and_memory(self, previous_state): + return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype)) + + + +class RandomSaturation(Operation): + ''' + Randomly adjust image color balance. Operates on raw arrays (not tensors). + + Parameters + ---------- + magnitude : float + randomly choose color balance enhancement factor on [max(0, 1-magnitude), 1+magnitude] + p : float + probability to apply saturation + ''' + def __init__(self, magnitude, p=0.5): + super().__init__() + self.p = p + self.magnitude = magnitude + + def generate_code(self): + my_range = Compiler.get_iterator() + p = self.p + magnitude = self.magnitude + + def saturation(images, *_): + def blend(img1, img2, ratio): return (ratio*img1 + (1-ratio)*img2).clip(0, 255).astype(img1.dtype) + + apply_saturation = np.random.rand(images.shape[0]) < p + magnitudes = np.random.uniform(max(0, 1-magnitude), 1+magnitude, images.shape[0]) + for i in my_range(images.shape[0]): + if apply_saturation[i]: + r, g, b = images[i,:,:,0], images[i,:,:,1], images[i,:,:,2] + l_img = (0.2989 * r + 0.587 * g + 0.114 * b).astype(images[i].dtype) + l_img3 = np.zeros_like(images[i]) + for j in my_range(images[i].shape[-1]): + l_img3[:,:,j] = l_img + images[i] = blend(images[i], l_img3, magnitudes[i]) + + return images + + saturation.is_parallel = True + return saturation + + def declare_state_and_memory(self, previous_state): + return (replace(previous_state, jit_mode=True), AllocationQuery(previous_state.shape, previous_state.dtype)) diff --git a/ffcv/transforms/ops.py b/ffcv/transforms/ops.py index cec29084..2b8bad5e 100644 --- a/ffcv/transforms/ops.py +++ b/ffcv/transforms/ops.py @@ -42,6 +42,8 @@ class ToDevice(Operation): def __init__(self, device, non_blocking=True): super().__init__() self.device = device + # assert isinstance(device, ch.device), \ + # f'Make sure device is a ch.device (not a {type(device)})' self.non_blocking = non_blocking def generate_code(self) -> Callable: @@ -155,4 +157,4 @@ def convert(inp, dst): return convert def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]: - return replace(previous_state, dtype=self.target_dtype, jit_mode=False), None \ No newline at end of file + return replace(previous_state, dtype=self.target_dtype, jit_mode=False), None diff --git a/ffcv/transforms/random_resized_crop.py b/ffcv/transforms/random_resized_crop.py index 24d2403a..5a7405c5 100644 --- a/ffcv/transforms/random_resized_crop.py +++ b/ffcv/transforms/random_resized_crop.py @@ -30,17 +30,24 @@ def __init__(self, scale: Tuple[float, float], ratio: Tuple[float, float], size: self.scale = scale self.ratio = ratio self.size = size + def generate_code(self) -> Callable: - scale, ratio = np.array(self.scale), np.array(self.ratio) + scale, ratio = self.scale, self.ratio + if isinstance(scale, tuple): + scale = np.array(scale) + if isinstance(ratio, tuple): + ratio = np.array(ratio) my_range = Compiler.get_iterator() - def random_resized_crop(im, dst): - n, h, w, _ = im.shape - for ind in my_range(n): - i, j, c_h, c_w = fast_crop.get_random_crop(h, w, scale, ratio) - fast_crop.resize_crop(im[ind], i, i + c_h, j, j + c_w, dst[ind]) + def random_resized_crop(images, dst): + for idx in my_range(images.shape[0]): + i, j, h, w = fast_crop.get_random_crop(images[idx].shape[0], + images[idx].shape[1], + scale, + ratio) + fast_crop.resize_crop(images[idx], i, i + h, j, j + w, dst[idx]) return dst - + random_resized_crop.is_parallel = True return random_resized_crop diff --git a/ffcv/transforms/translate.py b/ffcv/transforms/translate.py index a40890b6..f9efebf8 100644 --- a/ffcv/transforms/translate.py +++ b/ffcv/transforms/translate.py @@ -30,12 +30,14 @@ def __init__(self, padding: int, fill: Tuple[int, int, int] = (0, 0, 0)): def generate_code(self) -> Callable: my_range = Compiler.get_iterator() pad = self.padding + fill = self.fill def translate(images, dst): n, h, w, _ = images.shape + dst[:] = fill + dst[:, pad:pad+h, pad:pad+w] = images for i in my_range(n): dst[i] = 0 - dst[i, pad:pad+h, pad:pad+w] = images[i] y_coord = randint(low=0, high=2 * pad + 1) x_coord = randint(low=0, high=2 * pad + 1) images[i] = dst[i, y_coord:y_coord+h, x_coord:x_coord+w] diff --git a/ffcv/transforms/utils/fast_crop.py b/ffcv/transforms/utils/fast_crop.py index 3b3f2af3..34cb7835 100644 --- a/ffcv/transforms/utils/fast_crop.py +++ b/ffcv/transforms/utils/fast_crop.py @@ -48,4 +48,4 @@ def get_center_crop(height, width, ratio): delta_h = (height - c) // 2 delta_w = (width - c) // 2 - return delta_h, delta_w, c, c \ No newline at end of file + return delta_h, delta_w, c, c diff --git a/ffcv/traversal_order/base.py b/ffcv/traversal_order/base.py index 301d3658..74f1a70b 100644 --- a/ffcv/traversal_order/base.py +++ b/ffcv/traversal_order/base.py @@ -13,7 +13,8 @@ def __init__(self, loader: 'Loader'): self.indices = self.loader.indices self.seed = self.loader.seed self.distributed = loader.distributed + self.sampler = None @abstractmethod def sample_order(self, epoch:int) -> Sequence[int]: - raise NotImplemented() \ No newline at end of file + raise NotImplemented() diff --git a/ffcv/types.py b/ffcv/types.py index 2b669ebc..3b2123f0 100644 --- a/ffcv/types.py +++ b/ffcv/types.py @@ -6,7 +6,8 @@ from .fields.base import Field from .fields import ( FloatField, IntField, RGBImageField, - BytesField, NDArrayField, JSONField + BytesField, NDArrayField, JSONField, + TorchTensorField ) CURRENT_VERSION = 2 @@ -49,7 +50,8 @@ 2 : RGBImageField, 3 : BytesField, 4 : NDArrayField, - 5 : JSONField + 5 : JSONField, + 6 : TorchTensorField } # Parse the fields descriptors from the header of the dataset diff --git a/ffcv/writer.py b/ffcv/writer.py index a783d954..1b70f74f 100644 --- a/ffcv/writer.py +++ b/ffcv/writer.py @@ -17,6 +17,7 @@ MIN_PAGE_SIZE = 1 << 21 # 2MiB, which is the most common HugePage size +MAX_PAGE_SIZE = 1 << 32 # Biggest page size that will not overflow uint32 def from_shard(shard, pipeline): # We import webdataset here so that it desn't crash if it's not required @@ -148,6 +149,8 @@ def __init__(self, fname: str, fields: Mapping[str, Field], raise ValueError(f'page_size isnt a power of 2') if page_size < MIN_PAGE_SIZE: raise ValueError(f"page_size can't be lower than{MIN_PAGE_SIZE}") + if page_size >= MAX_PAGE_SIZE: + raise ValueError(f"page_size can't be bigger(or =) than{MAX_PAGE_SIZE}") self.page_size = page_size @@ -330,8 +333,8 @@ def finalize(self, allocations) : # Retrieve all the allocations from the workers # Turn them into a numpy array try: - allocation_table = np.concatenate([ - np.array(x).view(ALLOC_TABLE_TYPE) for x in allocations if len(x) + allocation_table = np.array([ + np.array(x, dtype=ALLOC_TABLE_TYPE) for x in allocations if len(x) ]) except: allocation_table = np.array([]).view(ALLOC_TABLE_TYPE) diff --git a/libffcv/libffcv.cpp b/libffcv/libffcv.cpp index 7bae23ba..db4798d1 100644 --- a/libffcv/libffcv.cpp +++ b/libffcv/libffcv.cpp @@ -8,6 +8,13 @@ #include #include #include +#ifdef _WIN32 + typedef unsigned __int32 __uint32_t; + typedef unsigned __int64 __uint64_t; + #define EXPORT __declspec(dllexport) +#else + #define EXPORT +#endif extern "C" { // a key use to point to the tjtransform instance @@ -23,7 +30,7 @@ extern "C" { pthread_key_create(&key_tj_transformer, NULL); } - void resize(int64_t cresizer, int64_t source_p, int64_t sx, int64_t sy, + EXPORT void resize(int64_t cresizer, int64_t source_p, int64_t sx, int64_t sy, int64_t start_row, int64_t end_row, int64_t start_col, int64_t end_col, int64_t dest_p, int64_t tx, int64_t ty) { // TODO use proper arguments type @@ -34,16 +41,16 @@ extern "C" { dest_matrix, dest_matrix.size(), 0, 0, cv::INTER_AREA); } - void my_memcpy(void *source, void* dst, uint64_t size) { + EXPORT void my_memcpy(void *source, void* dst, uint64_t size) { memcpy(dst, source, size); } - void my_fread(int64_t fp, int64_t offset, void *destination, int64_t size) { + EXPORT void my_fread(int64_t fp, int64_t offset, void *destination, int64_t size) { fseek((FILE *) fp, offset, SEEK_SET); fread(destination, 1, size, (FILE *) fp); } - int imdecode(unsigned char *input_buffer, __uint64_t input_size, + EXPORT int imdecode(unsigned char *input_buffer, __uint64_t input_size, __uint32_t source_height, __uint32_t source_width, unsigned char *output_buffer, diff --git a/setup.py b/setup.py index f8c09970..d40aa2ab 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,9 @@ from setuptools import find_packages import subprocess +from difflib import get_close_matches from glob import glob +import os +import platform from distutils.core import setup, Extension @@ -10,12 +13,65 @@ long_description = (this_directory / "README.md").read_text() +def find_pkg_dirs(package): + close_matches = get_close_matches(package.lower(), + os.environ["PATH"].lower().split(';'), + cutoff=0) + dll_dir = None + for close_match in close_matches: + if (os.path.exists(close_match) + and glob(os.path.join(close_match, '*.dll'))): + dll_dir = close_match + break + if dll_dir is None: + raise Exception( + f"Could not find required package: {package}. " + "Add directory containing .dll files to system environment path." + ) + dll_dir_split = dll_dir.replace('\\', '/').split('/') + root = get_close_matches(package.lower(), dll_dir_split, cutoff=0)[0] + root_dir = '/'.join(dll_dir_split[:dll_dir_split.index(root) + 1]) + return os.path.normpath(root_dir), os.path.normpath(dll_dir) + + +def pkgconfig_windows(package, kw): + is_x64 = platform.machine().endswith('64') + root_dir, dll_dir = find_pkg_dirs(package) + include_dir = None + library_dir = None + parent = None + while parent != root_dir: + parent = os.path.dirname(dll_dir if parent is None else parent) + if include_dir is None and os.path.exists(os.path.join(parent, 'include')): + include_dir = os.path.join(parent, 'include') + library_dirs = set() + libraries = glob(os.path.join(parent, '**', 'lib', '**', '*.lib'), + recursive=True) + for library in libraries: + if ((is_x64 and 'x86' in library) + or (not is_x64 and 'x64' in library)): + continue + library_dirs.add(os.path.dirname(library)) + if library_dir is None and library_dirs: + library_dir = sorted(library_dirs)[-1] + if include_dir and library_dir: + libraries = [os.path.splitext(library)[0] + for library in glob(os.path.join(library_dir, '*.lib'))] + break + if not include_dir or not library_dir: + raise Exception(f"Could not find required package: {package}.") + kw.setdefault('include_dirs', []).append(include_dir) + kw.setdefault('library_dirs', []).append(library_dir) + kw.setdefault('libraries', []).extend(libraries) + return kw + + def pkgconfig(package, kw): flag_map = {'-I': 'include_dirs', '-L': 'library_dirs', '-l': 'libraries'} output = subprocess.getoutput( 'pkg-config --cflags --libs {}'.format(package)) if 'not found' in output: - raise Exception(f"Could not find required package: {package}.") + raise RuntimeError(f"Could not find required package: {package}.") for token in output.strip().split(): kw.setdefault(flag_map.get(token[:2]), []).append(token[2:]) return kw @@ -27,20 +83,29 @@ def pkgconfig(package, kw): 'sources': sources, 'include_dirs': [] } -extension_kwargs = pkgconfig('opencv4', extension_kwargs) -extension_kwargs = pkgconfig('libturbojpeg', extension_kwargs) +if platform.system() == 'Windows': + extension_kwargs = pkgconfig_windows('opencv4', extension_kwargs) + extension_kwargs = pkgconfig_windows('libturbojpeg', extension_kwargs) + + extension_kwargs = pkgconfig_windows('pthread', extension_kwargs) +else: + try: + extension_kwargs = pkgconfig('opencv4', extension_kwargs) + except RuntimeError: + extension_kwargs = pkgconfig('opencv', extension_kwargs) + extension_kwargs = pkgconfig('libturbojpeg', extension_kwargs) -extension_kwargs['libraries'].append('pthread') + extension_kwargs['libraries'].append('pthread') libffcv = Extension('ffcv._libffcv', **extension_kwargs) setup(name='ffcv', - version='0.0.3rc1', + version='1.0.0', description=' FFCV: Fast Forward Computer Vision ', author='MadryLab', - author_email='leclerc@mit.edu', + author_email='ffcv@mit.edu', url='https://github.com/libffcv/ffcv', license_files = ('LICENSE.txt',), packages=find_packages(), @@ -49,14 +114,11 @@ def pkgconfig(package, kw): ext_modules=[libffcv], install_requires=[ 'terminaltables', - 'pytorch_pfn_extras', - 'fastargs', - 'matplotlib', - 'scikit-learn', - 'pandas', - 'assertpy', - 'tqdm', - 'psutil', - 'webdataset', - ] - ) + 'pytorch_pfn_extras', + 'fastargs', + 'opencv-python', + 'assertpy', + 'tqdm', + 'psutil', + 'numba', + ]) diff --git a/tests/test_array_field.py b/tests/test_array_field.py index 6cfe57a7..85cd1b14 100644 --- a/tests/test_array_field.py +++ b/tests/test_array_field.py @@ -4,18 +4,20 @@ from assertpy.assertpy import assert_that from multiprocessing import cpu_count +import torch as ch from assertpy import assert_that import numpy as np from torch.utils.data import Dataset from ffcv import DatasetWriter -from ffcv.fields import IntField, NDArrayField +from ffcv.fields import IntField, NDArrayField, TorchTensorField from ffcv import Loader class DummyActivationsDataset(Dataset): - def __init__(self, n_samples, shape): + def __init__(self, n_samples, shape, is_ch=False): self.n_samples = n_samples self.shape = shape + self.is_ch = is_ch def __len__(self): return self.n_samples @@ -24,7 +26,11 @@ def __getitem__(self, index): if index >= self.n_samples: raise IndexError() np.random.seed(index) - return index, np.random.randn(*self.shape).astype('