from __future__ import annotations
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.nb_imports import *
Cast
This module contains some of fastcore.dispatch
’s utility functions for type casting. We copy them over here as with fasttransform
’s release those modules may be removed from fastcore.
The functions here have not been changed, except for retain_type
, which has the same functionality but now accepts the type hints as Plum dispatch provides them instead of fastcore.dispatch
’s convention.
Type casting
Some objects may have a set_meta
method, such as fastai.torch_core.Tensor
. When casting these to another type we want to preserve metadata.
retain_meta
retain_meta (x, res, as_copy=False)
Call res.set_meta(x)
, if it exists
default_set_meta
default_set_meta (x, as_copy=False)
Copy over _meta
from x
to res
, if it’s missing
cast
cast (x, typ)
cast x
to type typ
(may also change x
inplace)
This works both for plain python classes…
'_T1', 'a') # mk_class is a fastai utility that constructs a class.
mk_class(class _T2(_T1): pass
= _T1(a=1)
t = cast(t, _T2)
t2 assert t2 is t # t2 refers to the same object as t
assert isinstance(t, _T2) # t also changed in-place
assert isinstance(t2, _T2)
=1), t2) test_eq_type(_T2(a
…as well as for arrays and tensors.
class _T1(ndarray): pass
= array([1])
t = cast(t, _T1)
t2 1]), t2)
test_eq(array([type(t2)) test_eq(_T1,
Retain type
Retain type is a function that’s useful for postprocessing function outputs. They are used in the Transform
class.
The conversion priorities are as follows:
- the function’s return type annotation
ret_type
- if there’s no return type annotation (i.e.
ret_type=Any
) then it will convert back to the input’s (old
) type, but only if if it was a subtype of the return value. - if the function has return type annotation of None (
ret_type=None
) then no conversion will be done.
retain_type
retain_type (new, old, ret_type=typing.Any, as_copy=False)
Cast new
to ret_type
if given, or old
’s type if new
is a superclass of old
. No conversion is done if ret_type=None
Return type annotation conversion
We try and convert new to the return type if it’s given.
class FS(float):
def __repr__(self): return f'FS({float(self)})'
1., 2., FS), FS(1.)) test_eq(retain_type(
Even if it won’t work, we’ll let the exception be raised:
# Raise error if return type is not compatible with new
try: retain_type("a", 2., FS)
except ValueError as e: print(f"Expected error: {e}")
Expected error: could not convert string to float: 'a'
Old type conversion
If the return type is Any
then new looks at old for conversion guidance.
1., FS(2.), Any), FS(1.)) test_eq(retain_type(
But if new isn’t subclass of old, keep new:
1.), 2.0, Any), FS(1.))
test_eq(retain_type(FS("a", 2.0, Any), "a") test_eq(retain_type(
No casting needed if new is already of type old. Then we return the original object.
= FS(1.)
x 2.), Any), x) test_is(retain_type(x, FS(
Edge cases with None
We dont convert at all if None is return type annotation:
1., FS(2.), NoneType), 1.) test_eq(retain_type(
None stays None:
None,FS(2.), Any), None) test_eq(retain_type(
If old was None then we just return new.
1.), None, Any), FS(1.)) test_eq(retain_type(FS(
Metadata retention
If old has a _meta attribute, its content is passed when casting new to the type of old. In the below example, only the attribute a, but not other_attr is kept, because other_attr is not in _meta:
class _A():
= default_set_meta
set_meta def __init__(self, t): self.t=t
class _B1(_A):
def __init__(self, t, a=1):
super().__init__(t)
self._meta = {'a':a}
self.other_attr = 'Hello' # will not be kept after casting.
= _B1(1, a=2)
x = _A(1)
b = retain_type(b, old=x)
c 'a': 2})
test_eq(c._meta, {assert not getattr(c, 'other_attr', None)
Retain types
Cast each item of new
to type of matching item in old
if it’s a superclass.
retain_types
retain_types (new, old=None, typs=None)
Cast each item of new
to type of matching item in old
if it’s a superclass
class T(tuple): pass
= retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))
t1,t2 1)
test_eq_type(t1, 1,T((1,1)))))
test_eq_type(t2, T((
= retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})
t1,t2 1)
test_eq_type(t1, 1,T((1,1))))) test_eq_type(t2, T((
explode_types
explode_types (o)
Return the type of o
, potentially in nested dictionaries for thing that are listy
2,T((2,T((3,4)))))), {tuple: [int, {T: [int, {T: [int,int]}]}]}) test_eq(explode_types((