zipline/tests/test_security_list.py
2018-05-22 10:03:36 -04:00

282 lines
9.5 KiB
Python

from datetime import timedelta
import pandas as pd
from nose_parameterized import parameterized
from zipline.algorithm import TradingAlgorithm
from zipline.errors import TradingControlViolation
from zipline.testing import (
add_security_data,
security_list_copy,
)
from zipline.testing.fixtures import (
WithMakeAlgo,
ZiplineTestCase,
)
from zipline.utils import factory
from zipline.utils.security_list import (
SecurityListSet,
load_from_directory,
)
LEVERAGED_ETFS = load_from_directory('leveraged_etf_list')
class RestrictedAlgoWithCheck(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_asset_restrictions(self.rl.restrict_leveraged_etfs)
self.order_count = 0
self.sid = self.symbol(symbol)
def handle_data(self, data):
if not self.order_count:
if self.sid not in \
self.rl.leveraged_etf_list.\
current_securities(self.get_datetime()):
self.order(self.sid, 100)
self.order_count += 1
class RestrictedAlgoWithoutCheck(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_asset_restrictions(self.rl.restrict_leveraged_etfs)
self.order_count = 0
self.sid = self.symbol(symbol)
def handle_data(self, data):
self.order(self.sid, 100)
self.order_count += 1
class RestrictedAlgoWithoutCheckSetDoNotOrderList(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_do_not_order_list(self.rl.leveraged_etf_list)
self.order_count = 0
self.sid = self.symbol(symbol)
def handle_data(self, data):
self.order(self.sid, 100)
self.order_count += 1
class IterateRLAlgo(TradingAlgorithm):
def initialize(self, symbol):
self.rl = SecurityListSet(self.get_datetime, self.asset_finder)
self.set_asset_restrictions(self.rl.restrict_leveraged_etfs)
self.order_count = 0
self.sid = self.symbol(symbol)
self.found = False
def handle_data(self, data):
for stock in self.rl.leveraged_etf_list.\
current_securities(self.get_datetime()):
if stock == self.sid:
self.found = True
class SecurityListTestCase(WithMakeAlgo, ZiplineTestCase):
# XXX: This suite uses way more than it probably needs.
START_DATE = pd.Timestamp('2002-01-03', tz='UTC')
assert START_DATE == sorted(list(LEVERAGED_ETFS.keys()))[0], \
"START_DATE should match start of LEVERAGED_ETF data."
END_DATE = pd.Timestamp('2015-02-17', tz='utc')
extra_knowledge_date = pd.Timestamp('2015-01-27', tz='utc')
trading_day_before_first_kd = pd.Timestamp('2015-01-23', tz='utc')
SIM_PARAMS_END = pd.Timestamp('2002-01-08', tz='UTC')
SIM_PARAMS_DATA_FREQUENCY = 'daily'
DATA_PORTAL_USE_MINUTE_DATA = False
ASSET_FINDER_EQUITY_SIDS = (1, 2, 3, 4, 5)
ASSET_FINDER_EQUITY_SYMBOLS = ('AAPL', 'GOOG', 'BZQ', 'URTY', 'JFT')
def test_iterate_over_restricted_list(self):
algo = self.make_algo(
algo_class=IterateRLAlgo,
symbol='BZQ',
)
algo.run()
self.assertTrue(algo.found)
def test_security_list(self):
# set the knowledge date to the first day of the
# leveraged etf knowledge date.
def get_datetime():
return self.START_DATE
rl = SecurityListSet(get_datetime, self.asset_finder)
# assert that a sample from the leveraged list are in restricted
should_exist = [
asset.sid for asset in
[self.asset_finder.lookup_symbol(
symbol,
as_of_date=self.extra_knowledge_date)
for symbol in ["BZQ", "URTY", "JFT"]]
]
for sid in should_exist:
self.assertIn(
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
# assert that a sample of allowed stocks are not in restricted
shouldnt_exist = [
asset.sid for asset in
[self.asset_finder.lookup_symbol(
symbol,
as_of_date=self.extra_knowledge_date)
for symbol in ["AAPL", "GOOG"]]
]
for sid in shouldnt_exist:
self.assertNotIn(
sid, rl.leveraged_etf_list.current_securities(get_datetime()))
def test_security_add(self):
def get_datetime():
return pd.Timestamp("2015-01-27", tz='UTC')
with security_list_copy():
add_security_data(['AAPL', 'GOOG'], [])
rl = SecurityListSet(get_datetime, self.asset_finder)
should_exist = [
asset.sid for asset in
[self.asset_finder.lookup_symbol(
symbol,
as_of_date=self.extra_knowledge_date
) for symbol in ["AAPL", "GOOG", "BZQ", "URTY"]]
]
for sid in should_exist:
self.assertIn(
sid,
rl.leveraged_etf_list.current_securities(get_datetime())
)
def test_security_add_delete(self):
with security_list_copy():
def get_datetime():
return pd.Timestamp("2015-01-27", tz='UTC')
rl = SecurityListSet(get_datetime, self.asset_finder)
self.assertNotIn(
"BZQ",
rl.leveraged_etf_list.current_securities(get_datetime())
)
self.assertNotIn(
"URTY",
rl.leveraged_etf_list.current_securities(get_datetime())
)
def test_algo_without_rl_violation_via_check(self):
self.run_algorithm(algo_class=RestrictedAlgoWithCheck, symbol="BZQ")
def test_algo_without_rl_violation(self):
self.run_algorithm(
algo_class=RestrictedAlgoWithoutCheck, symbol="AAPL",
)
@parameterized.expand([
('using_set_do_not_order_list',
RestrictedAlgoWithoutCheckSetDoNotOrderList),
('using_set_restrictions', RestrictedAlgoWithoutCheck),
])
def test_algo_with_rl_violation(self, name, algo_class):
algo = self.make_algo(algo_class=algo_class, symbol='BZQ')
with self.assertRaises(TradingControlViolation) as ctx:
algo.run()
self.check_algo_exception(algo, ctx, 0)
# repeat with a symbol from a different lookup date
algo = self.make_algo(
algo_class=RestrictedAlgoWithoutCheck, symbol='JFT',
)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run()
self.check_algo_exception(algo, ctx, 0)
def test_algo_with_rl_violation_after_knowledge_date(self):
start = self.START_DATE + timedelta(days=7)
end = start + self.trading_calendar.day * 4
algo = self.make_algo(
algo_class=RestrictedAlgoWithoutCheck,
symbol='BZQ',
sim_params=self.make_simparams(
start_session=start,
end_session=end,
)
)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run()
self.check_algo_exception(algo, ctx, 0)
def test_algo_with_rl_violation_cumulative(self):
"""
Add a new restriction, run a test long after both
knowledge dates, make sure stock from original restriction
set is still disallowed.
"""
sim_params = factory.create_simulation_parameters(
start=self.START_DATE + timedelta(days=7),
num_days=4
)
with security_list_copy():
add_security_data(['AAPL'], [])
algo = self.make_algo(
algo_class=RestrictedAlgoWithoutCheck,
symbol='BZQ',
sim_params=sim_params,
)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run()
self.check_algo_exception(algo, ctx, 0)
def test_algo_without_rl_violation_after_delete(self):
sim_params = factory.create_simulation_parameters(
start=self.extra_knowledge_date,
num_days=4,
)
with security_list_copy():
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
algo = self.make_algo(
algo_class=RestrictedAlgoWithoutCheck,
symbol='BZQ',
sim_params=sim_params,
)
algo.run()
def test_algo_with_rl_violation_after_add(self):
sim_params = factory.create_simulation_parameters(
start=self.trading_day_before_first_kd,
num_days=4,
)
with security_list_copy():
add_security_data(['AAPL'], [])
algo = self.make_algo(
algo_class=RestrictedAlgoWithoutCheck,
symbol='AAPL',
sim_params=sim_params,
)
with self.assertRaises(TradingControlViolation) as ctx:
algo.run()
self.check_algo_exception(algo, ctx, 2)
def check_algo_exception(self, algo, ctx, expected_order_count):
self.assertEqual(algo.order_count, expected_order_count)
exc = ctx.exception
self.assertEqual(TradingControlViolation, type(exc))
exc_msg = str(ctx.exception)
self.assertTrue("RestrictedListOrder" in exc_msg)