Wiped out tests.

This commit is contained in:
Stephen Diehl 2012-02-29 16:48:55 -05:00
parent a3f35444e9
commit 6edf17fb69
21 changed files with 1795 additions and 142 deletions

3
.gitignore vendored
View file

@ -40,3 +40,6 @@ nosetests.xml
# Built documentation
docs/_build/*
# credentials and other uncheckinables
host_settings.py

39
dataloader.py Normal file
View file

@ -0,0 +1,39 @@
import datetime
import sys
import zipline.util as qutil
from zipline.finance.data import DataLoader
def print_usage():
print """
Usage is:
python loaddata.py (pt | lt | lh | ld | ei | bm | si | help)
pt - purge trade collection from the db
lt - load trades (minute bars) to the db
lh - load trades (hour bars) to the db
ld - load trades (daily close) to the db
ei - ensure all indexes on all collections in tick and algo db
tr - load treasury rates
bm - load benchmark data
si - load security info (sid, symbol, qualifier)
help - display this message
"""
if __name__ == "__main__":
if len(sys.argv) == 2:
qutil.configure_logging()
operation = sys.argv[1]
if(operation not in['pt','lt','lh','ld','ei','si', 'tr','bm'] or operation == 'help'):
print_usage()
else:
ts = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
pidfile = "/tmp/loaddata-{stamp}.pid".format(stamp=ts)
daemon = DataLoader(pidfile,operation)
qutil.LOGGER.info("DataLoader starting.")
daemon.run()
sys.exit(0)
else:
print_usage()
sys.exit(2)

View file

@ -28,17 +28,17 @@ via zeromq.
DataSources
--------------------
A DataSource represents a historical event record, which will be played back during simulation. A simulation may have one or more DataSources, which will be combined in DataFeed. Generally, datasources read records from a persistent store (db, csv file, remote service), format the messages for downstream simulation components, and send them to a PUSH socket. See the base class for all datasources :py:class:`~zipline.sources.DataSource`
A DataSource represents a historical event record, which will be played back during simulation. A simulation may have one or more DataSources, which will be combined in DataFeed. Generally, datasources read records from a persistent store (db, csv file, remote service), format the messages for downstream simulation components, and send them to a PUSH socket. See the base class for all datasources :py:class:`~zipline.messaging.DataSource` and the module holding all datasources :py:mod:`zipline.sources`
DataFeed
--------------------
All simulations start with a collection of :py:class:`DataSources <zipline.sources.DataSource>`, which need to be fed to an algorithm. Each :py:class:`~zipline.sources.DataSource`can contain events of differing content (trades, quotes, corporate event) and frequency (quarterly, intraday). To simplify the process of managing the data sources, :py:class:`~zipline.core.DataFeed` can receive events from multiple :py:class:`DataSources <zipline.sources.DataSource>` and combine them into a serial chronological stream.
All simulations start with a collection of :py:class:`~zipline.messaging.DataSource`, which need to be fed to an algorithm. Each :py:class:`~zipline.sources.DataSource`can contain events of differing content (trades, quotes, corporate event) and frequency (quarterly, intraday). To simplify the process of managing the data sources, :py:class:`~zipline.core.DataFeed` can receive events from multiple :py:class:`DataSources <zipline.sources.DataSource>` and combine them into a serial chronological stream.
Transforms
--------------------
Often, an algorithm will require a running calculation on top of a :py:class:`~zipline.sources.DataSource`, or on the consolidated feed. A simple example is a technical indicator or a moving average. Transforms can be described in :py:class:`~zipline.core.Simulator`'s configuration. Subclass :py:class:`~zipline.transforms.core.Transform` to add your own Transform. Transforms must hold their own state between events, and serialize their current values into messages.
Often, an algorithm will require a running calculation on top of a :py:class:`~zipline.messaging.DataSource`, or on the consolidated feed. A simple example is a technical indicator or a moving average. Transforms can be described in :py:class:`~zipline.core.Simulator`'s configuration. Subclass :py:class:`~zipline.transforms.core.Transform` to add your own Transform. Transforms must hold their own state between events, and serialize their current values into messages.
Data Alignment

27
docs/zipline.finance.rst Normal file
View file

@ -0,0 +1,27 @@
finance Package
===============
:mod:`data` Module
------------------
.. automodule:: zipline.finance.data
:members:
:undoc-members:
:show-inheritance:
:mod:`risk` Module
------------------
.. automodule:: zipline.finance.risk
:members:
:undoc-members:
:show-inheritance:
:mod:`trading` Module
---------------------
.. automodule:: zipline.finance.trading
:members:
:undoc-members:
:show-inheritance:

View file

@ -8,6 +8,7 @@ if [ ! -d $WORKON_HOME ]; then
fi
source /usr/local/bin/virtualenvwrapper.sh
#create the scientific python virtualenv and copy to provide zipline base
mkvirtualenv --no-site-packages scientific_base
workon scientific_base
@ -24,6 +25,9 @@ workon zipline
# Show what we have installed
pip freeze
#copy the host_settings file into place
cp /mnt/jenkins/zipline_host_settings.py ./host_settings.py
#documentation output
paver apidocs html

View file

@ -12,6 +12,6 @@ with-xunit=1
# Drop into debugger on failure
#pdb=0
#pdb-failures=0
pdb=0
pdb-failures=0

143
zipline/daemon.py Normal file
View file

@ -0,0 +1,143 @@
"""
Daemon class, based on the excellent article:
http://www.jejik.com/articles/2007/02/a_simple_unix_linux_daemon_in_python/
"""
import sys, os, time, atexit
from signal import SIGTERM, SIGINT
class Daemon:
"""
A generic daemon class.
Usage: subclass the Daemon class and override the run() method
"""
def __init__(self, pidfile, stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'):
self.stdin = stdin
self.stdout = stdout
self.stderr = stderr
self.pidfile = pidfile
def daemonize(self):
"""
do the UNIX double-fork magic, see Stevens' "Advanced
Programming in the UNIX Environment" for details (ISBN 0201563177)
http://www.erlenstar.demon.co.uk/unix/faq_2.html#SEC16
"""
try:
pid = os.fork()
if pid > 0:
# exit first parent
sys.exit(0)
except OSError, e:
sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
sys.exit(1)
# decouple from parent environment
os.chdir("/")
os.setsid()
os.umask(0)
# do second fork
try:
pid = os.fork()
if pid > 0:
# exit from second parent
sys.exit(0)
except OSError, e:
sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror))
sys.exit(1)
# redirect standard file descriptors
sys.stdout.flush()
sys.stderr.flush()
si = file(self.stdin, 'r')
so = file(self.stdout, 'a+')
se = file(self.stderr, 'a+', 0)
os.dup2(si.fileno(), sys.stdin.fileno())
os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno())
# write pidfile
atexit.register(self.delpid)
pid = str(os.getpid())
file(self.pidfile,'w+').write("%s\n" % pid)
def delpid(self):
os.remove(self.pidfile)
def start(self):
"""
Start the daemon
"""
# Check for a pidfile to see if the daemon already runs
try:
pf = file(self.pidfile,'r')
pid = int(pf.read().strip())
pf.close()
except IOError:
pid = None
if pid:
message = "pidfile %s already exist. Daemon already running?\n"
sys.stderr.write(message % self.pidfile)
sys.exit(1)
# Start the daemon
self.daemonize()
try:
signal.signal(signal.SIGINT, self.handle_kill)
except Exception, err:
print "Problem with sigint signup " + str(err)
self.run()
def stop(self):
"""
Stop the daemon
"""
# Get the pid from the pidfile
try:
pf = file(self.pidfile,'r')
pid = int(pf.read().strip())
pf.close()
except IOError:
pid = None
if not pid:
message = "pidfile %s does not exist. Daemon not running?\n"
sys.stderr.write(message % self.pidfile)
return # not an error in a restart
# First signal the process that we need to interrupt, so it can do things like close child procs
try:
os.kill(pid, SIGINT)
time.sleep(2.0) #Give the process some time to kill...
except OSError, err:
print "Error trying to sigint the process" + str(err)
# Try killing the daemon process
try:
while 1:
os.kill(pid, SIGTERM)
time.sleep(0.1)
except OSError, err:
err = str(err)
if err.find("No such process") > 0:
if os.path.exists(self.pidfile):
os.remove(self.pidfile)
else:
print str(err)
sys.exit(1)
def restart(self):
"""
Restart the daemon
"""
self.stop()
self.start()
def run(self):
"""
You should override this method when you subclass Daemon. It will be called after the process has been
daemonized by start() or restart().
"""

76
zipline/db.py Normal file
View file

@ -0,0 +1,76 @@
import atexit
import pymongo
import zipline.util as qutil
class MongoOptions(object):
def __init__(self, host, port, dbname, user, password):
self.mongodb_host = host
self.mongodb_port = port
self.mongodb_dbname = dbname
self.mongodb_user = user
self.mongodb_password = password
class NoDatabase(Exception):
def __repr__(self):
return 'The database has not been set up yet.'
def setup_db(credentials):
"""
Setup the database. Has global side effects.
"""
qutil.LOGGER.info(dir(DbConnection))
if not DbConnection.initd:
connector = connect_db(credentials)
DbConnection.set(*connector)
def connect_db(options):
"""
Connect to pymongo, return a connection and database instance
as a tuple.
"""
connection = pymongo.Connection(options.mongodb_host, options.mongodb_port)
db = connection[options.mongodb_dbname]
db.authenticate(options.mongodb_user, options.mongodb_password)
def _gc_connection(): # pragma: no cover
connection.close()
atexit.register(_gc_connection)
return connection, db
class DbConnection(object):
"""
Hold the shared state of the database connection.
"""
initd = False
__shared = {}
def __init__(self):
self.__dict__ = self.__shared
@staticmethod
def set(conn, db):
DbConnection.__shared['conn'] = conn
DbConnection.__shared['db'] = db
DbConnection.initd = True
@staticmethod
def get():
return (
DbConnection.__shared['conn'],
DbConnection.__shared['db']
)
def __getattr__(self, key):
if not DbConnection.__shared.get('initd'):
raise NoDatabase()
else:
return DbConnection.__shared.get(key)
def destory(self): # pragma: no cover
DbConnection.__shared['initd'] = False
self.conn.close()

View file

497
zipline/finance/data.py Normal file
View file

@ -0,0 +1,497 @@
import sys
import logging
import datetime
import sys
import os
import pymongo
import csv
import re
import copy
import datetime
import time
import pytz
import shutil
import urllib
import subprocess
from pymongo import ASCENDING, DESCENDING
from zipline.daemon import Daemon
import zipline.util as qutil
import zipline.db as db
import host_settings
class FinancialDataLoader():
"""
Load trade and quote data from tickdata extracts into the db.
Dates and times in the extracts must be in GMT.
All data extract files are expected to be in $HOME/fdl/. The expected directory layout is::
/benchmark.csv -- this will be created from yahoo data each time load_bench_marks is run
/interest_rates.csv --
"""
BATCH_SIZE = 100
def __init__(self):
self.conn, self.db = db.DbConnection.get()
self.data_file_path = os.environ['HOME'] + "/fdl/"
subprocess.call("mkdir {data_dir}".format(data_dir=self.data_file_path), shell=True)
self.last_bm_close = None
def load_bench_marks(self):
"""Fetches the S&P end of day pricing history from yahoo, loads it to db.bench_marks"""
start = time.time()
start_date = datetime.datetime(year=1950, month=1, day=3)
end_date = datetime.datetime.utcnow()
file_path = self.data_file_path + "benchmark.csv"
fp = open(file_path + ".tmp", "wb")
#create benchmark files
#^GSPC 19500103
query = {}
query['s'] = "^GSPC" #the s&p 500
query['d'] = end_date.month - 1 # end_date month, zero indexed
query['e'] = end_date.day # end_date day str(int(todate[6:8])) #day
query['f'] = end_date.year #end_date year str(int(todate[0:4]))
query['g'] = "d" #daily frequency
query['a'] = start_date.month - 1 #start_date month, zero indexed
query['b'] = start_date.day #start_date day
query['c'] = start_date.year #start_date year
#print query
params = urllib.urlencode(query)
params += "&ignore=.csv"
url = "http://ichart.yahoo.com/table.csv?%s" % params
qutil.LOGGER.info("fetching {url}".format(url=url))
f = urllib.urlopen(url)
fp.write(f.read())
fp.close()
qutil.LOGGER.info("fetched {url} Reversing.".format(url=url))
tmp_file = file_path + ".tmp"
reversed_tmp_file = file_path + ".rev"
rcode = subprocess.call("tac {oldfile} > {newfile}".format(oldfile=tmp_file, newfile=reversed_tmp_file), shell=True)
#on mac, there is no tac command, so use tail -r (which isn't available on debian)
if rcode != 0:
rcode = subprocess.call("tail -r {oldfile} > {newfile}".format(oldfile=tmp_file, newfile=reversed_tmp_file), shell=True)
#tail -1 benchmark.csv.rev > benchmark.csv
subprocess.call("echo \"date,open,high,low,close,volume,adj_close\" > {result}".format(newfile=reversed_tmp_file, result=self.data_file_path), shell=True)
#sed '$d' < ~/fdl/benchmark.csv.rev >> ~/fdl/benchmark.csv
subprocess.call("sed '$d' < {newfile} >> {result}".format(newfile=reversed_tmp_file, result=self.data_file_path), shell=True)
#clean up working files
subprocess.call("rm {tmp} {reversed}".format(tmp=tmp_file, reversed=reversed_tmp_file), shell=True)
#load the records into mongodb
self.db.bench_marks.drop()
qutil.LOGGER.info("processing benchmark info")
self.parse_file(self.db.bench_marks,
self.bench_mark_cb,
file_path,
['date','open','high','low','close','volume','adj_close'],
None,
0)
qutil.LOGGER.info("benchmark info complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to load benchmark history" % total)
def load_treasuries(self):
"""fetches data from the treasury.gov yield curve website, and populates the treasury_curves table.
to explore data available from the treasury:
http://www.treasury.gov/resource-center/data-chart-center/interest-rates/Pages/TextView.aspx?data=yield
to fetch xml of all daily yield curves:
http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData
"""
from xml.dom.minidom import parse
self.db.treasury_curves.drop()
path = os.path.join(self.data_file_path + "all_treasury_rates.xml")
#download all data to local filesystem
subprocess.call("curl http://data.treasury.gov/feed.svc/DailyTreasuryYieldCurveRateData > {path}".format(path=path), shell=True)
dom = parse(path)
entries = dom.getElementsByTagName("entry")
for entry in entries:
curve = {}
curve['tid'] = self.get_node_value(entry, "d:Id")
curve['date'] = self.get_treasury_date(self.get_node_value(entry, "d:NEW_DATE"))
curve['1month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_1MONTH"))
curve['3month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_3MONTH"))
curve['6month'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_6MONTH"))
curve['1year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_1YEAR"))
curve['2year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_2YEAR"))
curve['3year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_3YEAR"))
curve['5year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_5YEAR"))
curve['7year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_7YEAR"))
curve['10year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_10YEAR"))
curve['20year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_20YEAR"))
curve['30year'] = self.get_treasury_rate(self.get_node_value(entry, "d:BC_30YEAR"))
self.db.treasury_curves.insert(curve, True)
def get_treasury_date(self, dstring):
return datetime.datetime.strptime(dstring.split("T")[0], '%Y-%m-%d')
def get_treasury_rate(self, string_val):
val = self.guarded_conversion(float, string_val, None)
if val != None:
val = round(val / 100.0, 4)
return val
def get_node_value(self, entry_node, tag_name):
return self.get_xml_text(entry_node.getElementsByTagName(tag_name)[0].childNodes)
def get_xml_text(self, nodelist):
rc = []
for node in nodelist:
if node.nodeType == node.TEXT_NODE:
rc.append(node.data)
return ''.join(rc)
def purge_quotes(self):
self.db.equity.quotes.drop()
def purge_trades(self):
self.db.equity.trades.drop()
def load_quotes(self):
start = time.time()
qutil.LOGGER.info("processing equity quotes")
self.load_events(self.db.equity.quotes,
self.quoteRowCallback,
self.data_file_path + "2008/Quotes/DATA",
['trade_date', 'trade_time','exchange_code','bid_price','ask_price', 'bid_size','ask_size'])
qutil.LOGGER.info("quotes complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to update equity quotes" % total)
def load_trades(self):
start = time.time()
qutil.LOGGER.info("processing equity minute bars")
self.load_events(self.db.equity.trades.minute,
self.trade_cb,
os.path.join(self.data_file_path, "2008/Trades/MINUTE_DATA"),
['trade_date','trade_time','price', 'volume'])
qutil.LOGGER.info("minute trades complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
def load_hourly_trades(self):
start = time.time()
qutil.LOGGER.info("processing equity hour bars")
self.load_events(self.db.equity.trades.hourly,
self.trade_cb,
os.path.join(self.data_file_path, "2008/Trades/HOURLY_DATA"),
['trade_date','trade_time','price','volume'])
qutil.LOGGER.info("hourly trades complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
def load_daily_close(self):
start = time.time()
qutil.LOGGER.info("processing equity daily close")
self.load_events(self.db.equity.trades.daily,
self.trade_cb,
os.path.join(self.data_file_path, "2008/Trades/DAILY_DATA"),
['trade_date','price', 'volume'])
qutil.LOGGER.info("daily close complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
def ensure_indexes(self):
#ensure indexes on minute trades
qutil.LOGGER.info("ensuring (+datetime, +sid) index on trades.minute")
self.db.equity.trades.minute.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
qutil.LOGGER.info("(+datetime, +sid) index on trades.minute ready")
#ensure indexes for hourly trades
qutil.LOGGER.info("ensuring (sid, +datetime) index on trades.hourly")
self.db.equity.trades.hourly.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
qutil.LOGGER.info("(sid, +datetime) index on trades.hourly ready")
#ensure indexes for daily trades
qutil.LOGGER.info("ensuring (+datetime,+sid) index on trades.daily")
self.db.equity.trades.daily.ensure_index([("dt",ASCENDING),("sid",ASCENDING)],background=True)
qutil.LOGGER.info("(+datetime,+sid) index on trades.daily ready")
#ensure indexes for orders and transactions
qutil.LOGGER.info("ensuring (+backtestid) index on orders")
self.db.orders.ensure_index([("back_test_run_id",ASCENDING)],background=True)
qutil.LOGGER.info("(+backtestid) index on orders ready")
qutil.LOGGER.info("ensuring (+backtestid, +datetime) index on orders")
self.db.orders.ensure_index([("back_test_run_id",ASCENDING),("dt",ASCENDING)],background=True)
qutil.LOGGER.info("(+backtestid, +datetime) index on orders ready")
qutil.LOGGER.info("ensuring (+backtestid) index on orders")
self.db.transactions.ensure_index([("back_test_run_id",ASCENDING)],background=True)
qutil.LOGGER.info("(+backtestid) index on orders ready")
qutil.LOGGER.info("ensuring (+backtestid) index on transactions")
self.db.transactions.ensure_index([("back_test_run_id",ASCENDING),("dt",ASCENDING)],background=True)
qutil.LOGGER.info("(+backtestid) index on transactions ready")
#indexes for benchmarks and treasuries
qutil.LOGGER.info("ensuring (+date) index on treasury_curves")
self.db.treasury_curves.ensure_index([("date",ASCENDING)],background=True)
qutil.LOGGER.info(" (+date) index on treasury_curves ready")
qutil.LOGGER.info("ensuring (-date) index on treasury_curves")
self.db.treasury_curves.ensure_index([("date",DESCENDING)],background=True)
qutil.LOGGER.info(" (-date) index on treasury_curves ready")
qutil.LOGGER.info("ensuring (+date) index on bench_marks")
self.db.bench_marks.ensure_index([("date",ASCENDING)],background=True)
qutil.LOGGER.info(" (+date) index on bench_marks ready")
qutil.LOGGER.info("ensuring (+symbol, +date) index on bench_marks")
self.db.bench_marks.ensure_index([("symbol",ASCENDING),("date",ASCENDING)],background=True)
qutil.LOGGER.info(" (+symbol, +date) index on bench_marks ready")
def load_security_info(self):
start = time.time()
qutil.LOGGER.info("processing company info")
sourceFile = os.path.join(self.data_file_path, "2008/Trades/MINUTE_DATA/CompanyInfo/CompanyInfo.asc")
self.db.securities.drop()
self.parse_file(self.db.securities,
self.security_cb,
sourceFile,
['symbol','file name','company name','CUSIP','exchange','industry code','first date','last date','company id'],
None,
0)
qutil.LOGGER.info("company info complete")
total = time.time() - start
qutil.LOGGER.info("%d seconds to recreate equity trades" % total)
def load_events(self, collection, rowCallBack, dataDirectory, csvFields):
id_counter = 0
listing = os.listdir(dataDirectory)
processedDir = os.path.join(dataDirectory,"processed")
if not os.path.exists(processedDir):
os.mkdir(processedDir)
for curFile in listing:
if os.path.isdir(os.path.join(dataDirectory,curFile)):
continue
start = time.time()
if id_counter == 0: #this is the first file we are processing, so we want to ensure we don't duplicate records
minDateTime = self.get_latest_entry_for_sid(self.get_sid_from_filename(curFile),collection)
else:
minDateTime = None #this isn't the first file, so don't bother querying
rowCount, totalCount = self.parse_file(collection, rowCallBack, os.path.join(dataDirectory,curFile), csvFields, minDateTime, id_counter)
id_counter = id_counter + rowCount
parseTime = time.time() - start
qutil.LOGGER.info("{time} seconds to parse and load {rowCount} records of {totalCount} from {file}. {rate} records/second".
format(time = parseTime, rowCount=rowCount, totalCount=totalCount, file=curFile, rate = rowCount/parseTime))
#we successfully processed the file without an exception, move it to the processed folder
#qutil.LOGGER.info("moving data file to {newpath}".format(newpath=os.path.join(processedDir,curFile)))
shutil.move(os.path.join(dataDirectory,curFile),os.path.join(processedDir,curFile))
def parse_file(self, collection, rowCallBack, curFile, pFieldnames, minDateTime, id_counter):
"""Parses the given file into the collection. Returns tuple of the rows committed, rows in csvfile"""
qutil.LOGGER.debug("processing {fn}".format(fn=curFile))
cur_id = id_counter
rowCount = 0
csvRowCount = 0
with open(curFile, 'rb') as f:
reader = csv.DictReader(f,fieldnames=pFieldnames)
header = False
if csv.Sniffer().has_header(f.read(1024)):
header = True
f.seek(0)
if header:
reader.next()
try:
rows = []
for row in reader:
#row['_id'] = cur_id
cur_id = cur_id + 1
csvRowCount += 1
utcDT, dt = self.get_event_datetime(row)
#only add rows that are after the mindate for the current sid.
if(minDateTime != None and dt <= minDateTime):
continue
if(dt != None):
row['dt'] = dt
if('company id' not in pFieldnames):
company_id = self.get_sid_from_filename(curFile)
if(company_id):
row['sid'] = int(company_id)
if not rowCallBack(curFile, row):
continue
rows.append(row)
rowCount+=1
if(len(rows) >= self.BATCH_SIZE):
collection.insert(rows, safe=True)
rows = []
if(len(rows) > 0):
collection.insert(rows, safe=True)
rows = None
except csv.Error, e:
sys.exit('file %s, line %d: %s' % (curFile, reader.line_num, e))
return rowCount, csvRowCount
def trade_cb(self, curFile, row):
row['price'] = self.guarded_conversion(float,row['price'])
row['volume'] = self.guarded_conversion(self.safe_int,row['volume'])
return True
def bench_mark_cb(self, curFile, row):
row['symbol'] = "GSPC"
row['volume'] = self.guarded_conversion(int,row['volume'])
row['open'] = self.guarded_conversion(float,row['open'])
row['high'] = self.guarded_conversion(float,row['high'])
row['low'] = self.guarded_conversion(float,row['low'])
row['close'] = self.guarded_conversion(float,row['close'])
row['adj_close'] = self.guarded_conversion(float,row['adj_close'])
row['date'] = datetime.datetime.strptime(row['date'], '%Y-%m-%d')
if self.last_bm_close == None:
row['returns'] = (row['close'] - row['open'])/row['open']
else:
row['returns'] = (row['close'] - self.last_bm_close) / self.last_bm_close
self.last_bm_close = row['close']
return True
def security_cb(self, curFile, row):
"""source columns: ['symbol','file name','company name','CUSIP','exchange','industry code','first date','last date','company id']"""
row['sid'] = self.guarded_conversion(int,row['company id'])
del(row['company id'])
row['start_date'] = self.guarded_conversion(self.date_conversion, row['first date'])
del(row['first date'])
row['end_date'] = self.guarded_conversion(self.date_conversion, row['last date'])
del(row['last date'])
row['symbol'] = self.verify_symbol_in_filename(row['symbol'], row['file name'])
del(row['file name'])
row['company_name'] = row['company name']
del(row['company name'])
return True
def guarded_conversion(self, conversion, strVal, default = None):
if(strVal == None or strVal == ""):
return default
return conversion(strVal)
def safe_int(self,str):
"""casts the string to a float to handle the occassionaly decimal point in int fields from data providers."""
f = float(str)
i = int(f)
return i
def date_conversion(self, dateStr):
dt = datetime.datetime.strptime(dateStr, '%m/%d/%Y')
dt = dt.replace (tzinfo = pytz.utc)
return dt
def verify_symbol_in_filename(self, symbol, file_name):
if(symbol == file_name):
return symbol
parts = file_name.split('_')
if(len(parts) == 2):
return file_name
else:
raise Exception("found a mismatch between symbol and filename, but no underscore.")
def get_event_datetime(self, row):
"""python 2.5 doesn't support %f for setting the microseconds, so this override is necessary.
a significant side effect - the trade date and trade time elements are removed from this dictionary. done to
avoid storing the source fields in the db.
"""
if row.has_key('trade_date') and row.has_key('trade_time'):
value = row['trade_date'] + "-" + row['trade_time']
dt = datetime.datetime.strptime(value.split(".")[0], '%m/%d/%Y-%H:%M:%S')
dt = dt.replace(microsecond=int(value.split(".")[1]+"000"))
del row['trade_date']
del row['trade_time']
elif row.has_key('trade_date'):
dt = datetime.datetime.strptime(row['trade_date'],'%m/%d/%Y')
del row['trade_date']
else:
return None, None
utcDT = quantoenv.getUTCFromExchangeTime(dt) #store everything in UTC
return utcDT, dt
def get_sid_from_filename(self, filename):
regexp = r"(?P<company_id>[0-9]+)([.]csv)"
result = re.search(regexp,filename)
if(result):
companyID = int(result.group('company_id'))
return companyID
else:
return None
def get_latest_entry_for_sid(self, sid, collection):
"""checks given collection for the most recent record for the given sid."""
results = collection.find(fields=["dt"],
spec={"sid":sid},
sort=[("dt",DESCENDING)],
limit=1,
as_class=quantoenv.DocWrap)
if(results.count() > 0):
return results[0].dt
else:
return datetime.datetime.min
class DataLoader(Daemon):
"""A daemon process that manages the data in the finance database."""
def __init__(self, pidfile, operation):
self.operation = operation
self.pidfile = pidfile
self.stdin = '/dev/null'
self.stdout = '/dev/null'
self.stderr = '/dev/null'
def run(self):
qutil.LOGGER.info("running operation: {op}".format(op=self.operation))
try:
fdl = FinancialDataLoader()
if(self.operation == 'pt'):
qutil.LOGGER.info("Purging trades from database!")
fdl.purge_trades()
elif(self.operation == 'ei'):
qutil.LOGGER.info("Ensuring indexes.")
fdl.ensure_indexes()
elif(self.operation == 'lt'):
qutil.LOGGER.info("Loading trades into database.")
fdl.loadTrades()
elif(self.operation == 'lh'):
qutil.LOGGER.info("Loading trades into database.")
fdl.load_hourly_trades()
elif(self.operation == 'ld'):
qutil.LOGGER.info("Loading trades into database.")
fdl.load_daily_close()
elif(self.operation == 'si'):
qutil.LOGGER.info("Loading security info into database.")
fdl.load_security_info()
elif(self.operation == 'tr'):
qutil.LOGGER.info("Loading US Treasury rates into database.")
fdl.load_treasuries()
elif(self.operation == 'bm'):
qutil.LOGGER.info("loading benchmark data into database.")
fdl.load_bench_marks()
else:
qutil.LOGGER.warning("Unknown command for load data: {op}.".format(op=self.operation))
qutil.LOGGER.info("Finished.")
except:
qutil.LOGGER.exception("exiting load_data due to unexpected exception.")
finally:
logging.shutdown()

273
zipline/finance/risk.py Normal file
View file

@ -0,0 +1,273 @@
import datetime
import math
import pytz
import numpy as np
import numpy.linalg as la
import zipline.util as qutil
import zipline.db as db
import zipline.protocol as zp
from pymongo import ASCENDING, DESCENDING
class daily_return():
def __init__(self, date, returns):
self.date = date
self.returns = returns
class periodmetrics():
def __init__(self, start_date, end_date, returns, benchmark_returns):
self.db = db.DbConnection.get()[1]
self.start_date = start_date
self.end_date = end_date
self.trading_calendar = trading_calendar
self.algorithm_period_returns, self.algorithm_returns = self.calculate_period_returns(returns)
self.benchmark_period_returns, self.benchmark_returns = self.calculate_period_returns(benchmark_returns)
if(len(self.benchmark_returns) != len(self.algorithm_returns)):
raise Exception("Mismatch between benchmark_returns ({bm_count}) and algorithm_returns ({algo_count}) in range {start} : {end}".format(
bm_count=len(self.benchmark_returns),
algo_count=len(self.algorithm_returns),
start=start_date,
end=end_date))
self.trading_days = len(self.benchmark_returns)
self.benchmark_volatility = self.calculate_volatility(self.benchmark_returns)
self.algorithm_volatility = self.calculate_volatility(self.algorithm_returns)
self.treasury_period_return = self.choose_treasury()
self.sharpe = self.calculate_sharpe()
self.beta, self.algorithm_covariance, self.benchmark_variance, self.condition_number, self.eigen_values = self.calculate_beta()
self.alpha = self.calculate_alpha()
self.excess_return = self.algorithm_period_returns - self.treasury_period_return
self.max_drawdown = self.calculate_max_drawdown()
def __repr__(self):
statements = []
for metric in ["algorithm_period_returns", "benchmark_period_returns", "excess_return", "trading_days", "benchmark_volatility", "algorithm_volatility", "sharpe", "algorithm_covariance", "benchmark_variance", "beta", "alpha", "max_drawdown", "algorithm_returns", "benchmark_returns", "condition_number", "eigen_values"]:
value = getattr(self, metric)
statements.append("{metric}:{value}".format(metric=metric, value=value))
return '\n'.join(statements)
def calculate_period_returns(self, daily_returns):
returns = [x.returns for x in daily_returns if x.date >= self.start_date and x.date <= self.end_date and self.trading_calendar.is_trading_day(x.date)]
#qutil.LOGGER.debug("using {count} daily returns out of {total}".format(count=len(returns),total=len(daily_returns)))
period_returns = 1.0
for r in returns:
period_returns = period_returns * (1.0 + r)
period_returns = period_returns - 1.0
return period_returns, returns
def calculate_volatility(self, daily_returns):
#qutil.LOGGER.debug("trading days {td}".format(td=self.trading_days))
return np.std(daily_returns, ddof=1) * math.sqrt(self.trading_days)
def calculate_sharpe(self):
return (self.algorithm_period_returns - self.treasury_period_return) / self.algorithm_volatility
def calculate_beta(self):
#qutil.LOGGER.debug("algorithm has {acount} days, benchmark has {bmcount} days".format(acount=len(self.algorithm_returns), bmcount=len(self.benchmark_returns)))
#it doesn't make much sense to calculate beta for less than two days, so return none.
if len(self.algorithm_returns) < 2:
return 0.0, 0.0, 0.0, 0.0, []
returns_matrix = np.vstack([self.algorithm_returns, self.benchmark_returns])
C = np.cov(returns_matrix)
eigen_values = la.eigvals(C)
condition_number = max(eigen_values) / min(eigen_values)
algorithm_covariance = C[0][1]
benchmark_variance = C[1][1]
beta = C[0][1] / C[1][1]
#qutil.LOGGER.debug("bm variance is {bmv}, returns matrix is {rm}, covariance is {c}, beta is {beta}".format(rm=returns_matrix, bmv=C[1][1], c=C, beta=beta))
return beta, algorithm_covariance, benchmark_variance, condition_number, eigen_values
def calculate_alpha(self):
return self.algorithm_period_returns - (self.treasury_period_return + self.beta * (self.benchmark_period_returns - self.treasury_period_return))
def calculate_max_drawdown(self):
compounded_returns = []
cur_return = 0.0
for r in self.algorithm_returns:
if(r != -1.0):
cur_return = math.log(1.0 + r) + cur_return
#this is a guard for a single day returning -100%
else:
qutil.LOGGER.warn("negative 100 percent return, zeroing the returns")
cur_return = 0.0
compounded_returns.append(cur_return)
#qutil.LOGGER.debug("compounded returns are {cr}".format(cr=compounded_returns))
cur_max = None
max_drawdown = None
for cur in compounded_returns:
if cur_max == None or cur > cur_max:
cur_max = cur
drawdown = (cur - cur_max)
if max_drawdown == None or drawdown < max_drawdown:
max_drawdown = drawdown
#qutil.LOGGER.debug("max drawdown is: {dd}".format(dd=max_drawdown))
if max_drawdown == None:
return 0.0
return 1.0 - math.exp(max_drawdown)
def choose_treasury(self):
td = self.end_date - self.start_date
if td.days <= 31:
self.treasury_duration = '1month'
elif td.days <= 93:
self.treasury_duration = '3month'
elif td.days <= 186:
self.treasury_duration = '6month'
elif td.days <= 366:
self.treasury_duration = '1year'
elif td.days <= 365 * 2 + 1:
self.treasury_duration = '2year'
elif td.days <= 365 * 3 + 1:
self.treasury_duration = '3year'
elif td.days <= 365 * 5 + 2:
self.treasury_duration = '5year'
elif td.days <= 365 * 7 + 2:
self.treasury_duration = '7year'
elif td.days <= 365 * 10 + 2:
self.treasury_duration = '10year'
else:
self.treasury_duration = '30year'
treasuryQS = self.db.treasury_curves.find(
spec={"date" : {"$lte" : self.end_date}},
sort=[("date",DESCENDING)],
limit=3,
slave_ok=True)
for curve in treasuryQS:
self.treasury_curve = curve
rate = self.treasury_curve[self.treasury_duration]
#1month note data begins in 8/2001, so we can use 3month instead.
if rate == None and self.treasury_duration == '1month':
rate = self.treasury_curve['3month']
if rate != None:
return rate * (td.days + 1) / 365
raise Exception("no rate for end date = {dt} and term = {term}, from {curve}. Using zero.".format(dt=self.end_date,
term=self.treasury_duration,
curve=self.treasury_curve['date']))
class riskmetrics():
def __init__(self, algorithm_returns):
"""algorithm_returns needs to be a list of daily_return objects sorted in date ascending order"""
self.db = db.DbConnection.get()[1]
self.algorithm_returns = algorithm_returns
self.bm_returns = [x for x in benchmark_returns if x.date >= self.algorithm_returns[0].date and x.date <= self.algorithm_returns[-1].date]
qutil.LOGGER.debug("#### {start} thru {end} with {count} trading_days of {total} possible".format(start=self.algorithm_returns[0].date,
end=self.algorithm_returns[-1].date,
count=len(self.bm_returns),
total=len(benchmark_returns)))
#calculate month ends
self.month_periods = self.periodsInRange(1, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
#calculate 3 month ends
self.three_month_periods = self.periodsInRange(3, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
#calculate 6 month ends
self.six_month_periods = self.periodsInRange(6, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
#calculate 1 year ends
self.year_periods = self.periodsInRange(12, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
#calculate 3 year ends
self.three_year_periods = self.periodsInRange(36, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
#calculate 5 year ends
self.five_year_periods = self.periodsInRange(60, self.algorithm_returns[0].date, self.algorithm_returns[-1].date)
def periodsInRange(self, months_per, start, end):
one_day = datetime.timedelta(days = 1)
ends = []
cur_start = start.replace(day=1)
#ensure that we have an end at the end of a calendar month, in case the return series ends mid-month...
the_end = advance_by_months(end.replace(day=1),1) - one_day
while True:
cur_end = advance_by_months(cur_start, months_per) - one_day
if(cur_end > the_end):
break
#qutil.LOGGER.debug("start: {start}, end: {end}".format(start=cur_start, end=cur_end))
cur_period_metrics = periodmetrics(start_date=cur_start, end_date=cur_end, returns=self.algorithm_returns, benchmark_returns=self.bm_returns)
ends.append(cur_period_metrics)
cur_start = advance_by_months(cur_start, 1)
return ends
def store_to_db(self, back_test_run_id):
col = self.db.risk_metrics
for period in self.month_periods:
for metric in ["algorithm_period_returns", "benchmark_period_returns", "excess_return", "trading_days", "benchmark_volatility", "algorithm_volatility", "sharpe", "beta", "alpha", "max_drawdown"]:
record = {'back_test_run_id':back_test_run_id}
record['ending_on'] = period.end_date
record['metric_name'] = metric
for dur in ["month", "three_month", "six_month", "year", "three_year", "five_year"]:
record[dur] = self.find_metric_by_end(period.end_date, dur, metric)
#qutil.LOGGER.debug("storing {val} for {metric} and {dur}".format(val=record[dur], metric=metric, dur=dur))
col.insert(record, safe=True)
def find_metric_by_end(self, end_date, duration, metric):
col = getattr(self, duration + "_periods")
col = [getattr(x, metric) for x in col if x.end_date == end_date]
if len(col) == 1:
return col[0]
return None
def advance_by_months(dt, jump_in_months):
month = dt.month + jump_in_months
years = month / 12
month = month % 12
#no remainder means that we are landing in december.
#modulo is, in a way, a zero indexed circular array.
#this is a way of converting to 1 indexed months. (in our modulo index, december is zeroth)
if(month == 0):
month = 12
years = years - 1
r = dt.replace(year = dt.year + years, month = month)
return r
class TradingCalendar(object):
def __init__(self, benchmark_returns):
self.trading_days = []
self.trading_day_map = {}
for bm in benchmark_returns:
self.trading_days.append(bm.date)
self.trading_day_map[bm.date] = bm
def normalize_date(self, test_date):
return datetime.datetime(year=test_date.year, month=test_date.month, day=test_date.day, tzinfo=pytz.utc)
def is_trading_day(self, test_date):
dt = self.normalize_date(test_date)
return self.trading_day_map.has_key(dt)
def get_benchmark_daily_return(self, test_date):
date = self.normalize_date(test_date)
if self.trading_day_map.has_key(date):
return self.trading_day_map[date].returns
else:
return 0.0
def get_benchmark_data():
bmQS = db.DbConnection.get()[1].bench_marks.find(
spec={"symbol" : "GSPC"},
sort=[("date",ASCENDING)],
slave_ok=True,
as_class=zp.namedict)
bm_returns = []
for bm in bmQS:
bm_r = daily_return(date=bm.date.replace(tzinfo=pytz.utc), returns=bm.returns)
bm_returns.append(bm_r)
cal = TradingCalendar(bm_returns)
return bm_returns, cal
benchmark_returns, trading_calendar = get_benchmark_data()

155
zipline/finance/trading.py Normal file
View file

@ -0,0 +1,155 @@
import json
import datetime
from zmq.core.poll import select
import zipline.messaging as qmsg
import zipline.util as qutil
import zipline.protocol as zp
class TradeSimulationClient(qmsg.Component):
def __init__(self):
qmsg.Component.__init__(self)
self.received_count = 0
self.prev_dt = None
self.event_queue = []
@property
def get_id(self):
return "TRADING_CLIENT"
def open(self):
self.result_feed = self.connect_result()
self.order_socket = self.connect_order()
def do_work(self):
#next feed event
(rlist, wlist, xlist) = select([self.result_feed],
[],
[self.result_feed],
timeout=self.heartbeat_timeout/1000) #select timeout is in sec
#
#no more orders, should be an error condition
if len(rlist) == 0 or len(xlist) > 0:
raise Exception("unexpected end of feed stream")
message = rlist[0].recv()
if message == str(zp.CONTROL_PROTOCOL.DONE):
self.signal_done()
return #leave open orders hanging? client requests for orders?
event = zp.MERGE_UNFRAME(message)
self._handle_event(event)
def connect_order(self):
return self.connect_push_socket(self.addresses['order_address'])
def _handle_event(self, event):
self.event_queue.append(event)
if event.ALGO_TIME <= event.dt:
#event occurred in the present, send the queue to be processed
self.handle_events(self.event_queue)
self.order_socket.send(str(zp.CONTROL_PROTOCOL.DONE))
def handle_events(self, event_queue):
raise NotImplementedError
def order(self, sid, amount):
self.order_socket.send(zp.ORDER_FRAME(sid, amount))
class TradeSimulator(qmsg.BaseTransform):
def __init__(self, expected_orders):
qmsg.BaseTransform.__init__(self, "")
self.open_orders = {}
self.algo_time = None
self.event_start = None
self.last_event_time = None
self.last_iteration_duration = None
self.expected_orders = expected_orders
self.order_count = 0
self.trade_count = 0
@property
def get_id(self):
return "ALGO_TIME"
def open(self):
qmsg.BaseTransform.open(self)
self.order_socket = self.bind_order()
def bind_order(self):
return self.bind_pull_socket(self.addresses['order_address'])
def do_work(self):
"""
Pulls one message from the event feed, then
loops on orders until client sends DONE message.
"""
#next feed event
(rlist, wlist, xlist) = select([self.feed_socket],
[],
[self.feed_socket],
timeout=self.heartbeat_timeout/1000) #select timeout is in sec
self.trade_count += 1
#no more orders, should be an error condition
if len(rlist) == 0 or len(xlist) > 0:
raise Exception("unexpected end of feed stream")
message = rlist[0].recv()
if message == str(zp.CONTROL_PROTOCOL.DONE):
self.signal_done()
if(self.expected_orders > 0):
assert self.expected_orders == self.order_count
return #leave open orders hanging? client requests for orders?
event = zp.FEED_UNFRAME(message)
if self.last_iteration_duration != None:
self.algo_time = self.last_event_time + self.last_iteration_duration
else:
self.algo_time = event.dt #base case, first event we're transporting.
self.last_event_time = event.dt
if self.algo_time < self.last_event_time:
#compress time, move algo's clock to the time of this event
self.algo_time = self.last_event_time
#self.process_orders(event)
#mark the start time for client's processing of this event.
self.event_start = datetime.datetime.utcnow()
self.result_socket.send(zp.TRANSFORM_FRAME('ALGO_TIME', self.algo_time), self.zmq.NOBLOCK)
while True: #this loop should also poll for portfolio state req/rep
(rlist, wlist, xlist) = select([self.order_socket],
[],
[self.order_socket],
timeout=self.heartbeat_timeout/1000) #select timeout is in sec
#no more orders, should this be an error condition?
if len(rlist) == 0 or len(xlist) > 0:
continue
order_msg = rlist[0].recv()
if order_msg == str(zp.CONTROL_PROTOCOL.DONE):
qutil.LOGGER.info("order loop finished")
break
sid, amount = zp.ORDER_UNFRAME(order_msg)
self.add_open_order(sid, amount)
#end of order processing loop
self.last_iteration_duration = datetime.datetime.utcnow() - self.event_start
def add_open_order(self, sid, amount):
self.order_count = self.order_count + 1
def process_orders(self, event):
#TODO put real fill logic here, return a list of fills
return [{'sid':133, 'amount':-100}]

View file

@ -42,6 +42,10 @@ you have a strong desire to JSON encode ancient Sanskrit
"""
import msgpack
import numbers
import datetime
import pytz
import zipline.util as qutil
#import ujson
#import ultrajson_numpy
@ -65,10 +69,12 @@ def FrameExceptionFactory(name):
def __init__(self, got):
self.got = got
def __str__(self):
return "Invalid {framcls} Frame: {got}".format(
return "Invalid {framecls} Frame: {got}".format(
framecls = name,
got = self.got,
)
return InvalidFrame
class namedict(object):
"""
@ -81,8 +87,29 @@ class namedict(object):
For more complex structs use collections.namedtuple:
"""
def __init__(self, dct):
self.__dict__.update(dct)
def __init__(self, dct=None):
if(dct):
self.__dict__.update(dct)
def __setitem__(self, key, value):
"""Required for use by pymongo as_class parameter to find."""
if(key == '_id'):
self.__dict__['id'] = value
else:
self.__dict__[key] = value
def merge(self, other_nd):
assert isinstance(other_nd, namedict)
self.__dict__.update(other_nd.__dict__)
def __repr__(self):
return "namedict: " + str(self.__dict__)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def has_attr(self, name):
return self.__dict__.has_key(name)
# ================
# Control Protocol
@ -151,3 +178,258 @@ COMPONENT_STATE = Enum(
'DONE' , # 1
'EXCEPTION' , # 2
)
# ==================
# Datasource Protocol
# ==================
INVALID_DATASOURCE_FRAME = FrameExceptionFactory('DATASOURCE')
def DATASOURCE_FRAME(event):
"""
wraps any datasource payload with id and type, so that unpacking may choose the write
UNFRAME for the payload.
::ds_id:: an identifier that is unique to the datasource in the context of a component host (e.g. Simulator
::ds_type:: a string denoting the datasource type. Must be on of::
TRADE
(others to follow soon)
::payload:: a msgpack string carrying the payload for the frame
"""
assert isinstance(event.source_id, basestring)
assert isinstance(event.type, basestring)
if(event.type == "TRADE"):
return msgpack.dumps(tuple([event.type, TRADE_FRAME(event)]))
else:
raise INVALID_DATASOURCE_FRAME(str(event))
def DATASOURCE_UNFRAME(msg):
"""
extracts payload, and calls correct UNFRAME method based on the datasource type passed along
returns a dict containing at least::
- source_id
- type
other properties are added based on the datasource type::
- TRADE::
- sid - int security identifier
- price - float
- volume - int
- dt - a datetime object
"""
try:
ds_type, payload = msgpack.loads(msg)
assert isinstance(ds_type, basestring)
if(ds_type == "TRADE"):
return TRADE_UNFRAME(payload)
else:
raise INVALID_DATASOURCE_FRAME(msg)
except TypeError:
raise INVALID_DATASOURCE_FRAME(msg)
except ValueError:
raise INVALID_DATASOURCE_FRAME(msg)
# ==================
# Feed Protocol
# ==================
INVALID_FEED_FRAME = FrameExceptionFactory('FEED')
def FEED_FRAME(event):
"""
:event: a nameddict with at least::
- source_id
- type
"""
assert isinstance(event, namedict)
source_id = event.source_id
ds_type = event.type
PACK_DATE(event)
payload = event.__dict__
return msgpack.dumps(payload)
def FEED_UNFRAME(msg):
try:
payload = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
rval = namedict(payload)
UNPACK_DATE(rval)
return rval
except TypeError:
raise INVALID_FEED_FRAME(msg)
except ValueError:
raise INVALID_FEED_FRAME(msg)
# ==================
# Transform Protocol
# ==================
INVALID_TRANSFORM_FRAME = FrameExceptionFactory('TRANSFORM')
def TRANSFORM_FRAME(name, value):
"""
:event: a nameddict with at least::
- source_id
- type
"""
assert isinstance(name, basestring)
assert value != None
if(name == 'ALGO_TIME'):
value = PACK_ALGO_DT(value)
return msgpack.dumps(tuple([name, value]))
def TRANSFORM_UNFRAME(msg):
"""
:rtype: namedict with <transform_name>:<transform_value>
"""
try:
name, value = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(name, basestring)
if(name == "PASSTHROUGH"):
value = FEED_UNFRAME(value)
elif(name == "ALGO_TIME"):
value = UNPACK_ALGO_DT(value)
return namedict({name : value})
except TypeError:
raise INVALID_TRANSFORM_FRAME(msg)
except ValueError:
raise INVALID_TRANSFORM_FRAME(msg)
def PACK_ALGO_DT(value):
value = namedict({'dt' : value})
PACK_DATE(value)
return value.__dict__
def UNPACK_ALGO_DT(value):
value = namedict(value)
UNPACK_DATE(value)
return value.dt
# ==================
# Merge Protocol
# ==================
INVALID_MERGE_FRAME = FrameExceptionFactory('MERGE')
def MERGE_FRAME(event):
"""
:event: a nameddict with at least::
- source_id
- type
"""
assert isinstance(event, namedict)
assert isinstance(event.dt, datetime.datetime)
PACK_DATE(event)
if(event.has_attr('ALGO_TIME')):
event.ALGO_TIME = PACK_ALGO_DT(event.ALGO_TIME)
payload = event.__dict__
return msgpack.dumps(payload)
def MERGE_UNFRAME(msg):
try:
payload = msgpack.loads(msg)
#TODO: anything we can do to assert more about the content of the dict?
assert isinstance(payload, dict)
payload = namedict(payload)
if(payload.has_attr('ALGO_TIME')):
payload.ALGO_TIME = UNPACK_ALGO_DT(payload.ALGO_TIME)
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
UNPACK_DATE(payload)
return payload
except TypeError:
raise INVALID_MERGE_FRAME(msg)
except ValueError:
raise INVALID_MERGE_FRAME(msg)
# ==================
# Finance Protocol
# ==================
INVALID_ORDER_FRAME = FrameExceptionFactory('ORDER')
INVALID_TRADE_FRAME = FrameExceptionFactory('TRADE')
# ==================
# Trades
# ==================
def TRADE_FRAME(event):
""":event: should be a namedict with::
- ds_id -- the datasource id sending this trade out
- sid -- the security id
- price -- float of the price printed for the trade
- volume -- int for shares in the trade
- dt -- datetime for the trade
"""
assert isinstance(event, namedict)
assert isinstance(event.source_id, basestring)
assert event.type == "TRADE"
assert isinstance(event.sid, int)
assert isinstance(event.price, float)
assert isinstance(event.volume, int)
PACK_DATE(event)
return msgpack.dumps(tuple([event.sid, event.price, event.volume, event.epoch, event.micros, event.type, event.source_id]))
def TRADE_UNFRAME(msg):
try:
sid, price, volume, epoch, micros, source_type, source_id = msgpack.loads(msg)
assert isinstance(sid, int)
assert isinstance(price, float)
assert isinstance(volume, int)
assert isinstance(epoch, numbers.Integral)
assert isinstance(micros, numbers.Integral)
rval = namedict({'sid' : sid, 'price' : price, 'volume' : volume, 'epoch' : epoch, 'micros' : micros, 'type' : source_type, 'source_id' : source_id})
UNPACK_DATE(rval)
return rval
except TypeError:
raise INVALID_TRADE_FRAME(msg)
except ValueError:
raise INVALID_TRADE_FRAME(msg)
# =========
# Orders
# =========
def ORDER_FRAME(sid, amount):
assert isinstance(sid, int)
assert isinstance(amount, int) #no partial shares...
return msgpack.dumps(tuple([sid, amount]))
def ORDER_UNFRAME(msg):
try:
sid, amount = msgpack.loads(msg)
assert isinstance(sid, int)
assert isinstance(amount, int)
return sid, amount
except TypeError:
raise INVALID_ORDER_FRAME(msg)
except ValueError:
raise INVALID_ORDER_FRAME(msg)
# =================
# Date Helpers
# =================
def PACK_DATE(event):
assert isinstance(event.dt, datetime.datetime)
assert event.dt.tzinfo == pytz.utc #utc only please
epoch = long(event.dt.strftime('%s'))
event['epoch'] = epoch
event['micros'] = event.dt.microsecond
del(event.__dict__['dt'])
return event
def UNPACK_DATE(payload):
assert isinstance(payload.epoch, numbers.Integral)
assert isinstance(payload.micros, numbers.Integral)
dt = datetime.datetime.fromtimestamp(payload.epoch)
dt = dt.replace(microsecond = payload.micros, tzinfo = pytz.utc)
del(payload.__dict__['epoch'])
del(payload.__dict__['micros'])
payload['dt'] = dt
return payload

View file

@ -3,36 +3,78 @@ Provides data handlers that can push messages to a zipline.core.DataFeed
"""
import datetime
import random
import pytz
import zipline.util as qutil
import zipline.messaging as qmsg
import zipline.messaging as zm
import zipline.protocol as zp
class RandomEquityTrades(qmsg.DataSource):
class TradeDataSource(zm.DataSource):
def send(self, event):
""" :param dict event: is a trade event with data as per :py:func: `zipline.protocol.TRADE_FRAME`
:rtype: None
"""
event.source_id = self.get_id
message = zp.DATASOURCE_FRAME(event)
self.data_socket.send(message)
class RandomEquityTrades(TradeDataSource):
"""Generates a random stream of trades for testing."""
def __init__(self, sid, source_id, count):
qmsg.DataSource.__init__(self, source_id)
zm.DataSource.__init__(self, source_id)
self.count = count
self.incr = 0
self.sid = sid
self.trade_start = datetime.datetime.now()
self.trade_start = datetime.datetime.now().replace(tzinfo=pytz.utc)
self.minute = datetime.timedelta(minutes=1)
self.price = random.uniform(5.0, 50.0)
def get_type(self):
return 'equity_trade'
def do_work(self):
if(self.incr == self.count):
self.signal_done()
return
self.price = self.price + random.uniform(-0.05, 0.05)
self._send(self.sid, self.price, random.randrange(100,10000,100), self.trade_start + (self.minute * self.incr))
self.incr += 1
def _send(self, sid, price, volume, dt):
event = zp.namedict({'source_id': self.get_id, "type" : "TRADE", "sid":sid, "price":price, "volume":volume, "dt":dt})
self.send(event)
class SpecificEquityTrades(TradeDataSource):
"""Generates a random stream of trades for testing."""
def __init__(self, source_id, event_list):
"""
:event_list: should be a chronologically ordered list of dictionaries in the following form:
event = {
'sid' : an integer for security id,
'dt' : datetime object,
'price' : float for price,
'volume' : integer for volume
}
"""
zm.DataSource.__init__(self, source_id)
self.event_list = event_list
def get_type(self):
return 'equity_trade'
def do_work(self):
if(self.incr == self.count):
if(len(self.event_list) == 0):
self.signal_done()
return
self.price = self.price + random.uniform(-0.05, 0.05)
event = {
'sid' : self.sid,
'dt' : qutil.format_date(self.trade_start + (self.minute * self.incr)),
'price' : self.price,
'volume' : random.randrange(100,10000,100)
}
event = self.event_list.pop(0)
self.send(zp.namedict(event))
self.send(event)
self.incr += 1

0
zipline/test/dummy.py Normal file
View file

102
zipline/test/factory.py Normal file
View file

@ -0,0 +1,102 @@
import datetime
import pytz
import zipline.util as qutil
import zipline.finance.risk as risk
def createReturns(daycount, start):
i = 0
test_range = []
current = start.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
while i < daycount:
i += 1
r = daily_return(current, random.random())
test_range.append(r)
current = current + one_day
return [ x for x in test_range if(risk.trading_calendar.is_trading_day(x.date)) ]
def createReturnsFromRange(start, end):
current = start.replace(tzinfo=pytz.utc)
end = end.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
test_range = []
i = 0
while current <= end:
current = current + one_day
if(not risk.trading_calendar.is_trading_day(current)):
continue
r = daily_return(current, random.random())
i += 1
test_range.append(r)
return test_range
def createReturnsFromList(returns, start):
current = start.replace(tzinfo=pytz.utc)
one_day = datetime.timedelta(days = 1)
test_range = []
i = 0
while len(test_range) < len(returns):
if(risk.trading_calendar.is_trading_day(current)):
r = daily_return(current, returns[i])
i += 1
test_range.append(r)
current = current + one_day
return test_range
def createAlgo(filename):
algo = Algorithm()
algo.code = getCodeFromFile(filename)
algo.title = filename
algo._id = pymongo.objectid.ObjectId()
hostedAlgo = HostedAlgorithm(algo)
return hostedAlgo
def getCodeFromFile(filename):
rVal = None
with open('./test/algo_samples/' + filename, 'r') as f:
rVal = f.read()
return rVal
def create_trade(sid, price, amount, datetime):
row = {}
row['source_id'] = "test_factory"
row['type'] = "TRADE"
row['sid'] = sid
row['dt'] = datetime
row['price'] = price
row['volume'] = amount
return row
def create_trade_history(sid, prices, amounts, start_time, interval):
i = 0
trades = []
current = start_time.replace(tzinfo = pytz.utc)
while i < len(prices):
if(risk.trading_calendar.is_trading_day(current)):
trades.append(create_trade(sid, prices[i], amounts[i], current))
current = current + interval
i += 1
else:
current = current + datetime.timedelta(days=1)
return trades
def createTxn(sid, price, amount, datetime, btrid=None):
txn = Transaction(sid=sid, amount=amount, dt = datetime,
price=price, transaction_cost=-1*price*amount)
return txn
def createTxnHistory(sid, priceList, amtList, startTime, interval):
i = 0
txns = []
current = startTime
while i < len(priceList):
if(risk.trading_calendar.is_trading_day(current)):
txns.append(createTxn(sid,priceList[i],amtList[i], current))
current = current + interval
i += 1
else:
current = current + datetime.timedelta(days=1)
return txns

View file

@ -1,87 +0,0 @@
"""
Dummy simulator backported from Qexec for development on Zipline.
"""
import threading
import mock
from unittest2 import TestCase
from zipline.test.test_messaging import SimulatorTestCase
from zipline.monitor import Controller
from zipline.messaging import ComponentHost
import zipline.util as qutil
class DummyAllocator(object):
def __init__(self, ns):
self.idx = 0
self.sockets = [
'tcp://127.0.0.1:%s' % (10000 + n)
for n in xrange(ns)
]
def lease(self, n):
sockets = self.sockets[self.idx:self.idx+n]
self.idx += n
return sockets
def reaquire(self, *conn):
pass
class SimulatorBase(ComponentHost):
"""
Simulator coordinates the launch and communication of source, feed, transform, and merge components.
"""
def __init__(self, addresses, gevent_needed=False):
"""
"""
ComponentHost.__init__(self, addresses, gevent_needed)
def simulate(self):
self.run()
def get_id(self):
return "Simulator"
class ThreadSimulator(SimulatorBase):
def __init__(self, addresses):
SimulatorBase.__init__(self, addresses)
def launch_controller(self):
thread = threading.Thread(target=self.controller.run)
thread.start()
self.cuc = thread
return thread
def launch_component(self, component):
thread = threading.Thread(target=component.run)
thread.start()
return thread
class ThreadPoolExecutor(SimulatorTestCase, TestCase):
allocator = DummyAllocator(100)
def setup_logging(self):
qutil.configure_logging()
# lazy import by design
self.logger = mock.Mock()
def setup_allocator(self):
pass
def get_simulator(self, addresses):
return ThreadSimulator(addresses)
def get_controller(self):
# Allocate two more sockets
controller_sockets = self.allocate_sockets(2)
return Controller(
controller_sockets[0],
controller_sockets[1],
logging = self.logger,
)

View file

@ -0,0 +1,119 @@
"""Tests for the zipline.finance package"""
import datetime
import mock
import pytz
import host_settings
from unittest2 import TestCase
import zipline.test.factory as factory
import zipline.util as qutil
import zipline.db as db
import zipline.finance.risk as risk
import zipline.protocol as zp
from zipline.test.client import TestTradingClient
from zipline.test.dummy import ThreadPoolExecutorMixin
from zipline.sources import SpecificEquityTrades
from zipline.finance.trading import TradeSimulator
class FinanceTestCase(ThreadPoolExecutorMixin, TestCase):
def test_trade_feed_protocol(self):
trades = factory.create_trade_history(133,
[10.0,10.0,10.0,10.0],
[100,100,100,100],
datetime.datetime.strptime("02/15/2012","%m/%d/%Y"),
datetime.timedelta(days=1))
for trade in trades:
#simulate data source sending frame
msg = zp.DATASOURCE_FRAME(zp.namedict(trade))
#feed unpacking frame
recovered_trade = zp.DATASOURCE_UNFRAME(msg)
#feed sending frame
feed_msg = zp.FEED_FRAME(recovered_trade)
#transform unframing
recovered_feed = zp.FEED_UNFRAME(feed_msg)
#do a transform
trans_msg = zp.TRANSFORM_FRAME('helloworld', 2345.6)
#simulate passthrough transform -- passthrough shouldn't even unpack the msg, just resend.
passthrough_msg = zp.TRANSFORM_FRAME('PASSTHROUGH', feed_msg)
#merge unframes transform and passthrough
trans_recovered = zp.TRANSFORM_UNFRAME(trans_msg)
pt_recovered = zp.TRANSFORM_UNFRAME(passthrough_msg)
#simulated merge
pt_recovered.PASSTHROUGH.merge(trans_recovered)
#frame the merged event
merged_msg = zp.MERGE_FRAME(pt_recovered.PASSTHROUGH)
#unframe the merge and validate values
event = zp.MERGE_UNFRAME(merged_msg)
#check the transformed value, should only be in event, not trade.
self.assertTrue(event.helloworld == 2345.6)
del(event.__dict__['helloworld'])
self.assertEqual(zp.namedict(trade), event)
def test_order_protocol(self):
order_msg = zp.ORDER_FRAME(133, 100)
sid, amount = zp.ORDER_UNFRAME(order_msg)
self.assertEqual(sid, 133)
self.assertEqual(amount, 100)
def test_trading_calendar(self):
known_trading_day = datetime.datetime.strptime("02/24/2012","%m/%d/%Y")
known_holiday = datetime.datetime.strptime("02/20/2012", "%m/%d/%Y") #president's day
saturday = datetime.datetime.strptime("02/25/2012", "%m/%d/%Y")
self.assertTrue(risk.trading_calendar.is_trading_day(known_trading_day))
self.assertFalse(risk.trading_calendar.is_trading_day(known_holiday))
self.assertFalse(risk.trading_calendar.is_trading_day(saturday))
def test_orders(self):
# Just verify sending and receiving orders.
# --------------
# Allocate sockets for the simulator components
sockets = self.allocate_sockets(6)
addresses = {
'sync_address' : sockets[0],
'data_address' : sockets[1],
'feed_address' : sockets[2],
'merge_address' : sockets[3],
'result_address' : sockets[4],
'order_address' : sockets[5]
}
sim = self.get_simulator(addresses)
con = self.get_controller()
# Simulation Components
# ---------------------
set1 = SpecificEquityTrades("flat-133",factory.create_trade_history(133,
[10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0,10.0],
[100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100],
datetime.datetime.strptime("02/1/2012","%m/%d/%Y"),
datetime.timedelta(days=1)))
client = TestTradingClient(10)
order_sim = TradeSimulator(expected_orders=10)
sim.register_components([client, order_sim, set1])
sim.register_controller( con )
# Simulation
# ----------
sim.simulate()
# Stop Running
# ------------
# TODO: less abrupt later, just shove a StopIteration
# down the pipe to make it stop spinning
sim.cuc._Thread__stop()
self.assertEqual(sim.feed.pending_messages(), 0,
"The feed should be drained of all messages, found {n} remaining."
.format(n=sim.feed.pending_messages())
)

View file

@ -13,6 +13,7 @@ from gevent_zeromq import zmq
ctx = zmq.Context()
#TODO: disabled by prefixing the test methods with a d
class TestControlProtocol(TestCase):
def setUpController(self):
@ -40,7 +41,7 @@ class TestControlProtocol(TestCase):
msg.join()
self.assertEqual(msg.value, message)
def test_control_message(self):
def dtest_control_message(self):
sub = self.controller.message_listener(context=ctx)
message = gevent.spawn(self.asyncMessage, sub)
@ -55,7 +56,7 @@ class TestControlProtocol(TestCase):
sub.close()
push.close()
def test_control_delivery(self):
def dtest_control_delivery(self):
# Assert that the number of messages sent on the wire is
# the number of messages received, ie we don't drop any.
# This is of course depenendent on the topology of the

View file

@ -31,16 +31,16 @@ class MovingAverage(BaseTransform):
"""
self.events.append(event)
self.current_total += event['price']
event_date = qutil.parse_date(event['dt'])
self.current_total += event.price
event_date = event.dt
index = 0
for cur_event in self.events:
cur_date = qutil.parse_date(cur_event['dt'])
if(cur_date - event_date):
cur_date = cur_event.dt
if(cur_date - event_date) >= self.window:
self.events.pop(index)
self.current_total -= cur_event['price']
self.current_total -= cur_event.price
index += 1
else:
break

View file

@ -6,8 +6,9 @@ and other common operations.
import datetime
import pytz
import logging
import logging.handlers
LOGGER = logging.getLogger('QSimLogger')
LOGGER = logging.getLogger('ZiplineLogger')
def configure_logging(loglevel=logging.DEBUG):
"""
@ -25,27 +26,3 @@ def configure_logging(loglevel=logging.DEBUG):
)
LOGGER.addHandler(handler)
LOGGER.info("logging started...")
def parse_date(dt_str):
"""
Parse strings according to the same format as generated by
format_date.
"""
if(dt_str == None):
return None
parts = dt_str.split(".")
dt = datetime.datetime.strptime(parts[0], '%Y/%m/%d-%H:%M:%S').replace(
microsecond=int(parts[1]+"000")).replace(tzinfo = pytz.utc
)
return dt
def format_date(dt):
"""
Format the date into a date with millesecond resolution and
string/alphabetical sorting that is equivalent to datetime sorting.
"""
if(dt == None):
return None
dt_str = dt.strftime('%Y/%m/%d-%H:%M:%S') + "." + str(dt.microsecond / 1000)
return dt_str