Merge pull request #3758 from jnoortheen/feat-threaded-prompt

feat: use threads to offload prompt sections
This commit is contained in:
Anthony Scopatz 2020-09-29 17:45:31 -07:00 committed by GitHub
commit 6b0a81a2e8
Failed to generate hash of commit
8 changed files with 354 additions and 42 deletions

View file

@ -0,0 +1,10 @@
.. _xonsh_ptk2_formatter:
******************************************************************
Prompt Toolkit 2+ Prompt Formatter (``xonsh.ptk_shell.formatter``)
******************************************************************
.. automodule:: xonsh.ptk_shell.formatter
:members:
:undoc-members:
:inherited-members:

View file

@ -0,0 +1,23 @@
**Added:**
* Now ptk_shell supports loading its sections in thread, speeding up the prompt. Enable it by setting ``$ENABLE_ASYNC_PROMPT=True``.
**Changed:**
* <news item>
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View file

@ -724,6 +724,23 @@ def DEFAULT_VARS():
"Places the auto-suggest result as the first option in the completions. "
"This enables you to tab complete the auto-suggestion.",
),
"ASYNC_INVALIDATE_INTERVAL": Var(
is_float,
float,
str,
0.05,
"When ENABLE_ASYNC_PROMPT is True, it may call the redraw frequently. "
"This is to group such calls into one that happens within that timeframe. "
"The number is set in seconds.",
),
"ASYNC_PROMPT_THREAD_WORKERS": Var(
is_int,
int,
str,
None,
"Define the number of workers used by the ASYC_PROPMT's pool. "
"By default it is defined by Python's concurrent.futures.ThreadPoolExecutor",
),
"BASH_COMPLETIONS": Var(
is_env_path,
str_to_env_path,
@ -876,6 +893,14 @@ def DEFAULT_VARS():
"The string used to show a shortened directory in a shortened cwd, "
"e.g. ``''``.",
),
"ENABLE_ASYNC_PROMPT": Var(
is_bool,
to_bool,
bool_to_str,
False,
"When enabled the prompt is loaded from threads making the shell faster. "
"Sections that take long will be updated in the background. ",
),
"EXPAND_ENV_VARS": Var(
is_bool,
to_bool,

View file

@ -37,47 +37,53 @@ class PromptFormatter:
def __init__(self):
self.cache = {}
def __call__(self, template=DEFAULT_PROMPT, fields=None):
def __call__(self, template=DEFAULT_PROMPT, fields=None, **kwargs):
"""Formats a xonsh prompt template string."""
if fields is None:
self.fields = builtins.__xonsh__.env.get("PROMPT_FIELDS", PROMPT_FIELDS)
else:
self.fields = fields
try:
prompt = self._format_prompt(template=template)
prompt = self._format_prompt(template=template, **kwargs)
except Exception:
return _failover_template_format(template)
# keep cache only during building prompt
self.cache.clear()
return prompt
def _format_prompt(self, template=DEFAULT_PROMPT):
def _format_prompt(self, template=DEFAULT_PROMPT, **kwargs):
return "".join(self._get_tokens(template, **kwargs))
def _get_tokens(self, template, **kwargs):
template = template() if callable(template) else template
toks = []
for literal, field, spec, conv in xt.FORMATTER.parse(template):
toks.append(literal)
entry = self._format_field(field, spec, conv)
entry = self._format_field(field, spec, conv, idx=len(toks), **kwargs)
if entry is not None:
toks.append(entry)
return "".join(toks)
return toks
def _format_field(self, field, spec, conv):
def _format_field(self, field, spec, conv, **kwargs):
if field is None:
return
elif field.startswith("$"):
val = builtins.__xonsh__.env[field[1:]]
return _format_value(val, spec, conv)
elif field in self.fields:
val = self._get_field_value(field)
val = self._get_field_value(field, **kwargs)
return _format_value(val, spec, conv)
else:
# color or unknown field, return as is
return "{" + field + "}"
def _get_field_value(self, field):
def _get_field_value(self, field, **kwargs):
field_value = self.fields[field]
if field_value in self.cache:
return self.cache[field_value]
return self._no_cache_field_value(field, field_value, **kwargs)
def _no_cache_field_value(self, field, field_value, **_):
try:
value = field_value() if callable(field_value) else field_value
self.cache[field_value] = value

View file

@ -0,0 +1,63 @@
"""PTK specific PromptFormatter class."""
import functools
from prompt_toolkit import PromptSession
from xonsh.prompt.base import PromptFormatter, DEFAULT_PROMPT
from xonsh.ptk_shell.updator import PromptUpdator
class PTKPromptFormatter(PromptFormatter):
"""A subclass of PromptFormatter to support rendering prompt sections with/without threads."""
def __init__(self, session: PromptSession):
super().__init__()
self.session = session
def __call__(
self,
template=DEFAULT_PROMPT,
fields=None,
threaded=False,
prompt_name: str = None,
) -> str:
"""Formats a xonsh prompt template string."""
kwargs = {}
if threaded:
# init only for async prompts
if not hasattr(self, "updator"):
# updates an async prompt.
self.updator = PromptUpdator(self.session)
# set these attributes per call. one can enable/disable async-prompt inside a session.
kwargs["async_prompt"] = self.updator.add(prompt_name)
# in case of failure it returns a fail-over template. otherwise it returns list of tokens
prompt_or_tokens = super().__call__(template, fields, **kwargs)
if isinstance(prompt_or_tokens, list):
if threaded:
self.updator.set_tokens(prompt_name, prompt_or_tokens)
return "".join(prompt_or_tokens)
return prompt_or_tokens
def _format_prompt(self, template=DEFAULT_PROMPT, **kwargs):
return self._get_tokens(template, **kwargs)
def _no_cache_field_value(
self, field, field_value, idx=None, async_prompt=None, **_
):
"""This branch is created so that caching fields per prompt would still work."""
func = functools.partial(super()._no_cache_field_value, field, field_value)
if async_prompt is not None and callable(field_value):
# create a thread and return an intermediate result
return async_prompt.submit_section(func, field, idx)
return func()
def start_update(self):
"""Start listening on the prompt section futures."""
self.updator.start()

View file

@ -8,6 +8,7 @@ from types import MethodType
from xonsh.events import events
from xonsh.base_shell import BaseShell
from xonsh.ptk_shell.formatter import PTKPromptFormatter
from xonsh.shell import transform_command
from xonsh.tools import print_exception, carriage_return
from xonsh.platform import HAS_PYGMENTS, ON_WINDOWS, ON_POSIX
@ -108,6 +109,7 @@ class PromptToolkitShell(BaseShell):
self._first_prompt = True
self.history = ThreadedHistory(PromptToolkitHistory())
self.prompter = PromptSession(history=self.history)
self.prompt_formatter = PTKPromptFormatter(self.prompter)
self.pt_completer = PromptToolkitCompleter(self.completer, self.ctx, self)
self.key_bindings = load_xonsh_bindings()
@ -196,7 +198,7 @@ class PromptToolkitShell(BaseShell):
"refresh_interval": refresh_interval,
"complete_in_thread": complete_in_thread,
}
if builtins.__xonsh__.env.get("COLOR_INPUT"):
if env.get("COLOR_INPUT"):
if HAS_PYGMENTS:
prompt_args["lexer"] = PygmentsLexer(pyghooks.XonshLexer)
style = style_from_pygments_cls(pyghooks.xonsh_style_proxy(self.styler))
@ -213,7 +215,12 @@ class PromptToolkitShell(BaseShell):
except (AttributeError, TypeError, ValueError):
print_exception()
if env["ENABLE_ASYNC_PROMPT"]:
# once the prompt is done, update it in background as each future is completed
prompt_args["pre_run"] = self.prompt_formatter.start_update
line = self.prompter.prompt(**prompt_args)
events.on_post_prompt.fire()
return line
@ -258,61 +265,57 @@ class PromptToolkitShell(BaseShell):
else:
break
def prompt_tokens(self):
"""Returns a list of (token, str) tuples for the current prompt."""
p = builtins.__xonsh__.env.get("PROMPT")
def _get_prompt_tokens(self, env_name: str, prompt_name: str, **kwargs):
env = builtins.__xonsh__.env
p = env.get(env_name)
if not p and "default" in kwargs:
return kwargs.pop("default")
try:
p = self.prompt_formatter(p)
p = self.prompt_formatter(
template=p,
threaded=env["ENABLE_ASYNC_PROMPT"],
prompt_name=prompt_name,
)
except Exception: # pylint: disable=broad-except
print_exception()
p, osc_tokens = remove_ansi_osc(p)
if kwargs.get("handle_osc_tokens"):
# handle OSC tokens
for osc in osc_tokens:
if osc[2:4] == "0;":
env["TITLE"] = osc[4:-1]
else:
print(osc, file=sys.__stdout__, flush=True)
toks = partial_color_tokenize(p)
return tokenize_ansi(PygmentsTokens(toks))
def prompt_tokens(self):
"""Returns a list of (token, str) tuples for the current prompt."""
if self._first_prompt:
carriage_return()
self._first_prompt = False
# handle OSC tokens
for osc in osc_tokens:
if osc[2:4] == "0;":
builtins.__xonsh__.env["TITLE"] = osc[4:-1]
else:
print(osc, file=sys.__stdout__, flush=True)
tokens = self._get_prompt_tokens("PROMPT", "message", handle_osc_tokens=True)
self.settitle()
return tokenize_ansi(PygmentsTokens(toks))
return tokens
def rprompt_tokens(self):
"""Returns a list of (token, str) tuples for the current right
prompt.
"""
p = builtins.__xonsh__.env.get("RIGHT_PROMPT")
# self.prompt_formatter does handle empty strings properly,
# but this avoids descending into it in the common case of
# $RIGHT_PROMPT == ''.
if isinstance(p, str) and len(p) == 0:
return []
try:
p = self.prompt_formatter(p)
except Exception: # pylint: disable=broad-except
print_exception()
toks = partial_color_tokenize(p)
return tokenize_ansi(PygmentsTokens(toks))
return self._get_prompt_tokens("RIGHT_PROMPT", "rprompt", default=[])
def _bottom_toolbar_tokens(self):
"""Returns a list of (token, str) tuples for the current bottom
toolbar.
"""
p = builtins.__xonsh__.env.get("BOTTOM_TOOLBAR")
if not p:
return
try:
p = self.prompt_formatter(p)
except Exception: # pylint: disable=broad-except
print_exception()
toks = partial_color_tokenize(p)
return tokenize_ansi(PygmentsTokens(toks))
return self._get_prompt_tokens("BOTTOM_TOOLBAR", "bottom_toolbar", default=None)
@property
def bottom_toolbar_tokens(self):

