import pandas as pd
import pandas_market_calendars as mcal
from dateutil.relativedelta import relativedelta
import datetime
from datetime import date
import numpy as np
import copy
import math
import json
from typing import Union, List, Tuple, Callable, overload
from warnings import warn
import os
from IPython.display import clear_output, display, HTML, display_html
from dataclasses import dataclass
import multiprocessing
import threading
from collections.abc import MutableSequence
import uuid
import time
import tabulate
from .data_providers import (
DailyPriceProvider,
YahooFinanceProvider,
DailyDataProvider,
PriceUnavailableError,
DataProvider,
)
from .fees import NoFee, Fee, InsufficientCapitalError
from .metrics import (
MaxDrawdown,
AnnualReturn,
PortfolioValue,
DailyProfitLoss,
TotalValue,
TotalReturn,
Metric,
)
from .strategy import Strategy, BuyAndHold
from .exceptions import BacktestRunError, LongShortLiquidationError, NegativeValueError
from .utils import is_notebook, _cls
# matplotlib is not a strict requirement, only needed for live_plot
try:
import pylab as pl
import matplotlib.pyplot as plt
import matplotlib
plt_exists = True
except ImportError:
plt_exists = False
# tqdm is not a strict requirement
try:
from tqdm import tqdm
tqdm_exists = True
except ImportError:
tqdm_exists = False
# from https://stackoverflow.com/a/44923103, https://stackoverflow.com/a/50899244
[docs]def display_side_by_side(bts):
html_str = ""
for bt in bts:
# styler = bt.logs.style.set_table_attributes(
# "style='display:inline'"
# ).set_caption(bt.name)
# html_str += styler._repr_html_()
pass
display_html(bt.logs._repr_html_(), raw=True)
[docs]class StrategySequence:
"""A sequence of strategies than can be accessed by name or :class:`int` index.\
Returned by :py:obj:`.Backtester.strategies` and should not be used elsewhere.
Examples:
Access by :class:`str`::
bt.strategies['Some Strategy Name']
Access by :class:`int`::
bt.strategies[0]
Use as iterator::
for strategy in bt.strategies:
# do something
"""
def __init__(self, bt):
self.bt = bt
self.i = 0
def __getitem__(self, index: Union[str, int]):
if isinstance(index, int):
bt = self.bt._get_bts()[index]
bt._from_sequence = True
return bt
elif isinstance(index, str):
for i, bt in enumerate(self.bt._get_bts()):
if bt.name is not None:
if bt.name == index:
bt._from_sequence = True
return bt
else:
if f"Backtest {i}" == index:
bt._from_sequence = True
return bt
raise IndexError
def __iter__(self):
return self
def __next__(self):
i = self.i
self.i += 1
return self[i]
def __len__(self):
return len(self.bt._get_bts())
[docs]class Position:
"""Tracks a single position in a portfolio or trade history.
"""
def __init__(
self,
bt: "Backtester",
symbol: str,
date: datetime.date,
event: str,
nshares: int,
uid: str,
fee: float,
slippage: float = None,
):
self.symbol = symbol
self.date = date
self.event = event
self._nshares_int = nshares
self.start_price = bt.price(symbol)
if slippage is not None:
if nshares < 0:
self.start_price *= 1 + slippage
if nshares > 0:
self.start_price *= 1 - slippage
self._slippage = slippage
self._bt = bt
self._frozen = False
self._uid = uid
self.fee = fee
def _attr(self):
return [attr for attr in dir(self) if not attr.startswith("_")]
def __repr__(self) -> str:
result = {}
for attr in self._attr():
val = getattr(self, attr)
if isinstance(val, float):
result[attr] = f"{val:.2f}"
else:
result[attr] = str(val)
return json.dumps(result, sort_keys=True, indent=2)
@property
def _short(self) -> bool:
"""True if this is a short position.
"""
return self._nshares_int < 0
@property
def _long(self) -> bool:
"""True if this is a long position.
"""
return self._nshares_int > 0
@property
def value(self) -> float:
"""Returns the current market value of the position.
"""
if self._short:
old_val = self.initial_value
cur_val = self.nshares * self.price
return old_val + (old_val - cur_val)
if self._long:
return self.nshares * self.price
@property
def price(self) -> float:
"""Returns the current price if the position is held in a portfolio.
Returns the last price if the position was liquidated and is part of a trade history.
"""
if self._frozen:
result = self._bt.prices[self.symbol, self.end_date][self.end_event]
else:
result = self._bt.price(self.symbol)
if self._slippage is not None:
if self._short:
result *= 1 - self._slippage
if self._long:
result *= 1 + self._slippage
return result
@property
def value_pershare(self) -> float:
"""Returns the value of the position per share.
"""
if self._long:
return self.price
if self._short:
return self.start_price + (self.start_price - self.price)
@property
def initial_value(self) -> float:
"""Returns the initial value of the position, including fees.
"""
if self._short:
return self.nshares * self.start_price + self.fee
if self._long:
return self.nshares * self.start_price + self.fee
@property
def profit_loss_pct(self) -> float:
"""Returns the profit/loss associated with the position (not including commission)
in relative terms.
"""
return self.value / self.initial_value - 1
@property
def profit_loss_abs(self) -> float:
"""Returns the profit/loss associated with the position (not including commission)
in absolute terms.
"""
return self.value - self.initial_value
@property
def nshares(self) -> int:
"""Returns the number of shares in the position.
"""
return abs(self._nshares_int)
@property
def order_type(self) -> str:
"""Returns "long" or "short" based on the position type.
"""
t = None
if self._short:
t = "short"
if self._long:
t = "long"
return t
def _remove_shares(self, n):
if self._short:
self._nshares_int += n
if self._long:
self._nshares_int -= n
def _freeze(self):
self._frozen = True
self.end_date = self._bt.current_date
self.end_event = self._bt.event
[docs]class Portfolio(MutableSequence):
"""A portfolio is a collection of :class:`.Position` objects,
and can be used to :meth:`.liquidate` a subset of them.
"""
def __init__(self, bt, positions: List[Position] = []):
self.positions = positions
self.bt = bt
@property
def total_value(self) -> float:
"""Returns the total value of the portfolio.
"""
val = 0
for pos in self.positions:
val += pos.value
return val
@property
def df(self) -> pd.DataFrame:
pos_dict = {}
for pos in self.positions:
for col in pos._attr():
if col not in pos_dict:
pos_dict[col] = []
pos_dict[col].append(getattr(pos, col))
return pd.DataFrame(pos_dict)
def _get_by_uid(self, uid) -> Position:
for pos in self.positions:
if pos._uid == uid:
return pos
def __repr__(self) -> str:
return self.positions.__repr__()
[docs] def liquidate(self, nshares: int = -1, _bt: "Backtester" = None):
"""Liquidate all positions in the current "view" of the portfolio.
If no view is given using `['some_ticker']`, :meth:`.filter`,
:meth:`.Portfolio.long` or :meth:`.Portfolio.short`,
an attempt to liquidate all positions is made.
Args:
nshares:
The number of shares to be liquidated.
This should only be used when a ticker is selected using `['some_ticker']`.
Examples:
Select all `MSFT` positions and liquidate them::
bt.portfolio['MSFT'].liquidate()
Liquidate 10 `MSFT` shares::
bt.portfolio['MSFT'].liquidate(nshares=10)
Liquidate all long positions::
bt.portfolio.long.liquidate()
Liquidate all positions that have lost more than 5% in value.
We can either use :meth:`.filter` or the dataframe as indexer
(in this case in combination with the pf shorthand)::
bt.pf[bt.pf.df['profit_loss_pct'] < -0.05].liquidate()
# or
bt.pf.filter(lambda x: x.profit_loss_pct < -0.05)
"""
bt = _bt
if bt is None:
bt = self.bt
if bt._slippage is not None:
self.liquidate(nshares, bt.lower_bound)
is_long = False
is_short = False
for pos in self.positions:
if pos._long:
is_long = True
if pos._short:
is_short = True
if is_long and is_short:
bt._graceful_stop()
raise LongShortLiquidationError(
"liquidating a mix of long and short positions is not possible"
)
for pos in copy.copy(self.positions):
pos = bt.pf._get_by_uid(pos._uid)
if nshares == -1 or nshares >= pos.nshares:
bt._available_capital += pos.value
if bt._available_capital < 0:
bt._graceful_stop()
raise NegativeValueError(
f"Tried to liquidate position resulting in negative capital {bt._available_capital}."
)
bt.portfolio._remove(pos)
pos._freeze()
bt.trades._add(copy.copy(pos))
if nshares != -1:
nshares -= pos.nshares
elif nshares > 0 and nshares < pos.nshares:
bt._available_capital += pos.value_pershare * nshares
pos._remove_shares(nshares)
hist = copy.copy(pos)
hist._freeze()
if hist._short:
hist._nshares_int = (-1) * nshares
if hist._long:
hist._nshares_int = nshares
bt.trades._add(hist)
break
def _add(self, position):
self.positions.append(position)
def _remove(self, position):
self.positions = [pos for pos in self.positions if pos._uid != position._uid]
@overload
def __getitem__(self, index: Union[int, slice]):
...
@overload
def __getitem__(self, index: str):
...
@overload
def __getitem__(self, index: Union[np.ndarray, pd.Series, List[bool]]):
...
def __getitem__(self, index):
if isinstance(index, (int, slice)):
return Portfolio(self.bt, copy.copy(self.positions[index]))
if isinstance(index, str):
new_pos = []
for pos in self.positions:
if pos.symbol == index:
new_pos.append(pos)
return Portfolio(self.bt, new_pos)
if isinstance(index, (np.ndarray, pd.Series, List[bool])):
if len(index) > 0:
new_pos = list(np.array(self.bt.portfolio.positions)[index])
else:
new_pos = []
return Portfolio(self.bt, new_pos)
def __setitem__(self, index, value):
self.positions[index] = value
def __delitem__(self, index: Union[int, slice]) -> None:
del self.positions[index]
def __len__(self):
return len(self.positions)
def __bool__(self):
return len(self) != 0
[docs] def insert(self, index: int, value: Position) -> None:
self.positions.insert(index, value)
@property
def short(self) -> "Portfolio":
"""Returns a view of the portfolio (which can be treated as its own :class:`.Portfolio`)
containing all *short* positions.
"""
new_pos = []
for pos in self.positions:
if pos._short:
new_pos.append(pos)
return Portfolio(self.bt, new_pos)
@property
def long(self) -> "Portfolio":
"""Returns a view of the portfolio (which can be treated as its own :class:`.Portfolio`)
containing all *long* positions.
"""
new_pos = []
for pos in self.positions:
if pos._long:
new_pos.append(pos)
return Portfolio(self.bt, new_pos)
[docs] def filter(self, func: Callable[[Position], bool]) -> "Portfolio":
"""Filters positions using any :class`.Callable`
Args:
func: The function/callable to do the filtering.
"""
new_pos = []
for pos in self.positions:
if func(pos):
new_pos.append(pos)
return Portfolio(self.bt, new_pos)
[docs] def attr(self, attribute: str) -> List:
"""Get a list of values for a certain value for all posititions
Args:
attribute:
String name of the attribute to get.
Can be any attribute of :class:`.Position`.
"""
self.bt._warn.append(
f"""
.attr will be removed in 0.7
you can use b.portfolio.df[{attribute}]
instead of b.portfolio.attr('{attribute}')
"""
)
result = [getattr(pos, attribute) for pos in self.positions]
if len(result) == 0:
return None
elif len(result) == 1:
return result[0]
else:
return result
[docs]class BacktesterBuilder:
"""
Used to configure a :class:`.Backtester`
and then creating it with :meth:`.build`
Example:
Create a new :class:`.Backtester` with 10,000 starting balance
which runs on days the `NYSE`_ is open::
bt = BacktesterBuilder().balance(10_000).calendar('NYSE').build()
.. _NYSE:
https://www.nyse.com/index
"""
def __init__(self):
self.bt = copy.deepcopy(Backtester())
[docs] def name(self, name: str) -> "BacktesterBuilder":
"""**Optional**, name will be set to "Backtest 0" if not specified.
Set the name of the strategy run using the :class:`.Backtester` iterator.
Args:
name: The strategy name.
"""
self = copy.deepcopy(self)
self.bt.name = name
return self
[docs] def balance(self, amount: int) -> "BacktesterBuilder":
"""**Required**, set the starting balance for all :class:`.Strategy` objects
run with the :class:`.Backtester`
Args:
amount: The starting balance.
"""
self = copy.deepcopy(self)
self.bt._capital = amount
self.bt._available_capital = amount
self.bt._start_capital = amount
return self
[docs] def prices(self, prices: DailyPriceProvider) -> "BacktesterBuilder":
"""**Optional**, set the :class:`.DailyPriceProvider` used to get prices during
a backtest. If this is not called, :class:`.YahooPriceProvider`
is used.
Args:
prices: The price provider.
"""
self = copy.deepcopy(self)
self.bt.prices = prices
self.bt.prices.bt = self.bt
return self
[docs] def data(self, data: DataProvider) -> "BacktesterBuilder":
"""**Optional**, add a :class:`.DataProvider` to use external data without time leaks.
Args:
data: The data provider.
"""
self = copy.deepcopy(self)
self.bt.data[data.name] = data
data.bt = self.bt
return self
[docs] def trade_cost(
self, trade_cost: Union[Fee, Callable[[float, float], Tuple[float, int]]]
) -> "BacktesterBuilder":
"""**Optional**, set a :class:`.Fee` to be applied when buying shares.
When not set, :class:`.NoFee` is used.
Args:
trade_cost: one ore more :class:`.Fee` objects or callables.
"""
self = copy.deepcopy(self)
self.bt._trade_cost = trade_cost
return self
[docs] def metrics(self, metrics: Union[Metric, List[Metric]]) -> "BacktesterBuilder":
"""**Optional**, set additional :class:`.Metric` objects to be used.
Args:
metrics: one or more :class:`.Metric` objects
"""
self = copy.deepcopy(self)
if isinstance(metrics, list):
for m in metrics:
for m in metrics:
m.bt = self.bt
self.bt.metric[m.name] = m
else:
metrics.bt = self.bt
self.bt.metric[metrics.name] = metrics
return self
[docs] def clear_metrics(self) -> "BacktesterBuilder":
"""**Optional**, remove all default metrics,
except :class:`.PortfolioValue`, which is needed internally.
"""
self = copy.deepcopy(self)
metrics = [PortfolioValue()]
self.bt.metric = {}
self.bt.metric(metrics)
return self
[docs] def calendar(self, calendar: str) -> "BacktesterBuilder":
"""**Optional**, set a `pandas market calendar`_ to be used.
If not called, "NYSE" is used.
Args:
calendar: the calendar identifier
.. _pandas market calendar:
https://pandas-market-calendars.readthedocs.io/en/latest/calendars.html
"""
self = copy.deepcopy(self)
self.bt._calendar = calendar
return self
[docs] def live_metrics(self, every: int = 10) -> "BacktesterBuilder":
"""**Optional**, shows all metrics live in output. This can be useful
when running simple-back from terminal.
Args:
every: how often metrics should be updated
(in events, e.g. 10 = 5 days)
"""
self = copy.deepcopy(self)
if self.bt._live_plot:
warn(
"""
live plotting and metrics cannot be used together,
setting live plotting to false
"""
)
self.bt._live_plot = False
self.bt._live_metrics = True
self.bt._live_metrics_every = every
return self
[docs] def no_live_metrics(self) -> "BacktesterBuilder":
"""Disables showing live metrics.
"""
self = copy.deepcopy(self)
self.bt._live_metrics = False
return self
[docs] def live_plot(
self,
every: int = None,
metric: str = "Total Value",
figsize: Tuple[float, float] = None,
min_y: int = 0,
blocking: bool = False,
) -> "BacktesterBuilder":
"""**Optional**, shows the backtest results live using matplotlib.
Can only be used in notebooks.
Args:
every:
how often metrics should be updated
(in events, e.g. 10 = 5 days)
the regular default is 10,
blocking default is 100
metric: which metric to plot
figsize: size of the plot
min_y:
minimum value on the y axis, set to `None`
for no lower limit
blocking:
will disable threading for plots and
allow live plotting in terminal,
this will slow down the backtester
significantly
"""
self = copy.deepcopy(self)
if self.bt._live_metrics:
warn(
"""
live metrics and plotting cannot be used together,
setting live metrics to false
"""
)
self.bt._live_metrics = False
if is_notebook():
if every is None:
every = 10
elif not blocking:
warn(
"""
live plots use threading which is not supported
with matplotlib outside notebooks. to disable
threading for live plots, you can call
live_plot with ``blocking = True``.
live_plot set to false.
"""
)
return self
elif blocking:
self.bt._live_plot_blocking = True
if every is None:
every = 100
self.bt._live_plot = True
self.bt._live_plot_every = every
self.bt._live_plot_metric = metric
self.bt._live_plot_figsize = figsize
self.bt._live_plot_min = min_y
return self
[docs] def no_live_plot(self) -> "BacktesterBuilder":
"""Disables showing live plots.
"""
self = copy.deepcopy(self)
self.bt._live_plot = False
return self
[docs] def live_progress(self, every: int = 10) -> "BacktesterBuilder":
"""**Optional**, shows a live progress bar using :class:`.tqdm`, either
as port of a plot or as text output.
"""
self = copy.deepcopy(self)
self.bt._live_progress = True
self.bt._live_progress_every = every
return self
[docs] def no_live_progress(self) -> "BacktesterBuilder":
"""Disables the live progress bar.
"""
self = copy.deepcopy(self)
self.bt._live_progress = False
return self
[docs] def compare(
self,
strategies: List[
Union[Callable[["datetime.date", str, "Backtester"], None], Strategy, str]
],
):
"""**Optional**, alias for :meth:`.BacktesterBuilder.strategies`,
should be used when comparing to :class:`.BuyAndHold` of a ticker instead of other strategies.
Args:
strategies:
should be the string of the ticker to compare to,
but :class:`.Strategy` objects can be passed as well
"""
self = copy.deepcopy(self)
return self.strategies(strategies)
[docs] def strategies(
self,
strategies: List[
Union[Callable[["datetime.date", str, "Backtester"], None], Strategy, str]
],
) -> "BacktesterBuilder":
"""**Optional**, sets :class:`.Strategy` objects to run.
Args:
strategies:
list of :class:`.Strategy` objects or tickers to :class:`.BuyAndHold`
"""
self = copy.deepcopy(self)
strats = []
for strat in strategies:
if isinstance(strat, str):
strats.append(BuyAndHold(strat))
else:
strats.append(strat)
self.bt._temp_strategies = strats
self.bt._has_strategies = True
return self
[docs] def slippage(self, slippage: int = 0.0005):
"""**Optional**, sets slippage which will create a (lower bound) strategy.
The orginial strategies will run without slippage.
Args:
slippage:
the slippage in percent of the base price,
default is equivalent to quantopian default for US Equities
"""
self = copy.deepcopy(self)
self.bt._slippage = slippage
return self
[docs] def build(self) -> "Backtester":
"""Build a :class:`.Backtester` given the previous configuration.
"""
self = copy.deepcopy(self)
self.bt._builder = self
return copy.deepcopy(self.bt)
[docs]class Backtester:
"""The :class:`.Backtester` object is yielded alongside
the current day and event (open or close)
when it is called with a date range,
which can be of the following forms.
The :class:`.Backtester` object stores information
about the backtest after it has completed.
Examples:
Initialize with dates as strings::
bt['2010-1-1','2020-1-1'].run()
# or
for day, event, b in bt['2010-1-1','2020-1-1']:
...
Initialize with dates as :class:`.datetime.date` objects::
bt[datetime.date(2010,1,1),datetime.date(2020,1,1)]
Initialize with dates as :class:`.int`::
bt[-100:] # backtest 100 days into the past
"""
def __getitem__(self, date_range: slice) -> "Backtester":
if self._run_before:
raise BacktestRunError(
"Backtest has already run, build a new backtester to run again."
)
self._run_before = True
if self.assume_nyse:
self._calendar = "NYSE"
if date_range.start is not None:
start_date = date_range.start
else:
raise ValueError("a date range without a start value is not allowed")
if date_range.stop is not None:
end_date = date_range.stop
else:
self._warn.append(
"backtests with no end date can lead to non-replicable results"
)
end_date = datetime.date.today() - relativedelta(days=1)
cal = mcal.get_calendar(self._calendar)
if isinstance(start_date, relativedelta):
start_date = datetime.date.today() + start_date
if isinstance(end_date, relativedelta):
end_date = datetime.date.today() + end_date
sched = cal.schedule(start_date=start_date, end_date=end_date)
self._schedule = sched
self.dates = mcal.date_range(sched, frequency="1D")
self.datetimes = []
self.dates = [d.date() for d in self.dates]
for date in self.dates:
self.datetimes += [
sched.loc[date.isoformat()]["market_open"],
sched.loc[date.isoformat()]["market_close"],
]
if self._has_strategies:
self._set_strategies(self._temp_strategies)
return self
def _init_slippage(self, bt=None):
if bt is None:
bt = self
lower_bound = copy.deepcopy(bt)
lower_bound._strategies = []
lower_bound._set_self()
lower_bound.name += " (lower bound)"
lower_bound._has_strategies = False
lower_bound._slippage_percent = (-1) * self._slippage
lower_bound._slippage = None
lower_bound._init_iter(lower_bound)
bt.lower_bound = lower_bound
self._strategies.append(lower_bound)
def __init__(self):
self.dates = []
self.assume_nyse = False
self.prices = YahooFinanceProvider()
self.prices.bt = self
self.portfolio = Portfolio(self)
self.trades = copy.deepcopy(Portfolio(self))
self._trade_cost = NoFee()
metrics = [
MaxDrawdown(),
AnnualReturn(),
PortfolioValue(),
TotalValue(),
TotalReturn(),
DailyProfitLoss(),
]
self.metric = {}
for m in metrics:
m.bt = self
self.metric[m.name] = m
self.data = {}
self._start_capital = None
self._available_capital = None
self._capital = None
self._live_plot = False
self._live_plot_figsize = None
self._live_plot_metric = "Total Value"
self._live_plot_figsize = None
self._live_plot_min = None
self._live_plot_axes = None
self._live_plot_blocking = False
self._live_metrics = False
self._live_progress = False
self._strategies = []
self._temp_strategies = []
self._has_strategies = False
self.name = "Backtest"
self._no_iter = False
self._schedule = None
self._warn = []
self._log = []
self._add_metrics = {}
self._add_metrics_lines = []
self.datetimes = None
self.add_metric_exists = False
self._run_before = False
self._last_thread = None
self._from_sequence = False
self._slippage = None
self._slippage_percent = None
def _set_self(self, new_self=None):
if new_self is not None:
self = new_self
self.portfolio.bt = self
self.trades.bt = self
self.prices.bt = self
for m in self.metric.values():
m.__init__()
m.bt = self
def _init_iter(self, bt=None):
global _live_progress_pbar
if bt is None:
bt = self
if bt.assume_nyse:
self._warn.append("no market calendar specified, assuming NYSE calendar")
if bt._available_capital is None or bt._capital is None:
raise ValueError(
"initial balance not specified, you can do so using .balance"
)
if bt.dates is None or len(bt.dates) == 0:
raise ValueError(
"no dates selected, you can select dates using [start_date:end_date]"
)
bt.i = -1
bt.event = "close"
if self._live_progress:
_live_progress_pbar = tqdm(total=len(self))
_cls()
if self._slippage is not None and not self._no_iter:
self._init_slippage(self)
self._has_strategies = True
return self
def _next_iter(self, bt=None):
if bt is None:
bt = self
if bt.i == len(self):
bt._init_iter()
if bt.event == "open":
bt.event = "close"
bt.i += 1
elif bt.event == "close":
try:
bt.i += 1
bt.current_date = bt.dates[bt.i // 2].isoformat()
bt.event = "open"
except IndexError:
bt.i -= 1
for metric in bt.metric.values():
if metric._single:
metric(write=True)
if self._has_strategies:
for strat in self._strategies:
for metric in strat.metric.values():
if metric._single:
metric(write=True)
self._plot(self._get_bts(), last=True)
raise StopIteration
bt._update()
return bt.current_date, bt.event, bt
[docs] def add_metric(self, key: str, value: float):
"""Called inside the backtest, adds a metric that is visually tracked.
Args:
key: the metric name
value: the numerical value of the metric
"""
if key not in self._add_metrics:
self._add_metrics[key] = (
np.repeat(np.nan, len(self)),
np.repeat(True, len(self)),
)
self._add_metrics[key][0][self.i] = value
self._add_metrics[key][1][self.i] = False
[docs] def add_line(self, **kwargs):
"""Adds a vertical line on the plot on the current date + event.
"""
self._add_metrics_lines.append((self.timestamp, kwargs))
[docs] def log(self, text: str):
"""Adds a log text on the current day and event that can be accessed using :obj:`.logs`
after the backtest has completed.
Args:
text: text to log
"""
self._log.append([self.current_date, self.event, text])
@property
def timestamp(self):
"""Returns the current timestamp, which includes the correct open/close time,
depending on the calendar that was set using :meth:`.BacktesterBuilder.calendar`
"""
return self._schedule.loc[self.current_date][f"market_{self.event}"]
def __iter__(self):
return self._init_iter()
def __next__(self):
if len(self._strategies) > 0:
result = self._next_iter()
self._run_once()
self._plot([self] + self._strategies)
else:
result = self._next_iter()
self._plot([self])
return result
def __len__(self):
return len(self.dates) * 2
def _show_live_metrics(self, bts=None):
_cls()
lines = []
bt_names = []
if bts is not None:
if not self._no_iter and self not in bts:
bts = [self] + bts
for i, bt in enumerate(bts):
if bt.name is None:
name = f"Backtest {i}"
else:
name = bt.name
name = f"{name:20}"
if len(name) > 20:
name = name[:18]
name += ".."
bt_names.append(name)
lines.append(f"{'':20} {''.join(bt_names)}")
if bts is None:
bts = [self]
for mkey in self.metric.keys():
metrics = []
for bt in bts:
metric = bt.metric[mkey]
if str(metric) == "None":
metric = f"{metric():.2f}"
metric = f"{str(metric):20}"
metrics.append(metric)
lines.append(f"{mkey:20} {''.join(metrics)}")
for line in lines:
print(line)
if self._live_progress:
print()
print(self._show_live_progress())
def _show_live_plot(self, bts=None, start_end=None):
if not plt_exists:
self._warn.append(
"matplotlib not installed, setting live plotting to false"
)
self._live_plot = False
return None
plot_df = pd.DataFrame()
plot_df["Date"] = self.datetimes
plot_df = plot_df.set_index("Date")
plot_add_df = plot_df.copy()
add_metric_exists = False
main_col = []
bound_col = []
for i, bt in enumerate(bts):
metric = bt.metric[self._live_plot_metric].values
name = f"Backtest {i}"
if bt.name is not None:
name = bt.name
if bt._slippage_percent is not None:
bound_col.append(name)
else:
main_col.append(name)
plot_df[name] = metric
for mkey in bt._add_metrics.keys():
add_metric = bt._add_metrics[mkey]
plot_add_df[mkey] = np.ma.masked_where(add_metric[1], add_metric[0])
if not self.add_metric_exists:
self.add_metric_exists = True
if self._live_plot_figsize is None:
if add_metric_exists:
self._live_plot_figsize = (10, 13)
else:
self._live_plot_figsize = (10, 6.5)
if self.add_metric_exists:
fig, axes = plt.subplots(
2, 1, sharex=True, figsize=self._live_plot_figsize, num=0
)
else:
fig, axes = plt.subplots(
1, 1, sharex=True, figsize=self._live_plot_figsize, num=0
)
axes = [axes]
if self._live_progress:
axes[0].set_title(str(self._show_live_progress()))
plot_df[main_col].plot(ax=axes[0])
try:
if self._slippage is not None:
for col in main_col:
axes[0].fill_between(
plot_df.index,
plot_df[f"{col} (lower bound)"],
plot_df[f"{col}"],
alpha=0.1,
)
except KeyError:
pass
if self._live_plot_min is not None:
axes[0].set_ylim(bottom=self._live_plot_min)
plt.tight_layout()
if self.add_metric_exists:
try:
interp_df = plot_add_df.interpolate(method="linear")
interp_df.plot(ax=axes[1], cmap="Accent")
for bt in bts:
for line in bt._add_metrics_lines:
plt.axvline(line[0], **line[1])
except TypeError:
pass
fig.autofmt_xdate()
if start_end is not None:
plt.xlim([start_end[0], start_end[1]])
else:
plt.xlim([self.dates[0], self.dates[-1]])
clear_output(wait=True)
plt.draw()
plt.pause(0.001)
if self._live_plot_blocking:
plt.clf() # needed to prevent overlapping tick labels
captions = []
for bt in bts:
captions.append(bt.name)
has_logs = False
for bt in bts:
if len(bt._log) > 0:
has_logs = True
if has_logs:
display_side_by_side(bts)
for w in self._warn:
warn(w)
def _show_live_progress(self):
_live_progress_pbar.n = self.i + 1
return _live_progress_pbar
def _update(self):
for metric in self.metric.values():
if metric._series:
try:
metric(write=True)
except PriceUnavailableError as e:
if self.event == "close":
self.i -= 2
if self.event == "open":
self.i -= 1
self._warn.append(
f"{e.symbol} discontinued on {self.current_date}, liquidating at previous day's {self.event} price"
)
self.current_date = self.dates[(self.i // 2)].isoformat()
self.portfolio[e.symbol].liquidate()
metric(write=True)
if self.event == "close":
self.i += 2
if self.event == "open":
self.i += 1
self.current_date = self.dates[(self.i // 2)].isoformat()
self._capital = self._available_capital + self.metric["Portfolio Value"][-1]
def _graceful_stop(self):
if self._last_thread is not None:
self._last_thread.join()
del self._last_thread
self._plot(self._get_bts(), last=True)
def _order(
self,
symbol,
capital,
as_percent=False,
as_percent_available=False,
shares=None,
uid=None,
):
if uid is None:
uid = uuid.uuid4()
if self._slippage is not None:
self.lower_bound._order(
symbol, capital, as_percent, as_percent_available, shares, uid
)
self._capital = self._available_capital + self.metric["Portfolio Value"]()
if capital < 0:
short = True
capital = (-1) * capital
else:
short = False
if not as_percent and not as_percent_available:
if capital > self._available_capital:
self._graceful_stop()
raise InsufficientCapitalError("not enough capital available")
elif as_percent:
if abs(capital * self._capital) > self._available_capital:
if not math.isclose(capital * self._capital, self._available_capital):
self._graceful_stop()
raise InsufficientCapitalError(
f"""
not enough capital available:
ordered {capital} * {self._capital}
with only {self._available_capital} available
"""
)
elif as_percent_available:
if abs(capital * self._available_capital) > self._available_capital:
if not math.isclose(
capital * self._available_capital, self._available_capital
):
self._graceful_stop()
raise InsufficientCapitalError(
f"""
not enough capital available:
ordered {capital} * {self._available_capital}
with only {self._available_capital} available
"""
)
current_price = self.price(symbol)
if self._slippage_percent is not None:
if short:
current_price *= 1 + self._slippage_percent
else:
current_price *= 1 - self._slippage_percent
if as_percent:
capital = capital * self._capital
if as_percent_available:
capital = capital * self._available_capital
try:
if shares is None:
fee_dict = self._trade_cost(current_price, capital)
nshares, total, fee = (
fee_dict["nshares"],
fee_dict["total"],
fee_dict["fee"],
)
else:
fee_dict = self._trade_cost(
current_price, self._available_capital, nshares=shares
)
nshares, total, fee = (
fee_dict["nshares"],
fee_dict["total"],
fee_dict["fee"],
)
except Exception as e:
self._graceful_stop()
raise e
if short:
nshares *= -1
if nshares != 0:
self._available_capital -= total
pos = Position(
self,
symbol,
self.current_date,
self.event,
nshares,
uid,
fee,
self._slippage_percent,
)
self.portfolio._add(pos)
else:
_cls()
raise Exception(
f"""
not enough capital specified to order a single share of {symbol}:
tried to order {capital} of {symbol}
with {symbol} price at {current_price}
"""
)
[docs] def long(self, symbol: str, **kwargs):
"""Enter a long position of the given symbol.
Args:
symbol: the ticker to buy
kwargs:
one of either
"percent" as a percentage of total value (cash + positions),
"absolute" as an absolute value,
"percent_available" as a percentage of remaining funds (excluding positions)
"nshares" as a number of shares
"""
if "percent" in kwargs:
self._order(symbol, kwargs["percent"], as_percent=True)
if "absolute" in kwargs:
self._order(symbol, kwargs["absolute"])
if "percent_available" in kwargs:
self._order(symbol, kwargs["percent_available"], as_percent_available=True)
if "nshares" in kwargs:
self._order(symbol, 1, shares=kwargs["nshares"])
[docs] def short(self, symbol: str, **kwargs):
"""Enter a short position of the given symbol.
Args:
symbol: the ticker to short
kwargs:
one of either
"percent" as a percentage of total value (cash + positions),
"absolute" as an absolute value,
"percent_available" as a percentage of remaining funds (excluding positions)
"nshares" as a number of shares
"""
if "percent" in kwargs:
self._order(symbol, -kwargs["percent"], as_percent=True)
if "absolute" in kwargs:
self._order(symbol, -kwargs["absolute"])
if "percent_available" in kwargs:
self._order(symbol, -kwargs["percent_available"], as_percent_available=True)
if "nshares" in kwargs:
self._order(symbol, -1, shares=kwargs["nshares"])
[docs] def price(self, symbol: str) -> float:
"""Get the current price of a given symbol.
Args:
symbol: the ticker
"""
try:
price = self.prices[symbol, self.current_date][self.event]
except KeyError:
raise PriceUnavailableError(
symbol,
self.current_date,
f"""
Price for {symbol} on {self.current_date} could not be found.
""".strip(),
)
if math.isnan(price) or price is None:
self._graceful_stop()
raise PriceUnavailableError(
symbol,
self.current_date,
f"""
Price for {symbol} on {self.current_date} is nan or None.
""".strip(),
)
return price
@property
def balance(self) -> "Balance":
"""Get the current or starting balance.
Examples:
Get the current balance::
bt.balance.current
Get the starting balance::
bt.balance.start
"""
@dataclass
class Balance:
start: float = self._start_capital
current: float = self._available_capital
return Balance()
def _get_bts(self):
bts = [self]
if self._has_strategies:
if self._no_iter:
bts = self._strategies
else:
bts = bts + self._strategies
return bts
@property
def metrics(self) -> pd.DataFrame:
"""Get a dataframe of all metrics collected during the backtest(s).
"""
bts = self._get_bts()
dfs = []
for i, bt in enumerate(bts):
df = pd.DataFrame()
df["Event"] = np.tile(["open", "close"], len(bt) // 2 + 1)[: len(bt)]
df["Date"] = np.repeat(bt.dates, 2)
if self._has_strategies:
if bt.name is not None:
df["Backtest"] = np.repeat(bt.name, len(bt))
else:
df["Backtest"] = np.repeat(f"Backtest {i}", len(bt))
for key in bt.metric.keys():
metric = bt.metric[key]
if metric._series:
df[key] = metric.values
if metric._single:
df[key] = np.repeat(metric.value, len(bt))
dfs.append(df)
if self._has_strategies:
return pd.concat(dfs).set_index(["Backtest", "Date", "Event"])
else:
return pd.concat(dfs).set_index(["Date", "Event"])
@property
def summary(self) -> pd.DataFrame:
"""Get a dataframe showing the last and overall values of all metrics
collected during the backtest.
This can be helpful for comparing backtests at a glance.
"""
bts = self._get_bts()
dfs = []
for i, bt in enumerate(bts):
df = pd.DataFrame()
if self._has_strategies:
if bt.name is not None:
df["Backtest"] = [bt.name]
else:
df["Backtest"] = [f"Backtest {i}"]
for key in bt.metric.keys():
metric = bt.metric[key]
if metric._series:
df[f"{key} (Last Value)"] = [metric[-1]]
if metric._single:
df[key] = [metric.value]
dfs.append(df)
if self._has_strategies:
return pd.concat(dfs).set_index(["Backtest"])
else:
return df
@property
def strategies(self):
"""Provides access to sub-strategies, returning a :class:`.StrategySequence`.
"""
if self._has_strategies:
return StrategySequence(self)
@property
def pf(self) -> Portfolio:
"""Shorthand for `portfolio`, returns the backtesters portfolio.
"""
return self.portfolio
def _set_strategies(
self, strategies: List[Callable[["Date", str, "Backtester"], None]]
):
self._strategies_call = copy.deepcopy(strategies)
for strat in strategies:
new_bt = copy.deepcopy(self)
new_bt._set_self()
new_bt.name = strat.name
new_bt._has_strategies = False
if self._slippage is not None:
self._init_slippage(new_bt)
# this is bad but not bad enough to
# do anything other than this hotfix
self._no_iter = True
self._init_iter(new_bt)
self._no_iter = False
self._strategies.append(new_bt)
def _run_once(self):
no_slip_strats = [
strat for strat in self._strategies if strat._slippage_percent is None
]
slip_strats = [
strat for strat in self._strategies if strat._slippage_percent is not None
]
for bt in slip_strats:
self._next_iter(bt)
for i, bt in enumerate(no_slip_strats):
self._strategies_call[i](*self._next_iter(bt))
def _plot(self, bts, last=False):
try:
if self._live_plot and (self.i % self._live_plot_every == 0 or last):
if not self._live_plot_blocking:
if self._last_thread is None or not self._last_thread.is_alive():
thr = threading.Thread(target=self._show_live_plot, args=(bts,))
thr.start()
self._last_thread = thr
if last:
self._last_thread.join()
self._show_live_plot(bts)
else:
self._show_live_plot(bts)
if self._live_metrics and (self.i % self._live_metrics_every == 0 or last):
self._show_live_metrics(bts)
if (
not (self._live_metrics or self._live_plot)
and self._live_progress
and (self.i % self._live_progress_every == 0 or last)
):
_cls()
print(self._show_live_progress())
for l in self._log[-20:]:
print(l)
if len(self._log) > 20:
print("... more logs stored in Backtester.logs")
for w in self._warn:
warn(w)
except:
pass
@property
def logs(self) -> pd.DataFrame:
"""Returns a :class:`.pd.DataFrame` for logs collected during the backtest.
"""
df = pd.DataFrame(self._log, columns=["date", "event", "log"])
df = df.set_index(["date", "event"])
return df
[docs] def show(self, start=None, end=None):
"""Show the backtester as a plot.
Args:
start: the start date
end: the end date
"""
bts = self._get_bts()
if self._from_sequence:
bts = [self]
if start is not None or end is not None:
self._show_live_plot(bts, [start, end])
else:
self._show_live_plot(bts)
if not is_notebook():
plt.show()
[docs] def run(self):
"""Run the backtesters strategies without using an iterator.
This is only possible if strategies have been set using :meth:`.BacktesterBuilder.strategies`.
"""
self._no_iter = True
self._init_iter()
for _ in range(len(self)):
self._run_once()
self.i = self._strategies[-1].i
self._plot(self._strategies)
self._plot(self._strategies, last=True)
for strat in self._strategies:
for metric in strat.metric.values():
if metric._single:
metric(write=True)