mirror of
https://github.com/saymrwulf/zipline.git
synced 2026-05-16 21:10:11 +00:00
579 lines
19 KiB
Python
579 lines
19 KiB
Python
from collections import OrderedDict
|
|
from itertools import permutations, product, islice
|
|
from operator import (
|
|
add,
|
|
ge,
|
|
gt,
|
|
le,
|
|
lt,
|
|
methodcaller,
|
|
mul,
|
|
ne,
|
|
sub,
|
|
)
|
|
from string import ascii_uppercase
|
|
from unittest import TestCase
|
|
|
|
import numpy
|
|
from numpy import (
|
|
arange,
|
|
array,
|
|
eye,
|
|
float64,
|
|
full,
|
|
isnan,
|
|
zeros,
|
|
)
|
|
from pandas import (
|
|
DataFrame,
|
|
date_range,
|
|
Int64Index,
|
|
)
|
|
|
|
from zipline.pipeline import Factor, Filter
|
|
from zipline.pipeline.factors.factor import NumExprFactor
|
|
from zipline.pipeline.expression import (
|
|
NUMEXPR_MATH_FUNCS,
|
|
NumericalExpression,
|
|
)
|
|
from zipline.testing import check_allclose, parameter_space
|
|
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype
|
|
|
|
|
|
class F(Factor):
|
|
dtype = float64_dtype
|
|
inputs = ()
|
|
window_length = 0
|
|
|
|
|
|
class G(Factor):
|
|
dtype = float64_dtype
|
|
inputs = ()
|
|
window_length = 0
|
|
|
|
|
|
class H(Factor):
|
|
dtype = float64_dtype
|
|
inputs = ()
|
|
window_length = 0
|
|
|
|
|
|
class NonExprFilter(Filter):
|
|
inputs = ()
|
|
window_length = 0
|
|
|
|
|
|
class DateFactor(Factor):
|
|
dtype = datetime64ns_dtype
|
|
inputs = ()
|
|
window_length = 0
|
|
|
|
|
|
class NumericalExpressionTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
self.dates = date_range('2014-01-01', periods=5, freq='D')
|
|
self.assets = Int64Index(range(5))
|
|
self.f = F()
|
|
self.g = G()
|
|
self.h = H()
|
|
self.d = DateFactor()
|
|
self.fake_raw_data = {
|
|
self.f: full((5, 5), 3, float),
|
|
self.g: full((5, 5), 2, float),
|
|
self.h: full((5, 5), 1, float),
|
|
self.d: full((5, 5), 0, dtype='datetime64[ns]'),
|
|
}
|
|
self.mask = DataFrame(True, index=self.dates, columns=self.assets)
|
|
|
|
def check_output(self, expr, expected):
|
|
result = expr._compute(
|
|
[self.fake_raw_data[input_] for input_ in expr.inputs],
|
|
self.mask.index,
|
|
self.mask.columns,
|
|
self.mask.values,
|
|
)
|
|
check_allclose(result, expected)
|
|
|
|
def check_constant_output(self, expr, expected):
|
|
self.assertFalse(isnan(expected))
|
|
return self.check_output(expr, full((5, 5), expected, float))
|
|
|
|
def test_validate_good(self):
|
|
f = self.f
|
|
g = self.g
|
|
|
|
NumExprFactor("x_0", (f,), dtype=float64_dtype)
|
|
NumExprFactor("x_0 ", (f,), dtype=float64_dtype)
|
|
NumExprFactor("x_0 + x_0", (f,), dtype=float64_dtype)
|
|
NumExprFactor("x_0 + 2", (f,), dtype=float64_dtype)
|
|
NumExprFactor("2 * x_0", (f,), dtype=float64_dtype)
|
|
NumExprFactor("x_0 + x_1", (f, g), dtype=float64_dtype)
|
|
NumExprFactor("x_0 + x_1 + x_0", (f, g), dtype=float64_dtype)
|
|
NumExprFactor("x_0 + 1 + x_1", (f, g), dtype=float64_dtype)
|
|
|
|
def test_validate_bad(self):
|
|
f, g, h = self.f, self.g, self.h
|
|
|
|
# Too few inputs.
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0", (), dtype=float64_dtype)
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0 + x_1", (f,), dtype=float64_dtype)
|
|
|
|
# Too many inputs.
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0", (f, g), dtype=float64_dtype)
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0 + x_1", (f, g, h), dtype=float64_dtype)
|
|
|
|
# Invalid variable name.
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0x_1", (f,), dtype=float64_dtype)
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_0x_1", (f, g), dtype=float64_dtype)
|
|
|
|
# Variable index must start at 0.
|
|
with self.assertRaises(ValueError):
|
|
NumExprFactor("x_1", (f,), dtype=float64_dtype)
|
|
|
|
# Scalar operands must be numeric.
|
|
with self.assertRaises(TypeError):
|
|
"2" + f
|
|
with self.assertRaises(TypeError):
|
|
f + "2"
|
|
with self.assertRaises(TypeError):
|
|
f > "2"
|
|
|
|
# Boolean binary operators must be between filters.
|
|
with self.assertRaises(TypeError):
|
|
f + (f > 2)
|
|
with self.assertRaises(TypeError):
|
|
(f > f) > f
|
|
|
|
@parameter_space(num_new_inputs=[1, 4])
|
|
def test_many_inputs(self, num_new_inputs):
|
|
"""
|
|
Test adding NumericalExpressions with >=32 (NPY_MAXARGS) inputs.
|
|
"""
|
|
# Create an initial NumericalExpression by adding two factors together.
|
|
f = self.f
|
|
expr = f + f
|
|
|
|
self.fake_raw_data = OrderedDict({f: full((5, 5), 0, float)})
|
|
expected = 0
|
|
|
|
# Alternate between adding and subtracting factors. Because subtraction
|
|
# is not commutative, this ensures that we are combining factors in the
|
|
# correct order.
|
|
ops = (add, sub)
|
|
|
|
for i, name in enumerate(
|
|
islice(product(ascii_uppercase, ascii_uppercase), 64)
|
|
):
|
|
name = ''.join(name)
|
|
op = ops[i % 2]
|
|
|
|
new_expr_inputs = []
|
|
for j in range(num_new_inputs):
|
|
NewFactor = type(
|
|
name + str(j),
|
|
(Factor,),
|
|
dict(dtype=float64_dtype, inputs=(), window_length=0),
|
|
)
|
|
new_factor = NewFactor()
|
|
self.fake_raw_data[new_factor] = full((5, 5), i + 1, float)
|
|
new_expr_inputs.append(new_factor)
|
|
|
|
# Again we need a NumericalExpression, so add two factors together.
|
|
new_expr = new_expr_inputs[0]
|
|
self.fake_raw_data[new_expr] = full((5, 5), (i + 1), float)
|
|
for new_expr_input in new_expr_inputs:
|
|
new_expr = new_expr + new_expr_input
|
|
self.fake_raw_data[new_expr] = full(
|
|
(5, 5), (i + 1) * (num_new_inputs + 1), float
|
|
)
|
|
|
|
# This will grow the number of inputs by num_new_inputs. We start
|
|
# at 1 (self.f). The num_new_inputs=4 case grows by 4 and covers
|
|
# growing from 29 to 33 (> NPY_MAXARGS).
|
|
expr = op(expr, new_expr)
|
|
# Each factor is counted num_new_inputs + 1 times.
|
|
expected = op(expected, (i + 1) * (num_new_inputs + 1))
|
|
self.fake_raw_data[expr] = full((5, 5), expected, float)
|
|
|
|
for expr, expected in self.fake_raw_data.items():
|
|
if isinstance(expr, NumericalExpression):
|
|
self.check_output(expr, expected)
|
|
|
|
def test_combine_datetimes(self):
|
|
with self.assertRaises(TypeError) as e:
|
|
self.d + self.d
|
|
message = e.exception.args[0]
|
|
expected = (
|
|
"Don't know how to compute datetime64[ns] + datetime64[ns].\n"
|
|
"Arithmetic operators are only supported between Factors of dtype "
|
|
"'float64'."
|
|
)
|
|
self.assertEqual(message, expected)
|
|
|
|
# Confirm that * shows up in the error instead of +.
|
|
with self.assertRaises(TypeError) as e:
|
|
self.d * self.d
|
|
message = e.exception.args[0]
|
|
expected = (
|
|
"Don't know how to compute datetime64[ns] * datetime64[ns].\n"
|
|
"Arithmetic operators are only supported between Factors of dtype "
|
|
"'float64'."
|
|
)
|
|
self.assertEqual(message, expected)
|
|
|
|
def test_combine_datetime_with_float(self):
|
|
# Test with both float-type factors and numeric values.
|
|
for float_value in (self.f, float64(1.0), 1.0):
|
|
for op, sym in ((add, '+'), (mul, '*')):
|
|
with self.assertRaises(TypeError) as e:
|
|
op(self.f, self.d)
|
|
message = e.exception.args[0]
|
|
expected = (
|
|
"Don't know how to compute float64 {sym} datetime64[ns].\n"
|
|
"Arithmetic operators are only supported between Factors"
|
|
" of dtype 'float64'."
|
|
).format(sym=sym)
|
|
self.assertEqual(message, expected)
|
|
|
|
with self.assertRaises(TypeError) as e:
|
|
op(self.d, self.f)
|
|
message = e.exception.args[0]
|
|
expected = (
|
|
"Don't know how to compute datetime64[ns] {sym} float64.\n"
|
|
"Arithmetic operators are only supported between Factors"
|
|
" of dtype 'float64'."
|
|
).format(sym=sym)
|
|
self.assertEqual(message, expected)
|
|
|
|
def test_negate_datetime(self):
|
|
with self.assertRaises(TypeError) as e:
|
|
-self.d
|
|
|
|
message = e.exception.args[0]
|
|
expected = (
|
|
"Can't apply unary operator '-' to instance of "
|
|
"'DateFactor' with dtype 'datetime64[ns]'.\n"
|
|
"'-' is only supported for Factors of dtype 'float64'."
|
|
)
|
|
self.assertEqual(message, expected)
|
|
|
|
def test_negate(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(-f, -3.0)
|
|
self.check_constant_output(--f, 3.0)
|
|
self.check_constant_output(---f, -3.0)
|
|
|
|
self.check_constant_output(-(f + f), -6.0)
|
|
self.check_constant_output(-f + -f, -6.0)
|
|
self.check_constant_output(-(-f + -f), 6.0)
|
|
|
|
self.check_constant_output(f + -g, 1.0)
|
|
self.check_constant_output(f - -g, 5.0)
|
|
|
|
self.check_constant_output(-(f + g) + (f + g), 0.0)
|
|
self.check_constant_output((f + g) + -(f + g), 0.0)
|
|
self.check_constant_output(-(f + g) + -(f + g), -10.0)
|
|
|
|
def test_add(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f + g, 5.0)
|
|
|
|
self.check_constant_output((1 + f) + g, 6.0)
|
|
self.check_constant_output(1 + (f + g), 6.0)
|
|
self.check_constant_output((f + 1) + g, 6.0)
|
|
self.check_constant_output(f + (1 + g), 6.0)
|
|
self.check_constant_output((f + g) + 1, 6.0)
|
|
self.check_constant_output(f + (g + 1), 6.0)
|
|
|
|
self.check_constant_output((f + f) + f, 9.0)
|
|
self.check_constant_output(f + (f + f), 9.0)
|
|
|
|
self.check_constant_output((f + g) + f, 8.0)
|
|
self.check_constant_output(f + (g + f), 8.0)
|
|
|
|
self.check_constant_output((f + g) + (f + g), 10.0)
|
|
self.check_constant_output((f + g) + (g + f), 10.0)
|
|
self.check_constant_output((g + f) + (f + g), 10.0)
|
|
self.check_constant_output((g + f) + (g + f), 10.0)
|
|
|
|
def test_subtract(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f - g, 1.0) # 3 - 2
|
|
|
|
self.check_constant_output((1 - f) - g, -4.) # (1 - 3) - 2
|
|
self.check_constant_output(1 - (f - g), 0.0) # 1 - (3 - 2)
|
|
self.check_constant_output((f - 1) - g, 0.0) # (3 - 1) - 2
|
|
self.check_constant_output(f - (1 - g), 4.0) # 3 - (1 - 2)
|
|
self.check_constant_output((f - g) - 1, 0.0) # (3 - 2) - 1
|
|
self.check_constant_output(f - (g - 1), 2.0) # 3 - (2 - 1)
|
|
|
|
self.check_constant_output((f - f) - f, -3.) # (3 - 3) - 3
|
|
self.check_constant_output(f - (f - f), 3.0) # 3 - (3 - 3)
|
|
|
|
self.check_constant_output((f - g) - f, -2.) # (3 - 2) - 3
|
|
self.check_constant_output(f - (g - f), 4.0) # 3 - (2 - 3)
|
|
|
|
self.check_constant_output((f - g) - (f - g), 0.0) # (3 - 2) - (3 - 2)
|
|
self.check_constant_output((f - g) - (g - f), 2.0) # (3 - 2) - (2 - 3)
|
|
self.check_constant_output((g - f) - (f - g), -2.) # (2 - 3) - (3 - 2)
|
|
self.check_constant_output((g - f) - (g - f), 0.0) # (2 - 3) - (2 - 3)
|
|
|
|
def test_multiply(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f * g, 6.0)
|
|
|
|
self.check_constant_output((2 * f) * g, 12.0)
|
|
self.check_constant_output(2 * (f * g), 12.0)
|
|
self.check_constant_output((f * 2) * g, 12.0)
|
|
self.check_constant_output(f * (2 * g), 12.0)
|
|
self.check_constant_output((f * g) * 2, 12.0)
|
|
self.check_constant_output(f * (g * 2), 12.0)
|
|
|
|
self.check_constant_output((f * f) * f, 27.0)
|
|
self.check_constant_output(f * (f * f), 27.0)
|
|
|
|
self.check_constant_output((f * g) * f, 18.0)
|
|
self.check_constant_output(f * (g * f), 18.0)
|
|
|
|
self.check_constant_output((f * g) * (f * g), 36.0)
|
|
self.check_constant_output((f * g) * (g * f), 36.0)
|
|
self.check_constant_output((g * f) * (f * g), 36.0)
|
|
self.check_constant_output((g * f) * (g * f), 36.0)
|
|
|
|
self.check_constant_output(f * f * f * 0 * f * f, 0.0)
|
|
|
|
def test_divide(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f / g, 3.0 / 2.0)
|
|
|
|
self.check_constant_output(
|
|
(2 / f) / g,
|
|
(2 / 3.0) / 2.0
|
|
)
|
|
self.check_constant_output(
|
|
2 / (f / g),
|
|
2 / (3.0 / 2.0),
|
|
)
|
|
self.check_constant_output(
|
|
(f / 2) / g,
|
|
(3.0 / 2) / 2.0,
|
|
)
|
|
self.check_constant_output(
|
|
f / (2 / g),
|
|
3.0 / (2 / 2.0),
|
|
)
|
|
self.check_constant_output(
|
|
(f / g) / 2,
|
|
(3.0 / 2.0) / 2,
|
|
)
|
|
self.check_constant_output(
|
|
f / (g / 2),
|
|
3.0 / (2.0 / 2),
|
|
)
|
|
self.check_constant_output(
|
|
(f / f) / f,
|
|
(3.0 / 3.0) / 3.0
|
|
)
|
|
self.check_constant_output(
|
|
f / (f / f),
|
|
3.0 / (3.0 / 3.0),
|
|
)
|
|
self.check_constant_output(
|
|
(f / g) / f,
|
|
(3.0 / 2.0) / 3.0,
|
|
)
|
|
self.check_constant_output(
|
|
f / (g / f),
|
|
3.0 / (2.0 / 3.0),
|
|
)
|
|
|
|
self.check_constant_output(
|
|
(f / g) / (f / g),
|
|
(3.0 / 2.0) / (3.0 / 2.0),
|
|
)
|
|
self.check_constant_output(
|
|
(f / g) / (g / f),
|
|
(3.0 / 2.0) / (2.0 / 3.0),
|
|
)
|
|
self.check_constant_output(
|
|
(g / f) / (f / g),
|
|
(2.0 / 3.0) / (3.0 / 2.0),
|
|
)
|
|
self.check_constant_output(
|
|
(g / f) / (g / f),
|
|
(2.0 / 3.0) / (2.0 / 3.0),
|
|
)
|
|
|
|
def test_pow(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f ** g, 3.0 ** 2)
|
|
self.check_constant_output(2 ** f, 2.0 ** 3)
|
|
self.check_constant_output(f ** 2, 3.0 ** 2)
|
|
|
|
self.check_constant_output((f + g) ** 2, (3.0 + 2.0) ** 2)
|
|
self.check_constant_output(2 ** (f + g), 2 ** (3.0 + 2.0))
|
|
|
|
self.check_constant_output(f ** (f ** g), 3.0 ** (3.0 ** 2.0))
|
|
self.check_constant_output((f ** f) ** g, (3.0 ** 3.0) ** 2.0)
|
|
|
|
self.check_constant_output((f ** g) ** (f ** g), 9.0 ** 9.0)
|
|
self.check_constant_output((f ** g) ** (g ** f), 9.0 ** 8.0)
|
|
self.check_constant_output((g ** f) ** (f ** g), 8.0 ** 9.0)
|
|
self.check_constant_output((g ** f) ** (g ** f), 8.0 ** 8.0)
|
|
|
|
def test_mod(self):
|
|
f, g = self.f, self.g
|
|
|
|
self.check_constant_output(f % g, 3.0 % 2.0)
|
|
self.check_constant_output(f % 2.0, 3.0 % 2.0)
|
|
self.check_constant_output(g % f, 2.0 % 3.0)
|
|
|
|
self.check_constant_output((f + g) % 2, (3.0 + 2.0) % 2)
|
|
self.check_constant_output(2 % (f + g), 2 % (3.0 + 2.0))
|
|
|
|
self.check_constant_output(f % (f % g), 3.0 % (3.0 % 2.0))
|
|
self.check_constant_output((f % f) % g, (3.0 % 3.0) % 2.0)
|
|
|
|
self.check_constant_output((f + g) % (f * g), 5.0 % 6.0)
|
|
|
|
def test_math_functions(self):
|
|
f, g = self.f, self.g
|
|
|
|
fake_raw_data = self.fake_raw_data
|
|
alt_fake_raw_data = {
|
|
self.f: full((5, 5), .5),
|
|
self.g: full((5, 5), -.5),
|
|
}
|
|
|
|
for funcname in NUMEXPR_MATH_FUNCS:
|
|
method = methodcaller(funcname)
|
|
func = getattr(numpy, funcname)
|
|
|
|
# These methods have domains in [0, 1], so we need alternate inputs
|
|
# that are in the domain.
|
|
if funcname in ('arcsin', 'arccos', 'arctanh'):
|
|
self.fake_raw_data = alt_fake_raw_data
|
|
else:
|
|
self.fake_raw_data = fake_raw_data
|
|
|
|
f_val = self.fake_raw_data[f][0, 0]
|
|
g_val = self.fake_raw_data[g][0, 0]
|
|
|
|
self.check_constant_output(method(f), func(f_val))
|
|
self.check_constant_output(method(g), func(g_val))
|
|
|
|
self.check_constant_output(method(f) + 1, func(f_val) + 1)
|
|
self.check_constant_output(1 + method(f), 1 + func(f_val))
|
|
|
|
self.check_constant_output(method(f + .25), func(f_val + .25))
|
|
self.check_constant_output(method(.25 + f), func(.25 + f_val))
|
|
|
|
self.check_constant_output(
|
|
method(f) + method(g),
|
|
func(f_val) + func(g_val),
|
|
)
|
|
self.check_constant_output(
|
|
method(f + g),
|
|
func(f_val + g_val),
|
|
)
|
|
|
|
def test_comparisons(self):
|
|
f, g, h = self.f, self.g, self.h
|
|
self.fake_raw_data = {
|
|
f: arange(25, dtype=float).reshape(5, 5),
|
|
g: arange(25, dtype=float).reshape(5, 5) - eye(5),
|
|
h: full((5, 5), 5, dtype=float),
|
|
}
|
|
f_data = self.fake_raw_data[f]
|
|
g_data = self.fake_raw_data[g]
|
|
|
|
cases = [
|
|
# Sanity Check with hand-computed values.
|
|
(f, g, eye(5), zeros((5, 5))),
|
|
(f, 10, f_data, 10),
|
|
(10, f, 10, f_data),
|
|
(f, f, f_data, f_data),
|
|
(f + 1, f, f_data + 1, f_data),
|
|
(1 + f, f, 1 + f_data, f_data),
|
|
(f, g, f_data, g_data),
|
|
(f + 1, g, f_data + 1, g_data),
|
|
(f, g + 1, f_data, g_data + 1),
|
|
(f + 1, g + 1, f_data + 1, g_data + 1),
|
|
((f + g) / 2, f ** 2, (f_data + g_data) / 2, f_data ** 2),
|
|
]
|
|
for op in (gt, ge, lt, le, ne):
|
|
for expr_lhs, expr_rhs, expected_lhs, expected_rhs in cases:
|
|
self.check_output(
|
|
op(expr_lhs, expr_rhs),
|
|
op(expected_lhs, expected_rhs),
|
|
)
|
|
|
|
def test_boolean_binops(self):
|
|
f, g, h = self.f, self.g, self.h
|
|
|
|
# Add a non-numexpr filter to ensure that we correctly handle
|
|
# delegation to NumericalExpression.
|
|
custom_filter = NonExprFilter()
|
|
custom_filter_mask = array(
|
|
[[0, 1, 0, 1, 0],
|
|
[0, 0, 1, 0, 0],
|
|
[1, 0, 0, 0, 0],
|
|
[0, 0, 1, 1, 0],
|
|
[0, 0, 0, 1, 0]],
|
|
dtype=bool,
|
|
)
|
|
|
|
self.fake_raw_data = {
|
|
f: arange(25, dtype=float).reshape(5, 5),
|
|
g: arange(25, dtype=float).reshape(5, 5) - eye(5),
|
|
h: full((5, 5), 5, dtype=float),
|
|
custom_filter: custom_filter_mask,
|
|
}
|
|
|
|
# Should be True on the diagonal.
|
|
eye_filter = (f > g)
|
|
|
|
# Should be True in the first row only.
|
|
first_row_filter = f < h
|
|
|
|
eye_mask = eye(5, dtype=bool)
|
|
|
|
first_row_mask = zeros((5, 5), dtype=bool)
|
|
first_row_mask[0] = 1
|
|
|
|
self.check_output(eye_filter, eye_mask)
|
|
self.check_output(first_row_filter, first_row_mask)
|
|
|
|
def gen_boolops(x, y, z):
|
|
"""
|
|
Generate all possible interleavings of & and | between all possible
|
|
orderings of x, y, and z.
|
|
"""
|
|
for a, b, c in permutations([x, y, z]):
|
|
yield (a & b) & c
|
|
yield (a & b) | c
|
|
yield (a | b) & c
|
|
yield (a | b) | c
|
|
yield a & (b & c)
|
|
yield a & (b | c)
|
|
yield a | (b & c)
|
|
yield a | (b | c)
|
|
|
|
exprs = gen_boolops(eye_filter, custom_filter, first_row_filter)
|
|
arrays = gen_boolops(eye_mask, custom_filter_mask, first_row_mask)
|
|
|
|
for expr, expected in zip(exprs, arrays):
|
|
self.check_output(expr, expected)
|