zipline/tests/pipeline/test_term.py
2020-04-23 17:55:29 -04:00

821 lines
26 KiB
Python

"""
Tests for Term.
"""
from collections import Counter
from itertools import product
from unittest import TestCase
from toolz import assoc
import pandas as pd
from zipline.assets import Asset, ExchangeInfo
from zipline.errors import (
DTypeNotSpecified,
InvalidOutputName,
NonWindowSafeInput,
NotDType,
TermInputsNotSpecified,
NonPipelineInputs,
TermOutputsEmpty,
UnsupportedDType,
WindowLengthNotSpecified,
)
from zipline.pipeline import (
Classifier,
CustomClassifier,
CustomFactor,
Factor,
Filter,
ExecutionPlan,
)
from zipline.pipeline.data import Column, DataSet
from zipline.pipeline.data.testing import TestingDataSet
from zipline.pipeline.domain import US_EQUITIES
from zipline.pipeline.expression import NUMEXPR_MATH_FUNCS
from zipline.pipeline.factors import RecarrayField
from zipline.pipeline.sentinels import NotSpecified
from zipline.pipeline.term import AssetExists, LoadableTerm
from zipline.testing import parameter_space
from zipline.testing.fixtures import WithTradingSessions, ZiplineTestCase
from zipline.testing.predicates import (
assert_equal,
assert_raises,
assert_raises_regex,
assert_regex,
)
from zipline.utils.numpy_utils import (
bool_dtype,
categorical_dtype,
complex128_dtype,
datetime64ns_dtype,
float64_dtype,
int64_dtype,
NoDefaultMissingValue,
)
class SomeDataSet(DataSet):
foo = Column(float64_dtype)
bar = Column(float64_dtype)
buzz = Column(float64_dtype)
class SubDataSet(SomeDataSet):
pass
class SubDataSetNewCol(SomeDataSet):
qux = Column(float64_dtype)
class SomeFactor(Factor):
dtype = float64_dtype
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
SomeFactorAlias = SomeFactor
class SomeOtherFactor(Factor):
dtype = float64_dtype
window_length = 5
inputs = [SomeDataSet.bar, SomeDataSet.buzz]
class DateFactor(Factor):
dtype = datetime64ns_dtype
window_length = 5
inputs = [SomeDataSet.bar, SomeDataSet.buzz]
class NoLookbackFactor(Factor):
dtype = float64_dtype
window_length = 0
class GenericCustomFactor(CustomFactor):
dtype = float64_dtype
window_length = 5
inputs = [SomeDataSet.foo]
class MultipleOutputs(CustomFactor):
dtype = float64_dtype
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
outputs = ['alpha', 'beta']
def some_method(self):
return
class GenericFilter(Filter):
dtype = bool_dtype
window_length = 0
inputs = []
class GenericClassifier(Classifier):
dtype = categorical_dtype
window_length = 0
inputs = []
def gen_equivalent_factors():
"""
Return an iterator of SomeFactor instances that should all be the same
object.
"""
yield SomeFactor()
yield SomeFactor(inputs=NotSpecified)
yield SomeFactor(SomeFactor.inputs)
yield SomeFactor(inputs=SomeFactor.inputs)
yield SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
yield SomeFactor(window_length=SomeFactor.window_length)
yield SomeFactor(window_length=NotSpecified)
yield SomeFactor(
[SomeDataSet.foo, SomeDataSet.bar],
window_length=NotSpecified,
)
yield SomeFactor(
[SomeDataSet.foo, SomeDataSet.bar],
window_length=SomeFactor.window_length,
)
yield SomeFactorAlias()
def to_dict(l):
"""
Convert a list to a dict with keys drawn from '0', '1', '2', ...
Examples
--------
>>> to_dict([2, 3, 4]) # doctest: +SKIP
{'0': 2, '1': 3, '2': 4}
"""
return dict(zip(map(str, range(len(l))), l))
class DependencyResolutionTestCase(WithTradingSessions, ZiplineTestCase):
TRADING_CALENDAR_STRS = ('NYSE',)
START_DATE = pd.Timestamp('2014-01-02', tz='UTC')
END_DATE = pd.Timestamp('2014-12-31', tz='UTC')
execution_plan_start = pd.Timestamp('2014-06-01', tz='UTC')
execution_plan_end = pd.Timestamp('2014-06-30', tz='UTC')
DOMAIN = US_EQUITIES
def check_dependency_order(self, ordered_terms):
seen = set()
for term in ordered_terms:
for dep in term.dependencies:
# LoadableTerms should be specialized do the domain of
# execution when emitted by an execution plan.
if isinstance(dep, LoadableTerm):
self.assertIn(dep.specialize(self.DOMAIN), seen)
else:
self.assertIn(dep, seen)
seen.add(term)
def make_execution_plan(self, terms):
return ExecutionPlan(
domain=self.DOMAIN,
terms=terms,
start_date=self.execution_plan_start,
end_date=self.execution_plan_end,
)
def test_single_factor(self):
"""
Test dependency resolution for a single factor.
"""
def check_output(graph):
resolution_order = list(graph.ordered())
# Loadable terms should get specialized during graph construction.
specialized_foo = SomeDataSet.foo.specialize(self.DOMAIN)
specialized_bar = SomeDataSet.foo.specialize(self.DOMAIN)
self.assertEqual(len(resolution_order), 4)
self.check_dependency_order(resolution_order)
self.assertIn(AssetExists(), resolution_order)
self.assertIn(specialized_foo, resolution_order)
self.assertIn(specialized_bar, resolution_order)
self.assertIn(SomeFactor(), resolution_order)
self.assertEqual(
graph.graph.node[specialized_foo]['extra_rows'], 4,
)
self.assertEqual(
graph.graph.node[specialized_bar]['extra_rows'], 4,
)
for foobar in gen_equivalent_factors():
check_output(self.make_execution_plan(to_dict([foobar])))
def test_single_factor_instance_args(self):
"""
Test dependency resolution for a single factor with arguments passed to
the constructor.
"""
bar, buzz = SomeDataSet.bar, SomeDataSet.buzz
factor = SomeFactor([bar, buzz], window_length=5)
graph = self.make_execution_plan(to_dict([factor]))
resolution_order = list(graph.ordered())
# SomeFactor, its inputs, and AssetExists()
self.assertEqual(len(resolution_order), 4)
self.check_dependency_order(resolution_order)
self.assertIn(AssetExists(), resolution_order)
self.assertEqual(graph.extra_rows[AssetExists()], 4)
# LoadableTerms should be specialized to our domain in the execution
# order.
self.assertIn(bar.specialize(self.DOMAIN), resolution_order)
self.assertIn(buzz.specialize(self.DOMAIN), resolution_order)
# ComputableTerms don't yet have a notion of specialization, so they
# shouldn't appear unchanged in the execution order.
self.assertIn(SomeFactor([bar, buzz], window_length=5),
resolution_order)
self.assertEqual(graph.extra_rows[bar.specialize(self.DOMAIN)], 4)
self.assertEqual(graph.extra_rows[bar.specialize(self.DOMAIN)], 4)
def test_reuse_loadable_terms(self):
"""
Test that raw inputs only show up in the dependency graph once.
"""
f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz])
graph = self.make_execution_plan(to_dict([f1, f2]))
resolution_order = list(graph.ordered())
# bar should only appear once.
self.assertEqual(len(resolution_order), 6)
self.assertEqual(len(set(resolution_order)), 6)
self.check_dependency_order(resolution_order)
def test_disallow_recursive_lookback(self):
with self.assertRaises(NonWindowSafeInput):
SomeFactor(inputs=[SomeFactor(), SomeDataSet.foo])
def test_window_safety_one_window_length(self):
"""
Test that window safety problems are only raised if
the parent factor has window length greater than 1
"""
with self.assertRaises(NonWindowSafeInput):
SomeFactor(inputs=[SomeOtherFactor()])
SomeFactor(inputs=[SomeOtherFactor()], window_length=1)
class ObjectIdentityTestCase(TestCase):
def assertSameObject(self, *objs):
first = objs[0]
for obj in objs:
self.assertIs(first, obj)
def assertDifferentObjects(self, *objs):
id_counts = Counter(map(id, objs))
((most_common_id, count),) = id_counts.most_common(1)
if count > 1:
dupe = [o for o in objs if id(o) == most_common_id][0]
self.fail("%s appeared %d times in %s" % (dupe, count, objs))
def test_instance_caching(self):
self.assertSameObject(*gen_equivalent_factors())
self.assertIs(
SomeFactor(window_length=SomeFactor.window_length + 1),
SomeFactor(window_length=SomeFactor.window_length + 1),
)
self.assertIs(
SomeFactor(dtype=float64_dtype),
SomeFactor(dtype=float64_dtype),
)
self.assertIs(
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
)
mask = SomeFactor() + SomeOtherFactor()
self.assertIs(SomeFactor(mask=mask), SomeFactor(mask=mask))
def test_instance_caching_multiple_outputs(self):
self.assertIs(MultipleOutputs(), MultipleOutputs())
self.assertIs(
MultipleOutputs(),
MultipleOutputs(outputs=MultipleOutputs.outputs),
)
self.assertIs(
MultipleOutputs(
outputs=[
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
],
),
MultipleOutputs(
outputs=[
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
],
),
)
# Ensure that both methods of accessing our outputs return the same
# things.
multiple_outputs = MultipleOutputs()
alpha, beta = MultipleOutputs()
self.assertIs(alpha, multiple_outputs.alpha)
self.assertIs(beta, multiple_outputs.beta)
def test_instance_caching_of_slices(self):
my_asset = Asset(
1,
exchange_info=ExchangeInfo('TEST FULL', 'TEST', 'US'),
)
f = GenericCustomFactor()
f_slice = f[my_asset]
self.assertIs(f_slice, type(f_slice)(GenericCustomFactor(), my_asset))
filt = GenericFilter()
filt_slice = filt[my_asset]
self.assertIs(filt_slice, type(filt_slice)(GenericFilter(), my_asset))
c = GenericClassifier()
c_slice = c[my_asset]
self.assertIs(c_slice, type(c_slice)(GenericClassifier(), my_asset))
def test_instance_non_caching(self):
f = SomeFactor()
# Different window_length.
self.assertIsNot(
f,
SomeFactor(window_length=SomeFactor.window_length + 1),
)
# Different dtype
self.assertIsNot(
f,
SomeFactor(dtype=datetime64ns_dtype)
)
# Reordering inputs changes semantics.
self.assertIsNot(
f,
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
)
def test_instance_non_caching_redefine_class(self):
orig_foobar_instance = SomeFactorAlias()
class SomeFactor(Factor):
dtype = float64_dtype
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
self.assertIsNot(orig_foobar_instance, SomeFactor())
def test_instance_non_caching_multiple_outputs(self):
multiple_outputs = MultipleOutputs()
# Different outputs.
self.assertIsNot(
MultipleOutputs(), MultipleOutputs(outputs=['beta', 'gamma']),
)
# Reordering outputs.
self.assertIsNot(
multiple_outputs,
MultipleOutputs(
outputs=[
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
],
),
)
# Different factors sharing an output name should produce different
# RecarrayField factors.
orig_beta = multiple_outputs.beta
beta, gamma = MultipleOutputs(outputs=['beta', 'gamma'])
self.assertIsNot(beta, orig_beta)
def test_instance_caching_binops(self):
f = SomeFactor()
g = SomeOtherFactor()
for lhs, rhs in product([f, g], [f, g]):
self.assertIs((lhs + rhs), (lhs + rhs))
self.assertIs((lhs - rhs), (lhs - rhs))
self.assertIs((lhs * rhs), (lhs * rhs))
self.assertIs((lhs / rhs), (lhs / rhs))
self.assertIs((lhs ** rhs), (lhs ** rhs))
self.assertIs((1 + rhs), (1 + rhs))
self.assertIs((rhs + 1), (rhs + 1))
self.assertIs((1 - rhs), (1 - rhs))
self.assertIs((rhs - 1), (rhs - 1))
self.assertIs((2 * rhs), (2 * rhs))
self.assertIs((rhs * 2), (rhs * 2))
self.assertIs((2 / rhs), (2 / rhs))
self.assertIs((rhs / 2), (rhs / 2))
self.assertIs((2 ** rhs), (2 ** rhs))
self.assertIs((rhs ** 2), (rhs ** 2))
self.assertIs((f + g) + (f + g), (f + g) + (f + g))
def test_instance_caching_unary_ops(self):
f = SomeFactor()
self.assertIs(-f, -f)
self.assertIs(--f, --f)
self.assertIs(---f, ---f)
def test_instance_caching_math_funcs(self):
f = SomeFactor()
for funcname in NUMEXPR_MATH_FUNCS:
method = getattr(f, funcname)
self.assertIs(method(), method())
def test_instance_caching_grouped_transforms(self):
f = SomeFactor()
c = GenericClassifier()
m = GenericFilter()
for meth in f.demean, f.zscore, f.rank:
self.assertIs(meth(), meth())
self.assertIs(meth(groupby=c), meth(groupby=c))
self.assertIs(meth(mask=m), meth(mask=m))
self.assertIs(meth(groupby=c, mask=m), meth(groupby=c, mask=m))
class SomeFactorParameterized(SomeFactor):
params = ('a', 'b')
def test_parameterized_term(self):
f = self.SomeFactorParameterized(a=1, b=2)
self.assertEqual(f.params, {'a': 1, 'b': 2})
g = self.SomeFactorParameterized(a=1, b=3)
h = self.SomeFactorParameterized(a=2, b=2)
self.assertDifferentObjects(f, g, h)
f2 = self.SomeFactorParameterized(a=1, b=2)
f3 = self.SomeFactorParameterized(b=2, a=1)
self.assertSameObject(f, f2, f3)
self.assertEqual(f.params['a'], 1)
self.assertEqual(f.params['b'], 2)
self.assertEqual(f.window_length, SomeFactor.window_length)
self.assertEqual(f.inputs, tuple(SomeFactor.inputs))
def test_parameterized_term_non_hashable_arg(self):
with assert_raises(TypeError) as e:
self.SomeFactorParameterized(a=[], b=1)
assert_equal(
str(e.exception),
"SomeFactorParameterized expected a hashable value for parameter"
" 'a', but got [] instead.",
)
with assert_raises(TypeError) as e:
self.SomeFactorParameterized(a=1, b=[])
assert_equal(
str(e.exception),
"SomeFactorParameterized expected a hashable value for parameter"
" 'b', but got [] instead.",
)
with assert_raises(TypeError) as e:
self.SomeFactorParameterized(a=[], b=[])
assert_regex(
str(e.exception),
r"SomeFactorParameterized expected a hashable value for parameter"
r" '(a|b)', but got \[\] instead\.",
)
def test_parameterized_term_default_value(self):
defaults = {'a': 'default for a', 'b': 'default for b'}
class F(Factor):
params = defaults
inputs = (SomeDataSet.foo,)
dtype = 'f8'
window_length = 5
assert_equal(F().params, defaults)
assert_equal(F(a='new a').params, assoc(defaults, 'a', 'new a'))
assert_equal(F(b='new b').params, assoc(defaults, 'b', 'new b'))
assert_equal(
F(a='new a', b='new b').params,
{'a': 'new a', 'b': 'new b'},
)
def test_parameterized_term_default_value_with_not_specified(self):
defaults = {'a': 'default for a', 'b': NotSpecified}
class F(Factor):
params = defaults
inputs = (SomeDataSet.foo,)
dtype = 'f8'
window_length = 5
pattern = r"F expected a keyword parameter 'b'\."
with assert_raises_regex(TypeError, pattern):
F()
with assert_raises_regex(TypeError, pattern):
F(a='new a')
assert_equal(F(b='new b').params, assoc(defaults, 'b', 'new b'))
assert_equal(
F(a='new a', b='new b').params,
{'a': 'new a', 'b': 'new b'},
)
def test_bad_input(self):
class SomeFactor(Factor):
dtype = float64_dtype
class SomeFactorDefaultInputs(SomeFactor):
inputs = (SomeDataSet.foo, SomeDataSet.bar)
class SomeFactorDefaultLength(SomeFactor):
window_length = 10
class SomeFactorNoDType(SomeFactor):
window_length = 10
inputs = (SomeDataSet.foo,)
dtype = NotSpecified
with self.assertRaises(TermInputsNotSpecified):
SomeFactor(window_length=1)
with self.assertRaises(TermInputsNotSpecified):
SomeFactorDefaultLength()
with self.assertRaises(NonPipelineInputs):
SomeFactor(window_length=1, inputs=[2])
with self.assertRaises(WindowLengthNotSpecified):
SomeFactor(inputs=(SomeDataSet.foo,))
with self.assertRaises(WindowLengthNotSpecified):
SomeFactorDefaultInputs()
with self.assertRaises(DTypeNotSpecified):
SomeFactorNoDType()
with self.assertRaises(NotDType):
SomeFactor(dtype=1)
with self.assertRaises(NoDefaultMissingValue):
SomeFactor(dtype=int64_dtype)
with self.assertRaises(UnsupportedDType):
SomeFactor(dtype=complex128_dtype)
with self.assertRaises(TermOutputsEmpty):
MultipleOutputs(outputs=[])
def test_bad_output_access(self):
with self.assertRaises(AttributeError) as e:
SomeFactor().not_an_attr
errmsg = str(e.exception)
self.assertEqual(
errmsg, "'SomeFactor' object has no attribute 'not_an_attr'",
)
mo = MultipleOutputs()
with self.assertRaises(AttributeError) as e:
mo.not_an_attr
errmsg = str(e.exception)
expected = (
"Instance of MultipleOutputs has no output named 'not_an_attr'."
" Possible choices are: ('alpha', 'beta')."
)
self.assertEqual(errmsg, expected)
with self.assertRaises(ValueError) as e:
alpha, beta = GenericCustomFactor()
errmsg = str(e.exception)
self.assertEqual(
errmsg, "GenericCustomFactor does not have multiple outputs.",
)
# Public method, user-defined method.
# Accessing these attributes should return the output, not the method.
conflicting_output_names = ['zscore', 'some_method']
mo = MultipleOutputs(outputs=conflicting_output_names)
for name in conflicting_output_names:
self.assertIsInstance(getattr(mo, name), RecarrayField)
# Non-callable attribute, private method, special method.
disallowed_output_names = ['inputs', '_init', '__add__']
for name in disallowed_output_names:
with self.assertRaises(InvalidOutputName):
GenericCustomFactor(outputs=[name])
def test_require_super_call_in_validate(self):
class MyFactor(Factor):
inputs = ()
dtype = float64_dtype
window_length = 0
def _validate(self):
"Woops, I didn't call super()!"
with self.assertRaises(AssertionError) as e:
MyFactor()
errmsg = str(e.exception)
self.assertEqual(
errmsg,
"Term._validate() was not called.\n"
"This probably means that you overrode _validate"
" without calling super()."
)
def test_latest_on_different_dtypes(self):
factor_dtypes = (float64_dtype, datetime64ns_dtype)
for column in TestingDataSet.columns:
if column.dtype == bool_dtype:
self.assertIsInstance(column.latest, Filter)
elif (column.dtype == int64_dtype
or column.dtype.kind in ('O', 'S', 'U')):
self.assertIsInstance(column.latest, Classifier)
elif column.dtype in factor_dtypes:
self.assertIsInstance(column.latest, Factor)
else:
self.fail(
"Unknown dtype %s for column %s" % (column.dtype, column)
)
# These should be the same value, plus this has the convenient
# property of correctly handling `NaN`.
self.assertIs(column.missing_value, column.latest.missing_value)
def test_failure_timing_on_bad_dtypes(self):
# Just constructing a bad column shouldn't fail.
Column(dtype=int64_dtype)
with self.assertRaises(NoDefaultMissingValue) as e:
class BadDataSet(DataSet):
bad_column = Column(dtype=int64_dtype)
float_column = Column(dtype=float64_dtype)
int_column = Column(dtype=int64_dtype, missing_value=3)
self.assertTrue(
str(e.exception.args[0]).startswith(
"Failed to create Column with name 'bad_column'"
)
)
Column(dtype=complex128_dtype)
with self.assertRaises(UnsupportedDType):
class BadDataSetComplex(DataSet):
bad_column = Column(dtype=complex128_dtype)
float_column = Column(dtype=float64_dtype)
int_column = Column(dtype=int64_dtype, missing_value=3)
class SubDataSetTestCase(TestCase):
def test_subdataset(self):
some_dataset_map = {
column.name: column for column in SomeDataSet.columns
}
sub_dataset_map = {
column.name: column for column in SubDataSet.columns
}
self.assertEqual(
{column.name for column in SomeDataSet.columns},
{column.name for column in SubDataSet.columns},
)
for k, some_dataset_column in some_dataset_map.items():
sub_dataset_column = sub_dataset_map[k]
self.assertIsNot(
some_dataset_column,
sub_dataset_column,
'subclass column %r should not have the same identity as'
' the parent' % k,
)
self.assertEqual(
some_dataset_column.dtype,
sub_dataset_column.dtype,
'subclass column %r should have the same dtype as the parent' %
k,
)
def test_add_column(self):
some_dataset_map = {
column.name: column for column in SomeDataSet.columns
}
sub_dataset_new_col_map = {
column.name: column for column in SubDataSetNewCol.columns
}
sub_col_names = {column.name for column in SubDataSetNewCol.columns}
# check our extra col
self.assertIn('qux', sub_col_names)
self.assertEqual(
sub_dataset_new_col_map['qux'].dtype,
float64_dtype,
)
self.assertEqual(
{column.name for column in SomeDataSet.columns},
sub_col_names - {'qux'},
)
for k, some_dataset_column in some_dataset_map.items():
sub_dataset_column = sub_dataset_new_col_map[k]
self.assertIsNot(
some_dataset_column,
sub_dataset_column,
'subclass column %r should not have the same identity as'
' the parent' % k,
)
self.assertEqual(
some_dataset_column.dtype,
sub_dataset_column.dtype,
'subclass column %r should have the same dtype as the parent' %
k,
)
@parameter_space(
dtype_=[categorical_dtype, int64_dtype],
outputs_=[('a',), ('a', 'b')],
)
def test_reject_multi_output_classifiers(self, dtype_, outputs_):
"""
Multi-output CustomClassifiers don't work because they use special
output allocation for string arrays.
"""
class SomeClassifier(CustomClassifier):
dtype = dtype_
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
outputs = outputs_
missing_value = dtype_.type('123')
expected_error = (
"SomeClassifier does not support custom outputs, "
"but received custom outputs={outputs}.".format(outputs=outputs_)
)
with self.assertRaises(ValueError) as e:
SomeClassifier()
self.assertEqual(str(e.exception), expected_error)
with self.assertRaises(ValueError) as e:
SomeClassifier()
self.assertEqual(str(e.exception), expected_error)
def test_unreasonable_missing_values(self):
for base_type, dtype_, bad_mv in ((Factor, float64_dtype, 'ayy'),
(Filter, bool_dtype, 'lmao'),
(Classifier, int64_dtype, 'lolwut'),
(Classifier, categorical_dtype, 7)):
class SomeTerm(base_type):
inputs = ()
window_length = 0
missing_value = bad_mv
dtype = dtype_
with self.assertRaises(TypeError) as e:
SomeTerm()
prefix = (
"^Missing value {mv!r} is not a valid choice "
"for term SomeTerm with dtype {dtype}.\n\n"
"Coercion attempt failed with:"
).format(mv=bad_mv, dtype=dtype_)
self.assertRegexpMatches(str(e.exception), prefix)