mirror of
https://github.com/saymrwulf/zipline.git
synced 2026-05-16 21:10:11 +00:00
220 lines
6.7 KiB
Python
220 lines
6.7 KiB
Python
#
|
|
# Copyright 2014 Quantopian, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from datetime import timedelta
|
|
from functools import wraps
|
|
from itertools import product
|
|
from nose_parameterized import parameterized
|
|
import operator
|
|
import random
|
|
from six import itervalues
|
|
from six.moves import map
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
from numpy.testing import assert_allclose
|
|
|
|
from zipline.finance.trading import TradingEnvironment
|
|
from zipline.algorithm import TradingAlgorithm
|
|
import zipline.utils.factory as factory
|
|
from zipline.api import add_transform, get_datetime
|
|
|
|
|
|
def handle_data_wrapper(f):
|
|
@wraps(f)
|
|
def wrapper(context, data):
|
|
dt = get_datetime()
|
|
if dt.date() != context.current_date:
|
|
context.warmup -= 1
|
|
context.mins_for_days.append(1)
|
|
context.current_date = dt.date()
|
|
else:
|
|
context.mins_for_days[-1] += 1
|
|
|
|
hist = context.history(2, '1d', 'close_price')
|
|
for n in (1, 2, 3):
|
|
if n in data:
|
|
if data[n].dt == dt:
|
|
context.vol_bars[n].append(data[n].volume)
|
|
else:
|
|
context.vol_bars[n].append(0)
|
|
|
|
context.price_bars[n].append(data[n].price)
|
|
else:
|
|
context.price_bars[n].append(np.nan)
|
|
context.vol_bars[n].append(0)
|
|
|
|
context.last_close_prices[n] = hist[n][0]
|
|
|
|
if context.warmup < 0:
|
|
return f(context, data)
|
|
|
|
return wrapper
|
|
|
|
|
|
def initialize_with(test_case, tfm_name, days):
|
|
def initalize(context):
|
|
context.test_case = test_case
|
|
context.days = days
|
|
context.mins_for_days = []
|
|
context.price_bars = (None, [np.nan], [np.nan], [np.nan])
|
|
context.vol_bars = (None, [np.nan], [np.nan], [np.nan])
|
|
if context.days:
|
|
context.warmup = days + 1
|
|
else:
|
|
context.warmup = 2
|
|
|
|
context.current_date = None
|
|
|
|
context.last_close_prices = [np.nan, np.nan, np.nan, np.nan]
|
|
add_transform(tfm_name, days)
|
|
|
|
return initalize
|
|
|
|
|
|
def windows_with_frequencies(*args):
|
|
args = args or (None,)
|
|
return product(('daily', 'minute'), args)
|
|
|
|
|
|
def with_algo(f):
|
|
name = f.__name__
|
|
if not name.startswith('test_'):
|
|
raise ValueError('This must decorate a test case')
|
|
|
|
tfm_name = name[len('test_'):]
|
|
|
|
@wraps(f)
|
|
def wrapper(self, data_frequency, days=None):
|
|
sim_params, source = self.sim_and_source[data_frequency]
|
|
|
|
algo = TradingAlgorithm(
|
|
initialize=initialize_with(self, tfm_name, days),
|
|
handle_data=handle_data_wrapper(f),
|
|
sim_params=sim_params,
|
|
env=self.env,
|
|
)
|
|
algo.run(source)
|
|
|
|
return wrapper
|
|
|
|
|
|
class TransformTestCase(TestCase):
|
|
"""
|
|
Tests the simple transforms by running them through a zipline.
|
|
"""
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
random.seed(0)
|
|
cls.sids = (1, 2, 3)
|
|
minute_sim_ps = factory.create_simulation_parameters(
|
|
num_days=3,
|
|
data_frequency='minute',
|
|
emission_rate='minute',
|
|
)
|
|
daily_sim_ps = factory.create_simulation_parameters(
|
|
num_days=30,
|
|
data_frequency='daily',
|
|
emission_rate='daily',
|
|
)
|
|
cls.env = TradingEnvironment()
|
|
cls.env.write_data(equities_identifiers=[1, 2, 3])
|
|
cls.sim_and_source = {
|
|
'minute': (minute_sim_ps, factory.create_minutely_trade_source(
|
|
cls.sids,
|
|
sim_params=minute_sim_ps,
|
|
env=cls.env,
|
|
)),
|
|
'daily': (daily_sim_ps, factory.create_trade_source(
|
|
cls.sids,
|
|
trade_time_increment=timedelta(days=1),
|
|
sim_params=daily_sim_ps,
|
|
env=cls.env,
|
|
)),
|
|
}
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
del cls.env
|
|
|
|
def tearDown(self):
|
|
"""
|
|
Each test consumes a source, we need to rewind it.
|
|
"""
|
|
for _, source in itervalues(self.sim_and_source):
|
|
source.rewind()
|
|
|
|
@parameterized.expand(windows_with_frequencies(1, 2, 3, 4))
|
|
@with_algo
|
|
def test_mavg(context, data):
|
|
"""
|
|
Tests the mavg transform by manually keeping track of the prices
|
|
in a naiive way and asserting that our mean is the same.
|
|
"""
|
|
mins = sum(context.mins_for_days[-context.days:])
|
|
|
|
for sid in data:
|
|
assert_allclose(
|
|
data[sid].mavg(context.days),
|
|
np.mean(context.price_bars[sid][-mins:]),
|
|
)
|
|
|
|
@parameterized.expand(windows_with_frequencies(2, 3, 4))
|
|
@with_algo
|
|
def test_stddev(context, data):
|
|
"""
|
|
Tests the stddev transform by manually keeping track of the prices
|
|
in a naiive way and asserting that our stddev is the same.
|
|
This accounts for the corrected ddof.
|
|
"""
|
|
mins = sum(context.mins_for_days[-context.days:])
|
|
|
|
for sid in data:
|
|
assert_allclose(
|
|
data[sid].stddev(context.days),
|
|
np.std(context.price_bars[sid][-mins:], ddof=1),
|
|
)
|
|
|
|
@parameterized.expand(windows_with_frequencies(2, 3, 4))
|
|
@with_algo
|
|
def test_vwap(context, data):
|
|
"""
|
|
Tests the vwap transform by manually keeping track of the prices
|
|
and volumes in a naiive way and asserting that our hand-rolled vwap is
|
|
the same
|
|
"""
|
|
mins = sum(context.mins_for_days[-context.days:])
|
|
for sid in data:
|
|
prices = context.price_bars[sid][-mins:]
|
|
vols = context.vol_bars[sid][-mins:]
|
|
manual_vwap = sum(
|
|
map(operator.mul, np.nan_to_num(np.array(prices)), vols),
|
|
) / sum(vols)
|
|
|
|
assert_allclose(
|
|
data[sid].vwap(context.days),
|
|
manual_vwap,
|
|
)
|
|
|
|
@parameterized.expand(windows_with_frequencies())
|
|
@with_algo
|
|
def test_returns(context, data):
|
|
for sid in data:
|
|
last_close = context.last_close_prices[sid]
|
|
returns = (data[sid].price - last_close) / last_close
|
|
|
|
assert_allclose(
|
|
data[sid].returns(),
|
|
returns,
|
|
)
|