diff --git a/tests/aliases/__init__.py b/tests/aliases/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/aliases/test_source.py b/tests/aliases/test_source.py new file mode 100644 index 000000000..10ccc0c09 --- /dev/null +++ b/tests/aliases/test_source.py @@ -0,0 +1,38 @@ +import os.path +import pytest + +from contextlib import contextmanager +from unittest.mock import MagicMock +from xonsh.aliases import source_alias, builtins + + +@pytest.fixture +def mockopen(xonsh_builtins, monkeypatch): + @contextmanager + def mocked_open(fpath, *args, **kwargs): + yield MagicMock(read=lambda: fpath) + monkeypatch.setattr(builtins, 'open', mocked_open) + + +def test_source_current_dir(mockopen, monkeypatch): + checker = [] + + def mocked_execx(src, *args, **kwargs): + checker.append(src.strip()) + monkeypatch.setattr(builtins, 'execx', mocked_execx) + monkeypatch.setattr(os.path, 'isfile', lambda x: True) + source_alias(['foo', 'bar']) + assert checker == ['foo', 'bar'] + + +def test_source_path(mockopen, monkeypatch): + checker = [] + + def mocked_execx(src, *args, **kwargs): + checker.append(src.strip()) + monkeypatch.setattr(builtins, 'execx', mocked_execx) + source_alias(['foo', 'bar']) + path_foo = os.path.join('tests', 'bin', 'foo') + path_bar = os.path.join('tests', 'bin', 'bar') + assert checker[0].endswith(path_foo) + assert checker[1].endswith(path_bar) diff --git a/xonsh/aliases.py b/xonsh/aliases.py index 0b09add1f..bb5acf288 100644 --- a/xonsh/aliases.py +++ b/xonsh/aliases.py @@ -256,9 +256,13 @@ def source_alias(args, stdin=None): encoding = env.get('XONSH_ENCODING') errors = env.get('XONSH_ENCODING_ERRORS') for fname in args: - if not os.path.isfile(fname): - fname = locate_binary(fname) - with open(fname, 'r', encoding=encoding, errors=errors) as fp: + fpath = fname + if not os.path.isfile(fpath): + fpath = locate_binary(fname) + if fpath is None: + print('source: {}: No such file'.format(fname), file=sys.stderr) + continue + with open(fpath, 'r', encoding=encoding, errors=errors) as fp: src = fp.read() if not src.endswith('\n'): src += '\n'