Skip to content

Commit

Permalink
no more safe-copy
Browse files Browse the repository at this point in the history
  • Loading branch information
j042 committed Jan 24, 2024
1 parent eed9cac commit 8fd896e
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 159 deletions.
5 changes: 2 additions & 3 deletions gustaf/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(
vertices=None,
edges=None,
elements=None,
copy=True,
):
"""Edges. It has vertices and edges. Also known as lines.
Expand All @@ -94,7 +93,7 @@ def __init__(
vertices: (n, d) np.ndarray
edges: (n, 2) np.ndarray
"""
super().__init__(vertices=vertices, copy=copy)
super().__init__(vertices=vertices)

if edges is not None:
self.edges = edges
Expand Down Expand Up @@ -132,7 +131,7 @@ def edges(self, es):
self._logd("setting edges")

self._edges = helpers.data.make_tracked_array(
es, settings.INT_DTYPE, self.setter_copies
es, settings.INT_DTYPE, copy=False
)

# shape check
Expand Down
9 changes: 3 additions & 6 deletions gustaf/faces.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def __init__(
vertices=None,
faces=None,
elements=None,
copy=True,
):
"""Faces. It has vertices and faces. Faces could be triangles or
quadrilaterals.
Expand All @@ -107,7 +106,7 @@ def __init__(
vertices: (n, d) np.ndarray
faces: (n, 3) or (n, 4) np.ndarray
"""
super().__init__(vertices=vertices, copy=copy)
super().__init__(vertices=vertices)
if faces is not None:
self.faces = faces

Expand Down Expand Up @@ -176,9 +175,7 @@ def whatareyou(cls, face_obj):
)

@property
def faces(
self,
):
def faces(self):
"""Returns faces.
Parameters
Expand Down Expand Up @@ -209,7 +206,7 @@ def faces(self, fs):
self._faces = helpers.data.make_tracked_array(
fs,
settings.INT_DTYPE,
self.setter_copies,
copy=False,
)
# shape check
if fs is not None:
Expand Down
197 changes: 105 additions & 92 deletions gustaf/helpers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,66 @@


class TrackedArray(np.ndarray):
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`. Minor
adaption, since we don't have hashing functionalities.
All the inplace functions will set modified flag and if some operations
has potential to cause un-trackable behavior, writeable flags will be set
to False.
Note, if you really really want, it is possible to change the tracked
array without setting modified flag.
"""numpy array object that keeps mirroring inplace changes to the source.
Meant to help control_points.
"""

__slots__ = ("_modified", "_source")
__slots__ = (
"_super_arr",
"_modified",
)

def __array_finalize__(self, obj):
"""Sets default flags for any arrays that maybe generated based on
tracked array."""
physical space array. For more information,
see https://numpy.org/doc/stable/user/basics.subclassing.html"""
self._super_arr = None
self._modified = True
self._source = 0

# for arrays created based on this subclass
if isinstance(obj, type(self)):
if isinstance(obj._source, int):
self._source = obj
else:
self._source = obj._source
# this is copy. nothing to worry here
if self.base is None:
return None

# first child array
if self.base is obj:
# make sure this is not a recursively born child
# for example, `arr[[1,2]][:,2]`
# we should have set _super_arr to True
# if we made this array using `make_tracked_array`
if obj._super_arr is True:
self._super_arr = obj
return None

# multi generation child array
if obj._super_arr is not None and self.base is obj.base:
self._super_arr = obj._super_arr
return None

return None

@property
def mutable(self):
return self.flags["WRITEABLE"]
def modified(self):
"""
Modified flag getter
"""
# have super arr and self is not super_arr,
if self._super_arr is not None and self._super_arr is not True:
return self._super_arr._modified

@mutable.setter
def mutable(self, value):
self.flags.writeable = value
return self._modified

def _set_modified(self):
"""set modified flags to itself and to the source."""
self._modified = True
if isinstance(self._source, type(self)):
self._source._modified = True

def copy(self, *_args, **_kwargs):
"""copy gives np.ndarray.
@modified.setter
def modified(self, m):
if self._super_arr is not None and self._super_arr is not True:
self._super_arr._modified = m
else:
self._modified = m

no more tracking.
"""
return np.array(self, copy=True)
def copy(self, *args, **kwargs):
"""copy creates regular numpy array"""
return np.array(self, *args, copy=True, **kwargs)

def view(self, *args, **kwargs):
"""Set writeable flags to False for the view."""
Expand All @@ -67,89 +81,88 @@ def view(self, *args, **kwargs):
return v

