Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions fasttransform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

# %% ../nbs/01_transform.ipynb 1
from typing import Any
import inspect

from fastcore.imports import *
from fastcore.foundation import *
Expand Down Expand Up @@ -63,7 +64,11 @@ def __call__(cls, *args, **kwargs):
if not hasattr(cls, nm): setattr(cls, nm, Function(f).dispatch(f))
else: getattr(cls,nm).dispatch(f)
return cls
return super().__call__(*args, **kwargs)
obj = super().__call__(*args, **kwargs)
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
# instances of cls, fix it
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
return obj


def __new__(cls, name, bases, namespace):
Expand All @@ -73,6 +78,8 @@ def __new__(cls, name, bases, namespace):
funcs = [getattr(new_cls, nm)] + [getattr(b, nm,None) for b in bases]
funcs = [f for f in funcs if f]
if funcs: setattr(new_cls, nm, _merge_funcs(*funcs))
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
new_cls.__signature__ = inspect.signature(new_cls.__init__)
return new_cls

# %% ../nbs/01_transform.ipynb 14
Expand Down Expand Up @@ -130,21 +137,21 @@ def _do_call(self, nm, *args, **kwargs):

add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform")

# %% ../nbs/01_transform.ipynb 155
# %% ../nbs/01_transform.ipynb 156
class InplaceTransform(Transform):
"A `Transform` that modifies in-place and just returns whatever it's passed"
def _call(self, fn, *args, split_idx=None, **kwargs):
super()._call(fn,*args, split_idx=split_idx, **kwargs)
return args[0]

# %% ../nbs/01_transform.ipynb 159
# %% ../nbs/01_transform.ipynb 160
class DisplayedTransform(Transform):
"A transform with a `__repr__` that shows its attrs"

@property
def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}\n"

# %% ../nbs/01_transform.ipynb 165
# %% ../nbs/01_transform.ipynb 166
class ItemTransform(Transform):
"A transform that always take tuples as items"
_retain = True
Expand All @@ -158,21 +165,21 @@ def _call1(self, x, name, **kwargs):
return retain_type(y, x, Any)


# %% ../nbs/01_transform.ipynb 174
# %% ../nbs/01_transform.ipynb 175
def get_func(t, name, *args, **kwargs):
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
f = nested_callable(t, name)
return f if not (args or kwargs) else partial(f, *args, **kwargs)

# %% ../nbs/01_transform.ipynb 178
# %% ../nbs/01_transform.ipynb 179
class Func():
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
def __call__(self,t): return mapped(self._get, t)

# %% ../nbs/01_transform.ipynb 181
# %% ../nbs/01_transform.ipynb 182
class _Sig():
def __getattr__(self,k):
def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
Expand All @@ -181,7 +188,7 @@ def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
Sig = _Sig()


# %% ../nbs/01_transform.ipynb 187
# %% ../nbs/01_transform.ipynb 188
def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
if reverse: tfms = reversed(tfms)
Expand All @@ -191,13 +198,13 @@ def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
return x


# %% ../nbs/01_transform.ipynb 192
# %% ../nbs/01_transform.ipynb 193
def mk_transform(f):
"Convert function `f` to `Transform` if it isn't already one"
f = instantiate(f)
return f if isinstance(f,(Transform,Pipeline)) else Transform(f)

# %% ../nbs/01_transform.ipynb 193
# %% ../nbs/01_transform.ipynb 194
def gather_attrs(o, k, nm):
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
if k.startswith('_') or k==nm: raise AttributeError(k)
Expand All @@ -206,12 +213,12 @@ def gather_attrs(o, k, nm):
if not res: raise AttributeError(k)
return res[0] if len(res)==1 else L(res)

# %% ../nbs/01_transform.ipynb 194
# %% ../nbs/01_transform.ipynb 195
def gather_attr_names(o, nm):
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
return L(getattr(o,nm)).map(dir).concat().unique()

# %% ../nbs/01_transform.ipynb 195
# %% ../nbs/01_transform.ipynb 196
class Pipeline:
"A pipeline of composed (for encode/decode) transforms, setup with types"
def __init__(self, funcs=None, split_idx=None):
Expand Down
189 changes: 186 additions & 3 deletions nbs/01_transform.ipynb

Large diffs are not rendered by default.

20 changes: 18 additions & 2 deletions nbs/fastcore_migration_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|eval: false\n",
"from fastcore.dispatch import typedispatch\n",
"\n",
"@typedispatch \n",
Expand Down Expand Up @@ -186,6 +187,7 @@
}
],
"source": [
"#|eval: false\n",
"from fastcore.dispatch import TypeDispatch\n",
"t_fc = TypeDispatch(fs)\n",
"t_fc"
Expand Down Expand Up @@ -265,6 +267,7 @@
}
],
"source": [
"#|eval: false\n",
"t_fc.add(lambda x: x**2)\n",
"t_fc"
]
Expand Down Expand Up @@ -373,6 +376,16 @@
"Before:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff9599da",
"metadata": {},
"outputs": [],
"source": [
"def f_str(x:str): return x+'1'"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -399,8 +412,7 @@
}
],
"source": [
"def f_str(x:str): return x+'1'\n",
"\n",
"#|eval: false\n",
"t_fc2 = TypeDispatch(f_str, bases=t_fc)\n",
"t_fc2"
]
Expand Down Expand Up @@ -539,6 +551,7 @@
}
],
"source": [
"#|eval: false\n",
"t_fc[int]"
]
},
Expand All @@ -559,6 +572,7 @@
"metadata": {},
"outputs": [],
"source": [
"#|eval: false\n",
"t_fc.returns(5)"
]
},
Expand Down Expand Up @@ -661,6 +675,7 @@
}
],
"source": [
"#|eval: false\n",
"@typedispatch\n",
"def f2_fc(x:int|float): return x+2\n",
"@typedispatch\n",
Expand Down Expand Up @@ -776,6 +791,7 @@
}
],
"source": [
"#|eval: false\n",
"# Before (subclassing required)\n",
"from fastcore.transform import Transform as FCTransform\n",
"\n",
Expand Down