175
xonsh/ptk_shell/updator.py Normal file
View file

@ -0,0 +1,175 @@
import builtins
import concurrent.futures
import threading
from typing import Dict, List, Union, Callable, Optional
from prompt_toolkit import PromptSession
from prompt_toolkit.formatted_text import PygmentsTokens
from xonsh.style_tools import partial_color_tokenize, style_as_faded
class Executor:
"""Caches thread results across prompts."""
def __init__(self):
self.thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=builtins.__xonsh__.env["ASYNC_PROMPT_THREAD_WORKERS"]
)
# the attribute, .cache is cleared between calls.
# This caches results from callback alone by field name.
self.thread_results = {}
def submit(self, func: Callable, field: str):
future = self.thread_pool.submit(self._run_func, func, field)
place_holder = "{" + field + "}"
return (
future,
(
self.thread_results[field]
if field in self.thread_results
else place_holder
),
place_holder,
)
def _run_func(self, func, field):
"""Run the callback and store the result."""
result = func()
self.thread_results[field] = (
result if result is None else style_as_faded(result)
)
return result
class AsyncPrompt:
"""Represent an asynchronous prompt."""
def __init__(self, name: str, session: PromptSession, executor: Executor):
"""
Parameters
----------
name: str
what prompt to update. One of ['message', 'rprompt', 'bottom_toolbar']
session: PromptSession
current ptk session
"""
self.name = name
# list of tokens in that prompt. It could either be resolved or not resolved.
self.tokens: List[str] = []
self.timer = None
self.session = session
self.executor = executor
# (Key: the future object) that is created for the (value: index/field_name) in the tokens list
self.futures: Dict[concurrent.futures.Future, Union[int, str]] = {}
def start_update(self, on_complete):
"""Listen on futures and update the prompt as each one completed.
Timer is used to avoid clogging multiple calls at the same time.
Parameters
-----------
on_complete:
callback to notify after all the futures are completed
"""
for fut in concurrent.futures.as_completed(self.futures):
val = fut.result() or ""
if fut not in self.futures:
# rare case where the future is completed but the container is already cleared
# because new prompt is called
continue
token_index = self.futures[fut]
if isinstance(token_index, int):
self.tokens[token_index] = val
else: # when the function is called outside shell.
for idx, sect in enumerate(self.tokens):
if token_index in sect:
self.tokens[idx] = sect.replace(token_index, val)
# calling invalidate in less period is inefficient
self.invalidate()
on_complete(self.name)
def invalidate(self):
"""Create a timer to update the prompt. The timing can be configured through env variables.
threading.Timer is used to stop calling invalidate frequently.
"""
from xonsh.ptk_shell.shell import tokenize_ansi
if self.timer:
self.timer.cancel()
def _invalidate():
new_prompt = "".join(self.tokens)
formatted_tokens = tokenize_ansi(
PygmentsTokens(partial_color_tokenize(new_prompt))
)
setattr(self.session, self.name, formatted_tokens)
self.session.app.invalidate()
self.timer = threading.Timer(
builtins.__xonsh__.env["ASYNC_INVALIDATE_INTERVAL"], _invalidate
)
self.timer.start()
def stop(self):
"""Stop any running threads"""
for fut in self.futures:
fut.cancel()
self.futures.clear()
def submit_section(self, func: Callable, field: str, idx: int = None):
future, intermediate_value, placeholder = self.executor.submit(func, field)
self.futures[future] = placeholder if idx is None else idx
return intermediate_value
class PromptUpdator:
def __init__(self, session: PromptSession):
self.prompts: Dict[str, AsyncPrompt] = {}
self.prompter = session
self.executor = Executor()
def add(self, prompt_name: Optional[str]):
# clear out old futures from the same prompt
if prompt_name is None:
return
if prompt_name in self.prompts:
self.stop(prompt_name)
self.prompts[prompt_name] = AsyncPrompt(
prompt_name, self.prompter, self.executor
)
return self.prompts[prompt_name]
def start(self):
"""after ptk prompt is created, update it in background."""
threads = [
threading.Thread(target=prompt.start_update, args=[self.on_complete])
for pt_name, prompt in self.prompts.items()
]
for th in threads:
th.start()
def stop(self, prompt_name: str):
if prompt_name in self.prompts:
self.prompts[prompt_name].stop()
def on_complete(self, prompt_name):
self.prompts.pop(prompt_name, None)
def set_tokens(self, prompt_name, tokens: List[str]):
if prompt_name in self.prompts:
self.prompts[prompt_name].tokens = tokens

View file

@ -167,6 +167,13 @@ def norm_name(name):
return name.upper().replace("#", "HEX")
def style_as_faded(template: str) -> str:
"""Remove the colors from the template string and style as faded."""
tokens = partial_color_tokenize(template)
without_color = "".join([sect for _, sect in tokens])
return "{NO_COLOR}{#d3d3d3}" + without_color + "{NO_COLOR}"
DEFAULT_STYLE_DICT = LazyObject(
lambda: defaultdict(
lambda: "",