def __iadd__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__iadd__(*args, **kwargs)
sr = super(self.__class__, self).__iadd__(*args, **kwargs)
self.modified = True
return sr

def __isub__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__isub__(*args, **kwargs)
sr = super(self.__class__, self).__isub__(*args, **kwargs)
self.modified = True
return sr

def __imul__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imul__(*args, **kwargs)
sr = super(self.__class__, self).__imul__(*args, **kwargs)
self.modified = True
return sr

def __idiv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__idiv__(*args, **kwargs)
sr = super(self.__class__, self).__idiv__(*args, **kwargs)
self.modified = True
return sr

def __itruediv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__itruediv__(*args, **kwargs)
sr = super(self.__class__, self).__itruediv__(*args, **kwargs)
self.modified = True
return sr

def __imatmul__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imatmul__(*args, **kwargs)
sr = super(self.__class__, self).__imatmul__(*args, **kwargs)
self.modified = True
return sr

def __ipow__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ipow__(*args, **kwargs)
sr = super(self.__class__, self).__ipow__(*args, **kwargs)
self.modified = True
return sr

def __imod__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__imod__(*args, **kwargs)
sr = super(self.__class__, self).__imod__(*args, **kwargs)
self.modified = True
return sr

def __ifloordiv__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ifloordiv__(*args, **kwargs)
sr = super(self.__class__, self).__ifloordiv__(*args, **kwargs)
self.modified = True
return sr

def __ilshift__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ilshift__(*args, **kwargs)
sr = super(self.__class__, self).__ilshift__(*args, **kwargs)
self.modified = True
return sr

def __irshift__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__irshift__(*args, **kwargs)
sr = super(self.__class__, self).__irshift__(*args, **kwargs)
self.modified = True
return sr

def __iand__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__iand__(*args, **kwargs)
sr = super(self.__class__, self).__iand__(*args, **kwargs)
self.modified = True
return sr

def __ixor__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ixor__(*args, **kwargs)
sr = super(self.__class__, self).__ixor__(*args, **kwargs)
self.modified = True
return sr

def __ior__(self, *args, **kwargs):
self._set_modified()
return super(self.__class__, self).__ior__(*args, **kwargs)

def __setitem__(self, *args, **kwargs):
self._set_modified()
super(self.__class__, self).__setitem__(*args, **kwargs)

def __setslice__(self, *args, **kwargs):
self._set_modified()
super(self.__class__, self).__setslice__(*args, **kwargs)
sr = super(self.__class__, self).__ior__(*args, **kwargs)
self.modified = True
return sr

def __getslice__(self, *args, **kwargs):
self._set_modified()
"""
return slices I am pretty sure np.ndarray does not have __*slice__
"""
slices = super(self.__class__, self).__getitem__(*args, **kwargs)
if isinstance(slices, np.ndarray):
slices.flags.writeable = False
return slices
def __setitem__(self, key, value):
# set first. invalid setting will cause error
sr = super(self.__class__, self).__setitem__(key, value)
self.modified = True
return sr


def make_tracked_array(array, dtype=None, copy=True):
"""Taken from nice implementations of `trimesh` (see LICENSE.txt).
"""Motivated by nice implementations of `trimesh` (see LICENSE.txt).
`https://github.com/mikedh/trimesh/blob/main/trimesh/caching.py`.
``Properly subclass a numpy ndarray to track changes.
Avoids some pitfalls of subclassing by forcing contiguous
arrays and does a view into a TrackedArray.``
Factory-like wrapper function for TrackedArray.
If you want to use TrackedArray, it is recommended to use this function.
Parameters
------------
Expand All @@ -168,16 +181,16 @@ def make_tracked_array(array, dtype=None, copy=True):
# if someone passed us None, just create an empty array
if array is None:
array = []
# make sure it is contiguous then view it as our subclass
tracked = np.ascontiguousarray(array, dtype=dtype)
tracked = (
tracked.copy().view(TrackedArray)
if copy
else tracked.view(TrackedArray)
)

# should always be contiguous here
assert tracked.flags["C_CONTIGUOUS"]
if copy:
array = np.array(array, dtype=dtype)
else:
array = np.asanyarray(array, dtype=dtype)

tracked = array.view(TrackedArray)

# this marks original array
tracked._super_arr = True

return tracked

Expand Down
Loading

0 comments on commit 8fd896e

Please sign in to comment.